working on adding MergedSSAVal.__mergable_check that ensures copy
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 1 Dec 2022 08:05:17 +0000 (00:05 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 1 Dec 2022 08:05:17 +0000 (00:05 -0800)
merging won't merge things that are illegal to merge

src/bigint_presentation_code/_tests/test_compiler_ir.py
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/register_allocator.py

index 904f66d89a05592806bd7c148edc4f715dd02c85..60d8b445a4c64857ab305518599f2a0c0ed7114c 100644 (file)
@@ -253,14 +253,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=1)",
+            "write_stage=OpStage.Early),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvLd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -276,7 +276,7 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=32)",
+            "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvLI, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -287,14 +287,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=32)",
+            "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetCA, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvAddE, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -321,7 +321,7 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
             "tied_input_index=2, spread_index=None, "
-            "write_stage=OpStage.Early)), maxvl=32)",
+            "write_stage=OpStage.Early)), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -337,7 +337,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early)), "
-            "outputs=(), maxvl=32)",
+            "outputs=(), maxvl=32, copy_reg_len=0)",
         ])
 
     def test_pre_ra_insert_copies(self):
@@ -429,7 +429,7 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=1)",
+            "write_stage=OpStage.Early),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.CopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -442,14 +442,14 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
             "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.CopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -462,14 +462,14 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvLd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -485,14 +485,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=32)",
+            "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -508,14 +508,14 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvLI, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -526,14 +526,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Early),), maxvl=32)",
+            "write_stage=OpStage.Early),), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -549,21 +549,21 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.SetCA, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -579,14 +579,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -602,14 +602,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvAddE, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -636,14 +636,14 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
             "tied_input_index=2, spread_index=None, "
-            "write_stage=OpStage.Early)), maxvl=32)",
+            "write_stage=OpStage.Early)), maxvl=32, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -659,14 +659,14 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -682,7 +682,7 @@ class TestCompilerIR(unittest.TestCase):
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=32)",
+            "write_stage=OpStage.Late),), maxvl=32, copy_reg_len=32)",
             "OpProperties(kind=OpKind.CopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -695,14 +695,14 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
-            "write_stage=OpStage.Late),), maxvl=1)",
+            "write_stage=OpStage.Late),), maxvl=1, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
@@ -718,7 +718,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
             "tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early)), "
-            "outputs=(), maxvl=32)",
+            "outputs=(), maxvl=32, copy_reg_len=0)",
         ])
 
     def test_sim(self):
@@ -878,7 +878,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.VL_MAXVL: FBitSet([0])}), "
             "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),"
-            "), maxvl=4)",
+            "), maxvl=4, copy_reg_len=0)",
             "OpProperties(kind=OpKind.SvLI, inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), "
@@ -889,7 +889,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
             "ty=<I64*4>), tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Early),"
-            "), maxvl=4)",
+            "), maxvl=4, copy_reg_len=0)",
             "OpProperties(kind=OpKind.Spread, inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
@@ -916,7 +916,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
             "ty=<I64*4>), tied_input_index=None, spread_index=3, "
             "write_stage=OpStage.Late)"
-            "), maxvl=4)",
+            "), maxvl=4, copy_reg_len=4)",
             "OpProperties(kind=OpKind.Concat, inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
@@ -943,7 +943,7 @@ class TestCompilerIR(unittest.TestCase):
             "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
             "ty=<I64*4>), tied_input_index=None, spread_index=None, "
             "write_stage=OpStage.Late),"
