from nmutil.plain_data import fields, plain_data
-from bigint_presentation_code.util import OFSet, OSet, final
+from bigint_presentation_code.util import FMap, OFSet, OSet, final
class ABCEnumMeta(EnumMeta, ABCMeta):
GPR_COUNT = 128
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class GPRRange(RegLoc, Sequence["GPRRange"]):
__slots__ = "start", "length"
raise ValueError(f"sub-register offset is out of range: {offset}")
return GPRRange(self.start + offset, subreg_type.length)
+ def __repr__(self):
+ if self.length == 1:
+ return f"<r{self.start}>"
+ return f"<r{self.start}..len={self.length}>"
+
SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
_RegLoc = TypeVar("_RegLoc", bound=RegLoc)
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
@final
class GPRRangeType(RegType):
__slots__ = "length",
def __hash__(self):
return hash(self.length)
+ def __repr__(self):
+ return f"<gpr_ty[{self.length}]>"
+
GPRType = GPRRangeType
"""a length=1 GPRRangeType"""
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class FixedGPRRangeType(RegType):
__slots__ = "reg",
# type: () -> int
return self.reg.length
+ def __repr__(self):
+ return f"<fixed({self.reg})>"
+
@plain_data(frozen=True, unsafe_hash=True)
@final
def __hash__(self):
return hash((id(self.op), self.arg_name))
- def __repr__(self, long=False):
- if not long:
- return f"<#{self.op.id}.{self.arg_name}>"
- fields_list = []
- for name in fields(self):
- v = getattr(self, name, None)
- if v is not None:
- if name == "op":
- v = v.__repr__(just_id=True)
- else:
- v = repr(v)
- fields_list.append(f"{name}={v}")
- fields_str = ", ".join(fields_list)
- return f"SSAVal({fields_str})"
+ def __repr__(self):
+ return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
SSAGPRRange = SSAVal[GPRRangeType]
ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
return f"<Fn([{ops}])>"
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ for op in self.ops:
+ op.pre_ra_sim(state)
+
class _NotSet:
""" helper for __repr__ for when fields aren't set """
return False
+GPR_SIZE_IN_BYTES = 8
+GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
+GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
+
+
+@plain_data(frozen=True)
+@final
+class PreRASimState:
+ __slots__ = ("gprs", "VLs", "CAs",
+ "global_mems", "stack_slots",
+ "fixed_gprs")
+
+ def __init__(
+ self,
+ gprs, # type: dict[SSAGPRRange, tuple[int, ...]]
+ VLs, # type: dict[SSAKnownVL, int]
+ CAs, # type: dict[SSAVal[CAType], bool]
+ global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
+ stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
+ fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
+ ):
+ # type: (...) -> None
+ self.gprs = gprs
+ self.VLs = VLs
+ self.CAs = CAs
+ self.global_mems = global_mems
+ self.stack_slots = stack_slots
+ self.fixed_gprs = fixed_gprs
+
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
class Op(metaclass=ABCMeta):
__slots__ = "id", "fn"
pass
if not just_id:
for name in fields(self):
+ if name in ("id", "fn"):
+ continue
v = getattr(self, name, _NOT_SET)
- if ((outputs is None or name in outputs)
- and isinstance(v, SSAVal)):
- v = v.__repr__(long=True)
- elif isinstance(v, Fn):
- v = v.__repr__(short=True)
+ if (outputs is not None and name in outputs
+ and outputs[name] is v):
+ fields_list.append(repr(v))
else:
- v = repr(v)
- fields_list.append(f"{name}={v}")
+ fields_list.append(f"{name}={v!r}")
fields_str = ', '.join(fields_list)
return f"{self.__class__.__name__}({fields_str})"
"""get the lines of assembly for this Op"""
...
+ @abstractmethod
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ """simulate op before register allocation"""
+ ...
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.ld {dest}, {src.start_byte}(1)"]
return [f"ld {dest}, {src.start_byte}(1)"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ """simulate op before register allocation"""
+ state.gprs[self.dest] = state.stack_slots[self.src]
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.std {src}, {dest.start_byte}(1)"]
return [f"std {src}, {dest.start_byte}(1)"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ """simulate op before register allocation"""
+ state.stack_slots[self.dest] = state.gprs[self.src]
+
_RegSrcType = TypeVar("_RegSrcType", bound=RegType)
elif src.ty != dest_ty:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
+ elif isinstance(src.ty, StackSlotType):
+ raise ValueError("can't use OpCopy on stack slots")
elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
length = src.ty.length
else:
return [f"or {dest_s}, {src_s}, {src_s}"]
raise NotImplementedError
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
+ isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
+ if isinstance(self.src.ty, GPRRangeType):
+ v = state.gprs[self.src] # type: ignore
+ else:
+ v = state.fixed_gprs[self.src] # type: ignore
+ if isinstance(self.dest.ty, GPRRangeType):
+ state.gprs[self.dest] = v # type: ignore
+ else:
+ state.fixed_gprs[self.dest] = v # type: ignore
+ elif (isinstance(self.src.ty, FixedGPRRangeType) and
+ isinstance(self.dest.ty, GPRRangeType)):
+ state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore
+ elif (isinstance(self.src.ty, GPRRangeType) and
+ isinstance(self.dest.ty, FixedGPRRangeType)):
+ state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore
+ elif (isinstance(self.src.ty, CAType) and
+ self.src.ty == self.dest.ty):
+ state.CAs[self.dest] = state.CAs[self.src] # type: ignore
+ elif (isinstance(self.src.ty, KnownVLType) and
+ self.src.ty == self.dest.ty):
+ state.VLs[self.dest] = state.VLs[self.src] # type: ignore
+ elif (isinstance(self.src.ty, GlobalMemType) and
+ self.src.ty == self.dest.ty):
+ v = state.global_mems[self.src] # type: ignore
+ state.global_mems[self.dest] = v # type: ignore
+ else:
+ raise NotImplementedError
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
# type: (AsmContext) -> list[str]
return []
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ v = []
+ for src in self.sources:
+ v.extend(state.gprs[src])
+ state.gprs[self.dest] = tuple(v)
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
ranges.append(GPRRangeType(src.ty.length - last))
self.src = src
self.results = tuple(
- SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
+ SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
def get_equality_constraints(self):
# type: () -> Iterable[EqualityConstraint]
# type: (AsmContext) -> list[str]
return []
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ rest = state.gprs[self.src]
+ for dest in reversed(self.results):
+ state.gprs[dest] = rest[-dest.ty.length:]
+ rest = rest[:-dest.ty.length]
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
return [f"{mnemonic} {out}, {RA}, {RB}"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ carry = state.CAs[self.CA_in]
+ out = [] # type: list[int]
+ for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
+ if self.is_sub:
+ r = r ^ GPR_VALUE_MASK
+ s = l + r + carry
+ carry = s != (s & GPR_VALUE_MASK)
+ out.append(s & GPR_VALUE_MASK)
+ state.CAs[self.CA_out] = carry
+ state.gprs[self.out] = tuple(out)
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
mnemonic = "divmod2du/mrr"
return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ carry = state.gprs[self.RC][0]
+ RA = state.gprs[self.RA]
+ RB = state.gprs[self.RB][0]
+ RT = [0] * self.RT.ty.length
+ if self.is_div:
+ for i in reversed(range(self.RT.ty.length)):
+ if carry < RB and RB != 0:
+ div, mod = divmod((carry << 64) | RA[i], RB)
+ RT[i] = div & GPR_VALUE_MASK
+ carry = mod & GPR_VALUE_MASK
+ else:
+ RT[i] = GPR_VALUE_MASK
+ carry = 0
+ else:
+ for i in range(self.RT.ty.length):
+ v = RA[i] * RB + carry
+ carry = v >> 64
+ RT[i] = v & GPR_VALUE_MASK
+ state.gprs[self.RS] = carry,
+ state.gprs[self.RT] = tuple(RT)
+
@final
@unique
RB = ctx.sgpr(self.sh)
return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ out = [0] * self.out.ty.length
+ carry = state.gprs[self.carry_in][0]
+ sh = state.gprs[self.sh][0] % 64
+ if self.kind is ShiftKind.Sl:
+ inp = carry, *state.gprs[self.inp]
+ for i in reversed(range(self.out.ty.length)):
+ v = inp[i] | (inp[i + 1] << 64)
+ v <<= sh
+ out[i] = (v >> 64) & GPR_VALUE_MASK
+ else:
+ assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+ inp = *state.gprs[self.inp], carry
+ for i in range(self.out.ty.length):
+ v = inp[i] | (inp[i + 1] << 64)
+ v >>= sh
+ out[i] = v & GPR_VALUE_MASK
+ # state.gprs[self._out_padding] is intentionally not written
+ state.gprs[self.out] = tuple(out)
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.{mnemonic} {out}, {inp}, {args}"]
return [f"{mnemonic} {out}, {inp}, {args}"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ inp = state.gprs[self.inp][0]
+ if self.kind is ShiftKind.Sl:
+ assert self.ca_out is None
+ out = inp << self.sh
+ elif self.kind is ShiftKind.Sr:
+ assert self.ca_out is None
+ out = inp >> self.sh
+ else:
+ assert self.kind is ShiftKind.Sra
+ assert self.ca_out is not None
+ if inp & (1 << 63): # sign extend
+ inp -= 1 << 64
+ out = inp >> self.sh
+ ca = inp < 0 and (out << self.sh) != inp
+ state.CAs[self.ca_out] = ca
+ state.gprs[self.out] = out,
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.addi {out}, 0, {self.value}"]
return [f"addi {out}, 0, {self.value}"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ value = self.value & GPR_VALUE_MASK
+ state.gprs[self.out] = (value,) * self.out.ty.length
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return ["subfic 0, 0, -1"]
return ["addic 0, 0, 0"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ state.CAs[self.out] = self.value
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.ld {RT}, {self.offset}({RA})"]
return [f"ld {RT}, {self.offset}({RA})"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ addr = state.gprs[self.RA][0]
+ addr += self.offset
+ RT = [0] * self.RT.ty.length
+ mem = state.global_mems[self.mem]
+ for i in range(self.RT.ty.length):
+ cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
+ if cur_addr % GPR_SIZE_IN_BYTES != 0:
+ raise ValueError(f"can't load from unaligned address: "
+ f"{cur_addr:#x}")
+ for j in range(GPR_SIZE_IN_BYTES):
+ byte_val = mem.get(cur_addr + j, 0) & 0xFF
+ RT[i] |= byte_val << (j * 8)
+ state.gprs[self.RT] = tuple(RT)
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return [f"sv.std {RS}, {self.offset}({RA})"]
return [f"std {RS}, {self.offset}({RA})"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ mem = dict(state.global_mems[self.mem_in])
+ addr = state.gprs[self.RA][0]
+ addr += self.offset
+ RS = state.gprs[self.RS]
+ for i in range(self.RS.ty.length):
+ cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
+ if cur_addr % GPR_SIZE_IN_BYTES != 0:
+ raise ValueError(f"can't store to unaligned address: "
+ f"{cur_addr:#x}")
+ for j in range(GPR_SIZE_IN_BYTES):
+ mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
+ state.global_mems[self.mem_out] = FMap(mem)
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
# type: (AsmContext) -> list[str]
return []
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ if self.out not in state.fixed_gprs:
+ state.fixed_gprs[self.out] = (0,) * self.out.ty.length
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
# type: (AsmContext) -> list[str]
return []
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ if self.out not in state.global_mems:
+ state.global_mems[self.out] = FMap()
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
# type: (AsmContext) -> list[str]
return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
+ def pre_ra_sim(self, state):
+ # type: (PreRASimState) -> None
+ state.VLs[self.out] = self.out.ty.length
+
def op_set_to_list(ops):
# type: (Iterable[Op]) -> list[Op]
repr(reg_assignments),
"AllocationFailed("
"node=IGNode(#0, merged_reg_set=MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
"edges={}, reg=None), "
- "live_intervals=LiveIntervals("
- "live_intervals={"
- "MergedRegSet([(<#0.out>, 0)]): "
+ "live_intervals=LiveIntervals(live_intervals={"
+ "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
"LiveInterval(first_write=0, last_use=1), "
- "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): "
+ "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]): "
"LiveInterval(first_write=1, last_use=4), "
- "MergedRegSet([(<#2.out>, 0)]): "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
"LiveInterval(first_write=2, last_use=3)}, "
"merged_reg_sets=MergedRegSets(data={"
- "<#0.out>: MergedRegSet([(<#0.out>, 0)]), "
- "<#1.out>: MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
- "<#2.out>: MergedRegSet([(<#2.out>, 0)]), "
- "<#3.out>: MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
- "<#4.dest>: MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), "
+ "<#0.out: KnownVLType(length=52)>: "
+ "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
+ "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "<#2.out: KnownVLType(length=64)>: "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
+ "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])}), "
"reg_sets_live_after={"
- "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), "
+ "0: OFSet([MergedRegSet(["
+ "(<#0.out: KnownVLType(length=52)>, 0)])]), "
"1: OFSet([MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])]), "
"2: OFSet([MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
- "MergedRegSet([(<#2.out>, 0)])]), "
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
"3: OFSet([MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])]), "
"4: OFSet()}), "
"interference_graph=InterferenceGraph(nodes={"
- "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), "
- "edges={}, reg=None), "
+ "...: IGNode(#0, merged_reg_set=MergedRegSet(["
+ "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
"...: IGNode(#1, merged_reg_set=MergedRegSet(["
- "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
- "edges={}, reg=None), "
- "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), "
- "edges={}, reg=None)}))"
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
+ "...: IGNode(#2, merged_reg_set=MergedRegSet(["
+ "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
)
def test_try_alloc_bigint_inc(self):
import unittest
-from bigint_presentation_code.toom_cook import ToomCookInstance
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
+ GlobalMem, GPRRange,
+ GPRRangeType, OpCopy,
+ OpFuncArg, OpInputMem,
+ OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
+ generate_assembly)
+from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.util import FMap
+
+
+class SimpleMul192x192:
+ def __init__(self):
+ self.fn = fn = Fn()
+ self.mem_in = mem = OpInputMem(fn).out
+ self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
+ self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
+ self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
+ dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
+ vl = OpSetVLImm(fn, 3).out
+ lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
+ rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
+ retval = simple_mul(fn, lhs, rhs)
+ vl = OpSetVLImm(fn, 6).out
+ self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
+ mem_in=mem, vl=vl).mem_out
class TestToomCook(unittest.TestCase):
- def test_toom_2(self):
+ maxDiff = None
+
+ def test_toom_2_repr(self):
TOOM_2 = ToomCookInstance.make_toom_2()
# print(repr(repr(TOOM_2)))
self.assertEqual(
"EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
)
- def test_toom_2_5(self):
+ def test_toom_2_5_repr(self):
TOOM_2_5 = ToomCookInstance.make_toom_2_5()
# print(repr(repr(TOOM_2_5)))
self.assertEqual(
"EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
)
- def test_reversed_toom_2_5(self):
+ def test_reversed_toom_2_5_repr(self):
TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
- print(repr(repr(TOOM_2_5)))
+ # print(repr(repr(TOOM_2_5)))
self.assertEqual(
repr(TOOM_2_5),
"ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
"EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
)
+ def test_simple_mul_192x192_pre_ra_sim(self):
+ # test multiplying:
+ # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
+ # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
+ # ==
+ # int("0x00074736574206e_6f69746163696c70"
+ # "_69746c756d207469_622d3438333e2d32"
+ # "_3931783239312079_7261727469627261", base=0)
+ # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
+ # 'little')
+ code = SimpleMul192x192()
+ dest_ptr = 0x100
+ state = PreRASimState(
+ gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
+ stack_slots={}, fixed_gprs={
+ code.dest_ptr_in: (dest_ptr,),
+ code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
+ 0x000191acb262e15b),
+ code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
+ 0x4a37c0567bcbab53)
+ })
+ code.fn.pre_ra_sim(state)
+ expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
+ OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
+ expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
+ mem_out = state.global_mems[code.mem_out]
+ out_bytes = bytes(
+ mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
+ self.assertEqual(out_bytes, expected_bytes)
+
+ def test_simple_mul_192x192_ops(self):
+ code = SimpleMul192x192()
+ fn = code.fn
+ self.assertEqual([repr(v) for v in fn.ops], [
+ 'OpInputMem(#0, <#0.out: GlobalMemType()>)',
+ 'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
+ 'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
+ 'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
+ 'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
+ 'vl=None)',
+ 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
+ 'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
+ 'src=<#2.out: <fixed(<r4..len=3>)>>, '
+ 'vl=<#5.out: KnownVLType(length=3)>)',
+ 'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
+ 'src=<#3.out: <fixed(<r7..len=3>)>>, '
+ 'vl=<#5.out: KnownVLType(length=3)>)',
+ 'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
+ '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
+ 'src=<#7.dest: <gpr_ty[3]>>)',
+ 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
+ 'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
+ 'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
+ '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
+ 'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
+ '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
+ 'OpSetCA(#15, <#15.out: CAType()>, value=False)',
+ 'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
+ 'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
+ 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
+ 'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+ 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
+ 'is_sub=False, vl=None)',
+ 'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
+ '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
+ '<#17.out: <gpr_ty[1]>>))',
+ 'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
+ '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
+ 'OpSetCA(#21, <#21.out: CAType()>, value=False)',
+ 'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
+ 'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
+ 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
+ 'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+ 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
+ 'is_sub=False, vl=None)',
+ 'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
+ '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
+ '<#23.out: <gpr_ty[1]>>))',
+ 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
+ 'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
+ 'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
+ 'mem_in=<#0.out: GlobalMemType()>, '
+ '<#26.mem_out: GlobalMemType()>, '
+ 'vl=<#25.out: KnownVLType(length=6)>)'
+ ])
+
+ # FIXME: register allocator currently allocates wrong registers
+ @unittest.expectedFailure
+ def test_simple_mul_192x192_reg_alloc(self):
+ code = SimpleMul192x192()
+ fn = code.fn
+ assigned_registers = allocate_registers(fn.ops)
+ self.assertEqual(assigned_registers, {
+ fn.ops[13].RS: GPRRange(9), # type: ignore
+ fn.ops[14].results[0]: GPRRange(6), # type: ignore
+ fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore
+ fn.ops[15].out: XERBit.CA, # type: ignore
+ fn.ops[16].out: GPRRange(7, length=3), # type: ignore
+ fn.ops[16].CA_out: XERBit.CA, # type: ignore
+ fn.ops[17].out: GPRRange(10), # type: ignore
+ fn.ops[17].CA_out: XERBit.CA, # type: ignore
+ fn.ops[18].dest: GPRRange(6, length=5), # type: ignore
+ fn.ops[19].RT: GPRRange(3, length=3), # type: ignore
+ fn.ops[19].RS: GPRRange(9), # type: ignore
+ fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore
+ fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore
+ fn.ops[21].out: XERBit.CA, # type: ignore
+ fn.ops[22].out: GPRRange(8, length=3), # type: ignore
+ fn.ops[22].CA_out: XERBit.CA, # type: ignore
+ fn.ops[23].out: GPRRange(11), # type: ignore
+ fn.ops[23].CA_out: XERBit.CA, # type: ignore
+ fn.ops[24].dest: GPRRange(6, length=6), # type: ignore
+ fn.ops[25].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore
+ fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore
+ fn.ops[1].out: GPRRange(3), # type: ignore
+ fn.ops[2].out: GPRRange(4, length=3), # type: ignore
+ fn.ops[3].out: GPRRange(7, length=3), # type: ignore
+ fn.ops[4].dest: GPRRange(12), # type: ignore
+ fn.ops[5].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[6].dest: GPRRange(17, length=3), # type: ignore
+ fn.ops[7].dest: GPRRange(14, length=3), # type: ignore
+ fn.ops[8].results[0]: GPRRange(14), # type: ignore
+ fn.ops[8].results[1]: GPRRange(15), # type: ignore
+ fn.ops[8].results[2]: GPRRange(16), # type: ignore
+ fn.ops[9].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[10].out: GPRRange(9), # type: ignore
+ fn.ops[11].RT: GPRRange(6, length=3), # type: ignore
+ fn.ops[11].RS: GPRRange(9), # type: ignore
+ fn.ops[12].dest: GPRRange(6, length=4), # type: ignore
+ fn.ops[13].RT: GPRRange(3, length=3) # type: ignore
+ })
+ self.fail("register allocator currently allocates wrong registers")
+
+ # FIXME: register allocator currently allocates wrong registers
+ @unittest.expectedFailure
+ def test_simple_mul_192x192_asm(self):
+ code = SimpleMul192x192()
+ asm = generate_assembly(code.fn.ops)
+ self.assertEqual(asm, [
+ 'or 12, 3, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *17, *4, *4',
+ 'sv.or *14, *7, *7',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'addi 9, 0, 0',
+ 'sv.maddedu *6, *17, 14, 9',
+ 'sv.maddedu *3, *17, 15, 9',
+ 'addic 0, 0, 0',
+ 'sv.adde *7, *3, *7',
+ 'adde 10, 9, 9',
+ 'sv.maddedu *3, *17, 16, 9',
+ 'addic 0, 0, 0',
+ 'sv.adde *8, *3, *8',
+ 'adde 11, 9, 9',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'sv.std *6, 0(12)',
+ 'bclr 20, 0, 0'
+ ])
+ self.fail("register allocator currently allocates wrong registers")
+
if __name__ == "__main__":
unittest.main()
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir import Fn, Op
+from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange
from bigint_presentation_code.matrix import Matrix
from bigint_presentation_code.util import Literal, OSet, final
# TODO: add make_toom_3
-def toom_cook_mul(fn, word_count, instances):
- # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
- retval = OSet() # type: OSet[Op]
- raise NotImplementedError
+def simple_mul(fn, lhs, rhs):
+ # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
+ """ simple O(n^2) big-int unsigned multiply """
+ if lhs.ty.length < rhs.ty.length:
+ lhs, rhs = rhs, lhs
+ # split rhs into elements
+ rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results
+ retval = None
+ vl = OpSetVLImm(fn, lhs.ty.length).out
+ zero = OpLI(fn, 0).out
+ for shift, rhs_word in enumerate(rhs_words):
+ mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero,
+ is_div=False, vl=vl)
+ if retval is None:
+ retval = OpConcat(fn, [mul.RT, mul.RS]).dest
+ else:
+ first_part, last_part = OpSplit(fn, retval, [shift]).results
+ add = OpBigIntAddSub(
+ fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out,
+ is_sub=False, vl=vl)
+ add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out,
+ is_sub=False)
+ retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest
+ assert retval is not None
return retval
+
+
+def toom_cook_mul(fn, lhs, rhs, instances):
+ # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
+ raise NotImplementedError