working on rewriting compiler ir to fix reg alloc issues
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 28 Oct 2022 09:24:23 +0000 (02:24 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 28 Oct 2022 09:24:23 +0000 (02:24 -0700)
20 files changed:
src/bigint_presentation_code/_tests/__init__.py [new file with mode: 0644]
src/bigint_presentation_code/_tests/test_compiler_ir.py [new file with mode: 0644]
src/bigint_presentation_code/_tests/test_matrix.py [new file with mode: 0644]
src/bigint_presentation_code/_tests/test_register_allocator.py [new file with mode: 0644]
src/bigint_presentation_code/_tests/test_toom_cook.py [new file with mode: 0644]
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/matrix.py
src/bigint_presentation_code/py.typed [new file with mode: 0644]
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/test_compiler_ir.py [deleted file]
src/bigint_presentation_code/test_matrix.py [deleted file]
src/bigint_presentation_code/test_register_allocator.py [deleted file]
src/bigint_presentation_code/test_toom_cook.py [deleted file]
src/bigint_presentation_code/toom_cook.py
src/bigint_presentation_code/type_util.py [new file with mode: 0644]
src/bigint_presentation_code/type_util.pyi [new file with mode: 0644]
src/bigint_presentation_code/util.py
src/bigint_presentation_code/util.pyi [deleted file]
typings/cached_property.pyi

diff --git a/src/bigint_presentation_code/_tests/__init__.py b/src/bigint_presentation_code/_tests/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py
new file mode 100644 (file)
index 0000000..820c305
--- /dev/null
@@ -0,0 +1,120 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+                                                  GlobalMem, GPRRange, GPRType,
+                                                  OpBigIntAddSub, OpConcat,
+                                                  OpCopy, OpFuncArg,
+                                                  OpInputMem, OpLI, OpLoad,
+                                                  OpSetCA, OpSetVLImm, OpStore,
+                                                  RegLoc, SSAVal, XERBit,
+                                                  generate_assembly,
+                                                  op_set_to_list)
+
+
+class TestCompilerIR(unittest.TestCase):
+    maxDiff = None
+
+    def test_op_set_to_list(self):
+        fn = Fn()
+        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+        op1 = OpCopy(fn, op0.out, GPRType())
+        arg = op1.dest
+        op2 = OpInputMem(fn)
+        mem = op2.out
+        op3 = OpSetVLImm(fn, 32)
+        vl = op3.out
+        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+        a = op4.RT
+        op5 = OpLI(fn, 1)
+        b_0 = op5.out
+        op6 = OpSetVLImm(fn, 31)
+        vl = op6.out
+        op7 = OpLI(fn, 0, vl=vl)
+        b_rest = op7.out
+        op8 = OpConcat(fn, [b_0, b_rest])
+        b = op8.dest
+        op9 = OpSetVLImm(fn, 32)
+        vl = op9.out
+        op10 = OpSetCA(fn, False)
+        ca = op10.out
+        op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+        s = op11.out
+        op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        mem = op12.mem_out
+
+        expected_ops = [
+            op10,  # OpSetCA(fn, False)
+            op9,  # OpSetVLImm(fn, 32)
+            op6,  # OpSetVLImm(fn, 31)
+            op5,  # OpLI(fn, 1)
+            op3,  # OpSetVLImm(fn, 32)
+            op2,  # OpInputMem(fn)
+            op0,  # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+            op7,  # OpLI(fn, 0, vl=vl)
+            op1,  # OpCopy(fn, op0.out, GPRType())
+            op8,  # OpConcat(fn, [b_0, b_rest])
+            op4,  # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+            op11,  # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+            op12,  # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        ]
+
+        ops = op_set_to_list(fn.ops[::-1])
+        if ops != expected_ops:
+            self.assertEqual(repr(ops), repr(expected_ops))
+
+    def tst_generate_assembly(self, use_reg_alloc=False):
+        fn = Fn()
+        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+        op1 = OpCopy(fn, op0.out, GPRType())
+        arg = op1.dest
+        op2 = OpInputMem(fn)
+        mem = op2.out
+        op3 = OpSetVLImm(fn, 32)
+        vl = op3.out
+        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+        a = op4.RT
+        op5 = OpLI(fn, 0, vl=vl)
+        b = op5.out
+        op6 = OpSetCA(fn, True)
+        ca = op6.out
+        op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+        s = op7.out
+        op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        mem = op8.mem_out
+
+        assigned_registers = {
+            op0.out: GPRRange(start=3, length=1),
+            op1.dest: GPRRange(start=3, length=1),
+            op2.out: GlobalMem.GlobalMem,
+            op3.out: VL.VL_MAXVL,
+            op4.RT: GPRRange(start=78, length=32),
+            op5.out: GPRRange(start=46, length=32),
+            op6.out: XERBit.CA,
+            op7.out: GPRRange(start=14, length=32),
+            op7.CA_out: XERBit.CA,
+            op8.mem_out: GlobalMem.GlobalMem,
+        }  # type: dict[SSAVal, RegLoc] | None
+
+        if use_reg_alloc:
+            assigned_registers = None
+
+        asm = generate_assembly(fn.ops, assigned_registers)
+        self.assertEqual(asm, [
+            "setvl 0, 0, 32, 0, 1, 1",
+            "sv.ld *78, 0(3)",
+            "sv.addi *46, 0, 0",
+            "subfic 0, 0, -1",
+            "sv.adde *14, *78, *46",
+            "sv.std *14, 0(3)",
+            "bclr 20, 0, 0",
+        ])
+
+    def test_generate_assembly(self):
+        self.tst_generate_assembly()
+
+    def test_generate_assembly_with_register_allocator(self):
+        self.tst_generate_assembly(use_reg_alloc=True)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/bigint_presentation_code/_tests/test_matrix.py b/src/bigint_presentation_code/_tests/test_matrix.py
new file mode 100644 (file)
index 0000000..1a56df0
--- /dev/null
@@ -0,0 +1,113 @@
+import unittest
+from fractions import Fraction
+
+from bigint_presentation_code.matrix import Matrix, SpecialMatrix
+
+
+class TestMatrix(unittest.TestCase):
+    def test_repr(self):
+        self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
+                                            3, 4, 5])),
+                         'Matrix(height=2, width=3, data=[\n'
+                         '    0, 1, 2,\n'
+                         '    3, 4, 5,\n'
+                         '])')
+        self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
+                                            3, 4, 5])),
+                         'Matrix(height=2, width=3, data=[\n'
+                         '    0, 1, Fraction(2, 3),\n'
+                         '    3, 4, 5,\n'
+                         '])')
+        self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
+        self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
+
+    def test_eq(self):
+        self.assertFalse(Matrix(1, 1) == 5)
+        self.assertFalse(5 == Matrix(1, 1))
+        self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
+        self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
+        self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
+        self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
+        self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
+
+    def test_add(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+                         + Matrix(2, 2, [40, 30, 20, 10]),
+                         Matrix(2, 2, [41, 32, 23, 14]))
+
+    def test_identity(self):
+        self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity),
+                         Matrix(2, 2, [1, 0,
+                                       0, 1]))
+        self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity),
+                         Matrix(1, 3, [1, 0, 0]))
+        self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity),
+                         Matrix(2, 3, [1, 0, 0,
+                                       0, 1, 0]))
+        self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity),
+                         Matrix(3, 3, [1, 0, 0,
+                                       0, 1, 0,
+                                       0, 0, 1]))
+
+    def test_sub(self):
+        self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
+                         - Matrix(2, 2, [-1, -2, -3, -4]),
+                         Matrix(2, 2, [41, 32, 23, 14]))
+
+    def test_neg(self):
+        self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
+                         Matrix(2, 2, [-40, -30, -20, -10]))
+
+    def test_mul(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
+                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+        self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
+                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+
+    def test_matmul(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+                         @ Matrix(2, 2, [4, 3, 2, 1]),
+                         Matrix(2, 2, [8, 5, 20, 13]))
+        self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
+                         @ Matrix(2, 1, [1, 2]),
+                         Matrix(3, 1, [16, 10, 4]))
+
+    def test_inverse(self):
+        self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
+        self.assertEqual(Matrix(1, 1, [2]).inverse(),
+                         Matrix(1, 1, [Fraction(1, 2)]))
+        self.assertEqual(Matrix(1, 1, [1]).inverse(),
+                         Matrix(1, 1, [1]))
+        self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
+                         Matrix(2, 2, [1, 0, -1, 1]))
+        self.assertEqual(Matrix(3, 3, [0, 1, 0,
+                                       1, 0, 0,
+                                       0, 0, 1]).inverse(),
+                         Matrix(3, 3, [0, 1, 0,
+                                       1, 0, 0,
+                                       0, 0, 1]))
+        _1_2 = Fraction(1, 2)
+        _1_3 = Fraction(1, 3)
+        _1_6 = Fraction(1, 6)
+        self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
+                                       1, 1, 1, 1, 1,
+                                       1, -1, 1, -1, 1,
+                                       1, -2, 4, -8, 16,
+                                       0, 0, 0, 0, 1]).inverse(),
+                         Matrix(5, 5, [1, 0, 0, 0, 0,
+                                       _1_2, _1_3, -1, _1_6, -2,
+                                       -1, _1_2, _1_2, 0, -1,
+                                       -_1_2, _1_6, _1_2, -_1_6, 2,
+                                       0, 0, 0, 0, 1]))
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(1, 1, [0]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [0, 0, 1, 1]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [1, 0, 1, 0]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [1, 1, 1, 1]).inverse()
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py
new file mode 100644 (file)
index 0000000..1eff254
--- /dev/null
@@ -0,0 +1,201 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+                                                  GlobalMem, GPRRange, GPRType,
+                                                  OpBigIntAddSub, OpConcat,
+                                                  OpCopy, OpFuncArg,
+                                                  OpInputMem, OpLI, OpLoad,
+                                                  OpSetCA, OpSetVLImm, OpStore,
+                                                  XERBit)
+from bigint_presentation_code.register_allocator import (
+    AllocationFailed, MergedRegSet, allocate_registers,
+    try_allocate_registers_without_spilling)
+
+
+class TestMergedRegSet(unittest.TestCase):
+    maxDiff = None
+
+    def test_from_equality_constraint(self):
+        fn = Fn()
+        li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
+        li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
+        li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            li0x1.out,
+            li0x2.out,
+            li0x3.out,
+        ]), MergedRegSet({
+            li0x1.out: 0,
+            li0x2.out: 1,
+            li0x3.out: 3,
+        }.items()))
+        self.assertEqual(MergedRegSet.from_equality_constraint([
+            li0x2.out,
+            li0x1.out,
+            li0x3.out,
+        ]), MergedRegSet({
+            li0x2.out: 0,
+            li0x1.out: 2,
+            li0x3.out: 3,
+        }.items()))
+
+
+class TestRegisterAllocator(unittest.TestCase):
+    maxDiff = None
+
+    def test_try_alloc_fail(self):
+        fn = Fn()
+        op0 = OpSetVLImm(fn, 52)
+        op1 = OpLI(fn, 0, vl=op0.out)
+        op2 = OpSetVLImm(fn, 64)
+        op3 = OpLI(fn, 0, vl=op2.out)
+        op4 = OpConcat(fn, [op1.out, op3.out])
+
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+        self.assertEqual(
+            repr(reg_assignments),
+            "AllocationFailed("
+            "node=IGNode(#0, merged_reg_set=MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "edges={}, reg=None), "
+            "live_intervals=LiveIntervals(live_intervals={"
+            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
+            "LiveInterval(first_write=0, last_use=1), "
+            "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]): "
+            "LiveInterval(first_write=1, last_use=4), "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
+            "LiveInterval(first_write=2, last_use=3)}, "
+            "merged_reg_sets=MergedRegSets(data={"
+            "<#0.out: KnownVLType(length=52)>: "
+            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
+            "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "<#2.out: KnownVLType(length=64)>: "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
+            "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])}), "
+            "reg_sets_live_after={"
+            "0: OFSet([MergedRegSet(["
+            "(<#0.out: KnownVLType(length=52)>, 0)])]), "
+            "1: OFSet([MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
+            "2: OFSet([MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
+            "3: OFSet([MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
+            "4: OFSet()}), "
+            "interference_graph=InterferenceGraph(nodes={"
+            "...: IGNode(#0, merged_reg_set=MergedRegSet(["
+            "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
+            "...: IGNode(#1, merged_reg_set=MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
+            "...: IGNode(#2, merged_reg_set=MergedRegSet(["
+            "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
+        )
+
+    def test_try_alloc_bigint_inc(self):
+        fn = Fn()
+        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+        op1 = OpCopy(fn, op0.out, GPRType())
+        arg = op1.dest
+        op2 = OpInputMem(fn)
+        mem = op2.out
+        op3 = OpSetVLImm(fn, 32)
+        vl = op3.out
+        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+        a = op4.RT
+        op5 = OpLI(fn, 0, vl=vl)
+        b = op5.out
+        op6 = OpSetCA(fn, True)
+        ca = op6.out
+        op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+        s = op7.out
+        op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        mem = op8.mem_out
+
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+
+        expected_reg_assignments = {
+            op0.out: GPRRange(start=3, length=1),
+            op1.dest: GPRRange(start=3, length=1),
+            op2.out: GlobalMem.GlobalMem,
+            op3.out: VL.VL_MAXVL,
+            op4.RT: GPRRange(start=78, length=32),
+            op5.out: GPRRange(start=46, length=32),
+            op6.out: XERBit.CA,
+            op7.out: GPRRange(start=14, length=32),
+            op7.CA_out: XERBit.CA,
+            op8.mem_out: GlobalMem.GlobalMem,
+        }
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
+        # type: (list[GPRRange], GPRRange) -> None
+        fn = Fn()
+        inputs = []
+        expected_reg_assignments = {}
+        for i, r in enumerate(expected_regs):
+            vl = OpSetVLImm(fn, r.length).out
+            expected_reg_assignments[vl] = VL.VL_MAXVL
+            inp = OpLI(fn, i, vl=vl).out
+            inputs.append(inp)
+            expected_reg_assignments[inp] = r
+        concat = OpConcat(fn, inputs)
+        expected_reg_assignments[concat.dest] = expected_dest_reg
+
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+
+        for inp, reg in zip(inputs, expected_regs):
+            expected_reg_assignments[inp] = reg
+
+        self.assertEqual(reg_assignments, expected_reg_assignments)
+
+    def test_try_alloc_concat_1(self):
+        self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
+
+    def test_try_alloc_concat_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
+
+    def test_try_alloc_concat_3_5(self):
+        self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_5_3(self):
+        self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
+                                  GPRRange(3, 8))
+
+    def test_try_alloc_concat_1_2_3_4_5_6(self):
+        self.tst_try_alloc_concat([
+            GPRRange(14, 1),
+            GPRRange(15, 2),
+            GPRRange(17, 3),
+            GPRRange(20, 4),
+            GPRRange(24, 5),
+            GPRRange(29, 6),
+        ], GPRRange(14, 21))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py
new file mode 100644 (file)
index 0000000..6fff570
--- /dev/null
@@ -0,0 +1,378 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
+                                                  GlobalMem, GPRRange,
+                                                  GPRRangeType, OpCopy,
+                                                  OpFuncArg, OpInputMem,
+                                                  OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
+                                                  generate_assembly)
+from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.util import FMap
+
+
+class SimpleMul192x192:
+    def __init__(self):
+        self.fn = fn = Fn()
+        self.mem_in = mem = OpInputMem(fn).out
+        self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
+        self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
+        self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
+        dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
+        vl = OpSetVLImm(fn, 3).out
+        lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
+        rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
+        retval = simple_mul(fn, lhs, rhs)
+        vl = OpSetVLImm(fn, 6).out
+        self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
+                               mem_in=mem, vl=vl).mem_out
+
+
+class TestToomCook(unittest.TestCase):
+    maxDiff = None
+
+    def test_toom_2_repr(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        # print(repr(repr(TOOM_2)))
+        self.assertEqual(
+            repr(TOOM_2),
+            "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
+            "eval_points=(0, 1, POINT_AT_INFINITY), "
+            "lhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+            " rhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+            " prod_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpSub(lhs="
+            "EvalOpSub(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({"
+            "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
+        )
+
+    def test_toom_2_5_repr(self):
+        TOOM_2_5 = ToomCookInstance.make_toom_2_5()
+        # print(repr(repr(TOOM_2_5)))
+        self.assertEqual(
+            repr(TOOM_2_5),
+            "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
+            "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=2, rhs=0, "
+            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=1, rhs=0, "
+            "poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({"
+            "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
+            "EvalOpSub(lhs="
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=2, rhs=0, "
+            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=1, rhs=0, "
+            "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
+            "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=2, rhs=0, "
+            "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
+            "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpSub(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=1, rhs=0, "
+            "poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, "
+            "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
+            "prod_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs=EvalOpInput(lhs=2, rhs=0, "
+            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+            "rhs=2, "
+            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
+            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
+            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
+            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
+            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
+        )
+
+    def test_reversed_toom_2_5_repr(self):
+        TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
+        # print(repr(repr(TOOM_2_5)))
+        self.assertEqual(
+            repr(TOOM_2_5),
+            "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
+            "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpSub(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+            " rhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs=EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
+            "EvalOpSub(lhs=EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
+            " prod_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+            "rhs=2, "
+            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
+            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
+            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
+            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "poly=EvalOpPoly("
+            "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
+            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
+        )
+
+    def test_simple_mul_192x192_pre_ra_sim(self):
+        # test multiplying:
+        #   0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
+        # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
+        # ==
+        # int("0x00074736574206e_6f69746163696c70"
+        #     "_69746c756d207469_622d3438333e2d32"
+        #     "_3931783239312079_7261727469627261", base=0)
+        # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
+        #                   'little')
+        code = SimpleMul192x192()
+        dest_ptr = 0x100
+        state = PreRASimState(
+            gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
+            stack_slots={}, fixed_gprs={
+                code.dest_ptr_in: (dest_ptr,),
+                code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
+                              0x000191acb262e15b),
+                code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
+                              0x4a37c0567bcbab53)
+            })
+        code.fn.pre_ra_sim(state)
+        expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
+        OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
+        expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
+        mem_out = state.global_mems[code.mem_out]
+        out_bytes = bytes(
+            mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
+        self.assertEqual(out_bytes, expected_bytes)
+
+    def test_simple_mul_192x192_ops(self):
+        code = SimpleMul192x192()
+        fn = code.fn
+        self.assertEqual([repr(v) for v in fn.ops], [
+            'OpInputMem(#0, <#0.out: GlobalMemType()>)',
+            'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
+            'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
+            'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
+            'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
+            'vl=None)',
+            'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
+            'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
+            'src=<#2.out: <fixed(<r4..len=3>)>>, '
+            'vl=<#5.out: KnownVLType(length=3)>)',
+            'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
+            'src=<#3.out: <fixed(<r7..len=3>)>>, '
+            'vl=<#5.out: KnownVLType(length=3)>)',
+            'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
+            '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
+            'src=<#7.dest: <gpr_ty[3]>>)',
+            'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
+            'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
+            'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
+            '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
+            'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
+            '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
+            'OpSetCA(#15, <#15.out: CAType()>, value=False)',
+            'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
+            'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
+            'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
+            'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+            'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
+            'is_sub=False, vl=None)',
+            'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
+            '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
+            '<#17.out: <gpr_ty[1]>>))',
+            'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
+            '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
+            'OpSetCA(#21, <#21.out: CAType()>, value=False)',
+            'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
+            'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
+            'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
+            'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+            'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
+            'is_sub=False, vl=None)',
+            'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
+            '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
+            '<#23.out: <gpr_ty[1]>>))',
+            'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
+            'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
+            'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
+            'mem_in=<#0.out: GlobalMemType()>, '
+            '<#26.mem_out: GlobalMemType()>, '
+            'vl=<#25.out: KnownVLType(length=6)>)'
+        ])
+
+    # FIXME: register allocator currently allocates wrong registers
+    @unittest.expectedFailure
+    def test_simple_mul_192x192_reg_alloc(self):
+        code = SimpleMul192x192()
+        fn = code.fn
+        assigned_registers = allocate_registers(fn.ops)
+        self.assertEqual(assigned_registers, {
+            fn.ops[13].RS: GPRRange(9),  # type: ignore
+            fn.ops[14].results[0]: GPRRange(6),  # type: ignore
+            fn.ops[14].results[1]: GPRRange(7, length=3),  # type: ignore
+            fn.ops[15].out: XERBit.CA,  # type: ignore
+            fn.ops[16].out: GPRRange(7, length=3),  # type: ignore
+            fn.ops[16].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[17].out: GPRRange(10),  # type: ignore
+            fn.ops[17].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[18].dest: GPRRange(6, length=5),  # type: ignore
+            fn.ops[19].RT: GPRRange(3, length=3),  # type: ignore
+            fn.ops[19].RS: GPRRange(9),  # type: ignore
+            fn.ops[20].results[0]: GPRRange(6, length=2),  # type: ignore
+            fn.ops[20].results[1]: GPRRange(8, length=3),  # type: ignore
+            fn.ops[21].out: XERBit.CA,  # type: ignore
+            fn.ops[22].out: GPRRange(8, length=3),  # type: ignore
+            fn.ops[22].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[23].out: GPRRange(11),  # type: ignore
+            fn.ops[23].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[24].dest: GPRRange(6, length=6),  # type: ignore
+            fn.ops[25].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[26].mem_out: GlobalMem.GlobalMem,  # type: ignore
+            fn.ops[0].out: GlobalMem.GlobalMem,  # type: ignore
+            fn.ops[1].out: GPRRange(3),  # type: ignore
+            fn.ops[2].out: GPRRange(4, length=3),  # type: ignore
+            fn.ops[3].out: GPRRange(7, length=3),  # type: ignore
+            fn.ops[4].dest: GPRRange(12),  # type: ignore
+            fn.ops[5].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[6].dest: GPRRange(17, length=3),  # type: ignore
+            fn.ops[7].dest: GPRRange(14, length=3),  # type: ignore
+            fn.ops[8].results[0]: GPRRange(14),  # type: ignore
+            fn.ops[8].results[1]: GPRRange(15),  # type: ignore
+            fn.ops[8].results[2]: GPRRange(16),  # type: ignore
+            fn.ops[9].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[10].out: GPRRange(9),  # type: ignore
+            fn.ops[11].RT: GPRRange(6, length=3),  # type: ignore
+            fn.ops[11].RS: GPRRange(9),  # type: ignore
+            fn.ops[12].dest: GPRRange(6, length=4),  # type: ignore
+            fn.ops[13].RT: GPRRange(3, length=3)  # type: ignore
+        })
+        self.fail("register allocator currently allocates wrong registers")
+
+    # FIXME: register allocator currently allocates wrong registers
+    @unittest.expectedFailure
+    def test_simple_mul_192x192_asm(self):
+        code = SimpleMul192x192()
+        asm = generate_assembly(code.fn.ops)
+        self.assertEqual(asm, [
+            'or 12, 3, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *17, *4, *4',
+            'sv.or *14, *7, *7',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'addi 9, 0, 0',
+            'sv.maddedu *6, *17, 14, 9',
+            'sv.maddedu *3, *17, 15, 9',
+            'addic 0, 0, 0',
+            'sv.adde *7, *3, *7',
+            'adde 10, 9, 9',
+            'sv.maddedu *3, *17, 16, 9',
+            'addic 0, 0, 0',
+            'sv.adde *8, *3, *8',
+            'adde 11, 9, 9',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'sv.std *6, 0(12)',
+            'bclr 20, 0, 0'
+        ])
+        self.fail("register allocator currently allocates wrong registers")
+
+
+if __name__ == "__main__":
+    unittest.main()
index 77e44a22b4ce2c4a890a8918f37fa53a076ec433..c5741741f559bdfffb88503e88611e1687c160d2 100644 (file)
@@ -1,3 +1,4 @@
+# type: ignore
 """
 Compiler IR for Toom-Cook algorithm generator for SVP64
 
@@ -12,7 +13,8 @@ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
 
 from nmutil.plain_data import fields, plain_data
 
-from bigint_presentation_code.util import FMap, OFSet, OSet, final
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import FMap, OFSet, OSet
 
 
 class ABCEnumMeta(EnumMeta, ABCMeta):
index eacceb410d8726d49d32f9be1192b8446c46a707..666df1402c2745d57de81bc3430ec3484efa429e 100644 (file)
@@ -1,20 +1,23 @@
+from collections import defaultdict
 import enum
 from enum import Enum, unique
-from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
-from weakref import WeakValueDictionary
+from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
+                                           OSet, FMap)
+from functools import lru_cache
 
 
 @final
 class Fn:
     def __init__(self):
         self.ops = []  # type: list[Op]
-        op_names = WeakValueDictionary()
-        self.__op_names = op_names  # type: WeakValueDictionary[str, Op]
+        self.__op_names = _WeakVDict()  # type: _WeakVDict[str, Op]
         self.__next_name_suffix = 2
 
     def _add_op_with_unused_name(self, op, name=""):
@@ -32,278 +35,464 @@ class Fn:
             self.__next_name_suffix += 1
 
     def __repr__(self):
+        # type: () -> str
         return "<Fn>"
 
 
 @unique
 @final
-class RegKind(Enum):
-    GPR = enum.auto()
+class BaseTy(Enum):
+    I64 = enum.auto()
     CA = enum.auto()
     VL_MAXVL = enum.auto()
 
     @cached_property
     def only_scalar(self):
-        if self is RegKind.GPR:
+        # type: () -> bool
+        if self is BaseTy.I64:
             return False
-        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+        elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
             return True
         else:
             assert_never(self)
 
     @cached_property
-    def reg_count(self):
-        if self is RegKind.GPR:
+    def max_reg_len(self):
+        # type: () -> int
+        if self is BaseTy.I64:
             return 128
-        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+        elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
             return 1
         else:
             assert_never(self)
 
     def __repr__(self):
-        return "RegKind." + self._name_
+        return "BaseTy." + self._name_
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OperandType:
-    __slots__ = "kind", "vec"
+class Ty:
+    __slots__ = "base_ty", "reg_len"
 
-    def __init__(self, kind, vec):
-        # type: (RegKind, bool) -> None
-        self.kind = kind
-        if kind.only_scalar and vec:
-            raise ValueError(f"kind={kind} must have vec=False")
-        self.vec = vec
-
-    def get_length(self, maxvl):
-        # type: (int) -> int
-        # here's where subvl and elwid would be accounted for
-        if self.vec:
-            return maxvl
-        return 1
+    @staticmethod
+    def validate(base_ty, reg_len):
+        # type: (BaseTy, int) -> str | None
+        """ return a string with the error if the combination is invalid,
+        otherwise return None
+        """
+        if base_ty.only_scalar and reg_len != 1:
+            return f"can't create a vector of an only-scalar type: {base_ty}"
+        if reg_len < 1 or reg_len > base_ty.max_reg_len:
+            return "reg_len out of range"
+        return None
+
+    def __init__(self, base_ty, reg_len):
+        # type: (BaseTy, int) -> None
+        msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+        if msg is not None:
+            raise ValueError(msg)
+        self.base_ty = base_ty
+        self.reg_len = reg_len
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@unique
 @final
-class RegShape:
-    __slots__ = "kind", "length"
+class LocKind(Enum):
+    GPR = enum.auto()
+    StackI64 = enum.auto()
+    CA = enum.auto()
+    VL_MAXVL = enum.auto()
 
-    def __init__(self, kind, length=1):
-        # type: (RegKind, int) -> None
-        self.kind = kind
-        if length < 1 or length > kind.reg_count:
-            raise ValueError("invalid length")
-        self.length = length
+    @cached_property
+    def base_ty(self):
+        # type: () -> BaseTy
+        if self is LocKind.GPR or self is LocKind.StackI64:
+            return BaseTy.I64
+        if self is LocKind.CA:
+            return BaseTy.CA
+        if self is LocKind.VL_MAXVL:
+            return BaseTy.VL_MAXVL
+        else:
+            assert_never(self)
 
-    def try_concat(self, *others):
-        # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
-        kind = self.kind
-        length = self.length
-        for other in others:
-            if isinstance(other, (Reg, RegClass)):
-                other = other.shape
-            if other is None:
-                return None
-            if other.kind != self.kind:
-                return None
-            length += other.length
-        if length > kind.reg_count:
-            return None
-        return RegShape(kind=kind, length=length)
+    @cached_property
+    def loc_count(self):
+        # type: () -> int
+        if self is LocKind.StackI64:
+            return 1024
+        if self is LocKind.GPR or self is LocKind.CA \
+                or self is LocKind.VL_MAXVL:
+            return self.base_ty.max_reg_len
+        else:
+            assert_never(self)
+
+    def __repr__(self):
+        return "LocKind." + self._name_
 
 
-@plain_data(frozen=True, unsafe_hash=True)
 @final
-class Reg:
-    __slots__ = "shape", "start"
-
-    def __init__(self, shape, start):
-        # type: (RegShape, int) -> None
-        self.shape = shape
-        if start < 0 or start + shape.length > shape.kind.reg_count:
-            raise ValueError("start not in valid range")
-        self.start = start
+@unique
+class LocSubKind(Enum):
+    BASE_GPR = enum.auto()
+    SV_EXTRA2_VGPR = enum.auto()
+    SV_EXTRA2_SGPR = enum.auto()
+    SV_EXTRA3_VGPR = enum.auto()
+    SV_EXTRA3_SGPR = enum.auto()
+    StackI64 = enum.auto()
+    CA = enum.auto()
+    VL_MAXVL = enum.auto()
 
-    @property
+    @cached_property
     def kind(self):
-        return self.shape.kind
+        # type: () -> LocKind
+        # pyright fails typechecking when using `in` here:
+        # reported: https://github.com/microsoft/pyright/issues/4102
+        if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
+                or self is LocSubKind.SV_EXTRA2_SGPR \
+                or self is LocSubKind.SV_EXTRA3_VGPR \
+                or self is LocSubKind.SV_EXTRA3_SGPR:
+            return LocKind.GPR
+        if self is LocSubKind.StackI64:
+            return LocKind.StackI64
+        if self is LocSubKind.CA:
+            return LocKind.CA
+        if self is LocSubKind.VL_MAXVL:
+            return LocKind.VL_MAXVL
+        assert_never(self)
 
     @property
-    def length(self):
-        return self.shape.length
+    def base_ty(self):
+        return self.kind.base_ty
+
+    @lru_cache()
+    def allocatable_locs(self, ty):
+        # type: (Ty) -> LocSet
+        if ty.base_ty != self.base_ty:
+            raise ValueError("type mismatch")
+        raise NotImplementedError  # FIXME: finish
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericTy:
+    __slots__ = "base_ty", "is_vec"
+
+    def __init__(self, base_ty, is_vec):
+        # type: (BaseTy, bool) -> None
+        self.base_ty = base_ty
+        if base_ty.only_scalar and is_vec:
+            raise ValueError(f"base_ty={base_ty} requires is_vec=False")
+        self.is_vec = is_vec
+
+    def instantiate(self, maxvl):
+        # type: (int) -> Ty
+        # here's where subvl and elwid would be accounted for
+        if self.is_vec:
+            return Ty(self.base_ty, maxvl)
+        return Ty(self.base_ty, 1)
+
+    def can_instantiate_to(self, ty):
+        # type: (Ty) -> bool
+        if self.base_ty != ty.base_ty:
+            return False
+        if self.is_vec:
+            return True
+        return ty.reg_len == 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Loc:
+    __slots__ = "kind", "start", "reg_len"
+
+    @staticmethod
+    def validate(kind, start, reg_len):
+        # type: (LocKind, int, int) -> str | None
+        msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
+        if msg is not None:
+            return msg
+        if reg_len > kind.loc_count:
+            return "invalid reg_len"
+        if start < 0 or start + reg_len > kind.loc_count:
+            return "start not in valid range"
+        return None
+
+    @staticmethod
+    def try_make(kind, start, reg_len):
+        # type: (LocKind, int, int) -> Loc | None
+        msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
+        if msg is None:
+            return None
+        return Loc(kind=kind, start=start, reg_len=reg_len)
+
+    def __init__(self, kind, start, reg_len):
+        # type: (LocKind, int, int) -> None
+        msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+        if msg is not None:
+            raise ValueError(msg)
+        self.kind = kind
+        self.reg_len = reg_len
+        self.start = start
 
     def conflicts(self, other):
-        # type: (Reg) -> bool
-        return (self.kind == other.kind
+        # type: (Loc) -> bool
+        return (self.kind != other.kind
                 and self.start < other.stop and other.start < self.stop)
 
+    @staticmethod
+    def make_ty(kind, reg_len):
+        # type: (LocKind, int) -> Ty
+        return Ty(base_ty=kind.base_ty, reg_len=reg_len)
+
+    @cached_property
+    def ty(self):
+        # type: () -> Ty
+        return self.make_ty(kind=self.kind, reg_len=self.reg_len)
+
     @property
     def stop(self):
-        return self.start + self.length
+        # type: () -> int
+        return self.start + self.reg_len
 
     def try_concat(self, *others):
-        # type: (*Reg | None) -> Reg | None
-        shape = self.shape.try_concat(*others)
-        if shape is None:
-            return None
+        # type: (*Loc | None) -> Loc | None
+        reg_len = self.reg_len
         stop = self.stop
         for other in others:
-            assert other is not None, "already caught by RegShape.try_concat"
+            if other is None or other.kind != self.kind:
+                return None
             if stop != other.start:
                 return None
             stop = other.stop
-        return Reg(shape, self.start)
+            reg_len += other.reg_len
+        return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
 
 
+@plain_data(frozen=True, eq=False, repr=False)
 @final
-class RegClass(AbstractSet[Reg]):
-    def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
-        # type: (Iterable[Reg | int], RegShape | None, int) -> None
-        for reg_or_start in regs_or_starts:
-            if isinstance(reg_or_start, Reg):
-                if shape is None:
-                    shape = reg_or_start.shape
-                elif shape != reg_or_start.shape:
-                    raise ValueError(f"conflicting RegShapes: {shape} and "
-                                     f"{reg_or_start.shape}")
-                start = reg_or_start.start
-            else:
-                start = reg_or_start
-            if start < 0:
-                raise ValueError("a Reg's start is out of range")
-            starts_bitset |= 1 << start
-        if starts_bitset == 0:
-            shape = None
-        self.__shape = shape
-        self.__starts_bitset = starts_bitset
-        if shape is None:
-            if starts_bitset != 0:
-                raise ValueError("non-empty RegClass must have non-None shape")
+class LocSet(AbstractSet[Loc]):
+    __slots__ = "starts", "ty"
+
+    def __init__(self, __locs=()):
+        # type: (Iterable[Loc]) -> None
+        if isinstance(__locs, LocSet):
+            self.starts = __locs.starts  # type: FMap[LocKind, FBitSet]
+            self.ty = __locs.ty  # type: Ty | None
             return
-        if self.stops_bitset >= 1 << shape.kind.reg_count:
-            raise ValueError("a Reg's start is out of range")
-
-    @property
-    def shape(self):
-        # type: () -> RegShape | None
-        return self.__shape
-
-    @property
-    def starts_bitset(self):
-        # type: () -> int
-        return self.__starts_bitset
-
-    @property
-    def stops_bitset(self):
-        # type: () -> int
-        if self.__shape is None:
-            return 0
-        return self.__starts_bitset << self.__shape.length
-
-    @cached_property
-    def starts(self):
-        # type: () -> OFSet[int]
-        if self.length is None:
-            return OFSet()
-        # TODO: fixme
-        # return OFSet(for i in range(self.length))
+        starts = {i: BitSet() for i in LocKind}
+        ty = None
+        for loc in __locs:
+            if ty is None:
+                ty = loc.ty
+            if ty != loc.ty:
+                raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+            starts[loc.kind].add(loc.start)
+        self.starts = FMap(
+            (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
+        self.ty = ty
 
     @cached_property
     def stops(self):
-        # type: () -> OFSet[int]
-        if self.__shape is None:
-            return OFSet()
-        return OFSet(i + self.__shape.length for i in self.__starts)
+        # type: () -> FMap[LocKind, FBitSet]
+        if self.ty is None:
+            return FMap()
+        sh = self.ty.reg_len
+        return FMap(
+            (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
 
     @property
-    def kind(self):
-        if self.__shape is None:
+    def kinds(self):
+        # type: () -> AbstractSet[LocKind]
+        return self.starts.keys()
+
+    @property
+    def reg_len(self):
+        # type: () -> int | None
+        if self.ty is None:
             return None
-        return self.__shape.kind
+        return self.ty.reg_len
 
     @property
-    def length(self):
-        """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
-        if self.__shape is None:
+    def base_ty(self):
+        # type: () -> BaseTy | None
+        if self.ty is None:
             return None
-        return self.__shape.length
+        return self.ty.base_ty
 
     def concat(self, *others):
-        # type: (*RegClass) -> RegClass
-        shape = self.__shape
-        if shape is None:
-            return RegClass()
-        shape = shape.try_concat(*others)
-        if shape is None:
-            return RegClass()
-        starts = OSet(self.starts)
-        offset = shape.length
+        # type: (*LocSet) -> LocSet
+        if self.ty is None:
+            return LocSet()
+        base_ty = self.ty.base_ty
+        reg_len = self.ty.reg_len
+        starts = {k: BitSet(v) for k, v in self.starts.items()}
         for other in others:
-            assert other.__shape is not None, \
-                "already caught by RegShape.try_concat"
-            starts &= OSet(i - offset for i in other.starts)
-            offset += other.__shape.length
-        return RegClass(starts, shape=shape)
-
-    def __contains__(self, reg):
-        # type: (Reg) -> bool
-        return reg.shape == self.shape and reg.start in self.starts
+            if other.ty is None:
+                return LocSet()
+            if other.ty.base_ty != base_ty:
+                return LocSet()
+            for kind, other_starts in other.starts.items():
+                if kind not in starts:
+                    continue
+                starts[kind].bits &= other_starts.bits >> reg_len
+                if starts[kind] == 0:
+                    del starts[kind]
+                    if len(starts) == 0:
+                        return LocSet()
+            reg_len += other.ty.reg_len
+
+        def locs():
+            # type: () -> Iterable[Loc]
+            for kind, v in starts.items():
+                for start in v:
+                    loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
+                    if loc is not None:
+                        yield loc
+        return LocSet(locs())
+
+    def __contains__(self, loc):
+        # type: (Loc | Any) -> bool
+        if not isinstance(loc, Loc) or loc.ty == self.ty:
+            return False
+        if loc.kind not in self.starts:
+            return False
+        return loc.start in self.starts[loc.kind]
 
     def __iter__(self):
-        # type: () -> Iterator[Reg]
-        if self.shape is None:
+        # type: () -> Iterator[Loc]
+        if self.ty is None:
             return
-        for start in self.starts:
-            yield Reg(shape=self.shape, start=start)
+        for kind, starts in self.starts.items():
+            for start in starts:
+                yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
+
+    @cached_property
+    def __len(self):
+        return sum((len(v) for v in self.starts.values()), 0)
 
     def __len__(self):
-        return len(self.starts)
+        return self.__len
 
-    def __hash__(self):
+    @cached_property
+    def __hash(self):
         return super()._hash()
 
+    def __hash__(self):
+        return self.__hash
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class Operand:
-    __slots__ = "ty", "regs"
-
-    def __init__(self, ty, regs=None):
-        # type: (OperandType, OFSet[int] | None) -> None
-        pass
-
-
-OT_VGPR = OperandType(RegKind.GPR, vec=True)
-OT_SGPR = OperandType(RegKind.GPR, vec=False)
-OT_CA = OperandType(RegKind.CA, vec=False)
-OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+class GenericOperandDesc:
+    """generic Op operand descriptor"""
+    __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index"
+
+    def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None):
+        # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None
+        self.ty = ty
+        self.sub_kinds = OFSet(sub_kinds)
+        if len(self.sub_kinds) == 0:
+            raise ValueError("sub_kinds can't be empty")
+        self.fixed_loc = fixed_loc
+        if fixed_loc is not None:
+            if tied_input_index is not None:
+                raise ValueError("operand can't be both tied and fixed")
+            if not ty.can_instantiate_to(fixed_loc.ty):
+                raise ValueError(
+                    f"fixed_loc has incompatible type for given generic "
+                    f"type: fixed_loc={fixed_loc} generic ty={ty}")
+            if len(self.sub_kinds) != 1:
+                raise ValueError(
+                    "multiple sub_kinds not allowed for fixed operand")
+            for sub_kind in self.sub_kinds:
+                if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
+                    raise ValueError(
+                        f"fixed_loc not in given sub_kind: "
+                        f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
+        for sub_kind in self.sub_kinds:
+            if sub_kind.base_ty != ty.base_ty:
+                raise ValueError(f"sub_kind is incompatible with type: "
+                                 f"sub_kind={sub_kind} ty={ty}")
+        if tied_input_index is not None and tied_input_index < 0:
+            raise ValueError("invalid tied_input_index")
+        self.tied_input_index = tied_input_index
+
+    def tied_to_input(self, tied_input_index):
+        # type: (int) -> Self
+        return GenericOperandDesc(self.ty, self.sub_kinds,
+                                  tied_input_index=tied_input_index)
+
+    def with_fixed_loc(self, fixed_loc):
+        # type: (Loc) -> Self
+        return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+
+    def instantiate(self, maxvl):
+        # type: (int) -> OperandDesc
+        ty = self.ty.instantiate(maxvl=maxvl)
+
+        def locs():
+            # type: () -> Iterable[Loc]
+            if self.fixed_loc is not None:
+                if ty != self.fixed_loc.ty:
+                    raise ValueError(
+                        f"instantiation failed: type mismatch with fixed_loc: "
+                        f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+                yield self.fixed_loc
+                return
+            for sub_kind in self.sub_kinds:
+                yield from sub_kind.allocatable_locs(ty)
+        return OperandDesc(loc_set=LocSet(locs()),
+                           tied_input_index=self.tied_input_index)
 
 
 @plain_data(frozen=True, unsafe_hash=True)
-class TiedOutput:
-    __slots__ = "input_index", "output_index"
-
-    def __init__(self, input_index, output_index):
-        # type: (int, int) -> None
-        self.input_index = input_index
-        self.output_index = output_index
-
-
-Constraint = Union[TiedOutput, NoReturn]
+@final
+class OperandDesc:
+    """Op operand descriptor"""
+    __slots__ = "loc_set", "tied_input_index"
+
+    def __init__(self, loc_set, tied_input_index):
+        # type: (LocSet, int | None) -> None
+        if len(loc_set) == 0:
+            raise ValueError("loc_set must not be empty")
+        self.loc_set = loc_set
+        self.tied_input_index = tied_input_index
+
+
+OD_BASE_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.BASE_GPR])
+OD_EXTRA3_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
+OD_EXTRA3_VGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+    sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
+OD_EXTRA2_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
+OD_EXTRA2_VGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+    sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
+OD_CA = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
+    sub_kinds=[LocSubKind.CA])
+OD_VL = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
+    sub_kinds=[LocSubKind.VL_MAXVL])
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OpProperties:
-    __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+class GenericOpProperties:
+    __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
                  "is_copy", "is_load_immediate", "has_side_effects")
 
     def __init__(self, demo_asm,  # type: str
-                 inputs,  # type: Iterable[OperandType]
-                 outputs,  # type: Iterable[OperandType]
+                 inputs,  # type: Iterable[GenericOperandDesc]
+                 outputs,  # type: Iterable[GenericOperandDesc]
                  immediates,  # type: Iterable[range]
-                 constraints,  # type: Iterable[Constraint]
                  is_copy=False,  # type: bool
                  is_load_immediate=False,  # type: bool
                  has_side_effects=False,  # type: bool
@@ -313,17 +502,22 @@ class OpProperties:
         self.inputs = tuple(inputs)
         self.outputs = tuple(outputs)
         self.immediates = tuple(immediates)
-        self.constraints = tuple(constraints)
         self.is_copy = is_copy
         self.is_load_immediate = is_load_immediate
         self.has_side_effects = has_side_effects
 
+    def instantiate(self, maxvl):
+        # type: (int) -> OpProperties
+        raise NotImplementedError  # FIXME: finish
+
+
+# FIXME: add OpProperties
 
 @unique
 @final
 class OpKind(Enum):
     def __init__(self, properties):
-        # type: (OpProperties) -> None
+        # type: (GenericOpProperties) -> None
         super().__init__()
         self.properties = properties
 
@@ -451,13 +645,13 @@ class SSAVal:
         self.sliced_op_outputs = tuple(processed)
 
     def __add__(self, other):
-        # type: (SSAVal) -> SSAVal
+        # type: (SSAVal | Any) -> SSAVal
         if not isinstance(other, SSAVal):
             return NotImplemented
         return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
 
     def __radd__(self, other):
-        # type: (SSAVal) -> SSAVal
+        # type: (SSAVal | Any) -> SSAVal
         if isinstance(other, SSAVal):
             return other.__add__(self)
         return NotImplemented
@@ -465,7 +659,7 @@ class SSAVal:
     @cached_property
     def expanded_sliced_op_outputs(self):
         # type: () -> tuple[tuple[Op, int, int], ...]
-        retval = []
+        retval = []  # type: list[tuple[Op, int, int]]
         for op, output_index, range_ in self.sliced_op_outputs:
             for i in range_:
                 retval.append((op, output_index, i))
@@ -490,7 +684,7 @@ class SSAVal:
         # type: () -> str
         if len(self.sliced_op_outputs) == 0:
             return "SSAVal([])"
-        parts = []
+        parts = []  # type: list[str]
         for op, output_index, range_ in self.sliced_op_outputs:
             out_len = op.properties.outputs[output_index].get_length(op.maxvl)
             parts.append(f"<{op.name}#{output_index}>")
@@ -513,13 +707,14 @@ class Op:
         self.maxvl = maxvl
         outputs_len = len(self.properties.outputs)
         self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
-        self.name = fn._add_op_with_unused_name(self, name)
+        self.name = fn._add_op_with_unused_name(self, name)  # type: ignore
 
     @property
     def properties(self):
         return self.kind.properties
 
     def __eq__(self, other):
+        # type: (Op | Any) -> bool
         if isinstance(other, Op):
             return self is other
         return NotImplemented
index 89c3ea23b2b4554cc1836a6773adc643713882c4..0674c02269b3bdcf4ef3dbe741e7acab562ee905 100644 (file)
@@ -1,10 +1,9 @@
-import operator
 from enum import Enum, unique
 from fractions import Fraction
-from numbers import Rational
+import operator
 from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar
 
-from bigint_presentation_code.util import final
+from bigint_presentation_code.type_util import final
 
 _T = TypeVar("_T")
 _T2 = TypeVar("_T2")
@@ -103,7 +102,7 @@ class Matrix(Generic[_T]):
         return retval
 
     def __truediv__(self, rhs):
-        # type: (Rational | int) -> Matrix
+        # type: (_T | int) -> Matrix[_T]
         retval = self.copy()
         for i in self.indexes():
             retval[i] /= rhs  # type: ignore
@@ -128,7 +127,7 @@ class Matrix(Generic[_T]):
         return lhs.__matmul__(self)
 
     def __elementwise_bin_op(self, rhs, op):
-        # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
+        # type: (Matrix[_T], Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
         if self.height != rhs.height or self.width != rhs.width:
             raise ValueError(
                 "matrix dimensions must match for element-wise operations")
@@ -172,8 +171,8 @@ class Matrix(Generic[_T]):
         # type: () -> str
         if self.height == 0 or self.width == 0:
             return f"Matrix(height={self.height}, width={self.width})"
-        lines = []
-        line = []
+        lines = []  # type: list[str]
+        line = []  # type: list[str]
         for row in range(self.height):
             line.clear()
             for col in range(self.width):
@@ -183,16 +182,16 @@ class Matrix(Generic[_T]):
                 else:
                     line.append(repr(el))
             lines.append(", ".join(line))
-        lines = ",\n    ".join(lines)
+        lines_str = ",\n    ".join(lines)
         element_type = ""
         if self.element_type is not Fraction:
             element_type = f"element_type={self.element_type}, "
         return (f"Matrix(height={self.height}, width={self.width}, "
                 f"{element_type}data=[\n"
-                f"    {lines},\n])")
+                f"    {lines_str},\n])")
 
     def __eq__(self, rhs):
