register_allocator2.py works!
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 08:23:59 +0000 (00:23 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 08:23:59 +0000 (00:23 -0800)
src/bigint_presentation_code/_tests/test_compiler_ir2.py
src/bigint_presentation_code/_tests/test_register_allocator2.py [new file with mode: 0644]
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py

index 116326a7c1a588ed2014333792aaa7f726051ec1..e53f871a588fcb73ee7e9711e4da63e2cc27a252 100644 (file)
@@ -1,7 +1,7 @@
 import unittest
 
 from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
-                                                   FnAnalysis, OpKind, OpStage,
+                                                   FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage,
                                                    PreRASimState, ProgramPoint,
                                                    SSAVal)
 
@@ -49,10 +49,9 @@ class TestCompilerIR(unittest.TestCase):
     def test_fn_analysis(self):
         fn, _arg = self.make_add_fn()
         fn_analysis = FnAnalysis(fn)
-        print(repr(fn_analysis))
         self.assertEqual(
-            repr(fn_analysis),
-            "FnAnalysis(fn=<Fn>, uses=FMap({"
+            repr(fn_analysis.uses),
+            "FMap({"
             "<arg.outputs[0]: <I64>>: OFSet(["
             "<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
             "<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
@@ -66,8 +65,11 @@ class TestCompilerIR(unittest.TestCase):
             "<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
             "<add.outputs[0]: <I64*32>>: OFSet(["
             "<st.input_uses[0]: <I64*32>>]), "
-            "<add.outputs[1]: <CA>>: OFSet()}), "
-            "op_indexes=FMap({"
+            "<add.outputs[1]: <CA>>: OFSet()})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.op_indexes),
+            "FMap({"
             "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
             "immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
             "name='arg'): 0, "
@@ -97,16 +99,22 @@ class TestCompilerIR(unittest.TestCase):
             "<vl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<st.input_uses[0]: <I64*32>>, "
             "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
-            "immediates=[0], outputs=(), name='st'): 6}), "
-            "live_ranges=FMap({"
+            "immediates=[0], outputs=(), name='st'): 6})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.live_ranges),
+            "FMap({"
             "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
             "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
             "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
             "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
             "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
             "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
-            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
-            "live_at=FMap({"
+            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.live_at),
+            "FMap({"
             "<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
             "<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
             "<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
@@ -142,16 +150,22 @@ class TestCompilerIR(unittest.TestCase):
             "<ops[6]:Early>: OFSet(["
             "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
             "<add.outputs[0]: <I64*32>>]), "
-            "<ops[6]:Late>: OFSet()}), "
-            "def_program_ranges=FMap({"
+            "<ops[6]:Late>: OFSet()})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.def_program_ranges),
+            "FMap({"
             "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
             "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
             "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
             "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
             "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
             "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
-            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
-            "use_program_points=FMap({"
+            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.use_program_points),
+            "FMap({"
             "<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
             "<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
             "<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
@@ -161,8 +175,11 @@ class TestCompilerIR(unittest.TestCase):
             "<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
             "<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
             "<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
-            "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>}), "
-            "all_program_points=<range:ops[0]:Early..ops[7]:Early>)")
+            "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.all_program_points),
+            "<range:ops[0]:Early..ops[7]:Early>")
 
     def test_repr(self):
         fn, _arg = self.make_add_fn()
@@ -286,7 +303,7 @@ class TestCompilerIR(unittest.TestCase):
             "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None, "
+            "tied_input_index=2, spread_index=None, "
             "write_stage=OpStage.Early)), maxvl=32)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
@@ -331,9 +348,15 @@ class TestCompilerIR(unittest.TestCase):
             "input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
             "immediates=[], "
             "outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), "
+            "immediates=[32], "
+            "outputs=(<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='ld.inp1.setvl')",
             "Op(kind=OpKind.SvLd, "
             "input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
-            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<ld.input_uses[0]: <I64>>, "
             "<ld.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[0], "
@@ -352,8 +375,14 @@ class TestCompilerIR(unittest.TestCase):
             "immediates=[], "
             "outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
             "name='ld.out0.copy')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), "
