finished __mergable_check
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 2 Dec 2022 05:39:02 +0000 (21:39 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 2 Dec 2022 05:39:02 +0000 (21:39 -0800)
src/bigint_presentation_code/_tests/test_compiler_ir.py
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/register_allocator.py

index 60d8b445a4c64857ab305518599f2a0c0ed7114c..75e8349b79e7c8c0a01f72971e32a5dd79c9ca5c 100644 (file)
@@ -196,6 +196,54 @@ class TestCompilerIR(unittest.TestCase):
         self.assertEqual(
             repr(fn_analysis.all_program_points),
             "<range:ops[0]:Early..ops[7]:Early>")
+        self.assertEqual(repr(fn_analysis.copies), "FMap({})")
+        self.assertEqual(
+            repr(fn_analysis.const_ssa_vals),
+            "FMap({"
+            "<vl.outputs[0]: <VL_MAXVL>>: (32,), "
+            "<li.outputs[0]: <I64*32>>: ("
+            "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "
+            "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), "
+            "<ca.outputs[0]: <CA>>: (1,)})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.const_ssa_val_sub_regs),
+            "FMap({"
+            "<vl.outputs[0]: <VL_MAXVL>>[0]: 32, "
+            "<li.outputs[0]: <I64*32>>[0]: 0, "
+            "<li.outputs[0]: <I64*32>>[1]: 0, "
+            "<li.outputs[0]: <I64*32>>[2]: 0, "
+            "<li.outputs[0]: <I64*32>>[3]: 0, "
+            "<li.outputs[0]: <I64*32>>[4]: 0, "
+            "<li.outputs[0]: <I64*32>>[5]: 0, "
+            "<li.outputs[0]: <I64*32>>[6]: 0, "
+            "<li.outputs[0]: <I64*32>>[7]: 0, "
+            "<li.outputs[0]: <I64*32>>[8]: 0, "
+            "<li.outputs[0]: <I64*32>>[9]: 0, "
+            "<li.outputs[0]: <I64*32>>[10]: 0, "
+            "<li.outputs[0]: <I64*32>>[11]: 0, "
+            "<li.outputs[0]: <I64*32>>[12]: 0, "
+            "<li.outputs[0]: <I64*32>>[13]: 0, "
+            "<li.outputs[0]: <I64*32>>[14]: 0, "
+            "<li.outputs[0]: <I64*32>>[15]: 0, "
+            "<li.outputs[0]: <I64*32>>[16]: 0, "
+            "<li.outputs[0]: <I64*32>>[17]: 0, "
+            "<li.outputs[0]: <I64*32>>[18]: 0, "
+            "<li.outputs[0]: <I64*32>>[19]: 0, "
+            "<li.outputs[0]: <I64*32>>[20]: 0, "
+            "<li.outputs[0]: <I64*32>>[21]: 0, "
+            "<li.outputs[0]: <I64*32>>[22]: 0, "
+            "<li.outputs[0]: <I64*32>>[23]: 0, "
+            "<li.outputs[0]: <I64*32>>[24]: 0, "
+            "<li.outputs[0]: <I64*32>>[25]: 0, "
+            "<li.outputs[0]: <I64*32>>[26]: 0, "
+            "<li.outputs[0]: <I64*32>>[27]: 0, "
+            "<li.outputs[0]: <I64*32>>[28]: 0, "
+            "<li.outputs[0]: <I64*32>>[29]: 0, "
+            "<li.outputs[0]: <I64*32>>[30]: 0, "
+            "<li.outputs[0]: <I64*32>>[31]: 0, "
+            "<ca.outputs[0]: <CA>>[0]: 1})"
+        )
 
     def test_repr(self):
         fn, _arg = self.make_add_fn()
@@ -962,6 +1010,36 @@ class TestCompilerIR(unittest.TestCase):
             "    <spread.outputs[1]: <I64>>, <spread.outputs[0]: <I64>>,\n"
             "    <vl.outputs[0]: <VL_MAXVL>>)"
         )
+        fn_analysis = FnAnalysis(fn)
+        self.assertEqual(
+            repr(fn_analysis.copies),
+            "FMap({"
+            "<spread.outputs[0]: <I64>>[0]: <li.outputs[0]: <I64*4>>[0], "
+            "<spread.outputs[1]: <I64>>[0]: <li.outputs[0]: <I64*4>>[1], "
+            "<spread.outputs[2]: <I64>>[0]: <li.outputs[0]: <I64*4>>[2], "
+            "<spread.outputs[3]: <I64>>[0]: <li.outputs[0]: <I64*4>>[3], "
+            "<concat.outputs[0]: <I64*4>>[0]: <li.outputs[0]: <I64*4>>[3], "
+            "<concat.outputs[0]: <I64*4>>[1]: <li.outputs[0]: <I64*4>>[2], "
+            "<concat.outputs[0]: <I64*4>>[2]: <li.outputs[0]: <I64*4>>[1], "
+            "<concat.outputs[0]: <I64*4>>[3]: <li.outputs[0]: <I64*4>>[0]})"
+        )
+        self.assertEqual(
+            repr(fn_analysis.const_ssa_val_sub_regs),
+            "FMap({"
+            "<vl.outputs[0]: <VL_MAXVL>>[0]: 4, "
+            "<li.outputs[0]: <I64*4>>[0]: 0, "
+            "<li.outputs[0]: <I64*4>>[1]: 0, "
+            "<li.outputs[0]: <I64*4>>[2]: 0, "
+            "<li.outputs[0]: <I64*4>>[3]: 0, "
+            "<spread.outputs[0]: <I64>>[0]: 0, "
+            "<spread.outputs[1]: <I64>>[0]: 0, "
+            "<spread.outputs[2]: <I64>>[0]: 0, "
+            "<spread.outputs[3]: <I64>>[0]: 0, "
+            "<concat.outputs[0]: <I64*4>>[0]: 0, "
+            "<concat.outputs[0]: <I64*4>>[1]: 0, "
+            "<concat.outputs[0]: <I64*4>>[2]: 0, "
+            "<concat.outputs[0]: <I64*4>>[3]: 0})"
+        )
 
 
 if __name__ == "__main__":
