test_op_set_to_list works
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 08:05:57 +0000 (01:05 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 08:05:57 +0000 (01:05 -0700)
setup.py
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/test_compiler_ir.py

index 36d91f44a9c231cf00429a520d481affb91caea3..8db62c9b410c7925078fb794c89ae4b08642f777 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -5,9 +5,19 @@ README = Path(__file__).with_name('README.md').read_text("UTF-8")
 
 version = '0.0.1'
 
+cprop = "git+https://git.libre-soc.org/git/cached-property.git@1.5.2" \
+        "#egg=cached-property-1.5.2"
+
 install_requires = [
     "libresoc-nmutil",
     'libresoc-openpower-isa',
+    # git url needed for having `pip3 install -e .` install from libre-soc git
+    'cached-property@'+cprop,
+]
+
+# git url needed for having `setup.py develop` install from libre-soc git
+dependency_links = [
+    cprop,
 ]
 
 setup(
@@ -30,4 +40,5 @@ setup(
     include_package_data=True,
     zip_safe=False,
     install_requires=install_requires,
+    dependency_links=dependency_links,
 )
index f4f09f7cb613806c63e233b66c7bab0ecc252f5c..24b86494e37c71f10c9491e697efb5f0823e37b8 100644 (file)
@@ -7,9 +7,10 @@ from collections import defaultdict
 from enum import Enum, EnumMeta, unique
 from functools import lru_cache
 from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
-                    TypeVar)
+                    TypeVar, cast)
 
-from nmutil.plain_data import plain_data
+from cached_property import cached_property
+from nmutil.plain_data import fields, plain_data
 
 if TYPE_CHECKING:
     from typing_extensions import final
@@ -196,7 +197,7 @@ class RegType(metaclass=ABCMeta):
         return ...
 
 
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+_RegType = TypeVar("_RegType", bound=RegType)
 
 
 @plain_data(frozen=True, eq=False)
@@ -359,13 +360,13 @@ class StackSlotType(RegType):
         return hash(self.length_in_slots)
 
 
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
 @final
-class SSAVal(Generic[_RegT_co]):
-    __slots__ = "op", "arg_name", "ty", "arg_index"
+class SSAVal(Generic[_RegType]):
+    __slots__ = "op", "arg_name", "ty",
 
     def __init__(self, op, arg_name, ty):
-        # type: (Op, str, _RegT_co) -> None
+        # type: (Op, str, _RegType) -> None
         self.op = op
         """the Op that writes this SSAVal"""
 
@@ -383,6 +384,19 @@ class SSAVal(Generic[_RegT_co]):
     def __hash__(self):
         return hash((id(self.op), self.arg_name))
 
+    def __repr__(self):
+        fields_list = []
+        for name in fields(self):
+            v = getattr(self, name, None)
+            if v is not None:
+                if name == "op":
+                    v = v.__repr__(just_id=True)
+                else:
+                    v = repr(v)
+                fields_list.append(f"{name}={v}")
+        fields_str = ", ".join(fields_list)
+        return f"SSAVal({fields_str})"
+
 
 @final
 @plain_data(unsafe_hash=True, frozen=True)
@@ -397,7 +411,17 @@ class EqualityConstraint:
             raise ValueError("can't constrain an empty list to be equal")
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+class _NotSet:
+    """ helper for __repr__ for when fields aren't set """
+
+    def __repr__(self):
+        return "<not set>"
+
+
+_NOT_SET = _NotSet()
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 class Op(metaclass=ABCMeta):
     __slots__ = ()
 
@@ -421,11 +445,26 @@ class Op(metaclass=ABCMeta):
         if False:
             yield ...
 
-    def __init__(self):
-        pass
+    __NEXT_ID = 0
 
+    @cached_property
+    def id(self):
+        retval = Op.__NEXT_ID
+        Op.__NEXT_ID += 1
+        return retval
 
-@plain_data(unsafe_hash=True, frozen=True)
+    @final
+    def __repr__(self, just_id=False):
+        fields_list = [f"#{self.id}"]
+        if not just_id:
+            for name in fields(self):
+                v = getattr(self, name, _NOT_SET)
+                fields_list.append(f"{name}={v!r}")
+        fields_str = ', '.join(fields_list)
+        return f"{self.__class__.__name__}({fields_str})"
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLoadFromStackSlot(Op):
     __slots__ = "dest", "src"
