try_allocate_registers_without_spilling is completed, but untested
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 05:52:15 +0000 (22:52 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 05:53:53 +0000 (22:53 -0700)
src/bigint_presentation_code/toom_cook.py

index ce26d68f79b42e776cae1bb171f0a79bb43008d5..88e8f7e27c2e3715ec73c90b5fa2df192e4af7c9 100644 (file)
@@ -11,7 +11,7 @@ from enum import Enum, unique, EnumMeta
 from functools import lru_cache
 from itertools import combinations
 from typing import (Sequence, AbstractSet, Iterable, Mapping,
-                    TYPE_CHECKING, Sequence, Sized, TypeVar, Generic)
+                    TYPE_CHECKING, Sequence, TypeVar, Generic)
 
 from nmutil.plain_data import plain_data
 
@@ -26,11 +26,7 @@ class ABCEnumMeta(EnumMeta, ABCMeta):
     pass
 
 
-class PhysLoc(metaclass=ABCMeta):
-    __slots__ = ()
-
-
-class RegLoc(PhysLoc):
+class RegLoc(metaclass=ABCMeta):
     __slots__ = ()
 
     @abstractmethod
@@ -38,9 +34,15 @@ class RegLoc(PhysLoc):
         # type: (RegLoc) -> bool
         ...
 
-
-class GPRRangeOrStackLoc(PhysLoc, Sized):
-    __slots__ = ()
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> RegLoc
+        if self not in subreg_type.reg_class:
+            raise ValueError(f"register not a member of subreg_type: "
+                             f"reg={self} subreg_type={subreg_type}")
+        if offset != 0:
+            raise ValueError(f"non-zero sub-register offset not supported "
+                             f"for register: {self}")
+        return self
 
 
 GPR_COUNT = 128
@@ -48,7 +50,7 @@ GPR_COUNT = 128
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]):
+class GPRRange(RegLoc, Sequence["GPRRange"]):
     __slots__ = "start", "length"
 
     def __init__(self, start, length=None):
@@ -110,6 +112,15 @@ class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]):
             return self.stop > other.start and other.stop > self.start
         return False
 
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> GPRRange
+        if not isinstance(subreg_type, GPRRangeType):
+            raise ValueError(f"subreg_type is not a "
+                             f"GPRRangeType: {subreg_type}")
+        if offset < 0 or offset + subreg_type.length > self.stop:
+            raise ValueError(f"sub-register offset is out of range: {offset}")
+        return GPRRange(self.start + offset, subreg_type.length)
+
 
 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
 
@@ -143,9 +154,14 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
 
 @final
 class RegClass(AbstractSet[RegLoc]):
+    """ an ordered set of registers.
+    earlier registers are preferred by the register allocator.
+    """
     def __init__(self, regs):
         # type: (Iterable[RegLoc]) -> None
-        self.__regs = frozenset(regs)
+
+        # use dict to maintain order
+        self.__regs = dict.fromkeys(regs)  # type: dict[RegLoc, None]
 
     def __len__(self):
         return len(self.__regs)
@@ -160,6 +176,17 @@ class RegClass(AbstractSet[RegLoc]):
     def __hash__(self):
         return super()._hash()
 
+    @lru_cache(maxsize=None, typed=True)
+    def max_conflicts_with(self, other):
+        # type: (RegClass | RegLoc) -> int
+        """the largest number of registers in `self` that a single register
+        from `other` can conflict with
+        """
+        if isinstance(other, RegClass):
+            return max(self.max_conflicts_with(i) for i in other)
+        else:
+            return sum(other.conflicts(i) for i in self)
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 class RegType(metaclass=ABCMeta):
@@ -183,7 +210,7 @@ class GPRRangeType(RegType):
         self.length = length
 
     @staticmethod
-    @lru_cache()
+    @lru_cache(maxsize=None)
     def __get_reg_class(length):
         # type: (int) -> RegClass
         regs = []
