working on code some more
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_register_allocator.py
index 78046fe83ef9395db4e78e92aa9d1ec90dec282f..0d3ccd261a64ef2424352de23ea9840c9065ea53 100644 (file)
@@ -1,5 +1,6 @@
 import sys
 import unittest
+from pathlib import Path
 
 from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
                                                   SSAVal)
@@ -41,22 +42,37 @@ class TestRegisterAllocator(unittest.TestCase):
 
         self.assertEqual(
             repr(reg_assignments),
-            "{<add.outputs[0]: <I64*32>>: "
+            "{"
+            "<add.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.inp1.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<add.inp0.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
-            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<st.inp1.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<st.inp0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<add.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<ca.outputs[0]: <CA>>: "
@@ -69,30 +85,17 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.inp0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<vl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<arg.out0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
-            "<arg.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
+            "}"
         )
 
     def test_gen_asm(self):
@@ -249,71 +252,51 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
             "}"
         )
-        # FIXME: is_copy_related is not correct, it's missing a bunch of
-        # edges (which aren't interference edges)
-        self.assertEqual(graphs, {
-            'initial': r"""graph {
-    graph [pack = true]
-    "0" [label = "<arg.outputs[0]: <I64>>: 0"]
-    "1" [label = "<arg.out0.copy.outputs[0]: <I64>>: 0"]
-    "2" [label = "<vl.outputs[0]: <VL_MAXVL>>: 0"]
-    "3" [label = "<ld.inp0.copy.outputs[0]: <I64>>: 0"]
-    "4" [label = "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "5" [label = "<ld.outputs[0]: <I64*32>>: 0"]
-    "6" [label = "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "7" [label = "<ld.out0.copy.outputs[0]: <I64*32>>: 0"]
-    "8" [label = "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "9" [label = "<li.outputs[0]: <I64*32>>: 0"]
-    "10" [label = "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "11" [label = "<li.out0.copy.outputs[0]: <I64*32>>: 0"]
-    "12" [label = "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "13" [label = "<add.inp0.copy.outputs[0]: <I64*32>>: 0"]
-    "14" [label = "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "15" [label = "<add.inp1.copy.outputs[0]: <I64*32>>: 0"]
-    "16" [label = "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "17" [label = "<add.outputs[0]: <I64*32>>: 0"]
-    "18" [label = "<ca.outputs[0]: <CA>>: 0\n<add.outputs[1]: <CA>>: 0"]
-    "19" [label = "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "20" [label = "<add.out0.copy.outputs[0]: <I64*32>>: 0"]
-    "21" [label = "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "22" [label = "<st.inp0.copy.outputs[0]: <I64*32>>: 0"]
-    "23" [label = "<st.inp1.copy.outputs[0]: <I64>>: 0"]
-    "24" [label = "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: 0"]
-    "0" -- "1" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "0" -- "3" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "0" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "1" -- "3" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "3" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "1" -- "5" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "7" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "9" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "11" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "13" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "15" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "20" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "22" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "1" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "3" -- "5" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "3" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "5" -- "7" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "5" -- "13" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "7" -- "9" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "7" -- "11" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "7" -- "13" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "9" -- "11" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "9" -- "15" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "11" -- "13" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "11" -- "15" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "13" -- "15" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "13" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "15" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-    "17" -- "20" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "17" -- "22" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "20" -- "22" [label = "copy related", color = "blue", style = "dashed", decorate = true]
-    "22" -- "23" [label = "interferes", color = "darkred", style = "bold", decorate = true]
-}"""
-        })
+        # load expected graphs
+        data_path = Path(__file__).with_suffix("")
+        data_path /= "test_register_allocate_graphs"
+        data_path /= "expected"
+        expected_graphs = {}  # type: dict[str, str]
+        expected_graph_names = [
+            'initial',
+            'step_0_simplify',
+            'step_1_simplify',
+            'step_2_simplify',
+            'step_3_simplify',
+            'step_4_simplify',
+            'step_5_simplify',
+            'step_6_simplify',
+            'step_7_simplify',
+            'step_8_simplify',
+            'step_9_simplify',
+            'step_10_simplify',
+            'step_11_simplify',
+            'step_12_copy_merge',
+            'step_12_copy_merge_result',
+            'step_13_copy_merge',
+            'step_13_copy_merge_result',
+            'step_14_copy_merge',
+            'step_14_copy_merge_result',
+            'step_15_copy_merge',
+            'step_15_copy_merge_result',
+            'step_16_simplify',
+            'step_17_freeze',
+            'step_18_freeze',
+            'step_19_simplify',
+            'step_20_copy_merge',
+            'step_20_copy_merge_result',
+            'step_21_copy_merge',
+            'step_21_copy_merge_result',
+            'step_22_simplify',
+            'step_23_copy_merge',
+            'step_23_copy_merge_result',
+            'step_24_simplify',
+            'final',
+        ]
+        for name in expected_graph_names:
+            file_path = (data_path / name).with_suffix(".dot")
+            expected_graphs[name] = file_path.read_text(encoding="utf-8")
+        self.assertEqual(graphs, expected_graphs)
 
     def test_register_allocate_spread(self):
         fn = Fn()