try_allocate_registers_without_spilling works!
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 09:50:28 +0000 (02:50 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 09:50:28 +0000 (02:50 -0700)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/ordered_set.py [new file with mode: 0644]
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/test_compiler_ir.py
src/bigint_presentation_code/test_register_allocator.py

index 24b86494e37c71f10c9491e697efb5f0823e37b8..bfcf159c25e4fc2cced574fdf019977d0fcd102a 100644 (file)
@@ -6,12 +6,13 @@ from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, EnumMeta, unique
 from functools import lru_cache
-from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
-                    TypeVar, cast)
+from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
 
 from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
+from bigint_presentation_code.ordered_set import OFSet, OSet
+
 if TYPE_CHECKING:
     from typing_extensions import final
 else:
@@ -111,8 +112,8 @@ class GPRRange(RegLoc, Sequence["GPRRange"]):
 
     def get_subreg_at_offset(self, subreg_type, offset):
         # type: (RegType, int) -> GPRRange
-        if not isinstance(subreg_type, GPRRangeType):
-            raise ValueError(f"subreg_type is not a "
+        if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
+            raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
                              f"GPRRangeType: {subreg_type}")
         if offset < 0 or offset + subreg_type.length > self.stop:
             raise ValueError(f"sub-register offset is out of range: {offset}")
@@ -150,30 +151,11 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
 
 
 @final
-class RegClass(AbstractSet[RegLoc]):
+class RegClass(OFSet[RegLoc]):
     """ an ordered set of registers.
     earlier registers are preferred by the register allocator.
     """
 
-    def __init__(self, regs):
-        # type: (Iterable[RegLoc]) -> None
-
-        # use dict to maintain order
-        self.__regs = dict.fromkeys(regs)  # type: dict[RegLoc, None]
-
-    def __len__(self):
-        return len(self.__regs)
-
-    def __iter__(self):
-        return iter(self.__regs)
-
-    def __contains__(self, v):
-        # type: (RegLoc) -> bool
-        return v in self.__regs
-
-    def __hash__(self):
-        return super()._hash()
-
     @lru_cache(maxsize=None, typed=True)
     def max_conflicts_with(self, other):
         # type: (RegClass | RegLoc) -> int
@@ -251,12 +233,11 @@ class GPRType(GPRRangeType):
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class FixedGPRRangeType(GPRRangeType):
+class FixedGPRRangeType(RegType):
     __slots__ = "reg",
 
     def __init__(self, reg):
         # type: (GPRRange) -> None
-        super().__init__(length=reg.length)
         self.reg = reg
 
     @property
@@ -264,6 +245,11 @@ class FixedGPRRangeType(GPRRangeType):
         # type: () -> RegClass
         return RegClass([self.reg])
 
+    @property
+    def length(self):
+        # type: () -> int
+        return self.reg.length
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
@@ -384,7 +370,9 @@ class SSAVal(Generic[_RegType]):
     def __hash__(self):
         return hash((id(self.op), self.arg_name))
 
-    def __repr__(self):
+    def __repr__(self, long=False):
+        if not long:
+            return f"<#{self.op.id}.{self.arg_name}>"
         fields_list = []
         for name in fields(self):
             v = getattr(self, name, None)
@@ -449,17 +437,33 @@ class Op(metaclass=ABCMeta):
 
     @cached_property
     def id(self):
+        # type: () -> int
+        # use cached_property rather than done in init so id is usable even if
+        # init hasn't run
         retval = Op.__NEXT_ID
         Op.__NEXT_ID += 1
         return retval
 
+    def __init__(self):
+        self.id  # initialize
+
     @final
     def __repr__(self, just_id=False):
         fields_list = [f"#{self.id}"]
+        outputs = None
+        try:
+            outputs = self.outputs()
+        except AttributeError:
+            pass
         if not just_id:
             for name in fields(self):
                 v = getattr(self, name, _NOT_SET)
-                fields_list.append(f"{name}={v!r}")
+                if ((outputs is None or name in outputs)
+                        and isinstance(v, SSAVal)):
+                    v = v.__repr__(long=True)
+                else:
+                    v = repr(v)
+                fields_list.append(f"{name}={v}")
         fields_str = ', '.join(fields_list)
         return f"{self.__class__.__name__}({fields_str})"
 
@@ -479,6 +483,7 @@ class OpLoadFromStackSlot(Op):
 
     def __init__(self, src):
         # type: (SSAVal[GPRRangeType]) -> None
+        super().__init__()
         self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
         self.src = src
 
@@ -498,6 +503,7 @@ class OpStoreToStackSlot(Op):
 
     def __init__(self, src):
         # type: (SSAVal[StackSlotType]) -> None
+        super().__init__()
         self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
         self.src = src
 
