From 2800dce3e173b96fe06d556bc21532a075289146 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 10 Oct 2022 15:38:21 -0700 Subject: [PATCH] add EqualitySets --- src/bigint_presentation_code/toom_cook.py | 77 +++++++++++++++++++---- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 2209d8b..16ef5d0 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod import builtins from collections import defaultdict from enum import Enum, unique -from typing import Iterable, Mapping, TYPE_CHECKING +from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING from nmutil.plain_data import plain_data @@ -869,22 +869,74 @@ class LiveInterval: return LiveInterval(assignment=self.assignment, last_use=last_use) -class LiveIntervals(Mapping[SSAVal, LiveInterval]): +@final +class EqualitySet(AbstractSet[SSAVal]): + def __init__(self, items): + # type: (Iterable[SSAVal]) -> None + self.__items = frozenset(items) + + def __contains__(self, x): + # type: (object) -> bool + return x in self.__items + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + +@final +class EqualitySets(Mapping[SSAVal, EqualitySet]): + def __init__(self, ops): + # type: (Iterable[Op]) -> None + indexes = {} # type: dict[SSAVal, int] + sets = [] # type: list[set[SSAVal]] + for op in ops: + for val in (*op.input_ssa_vals(), *op.output_ssa_vals()): + if val not in indexes: + indexes[val] = len(sets) + sets.append({val}) + for e in op.get_equality_constraints(): + lhs_index = indexes[e.lhs] + rhs_index = indexes[e.rhs] + sets[lhs_index] |= sets[rhs_index] + for val in sets[rhs_index]: + indexes[val] = lhs_index + + equality_sets = [EqualitySet(i) for i in sets] + self.__map = {k: equality_sets[v] for k, v in indexes.items()} + + def __getitem__(self, key): + # type: (SSAVal) -> EqualitySet + return self.__map[key] + + def __iter__(self): + return iter(self.__map) + + +@final +class LiveIntervals(Mapping[EqualitySet, LiveInterval]): def __init__(self, ops): # type: (list[Op]) -> None - live_intervals = {} # type: dict[SSAVal, LiveInterval] + self.__equality_sets = eqsets = EqualitySets(ops) + live_intervals = {} # type: dict[EqualitySet, LiveInterval] for op_idx, op in enumerate(ops): for val in op.input_ssa_vals(): - live_intervals[val] += op_idx + live_intervals[eqsets[val]] += op_idx for val in op.output_ssa_vals(): - if val in live_intervals: - raise ValueError(f"multiple instructions must not write " - f"to the same SSA value: {val}") - live_intervals[val] = LiveInterval(op_idx) + if eqsets[val] not in live_intervals: + live_intervals[eqsets[val]] = LiveInterval(op_idx) + else: + live_intervals[eqsets[val]] += op_idx self.__live_intervals = live_intervals + @property + def equality_sets(self): + return self.__equality_sets + def __getitem__(self, key): - # type: (SSAVal) -> LiveInterval + # type: (EqualitySet) -> LiveInterval return self.__live_intervals[key] def __iter__(self): @@ -893,14 +945,13 @@ class LiveIntervals(Mapping[SSAVal, LiveInterval]): @plain_data() class AllocationFailed: - __slots__ = "op_idx", "arg", "live_intervals", "free_regs" + __slots__ = "op_idx", "arg", "live_intervals" - def __init__(self, op_idx, arg, live_intervals, free_regs): - # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None + def __init__(self, op_idx, arg, live_intervals): + # type: (int, SSAVal | VecArg, LiveIntervals) -> None self.op_idx = op_idx self.arg = arg self.live_intervals = live_intervals - self.free_regs = free_regs def try_allocate_registers_without_spilling(ops): -- 2.30.2