-        # type: (object) -> bool
+        # type: (Matrix[Any] | Any) -> bool
         if not isinstance(rhs, Matrix):
             return NotImplemented
         return (self.height == rhs.height
diff --git a/src/bigint_presentation_code/py.typed b/src/bigint_presentation_code/py.typed
new file mode 100644 (file)
index 0000000..e69de29
index b8269e4bc354f6c31b5433b10d19b71a39a38c1e..cc794e9b926efc906ae8bd6906c1de9bdb7034ce 100644 (file)
@@ -12,7 +12,8 @@ from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
                                                   RegLoc, RegType, SSAVal)
-from bigint_presentation_code.util import OFSet, OSet, final
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import OFSet, OSet
 
 _RegType = TypeVar("_RegType", bound=RegType)
 
diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py
deleted file mode 100644 (file)
index 820c305..0000000
+++ /dev/null
@@ -1,120 +0,0 @@
-import unittest
-
-from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
-                                                  GlobalMem, GPRRange, GPRType,
-                                                  OpBigIntAddSub, OpConcat,
-                                                  OpCopy, OpFuncArg,
-                                                  OpInputMem, OpLI, OpLoad,
-                                                  OpSetCA, OpSetVLImm, OpStore,
-                                                  RegLoc, SSAVal, XERBit,
-                                                  generate_assembly,
-                                                  op_set_to_list)
-
-
-class TestCompilerIR(unittest.TestCase):
-    maxDiff = None
-
-    def test_op_set_to_list(self):
-        fn = Fn()
-        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
-        op1 = OpCopy(fn, op0.out, GPRType())
-        arg = op1.dest
-        op2 = OpInputMem(fn)
-        mem = op2.out
-        op3 = OpSetVLImm(fn, 32)
-        vl = op3.out
-        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
-        a = op4.RT
-        op5 = OpLI(fn, 1)
-        b_0 = op5.out
-        op6 = OpSetVLImm(fn, 31)
-        vl = op6.out
-        op7 = OpLI(fn, 0, vl=vl)
-        b_rest = op7.out
-        op8 = OpConcat(fn, [b_0, b_rest])
-        b = op8.dest
-        op9 = OpSetVLImm(fn, 32)
-        vl = op9.out
-        op10 = OpSetCA(fn, False)
-        ca = op10.out
-        op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
-        s = op11.out
-        op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
-        mem = op12.mem_out
-
-        expected_ops = [
-            op10,  # OpSetCA(fn, False)
-            op9,  # OpSetVLImm(fn, 32)
-            op6,  # OpSetVLImm(fn, 31)
-            op5,  # OpLI(fn, 1)
-            op3,  # OpSetVLImm(fn, 32)
-            op2,  # OpInputMem(fn)
-            op0,  # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
-            op7,  # OpLI(fn, 0, vl=vl)
-            op1,  # OpCopy(fn, op0.out, GPRType())
-            op8,  # OpConcat(fn, [b_0, b_rest])
-            op4,  # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
-            op11,  # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
-            op12,  # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
-        ]
-
-        ops = op_set_to_list(fn.ops[::-1])
-        if ops != expected_ops:
-            self.assertEqual(repr(ops), repr(expected_ops))
-
-    def tst_generate_assembly(self, use_reg_alloc=False):
-        fn = Fn()
-        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
-        op1 = OpCopy(fn, op0.out, GPRType())
-        arg = op1.dest
-        op2 = OpInputMem(fn)
-        mem = op2.out
-        op3 = OpSetVLImm(fn, 32)
-        vl = op3.out
-        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
-        a = op4.RT
-        op5 = OpLI(fn, 0, vl=vl)
-        b = op5.out
-        op6 = OpSetCA(fn, True)
-        ca = op6.out
-        op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
-        s = op7.out
-        op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
-        mem = op8.mem_out
-
-        assigned_registers = {
-            op0.out: GPRRange(start=3, length=1),
-            op1.dest: GPRRange(start=3, length=1),
-            op2.out: GlobalMem.GlobalMem,
-            op3.out: VL.VL_MAXVL,
-            op4.RT: GPRRange(start=78, length=32),
-            op5.out: GPRRange(start=46, length=32),
-            op6.out: XERBit.CA,
-            op7.out: GPRRange(start=14, length=32),
-            op7.CA_out: XERBit.CA,
-            op8.mem_out: GlobalMem.GlobalMem,
-        }  # type: dict[SSAVal, RegLoc] | None
-
-        if use_reg_alloc:
-            assigned_registers = None
-
-        asm = generate_assembly(fn.ops, assigned_registers)
-        self.assertEqual(asm, [
-            "setvl 0, 0, 32, 0, 1, 1",
-            "sv.ld *78, 0(3)",
-            "sv.addi *46, 0, 0",
-            "subfic 0, 0, -1",
-            "sv.adde *14, *78, *46",
-            "sv.std *14, 0(3)",
-            "bclr 20, 0, 0",
-        ])
-
-    def test_generate_assembly(self):
-        self.tst_generate_assembly()
-
-    def test_generate_assembly_with_register_allocator(self):
-        self.tst_generate_assembly(use_reg_alloc=True)
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/src/bigint_presentation_code/test_matrix.py b/src/bigint_presentation_code/test_matrix.py
deleted file mode 100644 (file)
index 1a56df0..0000000
+++ /dev/null
@@ -1,113 +0,0 @@
-import unittest
-from fractions import Fraction
-
-from bigint_presentation_code.matrix import Matrix, SpecialMatrix
-
-
-class TestMatrix(unittest.TestCase):
-    def test_repr(self):
-        self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
-                                            3, 4, 5])),
-                         'Matrix(height=2, width=3, data=[\n'
-                         '    0, 1, 2,\n'
-                         '    3, 4, 5,\n'
-                         '])')
-        self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
-                                            3, 4, 5])),
-                         'Matrix(height=2, width=3, data=[\n'
-                         '    0, 1, Fraction(2, 3),\n'
-                         '    3, 4, 5,\n'
-                         '])')
-        self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
-        self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
-
-    def test_eq(self):
-        self.assertFalse(Matrix(1, 1) == 5)
-        self.assertFalse(5 == Matrix(1, 1))
-        self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
-        self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
-        self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
-        self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
-        self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
-
-    def test_add(self):
-        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
-                         + Matrix(2, 2, [40, 30, 20, 10]),
-                         Matrix(2, 2, [41, 32, 23, 14]))
-
-    def test_identity(self):
-        self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity),
-                         Matrix(2, 2, [1, 0,
-                                       0, 1]))
-        self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity),
-                         Matrix(1, 3, [1, 0, 0]))
-        self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity),
-                         Matrix(2, 3, [1, 0, 0,
-                                       0, 1, 0]))
-        self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity),
-                         Matrix(3, 3, [1, 0, 0,
-                                       0, 1, 0,
-                                       0, 0, 1]))
-
-    def test_sub(self):
-        self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
-                         - Matrix(2, 2, [-1, -2, -3, -4]),
-                         Matrix(2, 2, [41, 32, 23, 14]))
-
-    def test_neg(self):
-        self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
-                         Matrix(2, 2, [-40, -30, -20, -10]))
-
-    def test_mul(self):
-        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
-                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
-        self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
-                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
-
-    def test_matmul(self):
-        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
-                         @ Matrix(2, 2, [4, 3, 2, 1]),
-                         Matrix(2, 2, [8, 5, 20, 13]))
-        self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
-                         @ Matrix(2, 1, [1, 2]),
-                         Matrix(3, 1, [16, 10, 4]))
-
-    def test_inverse(self):
-        self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
-        self.assertEqual(Matrix(1, 1, [2]).inverse(),
-                         Matrix(1, 1, [Fraction(1, 2)]))
-        self.assertEqual(Matrix(1, 1, [1]).inverse(),
-                         Matrix(1, 1, [1]))
-        self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
-                         Matrix(2, 2, [1, 0, -1, 1]))
-        self.assertEqual(Matrix(3, 3, [0, 1, 0,
-                                       1, 0, 0,
-                                       0, 0, 1]).inverse(),
-                         Matrix(3, 3, [0, 1, 0,
-                                       1, 0, 0,
-                                       0, 0, 1]))
-        _1_2 = Fraction(1, 2)
-        _1_3 = Fraction(1, 3)
-        _1_6 = Fraction(1, 6)
-        self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
-                                       1, 1, 1, 1, 1,
-                                       1, -1, 1, -1, 1,
-                                       1, -2, 4, -8, 16,
-                                       0, 0, 0, 0, 1]).inverse(),
-                         Matrix(5, 5, [1, 0, 0, 0, 0,
-                                       _1_2, _1_3, -1, _1_6, -2,
-                                       -1, _1_2, _1_2, 0, -1,
-                                       -_1_2, _1_6, _1_2, -_1_6, 2,
-                                       0, 0, 0, 0, 1]))
-        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(1, 1, [0]).inverse()
-        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [0, 0, 1, 1]).inverse()
-        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [1, 0, 1, 0]).inverse()
-        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [1, 1, 1, 1]).inverse()
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py
deleted file mode 100644 (file)
index 1eff254..0000000
+++ /dev/null
@@ -1,201 +0,0 @@
-import unittest
-
-from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
-                                                  GlobalMem, GPRRange, GPRType,
-                                                  OpBigIntAddSub, OpConcat,
-                                                  OpCopy, OpFuncArg,
-                                                  OpInputMem, OpLI, OpLoad,
-                                                  OpSetCA, OpSetVLImm, OpStore,
-                                                  XERBit)
-from bigint_presentation_code.register_allocator import (
-    AllocationFailed, MergedRegSet, allocate_registers,
-    try_allocate_registers_without_spilling)
-
-
-class TestMergedRegSet(unittest.TestCase):
-    maxDiff = None
-
-    def test_from_equality_constraint(self):
-        fn = Fn()
-        li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
-        li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
-        li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
-        self.assertEqual(MergedRegSet.from_equality_constraint([
-            li0x1.out,
-            li0x2.out,
-            li0x3.out,
-        ]), MergedRegSet({
-            li0x1.out: 0,
-            li0x2.out: 1,
-            li0x3.out: 3,
-        }.items()))
-        self.assertEqual(MergedRegSet.from_equality_constraint([
-            li0x2.out,
-            li0x1.out,
-            li0x3.out,
-        ]), MergedRegSet({
-            li0x2.out: 0,
-            li0x1.out: 2,
-            li0x3.out: 3,
-        }.items()))
-
-
-class TestRegisterAllocator(unittest.TestCase):
-    maxDiff = None
-
-    def test_try_alloc_fail(self):
-        fn = Fn()
-        op0 = OpSetVLImm(fn, 52)
-        op1 = OpLI(fn, 0, vl=op0.out)
-        op2 = OpSetVLImm(fn, 64)
-        op3 = OpLI(fn, 0, vl=op2.out)
-        op4 = OpConcat(fn, [op1.out, op3.out])
-
-        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
-        self.assertEqual(
-            repr(reg_assignments),
-            "AllocationFailed("
-            "node=IGNode(#0, merged_reg_set=MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]), "
-            "edges={}, reg=None), "
-            "live_intervals=LiveIntervals(live_intervals={"
-            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
-            "LiveInterval(first_write=0, last_use=1), "
-            "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]): "
-            "LiveInterval(first_write=1, last_use=4), "
-            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
-            "LiveInterval(first_write=2, last_use=3)}, "
-            "merged_reg_sets=MergedRegSets(data={"
-            "<#0.out: KnownVLType(length=52)>: "
-            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
-            "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]), "
-            "<#2.out: KnownVLType(length=64)>: "
-            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
-            "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]), "
-            "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)])}), "
-            "reg_sets_live_after={"
-            "0: OFSet([MergedRegSet(["
-            "(<#0.out: KnownVLType(length=52)>, 0)])]), "
-            "1: OFSet([MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
-            "2: OFSet([MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]), "
-            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
-            "3: OFSet([MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
-            "4: OFSet()}), "
-            "interference_graph=InterferenceGraph(nodes={"
-            "...: IGNode(#0, merged_reg_set=MergedRegSet(["
-            "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
-            "...: IGNode(#1, merged_reg_set=MergedRegSet(["
-            "(<#4.dest: <gpr_ty[116]>>, 0), "
-            "(<#1.out: <gpr_ty[52]>>, 0), "
-            "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
-            "...: IGNode(#2, merged_reg_set=MergedRegSet(["
-            "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
-        )
-
-    def test_try_alloc_bigint_inc(self):
-        fn = Fn()
-        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
-        op1 = OpCopy(fn, op0.out, GPRType())
-        arg = op1.dest
-        op2 = OpInputMem(fn)
-        mem = op2.out
-        op3 = OpSetVLImm(fn, 32)
-        vl = op3.out
-        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
-        a = op4.RT
-        op5 = OpLI(fn, 0, vl=vl)
-        b = op5.out
-        op6 = OpSetCA(fn, True)
-        ca = op6.out
-        op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
-        s = op7.out
-        op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
-        mem = op8.mem_out
-
-        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
-
-        expected_reg_assignments = {
-            op0.out: GPRRange(start=3, length=1),
-            op1.dest: GPRRange(start=3, length=1),
-            op2.out: GlobalMem.GlobalMem,
-            op3.out: VL.VL_MAXVL,
-            op4.RT: GPRRange(start=78, length=32),
-            op5.out: GPRRange(start=46, length=32),
-            op6.out: XERBit.CA,
-            op7.out: GPRRange(start=14, length=32),
-            op7.CA_out: XERBit.CA,
-            op8.mem_out: GlobalMem.GlobalMem,
-        }
-
-        self.assertEqual(reg_assignments, expected_reg_assignments)
-
-    def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
-        # type: (list[GPRRange], GPRRange) -> None
-        fn = Fn()
-        inputs = []
-        expected_reg_assignments = {}
-        for i, r in enumerate(expected_regs):
-            vl = OpSetVLImm(fn, r.length).out
-            expected_reg_assignments[vl] = VL.VL_MAXVL
-            inp = OpLI(fn, i, vl=vl).out
-            inputs.append(inp)
-            expected_reg_assignments[inp] = r
-        concat = OpConcat(fn, inputs)
-        expected_reg_assignments[concat.dest] = expected_dest_reg
-
-        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
-
-        for inp, reg in zip(inputs, expected_regs):
-            expected_reg_assignments[inp] = reg
-
-        self.assertEqual(reg_assignments, expected_reg_assignments)
-
-    def test_try_alloc_concat_1(self):
-        self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
-
-    def test_try_alloc_concat_3(self):
-        self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
-
-    def test_try_alloc_concat_3_5(self):
-        self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
-                                  GPRRange(3, 8))
-
-    def test_try_alloc_concat_5_3(self):
-        self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
-                                  GPRRange(3, 8))
-
-    def test_try_alloc_concat_1_2_3_4_5_6(self):
-        self.tst_try_alloc_concat([
-            GPRRange(14, 1),
-            GPRRange(15, 2),
-            GPRRange(17, 3),
-            GPRRange(20, 4),
-            GPRRange(24, 5),
-            GPRRange(29, 6),
-        ], GPRRange(14, 21))
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/test_toom_cook.py
deleted file mode 100644 (file)
index 6fff570..0000000
+++ /dev/null
@@ -1,378 +0,0 @@
-import unittest
-
-from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
-                                                  GlobalMem, GPRRange,
-                                                  GPRRangeType, OpCopy,
-                                                  OpFuncArg, OpInputMem,
-                                                  OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
-                                                  generate_assembly)
-from bigint_presentation_code.register_allocator import allocate_registers
-from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
-from bigint_presentation_code.util import FMap
-
-
-class SimpleMul192x192:
-    def __init__(self):
-        self.fn = fn = Fn()
-        self.mem_in = mem = OpInputMem(fn).out
-        self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
-        self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
-        self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
-        dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
-        vl = OpSetVLImm(fn, 3).out
-        lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
-        rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
-        retval = simple_mul(fn, lhs, rhs)
-        vl = OpSetVLImm(fn, 6).out
-        self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
-                               mem_in=mem, vl=vl).mem_out
-
-
-class TestToomCook(unittest.TestCase):
-    maxDiff = None
-
-    def test_toom_2_repr(self):
-        TOOM_2 = ToomCookInstance.make_toom_2()
-        # print(repr(repr(TOOM_2)))
-        self.assertEqual(
-            repr(TOOM_2),
-            "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
-            "eval_points=(0, 1, POINT_AT_INFINITY), "
-            "lhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
-            " rhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
-            " prod_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpSub(lhs="
-            "EvalOpSub(lhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({"
-            "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
-        )
-
-    def test_toom_2_5_repr(self):
-        TOOM_2_5 = ToomCookInstance.make_toom_2_5()
-        # print(repr(repr(TOOM_2_5)))
-        self.assertEqual(
-            repr(TOOM_2_5),
-            "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
-            "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs="
-            "EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=2, rhs=0, "
-            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=1, rhs=0, "
-            "poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({"
-            "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
-            "EvalOpSub(lhs="
-            "EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=2, rhs=0, "
-            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=1, rhs=0, "
-            "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
-            "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
-            "EvalOpInput(lhs=2, rhs=0, "
-            "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
-            "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
-            "EvalOpSub(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=1, rhs=0, "
-            "poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
-            "EvalOpInput(lhs=1, rhs=0, "
-            "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
-            "prod_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "rhs=EvalOpInput(lhs=2, rhs=0, "
-            "poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
-            "rhs=2, "
-            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
-            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
-            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
-            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
-            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
-        )
-
-    def test_reversed_toom_2_5_repr(self):
-        TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
-        # print(repr(repr(TOOM_2_5)))
-        self.assertEqual(
-            repr(TOOM_2_5),
-            "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
-            "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
-            "EvalOpSub(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
-            " rhs_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpAdd(lhs=EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
-            "EvalOpSub(lhs=EvalOpAdd(lhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
-            " prod_eval_ops=("
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
-            "rhs=2, "
-            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
-            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
-            "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
-            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
-            "rhs="
-            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
-            "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
-            "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
-            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
-            "poly=EvalOpPoly("
-            "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
-            "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
-        )
-
-    def test_simple_mul_192x192_pre_ra_sim(self):
-        # test multiplying:
-        #   0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
-        # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
-        # ==
-        # int("0x00074736574206e_6f69746163696c70"
-        #     "_69746c756d207469_622d3438333e2d32"
-        #     "_3931783239312079_7261727469627261", base=0)
-        # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
-        #                   'little')
-        code = SimpleMul192x192()
-        dest_ptr = 0x100
-        state = PreRASimState(
-            gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
-            stack_slots={}, fixed_gprs={
-                code.dest_ptr_in: (dest_ptr,),
-                code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
-                              0x000191acb262e15b),
-                code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
-                              0x4a37c0567bcbab53)
-            })
-        code.fn.pre_ra_sim(state)
-        expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
-        OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
-        expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
-        mem_out = state.global_mems[code.mem_out]
-        out_bytes = bytes(
-            mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
-        self.assertEqual(out_bytes, expected_bytes)
-
-    def test_simple_mul_192x192_ops(self):
-        code = SimpleMul192x192()
-        fn = code.fn
-        self.assertEqual([repr(v) for v in fn.ops], [
-            'OpInputMem(#0, <#0.out: GlobalMemType()>)',
-            'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
-            'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
-            'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
-            'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
-            'vl=None)',
-            'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
-            'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
-            'src=<#2.out: <fixed(<r4..len=3>)>>, '
-            'vl=<#5.out: KnownVLType(length=3)>)',
-            'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
-            'src=<#3.out: <fixed(<r7..len=3>)>>, '
-            'vl=<#5.out: KnownVLType(length=3)>)',
-            'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
-            '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
-            'src=<#7.dest: <gpr_ty[3]>>)',
-            'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
-            'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
-            'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
-            '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
-            'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
-            '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
-            'OpSetCA(#15, <#15.out: CAType()>, value=False)',
-            'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
-            'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
-            'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
-            'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
-            'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
-            'is_sub=False, vl=None)',
-            'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
-            '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
-            '<#17.out: <gpr_ty[1]>>))',
-            'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
-            '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
-            'OpSetCA(#21, <#21.out: CAType()>, value=False)',
-            'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
-            'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
-            'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
-            'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
-            'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
-            'is_sub=False, vl=None)',
-            'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
-            '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
-            '<#23.out: <gpr_ty[1]>>))',
-            'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
-            'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
-            'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
-            'mem_in=<#0.out: GlobalMemType()>, '
-            '<#26.mem_out: GlobalMemType()>, '
-            'vl=<#25.out: KnownVLType(length=6)>)'
-        ])
-
-    # FIXME: register allocator currently allocates wrong registers
-    @unittest.expectedFailure
-    def test_simple_mul_192x192_reg_alloc(self):
-        code = SimpleMul192x192()
-        fn = code.fn
-        assigned_registers = allocate_registers(fn.ops)
-        self.assertEqual(assigned_registers, {
-            fn.ops[13].RS: GPRRange(9),  # type: ignore
-            fn.ops[14].results[0]: GPRRange(6),  # type: ignore
-            fn.ops[14].results[1]: GPRRange(7, length=3),  # type: ignore
-            fn.ops[15].out: XERBit.CA,  # type: ignore
-            fn.ops[16].out: GPRRange(7, length=3),  # type: ignore
-            fn.ops[16].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[17].out: GPRRange(10),  # type: ignore
-            fn.ops[17].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[18].dest: GPRRange(6, length=5),  # type: ignore
-            fn.ops[19].RT: GPRRange(3, length=3),  # type: ignore
-            fn.ops[19].RS: GPRRange(9),  # type: ignore
-            fn.ops[20].results[0]: GPRRange(6, length=2),  # type: ignore
-            fn.ops[20].results[1]: GPRRange(8, length=3),  # type: ignore
-            fn.ops[21].out: XERBit.CA,  # type: ignore
-            fn.ops[22].out: GPRRange(8, length=3),  # type: ignore
-            fn.ops[22].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[23].out: GPRRange(11),  # type: ignore
-            fn.ops[23].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[24].dest: GPRRange(6, length=6),  # type: ignore
-            fn.ops[25].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[26].mem_out: GlobalMem.GlobalMem,  # type: ignore
-            fn.ops[0].out: GlobalMem.GlobalMem,  # type: ignore
-            fn.ops[1].out: GPRRange(3),  # type: ignore
-            fn.ops[2].out: GPRRange(4, length=3),  # type: ignore
-            fn.ops[3].out: GPRRange(7, length=3),  # type: ignore
-            fn.ops[4].dest: GPRRange(12),  # type: ignore
-            fn.ops[5].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[6].dest: GPRRange(17, length=3),  # type: ignore
-            fn.ops[7].dest: GPRRange(14, length=3),  # type: ignore
-            fn.ops[8].results[0]: GPRRange(14),  # type: ignore
-            fn.ops[8].results[1]: GPRRange(15),  # type: ignore
-            fn.ops[8].results[2]: GPRRange(16),  # type: ignore
-            fn.ops[9].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[10].out: GPRRange(9),  # type: ignore
-            fn.ops[11].RT: GPRRange(6, length=3),  # type: ignore
-            fn.ops[11].RS: GPRRange(9),  # type: ignore
-            fn.ops[12].dest: GPRRange(6, length=4),  # type: ignore
-            fn.ops[13].RT: GPRRange(3, length=3)  # type: ignore
-        })
-        self.fail("register allocator currently allocates wrong registers")
-
-    # FIXME: register allocator currently allocates wrong registers
-    @unittest.expectedFailure
-    def test_simple_mul_192x192_asm(self):
-        code = SimpleMul192x192()
-        asm = generate_assembly(code.fn.ops)
-        self.assertEqual(asm, [
-            'or 12, 3, 3',
-            'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *17, *4, *4',
-            'sv.or *14, *7, *7',
-            'setvl 0, 0, 3, 0, 1, 1',
-            'addi 9, 0, 0',
-            'sv.maddedu *6, *17, 14, 9',
-            'sv.maddedu *3, *17, 15, 9',
-            'addic 0, 0, 0',
-            'sv.adde *7, *3, *7',
-            'adde 10, 9, 9',
-            'sv.maddedu *3, *17, 16, 9',
-            'addic 0, 0, 0',
-            'sv.adde *8, *3, *8',
-            'adde 11, 9, 9',
-            'setvl 0, 0, 6, 0, 1, 1',
-            'sv.std *6, 0(12)',
-            'bclr 20, 0, 0'
-        ])
-        self.fail("register allocator currently allocates wrong registers")
-
-
-if __name__ == "__main__":
-    unittest.main()
index 9e3ec74e8842431bb0c3df895c8053e5b51bbd28..246f65454a861be15fbfb00954f9f9a37dac4899 100644 (file)
@@ -4,13 +4,16 @@ Toom-Cook multiplication algorithm generator for SVP64
 from abc import abstractmethod
 from enum import Enum
 from fractions import Fraction
-from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
+from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
 
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange
+from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub,
+                                                  OpBigIntMulDiv, OpConcat,
+                                                  OpLI, OpSetCA, OpSetVLImm,
+                                                  OpSplit, SSAGPRRange)
 from bigint_presentation_code.matrix import Matrix
