working on adding signed multiplication -- needed for toom-cook
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
index 96dc3187e7076309b06dc33bab4f7c945d0b95b5..2f548a4886587d97c39a2ca789b590146157c723 100644 (file)
@@ -5,37 +5,57 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES,
                                                   BaseSimState, Fn,
                                                   GenAsmState, OpKind,
                                                   PostRASimState,
-                                                  PreRASimState)
+                                                  PreRASimState, SSAVal)
 from bigint_presentation_code.register_allocator import allocate_registers
-from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul,
+                                                toom_cook_mul)
 
 
-class SimpleMul192x192:
-    def __init__(self):
+def simple_umul(fn, lhs, rhs):
+    # type: (Fn, SSAVal, SSAVal) -> SSAVal
+    return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+                      rhs_signed=False, name="simple_umul")
+
+
+class Mul:
+    def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
+        # type: (Callable[[Fn, SSAVal, SSAVal], SSAVal], int, int) -> None
         super().__init__()
         self.fn = fn = Fn()
         self.dest_offset = 0
-        self.lhs_offset = 48 + self.dest_offset
-        self.rhs_offset = 24 + self.lhs_offset
+        self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
+        self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
+        self.lhs_size_in_words = lhs_size_in_words
+        self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
+        self.rhs_size_in_words = rhs_size_in_words
+        self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
+        self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
+        self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
         self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
                                        name="ptr_in").outputs[0]
-        setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3],
-                                  maxvl=3, name="setvl3")
+        lhs_setvl = fn.append_new_op(
+            kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
+            maxvl=lhs_size_in_words, name="lhs_setvl")
         load_lhs = fn.append_new_op(
             kind=OpKind.SvLd, immediates=[self.lhs_offset],
-            input_vals=[self.ptr_in, setvl3.outputs[0]],
-            name="load_lhs", maxvl=3)
+            input_vals=[self.ptr_in, lhs_setvl.outputs[0]],
+            name="load_lhs", maxvl=lhs_size_in_words)
+        rhs_setvl = fn.append_new_op(
+            kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
+            maxvl=rhs_size_in_words, name="rhs_setvl")
         load_rhs = fn.append_new_op(
             kind=OpKind.SvLd, immediates=[self.rhs_offset],
-            input_vals=[self.ptr_in, setvl3.outputs[0]],
+            input_vals=[self.ptr_in, rhs_setvl.outputs[0]],
             name="load_rhs", maxvl=3)
-        retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
-        setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6],
-                                  maxvl=6, name="setvl6")
+        retval = mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
+        dest_setvl = fn.append_new_op(
+            kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
+            maxvl=self.dest_size_in_words, name="dest_setvl")
         fn.append_new_op(
             kind=OpKind.SvStd,
-            input_vals=[retval, self.ptr_in, setvl6.outputs[0]],
-            immediates=[self.dest_offset], maxvl=6, name="store_dest")
+            input_vals=[retval, self.ptr_in, dest_setvl.outputs[0]],
+            immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
+            name="store_dest")
 
 
 class TestToomCook(unittest.TestCase):
@@ -207,21 +227,26 @@ class TestToomCook(unittest.TestCase):
         )
 
     def test_simple_mul_192x192_pre_ra_sim(self):
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
+
         def create_sim_state(code):
-            # type: (SimpleMul192x192) -> BaseSimState
+            # type: (Mul) -> BaseSimState
             return PreRASimState(ssa_vals={}, memory={})
         self.tst_simple_mul_192x192_sim(create_sim_state)
 
     def test_simple_mul_192x192_post_ra_sim(self):
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
+
         def create_sim_state(code):
-            # type: (SimpleMul192x192) -> BaseSimState
+            # type: (Mul) -> BaseSimState
             ssa_val_to_loc_map = allocate_registers(code.fn)
             return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
                                   memory={}, loc_values={})
         self.tst_simple_mul_192x192_sim(create_sim_state)
 
     def tst_simple_mul_192x192_sim(self, create_sim_state):
-        # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None
+        # type: (Callable[[Mul], BaseSimState]) -> None
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
         # test multiplying:
         #   0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
         # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
@@ -231,7 +256,7 @@ class TestToomCook(unittest.TestCase):
         #     "_3931783239312079_7261727469627261", base=0)
         # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
         #                   'little')
-        code = SimpleMul192x192()
+        code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         state = create_sim_state(code)
         ptr_in = 0x100
         dest_ptr = ptr_in + code.dest_offset
@@ -253,7 +278,8 @@ class TestToomCook(unittest.TestCase):
         self.assertEqual(out_bytes, expected_bytes)
 
     def test_simple_mul_192x192_ops(self):
