From 9cb1a00725d0d357407d59b2d65911cb566a072c Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 1 Dec 2022 21:39:02 -0800 Subject: [PATCH] finished __mergable_check --- .../_tests/test_compiler_ir.py | 78 +++++++++++++++++++ src/bigint_presentation_code/compiler_ir.py | 16 ++-- .../register_allocator.py | 47 ++++++++--- 3 files changed, 125 insertions(+), 16 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py index 60d8b44..75e8349 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -196,6 +196,54 @@ class TestCompilerIR(unittest.TestCase): self.assertEqual( repr(fn_analysis.all_program_points), "") + self.assertEqual(repr(fn_analysis.copies), "FMap({})") + self.assertEqual( + repr(fn_analysis.const_ssa_vals), + "FMap({" + ">: (32,), " + ">: (" + "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, " + "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), " + ">: (1,)})" + ) + self.assertEqual( + repr(fn_analysis.const_ssa_val_sub_regs), + "FMap({" + ">[0]: 32, " + ">[0]: 0, " + ">[1]: 0, " + ">[2]: 0, " + ">[3]: 0, " + ">[4]: 0, " + ">[5]: 0, " + ">[6]: 0, " + ">[7]: 0, " + ">[8]: 0, " + ">[9]: 0, " + ">[10]: 0, " + ">[11]: 0, " + ">[12]: 0, " + ">[13]: 0, " + ">[14]: 0, " + ">[15]: 0, " + ">[16]: 0, " + ">[17]: 0, " + ">[18]: 0, " + ">[19]: 0, " + ">[20]: 0, " + ">[21]: 0, " + ">[22]: 0, " + ">[23]: 0, " + ">[24]: 0, " + ">[25]: 0, " + ">[26]: 0, " + ">[27]: 0, " + ">[28]: 0, " + ">[29]: 0, " + ">[30]: 0, " + ">[31]: 0, " + ">[0]: 1})" + ) def test_repr(self): fn, _arg = self.make_add_fn() @@ -962,6 +1010,36 @@ class TestCompilerIR(unittest.TestCase): " >, >,\n" " >)" ) + fn_analysis = FnAnalysis(fn) + self.assertEqual( + repr(fn_analysis.copies), + "FMap({" + ">[0]: >[0], " + ">[0]: >[1], " + ">[0]: >[2], " + ">[0]: >[3], " + ">[0]: >[3], " + ">[1]: >[2], " + ">[2]: >[1], " + ">[3]: >[0]})" + ) + self.assertEqual( + repr(fn_analysis.const_ssa_val_sub_regs), + "FMap({" + ">[0]: 4, " + ">[0]: 0, " + ">[1]: 0, " + ">[2]: 0, " + ">[3]: 0, " + ">[0]: 0, " + ">[0]: 0, " + ">[0]: 0, " + ">[0]: 0, " + ">[0]: 0, " + ">[1]: 0, " + ">[2]: 0, " + ">[3]: 0})" + ) if __name__ == "__main__": diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index e3a4ed5..12a32f1 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -311,7 +311,7 @@ class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta): return f"" -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(frozen=True, unsafe_hash=True, repr=False) @final class SSAValSubReg(metaclass=InternedMeta): __slots__ = "ssa_val", "reg_idx" @@ -323,6 +323,10 @@ class SSAValSubReg(metaclass=InternedMeta): self.ssa_val = ssa_val self.reg_idx = reg_idx + def __repr__(self): + # type: () -> str + return f"{self.ssa_val}[{self.reg_idx}]" + @plain_data(frozen=True, eq=False, repr=False) @final @@ -448,7 +452,7 @@ class FnAnalysis: retval[SSAValSubReg(ssa_val, reg_idx)] = v return FMap(retval) - def are_always_equal(self, a, b): + def is_always_equal(self, a, b): # type: (SSAValSubReg, SSAValSubReg) -> bool """check if a and b are known to be always equal to each other. This means they can be allocated to the same location if other @@ -2338,7 +2342,7 @@ class BaseSimState(metaclass=ABCMeta): @abstractmethod def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None + # type: (SSAVal, Iterable[int]) -> None ... @@ -2387,7 +2391,8 @@ class PreRABaseSimState(BaseSimState): raise KeyError("SSAVal has no value set", ssa_val) def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None + # type: (SSAVal, Iterable[int]) -> None + value = tuple(map(int, value)) if len(value) != ssa_val.ty.reg_len: raise ValueError("value has wrong len") self.ssa_vals[ssa_val] = value @@ -2496,7 +2501,8 @@ class PostRASimState(BaseSimState): return tuple(retval) def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None + # type: (SSAVal, Iterable[int]) -> None + value = tuple(map(int, value)) if len(value) != ssa_val.ty.reg_len: raise ValueError("value has wrong len") loc = self.ssa_val_to_loc_map[ssa_val] diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index a99d849..73381e4 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -12,8 +12,8 @@ from cached_property import cached_property from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc, - LocSet, ProgramRange, SSAVal, - Ty) + LocSet, Op, ProgramRange, + SSAVal, SSAValSubReg, Ty) from bigint_presentation_code.type_util import final from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet @@ -119,17 +119,42 @@ class MergedSSAVal(metaclass=InternedMeta): spread arguments are one of the things that can force two values to illegally overlap. """ - # pick an arbitrary Loc, any Loc will do - loc = self.first_loc - ops = sorted(OSet(i.op for i in self.ssa_vals), - key=self.fn_analysis.op_indexes.__getitem__) - vals = {} # type: dict[Loc, tuple[SSAVal, int]] + ops = OSet() # type: Iterable[Op] + for ssa_val in self.ssa_vals: + ops.add(ssa_val.op) + for use in self.fn_analysis.uses[ssa_val]: + ops.add(use.op) + ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__) + vals = {} # type: dict[int, SSAValSubReg] for op in ops: for inp in op.input_vals: - pass - # FIXME: finish checking using FnAnalysis.are_always_equal - # also check that two different outputs of the same - # instruction aren't merged + try: + ssa_val_offset = self.ssa_val_offsets[inp] + except KeyError: + continue + for orig_reg in inp.ssa_val_sub_regs: + reg_offset = ssa_val_offset + orig_reg.reg_idx + replaced_reg = vals[reg_offset] + if not self.fn_analysis.is_always_equal( + orig_reg, replaced_reg): + raise BadMergedSSAVal( + f"attempting to merge values that aren't known to " + f"be always equal: {orig_reg} != {replaced_reg}") + output_offsets = dict.fromkeys(range( + self.offset, self.offset + self.ty.reg_len)) + for out in op.outputs: + try: + ssa_val_offset = self.ssa_val_offsets[out] + except KeyError: + continue + for reg in out.ssa_val_sub_regs: + reg_offset = ssa_val_offset + reg.reg_idx + try: + del output_offsets[reg_offset] + except KeyError: + raise BadMergedSSAVal("attempted to merge two outputs " + "of the same instruction") + vals[reg_offset] = reg @cached_property def __hash(self): -- 2.30.2