@@ -520,11 +526,17 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
 
     def __init__(self, src, dest_ty=None):
         # type: (SSAVal[_RegSrcType], _RegType | None) -> None
+        super().__init__()
         if dest_ty is None:
             dest_ty = cast(_RegType, src.ty)
         if isinstance(src.ty, GPRRangeType) \
+                and isinstance(dest_ty, FixedGPRRangeType):
+            if src.ty.length != dest_ty.reg.length:
+                raise ValueError(f"incompatible source and destination "
+                                 f"types: {src.ty} and {dest_ty}")
+        elif isinstance(src.ty, FixedGPRRangeType) \
                 and isinstance(dest_ty, GPRRangeType):
-            if src.ty.length != dest_ty.length:
+            if src.ty.reg.length != dest_ty.length:
                 raise ValueError(f"incompatible source and destination "
                                  f"types: {src.ty} and {dest_ty}")
         elif src.ty != dest_ty:
@@ -550,6 +562,7 @@ class OpConcat(Op):
 
     def __init__(self, sources):
         # type: (Iterable[SSAVal[GPRRangeType]]) -> None
+        super().__init__()
         sources = tuple(sources)
         self.dest = SSAVal(self, "dest", GPRRangeType(
             sum(i.ty.length for i in sources)))
@@ -575,6 +588,7 @@ class OpSplit(Op):
 
     def __init__(self, src, split_indexes):
         # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+        super().__init__()
         ranges = []  # type: list[GPRRangeType]
         last = 0
         for i in split_indexes:
@@ -608,6 +622,7 @@ class OpAddSubE(Op):
 
     def __init__(self, RA, RB, CY_in, is_sub):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+        super().__init__()
         if RA.ty != RB.ty:
             raise TypeError(f"source types must match: "
                             f"{RA} doesn't match {RB}")
@@ -639,6 +654,7 @@ class OpBigIntMulDiv(Op):
 
     def __init__(self, RA, RB, RC, is_div):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
         self.RB = RB
@@ -683,6 +699,7 @@ class OpBigIntShift(Op):
 
     def __init__(self, inp, sh, kind):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", inp.ty)
         self.inp = inp
         self.sh = sh
@@ -709,6 +726,7 @@ class OpLI(Op):
 
     def __init__(self, value, length=1):
         # type: (int, int) -> None
+        super().__init__()
         self.out = SSAVal(self, "out", GPRRangeType(length))
         self.value = value
 
@@ -728,6 +746,7 @@ class OpClearCY(Op):
 
     def __init__(self):
         # type: () -> None
+        super().__init__()
         self.out = SSAVal(self, "out", CYType())
 
 
@@ -746,6 +765,7 @@ class OpLoad(Op):
 
     def __init__(self, RA, offset, mem, length=1):
         # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", GPRRangeType(length))
         self.RA = RA
         self.offset = offset
@@ -772,6 +792,7 @@ class OpStore(Op):
 
     def __init__(self, RS, RA, offset, mem_in):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+        super().__init__()
         self.RS = RS
         self.RA = RA
         self.offset = offset
@@ -794,6 +815,7 @@ class OpFuncArg(Op):
 
     def __init__(self, ty):
         # type: (FixedGPRRangeType) -> None
+        super().__init__()
         self.out = SSAVal(self, "out", ty)
 
 
@@ -812,6 +834,7 @@ class OpInputMem(Op):
 
     def __init__(self):
         # type: () -> None
+        super().__init__()
         self.out = SSAVal(self, "out", GlobalMemType())
 
 
@@ -830,7 +853,7 @@ def op_set_to_list(ops):
         ops_to_pending_input_count_map[op] = input_count
         worklists[input_count][op] = None
     retval = []  # type: list[Op]
-    ready_vals = set()  # type: set[SSAVal]
+    ready_vals = OSet()  # type: OSet[SSAVal]
     while len(worklists[0]) != 0:
         writing_op = next(iter(worklists[0]))
         del worklists[0][writing_op]
diff --git a/src/bigint_presentation_code/ordered_set.py b/src/bigint_presentation_code/ordered_set.py
new file mode 100644 (file)
index 0000000..018f97b
--- /dev/null
@@ -0,0 +1,59 @@
+from typing import AbstractSet, Iterable, MutableSet, TypeVar
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+
+class OFSet(AbstractSet[_T_co]):
+    """ ordered frozen set """
+
+    def __init__(self, items=()):
+        # type: (Iterable[_T_co]) -> None
+        self.__items = {v: None for v in items}
+
+    def __contains__(self, x):
+        return x in self.__items
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def __hash__(self):
+        return self._hash()
+
+    def __repr__(self):
+        if len(self) == 0:
+            return "OFSet()"
+        return f"OFSet({list(self)})"
+
+
+class OSet(MutableSet[_T]):
+    """ ordered mutable set """
+
+    def __init__(self, items=()):
+        # type: (Iterable[_T]) -> None
+        self.__items = {v: None for v in items}
+
+    def __contains__(self, x):
+        return x in self.__items
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def add(self, value):
+        # type: (_T) -> None
+        self.__items[value] = None
+
+    def discard(self, value):
+        # type: (_T) -> None
+        self.__items.pop(value, None)
+
+    def __repr__(self):
+        if len(self) == 0:
+            return "OSet()"
+        return f"OSet({list(self)})"
index 75b342243fa95cd7223647b6204f0d33d7e36cf0..b44c32dd76023c948ff9be447d34eb16e02a77bf 100644 (file)
@@ -12,9 +12,10 @@ from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
                                                   RegLoc, RegType, SSAVal)