-        code = SimpleMul192x192()
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
+        code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
         self.assertEqual([repr(v) for v in fn.ops], [
             "Op(kind=OpKind.FuncArgR3, "
@@ -264,18 +290,23 @@ class TestToomCook(unittest.TestCase):
             "Op(kind=OpKind.SetVLI, "
             "input_vals=[], "
             "input_uses=(), immediates=[3], "
-            "outputs=(<setvl3.outputs[0]: <VL_MAXVL>>,), "
-            "name='setvl3')",
+            "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='lhs_setvl')",
             "Op(kind=OpKind.SvLd, "
             "input_vals=[<ptr_in.outputs[0]: <I64>>, "
-            "<setvl3.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<load_lhs.input_uses[0]: <I64>>, "
             "<load_lhs.input_uses[1]: <VL_MAXVL>>), immediates=[48], "
             "outputs=(<load_lhs.outputs[0]: <I64*3>>,), "
             "name='load_lhs')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[3], "
+            "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='rhs_setvl')",
             "Op(kind=OpKind.SvLd, "
             "input_vals=[<ptr_in.outputs[0]: <I64>>, "
-            "<setvl3.outputs[0]: <VL_MAXVL>>], "
+            "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<load_rhs.input_uses[0]: <I64>>, "
             "<load_rhs.input_uses[1]: <VL_MAXVL>>), immediates=[72], "
             "outputs=(<load_rhs.outputs[0]: <I64*3>>,), "
@@ -283,11 +314,11 @@ class TestToomCook(unittest.TestCase):
             "Op(kind=OpKind.SetVLI, "
             "input_vals=[], "
             "input_uses=(), immediates=[3], "
-            "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
-            "name='rhs_setvl')",
+            "outputs=(<rhs_setvl2.outputs[0]: <VL_MAXVL>>,), "
+            "name='rhs_setvl2')",
             "Op(kind=OpKind.Spread, "
             "input_vals=[<load_rhs.outputs[0]: <I64*3>>, "
-            "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<rhs_setvl2.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<rhs_spread.input_uses[0]: <I64*3>>, "
             "<rhs_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
             "outputs=(<rhs_spread.outputs[0]: <I64>>, "
@@ -297,8 +328,8 @@ class TestToomCook(unittest.TestCase):
             "Op(kind=OpKind.SetVLI, "
             "input_vals=[], "
             "input_uses=(), immediates=[3], "
-            "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
-            "name='lhs_setvl')",
+            "outputs=(<lhs_setvl3.outputs[0]: <VL_MAXVL>>,), "
+            "name='lhs_setvl3')",
             "Op(kind=OpKind.LI, "
             "input_vals=[], "
             "input_uses=(), immediates=[0], "
@@ -308,7 +339,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
             "<rhs_spread.outputs[0]: <I64>>, "
             "<zero.outputs[0]: <I64>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<mul0.input_uses[0]: <I64*3>>, "
             "<mul0.input_uses[1]: <I64>>, "
             "<mul0.input_uses[2]: <I64>>, "
@@ -318,7 +349,7 @@ class TestToomCook(unittest.TestCase):
             "name='mul0')",
             "Op(kind=OpKind.Spread, "
             "input_vals=[<mul0.outputs[0]: <I64*3>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<mul0_rt_spread.input_uses[0]: <I64*3>>, "
             "<mul0_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
             "outputs=(<mul0_rt_spread.outputs[0]: <I64>>, "
@@ -329,7 +360,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
             "<rhs_spread.outputs[1]: <I64>>, "
             "<zero.outputs[0]: <I64>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<mul1.input_uses[0]: <I64*3>>, "
             "<mul1.input_uses[1]: <I64>>, "
             "<mul1.input_uses[2]: <I64>>, "
@@ -341,7 +372,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<mul0_rt_spread.outputs[1]: <I64>>, "
             "<mul0_rt_spread.outputs[2]: <I64>>, "
             "<mul0.outputs[1]: <I64>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add1_rb_concat.input_uses[0]: <I64>>, "
             "<add1_rb_concat.input_uses[1]: <I64>>, "
             "<add1_rb_concat.input_uses[2]: <I64>>, "
@@ -357,7 +388,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<mul1.outputs[0]: <I64*3>>, "
             "<add1_rb_concat.outputs[0]: <I64*3>>, "
             "<clear_ca1.outputs[0]: <CA>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add1.input_uses[0]: <I64*3>>, "
             "<add1.input_uses[1]: <I64*3>>, "
             "<add1.input_uses[2]: <CA>>, "
@@ -367,7 +398,7 @@ class TestToomCook(unittest.TestCase):
             "name='add1')",
             "Op(kind=OpKind.Spread, "
             "input_vals=[<add1.outputs[0]: <I64*3>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add1_rt_spread.input_uses[0]: <I64*3>>, "
             "<add1_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
             "outputs=(<add1_rt_spread.outputs[0]: <I64>>, "
@@ -386,7 +417,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
             "<rhs_spread.outputs[2]: <I64>>, "
             "<zero.outputs[0]: <I64>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<mul2.input_uses[0]: <I64*3>>, "
             "<mul2.input_uses[1]: <I64>>, "
             "<mul2.input_uses[2]: <I64>>, "
@@ -398,7 +429,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<add1_rt_spread.outputs[1]: <I64>>, "
             "<add1_rt_spread.outputs[2]: <I64>>, "
             "<add_hi1.outputs[0]: <I64>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add2_rb_concat.input_uses[0]: <I64>>, "
             "<add2_rb_concat.input_uses[1]: <I64>>, "
             "<add2_rb_concat.input_uses[2]: <I64>>, "
@@ -414,7 +445,7 @@ class TestToomCook(unittest.TestCase):
             "input_vals=[<mul2.outputs[0]: <I64*3>>, "
             "<add2_rb_concat.outputs[0]: <I64*3>>, "
             "<clear_ca2.outputs[0]: <CA>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add2.input_uses[0]: <I64*3>>, "
             "<add2.input_uses[1]: <I64*3>>, "
             "<add2.input_uses[2]: <CA>>, "
@@ -424,7 +455,7 @@ class TestToomCook(unittest.TestCase):
             "name='add2')",
             "Op(kind=OpKind.Spread, "
             "input_vals=[<add2.outputs[0]: <I64*3>>, "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add2_rt_spread.input_uses[0]: <I64*3>>, "
             "<add2_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
             "outputs=(<add2_rt_spread.outputs[0]: <I64>>, "
@@ -464,12 +495,12 @@ class TestToomCook(unittest.TestCase):
             "Op(kind=OpKind.SetVLI, "
             "input_vals=[], "
             "input_uses=(), immediates=[6], "
-            "outputs=(<setvl6.outputs[0]: <VL_MAXVL>>,), "
-            "name='setvl6')",
+            "outputs=(<dest_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='dest_setvl')",
             "Op(kind=OpKind.SvStd, "
             "input_vals=[<concat_retval.outputs[0]: <I64*6>>, "
             "<ptr_in.outputs[0]: <I64>>, "
-            "<setvl6.outputs[0]: <VL_MAXVL>>], "
+            "<dest_setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<store_dest.input_uses[0]: <I64*6>>, "
             "<store_dest.input_uses[1]: <I64>>, "
             "<store_dest.input_uses[2]: <VL_MAXVL>>), immediates=[0], "
@@ -478,7 +509,8 @@ class TestToomCook(unittest.TestCase):
         ])
 
     def test_simple_mul_192x192_reg_alloc(self):
