add Fn class rather than global for generating op ids
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 15 Oct 2022 00:13:26 +0000 (17:13 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 15 Oct 2022 00:13:26 +0000 (17:13 -0700)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/test_compiler_ir.py
src/bigint_presentation_code/test_register_allocator.py

index bfcf159c25e4fc2cced574fdf019977d0fcd102a..aa37fa48eaa8da590481956ee95a028e5d85e51e 100644 (file)
@@ -399,6 +399,21 @@ class EqualityConstraint:
             raise ValueError("can't constrain an empty list to be equal")
 
 
+@final
+class Fn:
+    __slots__ = "ops",
+
+    def __init__(self):
+        # type: () -> None
+        self.ops = []  # type: list[Op]
+
+    def __repr__(self, short=False):
+        if short:
+            return "<Fn>"
+        ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
+        return f"<Fn([{ops}])>"
+
+
 class _NotSet:
     """ helper for __repr__ for when fields aren't set """
 
@@ -411,7 +426,7 @@ _NOT_SET = _NotSet()
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 class Op(metaclass=ABCMeta):
-    __slots__ = ()
+    __slots__ = "id", "fn"
 
     @abstractmethod
     def inputs(self):
@@ -433,19 +448,11 @@ class Op(metaclass=ABCMeta):
         if False:
             yield ...
 
-    __NEXT_ID = 0
-
-    @cached_property
-    def id(self):
-        # type: () -> int
-        # use cached_property rather than done in init so id is usable even if
-        # init hasn't run
-        retval = Op.__NEXT_ID
-        Op.__NEXT_ID += 1
-        return retval
-
-    def __init__(self):
-        self.id  # initialize
+    def __init__(self, fn):
+        # type: (Fn) -> None
+        self.id = len(fn.ops)
+        fn.ops.append(self)
+        self.fn = fn
 
     @final
     def __repr__(self, just_id=False):
@@ -461,6 +468,8 @@ class Op(metaclass=ABCMeta):
                 if ((outputs is None or name in outputs)
                         and isinstance(v, SSAVal)):
                     v = v.__repr__(long=True)
+                elif isinstance(v, Fn):
+                    v = v.__repr__(short=True)
                 else:
                     v = repr(v)
                 fields_list.append(f"{name}={v}")
@@ -481,9 +490,9 @@ class OpLoadFromStackSlot(Op):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, src):
-        # type: (SSAVal[GPRRangeType]) -> None
-        super().__init__()
+    def __init__(self, fn, src):
+        # type: (Fn, SSAVal[GPRRangeType]) -> None
+        super().__init__(fn)
         self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
         self.src = src
 
@@ -501,9 +510,9 @@ class OpStoreToStackSlot(Op):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, src):
-        # type: (SSAVal[StackSlotType]) -> None
-        super().__init__()
+    def __init__(self, fn, src):
+        # type: (Fn, SSAVal[StackSlotType]) -> None
+        super().__init__(fn)
         self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
         self.src = src
 
@@ -524,9 +533,9 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, src, dest_ty=None):
-        # type: (SSAVal[_RegSrcType], _RegType | None) -> None
-        super().__init__()
+    def __init__(self, fn, src, dest_ty=None):
+        # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None
+        super().__init__(fn)
         if dest_ty is None:
             dest_ty = cast(_RegType, src.ty)
         if isinstance(src.ty, GPRRangeType) \
@@ -560,9 +569,9 @@ class OpConcat(Op):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, sources):
-        # type: (Iterable[SSAVal[GPRRangeType]]) -> None
-        super().__init__()
+    def __init__(self, fn, sources):
+        # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None
+        super().__init__(fn)
         sources = tuple(sources)
         self.dest = SSAVal(self, "dest", GPRRangeType(
             sum(i.ty.length for i in sources)))
@@ -586,9 +595,9 @@ class OpSplit(Op):
         # type: () -> dict[str, SSAVal]
         return {i.arg_name: i for i in self.results}
 
-    def __init__(self, src, split_indexes):
-        # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
-        super().__init__()
+    def __init__(self, fn, src, split_indexes):
+        # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None
+        super().__init__(fn)
         ranges = []  # type: list[GPRRangeType]
         last = 0
         for i in split_indexes:
@@ -620,9 +629,9 @@ class OpAddSubE(Op):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT, "CY_out": self.CY_out}
 
-    def __init__(self, RA, RB, CY_in, is_sub):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
-        super().__init__()
+    def __init__(self, fn, RA, RB, CY_in, is_sub):
+        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+        super().__init__(fn)
         if RA.ty != RB.ty:
             raise TypeError(f"source types must match: "
                             f"{RA} doesn't match {RB}")
