copy-merging works afaict! -- some tests still broken: out-of-date
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
index eadc0809afa4edb60705021033ef54affbae2489..dc3bb8d8985cad70b95a7acf419eb7266f9bb912 100644 (file)
@@ -1,4 +1,5 @@
 from contextlib import contextmanager
+import sys
 import unittest
 from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
 
@@ -9,6 +10,7 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
                                                   PostRASimState,
                                                   PreRASimState, SSAVal)
 from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
                                                 simple_mul)
 from bigint_presentation_code.util import OSet
@@ -77,21 +79,21 @@ class Mul:
             name="store_dest")
 
 
-def get_post_ra_state_factory(code):
-    # type: (Mul) -> _StateFactory
-    ssa_val_to_loc_map = allocate_registers(code.fn)
-
-    @contextmanager
-    def state_factory():
-        yield PostRASimState(
-            ssa_val_to_loc_map=ssa_val_to_loc_map,
-            memory={}, loc_values={})
-    return state_factory
-
-
 class TestToomCook(unittest.TestCase):
     maxDiff = None
 
+    def get_post_ra_state_factory(self, code):
+        # type: (Mul) -> _StateFactory
+        ssa_val_to_loc_map = allocate_registers(
+            code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
+
+        @contextmanager
+        def state_factory():
+            yield PostRASimState(
+                ssa_val_to_loc_map=ssa_val_to_loc_map,
+                memory={}, loc_values={})
+        return state_factory
+
     def test_toom_2_repr(self):
         TOOM_2 = ToomCookInstance.make_toom_2()
         # print(repr(repr(TOOM_2)))
@@ -269,7 +271,7 @@ class TestToomCook(unittest.TestCase):
             for rhs_signed in False, True:
                 self.tst_simple_mul_192x192_sim(
                     lhs_signed=lhs_signed, rhs_signed=rhs_signed,
-                    get_state_factory=get_post_ra_state_factory)
+                    get_state_factory=self.get_post_ra_state_factory)
 
     def tst_simple_mul_192x192_sim(
         self, lhs_signed,  # type: bool
@@ -485,7 +487,8 @@ class TestToomCook(unittest.TestCase):
     def test_simple_mul_192x192_reg_alloc(self):
         code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         self.assertEqual(
             repr(assigned_registers), "{"
             "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
@@ -865,150 +868,85 @@ class TestToomCook(unittest.TestCase):
     def test_simple_mul_192x192_asm(self):
         code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         gen_asm_state = GenAsmState(assigned_registers)
         fn.gen_asm(gen_asm_state)
         self.assertEqual(gen_asm_state.output, [
-            'or 27, 3, 3',
+            'or 9, 3, 3',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 6, 27, 27',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.ld *3, 48(6)',
+            'sv.ld *20, 48(9)',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *24, *3, *3',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 6, 27, 27',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.ld *3, 72(6)',
+            'sv.ld *10, 72(9)',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or/mrr *5, *3, *3',
-            'or 4, 5, 5',
-            'or 14, 6, 6',
-            'or 23, 7, 7',
-            'addi 3, 0, 0',
-            'or 22, 3, 3',
+            'addi 6, 0, 0',
+            'or 8, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'addi 3, 0, 0',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 4, 4',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *14, *20, 10, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 19, 6, 6',
+            'or 17, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 21, 3, 3',
-            'or 12, 4, 4',
-            'or 11, 5, 5',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 14, 14',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *3, *20, 11, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 18, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 17, 3, 3',
-            'or 16, 4, 4',
-            'or 15, 5, 5',
-            'addi 3, 0, 0',
-            'or 8, 3, 3',
-            'addi 3, 0, 0',
-            'or 14, 3, 3',
+            'addi 18, 0, 0',
+            'addi 7, 0, 0',
             'setvl 0, 0, 5, 0, 1, 1',
-            'or 3, 12, 12',
-            'or 4, 11, 11',
-            'or 5, 19, 19',
-            'or 6, 8, 8',
-            'or 7, 8, 8',
+            'or 19, 18, 18',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *8, *3, *3',
-            'or 3, 17, 17',
-            'or 4, 16, 16',
-            'or 5, 15, 15',
-            'or 6, 18, 18',
-            'or 7, 14, 14',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
             'addic 0, 0, 0',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *14, *8, *8',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *8, *3, *3',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.adde *3, *14, *8',
+            'sv.adde *15, *15, *3',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
-            'or 20, 3, 3',
-            'or 19, 4, 4',
-            'or 18, 5, 5',
-            'or 17, 6, 6',
-            'or 16, 7, 7',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 23, 23',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *3, *20, 12, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 15, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 14, 3, 3',
-            'or 12, 4, 4',
-            'or 11, 5, 5',
             'setvl 0, 0, 4, 0, 1, 1',
-            'or 3, 19, 19',
-            'or 4, 18, 18',
-            'or 5, 17, 17',
-            'or 6, 16, 16',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *7, *3, *3',
-            'or 3, 14, 14',
-            'or 4, 12, 12',
-            'or 5, 11, 11',
-            'or 6, 15, 15',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
             'addic 0, 0, 0',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *14, *7, *7',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *7, *3, *3',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.adde *3, *14, *7',
+            'sv.adde *16, *16, *3',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
-            'or 12, 3, 3',
-            'or 11, 4, 4',
-            'or 10, 5, 5',
-            'or 9, 6, 6',
             'setvl 0, 0, 6, 0, 1, 1',
-            'or 3, 21, 21',
-            'or 4, 20, 20',
-            'or 5, 12, 12',
-            'or 6, 11, 11',
-            'or 7, 10, 10',
-            'or 8, 9, 9',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
-            'sv.or/mrr *4, *3, *3',
-            'or 3, 27, 27',
             'setvl 0, 0, 6, 0, 1, 1',
-            'sv.std *4, 0(3)'
+            'sv.std *14, 0(9)',
         ])
 
     def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
@@ -1122,27 +1060,28 @@ class TestToomCook(unittest.TestCase):
     def test_toom_2_mul_256x256_uu_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=False, rhs_signed=False,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_su_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=True, rhs_signed=False,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_us_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=False, rhs_signed=True,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_ss_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=True, rhs_signed=True,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_asm(self):
         code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         gen_asm_state = GenAsmState(assigned_registers)
         fn.gen_asm(gen_asm_state)
         self.assertEqual(gen_asm_state.output, [