+            "immediates=[32], "
+            "outputs=(<li.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='li.inp0.setvl')",
             "Op(kind=OpKind.SvLI, "
-            "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<li.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
             "immediates=[0], "
             "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
@@ -404,10 +433,16 @@ class TestCompilerIR(unittest.TestCase):
             "immediates=[], "
             "outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
             "name='add.inp1.copy')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), "
+            "immediates=[32], "
+            "outputs=(<add.inp3.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='add.inp3.setvl')",
             "Op(kind=OpKind.SvAddE, "
             "input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
             "<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
-            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<add.input_uses[0]: <I64*32>>, "
             "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
             "<add.input_uses[3]: <VL_MAXVL>>), "
@@ -448,9 +483,16 @@ class TestCompilerIR(unittest.TestCase):
             "immediates=[], "
             "outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
             "name='st.inp1.copy')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), "
+            "immediates=[32], "
+            "outputs=(<st.inp2.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='st.inp2.setvl')",
             "Op(kind=OpKind.SvStd, "
             "input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
-            "<st.inp1.copy.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "<st.inp1.copy.outputs[0]: <I64>>, "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>], "
             "input_uses=(<st.input_uses[0]: <I64*32>>, "
             "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
             "immediates=[0], "
@@ -474,7 +516,7 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
-            "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+            "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SetVLI, "
@@ -488,7 +530,7 @@ class TestCompilerIR(unittest.TestCase):
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
-            "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+            "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early),), "
             "outputs=("
@@ -497,6 +539,13 @@ class TestCompilerIR(unittest.TestCase):
             "ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=1)",
+            "OpProperties(kind=OpKind.SetVLI, "
+            "inputs=(), "
+            "outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvLd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -533,9 +582,16 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=32)",
+            "OpProperties(kind=OpKind.SetVLI, "
+            "inputs=(), "
+            "outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvLI, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -567,7 +623,7 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SetCA, "
@@ -588,7 +644,7 @@ class TestCompilerIR(unittest.TestCase):
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -611,7 +667,7 @@ class TestCompilerIR(unittest.TestCase):
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -623,6 +679,13 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=32)",
+            "OpProperties(kind=OpKind.SetVLI, "
+            "inputs=(), "
+            "outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvAddE, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -648,7 +711,7 @@ class TestCompilerIR(unittest.TestCase):
             "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None, "
+            "tied_input_index=2, spread_index=None, "
             "write_stage=OpStage.Early)), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
@@ -670,7 +733,7 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
@@ -684,7 +747,7 @@ class TestCompilerIR(unittest.TestCase):
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
-            "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+            "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -700,7 +763,7 @@ class TestCompilerIR(unittest.TestCase):
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
-            "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+            "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early),), "
             "outputs=("
@@ -709,6 +772,13 @@ class TestCompilerIR(unittest.TestCase):
             "ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),), maxvl=1)",
+            "OpProperties(kind=OpKind.SetVLI, "
+            "inputs=(), "
+            "outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -810,6 +880,59 @@ class TestCompilerIR(unittest.TestCase):
             "0x001f0: <0x0000000000000000>,\n"
             "0x001f8: <0x0000000000000000>})")
 
+    def test_gen_asm(self):
+        fn, _arg = self.make_add_fn()
+        fn.pre_ra_insert_copies()
+        VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)
+        CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1)
+        state = GenAsmState(allocated_locs={
+            fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+            fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+            fn.ops[2].outputs[0]: VL_LOC,
+            fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+            fn.ops[4].outputs[0]: VL_LOC,
+            fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[6].outputs[0]: VL_LOC,
+            fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[8].outputs[0]: VL_LOC,
+            fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+            fn.ops[10].outputs[0]: VL_LOC,
+            fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+            fn.ops[12].outputs[0]: CA_LOC,
+            fn.ops[13].outputs[0]: VL_LOC,
+            fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[15].outputs[0]: VL_LOC,
+            fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+            fn.ops[17].outputs[0]: VL_LOC,
+            fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[18].outputs[1]: CA_LOC,
+            fn.ops[19].outputs[0]: VL_LOC,
+            fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[21].outputs[0]: VL_LOC,
+            fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+            fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+            fn.ops[24].outputs[0]: VL_LOC,
+        })
+        fn.gen_asm(state)
+        self.assertEqual(state.output, [
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.ld *32, 0(3)',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.addi *64, 0, 0',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'subfc 0, 0, 0',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.adde *32, *32, *64',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.std *32, 0(3)',
+        ])
+
 
 if __name__ == "__main__":
     _ = unittest.main()