@@ -243,21 +270,77 @@ class GlobalMemType(RegType):
         return RegClass([GlobalMem.GlobalMem])
 
 
-@plain_data()
+@plain_data(frozen=True, unsafe_hash=True)
 @final
-class StackSlot(GPRRangeOrStackLoc):
-    """a stack slot. Use OpCopy to load from/store into this stack slot."""
-    __slots__ = "offset", "length"
+class StackSlot(RegLoc):
+    __slots__ = "start_slot", "length_in_slots",
 
-    def __init__(self, offset=None, length=1):
-        # type: (int | None, int) -> None
-        self.offset = offset
-        if length < 1:
-            raise ValueError("invalid length")
-        self.length = length
+    def __init__(self, start_slot, length_in_slots):
+        # type: (int, int) -> None
+        self.start_slot = start_slot
+        if length_in_slots < 1:
+            raise ValueError("invalid length_in_slots")
+        self.length_in_slots = length_in_slots
 
-    def __len__(self):
-        return self.length
+    @property
+    def stop_slot(self):
+        return self.start_slot + self.length_in_slots
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, StackSlot):
+            return (self.stop_slot > other.start_slot
+                    and other.stop_slot > self.start_slot)
+        return False
+
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> StackSlot
+        if not isinstance(subreg_type, StackSlotType):
+            raise ValueError(f"subreg_type is not a "
+                             f"StackSlotType: {subreg_type}")
+        if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
+            raise ValueError(f"sub-register offset is out of range: {offset}")
+        return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
+
+
+STACK_SLOT_COUNT = 128
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class StackSlotType(RegType):
+    __slots__ = "length_in_slots",
+
+    def __init__(self, length_in_slots=1):
+        # type: (int) -> None
+        if length_in_slots < 1:
+            raise ValueError("invalid length_in_slots")
+        self.length_in_slots = length_in_slots
+
+    @staticmethod
+    @lru_cache(maxsize=None)
+    def __get_reg_class(length_in_slots):
+        # type: (int) -> RegClass
+        regs = []
+        for start in range(STACK_SLOT_COUNT - length_in_slots):
+            reg = StackSlot(start, length_in_slots)
+            regs.append(reg)
+        return RegClass(regs)
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return StackSlotType.__get_reg_class(self.length_in_slots)
+
+    @final
+    def __eq__(self, other):
+        if isinstance(other, StackSlotType):
+            return self.length_in_slots == other.length_in_slots
+        return False
+
+    @final
+    def __hash__(self):
+        return hash(self.length_in_slots)
 
 
 _RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
@@ -329,6 +412,44 @@ class Op(metaclass=ABCMeta):
         pass
 
 
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoadFromStackSlot(Op):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, src):
+        # type: (SSAVal[GPRRangeType]) -> None
+        self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+        self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStoreToStackSlot(Op):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, src):
+        # type: (SSAVal[StackSlotType]) -> None
+        self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+        self.src = src
+
+
 @plain_data(unsafe_hash=True, frozen=True)
 @final
 class OpCopy(Op, Generic[_RegT_co]):
@@ -745,6 +866,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
         self.__start = start  # type: int
         self.__stop = stop  # type: int
         self.__ty = ty  # type: RegType
+        self.__hash = hash(frozenset(self.items()))
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
@@ -804,7 +926,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
         return len(self.__items)
 
     def __hash__(self):
-        return hash(frozenset(self.items()))
+        return self.__hash
 
     def __repr__(self):
         return f"MergedRegSet({list(self.__items.items())})"
@@ -892,12 +1014,13 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
 @final
 class IGNode(Generic[_RegT_co]):
     """ interference graph node """
-    __slots__ = "merged_reg_set", "edges"
+    __slots__ = "merged_reg_set", "edges", "reg"
 
-    def __init__(self, merged_reg_set, edges=()):
-        # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None
+    def __init__(self, merged_reg_set, edges=(), reg=None):
+        # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
         self.merged_reg_set = merged_reg_set
         self.edges = set(edges)