+from bigint_presentation_code.ordered_set import OFSet, OSet
 
 if TYPE_CHECKING:
-    from typing_extensions import Self, final
+    from typing_extensions import final
 else:
     def final(v):
         return v
@@ -103,7 +104,7 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]):
         self.__start = start  # type: int
         self.__stop = stop  # type: int
         self.__ty = ty  # type: RegType
-        self.__hash = hash(frozenset(self.items()))
+        self.__hash = hash(OFSet(self.items()))
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
@@ -181,9 +182,14 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
             for e in op.get_equality_constraints():
                 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
                 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
-                lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
-                rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
-                full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+                items = []  # type: list[tuple[SSAVal, int]]
+                for i in e.lhs:
+                    s = merged_sets[i].with_offset_to_match(lhs_set)
+                    items.extend(s.items())
+                for i in e.rhs:
+                    s = merged_sets[i].with_offset_to_match(rhs_set)
+                    items.extend(s.items())
+                full_set = MergedRegSet(items)
                 for val in full_set.keys():
                     merged_sets[val] = full_set
 
@@ -219,12 +225,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
                 else:
                     live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
-        live_after = []  # type: list[set[MergedRegSet[_RegType]]]
-        live_after += (set() for _ in ops)
+        live_after = []  # type: list[OSet[MergedRegSet[_RegType]]]
+        live_after += (OSet() for _ in ops)
         for reg_set, live_interval in self.__live_intervals.items():
             for i in live_interval.live_after_op_range:
                 live_after[i].add(reg_set)
-        self.__live_after = [frozenset(i) for i in live_after]
+        self.__live_after = [OFSet(i) for i in live_after]
 
     @property
     def merged_reg_sets(self):
@@ -237,8 +243,11 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
     def __iter__(self):
         return iter(self.__live_intervals)
 
+    def __len__(self):
+        return len(self.__live_intervals)
+
     def reg_sets_live_after(self, op_index):
-        # type: (int) -> frozenset[MergedRegSet[_RegType]]
+        # type: (int) -> OFSet[MergedRegSet[_RegType]]
         return self.__live_after[op_index]
 
     def __repr__(self):
@@ -256,7 +265,7 @@ class IGNode(Generic[_RegType]):
     def __init__(self, merged_reg_set, edges=(), reg=None):
         # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
         self.merged_reg_set = merged_reg_set
-        self.edges = set(edges)
+        self.edges = OSet(edges)
         self.reg = reg
 
     def add_edge(self, other):
@@ -312,6 +321,9 @@ class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
     def __iter__(self):
         return iter(self.__nodes)
 
+    def __len__(self):
+        return len(self.__nodes)
+
     def __repr__(self):
         nodes = {}
         nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
@@ -347,7 +359,7 @@ def try_allocate_registers_without_spilling(ops):
             if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
                 interference_graph[i].add_edge(interference_graph[j])
 
-    nodes_remaining = set(interference_graph.values())
+    nodes_remaining = OSet(interference_graph.values())
 
     def local_colorability_score(node):
         # type: (IGNode) -> int
index 26d5272a040fb178d3cd92c5a33e0c0b6bed0430..1f30547806c5203a9ab8d2855d6c70122eeaca22 100644 (file)
@@ -9,7 +9,7 @@ class TestCompilerIR(unittest.TestCase):
     maxDiff = None
 
     def test_op_set_to_list(self):
-        ops = []  # list[Op]
+        ops = []  # type: list[Op]
         op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
         ops.append(op0)
         op1 = OpCopy(op0.out, GPRType())
index 43675a973807db9e5f23d9a8c46b5cfd34faa6ee..8ba74ebbf87d00439a9e2077eb3d83f4e63e9cbf 100644 (file)
 import unittest
 
-from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange,
+                                                  GPRType, GlobalMem, Op, OpAddSubE,
+                                                  OpClearCY, OpConcat, OpCopy,
+                                                  OpFuncArg, OpInputMem, OpLI,
+                                                  OpLoad, OpStore, XERBit)
 from bigint_presentation_code.register_allocator import (
-    AllocationFailed, allocate_registers,
+    AllocationFailed, allocate_registers, MergedRegSet,
     try_allocate_registers_without_spilling)
 
 
