try_allocate_registers_without_spilling works!
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
index 75b342243fa95cd7223647b6204f0d33d7e36cf0..b44c32dd76023c948ff9be447d34eb16e02a77bf 100644 (file)
@@ -12,9 +12,10 @@ from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
                                                   RegLoc, RegType, SSAVal)
+from bigint_presentation_code.ordered_set import OFSet, OSet
 
 if TYPE_CHECKING:
-    from typing_extensions import Self, final
+    from typing_extensions import final
 else:
     def final(v):
         return v
@@ -103,7 +104,7 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]):
         self.__start = start  # type: int
         self.__stop = stop  # type: int
         self.__ty = ty  # type: RegType
-        self.__hash = hash(frozenset(self.items()))
+        self.__hash = hash(OFSet(self.items()))
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
@@ -181,9 +182,14 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
             for e in op.get_equality_constraints():
                 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
                 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
-                lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
-                rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
-                full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+                items = []  # type: list[tuple[SSAVal, int]]
+                for i in e.lhs:
+                    s = merged_sets[i].with_offset_to_match(lhs_set)
+                    items.extend(s.items())
+                for i in e.rhs:
+                    s = merged_sets[i].with_offset_to_match(rhs_set)
+                    items.extend(s.items())
+                full_set = MergedRegSet(items)
                 for val in full_set.keys():
                     merged_sets[val] = full_set
 
@@ -219,12 +225,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
                 else:
                     live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
-        live_after = []  # type: list[set[MergedRegSet[_RegType]]]
-        live_after += (set() for _ in ops)
+        live_after = []  # type: list[OSet[MergedRegSet[_RegType]]]
+        live_after += (OSet() for _ in ops)
         for reg_set, live_interval in self.__live_intervals.items():
             for i in live_interval.live_after_op_range:
                 live_after[i].add(reg_set)
-        self.__live_after = [frozenset(i) for i in live_after]
+        self.__live_after = [OFSet(i) for i in live_after]
 
     @property
     def merged_reg_sets(self):
@@ -237,8 +243,11 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
     def __iter__(self):
         return iter(self.__live_intervals)
 
+    def __len__(self):
+        return len(self.__live_intervals)
+
     def reg_sets_live_after(self, op_index):
-        # type: (int) -> frozenset[MergedRegSet[_RegType]]
+        # type: (int) -> OFSet[MergedRegSet[_RegType]]
         return self.__live_after[op_index]
 
     def __repr__(self):
@@ -256,7 +265,7 @@ class IGNode(Generic[_RegType]):
     def __init__(self, merged_reg_set, edges=(), reg=None):
         # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
         self.merged_reg_set = merged_reg_set
-        self.edges = set(edges)
+        self.edges = OSet(edges)
         self.reg = reg
 
     def add_edge(self, other):
@@ -312,6 +321,9 @@ class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
     def __iter__(self):
         return iter(self.__nodes)
 
+    def __len__(self):
+        return len(self.__nodes)
+
     def __repr__(self):
         nodes = {}
         nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
@@ -347,7 +359,7 @@ def try_allocate_registers_without_spilling(ops):
             if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
                 interference_graph[i].add_edge(interference_graph[j])
 
-    nodes_remaining = set(interference_graph.values())
+    nodes_remaining = OSet(interference_graph.values())
 
     def local_colorability_score(node):
         # type: (IGNode) -> int