@@ -444,7 +483,7 @@ class OpLoadFromStackSlot(Op):
         self.src = src
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpStoreToStackSlot(Op):
     __slots__ = "dest", "src"
@@ -463,9 +502,12 @@ class OpStoreToStackSlot(Op):
         self.src = src
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+_RegSrcType = TypeVar("_RegSrcType", bound=RegType)
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
-class OpCopy(Op, Generic[_RegT_co]):
+class OpCopy(Op, Generic[_RegSrcType, _RegType]):
     __slots__ = "dest", "src"
 
     def inputs(self):
@@ -477,9 +519,9 @@ class OpCopy(Op, Generic[_RegT_co]):
         return {"dest": self.dest}
 
     def __init__(self, src, dest_ty=None):
-        # type: (SSAVal[_RegT_co], _RegT_co | None) -> None
+        # type: (SSAVal[_RegSrcType], _RegType | None) -> None
         if dest_ty is None:
-            dest_ty = src.ty
+            dest_ty = cast(_RegType, src.ty)
         if isinstance(src.ty, GPRRangeType) \
                 and isinstance(dest_ty, GPRRangeType):
             if src.ty.length != dest_ty.length:
@@ -489,11 +531,11 @@ class OpCopy(Op, Generic[_RegT_co]):
             raise ValueError(f"incompatible source and destination "
                              f"types: {src.ty} and {dest_ty}")
 
-        self.dest = SSAVal(self, "dest", dest_ty)
+        self.dest = SSAVal(self, "dest", dest_ty)  # type: SSAVal[_RegType]
         self.src = src
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpConcat(Op):
     __slots__ = "dest", "sources"
@@ -518,7 +560,7 @@ class OpConcat(Op):
         yield EqualityConstraint([self.dest], [*self.sources])
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpSplit(Op):
     __slots__ = "results", "src"
@@ -551,7 +593,7 @@ class OpSplit(Op):
         yield EqualityConstraint([*self.results], [self.src])
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpAddSubE(Op):
     __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
@@ -582,7 +624,7 @@ class OpAddSubE(Op):
         yield self.RT, self.RB
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpBigIntMulDiv(Op):
     __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
@@ -626,7 +668,7 @@ class ShiftKind(Enum):
     Sra = "sra"
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpBigIntShift(Op):
     __slots__ = "RT", "inp", "sh", "kind"
@@ -652,7 +694,7 @@ class OpBigIntShift(Op):
         yield self.RT, self.sh
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLI(Op):
     __slots__ = "out", "value"
@@ -671,7 +713,7 @@ class OpLI(Op):
         self.value = value
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpClearCY(Op):
     __slots__ = "out",
@@ -689,7 +731,7 @@ class OpClearCY(Op):
         self.out = SSAVal(self, "out", CYType())
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLoad(Op):
     __slots__ = "RT", "RA", "offset", "mem"
@@ -715,7 +757,7 @@ class OpLoad(Op):
             yield self.RT, self.RA
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpStore(Op):
     __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
@@ -737,7 +779,7 @@ class OpStore(Op):
         self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpFuncArg(Op):
     __slots__ = "out",
@@ -755,7 +797,7 @@ class OpFuncArg(Op):
         self.out = SSAVal(self, "out", ty)
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpInputMem(Op):
     __slots__ = "out",
@@ -775,33 +817,34 @@ class OpInputMem(Op):
 
 def op_set_to_list(ops):
     # type: (Iterable[Op]) -> list[Op]
-    worklists = [set()]  # type: list[set[Op]]
-    input_vals_to_ops_map = defaultdict(set)  # type: dict[SSAVal, set[Op]]
+    worklists = [{}]  # type: list[dict[Op, None]]
+    inps_to_ops_map = defaultdict(dict)  # type: dict[SSAVal, dict[Op, None]]
     ops_to_pending_input_count_map = {}  # type: dict[Op, int]
     for op in ops:
         input_count = 0
         for val in op.inputs().values():
             input_count += 1
