from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
RegLoc, RegType, SSAVal)
+from bigint_presentation_code.ordered_set import OFSet, OSet
if TYPE_CHECKING:
- from typing_extensions import Self, final
+ from typing_extensions import final
else:
def final(v):
return v
self.__start = start # type: int
self.__stop = stop # type: int
self.__ty = ty # type: RegType
- self.__hash = hash(frozenset(self.items()))
+ self.__hash = hash(OFSet(self.items()))
@staticmethod
def from_equality_constraint(constraint_sequence):
for e in op.get_equality_constraints():
lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
- rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
- full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+ items = [] # type: list[tuple[SSAVal, int]]
+ for i in e.lhs:
+ s = merged_sets[i].with_offset_to_match(lhs_set)
+ items.extend(s.items())
+ for i in e.rhs:
+ s = merged_sets[i].with_offset_to_match(rhs_set)
+ items.extend(s.items())
+ full_set = MergedRegSet(items)
for val in full_set.keys():
merged_sets[val] = full_set
else:
live_intervals[reg_set] += op_idx
self.__live_intervals = live_intervals
- live_after = [] # type: list[set[MergedRegSet[_RegType]]]
- live_after += (set() for _ in ops)
+ live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
+ live_after += (OSet() for _ in ops)
for reg_set, live_interval in self.__live_intervals.items():
for i in live_interval.live_after_op_range:
live_after[i].add(reg_set)
- self.__live_after = [frozenset(i) for i in live_after]
+ self.__live_after = [OFSet(i) for i in live_after]
@property
def merged_reg_sets(self):
def __iter__(self):
return iter(self.__live_intervals)
+ def __len__(self):
+ return len(self.__live_intervals)
+
def reg_sets_live_after(self, op_index):
- # type: (int) -> frozenset[MergedRegSet[_RegType]]
+ # type: (int) -> OFSet[MergedRegSet[_RegType]]
return self.__live_after[op_index]
def __repr__(self):
def __init__(self, merged_reg_set, edges=(), reg=None):
# type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
self.merged_reg_set = merged_reg_set
- self.edges = set(edges)
+ self.edges = OSet(edges)
self.reg = reg
def add_edge(self, other):
def __iter__(self):
return iter(self.__nodes)
+ def __len__(self):
+ return len(self.__nodes)
+
def __repr__(self):
nodes = {}
nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
interference_graph[i].add_edge(interference_graph[j])
- nodes_remaining = set(interference_graph.values())
+ nodes_remaining = OSet(interference_graph.values())
def local_colorability_score(node):
# type: (IGNode) -> int