-from bigint_presentation_code.util import Literal, OSet, final
+from bigint_presentation_code.type_util import Literal, final
 
 
 @final
@@ -158,8 +161,8 @@ class EvalOpPoly:
         return f"EvalOpPoly({self.coefficients})"
 
 
-_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
-_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
+_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
+_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
 
 
 @plain_data(frozen=True, unsafe_hash=True)
@@ -238,7 +241,7 @@ class EvalOpInput(EvalOp[int, Literal[0]]):
     __slots__ = ()
 
     def __init__(self, lhs, rhs=0):
-        # type: (...) -> None
+        # type: (int, int) -> None
         if lhs < 0:
             raise ValueError("Input part_index (lhs) must be >= 0")
         if rhs != 0:
diff --git a/src/bigint_presentation_code/type_util.py b/src/bigint_presentation_code/type_util.py
new file mode 100644 (file)
index 0000000..ed7296f
--- /dev/null
@@ -0,0 +1,32 @@
+from typing import TYPE_CHECKING, Any, NoReturn, Union
+
+if TYPE_CHECKING:
+    from typing_extensions import Literal, Self, final
+else:
+    def final(v):
+        return v
+
+    class _Literal:
+        def __getitem__(self, v):
+            if isinstance(v, tuple):
+                return Union[tuple(type(i) for i in v)]
+            return type(v)
+
+    Literal = _Literal()
+
+    Self = Any
+
+
+# pyright currently doesn't like typing_extensions' definition
+# -- added to typing in python 3.11
+def assert_never(arg):
+    # type: (NoReturn) -> NoReturn
+    raise AssertionError("got to code that's supposed to be unreachable")
+
+
+__all__ = [
+    "assert_never",
+    "final",
+    "Literal",
+    "Self",
+]
diff --git a/src/bigint_presentation_code/type_util.pyi b/src/bigint_presentation_code/type_util.pyi
new file mode 100644 (file)
index 0000000..630ca20
--- /dev/null
@@ -0,0 +1,19 @@
+from typing import NoReturn, TypeVar
+
+from typing_extensions import Literal, Self, final
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+
+# pyright currently doesn't like typing_extensions' definition
+# -- added to typing in python 3.11
+def assert_never(arg: NoReturn) -> NoReturn: ...
+
+
+__all__ = [
+    "assert_never",
+    "final",
+    "Literal",
+    "Self",
+]
index aeea240c545aaf458772a515c78252ac1b928fc9..4b3978741eb5290cfd4257846ecbf5bd214aff07 100644 (file)
@@ -1,50 +1,25 @@
 from abc import abstractmethod
