test_op_set_to_list works
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
index 297e3e5391c9e1a8821430e8a1bd789907ad30c3..75b342243fa95cd7223647b6204f0d33d7e36cf0 100644 (file)
@@ -20,7 +20,7 @@ else:
         return v
 
 
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+_RegType = TypeVar("_RegType", bound=RegType)
 
 
 @plain_data(unsafe_hash=True, order=True, frozen=True)
@@ -59,10 +59,10 @@ class LiveInterval:
 
 
 @final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
+class MergedRegSet(Mapping[SSAVal[_RegType], int]):
     def __init__(self, reg_set):
-        # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
-        self.__items = {}  # type: dict[SSAVal[_RegT_co], int]
+        # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
+        self.__items = {}  # type: dict[SSAVal[_RegType], int]
         if isinstance(reg_set, SSAVal):
             reg_set = [(reg_set, 0)]
         for ssa_val, offset in reg_set:
@@ -107,7 +107,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
-        # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
+        # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
         if len(constraint_sequence) == 1:
             # any type allowed with len = 1
             return MergedRegSet(constraint_sequence[0])
@@ -138,22 +138,22 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
         return range(self.__start, self.__stop)
 
     def offset_by(self, amount):
-        # type: (int) -> MergedRegSet[_RegT_co]
+        # type: (int) -> MergedRegSet[_RegType]
         return MergedRegSet((k, v + amount) for k, v in self.items())
 
     def normalized(self):
-        # type: () -> MergedRegSet[_RegT_co]
+        # type: () -> MergedRegSet[_RegType]
         return self.offset_by(-self.start)
 
     def with_offset_to_match(self, target):
-        # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
+        # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
         for ssa_val, offset in self.items():
             if ssa_val in target:
                 return self.offset_by(target[ssa_val] - offset)
         raise ValueError("can't change offset to match unrelated MergedRegSet")
 
     def __getitem__(self, item):
-        # type: (SSAVal[_RegT_co]) -> int
+        # type: (SSAVal[_RegType]) -> int
         return self.__items[item]
 
     def __iter__(self):
@@ -170,10 +170,10 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
 
 
 @final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
     def __init__(self, ops):
         # type: (Iterable[Op]) -> None
-        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegT_co]]
+        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegType]]
         for op in ops:
             for val in (*op.inputs().values(), *op.outputs().values()):
                 if val not in merged_sets:
@@ -204,11 +204,11 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
 
 
 @final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
     def __init__(self, ops):
         # type: (list[Op]) -> None
         self.__merged_reg_sets = MergedRegSets(ops)
-        live_intervals = {}  # type: dict[MergedRegSet[_RegT_co], LiveInterval]
+        live_intervals = {}  # type: dict[MergedRegSet[_RegType], LiveInterval]
         for op_idx, op in enumerate(ops):
             for val in op.inputs().values():
                 live_intervals[self.__merged_reg_sets[val]] += op_idx
@@ -219,7 +219,7 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
                 else:
                     live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
-        live_after = []  # type: list[set[MergedRegSet[_RegT_co]]]
+        live_after = []  # type: list[set[MergedRegSet[_RegType]]]
         live_after += (set() for _ in ops)
         for reg_set, live_interval in self.__live_intervals.items():
             for i in live_interval.live_after_op_range:
@@ -231,14 +231,14 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
         return self.__merged_reg_sets
 
     def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> LiveInterval
+        # type: (MergedRegSet[_RegType]) -> LiveInterval
         return self.__live_intervals[key]
 
     def __iter__(self):
         return iter(self.__live_intervals)
 
     def reg_sets_live_after(self, op_index):
-        # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+        # type: (int) -> frozenset[MergedRegSet[_RegType]]
         return self.__live_after[op_index]
 
     def __repr__(self):
@@ -249,12 +249,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
 
 
 @final
-class IGNode(Generic[_RegT_co]):
+class IGNode(Generic[_RegType]):
     """ interference graph node """
     __slots__ = "merged_reg_set", "edges", "reg"
 
     def __init__(self, merged_reg_set, edges=(), reg=None):
-        # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
+        # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
         self.merged_reg_set = merged_reg_set
         self.edges = set(edges)
         self.reg = reg
@@ -300,13 +300,13 @@ class IGNode(Generic[_RegT_co]):
 
 
 @final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
+class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
     def __init__(self, merged_reg_sets):
-        # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
+        # type: (Iterable[MergedRegSet[_RegType]]) -> None
         self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
 
     def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> IGNode
+        # type: (MergedRegSet[_RegType]) -> IGNode
         return self.__nodes[key]
 
     def __iter__(self):