From 2c979da1bcdf8b086fa084ea6248583c8668aa1b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 1 Dec 2022 00:05:17 -0800 Subject: [PATCH] working on adding MergedSSAVal.__mergable_check that ensures copy merging won't merge things that are illegal to merge --- .../_tests/test_compiler_ir.py | 74 +++---- src/bigint_presentation_code/compiler_ir.py | 200 +++++++++++++++++- .../register_allocator.py | 32 ++- 3 files changed, 260 insertions(+), 46 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py index 904f66d..60d8b44 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -253,14 +253,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([3])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=1)", + "write_stage=OpStage.Early),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvLd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -276,7 +276,7 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", + "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SvLI, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -287,14 +287,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", + "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SetCA, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvAddE, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -321,7 +321,7 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " "tied_input_index=2, spread_index=None, " - "write_stage=OpStage.Early)), maxvl=32)", + "write_stage=OpStage.Early)), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SvStd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -337,7 +337,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " "write_stage=OpStage.Early)), " - "outputs=(), maxvl=32)", + "outputs=(), maxvl=32, copy_reg_len=0)", ]) def test_pre_ra_insert_copies(self): @@ -429,7 +429,7 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([3])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=1)", + "write_stage=OpStage.Early),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.CopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -442,14 +442,14 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.CopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -462,14 +462,14 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvLd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -485,14 +485,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", + "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -508,14 +508,14 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvLI, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -526,14 +526,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", + "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -549,21 +549,21 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.SetCA, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -579,14 +579,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -602,14 +602,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvAddE, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -636,14 +636,14 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " "tied_input_index=2, spread_index=None, " - "write_stage=OpStage.Early)), maxvl=32)", + "write_stage=OpStage.Early)), maxvl=32, copy_reg_len=0)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -659,14 +659,14 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -682,7 +682,7 @@ class TestCompilerIR(unittest.TestCase): "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", + "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)", "OpProperties(kind=OpKind.CopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -695,14 +695,14 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", + "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)", "OpProperties(kind=OpKind.SvStd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" @@ -718,7 +718,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " "tied_input_index=None, spread_index=None, " "write_stage=OpStage.Early)), " - "outputs=(), maxvl=32)", + "outputs=(), maxvl=32, copy_reg_len=0)", ]) def test_sim(self): @@ -878,7 +878,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.VL_MAXVL: FBitSet([0])}), " "ty=), tied_input_index=None, spread_index=None, " "write_stage=OpStage.Late)," - "), maxvl=4)", + "), maxvl=4, copy_reg_len=0)", "OpProperties(kind=OpKind.SvLI, inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), " @@ -889,7 +889,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " "ty=), tied_input_index=None, spread_index=None, " "write_stage=OpStage.Early)," - "), maxvl=4)", + "), maxvl=4, copy_reg_len=0)", "OpProperties(kind=OpKind.Spread, inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " @@ -916,7 +916,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " "ty=), tied_input_index=None, spread_index=3, " "write_stage=OpStage.Late)" - "), maxvl=4)", + "), maxvl=4, copy_reg_len=4)", "OpProperties(kind=OpKind.Concat, inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " @@ -943,7 +943,7 @@ class TestCompilerIR(unittest.TestCase): "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " "ty=), tied_input_index=None, spread_index=None, " "write_stage=OpStage.Late)," - "), maxvl=4)", + "), maxvl=4, copy_reg_len=4)", ]) self.assertEqual( fn.ops_to_str(), diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 69f8d8e..e3a4ed5 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -311,6 +311,19 @@ class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta): return f"" +@plain_data(frozen=True, unsafe_hash=True) +@final +class SSAValSubReg(metaclass=InternedMeta): + __slots__ = "ssa_val", "reg_idx" + + def __init__(self, ssa_val, reg_idx): + # type: (SSAVal, int) -> None + if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len: + raise ValueError("reg_idx out of range") + self.ssa_val = ssa_val + self.reg_idx = reg_idx + + @plain_data(frozen=True, eq=False, repr=False) @final class FnAnalysis: @@ -354,6 +367,9 @@ class FnAnalysis: live_at[program_point].add(ssa_val) self.live_ranges = FMap(live_ranges) self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items()) + self.copies # initialize + self.const_ssa_vals # initialize + self.const_ssa_val_sub_regs # initialize def __get_def_program_range(self, ssa_val): # type: (SSAVal) -> ProgramRange @@ -387,6 +403,78 @@ class FnAnalysis: # type: () -> str return "" + @cached_property + def copies(self): + # type: () -> FMap[SSAValSubReg, SSAValSubReg] + """ map from SSAValSubRegs to the original SSAValSubRegs that they are + a copy of, looking through all layers of copies. The map excludes all + SSAValSubRegs that aren't copies of other SSAValSubRegs. + """ + retval = {} # type: dict[SSAValSubReg, SSAValSubReg] + for op in self.op_indexes.keys(): + if not op.properties.is_copy: + continue + copy_reg_len = op.properties.copy_reg_len + copy_inputs = [] # type: list[SSAValSubReg] + for inp in op.input_vals[:op.properties.copy_inputs_len]: + for inp_sub_reg in inp.ssa_val_sub_regs: + # propagate copies of copies + inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg) + copy_inputs.append(inp_sub_reg) + assert len(copy_inputs) == copy_reg_len, "logic error" + copy_outputs = [] # type: list[SSAValSubReg] + for out in op.outputs[:op.properties.copy_outputs_len]: + copy_outputs.extend(out.ssa_val_sub_regs) + assert len(copy_outputs) == copy_reg_len, "logic error" + for inp, out in zip(copy_inputs, copy_outputs): + retval[out] = inp + return FMap(retval) + + @cached_property + def const_ssa_vals(self): + # type: () -> FMap[SSAVal, tuple[int, ...]] + state = ConstPropagationState( + ssa_vals={}, memory={}, skipped_ops=OSet()) + self.fn.sim(state) + return FMap(state.ssa_vals) + + @cached_property + def const_ssa_val_sub_regs(self): + # type: () -> FMap[SSAValSubReg, int] + retval = {} # type: dict[SSAValSubReg, int] + for ssa_val, const_val in self.const_ssa_vals.items(): + assert ssa_val.ty.reg_len == len(const_val), "logic error" + for reg_idx, v in enumerate(const_val): + retval[SSAValSubReg(ssa_val, reg_idx)] = v + return FMap(retval) + + def are_always_equal(self, a, b): + # type: (SSAValSubReg, SSAValSubReg) -> bool + """check if a and b are known to be always equal to each other. + This means they can be allocated to the same location if other + constraints don't prevent that. + + this can happen for a number of reasons, such as: + * a and b are copies of the same thing + * a and b are known to be constants and they have the same value + """ + if a.ssa_val.base_ty != b.ssa_val.base_ty: + return False # can't be equal, they have different types + # look through copies + a = self.copies.get(a, a) + b = self.copies.get(b, b) + if a == b: + return True + # check if they have the same constant value + try: + a_const_val = self.const_ssa_val_sub_regs[a] + b_const_val = self.const_ssa_val_sub_regs[b] + if a_const_val == b_const_val: + return True + except KeyError: + pass + return False + @unique @final @@ -1015,7 +1103,7 @@ class GenericOpProperties(metaclass=InternedMeta): @plain_data(frozen=True, unsafe_hash=True) @final class OpProperties(metaclass=InternedMeta): - __slots__ = "kind", "inputs", "outputs", "maxvl" + __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len" def __init__(self, kind, maxvl): # type: (OpKind, int) -> None @@ -1029,6 +1117,17 @@ class OpProperties(metaclass=InternedMeta): outputs.extend(out.instantiate(maxvl=maxvl)) self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...] self.maxvl = maxvl # type: int + copy_input_reg_len = 0 + for inp in self.inputs[:self.copy_inputs_len]: + copy_input_reg_len += inp.ty.reg_len + copy_output_reg_len = 0 + for out in self.outputs[:self.copy_outputs_len]: + copy_output_reg_len += out.ty.reg_len + if copy_input_reg_len != copy_output_reg_len: + raise ValueError(f"invalid copy: copy's input reg len must " + f"match its output reg len: " + f"{copy_input_reg_len} != {copy_output_reg_len}") + self.copy_reg_len = copy_input_reg_len @property def generic(self): @@ -1060,6 +1159,34 @@ class OpProperties(metaclass=InternedMeta): # type: () -> bool return self.generic.has_side_effects + @cached_property + def copy_inputs_len(self): + # type: () -> int + if not self.is_copy: + return 0 + if self.inputs[0].spread_index is None: + return 1 + retval = 0 + for i, inp in enumerate(self.inputs): + if inp.spread_index != i: + break + retval += 1 + return retval + + @cached_property + def copy_outputs_len(self): + # type: () -> int + if not self.is_copy: + return 0 + if self.outputs[0].spread_index is None: + return 1 + retval = 0 + for i, out in enumerate(self.outputs): + if out.spread_index != i: + break + retval += 1 + return retval + IMM_S16 = range(-1 << 15, 1 << 15) @@ -1808,6 +1935,11 @@ class SSAVal(SSAValOrUse): """ return PreRASimState.get_current_debugging_state()[self] + @cached_property + def ssa_val_sub_regs(self): + # type: () -> tuple[SSAValSubReg, ...] + return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len)) + @plain_data(frozen=True, unsafe_hash=True, repr=False) @final @@ -2067,6 +2199,8 @@ class Op: except KeyError: raise ValueError(f"SSAVal {inp} not yet assigned when " f"running {self}") + except SimSkipOp: + continue if len(val) != inp.ty.reg_len: raise ValueError( f"value of SSAVal {inp} has wrong number of elements: " @@ -2079,12 +2213,17 @@ class Op: continue raise ValueError(f"SSAVal {out} already assigned before " f"running {self}") - self.kind.sim(self, state) + try: + self.kind.sim(self, state) + except SimSkipOp: + state.on_skip(self) for out in self.outputs: try: val = state[out] except KeyError: raise ValueError(f"running {self} failed to assign to {out}") + except SimSkipOp: + continue if len(val) != out.ty.reg_len: raise ValueError( f"value of SSAVal {out} has wrong number of elements: " @@ -2110,10 +2249,21 @@ class BaseSimState(metaclass=ABCMeta): super().__init__() self.memory = memory # type: dict[int, int] + def _default_memory_value(self): + # type: () -> int + return 0 + + def on_skip(self, op): + # type: (Op) -> None + raise ValueError("skipping instructions not supported") + def load_byte(self, addr): # type: (int) -> int addr &= GPR_VALUE_MASK - return self.memory.get(addr, 0) & 0xFF + try: + return self.memory[addr] & 0xFF + except KeyError: + return self._default_memory_value() def store_byte(self, addr, value): # type: (int, int) -> None @@ -2193,8 +2343,7 @@ class BaseSimState(metaclass=ABCMeta): @plain_data(frozen=True, repr=False) -@final -class PreRASimState(BaseSimState): +class PreRABaseSimState(BaseSimState): __slots__ = "ssa_vals", def __init__(self, ssa_vals, memory): @@ -2228,7 +2377,14 @@ class PreRASimState(BaseSimState): def __getitem__(self, ssa_val): # type: (SSAVal) -> tuple[int, ...] - return self.ssa_vals[ssa_val] + try: + return self.ssa_vals[ssa_val] + except KeyError: + return self._handle_undefined_ssa_val(ssa_val) + + def _handle_undefined_ssa_val(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + raise KeyError("SSAVal has no value set", ssa_val) def __setitem__(self, ssa_val, value): # type: (SSAVal, tuple[int, ...]) -> None @@ -2236,6 +2392,38 @@ class PreRASimState(BaseSimState): raise ValueError("value has wrong len") self.ssa_vals[ssa_val] = value + +class SimSkipOp(Exception): + pass + + +@plain_data(frozen=True, repr=False) +@final +class ConstPropagationState(PreRABaseSimState): + __slots__ = "skipped_ops", + + def __init__(self, ssa_vals, memory, skipped_ops): + # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None + super().__init__(ssa_vals, memory) + self.skipped_ops = skipped_ops + + def _default_memory_value(self): + # type: () -> int + raise SimSkipOp + + def _handle_undefined_ssa_val(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + raise SimSkipOp + + def on_skip(self, op): + # type: (Op) -> None + self.skipped_ops.add(op) + + +@plain_data(frozen=True, repr=False) +class PreRASimState(PreRABaseSimState): + __slots__ = () + __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState] @contextmanager diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index f41d18e..a99d849 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -52,7 +52,8 @@ class MergedSSAVal(metaclass=InternedMeta): * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` """ - __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set" + __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set", + "first_loc") def __init__(self, fn_analysis, ssa_val_offsets): # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None @@ -93,17 +94,42 @@ class MergedSSAVal(metaclass=InternedMeta): break if disallowed_by_use: continue - # FIXME: add spread consistency check start = loc.start - cur_offset + self.offset loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len) if loc is not None and (loc_set is None or loc in loc_set): yield loc loc_set = LocSet(locs()) assert loc_set is not None, "already checked that self isn't empty" - if loc_set.ty is None: + first_loc = None + for loc in loc_set: + first_loc = loc + break + if first_loc is None: raise BadMergedSSAVal("there are no valid Locs left") + self.first_loc = first_loc assert loc_set.ty == self.ty, "logic error somewhere" self.loc_set = loc_set # type: LocSet + self.__mergable_check() + + def __mergable_check(self): + # type: () -> None + """ checks that nothing is forcing two independent SSAVals + to illegally overlap. This is required to avoid copy merging merging + things that can't be merged. + spread arguments are one of the things that can force two values to + illegally overlap. + """ + # pick an arbitrary Loc, any Loc will do + loc = self.first_loc + ops = sorted(OSet(i.op for i in self.ssa_vals), + key=self.fn_analysis.op_indexes.__getitem__) + vals = {} # type: dict[Loc, tuple[SSAVal, int]] + for op in ops: + for inp in op.input_vals: + pass + # FIXME: finish checking using FnAnalysis.are_always_equal + # also check that two different outputs of the same + # instruction aren't merged @cached_property def __hash(self): -- 2.30.2