-from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator,
-                    Mapping, MutableSet, NoReturn, TypeVar, Union)
+from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
+                    TypeVar, overload)
 
-if TYPE_CHECKING:
-    from typing_extensions import Literal, Self, final
-else:
-    def final(v):
-        return v
-
-    class _Literal:
-        def __getitem__(self, v):
-            if isinstance(v, tuple):
-                return Union[tuple(type(i) for i in v)]
-            return type(v)
-
-    Literal = _Literal()
-
-    Self = Any
+from bigint_presentation_code.type_util import Self, final
 
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
 
 __all__ = [
-    "assert_never",
     "BaseBitSet",
     "bit_count",
     "BitSet",
     "FBitSet",
-    "final",
     "FMap",
-    "Literal",
     "OFSet",
     "OSet",
-    "Self",
     "top_set_bit_index",
     "trailing_zero_count",
 ]
 
 
-# pyright currently doesn't like typing_extensions' definition
-# -- added to typing in python 3.11
-def assert_never(arg):
-    # type: (NoReturn) -> NoReturn
-    raise AssertionError("got to code that's supposed to be unreachable")
-
-
 class OFSet(AbstractSet[_T_co]):
     """ ordered frozen set """
     __slots__ = "__items",
@@ -54,18 +29,23 @@ class OFSet(AbstractSet[_T_co]):
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         return x in self.__items
 
     def __iter__(self):