@@ -652,9 +661,9 @@ class OpBigIntMulDiv(Op):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT, "RS": self.RS}
 
-    def __init__(self, RA, RB, RC, is_div):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
-        super().__init__()
+    def __init__(self, fn, RA, RB, RC, is_div):
+        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+        super().__init__(fn)
         self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
         self.RB = RB
@@ -697,9 +706,9 @@ class OpBigIntShift(Op):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, inp, sh, kind):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
-        super().__init__()
+    def __init__(self, fn, inp, sh, kind):
+        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+        super().__init__(fn)
         self.RT = SSAVal(self, "RT", inp.ty)
         self.inp = inp
         self.sh = sh
@@ -724,9 +733,9 @@ class OpLI(Op):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, value, length=1):
-        # type: (int, int) -> None
-        super().__init__()
+    def __init__(self, fn, value, length=1):
+        # type: (Fn, int, int) -> None
+        super().__init__(fn)
         self.out = SSAVal(self, "out", GPRRangeType(length))
         self.value = value
 
@@ -744,9 +753,9 @@ class OpClearCY(Op):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self):
-        # type: () -> None
-        super().__init__()
+    def __init__(self, fn):
+        # type: (Fn) -> None
+        super().__init__(fn)
         self.out = SSAVal(self, "out", CYType())
 
 
@@ -763,9 +772,9 @@ class OpLoad(Op):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, RA, offset, mem, length=1):
-        # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
-        super().__init__()
+    def __init__(self, fn, RA, offset, mem, length=1):
+        # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+        super().__init__(fn)
         self.RT = SSAVal(self, "RT", GPRRangeType(length))
         self.RA = RA
         self.offset = offset
@@ -790,9 +799,9 @@ class OpStore(Op):
         # type: () -> dict[str, SSAVal]
         return {"mem_out": self.mem_out}
 
-    def __init__(self, RS, RA, offset, mem_in):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
-        super().__init__()
+    def __init__(self, fn, RS, RA, offset, mem_in):
+        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+        super().__init__(fn)
         self.RS = RS
         self.RA = RA
         self.offset = offset
@@ -813,9 +822,9 @@ class OpFuncArg(Op):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, ty):
-        # type: (FixedGPRRangeType) -> None
-        super().__init__()
+    def __init__(self, fn, ty):
+        # type: (Fn, FixedGPRRangeType) -> None
+        super().__init__(fn)
         self.out = SSAVal(self, "out", ty)
 
 
@@ -832,9 +841,9 @@ class OpInputMem(Op):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self):
-        # type: () -> None
-        super().__init__()
+    def __init__(self, fn):
+        # type: (Fn) -> None
+        super().__init__(fn)
         self.out = SSAVal(self, "out", GlobalMemType())
 
 
index 1f30547806c5203a9ab8d2855d6c70122eeaca22..ff52641fa273a654f80f02865b41517c419b1f33 100644 (file)
@@ -1,6 +1,6 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType,
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType,
                                                   Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore,
                                                   op_set_to_list)
 
@@ -9,35 +9,25 @@ class TestCompilerIR(unittest.TestCase):
     maxDiff = None
 
     def test_op_set_to_list(self):
-        ops = []  # type: list[Op]
-        op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
-        ops.append(op0)
-        op1 = OpCopy(op0.out, GPRType())
-        ops.append(op1)
+        fn = Fn()
+        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+        op1 = OpCopy(fn, op0.out, GPRType())
         arg = op1.dest
-        op2 = OpInputMem()
-        ops.append(op2)
+        op2 = OpInputMem(fn)
         mem = op2.out
-        op3 = OpLoad(arg, offset=0, mem=mem, length=32)
-        ops.append(op3)
+        op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
         a = op3.RT
-        op4 = OpLI(1)
-        ops.append(op4)
+        op4 = OpLI(fn, 1)
         b_0 = op4.out
-        op5 = OpLI(0, length=31)
-        ops.append(op5)
+        op5 = OpLI(fn, 0, length=31)
         b_rest = op5.out
-        op6 = OpConcat([b_0, b_rest])
-        ops.append(op6)
+        op6 = OpConcat(fn, [b_0, b_rest])
         b = op6.dest
-        op7 = OpClearCY()
-        ops.append(op7)
+        op7 = OpClearCY(fn)
         cy = op7.out
-        op8 = OpAddSubE(a, b, cy, is_sub=False)
-        ops.append(op8)
+        op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
         s = op8.RT
-        op9 = OpStore(s, arg, offset=0, mem_in=mem)
-        ops.append(op9)
+        op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
         mem = op9.mem_out
 
         expected_ops = [
@@ -53,7 +43,7 @@ class TestCompilerIR(unittest.TestCase):
             op9,  # OpStore(s, arg, offset=0, mem_in=mem)
         ]
 
