From: Jacob Lifshay Date: Tue, 8 Nov 2022 06:53:38 +0000 (-0800) Subject: register allocation and simulation works for simple mul 192x192! X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=088016bf30dfe3104094530cb1f909d3645cd6d5;p=bigint-presentation-code.git register allocation and simulation works for simple mul 192x192! --- diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir2.py b/src/bigint_presentation_code/_tests/test_compiler_ir2.py index e53f871..833dbc9 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir2.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir2.py @@ -1,9 +1,9 @@ import unittest -from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn, +from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, BaseTy, Fn, FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage, PreRASimState, ProgramPoint, - SSAVal) + SSAVal, Ty) class TestCompilerIR(unittest.TestCase): @@ -807,46 +807,14 @@ class TestCompilerIR(unittest.TestCase): size_in_bytes=GPR_SIZE_IN_BYTES) self.assertEqual( repr(state), - "PreRASimState(ssa_vals={>: (0x100,)}, " - "memory={\n" + "PreRASimState(memory={\n" "0x00100: <0xffffffffffffffff>,\n" - "0x00108: <0xabcdef0123456789>})") - fn.pre_ra_sim(state) + "0x00108: <0xabcdef0123456789>}, " + "ssa_vals={>: (0x100,)})") + fn.sim(state) self.assertEqual( repr(state), - "PreRASimState(ssa_vals={\n" - ">: (0x100,),\n" - ">: (0x20,),\n" - ">: (\n" - " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x1,),\n" - ">: (\n" - " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x0,),\n" - "}, memory={\n" + "PreRASimState(memory={\n" "0x00100: <0x0000000000000000>,\n" "0x00108: <0xabcdef012345678a>,\n" "0x00110: <0x0000000000000000>,\n" @@ -878,7 +846,39 @@ class TestCompilerIR(unittest.TestCase): "0x001e0: <0x0000000000000000>,\n" "0x001e8: <0x0000000000000000>,\n" "0x001f0: <0x0000000000000000>,\n" - "0x001f8: <0x0000000000000000>})") + "0x001f8: <0x0000000000000000>}, ssa_vals={\n" + ">: (0x100,),\n" + ">: (0x20,),\n" + ">: (\n" + " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (0x1,),\n" + ">: (\n" + " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (0x0,),\n" + "})") def test_gen_asm(self): fn, _arg = self.make_add_fn() @@ -933,6 +933,136 @@ class TestCompilerIR(unittest.TestCase): 'sv.std *32, 0(3)', ]) + def test_spread(self): + fn = Fn() + maxvl = 4 + vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], + name="vl", maxvl=maxvl).outputs[0] + li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + name="li", maxvl=maxvl).outputs[0] + spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], + name="spread", maxvl=maxvl) + self.assertEqual(spread_op.outputs[0].ty_before_spread, + Ty(base_ty=BaseTy.I64, reg_len=maxvl)) + _concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl], + name="concat", maxvl=maxvl) + self.assertEqual([repr(op.properties) for op in fn.ops], [ + "OpProperties(kind=OpKind.SetVLI, inputs=(" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late)," + "), maxvl=4)", + "OpProperties(kind=OpKind.SvLI, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)," + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)," + "), maxvl=4)", + "OpProperties(kind=OpKind.Spread, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=0, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=1, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=2, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=3, " + "write_stage=OpStage.Late)" + "), maxvl=4)", + "OpProperties(kind=OpKind.Concat, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=0, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=1, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=2, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=3, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late)," + "), maxvl=4)", + ]) + self.assertEqual([repr(op) for op in fn.ops], [ + "Op(kind=OpKind.SetVLI, input_vals=[" + "], input_uses=(" + "), immediates=[4], outputs=(" + ">," + "), name='vl')", + "Op(kind=OpKind.SvLI, input_vals=[" + ">" + "], input_uses=(" + ">," + "), immediates=[0], outputs=(" + ">," + "), name='li')", + "Op(kind=OpKind.Spread, input_vals=[" + ">, " + ">" + "], input_uses=(" + ">, " + ">" + "), immediates=[], outputs=(" + ">, " + ">, " + ">, " + ">" + "), name='spread')", + "Op(kind=OpKind.Concat, input_vals=[" + ">, " + ">, " + ">, " + ">, " + ">" + "], input_uses=(" + ">, " + ">, " + ">, " + ">, " + ">" + "), immediates=[], outputs=(" + ">," + "), name='concat')", + ]) + if __name__ == "__main__": _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_register_allocator2.py b/src/bigint_presentation_code/_tests/test_register_allocator2.py index 697417f..f34ed98 100644 --- a/src/bigint_presentation_code/_tests/test_register_allocator2.py +++ b/src/bigint_presentation_code/_tests/test_register_allocator2.py @@ -1,6 +1,8 @@ +import sys import unittest -from bigint_presentation_code.compiler_ir2 import Fn, GenAsmState, OpKind, SSAVal +from bigint_presentation_code.compiler_ir2 import (Fn, GenAsmState, OpKind, + SSAVal) from bigint_presentation_code.register_allocator2 import allocate_registers @@ -177,6 +179,315 @@ class TestCompilerIR(unittest.TestCase): 'sv.std *14, 0(3)', ]) + def test_register_allocate_spread(self): + fn = Fn() + maxvl = 32 + vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], + name="vl", maxvl=maxvl).outputs[0] + li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + name="li", maxvl=maxvl).outputs[0] + spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], + name="spread", maxvl=maxvl).outputs + _concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread[::-1], vl], + name="concat", maxvl=maxvl) + reg_assignments = allocate_registers(fn, debug_out=sys.stdout) + + self.assertEqual( + repr(reg_assignments), + "{>: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=21, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=22, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=24, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=25, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=26, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=27, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=28, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=29, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=30, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=31, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=32, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=33, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=34, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=35, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=36, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=37, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=38, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=39, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=40, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=41, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=42, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=43, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=44, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=45, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=21, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=22, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=24, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=25, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=26, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=27, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=28, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=29, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=30, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=31, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=32, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=33, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=34, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=35, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=36, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=37, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=38, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=39, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=40, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=41, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=42, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=43, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=44, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=45, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=9, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=10, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=11, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=12, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=47, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=48, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=49, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=50, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=51, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=52, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=53, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=54, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=55, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=56, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=57, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=58, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=59, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=60, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=61, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=62, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=63, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=64, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=65, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=66, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=67, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}" + ) + state = GenAsmState(reg_assignments) + fn.gen_asm(state) + self.assertEqual(state.output, [ + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.addi *14, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'or 67, 14, 14', + 'or 66, 15, 15', + 'or 65, 16, 16', + 'or 64, 17, 17', + 'or 63, 18, 18', + 'or 62, 19, 19', + 'or 61, 20, 20', + 'or 60, 21, 21', + 'or 59, 22, 22', + 'or 58, 23, 23', + 'or 57, 24, 24', + 'or 56, 25, 25', + 'or 55, 26, 26', + 'or 54, 27, 27', + 'or 53, 28, 28', + 'or 52, 29, 29', + 'or 51, 30, 30', + 'or 50, 31, 31', + 'or 49, 32, 32', + 'or 48, 33, 33', + 'or 47, 34, 34', + 'or 46, 35, 35', + 'or 12, 36, 36', + 'or 11, 37, 37', + 'or 10, 38, 38', + 'or 9, 39, 39', + 'or 8, 40, 40', + 'or 7, 41, 41', + 'or 6, 42, 42', + 'or 5, 43, 43', + 'or 4, 44, 44', + 'or 3, 45, 45', + 'or 14, 3, 3', + 'or 15, 4, 4', + 'or 16, 5, 5', + 'or 17, 6, 6', + 'or 18, 7, 7', + 'or 19, 8, 8', + 'or 20, 9, 9', + 'or 21, 10, 10', + 'or 22, 11, 11', + 'or 23, 12, 12', + 'or 24, 46, 46', + 'or 25, 47, 47', + 'or 26, 48, 48', + 'or 27, 49, 49', + 'or 28, 50, 50', + 'or 29, 51, 51', + 'or 30, 52, 52', + 'or 31, 53, 53', + 'or 32, 54, 54', + 'or 33, 55, 55', + 'or 34, 56, 56', + 'or 35, 57, 57', + 'or 36, 58, 58', + 'or 37, 59, 59', + 'or 38, 60, 60', + 'or 39, 61, 61', + 'or 40, 62, 62', + 'or 41, 63, 63', + 'or 42, 64, 64', + 'or 43, 65, 65', + 'or 44, 66, 66', + 'or 45, 67, 67', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1']) + if __name__ == "__main__": _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 994c951..3188430 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,7 +1,10 @@ import unittest +from typing import Callable -from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn, +from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, + BaseSimState, Fn, GenAsmState, OpKind, + PostRASimState, PreRASimState) from bigint_presentation_code.register_allocator2 import allocate_registers from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul @@ -204,6 +207,21 @@ class TestToomCook(unittest.TestCase): ) def test_simple_mul_192x192_pre_ra_sim(self): + def create_sim_state(code): + # type: (SimpleMul192x192) -> BaseSimState + return PreRASimState(ssa_vals={}, memory={}) + self.tst_simple_mul_192x192_sim(create_sim_state) + + def test_simple_mul_192x192_post_ra_sim(self): + def create_sim_state(code): + # type: (SimpleMul192x192) -> BaseSimState + ssa_val_to_loc_map = allocate_registers(code.fn) + return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map, + memory={}, loc_values={}) + self.tst_simple_mul_192x192_sim(create_sim_state) + + def tst_simple_mul_192x192_sim(self, create_sim_state): + # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None # test multiplying: # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 @@ -214,18 +232,19 @@ class TestToomCook(unittest.TestCase): # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", # 'little') code = SimpleMul192x192() + state = create_sim_state(code) ptr_in = 0x100 dest_ptr = ptr_in + code.dest_offset lhs_ptr = ptr_in + code.lhs_offset rhs_ptr = ptr_in + code.rhs_offset - state = PreRASimState(ssa_vals={code.ptr_in: (ptr_in,)}, memory={}) + state[code.ptr_in] = ptr_in, state.store(lhs_ptr, 0x821a2342132c5b57) state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e) state.store(lhs_ptr + 16, 0x000191acb262e15b) state.store(rhs_ptr, 0x208a49071aeec507) state.store(rhs_ptr + 8, 0xcf1f597598194ae6) state.store(rhs_ptr + 16, 0x4a37c0567bcbab53) - code.fn.pre_ra_sim(state) + code.fn.sim(state) expected_bytes = b"arbitrary 192x192->384-bit multiplication test" OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') @@ -458,28 +477,413 @@ class TestToomCook(unittest.TestCase): "name='store_dest')", ]) - # FIXME: register allocator currently allocates wrong registers - @unittest.expectedFailure def test_simple_mul_192x192_reg_alloc(self): code = SimpleMul192x192() fn = code.fn assigned_registers = allocate_registers(fn) - self.assertEqual(assigned_registers, { - }) - self.fail("register allocator currently allocates wrong registers") + self.assertEqual( + repr(assigned_registers), "{" + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=6), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=6), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=6), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=9, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=10, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=11, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=12, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=9, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=11, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=12, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=9, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=11, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=12, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1)" + "}") - # FIXME: register allocator currently allocates wrong registers - @unittest.expectedFailure def test_simple_mul_192x192_asm(self): - self.skipTest("WIP") code = SimpleMul192x192() fn = code.fn assigned_registers = allocate_registers(fn) gen_asm_state = GenAsmState(assigned_registers) fn.gen_asm(gen_asm_state) self.assertEqual(gen_asm_state.output, [ + 'or 23, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 6, 23, 23', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.ld *3, 48(6)', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *20, *3, *3', + 'or 6, 23, 23', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.ld *3, 72(6)', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or/mrr *5, *3, *3', + 'or 4, 5, 5', + 'or 14, 6, 6', + 'or 19, 7, 7', + 'setvl 0, 0, 3, 0, 1, 1', + 'addi 3, 0, 0', + 'or 18, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *8, *20, *20', + 'or 7, 4, 4', + 'or 6, 18, 18', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.maddedu *3, *8, 7, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 15, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 17, 3, 3', + 'or 12, 4, 4', + 'or 11, 5, 5', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *8, *20, *20', + 'or 7, 14, 14', + 'or 3, 18, 18', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.maddedu *4, *8, 7, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or/mrr *6, *4, *4', + 'or 14, 3, 3', + 'or 3, 12, 12', + 'or 4, 11, 11', + 'or 5, 15, 15', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *6, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.adde *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 16, 3, 3', + 'or 15, 4, 4', + 'or 12, 5, 5', + 'or 4, 14, 14', + 'addze *3, *4', + 'or 11, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *8, *20, *20', + 'or 7, 19, 19', + 'or 3, 18, 18', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.maddedu *4, *8, 7, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or/mrr *6, *4, *4', + 'or 14, 3, 3', + 'or 3, 15, 15', + 'or 4, 12, 12', + 'or 5, 11, 11', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *6, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.adde *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 12, 3, 3', + 'or 11, 4, 4', + 'or 10, 5, 5', + 'or 4, 14, 14', + 'addze *3, *4', + 'or 9, 3, 3', + 'setvl 0, 0, 6, 0, 1, 1', + 'or 3, 17, 17', + 'or 4, 16, 16', + 'or 5, 12, 12', + 'or 6, 11, 11', + 'or 7, 10, 10', + 'or 8, 9, 9', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.or/mrr *4, *3, *3', + 'or 3, 23, 23', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.std *4, 0(3)' ]) - self.fail("register allocator currently allocates wrong registers") if __name__ == "__main__": diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index bd3d38c..d3b52e8 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -12,7 +12,8 @@ from nmutil.plain_data import fields, plain_data from bigint_presentation_code.type_util import (Literal, Self, assert_never, final) -from bigint_presentation_code.util import BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet +from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta, + OFSet, OSet) @final @@ -54,10 +55,10 @@ class Fn: self.append_op(retval) return retval - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None + def sim(self, state): + # type: (BaseSimState) -> None for op in self.ops: - op.pre_ra_sim(state) + op.sim(state) def gen_asm(self, state): # type: (GenAsmState) -> None @@ -641,25 +642,57 @@ SPECIAL_GPRS = ( ) -@plain_data(frozen=True, eq=False) +@final +class _LocSetHashHelper(AbstractSet[Loc]): + """helper to more quickly compute LocSet's hash""" + + def __init__(self, locs): + # type: (Iterable[Loc]) -> None + super().__init__() + self.locs = list(locs) + + def __hash__(self): + # type: () -> int + return super()._hash() + + def __contains__(self, x): + # type: (Loc | Any) -> bool + return x in self.locs + + def __iter__(self): + # type: () -> Iterator[Loc] + return iter(self.locs) + + def __len__(self): + return len(self.locs) + + +@plain_data(frozen=True, eq=False, repr=False) @final class LocSet(AbstractSet[Loc], metaclass=InternedMeta): - __slots__ = "starts", "ty" + __slots__ = "starts", "ty", "_LocSet__hash" def __init__(self, __locs=()): # type: (Iterable[Loc]) -> None if isinstance(__locs, LocSet): self.starts = __locs.starts # type: FMap[LocKind, FBitSet] self.ty = __locs.ty # type: Ty | None + self._LocSet__hash = __locs._LocSet__hash # type: int return starts = {i: BitSet() for i in LocKind} - ty = None - for loc in __locs: - if ty is None: - ty = loc.ty - if ty != loc.ty: - raise ValueError(f"conflicting types: {ty} != {loc.ty}") - starts[loc.kind].add(loc.start) + ty = None # type: None | Ty + + def locs(): + # type: () -> Iterable[Loc] + nonlocal ty + for loc in __locs: + if ty is None: + ty = loc.ty + if ty != loc.ty: + raise ValueError(f"conflicting types: {ty} != {loc.ty}") + starts[loc.kind].add(loc.start) + yield loc + self._LocSet__hash = _LocSetHashHelper(locs()).__hash__() self.starts = FMap( (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) self.ty = ty @@ -747,7 +780,7 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta): return self.__len def __hash__(self): - return super()._hash() + return self._LocSet__hash def __eq__(self, __other): # type: (LocSet | Any) -> bool @@ -766,6 +799,14 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta): else: return sum(other.conflicts(i) for i in self) + def __repr__(self): + items = [] # type: list[str] + for name in fields(self): + if name.startswith("_"): + continue + items.append(f"{name}={getattr(self, name)!r}") + return f"LocSet({', '.join(items)})" + @plain_data(frozen=True, unsafe_hash=True) @final @@ -821,6 +862,13 @@ class GenericOperandDesc(metaclass=InternedMeta): raise ValueError("operand can't be both spread and vector") self.write_stage = write_stage + @cached_property + def ty_before_spread(self): + # type: () -> GenericTy + if self.spread: + return GenericTy(base_ty=self.ty.base_ty, is_vec=True) + return self.ty + def tied_to_input(self, tied_input_index): # type: (int) -> Self return GenericOperandDesc(self.ty, self.sub_kinds, @@ -846,21 +894,21 @@ class GenericOperandDesc(metaclass=InternedMeta): rep_count = 1 if self.spread: rep_count = maxvl - maxvl = 1 - ty = self.ty.instantiate(maxvl=maxvl) + ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl) - def locs(): + def locs_before_spread(): # type: () -> Iterable[Loc] if self.fixed_loc is not None: - if ty != self.fixed_loc.ty: + if ty_before_spread != self.fixed_loc.ty: raise ValueError( f"instantiation failed: type mismatch with fixed_loc: " - f"instantiated type: {ty} fixed_loc: {self.fixed_loc}") + f"instantiated type: {ty_before_spread} " + f"fixed_loc: {self.fixed_loc}") yield self.fixed_loc return for sub_kind in self.sub_kinds: - yield from sub_kind.allocatable_locs(ty) - loc_set_before_spread = LocSet(locs()) + yield from sub_kind.allocatable_locs(ty_before_spread) + loc_set_before_spread = LocSet(locs_before_spread()) for idx in range(rep_count): if not self.spread: idx = None @@ -1045,9 +1093,9 @@ class OpProperties(metaclass=InternedMeta): IMM_S16 = range(-1 << 15, 1 << 15) -_PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None] -_PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN] -_PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2] +_SIM_FN = Callable[["Op", "BaseSimState"], None] +_SIM_FN2 = Callable[[], _SIM_FN] +_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2] _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None] _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN] _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2] @@ -1075,9 +1123,9 @@ class OpKind(Enum): return "OpKind." + self._name_ @cached_property - def pre_ra_sim(self): - # type: () -> _PRE_RA_SIM_FN - return _PRE_RA_SIMS[self.properties]() + def sim(self): + # type: () -> _SIM_FN + return _SIM_FNS[self.properties]() @cached_property def gen_asm(self): @@ -1085,9 +1133,9 @@ class OpKind(Enum): return _GEN_ASMS[self.properties]() @staticmethod - def __clearca_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = False, + def __clearca_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = False, @staticmethod def __clearca_gen_asm(op, state): @@ -1098,13 +1146,13 @@ class OpKind(Enum): inputs=[], outputs=[OD_CA.with_write_stage(OpStage.Late)], ) - _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim + _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm @staticmethod - def __setca_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = True, + def __setca_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = True, @staticmethod def __setca_gen_asm(op, state): @@ -1115,23 +1163,23 @@ class OpKind(Enum): inputs=[], outputs=[OD_CA.with_write_stage(OpStage.Late)], ) - _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim + _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm @staticmethod - def __svadde_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.input_vals[0]] - RB = state.ssa_vals[op.input_vals[1]] - carry, = state.ssa_vals[op.input_vals[2]] - VL, = state.ssa_vals[op.input_vals[3]] + def __svadde_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = RA[i] + RB[i] + carry RT.append(v & GPR_VALUE_MASK) carry = (v >> GPR_SIZE_IN_BITS) != 0 - state.ssa_vals[op.outputs[0]] = tuple(RT) - state.ssa_vals[op.outputs[1]] = carry, + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, @staticmethod def __svadde_gen_asm(op, state): @@ -1145,19 +1193,19 @@ class OpKind(Enum): inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], ) - _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim + _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm @staticmethod - def __addze_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA, = state.ssa_vals[op.input_vals[0]] - carry, = state.ssa_vals[op.input_vals[1]] + def __addze_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] + carry, = state[op.input_vals[1]] v = RA + carry RT = v & GPR_VALUE_MASK carry = (v >> GPR_SIZE_IN_BITS) != 0 - state.ssa_vals[op.outputs[0]] = RT, - state.ssa_vals[op.outputs[1]] = carry, + state[op.outputs[0]] = RT, + state[op.outputs[1]] = carry, @staticmethod def __addze_gen_asm(op, state): @@ -1170,23 +1218,23 @@ class OpKind(Enum): inputs=[OD_BASE_SGPR, OD_CA], outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)], ) - _PRE_RA_SIMS[AddZE] = lambda: OpKind.__addze_pre_ra_sim + _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm @staticmethod - def __svsubfe_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.input_vals[0]] - RB = state.ssa_vals[op.input_vals[1]] - carry, = state.ssa_vals[op.input_vals[2]] - VL, = state.ssa_vals[op.input_vals[3]] + def __svsubfe_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry RT.append(v & GPR_VALUE_MASK) carry = (v >> GPR_SIZE_IN_BITS) != 0 - state.ssa_vals[op.outputs[0]] = tuple(RT) - state.ssa_vals[op.outputs[1]] = carry, + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, @staticmethod def __svsubfe_gen_asm(op, state): @@ -1200,23 +1248,23 @@ class OpKind(Enum): inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], ) - _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim + _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm @staticmethod - def __svmaddedu_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.input_vals[0]] - RB, = state.ssa_vals[op.input_vals[1]] - carry, = state.ssa_vals[op.input_vals[2]] - VL, = state.ssa_vals[op.input_vals[3]] + def __svmaddedu_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB, = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = RA[i] * RB + carry RT.append(v & GPR_VALUE_MASK) carry = v >> GPR_SIZE_IN_BITS - state.ssa_vals[op.outputs[0]] = tuple(RT) - state.ssa_vals[op.outputs[1]] = carry, + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, @staticmethod def __svmaddedu_gen_asm(op, state): @@ -1231,13 +1279,13 @@ class OpKind(Enum): inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)], ) - _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim + _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm @staticmethod - def __setvli_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = op.immediates[0], + def __setvli_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = op.immediates[0], @staticmethod def __setvli_gen_asm(op, state): @@ -1251,15 +1299,15 @@ class OpKind(Enum): immediates=[range(1, 65)], is_load_immediate=True, ) - _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim + _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm @staticmethod - def __svli_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - VL, = state.ssa_vals[op.input_vals[0]] + def __svli_sim(op, state): + # type: (Op, BaseSimState) -> None + VL, = state[op.input_vals[0]] imm = op.immediates[0] & GPR_VALUE_MASK - state.ssa_vals[op.outputs[0]] = (imm,) * VL + state[op.outputs[0]] = (imm,) * VL @staticmethod def __svli_gen_asm(op, state): @@ -1274,14 +1322,14 @@ class OpKind(Enum): immediates=[IMM_S16], is_load_immediate=True, ) - _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim + _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm @staticmethod - def __li_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None + def __li_sim(op, state): + # type: (Op, BaseSimState) -> None imm = op.immediates[0] & GPR_VALUE_MASK - state.ssa_vals[op.outputs[0]] = imm, + state[op.outputs[0]] = imm, @staticmethod def __li_gen_asm(op, state): @@ -1296,13 +1344,13 @@ class OpKind(Enum): immediates=[IMM_S16], is_load_immediate=True, ) - _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim + _SIM_FNS[LI] = lambda: OpKind.__li_sim _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm @staticmethod - def __veccopytoreg_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + def __veccopytoreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] @staticmethod def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state): @@ -1359,13 +1407,13 @@ class OpKind(Enum): outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], is_copy=True, ) - _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim + _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm @staticmethod - def __veccopyfromreg_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + def __veccopyfromreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] @staticmethod def __veccopyfromreg_gen_asm(op, state): @@ -1385,13 +1433,13 @@ class OpKind(Enum): )], is_copy=True, ) - _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim + _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm @staticmethod - def __copytoreg_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + def __copytoreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] @staticmethod def __copytoreg_gen_asm(op, state): @@ -1415,13 +1463,13 @@ class OpKind(Enum): )], is_copy=True, ) - _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim + _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm @staticmethod - def __copyfromreg_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + def __copyfromreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] @staticmethod def __copyfromreg_gen_asm(op, state): @@ -1445,14 +1493,14 @@ class OpKind(Enum): )], is_copy=True, ) - _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim + _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm @staticmethod - def __concat_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = tuple( - state.ssa_vals[i][0] for i in op.input_vals[:-1]) + def __concat_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = tuple( + state[i][0] for i in op.input_vals[:-1]) @staticmethod def __concat_gen_asm(op, state): @@ -1471,14 +1519,14 @@ class OpKind(Enum): outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], is_copy=True, ) - _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim + _SIM_FNS[Concat] = lambda: OpKind.__concat_sim _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm @staticmethod - def __spread_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]): - state.ssa_vals[op.outputs[idx]] = inp, + def __spread_sim(op, state): + # type: (Op, BaseSimState) -> None + for idx, inp in enumerate(state[op.input_vals[0]]): + state[op.outputs[idx]] = inp, @staticmethod def __spread_gen_asm(op, state): @@ -1498,20 +1546,20 @@ class OpKind(Enum): )], is_copy=True, ) - _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim + _SIM_FNS[Spread] = lambda: OpKind.__spread_sim _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm @staticmethod - def __svld_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA, = state.ssa_vals[op.input_vals[0]] - VL, = state.ssa_vals[op.input_vals[1]] + def __svld_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] + VL, = state[op.input_vals[1]] addr = RA + op.immediates[0] RT = [] # type: list[int] for i in range(VL): v = state.load(addr + GPR_SIZE_IN_BYTES * i) RT.append(v & GPR_VALUE_MASK) - state.ssa_vals[op.outputs[0]] = tuple(RT) + state[op.outputs[0]] = tuple(RT) @staticmethod def __svld_gen_asm(op, state): @@ -1526,16 +1574,16 @@ class OpKind(Enum): outputs=[OD_EXTRA3_VGPR], immediates=[IMM_S16], ) - _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim + _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm @staticmethod - def __ld_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RA, = state.ssa_vals[op.input_vals[0]] + def __ld_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] addr = RA + op.immediates[0] v = state.load(addr) - state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK, + state[op.outputs[0]] = v & GPR_VALUE_MASK, @staticmethod def __ld_gen_asm(op, state): @@ -1550,15 +1598,15 @@ class OpKind(Enum): outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], immediates=[IMM_S16], ) - _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim + _SIM_FNS[Ld] = lambda: OpKind.__ld_sim _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm @staticmethod - def __svstd_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RS = state.ssa_vals[op.input_vals[0]] - RA, = state.ssa_vals[op.input_vals[1]] - VL, = state.ssa_vals[op.input_vals[2]] + def __svstd_sim(op, state): + # type: (Op, BaseSimState) -> None + RS = state[op.input_vals[0]] + RA, = state[op.input_vals[1]] + VL, = state[op.input_vals[2]] addr = RA + op.immediates[0] for i in range(VL): state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) @@ -1577,14 +1625,14 @@ class OpKind(Enum): immediates=[IMM_S16], has_side_effects=True, ) - _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim + _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm @staticmethod - def __std_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None - RS, = state.ssa_vals[op.input_vals[0]] - RA, = state.ssa_vals[op.input_vals[1]] + def __std_sim(op, state): + # type: (Op, BaseSimState) -> None + RS, = state[op.input_vals[0]] + RA, = state[op.input_vals[1]] addr = RA + op.immediates[0] state.store(addr, value=RS) @@ -1602,12 +1650,12 @@ class OpKind(Enum): immediates=[IMM_S16], has_side_effects=True, ) - _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim + _SIM_FNS[Std] = lambda: OpKind.__std_sim _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm @staticmethod - def __funcargr3_pre_ra_sim(op, state): - # type: (Op, PreRASimState) -> None + def __funcargr3_sim(op, state): + # type: (Op, BaseSimState) -> None pass # return value set before simulation @staticmethod @@ -1620,7 +1668,7 @@ class OpKind(Enum): outputs=[OD_BASE_SGPR.with_fixed_loc( Loc(kind=LocKind.GPR, start=3, reg_len=1))], ) - _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim + _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm @@ -1942,32 +1990,37 @@ class Op: field_vals_str = ", ".join(field_vals) return f"Op({field_vals_str})" - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None + def sim(self, state): + # type: (BaseSimState) -> None for inp in self.input_vals: - if inp not in state.ssa_vals: + try: + val = state[inp] + except KeyError: raise ValueError(f"SSAVal {inp} not yet assigned when " f"running {self}") - if len(state.ssa_vals[inp]) != inp.ty.reg_len: + if len(val) != inp.ty.reg_len: raise ValueError( f"value of SSAVal {inp} has wrong number of elements: " f"expected {inp.ty.reg_len} found " - f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}") - for out in self.outputs: - if out in state.ssa_vals: - if self.kind is OpKind.FuncArgR3: - continue - raise ValueError(f"SSAVal {out} already assigned before " - f"running {self}") - self.kind.pre_ra_sim(self, state) + f"{len(val)}: {val!r}") + if isinstance(state, PreRASimState): + for out in self.outputs: + if out in state.ssa_vals: + if self.kind is OpKind.FuncArgR3: + continue + raise ValueError(f"SSAVal {out} already assigned before " + f"running {self}") + self.kind.sim(self, state) for out in self.outputs: - if out not in state.ssa_vals: + try: + val = state[out] + except KeyError: raise ValueError(f"running {self} failed to assign to {out}") - if len(state.ssa_vals[out]) != out.ty.reg_len: + if len(val) != out.ty.reg_len: raise ValueError( f"value of SSAVal {out} has wrong number of elements: " f"expected {out.ty.reg_len} found " - f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}") + f"{len(val)}: {val!r}") def gen_asm(self, state): # type: (GenAsmState) -> None @@ -1986,13 +2039,12 @@ GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 @plain_data(frozen=True, repr=False) -@final -class PreRASimState: - __slots__ = "ssa_vals", "memory" +class BaseSimState(metaclass=ABCMeta): + __slots__ = "memory", - def __init__(self, ssa_vals, memory): - # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None - self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]] + def __init__(self, memory): + # type: (dict[int, int]) -> None + super().__init__() self.memory = memory # type: dict[int, int] def load_byte(self, addr): @@ -2049,6 +2101,44 @@ class PreRASimState: items_str = ",\n".join(items) return f"{{\n{items_str}}}" + def __repr__(self): + # type: () -> str + field_vals = [] # type: list[str] + for name in fields(self): + try: + value = getattr(self, name) + except AttributeError: + field_vals.append(f"{name}=") + continue + repr_fn = getattr(self, f"_{name}__repr", None) + if callable(repr_fn): + field_vals.append(f"{name}={repr_fn()}") + else: + field_vals.append(f"{name}={value!r}") + field_vals_str = ", ".join(field_vals) + return f"{self.__class__.__name__}({field_vals_str})" + + @abstractmethod + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + ... + + @abstractmethod + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + ... + + +@plain_data(frozen=True, repr=False) +@final +class PreRASimState(BaseSimState): + __slots__ = "ssa_vals", + + def __init__(self, ssa_vals, memory): + # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None + super().__init__(memory) + self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]] + def _ssa_vals__repr(self): # type: () -> str if len(self.ssa_vals) == 0: @@ -2073,22 +2163,65 @@ class PreRASimState: items_str = ",\n".join(items) return f"{{\n{items_str},\n}}" - def __repr__(self): + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + return self.ssa_vals[ssa_val] + + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + if len(value) != ssa_val.ty.reg_len: + raise ValueError("value has wrong len") + self.ssa_vals[ssa_val] = value + + +@plain_data(frozen=True, repr=False) +@final +class PostRASimState(BaseSimState): + __slots__ = "ssa_val_to_loc_map", "loc_values" + + def __init__(self, ssa_val_to_loc_map, memory, loc_values): + # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None + super().__init__(memory) + self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map) + for ssa_val, loc in self.ssa_val_to_loc_map.items(): + if ssa_val.ty != loc.ty: + raise ValueError( + f"type mismatch for SSAVal and Loc: {ssa_val} {loc}") + self.loc_values = loc_values + for loc in self.loc_values.keys(): + if loc.reg_len != 1: + raise ValueError( + "loc_values must only contain Locs with reg_len=1, all " + "larger Locs will be split into reg_len=1 sub-Locs") + + def _loc_values__repr(self): # type: () -> str - field_vals = [] # type: list[str] - for name in fields(self): - try: - value = getattr(self, name) - except AttributeError: - field_vals.append(f"{name}=") - continue - repr_fn = getattr(self, f"_{name}__repr", None) - if callable(repr_fn): - field_vals.append(f"{name}={repr_fn()}") - else: - field_vals.append(f"{name}={value!r}") - field_vals_str = ", ".join(field_vals) - return f"PreRASimState({field_vals_str})" + locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start)) + items = [] # type: list[str] + for loc in locs: + items.append(f"{loc}: 0x{self.loc_values[loc]:x}") + items_str = ",\n".join(items) + return f"{{\n{items_str},\n}}" + + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + loc = self.ssa_val_to_loc_map[ssa_val] + subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) + retval = [] # type: list[int] + for i in range(loc.reg_len): + subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) + retval.append(self.loc_values.get(subloc, 0)) + return tuple(retval) + + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + if len(value) != ssa_val.ty.reg_len: + raise ValueError("value has wrong len") + loc = self.ssa_val_to_loc_map[ssa_val] + subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) + for i in range(loc.reg_len): + subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) + self.loc_values[subloc] = value[i] @plain_data(frozen=True) diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py index 198eebc..278693d 100644 --- a/src/bigint_presentation_code/register_allocator2.py +++ b/src/bigint_presentation_code/register_allocator2.py @@ -6,7 +6,7 @@ this uses an algorithm based on: """ from itertools import combinations -from typing import Iterable, Iterator, Mapping +from typing import Iterable, Iterator, Mapping, TextIO from cached_property import cached_property from nmutil.plain_data import plain_data @@ -71,22 +71,19 @@ class MergedSSAVal(metaclass=InternedMeta): reg_len = self.ty.reg_len loc_set = None # type: None | LocSet for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items(): - def_spread_idx = ssa_val.defining_descriptor.spread_index or 0 - def locs(): # type: () -> Iterable[Loc] for loc in ssa_val.def_loc_set_before_spread: disallowed_by_use = False for use in fn_analysis.uses[ssa_val]: - use_spread_idx = \ - use.defining_descriptor.spread_index or 0 # calculate the start for the use's Loc before spread # e.g. if the def's Loc before spread starts at r6 - # and the def's spread_index is 5 - # and the use's spread_index is 3 + # and the def's reg_offset_in_unspread is 5 + # and the use's reg_offset_in_unspread is 3 # then the use's Loc before spread starts at r8 # because 8 == 6 + 5 - 3 - start = loc.start + def_spread_idx - use_spread_idx + start = (loc.start + ssa_val.reg_offset_in_unspread + - use.reg_offset_in_unspread) use_loc = Loc.try_make( loc.kind, start=start, reg_len=use.ty_before_spread.reg_len) @@ -201,8 +198,9 @@ class MergedSSAVal(metaclass=InternedMeta): return ProgramRange(start=start, stop=stop) def __repr__(self): - return (f"MergedSSAVal({self.fn_analysis}, " - f"ssa_val_offsets={self.ssa_val_offsets})") + return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, " + f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, " + f"live_interval={self.live_interval})") @final @@ -464,18 +462,26 @@ class AllocationFailedError(Exception): return self.__repr__() -def allocate_registers(fn): - # type: (Fn) -> dict[SSAVal, Loc] +def allocate_registers(fn, debug_out=None): + # type: (Fn, TextIO | None) -> dict[SSAVal, Loc] # inserts enough copies that no manual spilling is necessary, all # spilling is done by the register allocator naturally allocating SSAVals # to stack slots fn.pre_ra_insert_copies() + if debug_out is not None: + print(f"After pre_ra_insert_copies():\n{fn.ops}", + file=debug_out, flush=True) + fn_analysis = FnAnalysis(fn) interference_graph = InterferenceGraph.minimally_merged(fn_analysis) - for ssa_vals in fn_analysis.live_at.values(): + if debug_out is not None: + print(f"After InterferenceGraph.minimally_merged():\n" + f"{interference_graph}", file=debug_out, flush=True) + + for pp, ssa_vals in fn_analysis.live_at.items(): live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] for ssa_val in ssa_vals: live_merged_ssa_vals.add( @@ -484,6 +490,13 @@ def allocate_registers(fn): if i.loc_set.max_conflicts_with(j.loc_set) != 0: interference_graph.nodes[i].add_edge( interference_graph.nodes[j]) + if debug_out is not None: + print(f"processed {pp} out of {fn_analysis.all_program_points}", + file=debug_out, flush=True) + + if debug_out is not None: + print(f"After adding interference graph edges:\n" + f"{interference_graph}", file=debug_out, flush=True) nodes_remaining = OSet(interference_graph.nodes.values()) @@ -521,6 +534,10 @@ def allocate_registers(fn): node_stack.append(best_node) nodes_remaining.remove(best_node) + if debug_out is not None: + print(f"After deciding node allocation order:\n" + f"{node_stack}", file=debug_out, flush=True) + retval = {} # type: dict[SSAVal, Loc] while len(node_stack) > 0: @@ -543,7 +560,15 @@ def allocate_registers(fn): "failed to allocate Loc for IGNode", node=node, interference_graph=interference_graph) + if debug_out is not None: + print(f"After allocating Loc for node:\n{node}", + file=debug_out, flush=True) + for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items(): retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset) + if debug_out is not None: + print(f"final Locs for all SSAVals:\n{retval}", + file=debug_out, flush=True) + return retval