+        # type: () -> Iterator[_T_co]
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def __hash__(self):
+        # type: () -> int
         return self._hash()
 
     def __repr__(self):
+        # type: () -> str
         if len(self) == 0:
             return "OFSet()"
         return f"OFSet({list(self)})"
@@ -80,12 +60,15 @@ class OSet(MutableSet[_T]):
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         return x in self.__items
 
     def __iter__(self):
+        # type: () -> Iterator[_T]
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def add(self, value):
@@ -97,6 +80,7 @@ class OSet(MutableSet[_T]):
         self.__items.pop(value, None)
 
     def __repr__(self):
+        # type: () -> str
         if len(self) == 0:
             return "OSet()"
         return f"OSet({list(self)})"
@@ -106,6 +90,21 @@ class FMap(Mapping[_T, _T_co]):
     """ordered frozen hashable mapping"""
     __slots__ = "__items", "__hash"
 
+    @overload
+    def __init__(self, items):
+        # type: (Mapping[_T, _T_co]) -> None
+        ...
+
+    @overload
+    def __init__(self, items):
+        # type: (Iterable[tuple[_T, _T_co]]) -> None
+        ...
+
+    @overload
+    def __init__(self):
+        # type: () -> None
+        ...
+
     def __init__(self, items=()):
         # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
         self.__items = dict(items)  # type: dict[_T, _T_co]