+        self.reg = reg
 
     def add_edge(self, other):
         # type: (IGNode) -> None
@@ -923,7 +1046,20 @@ class IGNode(Generic[_RegT_co]):
         edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
         return (f"IGNode(#{nodes[self]}, "
                 f"merged_reg_set={self.merged_reg_set}, "
-                f"edges={edges})")
+                f"edges={edges}, "
+                f"reg={self.reg})")
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return self.merged_reg_set.ty.reg_class
+
+    def reg_conflicts_with_neighbors(self, reg):
+        # type: (RegLoc) -> bool
+        for neighbor in self.edges:
+            if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+                return True
+        return False
 
 
 @final
@@ -948,17 +1084,17 @@ class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
 
 @plain_data()
 class AllocationFailed:
-    __slots__ = "op_idx", "arg", "live_intervals"
+    __slots__ = "node", "live_intervals", "interference_graph"
 
-    def __init__(self, op_idx, arg, live_intervals):
-        # type: (int, SSAVal, LiveIntervals) -> None
-        self.op_idx = op_idx
-        self.arg = arg
+    def __init__(self, node, live_intervals, interference_graph):
+        # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+        self.node = node
         self.live_intervals = live_intervals
+        self.interference_graph = interference_graph
 
 
 def try_allocate_registers_without_spilling(ops):
-    # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
+    # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
 
     live_intervals = LiveIntervals(ops)
     merged_reg_sets = live_intervals.merged_reg_sets
@@ -966,12 +1102,74 @@ def try_allocate_registers_without_spilling(ops):
     for op_idx, op in enumerate(ops):
         reg_sets = live_intervals.reg_sets_live_after(op_idx)
         for i, j in combinations(reg_sets, 2):
-            interference_graph[i].add_edge(interference_graph[j])
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
         for i, j in op.get_extra_interferences():
-            interference_graph[merged_reg_sets[i]].add_edge(
-                interference_graph[merged_reg_sets[j]])
+            i = merged_reg_sets[i]
+            j = merged_reg_sets[j]
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
+
+    nodes_remaining = set(interference_graph.values())
+
+    def local_colorability_score(node):
+        # type: (IGNode) -> int
+        """ returns a positive integer if node is locally colorable, returns
+        zero or a negative integer if node isn't known to be locally
+        colorable, the more negative the value, the less colorable
+        """
+        if node not in nodes_remaining:
+            raise ValueError()
+        retval = len(node.reg_class)
+        for neighbor in node.edges:
+            if neighbor in nodes_remaining:
+                retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+        return retval
+
+    node_stack = []  # type: list[IGNode]
+    while True:
+        best_node = None  # type: None | IGNode
+        best_score = 0
+        for node in nodes_remaining:
+            score = local_colorability_score(node)
+            if best_node is None or score > best_score:
+                best_node = node
+                best_score = score
+            if best_score > 0:
+                # it's locally colorable, no need to find a better one
+                break
+
+        if best_node is None:
+            break
+        node_stack.append(best_node)
+        nodes_remaining.remove(best_node)
+
+    retval = {}  # type: dict[SSAVal, RegLoc]
+
+    while len(node_stack) > 0:
+        node = node_stack.pop()
+        if node.reg is not None:
+            if node.reg_conflicts_with_neighbors(node.reg):
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+        else:
+            # pick the first non-conflicting register in node.reg_class, since
+            # register classes are ordered from most preferred to least
+            # preferred register.
+            for reg in node.reg_class:
+                if not node.reg_conflicts_with_neighbors(reg):
+                    node.reg = reg
+                    break
+            if node.reg is None:
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+
+        for ssa_val, offset in node.merged_reg_set.items():
+            retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
 
-    raise NotImplementedError
+    return retval
 
 
 def allocate_registers(ops):