diff --git a/src/bigint_presentation_code/_tests/test_register_allocator2.py b/src/bigint_presentation_code/_tests/test_register_allocator2.py
new file mode 100644 (file)
index 0000000..697417f
--- /dev/null
@@ -0,0 +1,182 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir2 import Fn, GenAsmState, OpKind, SSAVal
+from bigint_presentation_code.register_allocator2 import allocate_registers
+
+
+class TestCompilerIR(unittest.TestCase):
+    maxDiff = None
+
+    def make_add_fn(self):
+        # type: () -> tuple[Fn, SSAVal]
+        fn = Fn()
+        op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
+        arg = op0.outputs[0]
+        MAXVL = 32
+        op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
+        vl = op1.outputs[0]
+        op2 = fn.append_new_op(
+            OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
+            name="ld")
+        a = op2.outputs[0]
+        op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+                               maxvl=MAXVL, name="li")
+        b = op3.outputs[0]
+        op4 = fn.append_new_op(OpKind.SetCA, name="ca")
+        ca = op4.outputs[0]
+        op5 = fn.append_new_op(
+            OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
+        s = op5.outputs[0]
+        _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+                             immediates=[0], maxvl=MAXVL, name="st")
+        return fn, arg
+
+    def test_register_allocate(self):
+        fn, _arg = self.make_add_fn()
+        reg_assignments = allocate_registers(fn)
+
+        self.assertEqual(
+            repr(reg_assignments),
+            "{<add.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.inp1.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<add.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<st.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ca.outputs[0]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<vl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+        )
+
+    def test_gen_asm(self):
+        fn, _arg = self.make_add_fn()
+        reg_assignments = allocate_registers(fn)
+
+        self.assertEqual(
+            repr(reg_assignments),
+            "{<add.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.inp1.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<add.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<st.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ca.outputs[0]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<vl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+        )
+        state = GenAsmState(reg_assignments)
+        fn.gen_asm(state)
+        self.assertEqual(state.output, [
+            'or 4, 3, 3',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'or 3, 4, 4',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.ld *14, 0(3)',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.or *46, *14, *14',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.addi *14, 0, 0',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'subfc 0, 0, 0',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.or *78, *46, *46',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.or *46, *14, *14',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.adde *14, *78, *46',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'or 3, 4, 4',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.std *14, 0(3)',
+        ])
+
+
+if __name__ == "__main__":
+    _ = unittest.main()
index 9f911969c7f96235d472f7fb7b15cac83c9491a9..b256a07b7a58868bdcce9d203cdf96389fcd6dc6 100644 (file)
@@ -4,7 +4,7 @@ from enum import Enum, unique
 from functools import lru_cache, total_ordering
 from io import StringIO
 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
-                    Mapping, Sequence, TypeVar, overload)
+                    Mapping, Sequence, TypeVar, Union, overload)
 from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
@@ -59,10 +59,16 @@ class Fn:
         for op in self.ops:
             op.pre_ra_sim(state)
 
+    def gen_asm(self, state):
+        # type: (GenAsmState) -> None
+        for op in self.ops:
+            op.gen_asm(state)
+
     def pre_ra_insert_copies(self):
         # type: () -> None
         orig_ops = list(self.ops)
         copied_outputs = {}  # type: dict[SSAVal, SSAVal]
+        setvli_outputs = {}  # type: dict[SSAVal, Op]
         self.ops.clear()
         for op in orig_ops:
             for i in range(len(op.input_vals)):
@@ -84,12 +90,22 @@ class Fn:
                     op.input_vals[i] = mv.outputs[0]
                 elif inp.ty.base_ty is BaseTy.CA \
                         or inp.ty.base_ty is BaseTy.VL_MAXVL:
-                    # all copies would be no-ops, so we don't need to copy
+                    # all copies would be no-ops, so we don't need to copy,
+                    # though we do need to rematerialize SetVLI ops right
+                    # before the ops VL
+                    if inp in setvli_outputs:
+                        setvl = self.append_new_op(
+                            OpKind.SetVLI,
+                            immediates=setvli_outputs[inp].immediates,
+                            name=f"{op.name}.inp{i}.setvl")
+                        inp = setvl.outputs[0]
                     op.input_vals[i] = inp
                 else:
                     assert_never(inp.ty.base_ty)
             self.ops.append(op)
             for i, out in enumerate(op.outputs):
+                if op.kind is OpKind.SetVLI:
+                    setvli_outputs[out] = op
                 if out.ty.base_ty is BaseTy.I64:
                     maxvl = out.ty.reg_len
                     if out.ty.reg_len != 1:
@@ -263,7 +279,7 @@ class ProgramRange(Sequence[ProgramPoint]):
         return f"<range:{start}..{stop}>"
 
 
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
 @final
 class FnAnalysis:
     __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
@@ -335,6 +351,10 @@ class FnAnalysis:
         # type: () -> int
         return hash(self.fn)
 
+    def __repr__(self):
+        # type: () -> str
+        return "<FnAnalysis>"
+
 
 @unique
 @final
@@ -425,7 +445,7 @@ class LocKind(Enum):
     def loc_count(self):
         # type: () -> int
         if self is LocKind.StackI64:
-            return 1024
+            return 512
         if self is LocKind.GPR or self is LocKind.CA \
                 or self is LocKind.VL_MAXVL:
             return self.base_ty.max_reg_len
@@ -606,11 +626,11 @@ class Loc:
         # type: (Ty, int) -> Loc
         if subloc_ty.base_ty != self.kind.base_ty:
             raise ValueError("BaseTy mismatch")
-        start = self.start + offset
-        if offset < 0 or start + subloc_ty.reg_len > self.reg_len:
+        if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
             raise ValueError("invalid sub-Loc: offset and/or "
                              "subloc_ty.reg_len out of range")
