add EqualitySets
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 10 Oct 2022 22:38:21 +0000 (15:38 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 10 Oct 2022 22:38:21 +0000 (15:38 -0700)
src/bigint_presentation_code/toom_cook.py

index 2209d8b780f3189588dbe7b1317157615d8d836f..16ef5d0db9ae5bb26c711577647e11eb16a0686b 100644 (file)
@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
 import builtins
 from collections import defaultdict
 from enum import Enum, unique
-from typing import Iterable, Mapping, TYPE_CHECKING
+from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING
 
 from nmutil.plain_data import plain_data
 
@@ -869,22 +869,74 @@ class LiveInterval:
         return LiveInterval(assignment=self.assignment, last_use=last_use)
 
 
-class LiveIntervals(Mapping[SSAVal, LiveInterval]):
+@final
+class EqualitySet(AbstractSet[SSAVal]):
+    def __init__(self, items):
+        # type: (Iterable[SSAVal]) -> None
+        self.__items = frozenset(items)
+
+    def __contains__(self, x):
+        # type: (object) -> bool
+        return x in self.__items
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+
+@final
+class EqualitySets(Mapping[SSAVal, EqualitySet]):
+    def __init__(self, ops):
+        # type: (Iterable[Op]) -> None
+        indexes = {}  # type: dict[SSAVal, int]
+        sets = []  # type: list[set[SSAVal]]
+        for op in ops:
+            for val in (*op.input_ssa_vals(), *op.output_ssa_vals()):
+                if val not in indexes:
+                    indexes[val] = len(sets)
+                    sets.append({val})
+            for e in op.get_equality_constraints():
+                lhs_index = indexes[e.lhs]
+                rhs_index = indexes[e.rhs]
+                sets[lhs_index] |= sets[rhs_index]
+                for val in sets[rhs_index]:
+                    indexes[val] = lhs_index
+
+        equality_sets = [EqualitySet(i) for i in sets]
+        self.__map = {k: equality_sets[v] for k, v in indexes.items()}
+
+    def __getitem__(self, key):
+        # type: (SSAVal) -> EqualitySet
+        return self.__map[key]
+
+    def __iter__(self):
+        return iter(self.__map)
+
+
+@final
+class LiveIntervals(Mapping[EqualitySet, LiveInterval]):
     def __init__(self, ops):
         # type: (list[Op]) -> None
-        live_intervals = {}  # type: dict[SSAVal, LiveInterval]
+        self.__equality_sets = eqsets = EqualitySets(ops)
+        live_intervals = {}  # type: dict[EqualitySet, LiveInterval]
         for op_idx, op in enumerate(ops):
             for val in op.input_ssa_vals():
-                live_intervals[val] += op_idx
+                live_intervals[eqsets[val]] += op_idx
             for val in op.output_ssa_vals():
-                if val in live_intervals:
-                    raise ValueError(f"multiple instructions must not write "
-                                     f"to the same SSA value: {val}")
-                live_intervals[val] = LiveInterval(op_idx)
+                if eqsets[val] not in live_intervals:
+                    live_intervals[eqsets[val]] = LiveInterval(op_idx)
+                else:
+                    live_intervals[eqsets[val]] += op_idx
         self.__live_intervals = live_intervals
 
+    @property
+    def equality_sets(self):
+        return self.__equality_sets
+
     def __getitem__(self, key):
-        # type: (SSAVal) -> LiveInterval
+        # type: (EqualitySet) -> LiveInterval
         return self.__live_intervals[key]
 
     def __iter__(self):
@@ -893,14 +945,13 @@ class LiveIntervals(Mapping[SSAVal, LiveInterval]):
 
 @plain_data()
 class AllocationFailed:
-    __slots__ = "op_idx", "arg", "live_intervals", "free_regs"
+    __slots__ = "op_idx", "arg", "live_intervals"
 
-    def __init__(self, op_idx, arg, live_intervals, free_regs):
-        # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None
+    def __init__(self, op_idx, arg, live_intervals):
+        # type: (int, SSAVal | VecArg, LiveIntervals) -> None
         self.op_idx = op_idx
         self.arg = arg
         self.live_intervals = live_intervals
-        self.free_regs = free_regs
 
 
 def try_allocate_registers_without_spilling(ops):