@@ -120,20 +119,23 @@ class FMap(Mapping[_T, _T_co]):
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def __eq__(self, other):
-        # type: (object) -> bool
+        # type: (FMap[Any, Any] | Any) -> bool
         if isinstance(other, FMap):
             return self.__items == other.__items
         return super().__eq__(other)
 
     def __hash__(self):
+        # type: () -> int
         if self.__hash is None:
             self.__hash = hash(frozenset(self.items()))
         return self.__hash
 
     def __repr__(self):
+        # type: () -> str
         return f"FMap({self.__items})"
 
 
@@ -153,7 +155,7 @@ def top_set_bit_index(v, default=-1):
 
 try:
     # added in cpython 3.10
-    bit_count = int.bit_count  # type: ignore[attr]
+    bit_count = int.bit_count  # type: ignore
 except AttributeError:
     def bit_count(v):
         # type: (int) -> int
@@ -177,16 +179,20 @@ class BaseBitSet(AbstractSet[int]):
 
     def __init__(self, items=(), bits=0):
         # type: (Iterable[int], int) -> None
-        for item in items:
-            if item < 0:
-                raise ValueError("can't store negative integers")
-            bits |= 1 << item
+        if isinstance(items, BaseBitSet):
+            bits |= items.bits
+        else:
+            for item in items:
+                if item < 0:
+                    raise ValueError("can't store negative integers")
+                bits |= 1 << item
         if bits < 0:
             raise ValueError("can't store an infinite set")
         self.__bits = bits
 
     @property
     def bits(self):