-        return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len)
+        return Loc(kind=self.kind,
+                   start=self.start + offset, reg_len=subloc_ty.reg_len)
 
 
 SPECIAL_GPRS = (
@@ -726,13 +746,26 @@ class LocSet(AbstractSet[Loc]):
     def __len__(self):
         return self.__len
 
+    __HASHES = {}  # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int]
+
     @cached_property
     def __hash(self):
-        return super()._hash()
+        # cache hashes to avoid slow LocSet iteration
+        key = self.ty, self.starts
+        retval = self.__HASHES.get(key, None)
+        if retval is None:
+            self.__HASHES[key] = retval = super(LocSet, self)._hash()
+        return retval
 
     def __hash__(self):
         return self.__hash
 
+    def __eq__(self, __other):
+        # type: (LocSet | Any) -> bool
+        if isinstance(__other, LocSet):
+            return self.ty == __other.ty and self.starts == __other.starts
+        return super().__eq__(__other)
+
     @lru_cache(maxsize=None, typed=True)
     def max_conflicts_with(self, other):
         # type: (LocSet | Loc) -> int
@@ -861,7 +894,7 @@ class OperandDesc:
             raise ValueError("loc_set_before_spread must not be empty")
         self.loc_set_before_spread = loc_set_before_spread
         self.tied_input_index = tied_input_index
-        if self.tied_input_index is not None and self.spread_index is not None:
+        if self.tied_input_index is not None and spread_index is not None:
             raise ValueError("operand can't be both spread and tied")
         self.spread_index = spread_index
         self.write_stage = write_stage
@@ -1121,7 +1154,7 @@ class OpKind(Enum):
     SvAddE = GenericOpProperties(
         demo_asm="sv.adde *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
-        outputs=[OD_EXTRA3_VGPR, OD_CA],
+        outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
     _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
     _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
@@ -1151,7 +1184,7 @@ class OpKind(Enum):
     SvSubFE = GenericOpProperties(
         demo_asm="sv.subfe *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
-        outputs=[OD_EXTRA3_VGPR, OD_CA],
+        outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
     _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
     _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
@@ -1258,26 +1291,50 @@ class OpKind(Enum):
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
     @staticmethod
-    def __veccopytoreg_gen_asm(op, state):
-        # type: (Op, GenAsmState) -> None
-        src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64))
-        dest_loc = state.loc(op.outputs[0], LocKind.GPR)
-        RT = state.vgpr(dest_loc)
+    def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
+        # type: (Loc, Loc, bool, GenAsmState) -> None
+        sv = "sv." if is_vec else ""
+        rev = ""
+        if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+            rev = "/mrr"
         if src_loc == dest_loc:
             return  # no-op
-        assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \
-            "checked by loc()"
+        if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+            raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
+        if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+            raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
         if src_loc.kind is LocKind.StackI64:
+            if dest_loc.kind is LocKind.StackI64:
+                raise ValueError(
+                    f"can't copy from stack to stack: {src_loc} {dest_loc}")
+            elif dest_loc.kind is not LocKind.GPR:
+                assert_never(dest_loc.kind)
             src = state.stack(src_loc)
-            state.writeln(f"sv.ld {RT}, {src}")
-            return
-        elif src_loc.kind is not LocKind.GPR:
+            dest = state.gpr(dest_loc, is_vec=is_vec)
+            state.writeln(f"{sv}ld {dest}, {src}")
+        elif dest_loc.kind is LocKind.StackI64:
+            if src_loc.kind is not LocKind.GPR:
+                assert_never(src_loc.kind)
+            src = state.gpr(src_loc, is_vec=is_vec)
+            dest = state.stack(dest_loc)
+            state.writeln(f"{sv}std {src}, {dest}")
+        elif src_loc.kind is LocKind.GPR:
+            if dest_loc.kind is not LocKind.GPR:
+                assert_never(dest_loc.kind)
+            src = state.gpr(src_loc, is_vec=is_vec)
+            dest = state.gpr(dest_loc, is_vec=is_vec)
+            state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
+        else:
             assert_never(src_loc.kind)
-        rev = ""
-        if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
-            rev = "/mrr"
-        src = state.vgpr(src_loc)
-        state.writeln(f"sv.or{rev} {RT}, {src}, {src}")
+
+    @staticmethod
+    def __veccopytoreg_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(
+                op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=True, state=state)
 
     VecCopyToReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
@@ -1296,11 +1353,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __veccopyfromreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(
+                op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+            is_vec=True, state=state)
     VecCopyFromReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1319,11 +1379,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __copytoreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(
+                op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=False, state=state)
     CopyToReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1346,11 +1409,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __copyfromreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(
+                op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+            is_vec=False, state=state)
     CopyFromReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1374,11 +1440,13 @@ class OpKind(Enum):
         state.ssa_vals[op.outputs[0]] = tuple(
             state.ssa_vals[i][0] for i in op.input_vals[:-1])
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __concat_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=True, state=state)
     Concat = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1398,11 +1466,13 @@ class OpKind(Enum):
         for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
             state.ssa_vals[op.outputs[idx]] = inp,
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __spread_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(op.outputs, LocKind.GPR),
+            is_vec=True, state=state)
     Spread = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1429,11 +1499,13 @@ class OpKind(Enum):
             RT.append(v & GPR_VALUE_MASK)
         state.ssa_vals[op.outputs[0]] = tuple(RT)
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __svld_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RA = state.sgpr(op.input_vals[0])
+        RT = state.vgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"sv.ld {RT}, {imm}({RA})")
     SvLd = GenericOpProperties(
         demo_asm="sv.ld *RT, imm(RA)",
         inputs=[OD_EXTRA3_SGPR, OD_VL],
@@ -1451,11 +1523,13 @@ class OpKind(Enum):
         v = state.load(addr)
         state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __ld_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RA = state.sgpr(op.input_vals[0])
+        RT = state.sgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"ld {RT}, {imm}({RA})")
     Ld = GenericOpProperties(
         demo_asm="ld RT, imm(RA)",
         inputs=[OD_BASE_SGPR],
@@ -1475,11 +1549,13 @@ class OpKind(Enum):
         for i in range(VL):
             state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __svstd_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RS = state.vgpr(op.input_vals[0])
+        RA = state.sgpr(op.input_vals[1])
+        imm = op.immediates[0]
+        state.writeln(f"sv.std {RS}, {imm}({RA})")
     SvStd = GenericOpProperties(
         demo_asm="sv.std *RS, imm(RA)",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
@@ -1498,13 +1574,15 @@ class OpKind(Enum):
         addr = RA + op.immediates[0]
         state.store(addr, value=RS)
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __std_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RS = state.sgpr(op.input_vals[0])
+        RA = state.sgpr(op.input_vals[1])
+        imm = op.immediates[0]
+        state.writeln(f"std {RS}, {imm}({RA})")
     Std = GenericOpProperties(
-        demo_asm="std RT, imm(RA)",
+        demo_asm="std RS, imm(RA)",
         inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
         outputs=[],
         immediates=[IMM_S16],
@@ -1877,6 +1955,15 @@ class Op:
                     f"expected {out.ty.reg_len} found "
                     f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
 
+    def gen_asm(self, state):
+        # type: (GenAsmState) -> None
+        all_loc_kinds = tuple(LocKind)
+        for inp in self.input_vals:
+            state.loc(inp, expected_kinds=all_loc_kinds)
+        for out in self.outputs:
+            state.loc(out, expected_kinds=all_loc_kinds)
+        self.kind.gen_asm(self, state)
+
 
 GPR_SIZE_IN_BYTES = 8
 BITS_IN_BYTE = 8
@@ -1994,7 +2081,7 @@ class PreRASimState:
 class GenAsmState:
     __slots__ = "allocated_locs", "output"
 
-    def __init__(self, allocated_locs, output):
+    def __init__(self, allocated_locs, output=None):
         # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
         super().__init__()
         self.allocated_locs = FMap(allocated_locs)
@@ -2006,38 +2093,49 @@ class GenAsmState:
             output = []
         self.output = output
 
-    def loc(self, ssa_val_or_loc, expected_kinds):
-        # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc
-        if isinstance(ssa_val_or_loc, SSAVal):
-            retval = self.allocated_locs[ssa_val_or_loc]
-        else:
-            retval = ssa_val_or_loc
+    __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
+
+    def loc(self, ssa_val_or_locs, expected_kinds):
+        # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
+        if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
+            ssa_val_or_locs = [ssa_val_or_locs]
+        locs = []  # type: list[Loc]
+        for i in ssa_val_or_locs:
+            if isinstance(i, SSAVal):
+                locs.append(self.allocated_locs[i])
+            else:
+                locs.append(i)
+        if len(locs) == 0:
+            raise ValueError("invalid Loc sequence: must not be empty")
+        retval = locs[0].try_concat(*locs[1:])
+        if retval is None:
+            raise ValueError("invalid Loc sequence: try_concat failed")
         if isinstance(expected_kinds, LocKind):
             expected_kinds = expected_kinds,
         if retval.kind not in expected_kinds:
             if len(expected_kinds) == 1:
                 expected_kinds = expected_kinds[0]
-            raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found "
+            raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
                              f"{retval.kind} expected {expected_kinds}")
         return retval
 
-    def gpr(self, ssa_val_or_loc, is_vec):
-        # type: (SSAVal | Loc, bool) -> str
-        loc = self.loc(ssa_val_or_loc, LocKind.GPR)
+    def gpr(self, ssa_val_or_locs, is_vec):
+        # type: (__SSA_VAL_OR_LOCS, bool) -> str
+        loc = self.loc(ssa_val_or_locs, LocKind.GPR)
         vec_str = "*" if is_vec else ""
         return vec_str + str(loc.start)
 
-    def sgpr(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        return self.gpr(ssa_val_or_loc, is_vec=False)
+    def sgpr(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        return self.gpr(ssa_val_or_locs, is_vec=False)
 
-    def vgpr(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        return self.gpr(ssa_val_or_loc, is_vec=True)
+    def vgpr(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        return self.gpr(ssa_val_or_locs, is_vec=True)
 
-    def stack(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        loc = self.loc(ssa_val_or_loc, LocKind.StackI64)
+    def stack(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
         return f"{loc.start}(1)"
 
     def writeln(self, *line_segments):
index d3ca3983c9dfb114fb0a7454484e9a6c3688a200..aefa2914f2cc58ad6a55994f04fbd113cf6758d8 100644 (file)
@@ -58,7 +58,7 @@ class BadMergedSSAVal(ValueError):
     pass
 
 
-@plain_data(frozen=True)
+@plain_data(frozen=True, repr=False)
 @final
 class MergedSSAVal:
     """a set of `SSAVal`s along with their offsets, all register allocated as
@@ -236,6 +236,10 @@ class MergedSSAVal:
             stop = max(stop, live_range.stop)
         return ProgramRange(start=start, stop=stop)
 
+    def __repr__(self):
+        return (f"MergedSSAVal({self.fn_analysis}, "
+                f"ssa_val_offsets={self.ssa_val_offsets})")
+
 
 @final
 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
@@ -354,13 +358,15 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
             self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
         return retval
 
-    def __repr__(self):
-        # type: () -> str
-        s = ",\n".join(repr(v) for v in self.__map.values())
+    def __repr__(self, repr_state=None):
+        # type: (None | IGNodeReprState) -> str
+        if repr_state is None:
+            repr_state = IGNodeReprState()
+        s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
         return f"MergedSSAValToIGNodeMap({{{s}}})"
 
 
-@plain_data(frozen=True)
+@plain_data(frozen=True, repr=False)
 @final
 class InterferenceGraph:
     __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
@@ -400,6 +406,23 @@ class InterferenceGraph:
                     retval.merge(out.tied_input.ssa_val, out)
         return retval
 
+    def __repr__(self, repr_state=None):
+        # type: (None | IGNodeReprState) -> str
+        if repr_state is None:
+            repr_state = IGNodeReprState()
+        s = self.nodes.__repr__(repr_state)
+        return f"InterferenceGraph(nodes={s}, <...>)"
+
+
+@plain_data(repr=False)
+class IGNodeReprState:
+    __slots__ = "node_ids", "did_full_repr"
+
+    def __init__(self):
+        super().__init__()
+        self.node_ids = {}  # type: dict[IGNode, int]
+        self.did_full_repr = OSet()  # type: OSet[IGNode]
+
 
 @final
 class IGNode:
@@ -427,17 +450,20 @@ class IGNode:
         # type: () -> int
         return hash(self.merged_ssa_val)
 
-    def __repr__(self, nodes=None):
-        # type: (None | dict[IGNode, int]) -> str
-        if nodes is None:
-            nodes = {}
-        if self in nodes:
-            return f"<IGNode #{nodes[self]}>"
-        nodes[self] = len(nodes)
-        edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
-        return (f"IGNode(#{nodes[self]}, "
+    def __repr__(self, repr_state=None, short=False):
+        # type: (None | IGNodeReprState, bool) -> str
+        if repr_state is None:
+            repr_state = IGNodeReprState()
+        node_id = repr_state.node_ids.get(self, None)
+        if node_id is None:
+            repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
+        if short or self in repr_state.did_full_repr:
+            return f"<IGNode #{node_id}>"
+        repr_state.did_full_repr.add(self)
+        edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
+        return (f"IGNode(#{node_id}, "
                 f"merged_ssa_val={self.merged_ssa_val}, "
-                f"edges={edges}, "
+                f"edges={{{edges}}}, "
                 f"loc={self.loc})")
 
     @property
@@ -460,6 +486,19 @@ class AllocationFailedError(Exception):
         self.node = node
         self.interference_graph = interference_graph
 
+    def __repr__(self, repr_state=None):
+        # type: (None | IGNodeReprState) -> str
+        if repr_state is None:
+            repr_state = IGNodeReprState()
+        return (f"{__class__.__name__}({self.args[0]!r}, "
+                f"node={self.node.__repr__(repr_state, True)}, "
+                f"interference_graph="
+                f"{self.interference_graph.__repr__(repr_state)})")
+
+    def __str__(self):
+        # type: () -> str
+        return self.__repr__()
+
 
 def allocate_registers(fn):
     # type: (Fn) -> dict[SSAVal, Loc]