change SSAVal to link to defining Op
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 10 Oct 2022 23:35:44 +0000 (16:35 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 10 Oct 2022 23:35:44 +0000 (16:35 -0700)
src/bigint_presentation_code/toom_cook.py

index 16ef5d0db9ae5bb26c711577647e11eb16a0686b..6a389c1d17832d887c45046767322f5ed589c645 100644 (file)
@@ -7,7 +7,7 @@ from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING
 from nmutil.plain_data import plain_data
 
 if TYPE_CHECKING:
-    from typing_extensions import final
+    from typing_extensions import final, Self
 else:
     def final(v):
         return v
@@ -83,23 +83,28 @@ class StackSlot(GPROrStackLoc):
         self.offset = offset
 
 
-@plain_data(eq=False)
 class SSAVal(metaclass=ABCMeta):
-    __slots__ = "id",
+    __slots__ = "op", "arg_name", "element_index"
 
-    def __init__(self, id=None):
-        # type: (int | None) -> None
-        if id is None:
-            id = builtins.id(self)
-        self.id = id
+    def __init__(self, op, arg_name, element_index):
+        # type: (Op, str, int) -> None
+        self.op = op
+        """the Op that writes this SSAVal"""
+
+        self.arg_name = arg_name
+        self.element_index = element_index
 
+    @final
     def __eq__(self, rhs):
         if isinstance(rhs, SSAVal):
-            return self.id == rhs.id
+            return (self.op is rhs.op
+                    and self.arg_name == rhs.arg_name
+                    and self.element_index == rhs.element_index)
         return False
 
+    @final
     def __hash__(self):
-        return hash(self.id)
+        return hash((id(self.op), self.arg_name, self.element_index))
 
     def _get_phys_loc(self, phys_loc_in, value_assignments=None):
         # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
@@ -114,16 +119,32 @@ class SSAVal(metaclass=ABCMeta):
         # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
         ...
 
+    @final
+    def __repr__(self):
+        name = self.__class__.__name__
+        op = object.__repr__(self.op)
+        phys_loc = self.get_phys_loc()
+        return (f"{name}(op={op}, arg_name={self.arg_name}, "
+                f"element_index={self.element_index}, phys_loc={phys_loc})")
+
+    @final
+    def like(self, op, arg_name):
+        # type: (Op, str) -> Self
+        """create a new SSAVal based off of self's type.
+        has same signature as VecArg.like.
+        """
+        return self.__class__(op=op, arg_name=arg_name,
+                              element_index=0)
+
 
-@plain_data(eq=False)
 @final
 class SSAGPRVal(SSAVal):
     __slots__ = "phys_loc",
 
-    def __init__(self, phys_loc=None):
-        # type: (GPROrStackLoc | None) -> None
+    def __init__(self, op, arg_name, element_index, phys_loc=None):
+        # type: (Op, str, int, GPROrStackLoc | None) -> None
+        super().__init__(op, arg_name, element_index)
         self.phys_loc = phys_loc
-        super().__init__()
 
     def __len__(self):
         return 1
@@ -166,13 +187,13 @@ class SSAGPRVal(SSAVal):
                 yield reg
 
 
-@plain_data(eq=False)
 @final
 class SSAXERBitVal(SSAVal):
     __slots__ = "phys_loc",
 
-    def __init__(self, phys_loc=None):
-        # type: (XERBit | None) -> None
+    def __init__(self, op, arg_name, element_index, phys_loc=None):
+        # type: (Op, str, int, XERBit | None) -> None
+        super().__init__(op, arg_name, element_index)
         self.phys_loc = phys_loc
 
     def get_phys_loc(self, value_assignments=None):
@@ -183,21 +204,22 @@ class SSAXERBitVal(SSAVal):
         return None
 
 
-@plain_data(eq=False)
 @final
 class SSAMemory(SSAVal):
     __slots__ = "phys_loc",
 
-    def __init__(self, phys_loc=GlobalMem.GlobalMem):
-        # type: (GlobalMem) -> None
+    def __init__(self, op, arg_name, element_index,
+                 phys_loc=GlobalMem.GlobalMem):
+        # type: (Op, str, int, GlobalMem) -> None
+        super().__init__(op, arg_name, element_index)
         self.phys_loc = phys_loc
 
     def get_phys_loc(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None
+        # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem
         loc = self._get_phys_loc(self.phys_loc, value_assignments)
         if isinstance(loc, GlobalMem):
             return loc
-        return None
+        return self.phys_loc
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -272,6 +294,22 @@ class VecArg:
         else:
             yield GPR(r[index])
 
+    def like(self, op, arg_name):
+        # type: (Op, str) -> VecArg
+        """create a new VecArg based off of self's type.
+        has same signature as SSAVal.like.
+        """
+        return VecArg(
+            SSAGPRVal(op, arg_name, i) for i in range(len(self.regs)))
+
+
+def vec_or_scalar_arg(element_count, op, arg_name):
+    # type: (int | None, Op, str) -> VecArg | SSAGPRVal
+    if element_count is None:
+        return SSAGPRVal(op, arg_name, 0)
+    else:
+        return VecArg(SSAGPRVal(op, arg_name, i) for i in range(element_count))
+
 
 @final
 @plain_data(unsafe_hash=True, frozen=True)
@@ -341,16 +379,9 @@ class OpCopy(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, dest, src):
-        # type: (VecArg | SSAVal, VecArg | SSAVal) -> None
-        if isinstance(dest, VecArg) and isinstance(src, VecArg):
-            if len(src.regs) != len(dest.regs):
-                raise TypeError(f"source length must match dest "
-                                f"length: {src} doesn't match {dest}")
-        elif type(dest) != type(src):
-            raise TypeError(f"source argument type must match dest "
-                            f"argument type: {src} doesn't match {dest}")
-        self.dest = dest
+    def __init__(self, src):
+        # type: (VecArg | SSAVal) -> None
+        self.dest = src.like(op=self, arg_name="dest")
         self.src = src
 
     def possible_reg_assignments(self, val, value_assignments):
@@ -406,19 +437,16 @@ class OpAddSubE(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT, "CY_out": self.CY_out}
 
-    def __init__(self, RT, RA, RB, CY_in, CY_out, is_sub):
-        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None
-        if len(RA.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {RA} doesn't match {RT}")
-        if len(RB.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {RB} doesn't match {RT}")
-        self.RT = RT
+    def __init__(self, RA, RB, CY_in, is_sub):
+        # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None
+        if len(RA.regs) != len(RB.regs):
+            raise TypeError(f"source lengths must match: "
+                            f"{RA} doesn't match {RB}")
+        self.RT = RA.like(op=self, arg_name="RT")
         self.RA = RA
         self.RB = RB
         self.CY_in = CY_in
-        self.CY_out = CY_out
+        self.CY_out = CY_in.like(op=self, arg_name="CY_out")
         self.is_sub = is_sub
 
     def possible_reg_assignments(self, val, value_assignments):
@@ -469,16 +497,13 @@ class OpBigIntMulDiv(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT, "RS": self.RS}
 
-    def __init__(self, RT, RA, RB, RC, RS, is_div):
-        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None
-        if len(RA.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {RA} doesn't match {RT}")
-        self.RT = RT
+    def __init__(self, RA, RB, RC, is_div):
+        # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None
+        self.RT = RA.like(op=self, arg_name="RT")
         self.RA = RA
         self.RB = RB
         self.RC = RC
-        self.RS = RS
+        self.RS = RC.like(op=self, arg_name="RS")
         self.is_div = is_div
 
     def possible_reg_assignments(self, val, value_assignments):
@@ -546,12 +571,9 @@ class OpBigIntShift(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, RT, inp, sh, kind):
-        # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None
-        if len(inp.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {inp} doesn't match {RT}")
-        self.RT = RT
+    def __init__(self, inp, sh, kind):
+        # type: (VecArg, SSAGPRVal, ShiftKind) -> None
+        self.RT = inp.like(op=self, arg_name="RT")
         self.inp = inp
         self.sh = sh
         self.kind = kind
@@ -604,9 +626,9 @@ class OpLI(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"out": self.out}
 
-    def __init__(self, out, value):
-        # type: (VecArg | SSAGPRVal, int) -> None
-        self.out = out
+    def __init__(self, value, element_count=None):
+        # type: (int, int | None) -> None
+        self.out = vec_or_scalar_arg(element_count, op=self, arg_name="out")
         self.value = value
 
     def possible_reg_assignments(self, val, value_assignments):
@@ -637,9 +659,10 @@ class OpClearCY(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"out": self.out}
 
-    def __init__(self, out):
-        # type: (SSAXERBitVal) -> None
-        self.out = out
+    def __init__(self):
+        # type: () -> None
+        self.out = SSAXERBitVal(op=self, arg_name="out", element_index=0,
+                                phys_loc=XERBit.CY)
 
     def possible_reg_assignments(self, val, value_assignments):
         # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
@@ -665,9 +688,9 @@ class OpLoad(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, RT, RA, offset, mem):
-        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
-        self.RT = RT
+    def __init__(self, RA, offset, mem, element_count=None):
+        # type: (SSAGPRVal, int, SSAMemory, int | None) -> None
+        self.RT = vec_or_scalar_arg(element_count, op=self, arg_name="RT")
         self.RA = RA
         self.offset = offset
         self.mem = mem
@@ -714,13 +737,13 @@ class OpStore(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"mem_out": self.mem_out}
 
-    def __init__(self, RS, RA, offset, mem_in, mem_out):
-        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None
+    def __init__(self, RS, RA, offset, mem_in):
+        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
         self.RS = RS
         self.RA = RA
         self.offset = offset
         self.mem_in = mem_in
-        self.mem_out = mem_out
+        self.mem_out = mem_in.like(op=self, arg_name="mem_out")
 
     def possible_reg_assignments(self, val, value_assignments):
         # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
@@ -757,9 +780,13 @@ class OpFuncArg(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"out": self.out}
 
-    def __init__(self, out):
-        # type: (VecArg | SSAGPRVal) -> None
-        self.out = out
+    def __init__(self, phys_loc):
+        # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None
+        if isinstance(phys_loc, GPROrStackLoc):
+            self.out = SSAGPRVal(self, "out", 0, phys_loc)
+        else:
+            self.out = VecArg(
+                SSAGPRVal(self, "out", i, v) for i, v in enumerate(phys_loc))
 
     def possible_reg_assignments(self, val, value_assignments):
         # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
@@ -789,9 +816,9 @@ class OpInputMem(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"out": self.out}
 
-    def __init__(self, out):
-        # type: (SSAMemory) -> None
-        self.out = out
+    def __init__(self):
+        # type: () -> None
+        self.out = SSAMemory(op=self, arg_name="out", element_index=0)
 
     def possible_reg_assignments(self, val, value_assignments):
         # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
@@ -843,30 +870,30 @@ def op_set_to_list(ops):
 
 @plain_data(unsafe_hash=True, order=True, frozen=True)
 class LiveInterval:
-    __slots__ = "assignment", "last_use"
+    __slots__ = "first_write", "last_use"
 
-    def __init__(self, assignment, last_use=None):
+    def __init__(self, first_write, last_use=None):
         # type: (int, int | None) -> None
         if last_use is None:
-            last_use = assignment
-        if last_use < assignment:
-            raise ValueError("uses must be after assignment")
-        if assignment < 0 or last_use < 0:
+            last_use = first_write
+        if last_use < first_write:
+            raise ValueError("uses must be after first_write")
+        if first_write < 0 or last_use < 0:
             raise ValueError("indexes must be nonnegative")
-        self.assignment = assignment
+        self.first_write = first_write
         self.last_use = last_use
 
     def overlaps(self, other):
         # type: (LiveInterval) -> bool
-        if self.assignment == other.assignment:
+        if self.first_write == other.first_write:
             return True
-        return self.last_use > other.assignment \
-            and other.last_use > self.assignment
+        return self.last_use > other.first_write \
+            and other.last_use > self.first_write
 
     def __add__(self, use):
         # type: (int) -> LiveInterval
         last_use = max(self.last_use, use)
-        return LiveInterval(assignment=self.assignment, last_use=last_use)
+        return LiveInterval(first_write=self.first_write, last_use=last_use)
 
 
 @final