+        # type: () -> int
         return self.__bits
 
     @bits.setter
@@ -199,6 +205,7 @@ class BaseBitSet(AbstractSet[int]):
         self.__bits = bits
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         if isinstance(x, int) and x >= 0:
             return (1 << x) & self.bits != 0
         return False
@@ -220,9 +227,11 @@ class BaseBitSet(AbstractSet[int]):
             bits -= 1 << index
 
     def __len__(self):
+        # type: () -> int
         return bit_count(self.bits)
 
     def __repr__(self):
+        # type: () -> str
         if self.bits == 0:
             return f"{self.__class__.__name__}()"
         if self.bits > 0xFFFFFFFF and len(self) < 10:
@@ -231,7 +240,7 @@ class BaseBitSet(AbstractSet[int]):
         return f"{self.__class__.__name__}(bits={hex(self.bits)})"
 
     def __eq__(self, other):
-        # type: (object) -> bool
+        # type: (Any) -> bool
         if not isinstance(other, BaseBitSet):
             return super().__eq__(other)
         return self.bits == other.bits
@@ -320,6 +329,7 @@ class BitSet(BaseBitSet, MutableSet[int]):
             self.bits &= ~(1 << value)
 
     def clear(self):
+        # type: () -> None
         self.bits = 0
 
     def __ior__(self, it):