index e3a4ed5bb4fcfc28bbc65c5a687954a8d5fe80df..12a32f1b6ad28bcf2258e804cc0958f1507047a4 100644 (file)
@@ -311,7 +311,7 @@ class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
         return f"<range:{start}..{stop}>"
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAValSubReg(metaclass=InternedMeta):
     __slots__ = "ssa_val", "reg_idx"
@@ -323,6 +323,10 @@ class SSAValSubReg(metaclass=InternedMeta):
         self.ssa_val = ssa_val
         self.reg_idx = reg_idx
 
+    def __repr__(self):
+        # type: () -> str
+        return f"{self.ssa_val}[{self.reg_idx}]"
+
 
 @plain_data(frozen=True, eq=False, repr=False)
 @final
@@ -448,7 +452,7 @@ class FnAnalysis:
                 retval[SSAValSubReg(ssa_val, reg_idx)] = v
         return FMap(retval)
 
-    def are_always_equal(self, a, b):
+    def is_always_equal(self, a, b):
         # type: (SSAValSubReg, SSAValSubReg) -> bool
         """check if a and b are known to be always equal to each other.
         This means they can be allocated to the same location if other
@@ -2338,7 +2342,7 @@ class BaseSimState(metaclass=ABCMeta):
 
     @abstractmethod
     def __setitem__(self, ssa_val, value):
-        # type: (SSAVal, tuple[int, ...]) -> None
+        # type: (SSAVal, Iterable[int]) -> None
         ...
 
 
@@ -2387,7 +2391,8 @@ class PreRABaseSimState(BaseSimState):
         raise KeyError("SSAVal has no value set", ssa_val)
 
     def __setitem__(self, ssa_val, value):
-        # type: (SSAVal, tuple[int, ...]) -> None
+        # type: (SSAVal, Iterable[int]) -> None
+        value = tuple(map(int, value))
         if len(value) != ssa_val.ty.reg_len:
             raise ValueError("value has wrong len")
         self.ssa_vals[ssa_val] = value
@@ -2496,7 +2501,8 @@ class PostRASimState(BaseSimState):
         return tuple(retval)
 
     def __setitem__(self, ssa_val, value):
-        # type: (SSAVal, tuple[int, ...]) -> None
+        # type: (SSAVal, Iterable[int]) -> None
+        value = tuple(map(int, value))
         if len(value) != ssa_val.ty.reg_len:
             raise ValueError("value has wrong len")
         loc = self.ssa_val_to_loc_map[ssa_val]
index a99d84959a943951faee6323ac9eaf31d6ce7d39..73381e41493bb0d3c1f6f76168ba250a0554c914 100644 (file)
@@ -12,8 +12,8 @@ from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
-                                                  LocSet, ProgramRange, SSAVal,
-                                                  Ty)
+                                                  LocSet, Op, ProgramRange,
+                                                  SSAVal, SSAValSubReg, Ty)
 from bigint_presentation_code.type_util import final
 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
 
@@ -119,17 +119,42 @@ class MergedSSAVal(metaclass=InternedMeta):
         spread arguments are one of the things that can force two values to
         illegally overlap.
         """
-        # pick an arbitrary Loc, any Loc will do
-        loc = self.first_loc
-        ops = sorted(OSet(i.op for i in self.ssa_vals),
-                     key=self.fn_analysis.op_indexes.__getitem__)
-        vals = {}  # type: dict[Loc, tuple[SSAVal, int]]
+        ops = OSet()  # type: Iterable[Op]
+        for ssa_val in self.ssa_vals:
+            ops.add(ssa_val.op)
+            for use in self.fn_analysis.uses[ssa_val]:
+                ops.add(use.op)
+        ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
+        vals = {}  # type: dict[int, SSAValSubReg]
         for op in ops:
             for inp in op.input_vals:
-                pass
-                # FIXME: finish checking using FnAnalysis.are_always_equal
-                # also check that two different outputs of the same
-                # instruction aren't merged
+                try:
+                    ssa_val_offset = self.ssa_val_offsets[inp]
+                except KeyError:
+                    continue
+                for orig_reg in inp.ssa_val_sub_regs:
+                    reg_offset = ssa_val_offset + orig_reg.reg_idx
+                    replaced_reg = vals[reg_offset]
+                    if not self.fn_analysis.is_always_equal(
+                            orig_reg, replaced_reg):
+                        raise BadMergedSSAVal(
+                            f"attempting to merge values that aren't known to "
+                            f"be always equal: {orig_reg} != {replaced_reg}")
+            output_offsets = dict.fromkeys(range(
+                self.offset, self.offset + self.ty.reg_len))
+            for out in op.outputs:
+                try:
+                    ssa_val_offset = self.ssa_val_offsets[out]
+                except KeyError:
+                    continue
+                for reg in out.ssa_val_sub_regs:
+                    reg_offset = ssa_val_offset + reg.reg_idx
+                    try:
+                        del output_offsets[reg_offset]
+                    except KeyError:
+                        raise BadMergedSSAVal("attempted to merge two outputs "
+                                              "of the same instruction")
+                    vals[reg_offset] = reg
 
     @cached_property
     def __hash(self):