try_allocate_registers_without_spilling works!
[bigint-presentation-code.git] / src / bigint_presentation_code / test_register_allocator.py
index 43675a973807db9e5f23d9a8c46b5cfd34faa6ee..8ba74ebbf87d00439a9e2077eb3d83f4e63e9cbf 100644 (file)
 import unittest
 
-from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange,
+                                                  GPRType, GlobalMem, Op, OpAddSubE,
+                                                  OpClearCY, OpConcat, OpCopy,
+                                                  OpFuncArg, OpInputMem, OpLI,
+                                                  OpLoad, OpStore, XERBit)
 from bigint_presentation_code.register_allocator import (
-    AllocationFailed, allocate_registers,
+    AllocationFailed, allocate_registers, MergedRegSet,
     try_allocate_registers_without_spilling)
 
 
+class TestMergedRegSet(unittest.TestCase):
+    maxDiff = None
+
+    def test_from_equality_constraint(self):
+        op0 = OpLI(0, length=1)
+        op1 = OpLI(0, length=2)
+        op2 = OpLI(0, length=3)
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            op0.out,
+            op1.out,
+            op2.out,
+        ]), MergedRegSet({
+            op0.out: 0,
+            op1.out: 1,
+            op2.out: 3,
+        }.items()))
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            op1.out,
+            op0.out,
+            op2.out,
+        ]), MergedRegSet({
+            op1.out: 0,
+            op0.out: 2,
+            op2.out: 3,
+        }.items()))
+
+
 class TestRegisterAllocator(unittest.TestCase):
-    pass  # no tests yet, just testing importing
+    maxDiff = None
+
+    def test_try_alloc_fail(self):
+        ops = []  # type: list[Op]
+        op0 = OpLI(0, length=52)
+        ops.append(op0)
+        op1 = OpLI(0, length=64)
+        ops.append(op1)
+        op2 = OpConcat([op0.out, op1.out])
+        ops.append(op2)
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+        self.assertEqual(
+            repr(reg_assignments),
+            "AllocationFailed("
+            "node=IGNode(#0, merged_reg_set=MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "edges={}, reg=None), "
+            "live_intervals=LiveIntervals("
+            "live_intervals={"
+            "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): "
+            "LiveInterval(first_write=0, last_use=2)}, "
+            "merged_reg_sets=MergedRegSets(data={"
+            "<#0.out>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "<#1.out>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "<#2.dest>: MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), "
+            "reg_sets_live_after={"
+            "0: OFSet([MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+            "1: OFSet([MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+            "2: OFSet()}), "
+            "interference_graph=InterferenceGraph(nodes={"
+            "...: IGNode(#0, "
+            "merged_reg_set=MergedRegSet(["
+            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "edges={}, reg=None)}))"
+        )
+
+    def test_try_alloc_bigint_inc(self):
+        ops = []  # type: list[Op]
+        op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+        ops.append(op0)
+        op1 = OpCopy(op0.out, GPRType())
+        ops.append(op1)
+        arg = op1.dest
+        op2 = OpInputMem()
+        ops.append(op2)
+        mem = op2.out
+        op3 = OpLoad(arg, offset=0, mem=mem, length=32)
+        ops.append(op3)
+        a = op3.RT
+        op4 = OpLI(1)
+        ops.append(op4)
+        b_0 = op4.out
+        op5 = OpLI(0, length=31)
+        ops.append(op5)
+        b_rest = op5.out
+        op6 = OpConcat([b_0, b_rest])
+        ops.append(op6)
+        b = op6.dest
+        op7 = OpClearCY()
+        ops.append(op7)
+        cy = op7.out
+        op8 = OpAddSubE(a, b, cy, is_sub=False)
+        ops.append(op8)
+        s = op8.RT
+        op9 = OpStore(s, arg, offset=0, mem_in=mem)
+        ops.append(op9)
+        mem = op9.mem_out
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+
+        expected_reg_assignments = {
+            op0.out: GPRRange(start=3, length=1),
+            op1.dest: GPRRange(start=3, length=1),
+            op2.out: GlobalMem.GlobalMem,
+            op3.RT: GPRRange(start=78, length=32),
+            op4.out: GPRRange(start=46, length=1),
+            op5.out: GPRRange(start=47, length=31),
+            op6.dest: GPRRange(start=46, length=32),
+            op7.out: XERBit.CY,
+            op8.RT: GPRRange(start=14, length=32),
+            op8.CY_out: XERBit.CY,
+            op9.mem_out: GlobalMem.GlobalMem,
+        }
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
+        # type: (list[GPRRange], GPRRange) -> None
+        li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)]
+        ops = [*li_ops]  # type: list[Op]
+        concat = OpConcat([i.out for i in li_ops])
+        ops.append(concat)
+
+        reg_assignments = try_allocate_registers_without_spilling(ops)
+
+        expected_reg_assignments = {concat.dest: expected_dest_reg}
+        for li_op, reg in zip(li_ops, expected_regs):
+            expected_reg_assignments[li_op.out] = reg
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def test_try_alloc_concat_1(self):
+        self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
+
+    def test_try_alloc_concat_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
+
+    def test_try_alloc_concat_3_5(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_5_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_1_2_3_4_5_6(self):
+        self.tst_try_alloc_concat([
+            GPRRange(14, 1),
+            GPRRange(15, 2),
+            GPRRange(17, 3),
+            GPRRange(20, 4),
+            GPRRange(24, 5),
+            GPRRange(29, 6),
+        ], GPRRange(14, 21))
 
 
 if __name__ == "__main__":