@@ -361,4 +371,5 @@ class FBitSet(BaseBitSet):
         return True
 
     def __hash__(self):
+        # type: () -> int
         return super()._hash()
diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi
deleted file mode 100644 (file)
index 6315823..0000000
+++ /dev/null
@@ -1,190 +0,0 @@
-from abc import abstractmethod
-from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
-                    NoReturn, TypeVar, overload)
-
-from typing_extensions import Literal, Self, final
-
-_T_co = TypeVar("_T_co", covariant=True)
-_T = TypeVar("_T")
-
-__all__ = [
-    "assert_never",
-    "BaseBitSet",
-    "bit_count",
-    "BitSet",
-    "FBitSet",
-    "final",
-    "FMap",
-    "Literal",
-    "OFSet",
-    "OSet",
-    "Self",
-    "top_set_bit_index",
-    "trailing_zero_count",
-]
-
-
-# pyright currently doesn't like typing_extensions' definition
-# -- added to typing in python 3.11
-def assert_never(arg):
-    # type: (NoReturn) -> NoReturn
-    raise AssertionError("got to code that's supposed to be unreachable")
-
-
-class OFSet(AbstractSet[_T_co]):
-    """ ordered frozen set """
-
-    def __init__(self, items: Iterable[_T_co] = ()):
-        ...
-
-    def __contains__(self, x: object) -> bool:
-        ...
-
-    def __iter__(self) -> Iterator[_T_co]:
-        ...
-
-    def __len__(self) -> int:
-        ...
-
-    def __hash__(self) -> int:
-        ...
-
-    def __repr__(self) -> str:
-        ...
-
-
-class OSet(MutableSet[_T]):
-    """ ordered mutable set """
-
-    def __init__(self, items: Iterable[_T] = ()):
-        ...
-
-    def __contains__(self, x: object) -> bool:
-        ...
-
-    def __iter__(self) -> Iterator[_T]:
-        ...
-
-    def __len__(self) -> int:
-        ...
-
-    def add(self, value: _T) -> None:
-        ...
-
-    def discard(self, value: _T) -> None:
-        ...
-
-    def __repr__(self) -> str:
-        ...
-
-
-class FMap(Mapping[_T, _T_co]):
-    """ordered frozen hashable mapping"""
-    @overload
-    def __init__(self, items: Mapping[_T, _T_co]): ...
-    @overload
-    def __init__(self, items: Iterable[tuple[_T, _T_co]]): ...
-    @overload
-    def __init__(self): ...
-
-    def __getitem__(self, item: _T) -> _T_co:
-        ...
-
-    def __iter__(self) -> Iterator[_T]:
-        ...
-
-    def __len__(self) -> int:
-        ...
-
-    def __eq__(self, other: object) -> bool:
-        ...
-
-    def __hash__(self) -> int:
-        ...
-
-    def __repr__(self) -> str:
-        ...
-
-
-def trailing_zero_count(v: int, default: int = -1) -> int: ...
-def top_set_bit_index(v: int, default: int = -1) -> int: ...
-def bit_count(v: int) -> int: ...
-
-
-class BaseBitSet(AbstractSet[int]):
-    @classmethod
-    @abstractmethod
-    def _frozen(cls) -> bool: ...
-
-    @classmethod
-    def _from_bits(cls, bits: int) -> Self: ...
-
-    def __init__(self, items: Iterable[int] = (), bits: int = 0): ...
-
-    @property
-    def bits(self) -> int:
-        ...
-
-    @bits.setter
-    def bits(self, bits: int) -> None: ...
-
-    def __contains__(self, x: object) -> bool: ...
-
-    def __iter__(self) -> Iterator[int]: ...
-
-    def __reversed__(self) -> Iterator[int]: ...
-
-    def __len__(self) -> int: ...
-
-    def __repr__(self) -> str: ...
-
-    def __eq__(self, other: object) -> bool: ...
-
-    def __and__(self, other: Iterable[Any]) -> Self: ...
-
-    __rand__ = __and__
-
-    def __or__(self, other: Iterable[Any]) -> Self: ...
-
-    __ror__ = __or__
-
-    def __xor__(self, other: Iterable[Any]) -> Self: ...
-
-    __rxor__ = __xor__
-
-    def __sub__(self, other: Iterable[Any]) -> Self: ...
-
-    def __rsub__(self, other: Iterable[Any]) -> Self: ...
-
-    def isdisjoint(self, other: Iterable[Any]) -> bool: ...
-
-
-class BitSet(BaseBitSet, MutableSet[int]):
-    @final
-    @classmethod
-    def _frozen(cls) -> Literal[False]: ...
-
-    def add(self, value: int) -> None: ...
-
-    def discard(self, value: int) -> None: ...
-
-    def clear(self) -> None: ...
-
-    def __ior__(self, it: AbstractSet[Any]) -> Self: ...
-
-    def __iand__(self, it: AbstractSet[Any]) -> Self: ...
-
-    def __ixor__(self, it: AbstractSet[Any]) -> Self: ...
-
-    def __isub__(self, it: AbstractSet[Any]) -> Self: ...
-
-
-class FBitSet(BaseBitSet):
-    @property
-    def bits(self) -> int: ...
-
-    @final
-    @classmethod
-    def _frozen(cls) -> Literal[True]: ...
-
-    def __hash__(self) -> int: ...
index b8b1f305ef4ed3a758944974b08b048fb8ef5a56..5ec70856242022a25d79008a336cd358460e61d7 100644 (file)
@@ -1,15 +1 @@
-from typing import Any, Callable, Generic, TypeVar, overload
-
-_T = TypeVar("_T")
-
-
-class cached_property(Generic[_T]):
-    def __init__(self, func: Callable[[Any], _T]) -> None: ...
-
-    @overload
-    def __get__(self, instance: None,
-                owner: type[Any] | None = ...) -> cached_property[_T]: ...
-
-    @overload
-    def __get__(self, instance: object,
-                owner: type[Any] | None = ...) -> _T: ...
+cached_property = property