-            "), maxvl=4)",
+            "), maxvl=4, copy_reg_len=4)",
         ])
         self.assertEqual(
             fn.ops_to_str(),
index 69f8d8ec066112fa7406dc4c0c5bb3ac89a5fb2a..e3a4ed5bb4fcfc28bbc65c5a687954a8d5fe80df 100644 (file)
@@ -311,6 +311,19 @@ class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
         return f"<range:{start}..{stop}>"
 
 
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class SSAValSubReg(metaclass=InternedMeta):
+    __slots__ = "ssa_val", "reg_idx"
+
+    def __init__(self, ssa_val, reg_idx):
+        # type: (SSAVal, int) -> None
+        if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
+            raise ValueError("reg_idx out of range")
+        self.ssa_val = ssa_val
+        self.reg_idx = reg_idx
+
+
 @plain_data(frozen=True, eq=False, repr=False)
 @final
 class FnAnalysis:
@@ -354,6 +367,9 @@ class FnAnalysis:
                 live_at[program_point].add(ssa_val)
         self.live_ranges = FMap(live_ranges)
         self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
+        self.copies  # initialize
+        self.const_ssa_vals  # initialize
+        self.const_ssa_val_sub_regs  # initialize
 
     def __get_def_program_range(self, ssa_val):
         # type: (SSAVal) -> ProgramRange
@@ -387,6 +403,78 @@ class FnAnalysis:
         # type: () -> str
         return "<FnAnalysis>"
 
+    @cached_property
+    def copies(self):
+        # type: () -> FMap[SSAValSubReg, SSAValSubReg]
+        """ map from SSAValSubRegs to the original SSAValSubRegs that they are
+        a copy of, looking through all layers of copies. The map excludes all
+        SSAValSubRegs that aren't copies of other SSAValSubRegs.
+        """
+        retval = {}  # type: dict[SSAValSubReg, SSAValSubReg]
+        for op in self.op_indexes.keys():
+            if not op.properties.is_copy:
+                continue
+            copy_reg_len = op.properties.copy_reg_len
+            copy_inputs = []  # type: list[SSAValSubReg]
+            for inp in op.input_vals[:op.properties.copy_inputs_len]:
+                for inp_sub_reg in inp.ssa_val_sub_regs:
+                    # propagate copies of copies
+                    inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
+                    copy_inputs.append(inp_sub_reg)
+            assert len(copy_inputs) == copy_reg_len, "logic error"
+            copy_outputs = []  # type: list[SSAValSubReg]
+            for out in op.outputs[:op.properties.copy_outputs_len]:
+                copy_outputs.extend(out.ssa_val_sub_regs)
+            assert len(copy_outputs) == copy_reg_len, "logic error"
+            for inp, out in zip(copy_inputs, copy_outputs):
+                retval[out] = inp
+        return FMap(retval)
+
+    @cached_property
+    def const_ssa_vals(self):
+        # type: () -> FMap[SSAVal, tuple[int, ...]]
+        state = ConstPropagationState(
+            ssa_vals={}, memory={}, skipped_ops=OSet())
+        self.fn.sim(state)
+        return FMap(state.ssa_vals)
+
+    @cached_property
+    def const_ssa_val_sub_regs(self):
+        # type: () -> FMap[SSAValSubReg, int]
+        retval = {}  # type: dict[SSAValSubReg, int]
+        for ssa_val, const_val in self.const_ssa_vals.items():
+            assert ssa_val.ty.reg_len == len(const_val), "logic error"
+            for reg_idx, v in enumerate(const_val):
+                retval[SSAValSubReg(ssa_val, reg_idx)] = v
+        return FMap(retval)
+
+    def are_always_equal(self, a, b):
+        # type: (SSAValSubReg, SSAValSubReg) -> bool
+        """check if a and b are known to be always equal to each other.
+        This means they can be allocated to the same location if other
+        constraints don't prevent that.
+
+        this can happen for a number of reasons, such as:
+        * a and b are copies of the same thing
+        * a and b are known to be constants and they have the same value
+        """
+        if a.ssa_val.base_ty != b.ssa_val.base_ty:
+            return False  # can't be equal, they have different types
+        # look through copies
+        a = self.copies.get(a, a)
+        b = self.copies.get(b, b)
+        if a == b:
+            return True
+        # check if they have the same constant value
+        try:
+            a_const_val = self.const_ssa_val_sub_regs[a]
+            b_const_val = self.const_ssa_val_sub_regs[b]
+            if a_const_val == b_const_val:
+                return True
+        except KeyError:
+            pass
+        return False
+
 
 @unique
 @final
@@ -1015,7 +1103,7 @@ class GenericOpProperties(metaclass=InternedMeta):
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class OpProperties(metaclass=InternedMeta):
-    __slots__ = "kind", "inputs", "outputs", "maxvl"
+    __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
 
     def __init__(self, kind, maxvl):
         # type: (OpKind, int) -> None
@@ -1029,6 +1117,17 @@ class OpProperties(metaclass=InternedMeta):
             outputs.extend(out.instantiate(maxvl=maxvl))
         self.outputs = tuple(outputs)  # type: tuple[OperandDesc, ...]
         self.maxvl = maxvl  # type: int
+        copy_input_reg_len = 0
+        for inp in self.inputs[:self.copy_inputs_len]:
+            copy_input_reg_len += inp.ty.reg_len
+        copy_output_reg_len = 0
+        for out in self.outputs[:self.copy_outputs_len]:
+            copy_output_reg_len += out.ty.reg_len
+        if copy_input_reg_len != copy_output_reg_len:
+            raise ValueError(f"invalid copy: copy's input reg len must "
+                             f"match its output reg len: "
+                             f"{copy_input_reg_len} != {copy_output_reg_len}")
+        self.copy_reg_len = copy_input_reg_len
 
     @property
     def generic(self):
@@ -1060,6 +1159,34 @@ class OpProperties(metaclass=InternedMeta):
         # type: () -> bool
         return self.generic.has_side_effects
 
+    @cached_property
+    def copy_inputs_len(self):
+        # type: () -> int
+        if not self.is_copy:
+            return 0
+        if self.inputs[0].spread_index is None:
+            return 1
+        retval = 0
+        for i, inp in enumerate(self.inputs):
+            if inp.spread_index != i:
+                break
+            retval += 1
+        return retval
+
+    @cached_property
+    def copy_outputs_len(self):
+        # type: () -> int
+        if not self.is_copy:
+            return 0
+        if self.outputs[0].spread_index is None:
+            return 1
+        retval = 0
+        for i, out in enumerate(self.outputs):
+            if out.spread_index != i:
+                break
+            retval += 1
+        return retval
+
 
 IMM_S16 = range(-1 << 15, 1 << 15)
 
@@ -1808,6 +1935,11 @@ class SSAVal(SSAValOrUse):
         """
         return PreRASimState.get_current_debugging_state()[self]
 
+    @cached_property
+    def ssa_val_sub_regs(self):
+        # type: () -> tuple[SSAValSubReg, ...]
+        return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
+
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
@@ -2067,6 +2199,8 @@ class Op:
             except KeyError:
                 raise ValueError(f"SSAVal {inp} not yet assigned when "
                                  f"running {self}")
+            except SimSkipOp:
+                continue
             if len(val) != inp.ty.reg_len:
                 raise ValueError(
                     f"value of SSAVal {inp} has wrong number of elements: "
@@ -2079,12 +2213,17 @@ class Op:
                         continue
                     raise ValueError(f"SSAVal {out} already assigned before "
                                      f"running {self}")
-        self.kind.sim(self, state)
+        try:
+            self.kind.sim(self, state)
+        except SimSkipOp:
+            state.on_skip(self)
         for out in self.outputs:
             try:
                 val = state[out]
             except KeyError:
                 raise ValueError(f"running {self} failed to assign to {out}")
+            except SimSkipOp:
+                continue
             if len(val) != out.ty.reg_len:
                 raise ValueError(
                     f"value of SSAVal {out} has wrong number of elements: "
@@ -2110,10 +2249,21 @@ class BaseSimState(metaclass=ABCMeta):
         super().__init__()
         self.memory = memory  # type: dict[int, int]
 
+    def _default_memory_value(self):
+        # type: () -> int
+        return 0
+
+    def on_skip(self, op):
+        # type: (Op) -> None
+        raise ValueError("skipping instructions not supported")
+
     def load_byte(self, addr):
         # type: (int) -> int
         addr &= GPR_VALUE_MASK
-        return self.memory.get(addr, 0) & 0xFF
+        try:
+            return self.memory[addr] & 0xFF
+        except KeyError:
+            return self._default_memory_value()
 
     def store_byte(self, addr, value):
         # type: (int, int) -> None
@@ -2193,8 +2343,7 @@ class BaseSimState(metaclass=ABCMeta):
 
 
 @plain_data(frozen=True, repr=False)
-@final
-class PreRASimState(BaseSimState):
+class PreRABaseSimState(BaseSimState):
     __slots__ = "ssa_vals",
 
     def __init__(self, ssa_vals, memory):
@@ -2228,7 +2377,14 @@ class PreRASimState(BaseSimState):
 
     def __getitem__(self, ssa_val):
         # type: (SSAVal) -> tuple[int, ...]
-        return self.ssa_vals[ssa_val]
+        try:
+            return self.ssa_vals[ssa_val]
+        except KeyError:
+            return self._handle_undefined_ssa_val(ssa_val)
+
+    def _handle_undefined_ssa_val(self, ssa_val):
+        # type: (SSAVal) -> tuple[int, ...]
+        raise KeyError("SSAVal has no value set", ssa_val)
 
     def __setitem__(self, ssa_val, value):
         # type: (SSAVal, tuple[int, ...]) -> None
@@ -2236,6 +2392,38 @@ class PreRASimState(BaseSimState):
             raise ValueError("value has wrong len")
         self.ssa_vals[ssa_val] = value
 
+
+class SimSkipOp(Exception):
+    pass
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class ConstPropagationState(PreRABaseSimState):
+    __slots__ = "skipped_ops",
+
+    def __init__(self, ssa_vals, memory, skipped_ops):
+        # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
+        super().__init__(ssa_vals, memory)
+        self.skipped_ops = skipped_ops
+
+    def _default_memory_value(self):
+        # type: () -> int
+        raise SimSkipOp
+
+    def _handle_undefined_ssa_val(self, ssa_val):
+        # type: (SSAVal) -> tuple[int, ...]
+        raise SimSkipOp
+
+    def on_skip(self, op):
+        # type: (Op) -> None
+        self.skipped_ops.add(op)
+
+
+@plain_data(frozen=True, repr=False)
+class PreRASimState(PreRABaseSimState):
+    __slots__ = ()
+
     __CURRENT_DEBUGGING_STATE = []  # type: list[PreRASimState]
 
     @contextmanager
index f41d18e7fc96cc8b92638464765fda622866b836..a99d84959a943951faee6323ac9eaf31d6ce7d39 100644 (file)
@@ -52,7 +52,8 @@ class MergedSSAVal(metaclass=InternedMeta):
     * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
     * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
     """
-    __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
+    __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
+                 "first_loc")
 
     def __init__(self, fn_analysis, ssa_val_offsets):
         # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