-        ops = op_set_to_list(reversed(ops))
+        ops = op_set_to_list(fn.ops[::-1])
         if ops != expected_ops:
             self.assertEqual(repr(ops), repr(expected_ops))
 
index 8ba74ebbf87d00439a9e2077eb3d83f4e63e9cbf..bdc193889f8160ae17758be0d25cc2eea37ecf35 100644 (file)
@@ -1,6 +1,6 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange,
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange,
                                                   GPRType, GlobalMem, Op, OpAddSubE,
                                                   OpClearCY, OpConcat, OpCopy,
                                                   OpFuncArg, OpInputMem, OpLI,
@@ -14,9 +14,10 @@ class TestMergedRegSet(unittest.TestCase):
     maxDiff = None
 
     def test_from_equality_constraint(self):
-        op0 = OpLI(0, length=1)
-        op1 = OpLI(0, length=2)
-        op2 = OpLI(0, length=3)
+        fn = Fn()
+        op0 = OpLI(fn, 0, length=1)
+        op1 = OpLI(fn, 0, length=2)
+        op2 = OpLI(fn, 0, length=3)
         self.assertEqual(MergedRegSet.from_equality_constraint([
             op0.out,
             op1.out,
@@ -41,15 +42,12 @@ class TestRegisterAllocator(unittest.TestCase):
     maxDiff = None
 
     def test_try_alloc_fail(self):
-        ops = []  # type: list[Op]
-        op0 = OpLI(0, length=52)
-        ops.append(op0)
-        op1 = OpLI(0, length=64)
-        ops.append(op1)
-        op2 = OpConcat([op0.out, op1.out])
-        ops.append(op2)
-
-        reg_assignments = try_allocate_registers_without_spilling(ops)
+        fn = Fn()
+        op0 = OpLI(fn, 0, length=52)
+        op1 = OpLI(fn, 0, length=64)
+        op2 = OpConcat(fn, [op0.out, op1.out])
+
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
         self.assertEqual(
             repr(reg_assignments),
             "AllocationFailed("
@@ -81,38 +79,28 @@ class TestRegisterAllocator(unittest.TestCase):
         )
 
     def test_try_alloc_bigint_inc(self):
-        ops = []  # type: list[Op]
-        op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
-        ops.append(op0)
-        op1 = OpCopy(op0.out, GPRType())
-        ops.append(op1)
+        fn = Fn()
+        op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+        op1 = OpCopy(fn, op0.out, GPRType())
         arg = op1.dest
-        op2 = OpInputMem()
-        ops.append(op2)
+        op2 = OpInputMem(fn)
         mem = op2.out
-        op3 = OpLoad(arg, offset=0, mem=mem, length=32)
-        ops.append(op3)
+        op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
         a = op3.RT
-        op4 = OpLI(1)
-        ops.append(op4)
+        op4 = OpLI(fn, 1)
         b_0 = op4.out
-        op5 = OpLI(0, length=31)
-        ops.append(op5)
+        op5 = OpLI(fn, 0, length=31)
         b_rest = op5.out
-        op6 = OpConcat([b_0, b_rest])
-        ops.append(op6)
+        op6 = OpConcat(fn, [b_0, b_rest])
         b = op6.dest
-        op7 = OpClearCY()
-        ops.append(op7)
+        op7 = OpClearCY(fn)
         cy = op7.out
-        op8 = OpAddSubE(a, b, cy, is_sub=False)
-        ops.append(op8)
+        op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
         s = op8.RT
-        op9 = OpStore(s, arg, offset=0, mem_in=mem)
-        ops.append(op9)
+        op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
         mem = op9.mem_out
 
-        reg_assignments = try_allocate_registers_without_spilling(ops)
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
 
         expected_reg_assignments = {
             op0.out: GPRRange(start=3, length=1),
@@ -132,12 +120,11 @@ class TestRegisterAllocator(unittest.TestCase):
 
     def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
         # type: (list[GPRRange], GPRRange) -> None
-        li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)]
-        ops = [*li_ops]  # type: list[Op]
-        concat = OpConcat([i.out for i in li_ops])
-        ops.append(concat)
+        fn = Fn()
+        li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)]
+        concat = OpConcat(fn, [i.out for i in li_ops])
 
-        reg_assignments = try_allocate_registers_without_spilling(ops)
+        reg_assignments = try_allocate_registers_without_spilling(fn.ops)
 
         expected_reg_assignments = {concat.dest: expected_dest_reg}
         for li_op, reg in zip(li_ops, expected_regs):