-            input_vals_to_ops_map[val].add(op)
+            inps_to_ops_map[val][op] = None
         while len(worklists) <= input_count:
-            worklists.append(set())
+            worklists.append({})
         ops_to_pending_input_count_map[op] = input_count
-        worklists[input_count].add(op)
+        worklists[input_count][op] = None
     retval = []  # type: list[Op]
     ready_vals = set()  # type: set[SSAVal]
     while len(worklists[0]) != 0:
-        writing_op = worklists[0].pop()
+        writing_op = next(iter(worklists[0]))
+        del worklists[0][writing_op]
         retval.append(writing_op)
         for val in writing_op.outputs().values():
             if val in ready_vals:
                 raise ValueError(f"multiple instructions must not write "
                                  f"to the same SSA value: {val}")
             ready_vals.add(val)
-            for reading_op in input_vals_to_ops_map[val]:
+            for reading_op in inps_to_ops_map[val]:
                 pending = ops_to_pending_input_count_map[reading_op]
-                worklists[pending].remove(reading_op)
+                del worklists[pending][reading_op]
                 pending -= 1
-                worklists[pending].add(reading_op)
+                worklists[pending][reading_op] = None
                 ops_to_pending_input_count_map[reading_op] = pending
     for worklist in worklists:
         for op in worklist:
index 297e3e5391c9e1a8821430e8a1bd789907ad30c3..75b342243fa95cd7223647b6204f0d33d7e36cf0 100644 (file)
@@ -20,7 +20,7 @@ else:
         return v
 
 
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+_RegType = TypeVar("_RegType", bound=RegType)
 
 
 @plain_data(unsafe_hash=True, order=True, frozen=True)
@@ -59,10 +59,10 @@ class LiveInterval:
 
 
 @final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
+class MergedRegSet(Mapping[SSAVal[_RegType], int]):
     def __init__(self, reg_set):
-        # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
-        self.__items = {}  # type: dict[SSAVal[_RegT_co], int]
+        # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
+        self.__items = {}  # type: dict[SSAVal[_RegType], int]
         if isinstance(reg_set, SSAVal):
             reg_set = [(reg_set, 0)]
         for ssa_val, offset in reg_set:
@@ -107,7 +107,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
 
     @staticmethod
     def from_equality_constraint(constraint_sequence):
-        # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
+        # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
         if len(constraint_sequence) == 1:
             # any type allowed with len = 1
             return MergedRegSet(constraint_sequence[0])
@@ -138,22 +138,22 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
         return range(self.__start, self.__stop)
 
     def offset_by(self, amount):
-        # type: (int) -> MergedRegSet[_RegT_co]
+        # type: (int) -> MergedRegSet[_RegType]
         return MergedRegSet((k, v + amount) for k, v in self.items())
 
     def normalized(self):
-        # type: () -> MergedRegSet[_RegT_co]
+        # type: () -> MergedRegSet[_RegType]
         return self.offset_by(-self.start)
 
     def with_offset_to_match(self, target):
-        # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
+        # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
         for ssa_val, offset in self.items():
             if ssa_val in target:
                 return self.offset_by(target[ssa_val] - offset)
         raise ValueError("can't change offset to match unrelated MergedRegSet")
 
     def __getitem__(self, item):
-        # type: (SSAVal[_RegT_co]) -> int
+        # type: (SSAVal[_RegType]) -> int
         return self.__items[item]
 
     def __iter__(self):
@@ -170,10 +170,10 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
 
 
 @final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
     def __init__(self, ops):
         # type: (Iterable[Op]) -> None
-        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegT_co]]
+        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegType]]
         for op in ops:
             for val in (*op.inputs().values(), *op.outputs().values()):
                 if val not in merged_sets:
@@ -204,11 +204,11 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
 
 
 @final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
     def __init__(self, ops):
         # type: (list[Op]) -> None
         self.__merged_reg_sets = MergedRegSets(ops)
-        live_intervals = {}  # type: dict[MergedRegSet[_RegT_co], LiveInterval]
+        live_intervals = {}  # type: dict[MergedRegSet[_RegType], LiveInterval]
         for op_idx, op in enumerate(ops):
             for val in op.inputs().values():
                 live_intervals[self.__merged_reg_sets[val]] += op_idx