@@ -93,17 +94,42 @@ class MergedSSAVal(metaclass=InternedMeta):
                             break
                     if disallowed_by_use:
                         continue
-                    # FIXME: add spread consistency check
                     start = loc.start - cur_offset + self.offset
                     loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
                     if loc is not None and (loc_set is None or loc in loc_set):
                         yield loc
             loc_set = LocSet(locs())
         assert loc_set is not None, "already checked that self isn't empty"
-        if loc_set.ty is None:
+        first_loc = None
+        for loc in loc_set:
+            first_loc = loc
+            break
+        if first_loc is None:
             raise BadMergedSSAVal("there are no valid Locs left")
+        self.first_loc = first_loc
         assert loc_set.ty == self.ty, "logic error somewhere"
         self.loc_set = loc_set  # type: LocSet
+        self.__mergable_check()
+
+    def __mergable_check(self):
+        # type: () -> None
+        """ checks that nothing is forcing two independent SSAVals
+        to illegally overlap. This is required to avoid copy merging merging
+        things that can't be merged.
+        spread arguments are one of the things that can force two values to
+        illegally overlap.
+        """
+        # pick an arbitrary Loc, any Loc will do
+        loc = self.first_loc
+        ops = sorted(OSet(i.op for i in self.ssa_vals),
+                     key=self.fn_analysis.op_indexes.__getitem__)
+        vals = {}  # type: dict[Loc, tuple[SSAVal, int]]
+        for op in ops:
+            for inp in op.input_vals:
+                pass
+                # FIXME: finish checking using FnAnalysis.are_always_equal
+                # also check that two different outputs of the same
+                # instruction aren't merged
 
     @cached_property
     def __hash(self):