construct interference graph
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 13 Oct 2022 03:34:41 +0000 (20:34 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 13 Oct 2022 03:34:41 +0000 (20:34 -0700)
src/bigint_presentation_code/toom_cook.py

index e8e97bfac60bad18506e39f53c867765441d14c4..ce26d68f79b42e776cae1bb171f0a79bb43008d5 100644 (file)
@@ -9,6 +9,7 @@ from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, unique, EnumMeta
 from functools import lru_cache
+from itertools import combinations
 from typing import (Sequence, AbstractSet, Iterable, Mapping,
                     TYPE_CHECKING, Sequence, Sized, TypeVar, Generic)
 
@@ -259,16 +260,16 @@ class StackSlot(GPRRangeOrStackLoc):
         return self.length
 
 
-_RegType = TypeVar("_RegType", bound=RegType)
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
 
 
 @plain_data(frozen=True, eq=False)
 @final
-class SSAVal(Generic[_RegType]):
+class SSAVal(Generic[_RegT_co]):
     __slots__ = "op", "arg_name", "ty", "arg_index"
 
     def __init__(self, op, arg_name, ty):
-        # type: (Op, str, _RegType) -> None
+        # type: (Op, str, _RegT_co) -> None
         self.op = op
         """the Op that writes this SSAVal"""
 
@@ -319,13 +320,18 @@ class Op(metaclass=ABCMeta):
         if False:
             yield ...
 
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        if False:
+            yield ...
+
     def __init__(self):
         pass
 
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpCopy(Op, Generic[_RegType]):
+class OpCopy(Op, Generic[_RegT_co]):
     __slots__ = "dest", "src"
 
     def inputs(self):
@@ -337,7 +343,7 @@ class OpCopy(Op, Generic[_RegType]):
         return {"dest": self.dest}
 
     def __init__(self, src):
-        # type: (SSAVal[_RegType]) -> None
+        # type: (SSAVal[_RegT_co]) -> None
         self.dest = SSAVal(self, "dest", src.ty)
         self.src = src
 
@@ -425,6 +431,11 @@ class OpAddSubE(Op):
         self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
         self.is_sub = is_sub
 
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.RA
+        yield self.RT, self.RB
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -452,6 +463,15 @@ class OpBigIntMulDiv(Op):
         # type: () -> Iterable[EqualityConstraint]
         yield EqualityConstraint([self.RC], [self.RS])
 
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.RA
+        yield self.RT, self.RB
+        yield self.RT, self.RC
+        yield self.RT, self.RS
+        yield self.RS, self.RA
+        yield self.RS, self.RB
+
 
 @final
 @unique
@@ -481,6 +501,11 @@ class OpBigIntShift(Op):
         self.sh = sh
         self.kind = kind
 
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.inp
+        yield self.RT, self.sh
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -539,6 +564,11 @@ class OpLoad(Op):
         self.offset = offset
         self.mem = mem
 
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        if self.RT.ty.length > 1:
+            yield self.RT, self.RA
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -662,12 +692,19 @@ class LiveInterval:
         last_use = max(self.last_use, use)
         return LiveInterval(first_write=self.first_write, last_use=last_use)
 
+    @property
+    def live_after_op_range(self):
+        """the range of op indexes where self is live immediately after the
+        Op at each index
+        """
+        return range(self.first_write, self.last_use)
+
 
 @final
-class MergedRegSet(Mapping[SSAVal[_RegType], int]):
+class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
     def __init__(self, reg_set):
-        # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
-        self.__items = {}  # type: dict[SSAVal[_RegType], int]
+        # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
+        self.__items = {}  # type: dict[SSAVal[_RegT_co], int]
         if isinstance(reg_set, SSAVal):
             reg_set = [(reg_set, 0)]
         for ssa_val, offset in reg_set:
@@ -711,7 +748,7 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]):
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
-        # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
+        # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
         if len(constraint_sequence) == 1:
             # any type allowed with len = 1
             return MergedRegSet(constraint_sequence[0])
@@ -742,22 +779,22 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]):
         return range(self.__start, self.__stop)
 
     def offset_by(self, amount):
-        # type: (int) -> MergedRegSet[_RegType]
+        # type: (int) -> MergedRegSet[_RegT_co]
         return MergedRegSet((k, v + amount) for k, v in self.items())
 
     def normalized(self):
-        # type: () -> MergedRegSet[_RegType]
+        # type: () -> MergedRegSet[_RegT_co]
         return self.offset_by(-self.start)
 
     def with_offset_to_match(self, target):
