192x192->384-bit O(n^2) mul works in SSA form, reg-alloc gives incorrect results...
authorJacob Lifshay <programmerjake@gmail.com>
Sun, 23 Oct 2022 07:25:43 +0000 (00:25 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sun, 23 Oct 2022 07:25:43 +0000 (00:25 -0700)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/test_register_allocator.py
src/bigint_presentation_code/test_toom_cook.py
src/bigint_presentation_code/toom_cook.py

index 517e5423e267ed031da12ef08c01af7c7090b39e..17f5f3fde6dacfa74f42a5eebea73976323fd858 100644 (file)
@@ -12,7 +12,7 @@ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
 
 from nmutil.plain_data import fields, plain_data
 
-from bigint_presentation_code.util import OFSet, OSet, final
+from bigint_presentation_code.util import FMap, OFSet, OSet, final
 
 
 class ABCEnumMeta(EnumMeta, ABCMeta):
@@ -41,7 +41,7 @@ class RegLoc(metaclass=ABCMeta):
 GPR_COUNT = 128
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class GPRRange(RegLoc, Sequence["GPRRange"]):
     __slots__ = "start", "length"
@@ -114,6 +114,11 @@ class GPRRange(RegLoc, Sequence["GPRRange"]):
             raise ValueError(f"sub-register offset is out of range: {offset}")
         return GPRRange(self.start + offset, subreg_type.length)
 
+    def __repr__(self):
+        if self.length == 1:
+            return f"<r{self.start}>"
+        return f"<r{self.start}..len={self.length}>"
+
 
 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
 
@@ -191,7 +196,7 @@ _RegType = TypeVar("_RegType", bound=RegType)
 _RegLoc = TypeVar("_RegLoc", bound=RegLoc)
 
 
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
 @final
 class GPRRangeType(RegType):
     __slots__ = "length",
@@ -230,12 +235,15 @@ class GPRRangeType(RegType):
     def __hash__(self):
         return hash(self.length)
 
+    def __repr__(self):
+        return f"<gpr_ty[{self.length}]>"
+
 
 GPRType = GPRRangeType
 """a length=1 GPRRangeType"""
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class FixedGPRRangeType(RegType):
     __slots__ = "reg",
@@ -254,6 +262,9 @@ class FixedGPRRangeType(RegType):
         # type: () -> int
         return self.reg.length
 
+    def __repr__(self):
+        return f"<fixed({self.reg})>"
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
@@ -411,20 +422,8 @@ class SSAVal(Generic[_RegType]):
     def __hash__(self):
         return hash((id(self.op), self.arg_name))
 
-    def __repr__(self, long=False):
-        if not long:
-            return f"<#{self.op.id}.{self.arg_name}>"
-        fields_list = []
-        for name in fields(self):
-            v = getattr(self, name, None)
-            if v is not None:
-                if name == "op":
-                    v = v.__repr__(just_id=True)
-                else:
-                    v = repr(v)
-                fields_list.append(f"{name}={v}")
-        fields_str = ", ".join(fields_list)
-        return f"SSAVal({fields_str})"
+    def __repr__(self):
+        return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
 
 
 SSAGPRRange = SSAVal[GPRRangeType]
@@ -459,6 +458,11 @@ class Fn:
         ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
         return f"<Fn([{ops}])>"
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        for op in self.ops:
+            op.pre_ra_sim(state)
+
 
 class _NotSet:
     """ helper for __repr__ for when fields aren't set """
@@ -641,6 +645,36 @@ class AsmContext:
         return False
 
 
+GPR_SIZE_IN_BYTES = 8
+GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
+GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
+
+
+@plain_data(frozen=True)
+@final
+class PreRASimState:
+    __slots__ = ("gprs", "VLs", "CAs",
+                 "global_mems", "stack_slots",
+                 "fixed_gprs")
+
+    def __init__(
+        self,
+        gprs,  # type: dict[SSAGPRRange, tuple[int, ...]]
+        VLs,  # type: dict[SSAKnownVL, int]
+        CAs,  # type: dict[SSAVal[CAType], bool]
+        global_mems,  # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
+        stack_slots,  # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
+        fixed_gprs,  # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
+    ):
+        # type: (...) -> None
+        self.gprs = gprs
+        self.VLs = VLs
+        self.CAs = CAs
+        self.global_mems = global_mems
+        self.stack_slots = stack_slots
+        self.fixed_gprs = fixed_gprs
+
+
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 class Op(metaclass=ABCMeta):
     __slots__ = "id", "fn"
@@ -681,15 +715,14 @@ class Op(metaclass=ABCMeta):
             pass
         if not just_id:
             for name in fields(self):
+                if name in ("id", "fn"):
+                    continue
                 v = getattr(self, name, _NOT_SET)
-                if ((outputs is None or name in outputs)
-                        and isinstance(v, SSAVal)):
-                    v = v.__repr__(long=True)
-                elif isinstance(v, Fn):
-                    v = v.__repr__(short=True)
+                if (outputs is not None and name in outputs
+                        and outputs[name] is v):
+                    fields_list.append(repr(v))
                 else:
-                    v = repr(v)
-                fields_list.append(f"{name}={v}")
+                    fields_list.append(f"{name}={v!r}")
         fields_str = ', '.join(fields_list)
         return f"{self.__class__.__name__}({fields_str})"
 
@@ -699,6 +732,12 @@ class Op(metaclass=ABCMeta):
         """get the lines of assembly for this Op"""
         ...
 
+    @abstractmethod
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        """simulate op before register allocation"""
+        ...
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -732,6 +771,11 @@ class OpLoadFromStackSlot(Op):
             return [f"sv.ld {dest}, {src.start_byte}(1)"]
         return [f"ld {dest}, {src.start_byte}(1)"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        """simulate op before register allocation"""
+        state.gprs[self.dest] = state.stack_slots[self.src]
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -765,6 +809,11 @@ class OpStoreToStackSlot(Op):
             return [f"sv.std {src}, {dest.start_byte}(1)"]
         return [f"std {src}, {dest.start_byte}(1)"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        """simulate op before register allocation"""
+        state.stack_slots[self.dest] = state.gprs[self.src]
+
 
 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
 
@@ -805,6 +854,8 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
         elif src.ty != dest_ty:
             raise ValueError(f"incompatible source and destination "
                              f"types: {src.ty} and {dest_ty}")
+        elif isinstance(src.ty, StackSlotType):
+            raise ValueError("can't use OpCopy on stack slots")
         elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
             length = src.ty.length
         else:
@@ -834,6 +885,37 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
             return [f"or {dest_s}, {src_s}, {src_s}"]
         raise NotImplementedError
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
+                isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
+            if isinstance(self.src.ty, GPRRangeType):
+                v = state.gprs[self.src]  # type: ignore
+            else:
+                v = state.fixed_gprs[self.src]  # type: ignore
+            if isinstance(self.dest.ty, GPRRangeType):
+                state.gprs[self.dest] = v  # type: ignore
+            else:
+                state.fixed_gprs[self.dest] = v  # type: ignore
+        elif (isinstance(self.src.ty, FixedGPRRangeType) and
+                isinstance(self.dest.ty, GPRRangeType)):
+            state.gprs[self.dest] = state.fixed_gprs[self.src]  # type: ignore
+        elif (isinstance(self.src.ty, GPRRangeType) and
+                isinstance(self.dest.ty, FixedGPRRangeType)):
+            state.fixed_gprs[self.dest] = state.gprs[self.src]  # type: ignore
+        elif (isinstance(self.src.ty, CAType) and
+                self.src.ty == self.dest.ty):
+            state.CAs[self.dest] = state.CAs[self.src]  # type: ignore
+        elif (isinstance(self.src.ty, KnownVLType) and
+                self.src.ty == self.dest.ty):
+            state.VLs[self.dest] = state.VLs[self.src]  # type: ignore
+        elif (isinstance(self.src.ty, GlobalMemType) and
+                self.src.ty == self.dest.ty):
+            v = state.global_mems[self.src]  # type: ignore
+            state.global_mems[self.dest] = v  # type: ignore
+        else:
+            raise NotImplementedError
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -864,6 +946,13 @@ class OpConcat(Op):
         # type: (AsmContext) -> list[str]
         return []
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        v = []
+        for src in self.sources:
+            v.extend(state.gprs[src])
+        state.gprs[self.dest] = tuple(v)
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -892,7 +981,7 @@ class OpSplit(Op):
         ranges.append(GPRRangeType(src.ty.length - last))
         self.src = src
         self.results = tuple(
-            SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
+            SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
 
     def get_equality_constraints(self):
         # type: () -> Iterable[EqualityConstraint]
@@ -902,6 +991,13 @@ class OpSplit(Op):
         # type: (AsmContext) -> list[str]
         return []
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        rest = state.gprs[self.src]
+        for dest in reversed(self.results):
+            state.gprs[dest] = rest[-dest.ty.length:]
+            rest = rest[:-dest.ty.length]
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -956,6 +1052,19 @@ class OpBigIntAddSub(Op):
             return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
         return [f"{mnemonic} {out}, {RA}, {RB}"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        carry = state.CAs[self.CA_in]
+        out = []  # type: list[int]
+        for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
+            if self.is_sub:
+                r = r ^ GPR_VALUE_MASK
+            s = l + r + carry
+            carry = s != (s & GPR_VALUE_MASK)
+            out.append(s & GPR_VALUE_MASK)
+        state.CAs[self.CA_out] = carry
+        state.gprs[self.out] = tuple(out)
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1013,6 +1122,29 @@ class OpBigIntMulDiv(Op):
             mnemonic = "divmod2du/mrr"
         return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        carry = state.gprs[self.RC][0]
+        RA = state.gprs[self.RA]
+        RB = state.gprs[self.RB][0]
+        RT = [0] * self.RT.ty.length
+        if self.is_div:
+            for i in reversed(range(self.RT.ty.length)):
+                if carry < RB and RB != 0:
+                    div, mod = divmod((carry << 64) | RA[i], RB)
+                    RT[i] = div & GPR_VALUE_MASK
+                    carry = mod & GPR_VALUE_MASK
+                else:
+                    RT[i] = GPR_VALUE_MASK
+                    carry = 0
+        else:
+            for i in range(self.RT.ty.length):
+                v = RA[i] * RB + carry
+                carry = v >> 64
+                RT[i] = v & GPR_VALUE_MASK
+        state.gprs[self.RS] = carry,
+        state.gprs[self.RT] = tuple(RT)
+
 
 @final
 @unique
@@ -1101,6 +1233,27 @@ class OpBigIntShift(Op):
             RB = ctx.sgpr(self.sh)
             return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        out = [0] * self.out.ty.length
+        carry = state.gprs[self.carry_in][0]
+        sh = state.gprs[self.sh][0] % 64
+        if self.kind is ShiftKind.Sl:
+            inp = carry, *state.gprs[self.inp]
+            for i in reversed(range(self.out.ty.length)):
+                v = inp[i] | (inp[i + 1] << 64)
+                v <<= sh
+                out[i] = (v >> 64) & GPR_VALUE_MASK
+        else:
+            assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+            inp = *state.gprs[self.inp], carry
+            for i in range(self.out.ty.length):
+                v = inp[i] | (inp[i + 1] << 64)
+                v >>= sh
+                out[i] = v & GPR_VALUE_MASK
+        # state.gprs[self._out_padding] is intentionally not written
+        state.gprs[self.out] = tuple(out)
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1150,6 +1303,25 @@ class OpShiftImm(Op):
             return [f"sv.{mnemonic} {out}, {inp}, {args}"]
         return [f"{mnemonic} {out}, {inp}, {args}"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        inp = state.gprs[self.inp][0]
+        if self.kind is ShiftKind.Sl:
+            assert self.ca_out is None
+            out = inp << self.sh
+        elif self.kind is ShiftKind.Sr:
+            assert self.ca_out is None
+            out = inp >> self.sh
+        else:
+            assert self.kind is ShiftKind.Sra
+            assert self.ca_out is not None
+            if inp & (1 << 63):  # sign extend
+                inp -= 1 << 64
+            out = inp >> self.sh
+            ca = inp < 0 and (out << self.sh) != inp
+            state.CAs[self.ca_out] = ca
+        state.gprs[self.out] = out,
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1189,6 +1361,11 @@ class OpLI(Op):
             return [f"sv.addi {out}, 0, {self.value}"]
         return [f"addi {out}, 0, {self.value}"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        value = self.value & GPR_VALUE_MASK
+        state.gprs[self.out] = (value,) * self.out.ty.length
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1215,6 +1392,10 @@ class OpSetCA(Op):
             return ["subfic 0, 0, -1"]
         return ["addic 0, 0, 0"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        state.CAs[self.out] = self.value
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1265,6 +1446,22 @@ class OpLoad(Op):
             return [f"sv.ld {RT}, {self.offset}({RA})"]
         return [f"ld {RT}, {self.offset}({RA})"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        addr = state.gprs[self.RA][0]
+        addr += self.offset
+        RT = [0] * self.RT.ty.length
+        mem = state.global_mems[self.mem]
+        for i in range(self.RT.ty.length):
+            cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
+            if cur_addr % GPR_SIZE_IN_BYTES != 0:
+                raise ValueError(f"can't load from unaligned address: "
+                                 f"{cur_addr:#x}")
+            for j in range(GPR_SIZE_IN_BYTES):
+                byte_val = mem.get(cur_addr + j, 0) & 0xFF
+                RT[i] |= byte_val << (j * 8)
+        state.gprs[self.RT] = tuple(RT)
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1308,6 +1505,21 @@ class OpStore(Op):
             return [f"sv.std {RS}, {self.offset}({RA})"]
         return [f"std {RS}, {self.offset}({RA})"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        mem = dict(state.global_mems[self.mem_in])
+        addr = state.gprs[self.RA][0]
+        addr += self.offset
+        RS = state.gprs[self.RS]
+        for i in range(self.RS.ty.length):
+            cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
+            if cur_addr % GPR_SIZE_IN_BYTES != 0:
+                raise ValueError(f"can't store to unaligned address: "
+                                 f"{cur_addr:#x}")
+            for j in range(GPR_SIZE_IN_BYTES):
+                mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
+        state.global_mems[self.mem_out] = FMap(mem)
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1331,6 +1543,11 @@ class OpFuncArg(Op):
         # type: (AsmContext) -> list[str]
         return []
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        if self.out not in state.fixed_gprs:
+            state.fixed_gprs[self.out] = (0,) * self.out.ty.length
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1354,6 +1571,11 @@ class OpInputMem(Op):
         # type: (AsmContext) -> list[str]
         return []
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        if self.out not in state.global_mems:
+            state.global_mems[self.out] = FMap()
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -1377,6 +1599,10 @@ class OpSetVLImm(Op):
         # type: (AsmContext) -> list[str]
         return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
 
+    def pre_ra_sim(self, state):
+        # type: (PreRASimState) -> None
+        state.VLs[self.out] = self.out.ty.length
+
 
 def op_set_to_list(ops):
     # type: (Iterable[Op]) -> list[Op]
index 65582643e789af865c3e0372413aa62e5a8e97a9..1eff2546e687620023e7a60a1e767d08015e3106 100644 (file)
@@ -56,43 +56,62 @@ class TestRegisterAllocator(unittest.TestCase):
             repr(reg_assignments),
             "AllocationFailed("
             "node=IGNode(#0, merged_reg_set=MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
             "edges={}, reg=None), "
-            "live_intervals=LiveIntervals("
-            "live_intervals={"
-            "MergedRegSet([(<#0.out>, 0)]): "
+            "live_intervals=LiveIntervals(live_intervals={"
+            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
             "LiveInterval(first_write=0, last_use=1), "
-            "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): "
+            "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]): "
             "LiveInterval(first_write=1, last_use=4), "
-            "MergedRegSet([(<#2.out>, 0)]): "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
             "LiveInterval(first_write=2, last_use=3)}, "
             "merged_reg_sets=MergedRegSets(data={"
-            "<#0.out>: MergedRegSet([(<#0.out>, 0)]), "
-            "<#1.out>: MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
-            "<#2.out>: MergedRegSet([(<#2.out>, 0)]), "
-            "<#3.out>: MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
-            "<#4.dest>: MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), "
+            "<#0.out: KnownVLType(length=52)>: "
+            "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
+            "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "<#2.out: KnownVLType(length=64)>: "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
+            "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])}), "
             "reg_sets_live_after={"
-            "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), "
+            "0: OFSet([MergedRegSet(["
+            "(<#0.out: KnownVLType(length=52)>, 0)])]), "
             "1: OFSet([MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
             "2: OFSet([MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
-            "MergedRegSet([(<#2.out>, 0)])]), "
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), "
+            "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
             "3: OFSet([MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)])]), "
             "4: OFSet()}), "
             "interference_graph=InterferenceGraph(nodes={"
-            "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), "
-            "edges={}, reg=None), "
+            "...: IGNode(#0, merged_reg_set=MergedRegSet(["
+            "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
             "...: IGNode(#1, merged_reg_set=MergedRegSet(["
-            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
-            "edges={}, reg=None), "
-            "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), "
-            "edges={}, reg=None)}))"
+            "(<#4.dest: <gpr_ty[116]>>, 0), "
+            "(<#1.out: <gpr_ty[52]>>, 0), "
+            "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
+            "...: IGNode(#2, merged_reg_set=MergedRegSet(["
+            "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
         )
 
     def test_try_alloc_bigint_inc(self):
index 656c8d7c8bc74401f2edbf42ffd2a3aed59f7a23..6fff570f0a485d49fc39d723da5f25a88a8271cb 100644 (file)
@@ -1,10 +1,37 @@
 import unittest
 
-from bigint_presentation_code.toom_cook import ToomCookInstance
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
+                                                  GlobalMem, GPRRange,
+                                                  GPRRangeType, OpCopy,
+                                                  OpFuncArg, OpInputMem,
+                                                  OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
+                                                  generate_assembly)
+from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.util import FMap
+
+
+class SimpleMul192x192:
+    def __init__(self):
+        self.fn = fn = Fn()
+        self.mem_in = mem = OpInputMem(fn).out
+        self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
+        self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
+        self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
+        dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
+        vl = OpSetVLImm(fn, 3).out
+        lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
+        rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
+        retval = simple_mul(fn, lhs, rhs)
+        vl = OpSetVLImm(fn, 6).out
+        self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
+                               mem_in=mem, vl=vl).mem_out
 
 
 class TestToomCook(unittest.TestCase):
-    def test_toom_2(self):
+    maxDiff = None
+
+    def test_toom_2_repr(self):
         TOOM_2 = ToomCookInstance.make_toom_2()
         # print(repr(repr(TOOM_2)))
         self.assertEqual(
@@ -42,7 +69,7 @@ class TestToomCook(unittest.TestCase):
             "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
         )
 
-    def test_toom_2_5(self):
+    def test_toom_2_5_repr(self):
         TOOM_2_5 = ToomCookInstance.make_toom_2_5()
         # print(repr(repr(TOOM_2_5)))
         self.assertEqual(
@@ -107,9 +134,9 @@ class TestToomCook(unittest.TestCase):
             "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
         )
 
-    def test_reversed_toom_2_5(self):
+    def test_reversed_toom_2_5_repr(self):
         TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
-        print(repr(repr(TOOM_2_5)))
+        print(repr(repr(TOOM_2_5)))
         self.assertEqual(
             repr(TOOM_2_5),
             "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
@@ -169,6 +196,183 @@ class TestToomCook(unittest.TestCase):
             "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
         )
 
+    def test_simple_mul_192x192_pre_ra_sim(self):
+        # test multiplying:
+        #   0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
+        # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
+        # ==
+        # int("0x00074736574206e_6f69746163696c70"
+        #     "_69746c756d207469_622d3438333e2d32"
+        #     "_3931783239312079_7261727469627261", base=0)
+        # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
+        #                   'little')
+        code = SimpleMul192x192()
+        dest_ptr = 0x100
+        state = PreRASimState(
+            gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
+            stack_slots={}, fixed_gprs={
+                code.dest_ptr_in: (dest_ptr,),
+                code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
+                              0x000191acb262e15b),
+                code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
+                              0x4a37c0567bcbab53)
+            })
+        code.fn.pre_ra_sim(state)
+        expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
+        OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
+        expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
+        mem_out = state.global_mems[code.mem_out]
+        out_bytes = bytes(
+            mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
+        self.assertEqual(out_bytes, expected_bytes)
+
+    def test_simple_mul_192x192_ops(self):
+        code = SimpleMul192x192()
+        fn = code.fn
+        self.assertEqual([repr(v) for v in fn.ops], [
+            'OpInputMem(#0, <#0.out: GlobalMemType()>)',
+            'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
+            'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
+            'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
+            'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
+            'vl=None)',
+            'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
+            'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
+            'src=<#2.out: <fixed(<r4..len=3>)>>, '
+            'vl=<#5.out: KnownVLType(length=3)>)',
+            'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
+            'src=<#3.out: <fixed(<r7..len=3>)>>, '
+            'vl=<#5.out: KnownVLType(length=3)>)',
+            'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
+            '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
+            'src=<#7.dest: <gpr_ty[3]>>)',
+            'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
+            'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
+            'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
+            '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
+            'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
+            '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
+            'OpSetCA(#15, <#15.out: CAType()>, value=False)',
+            'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
+            'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
+            'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
+            'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+            'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
+            'is_sub=False, vl=None)',
+            'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
+            '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
+            '<#17.out: <gpr_ty[1]>>))',
+            'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
+            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
+            'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
+            '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
+            'OpSetCA(#21, <#21.out: CAType()>, value=False)',
+            'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
+            'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
+            'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
+            'vl=<#9.out: KnownVLType(length=3)>)',
+            'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
+            'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+            'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
+            'is_sub=False, vl=None)',
+            'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
+            '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
+            '<#23.out: <gpr_ty[1]>>))',
+            'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
+            'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
+            'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
+            'mem_in=<#0.out: GlobalMemType()>, '
+            '<#26.mem_out: GlobalMemType()>, '
+            'vl=<#25.out: KnownVLType(length=6)>)'
+        ])
+
+    # FIXME: register allocator currently allocates wrong registers
+    @unittest.expectedFailure
+    def test_simple_mul_192x192_reg_alloc(self):
+        code = SimpleMul192x192()
+        fn = code.fn
+        assigned_registers = allocate_registers(fn.ops)
+        self.assertEqual(assigned_registers, {
+            fn.ops[13].RS: GPRRange(9),  # type: ignore
+            fn.ops[14].results[0]: GPRRange(6),  # type: ignore
+            fn.ops[14].results[1]: GPRRange(7, length=3),  # type: ignore
+            fn.ops[15].out: XERBit.CA,  # type: ignore
+            fn.ops[16].out: GPRRange(7, length=3),  # type: ignore
+            fn.ops[16].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[17].out: GPRRange(10),  # type: ignore
+            fn.ops[17].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[18].dest: GPRRange(6, length=5),  # type: ignore
+            fn.ops[19].RT: GPRRange(3, length=3),  # type: ignore
+            fn.ops[19].RS: GPRRange(9),  # type: ignore
+            fn.ops[20].results[0]: GPRRange(6, length=2),  # type: ignore
+            fn.ops[20].results[1]: GPRRange(8, length=3),  # type: ignore
+            fn.ops[21].out: XERBit.CA,  # type: ignore
+            fn.ops[22].out: GPRRange(8, length=3),  # type: ignore
+            fn.ops[22].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[23].out: GPRRange(11),  # type: ignore
+            fn.ops[23].CA_out: XERBit.CA,  # type: ignore
+            fn.ops[24].dest: GPRRange(6, length=6),  # type: ignore
+            fn.ops[25].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[26].mem_out: GlobalMem.GlobalMem,  # type: ignore
+            fn.ops[0].out: GlobalMem.GlobalMem,  # type: ignore
+            fn.ops[1].out: GPRRange(3),  # type: ignore
+            fn.ops[2].out: GPRRange(4, length=3),  # type: ignore
+            fn.ops[3].out: GPRRange(7, length=3),  # type: ignore
+            fn.ops[4].dest: GPRRange(12),  # type: ignore
+            fn.ops[5].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[6].dest: GPRRange(17, length=3),  # type: ignore
+            fn.ops[7].dest: GPRRange(14, length=3),  # type: ignore
+            fn.ops[8].results[0]: GPRRange(14),  # type: ignore
+            fn.ops[8].results[1]: GPRRange(15),  # type: ignore
+            fn.ops[8].results[2]: GPRRange(16),  # type: ignore
+            fn.ops[9].out: VL.VL_MAXVL,  # type: ignore
+            fn.ops[10].out: GPRRange(9),  # type: ignore
+            fn.ops[11].RT: GPRRange(6, length=3),  # type: ignore
+            fn.ops[11].RS: GPRRange(9),  # type: ignore
+            fn.ops[12].dest: GPRRange(6, length=4),  # type: ignore
+            fn.ops[13].RT: GPRRange(3, length=3)  # type: ignore
+        })
+        self.fail("register allocator currently allocates wrong registers")
+
+    # FIXME: register allocator currently allocates wrong registers
+    @unittest.expectedFailure
+    def test_simple_mul_192x192_asm(self):
+        code = SimpleMul192x192()
+        asm = generate_assembly(code.fn.ops)
+        self.assertEqual(asm, [
+            'or 12, 3, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *17, *4, *4',
+            'sv.or *14, *7, *7',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'addi 9, 0, 0',
+            'sv.maddedu *6, *17, 14, 9',
+            'sv.maddedu *3, *17, 15, 9',
+            'addic 0, 0, 0',
+            'sv.adde *7, *3, *7',
+            'adde 10, 9, 9',
+            'sv.maddedu *3, *17, 16, 9',
+            'addic 0, 0, 0',
+            'sv.adde *8, *3, *8',
+            'adde 11, 9, 9',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'sv.std *6, 0(12)',
+            'bclr 20, 0, 0'
+        ])
+        self.fail("register allocator currently allocates wrong registers")
+
 
 if __name__ == "__main__":
     unittest.main()
index 9786624d7dc395df7ce17d6b977351f14430f38e..9e3ec74e8842431bb0c3df895c8053e5b51bbd28 100644 (file)
@@ -8,7 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
 
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir import Fn, Op
+from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange
 from bigint_presentation_code.matrix import Matrix
 from bigint_presentation_code.util import Literal, OSet, final
 
@@ -438,8 +438,33 @@ class ToomCookInstance:
     # TODO: add make_toom_3
 
 
-def toom_cook_mul(fn, word_count, instances):
-    # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
-    retval = OSet()  # type: OSet[Op]
-    raise NotImplementedError
+def simple_mul(fn, lhs, rhs):
+    # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
+    """ simple O(n^2) big-int unsigned multiply """
+    if lhs.ty.length < rhs.ty.length:
+        lhs, rhs = rhs, lhs
+    # split rhs into elements
+    rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results
+    retval = None
+    vl = OpSetVLImm(fn, lhs.ty.length).out
+    zero = OpLI(fn, 0).out
+    for shift, rhs_word in enumerate(rhs_words):
+        mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero,
+                             is_div=False, vl=vl)
+        if retval is None:
+            retval = OpConcat(fn, [mul.RT, mul.RS]).dest
+        else:
+            first_part, last_part = OpSplit(fn, retval, [shift]).results
+            add = OpBigIntAddSub(
+                fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out,
+                is_sub=False, vl=vl)
+            add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out,
+                                    is_sub=False)
+            retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest
+    assert retval is not None
     return retval
+
+
+def toom_cook_mul(fn, lhs, rhs, instances):
+    # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
+    raise NotImplementedError