-        code = SimpleMul192x192()
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
+        code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
         assigned_registers = allocate_registers(fn)
         self.assertEqual(
@@ -491,7 +523,7 @@ class TestToomCook(unittest.TestCase):
             "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
             "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<setvl6.outputs[0]: <VL_MAXVL>>: "
+            "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
@@ -717,7 +749,7 @@ class TestToomCook(unittest.TestCase):
             "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
             "<zero.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
+            "<lhs_setvl3.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<rhs_spread.out2.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
@@ -737,7 +769,7 @@ class TestToomCook(unittest.TestCase):
             "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
             "<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+            "<rhs_setvl2.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
@@ -749,6 +781,8 @@ class TestToomCook(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
             "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
             "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
@@ -759,7 +793,7 @@ class TestToomCook(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
-            "<setvl3.outputs[0]: <VL_MAXVL>>: "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<ptr_in.out0.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
@@ -768,7 +802,8 @@ class TestToomCook(unittest.TestCase):
             "}")
 
     def test_simple_mul_192x192_asm(self):
-        code = SimpleMul192x192()
+        self.skipTest("WIP")  # FIXME: finish fixing simple_mul
+        code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
         assigned_registers = allocate_registers(fn)
         gen_asm_state = GenAsmState(assigned_registers)
@@ -781,6 +816,7 @@ class TestToomCook(unittest.TestCase):
             'sv.ld *3, 48(6)',
             'setvl 0, 0, 3, 0, 1, 1',
             'sv.or *20, *3, *3',
+            'setvl 0, 0, 3, 0, 1, 1',
             'or 6, 23, 23',
             'setvl 0, 0, 3, 0, 1, 1',
             'sv.ld *3, 72(6)',
@@ -885,6 +921,23 @@ class TestToomCook(unittest.TestCase):
             'sv.std *4, 0(3)'
         ])
 
+    def test_toom_2_mul_256x256_asm(self):
+        self.skipTest("WIP")  # FIXME: finish
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        instances = TOOM_2, TOOM_2
+
+        def mul(fn, lhs, rhs):
+            # type: (Fn, SSAVal, SSAVal) -> SSAVal
+            return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+                                 rhs_signed=False, instances=instances)
+        code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3)
+        fn = code.fn
+        assigned_registers = allocate_registers(fn)
+        gen_asm_state = GenAsmState(assigned_registers)
+        fn.gen_asm(gen_asm_state)
+        self.assertEqual(gen_asm_state.output, [
+        ])
+
 
 if __name__ == "__main__":
     unittest.main()