-        # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
+        # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
         for ssa_val, offset in self.items():
             if ssa_val in target:
                 return self.offset_by(target[ssa_val] - offset)
         raise ValueError("can't change offset to match unrelated MergedRegSet")
 
     def __getitem__(self, item):
-        # type: (SSAVal[_RegType]) -> int
+        # type: (SSAVal[_RegT_co]) -> int
         return self.__items[item]
 
     def __iter__(self):
@@ -774,10 +811,10 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]):
 
 
 @final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
     def __init__(self, ops):
         # type: (Iterable[Op]) -> None
-        merged_sets = {}  # type: dict[SSAVal, MergedRegSet]
+        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegT_co]]
         for op in ops:
             for val in (*op.inputs().values(), *op.outputs().values()):
                 if val not in merged_sets:
@@ -803,13 +840,16 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet]):
     def __len__(self):
         return len(self.__map)
 
+    def __repr__(self):
+        return f"MergedRegSets(data={self.__map})"
+
 
 @final
-class LiveIntervals(Mapping[MergedRegSet, LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
     def __init__(self, ops):
         # type: (list[Op]) -> None
         self.__merged_reg_sets = MergedRegSets(ops)
-        live_intervals = {}  # type: dict[MergedRegSet, LiveInterval]
+        live_intervals = {}  # type: dict[MergedRegSet[_RegT_co], LiveInterval]
         for op_idx, op in enumerate(ops):
             for val in op.inputs().values():
                 live_intervals[self.__merged_reg_sets[val]] += op_idx
@@ -820,26 +860,42 @@ class LiveIntervals(Mapping[MergedRegSet, LiveInterval]):
                 else:
                     live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
+        live_after = []  # type: list[set[MergedRegSet[_RegT_co]]]
+        live_after += (set() for _ in ops)
+        for reg_set, live_interval in self.__live_intervals.items():
+            for i in live_interval.live_after_op_range:
+                live_after[i].add(reg_set)
+        self.__live_after = [frozenset(i) for i in live_after]
 
     @property
     def merged_reg_sets(self):
         return self.__merged_reg_sets
 
     def __getitem__(self, key):
-        # type: (MergedRegSet) -> LiveInterval
+        # type: (MergedRegSet[_RegT_co]) -> LiveInterval
         return self.__live_intervals[key]
 
     def __iter__(self):
         return iter(self.__live_intervals)
 
+    def reg_sets_live_after(self, op_index):
+        # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+        return self.__live_after[op_index]
+
+    def __repr__(self):
+        reg_sets_live_after = dict(enumerate(self.__live_after))
+        return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+                f"merged_reg_sets={self.merged_reg_sets}, "
+                f"reg_sets_live_after={reg_sets_live_after})")
+
 
 @final
-class IGNode:
+class IGNode(Generic[_RegT_co]):
     """ interference graph node """
     __slots__ = "merged_reg_set", "edges"
 
     def __init__(self, merged_reg_set, edges=()):
-        # type: (MergedRegSet, Iterable[IGNode]) -> None
+        # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None
         self.merged_reg_set = merged_reg_set
         self.edges = set(edges)
 
@@ -871,18 +927,24 @@ class IGNode:
 
 
 @final
-class InterferenceGraph(Mapping[MergedRegSet, IGNode]):
+class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
     def __init__(self, merged_reg_sets):
-        # type: (Iterable[MergedRegSet]) -> None
+        # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
         self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
 
     def __getitem__(self, key):
-        # type: (MergedRegSet) -> IGNode
+        # type: (MergedRegSet[_RegT_co]) -> IGNode
         return self.__nodes[key]
 
     def __iter__(self):
         return iter(self.__nodes)
 
+    def __repr__(self):
+        nodes = {}
+        nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+        nodes_text = ", ".join(nodes_text)
+        return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
 
 @plain_data()
 class AllocationFailed:
@@ -899,6 +961,15 @@ def try_allocate_registers_without_spilling(ops):
     # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
 
     live_intervals = LiveIntervals(ops)
+    merged_reg_sets = live_intervals.merged_reg_sets
+    interference_graph = InterferenceGraph(merged_reg_sets.values())
+    for op_idx, op in enumerate(ops):
+        reg_sets = live_intervals.reg_sets_live_after(op_idx)
+        for i, j in combinations(reg_sets, 2):
+            interference_graph[i].add_edge(interference_graph[j])
+        for i, j in op.get_extra_interferences():
+            interference_graph[merged_reg_sets[i]].add_edge(
+                interference_graph[merged_reg_sets[j]])
 
     raise NotImplementedError