+class TestMergedRegSet(unittest.TestCase):
+    maxDiff = None
+
+    def test_from_equality_constraint(self):
+        op0 = OpLI(0, length=1)
+        op1 = OpLI(0, length=2)
+        op2 = OpLI(0, length=3)
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            op0.out,
+            op1.out,
+            op2.out,
+        ]), MergedRegSet({
+            op0.out: 0,
+            op1.out: 1,
+            op2.out: 3,
+        }.items()))
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            op1.out,
+            op0.out,
+            op2.out,
+        ]), MergedRegSet({
+            op1.out: 0,
+            op0.out: 2,
+            op2.out: 3,
+        }.items()))
+
+
 class TestRegisterAllocator(unittest.TestCase):
-    pass  # no tests yet, just testing importing
+    maxDiff = None
+
+    def test_try_alloc_fail(self):
+        ops = []  # type: list[Op]
+        op0 = OpLI(0, length=52)
+        ops.append(op0)
+        op1 = OpLI(0, length=64)
+        ops.append(op1)
+        op2 = OpConcat([op0.out, op1.out])
+        ops.append(op2)
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+        self.assertEqual(
+            repr(reg_assignments),
+            "AllocationFailed("
+            "node=IGNode(#0, merged_reg_set=MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "edges={}, reg=None), "
+            "live_intervals=LiveIntervals("
+            "live_intervals={"
+            "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): "
+            "LiveInterval(first_write=0, last_use=2)}, "
+            "merged_reg_sets=MergedRegSets(data={"
+            "<#0.out>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "<#1.out>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "<#2.dest>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), "
+            "reg_sets_live_after={"
+            "0: OFSet([MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+            "1: OFSet([MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+            "2: OFSet()}), "
+            "interference_graph=InterferenceGraph(nodes={"
+            "...: IGNode(#0, "
+            "merged_reg_set=MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "edges={}, reg=None)}))"
+        )
+
+    def test_try_alloc_bigint_inc(self):
+        ops = []  # type: list[Op]
+        op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+        ops.append(op0)
+        op1 = OpCopy(op0.out, GPRType())
+        ops.append(op1)
+        arg = op1.dest
+        op2 = OpInputMem()
+        ops.append(op2)
+        mem = op2.out
+        op3 = OpLoad(arg, offset=0, mem=mem, length=32)
+        ops.append(op3)
+        a = op3.RT
+        op4 = OpLI(1)
+        ops.append(op4)
+        b_0 = op4.out
+        op5 = OpLI(0, length=31)
+        ops.append(op5)
+        b_rest = op5.out
+        op6 = OpConcat([b_0, b_rest])
+        ops.append(op6)
+        b = op6.dest
+        op7 = OpClearCY()
+        ops.append(op7)
+        cy = op7.out
+        op8 = OpAddSubE(a, b, cy, is_sub=False)
+        ops.append(op8)
+        s = op8.RT
+        op9 = OpStore(s, arg, offset=0, mem_in=mem)
+        ops.append(op9)
+        mem = op9.mem_out
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+
+        expected_reg_assignments = {
+            op0.out: GPRRange(start=3, length=1),
+            op1.dest: GPRRange(start=3, length=1),
+            op2.out: GlobalMem.GlobalMem,
+            op3.RT: GPRRange(start=78, length=32),
+            op4.out: GPRRange(start=46, length=1),
+            op5.out: GPRRange(start=47, length=31),
+            op6.dest: GPRRange(start=46, length=32),
+            op7.out: XERBit.CY,
+            op8.RT: GPRRange(start=14, length=32),
+            op8.CY_out: XERBit.CY,
+            op9.mem_out: GlobalMem.GlobalMem,
+        }
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
+        # type: (list[GPRRange], GPRRange) -> None
+        li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)]
+        ops = [*li_ops]  # type: list[Op]
+        concat = OpConcat([i.out for i in li_ops])
+        ops.append(concat)
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+
+        expected_reg_assignments = {concat.dest: expected_dest_reg}
+        for li_op, reg in zip(li_ops, expected_regs):
+            expected_reg_assignments[li_op.out] = reg
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def test_try_alloc_concat_1(self):
+        self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
+
+    def test_try_alloc_concat_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
+
+    def test_try_alloc_concat_3_5(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_5_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_1_2_3_4_5_6(self):
+        self.tst_try_alloc_concat([
+            GPRRange(14, 1),
+            GPRRange(15, 2),
+            GPRRange(17, 3),
+            GPRRange(20, 4),
+            GPRRange(24, 5),
+            GPRRange(29, 6),
+        ], GPRRange(14, 21))
 
 
 if __name__ == "__main__":