From e05169d3d027e59daea56b7f982eca1711b12bae Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 10 Oct 2022 16:35:44 -0700 Subject: [PATCH] change SSAVal to link to defining Op --- src/bigint_presentation_code/toom_cook.py | 195 ++++++++++++---------- 1 file changed, 111 insertions(+), 84 deletions(-) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 16ef5d0..6a389c1 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -7,7 +7,7 @@ from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING from nmutil.plain_data import plain_data if TYPE_CHECKING: - from typing_extensions import final + from typing_extensions import final, Self else: def final(v): return v @@ -83,23 +83,28 @@ class StackSlot(GPROrStackLoc): self.offset = offset -@plain_data(eq=False) class SSAVal(metaclass=ABCMeta): - __slots__ = "id", + __slots__ = "op", "arg_name", "element_index" - def __init__(self, id=None): - # type: (int | None) -> None - if id is None: - id = builtins.id(self) - self.id = id + def __init__(self, op, arg_name, element_index): + # type: (Op, str, int) -> None + self.op = op + """the Op that writes this SSAVal""" + + self.arg_name = arg_name + self.element_index = element_index + @final def __eq__(self, rhs): if isinstance(rhs, SSAVal): - return self.id == rhs.id + return (self.op is rhs.op + and self.arg_name == rhs.arg_name + and self.element_index == rhs.element_index) return False + @final def __hash__(self): - return hash(self.id) + return hash((id(self.op), self.arg_name, self.element_index)) def _get_phys_loc(self, phys_loc_in, value_assignments=None): # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None @@ -114,16 +119,32 @@ class SSAVal(metaclass=ABCMeta): # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None ... + @final + def __repr__(self): + name = self.__class__.__name__ + op = object.__repr__(self.op) + phys_loc = self.get_phys_loc() + return (f"{name}(op={op}, arg_name={self.arg_name}, " + f"element_index={self.element_index}, phys_loc={phys_loc})") + + @final + def like(self, op, arg_name): + # type: (Op, str) -> Self + """create a new SSAVal based off of self's type. + has same signature as VecArg.like. + """ + return self.__class__(op=op, arg_name=arg_name, + element_index=0) + -@plain_data(eq=False) @final class SSAGPRVal(SSAVal): __slots__ = "phys_loc", - def __init__(self, phys_loc=None): - # type: (GPROrStackLoc | None) -> None + def __init__(self, op, arg_name, element_index, phys_loc=None): + # type: (Op, str, int, GPROrStackLoc | None) -> None + super().__init__(op, arg_name, element_index) self.phys_loc = phys_loc - super().__init__() def __len__(self): return 1 @@ -166,13 +187,13 @@ class SSAGPRVal(SSAVal): yield reg -@plain_data(eq=False) @final class SSAXERBitVal(SSAVal): __slots__ = "phys_loc", - def __init__(self, phys_loc=None): - # type: (XERBit | None) -> None + def __init__(self, op, arg_name, element_index, phys_loc=None): + # type: (Op, str, int, XERBit | None) -> None + super().__init__(op, arg_name, element_index) self.phys_loc = phys_loc def get_phys_loc(self, value_assignments=None): @@ -183,21 +204,22 @@ class SSAXERBitVal(SSAVal): return None -@plain_data(eq=False) @final class SSAMemory(SSAVal): __slots__ = "phys_loc", - def __init__(self, phys_loc=GlobalMem.GlobalMem): - # type: (GlobalMem) -> None + def __init__(self, op, arg_name, element_index, + phys_loc=GlobalMem.GlobalMem): + # type: (Op, str, int, GlobalMem) -> None + super().__init__(op, arg_name, element_index) self.phys_loc = phys_loc def get_phys_loc(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None + # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem loc = self._get_phys_loc(self.phys_loc, value_assignments) if isinstance(loc, GlobalMem): return loc - return None + return self.phys_loc @plain_data(unsafe_hash=True, frozen=True) @@ -272,6 +294,22 @@ class VecArg: else: yield GPR(r[index]) + def like(self, op, arg_name): + # type: (Op, str) -> VecArg + """create a new VecArg based off of self's type. + has same signature as SSAVal.like. + """ + return VecArg( + SSAGPRVal(op, arg_name, i) for i in range(len(self.regs))) + + +def vec_or_scalar_arg(element_count, op, arg_name): + # type: (int | None, Op, str) -> VecArg | SSAGPRVal + if element_count is None: + return SSAGPRVal(op, arg_name, 0) + else: + return VecArg(SSAGPRVal(op, arg_name, i) for i in range(element_count)) + @final @plain_data(unsafe_hash=True, frozen=True) @@ -341,16 +379,9 @@ class OpCopy(Op): # type: () -> dict[str, VecArg | SSAVal] return {"dest": self.dest} - def __init__(self, dest, src): - # type: (VecArg | SSAVal, VecArg | SSAVal) -> None - if isinstance(dest, VecArg) and isinstance(src, VecArg): - if len(src.regs) != len(dest.regs): - raise TypeError(f"source length must match dest " - f"length: {src} doesn't match {dest}") - elif type(dest) != type(src): - raise TypeError(f"source argument type must match dest " - f"argument type: {src} doesn't match {dest}") - self.dest = dest + def __init__(self, src): + # type: (VecArg | SSAVal) -> None + self.dest = src.like(op=self, arg_name="dest") self.src = src def possible_reg_assignments(self, val, value_assignments): @@ -406,19 +437,16 @@ class OpAddSubE(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT, "CY_out": self.CY_out} - def __init__(self, RT, RA, RB, CY_in, CY_out, is_sub): - # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None - if len(RA.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {RA} doesn't match {RT}") - if len(RB.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {RB} doesn't match {RT}") - self.RT = RT + def __init__(self, RA, RB, CY_in, is_sub): + # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None + if len(RA.regs) != len(RB.regs): + raise TypeError(f"source lengths must match: " + f"{RA} doesn't match {RB}") + self.RT = RA.like(op=self, arg_name="RT") self.RA = RA self.RB = RB self.CY_in = CY_in - self.CY_out = CY_out + self.CY_out = CY_in.like(op=self, arg_name="CY_out") self.is_sub = is_sub def possible_reg_assignments(self, val, value_assignments): @@ -469,16 +497,13 @@ class OpBigIntMulDiv(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT, "RS": self.RS} - def __init__(self, RT, RA, RB, RC, RS, is_div): - # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None - if len(RA.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {RA} doesn't match {RT}") - self.RT = RT + def __init__(self, RA, RB, RC, is_div): + # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None + self.RT = RA.like(op=self, arg_name="RT") self.RA = RA self.RB = RB self.RC = RC - self.RS = RS + self.RS = RC.like(op=self, arg_name="RS") self.is_div = is_div def possible_reg_assignments(self, val, value_assignments): @@ -546,12 +571,9 @@ class OpBigIntShift(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT} - def __init__(self, RT, inp, sh, kind): - # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None - if len(inp.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {inp} doesn't match {RT}") - self.RT = RT + def __init__(self, inp, sh, kind): + # type: (VecArg, SSAGPRVal, ShiftKind) -> None + self.RT = inp.like(op=self, arg_name="RT") self.inp = inp self.sh = sh self.kind = kind @@ -604,9 +626,9 @@ class OpLI(Op): # type: () -> dict[str, VecArg | SSAVal] return {"out": self.out} - def __init__(self, out, value): - # type: (VecArg | SSAGPRVal, int) -> None - self.out = out + def __init__(self, value, element_count=None): + # type: (int, int | None) -> None + self.out = vec_or_scalar_arg(element_count, op=self, arg_name="out") self.value = value def possible_reg_assignments(self, val, value_assignments): @@ -637,9 +659,10 @@ class OpClearCY(Op): # type: () -> dict[str, VecArg | SSAVal] return {"out": self.out} - def __init__(self, out): - # type: (SSAXERBitVal) -> None - self.out = out + def __init__(self): + # type: () -> None + self.out = SSAXERBitVal(op=self, arg_name="out", element_index=0, + phys_loc=XERBit.CY) def possible_reg_assignments(self, val, value_assignments): # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] @@ -665,9 +688,9 @@ class OpLoad(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT} - def __init__(self, RT, RA, offset, mem): - # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None - self.RT = RT + def __init__(self, RA, offset, mem, element_count=None): + # type: (SSAGPRVal, int, SSAMemory, int | None) -> None + self.RT = vec_or_scalar_arg(element_count, op=self, arg_name="RT") self.RA = RA self.offset = offset self.mem = mem @@ -714,13 +737,13 @@ class OpStore(Op): # type: () -> dict[str, VecArg | SSAVal] return {"mem_out": self.mem_out} - def __init__(self, RS, RA, offset, mem_in, mem_out): - # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None + def __init__(self, RS, RA, offset, mem_in): + # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None self.RS = RS self.RA = RA self.offset = offset self.mem_in = mem_in - self.mem_out = mem_out + self.mem_out = mem_in.like(op=self, arg_name="mem_out") def possible_reg_assignments(self, val, value_assignments): # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] @@ -757,9 +780,13 @@ class OpFuncArg(Op): # type: () -> dict[str, VecArg | SSAVal] return {"out": self.out} - def __init__(self, out): - # type: (VecArg | SSAGPRVal) -> None - self.out = out + def __init__(self, phys_loc): + # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None + if isinstance(phys_loc, GPROrStackLoc): + self.out = SSAGPRVal(self, "out", 0, phys_loc) + else: + self.out = VecArg( + SSAGPRVal(self, "out", i, v) for i, v in enumerate(phys_loc)) def possible_reg_assignments(self, val, value_assignments): # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] @@ -789,9 +816,9 @@ class OpInputMem(Op): # type: () -> dict[str, VecArg | SSAVal] return {"out": self.out} - def __init__(self, out): - # type: (SSAMemory) -> None - self.out = out + def __init__(self): + # type: () -> None + self.out = SSAMemory(op=self, arg_name="out", element_index=0) def possible_reg_assignments(self, val, value_assignments): # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] @@ -843,30 +870,30 @@ def op_set_to_list(ops): @plain_data(unsafe_hash=True, order=True, frozen=True) class LiveInterval: - __slots__ = "assignment", "last_use" + __slots__ = "first_write", "last_use" - def __init__(self, assignment, last_use=None): + def __init__(self, first_write, last_use=None): # type: (int, int | None) -> None if last_use is None: - last_use = assignment - if last_use < assignment: - raise ValueError("uses must be after assignment") - if assignment < 0 or last_use < 0: + last_use = first_write + if last_use < first_write: + raise ValueError("uses must be after first_write") + if first_write < 0 or last_use < 0: raise ValueError("indexes must be nonnegative") - self.assignment = assignment + self.first_write = first_write self.last_use = last_use def overlaps(self, other): # type: (LiveInterval) -> bool - if self.assignment == other.assignment: + if self.first_write == other.first_write: return True - return self.last_use > other.assignment \ - and other.last_use > self.assignment + return self.last_use > other.first_write \ + and other.last_use > self.first_write def __add__(self, use): # type: (int) -> LiveInterval last_use = max(self.last_use, use) - return LiveInterval(assignment=self.assignment, last_use=last_use) + return LiveInterval(first_write=self.first_write, last_use=last_use) @final -- 2.30.2