@@ -219,7 +219,7 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
                 else:
                     live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
-        live_after = []  # type: list[set[MergedRegSet[_RegT_co]]]
+        live_after = []  # type: list[set[MergedRegSet[_RegType]]]
         live_after += (set() for _ in ops)
         for reg_set, live_interval in self.__live_intervals.items():
             for i in live_interval.live_after_op_range:
@@ -231,14 +231,14 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
         return self.__merged_reg_sets
 
     def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> LiveInterval
+        # type: (MergedRegSet[_RegType]) -> LiveInterval
         return self.__live_intervals[key]
 
     def __iter__(self):
         return iter(self.__live_intervals)
 
     def reg_sets_live_after(self, op_index):
-        # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+        # type: (int) -> frozenset[MergedRegSet[_RegType]]
         return self.__live_after[op_index]
 
     def __repr__(self):
@@ -249,12 +249,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
 
 
 @final
-class IGNode(Generic[_RegT_co]):
+class IGNode(Generic[_RegType]):
     """ interference graph node """
     __slots__ = "merged_reg_set", "edges", "reg"
 
     def __init__(self, merged_reg_set, edges=(), reg=None):
-        # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
+        # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
         self.merged_reg_set = merged_reg_set
         self.edges = set(edges)
         self.reg = reg
@@ -300,13 +300,13 @@ class IGNode(Generic[_RegT_co]):
 
 
 @final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
+class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
     def __init__(self, merged_reg_sets):
-        # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
+        # type: (Iterable[MergedRegSet[_RegType]]) -> None
         self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
 
     def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> IGNode
+        # type: (MergedRegSet[_RegType]) -> IGNode
         return self.__nodes[key]
 
     def __iter__(self):
index 231f2fb93182d9cd59ac0410abb1426c92e8a08a..26d5272a040fb178d3cd92c5a33e0c0b6bed0430 100644 (file)
@@ -1,10 +1,61 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import Op, op_set_to_list
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType,
+                                                  Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore,
+                                                  op_set_to_list)
 
 
 class TestCompilerIR(unittest.TestCase):
-    pass  # no tests yet, just testing importing
+    maxDiff = None
+
+    def test_op_set_to_list(self):
+        ops = []  # list[Op]
+        op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+        ops.append(op0)
+        op1 = OpCopy(op0.out, GPRType())
+        ops.append(op1)
+        arg = op1.dest
+        op2 = OpInputMem()
+        ops.append(op2)
+        mem = op2.out
+        op3 = OpLoad(arg, offset=0, mem=mem, length=32)
+        ops.append(op3)
+        a = op3.RT
+        op4 = OpLI(1)
+        ops.append(op4)
+        b_0 = op4.out
+        op5 = OpLI(0, length=31)
+        ops.append(op5)
+        b_rest = op5.out
+        op6 = OpConcat([b_0, b_rest])
+        ops.append(op6)
+        b = op6.dest
+        op7 = OpClearCY()
+        ops.append(op7)
+        cy = op7.out
+        op8 = OpAddSubE(a, b, cy, is_sub=False)
+        ops.append(op8)
+        s = op8.RT
+        op9 = OpStore(s, arg, offset=0, mem_in=mem)
+        ops.append(op9)
+        mem = op9.mem_out
+
+        expected_ops = [
+            op7,  # OpClearCY()
+            op5,  # OpLI(0, length=31)
+            op4,  # OpLI(1)
+            op2,  # OpInputMem()
+            op0,  # OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+            op6,  # OpConcat([b_0, b_rest])
+            op1,  # OpCopy(op0.out, GPRType())
+            op3,  # OpLoad(arg, offset=0, mem=mem, length=32)
+            op8,  # OpAddSubE(a, b, cy, is_sub=False)
+            op9,  # OpStore(s, arg, offset=0, mem_in=mem)
+        ]
+
+        ops = op_set_to_list(reversed(ops))
+        if ops != expected_ops:
+            self.assertEqual(repr(ops), repr(expected_ops))
 
 
 if __name__ == "__main__":