BaseSimState, Fn,
GenAsmState, OpKind,
PostRASimState,
- PreRASimState)
+ PreRASimState, SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
-from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul,
+ toom_cook_mul)
-class SimpleMul192x192:
- def __init__(self):
+def simple_umul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> SSAVal
+ return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+ rhs_signed=False, name="simple_umul")
+
+
+class Mul:
+ def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
+ # type: (Callable[[Fn, SSAVal, SSAVal], SSAVal], int, int) -> None
super().__init__()
self.fn = fn = Fn()
self.dest_offset = 0
- self.lhs_offset = 48 + self.dest_offset
- self.rhs_offset = 24 + self.lhs_offset
+ self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
+ self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
+ self.lhs_size_in_words = lhs_size_in_words
+ self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
+ self.rhs_size_in_words = rhs_size_in_words
+ self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
+ self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
+ self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
name="ptr_in").outputs[0]
- setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3],
- maxvl=3, name="setvl3")
+ lhs_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
+ maxvl=lhs_size_in_words, name="lhs_setvl")
load_lhs = fn.append_new_op(
kind=OpKind.SvLd, immediates=[self.lhs_offset],
- input_vals=[self.ptr_in, setvl3.outputs[0]],
- name="load_lhs", maxvl=3)
+ input_vals=[self.ptr_in, lhs_setvl.outputs[0]],
+ name="load_lhs", maxvl=lhs_size_in_words)
+ rhs_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
+ maxvl=rhs_size_in_words, name="rhs_setvl")
load_rhs = fn.append_new_op(
kind=OpKind.SvLd, immediates=[self.rhs_offset],
- input_vals=[self.ptr_in, setvl3.outputs[0]],
+ input_vals=[self.ptr_in, rhs_setvl.outputs[0]],
name="load_rhs", maxvl=3)
- retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
- setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6],
- maxvl=6, name="setvl6")
+ retval = mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
+ dest_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
+ maxvl=self.dest_size_in_words, name="dest_setvl")
fn.append_new_op(
kind=OpKind.SvStd,
- input_vals=[retval, self.ptr_in, setvl6.outputs[0]],
- immediates=[self.dest_offset], maxvl=6, name="store_dest")
+ input_vals=[retval, self.ptr_in, dest_setvl.outputs[0]],
+ immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
+ name="store_dest")
class TestToomCook(unittest.TestCase):
)
def test_simple_mul_192x192_pre_ra_sim(self):
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+
def create_sim_state(code):
- # type: (SimpleMul192x192) -> BaseSimState
+ # type: (Mul) -> BaseSimState
return PreRASimState(ssa_vals={}, memory={})
self.tst_simple_mul_192x192_sim(create_sim_state)
def test_simple_mul_192x192_post_ra_sim(self):
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+
def create_sim_state(code):
- # type: (SimpleMul192x192) -> BaseSimState
+ # type: (Mul) -> BaseSimState
ssa_val_to_loc_map = allocate_registers(code.fn)
return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
memory={}, loc_values={})
self.tst_simple_mul_192x192_sim(create_sim_state)
def tst_simple_mul_192x192_sim(self, create_sim_state):
- # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None
+ # type: (Callable[[Mul], BaseSimState]) -> None
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
# test multiplying:
# 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
# * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
# "_3931783239312079_7261727469627261", base=0)
# == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
# 'little')
- code = SimpleMul192x192()
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
state = create_sim_state(code)
ptr_in = 0x100
dest_ptr = ptr_in + code.dest_offset
self.assertEqual(out_bytes, expected_bytes)
def test_simple_mul_192x192_ops(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
self.assertEqual([repr(v) for v in fn.ops], [
"Op(kind=OpKind.FuncArgR3, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<setvl3.outputs[0]: <VL_MAXVL>>,), "
- "name='setvl3')",
+ "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='lhs_setvl')",
"Op(kind=OpKind.SvLd, "
"input_vals=[<ptr_in.outputs[0]: <I64>>, "
- "<setvl3.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<load_lhs.input_uses[0]: <I64>>, "
"<load_lhs.input_uses[1]: <VL_MAXVL>>), immediates=[48], "
"outputs=(<load_lhs.outputs[0]: <I64*3>>,), "
"name='load_lhs')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), immediates=[3], "
+ "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='rhs_setvl')",
"Op(kind=OpKind.SvLd, "
"input_vals=[<ptr_in.outputs[0]: <I64>>, "
- "<setvl3.outputs[0]: <VL_MAXVL>>], "
+ "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<load_rhs.input_uses[0]: <I64>>, "
"<load_rhs.input_uses[1]: <VL_MAXVL>>), immediates=[72], "
"outputs=(<load_rhs.outputs[0]: <I64*3>>,), "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='rhs_setvl')",
+ "outputs=(<rhs_setvl2.outputs[0]: <VL_MAXVL>>,), "
+ "name='rhs_setvl2')",
"Op(kind=OpKind.Spread, "
"input_vals=[<load_rhs.outputs[0]: <I64*3>>, "
- "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<rhs_setvl2.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<rhs_spread.input_uses[0]: <I64*3>>, "
"<rhs_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<rhs_spread.outputs[0]: <I64>>, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='lhs_setvl')",
+ "outputs=(<lhs_setvl3.outputs[0]: <VL_MAXVL>>,), "
+ "name='lhs_setvl3')",
"Op(kind=OpKind.LI, "
"input_vals=[], "
"input_uses=(), immediates=[0], "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[0]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul0.input_uses[0]: <I64*3>>, "
"<mul0.input_uses[1]: <I64>>, "
"<mul0.input_uses[2]: <I64>>, "
"name='mul0')",
"Op(kind=OpKind.Spread, "
"input_vals=[<mul0.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul0_rt_spread.input_uses[0]: <I64*3>>, "
"<mul0_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<mul0_rt_spread.outputs[0]: <I64>>, "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[1]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul1.input_uses[0]: <I64*3>>, "
"<mul1.input_uses[1]: <I64>>, "
"<mul1.input_uses[2]: <I64>>, "
"input_vals=[<mul0_rt_spread.outputs[1]: <I64>>, "
"<mul0_rt_spread.outputs[2]: <I64>>, "
"<mul0.outputs[1]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1_rb_concat.input_uses[0]: <I64>>, "
"<add1_rb_concat.input_uses[1]: <I64>>, "
"<add1_rb_concat.input_uses[2]: <I64>>, "
"input_vals=[<mul1.outputs[0]: <I64*3>>, "
"<add1_rb_concat.outputs[0]: <I64*3>>, "
"<clear_ca1.outputs[0]: <CA>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1.input_uses[0]: <I64*3>>, "
"<add1.input_uses[1]: <I64*3>>, "
"<add1.input_uses[2]: <CA>>, "
"name='add1')",
"Op(kind=OpKind.Spread, "
"input_vals=[<add1.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1_rt_spread.input_uses[0]: <I64*3>>, "
"<add1_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<add1_rt_spread.outputs[0]: <I64>>, "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[2]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul2.input_uses[0]: <I64*3>>, "
"<mul2.input_uses[1]: <I64>>, "
"<mul2.input_uses[2]: <I64>>, "
"input_vals=[<add1_rt_spread.outputs[1]: <I64>>, "
"<add1_rt_spread.outputs[2]: <I64>>, "
"<add_hi1.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2_rb_concat.input_uses[0]: <I64>>, "
"<add2_rb_concat.input_uses[1]: <I64>>, "
"<add2_rb_concat.input_uses[2]: <I64>>, "
"input_vals=[<mul2.outputs[0]: <I64*3>>, "
"<add2_rb_concat.outputs[0]: <I64*3>>, "
"<clear_ca2.outputs[0]: <CA>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2.input_uses[0]: <I64*3>>, "
"<add2.input_uses[1]: <I64*3>>, "
"<add2.input_uses[2]: <CA>>, "
"name='add2')",
"Op(kind=OpKind.Spread, "
"input_vals=[<add2.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2_rt_spread.input_uses[0]: <I64*3>>, "
"<add2_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<add2_rt_spread.outputs[0]: <I64>>, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[6], "
- "outputs=(<setvl6.outputs[0]: <VL_MAXVL>>,), "
- "name='setvl6')",
+ "outputs=(<dest_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='dest_setvl')",
"Op(kind=OpKind.SvStd, "
"input_vals=[<concat_retval.outputs[0]: <I64*6>>, "
"<ptr_in.outputs[0]: <I64>>, "
- "<setvl6.outputs[0]: <VL_MAXVL>>], "
+ "<dest_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<store_dest.input_uses[0]: <I64*6>>, "
"<store_dest.input_uses[1]: <I64>>, "
"<store_dest.input_uses[2]: <VL_MAXVL>>), immediates=[0], "
])
def test_simple_mul_192x192_reg_alloc(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
assigned_registers = allocate_registers(fn)
self.assertEqual(
"Loc(kind=LocKind.GPR, start=4, reg_len=6), "
"<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<setvl6.outputs[0]: <VL_MAXVL>>: "
+ "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=6), "
"Loc(kind=LocKind.GPR, start=18, reg_len=1), "
"<zero.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<rhs_spread.out2.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=19, reg_len=1), "
"Loc(kind=LocKind.GPR, start=3, reg_len=3), "
"<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "<rhs_setvl2.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=3), "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_rhs.inp0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
"Loc(kind=LocKind.GPR, start=20, reg_len=3), "
"<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_lhs.inp0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=6, reg_len=1), "
- "<setvl3.outputs[0]: <VL_MAXVL>>: "
+ "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<ptr_in.out0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=23, reg_len=1), "
"}")
def test_simple_mul_192x192_asm(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
assigned_registers = allocate_registers(fn)
gen_asm_state = GenAsmState(assigned_registers)
'sv.ld *3, 48(6)',
'setvl 0, 0, 3, 0, 1, 1',
'sv.or *20, *3, *3',
+ 'setvl 0, 0, 3, 0, 1, 1',
'or 6, 23, 23',
'setvl 0, 0, 3, 0, 1, 1',
'sv.ld *3, 72(6)',
'sv.std *4, 0(3)'
])
+ def test_toom_2_mul_256x256_asm(self):
+ self.skipTest("WIP") # FIXME: finish
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2
+
+ def mul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> SSAVal
+ return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+ rhs_signed=False, instances=instances)
+ code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3)
+ fn = code.fn
+ assigned_registers = allocate_registers(fn)
+ gen_asm_state = GenAsmState(assigned_registers)
+ fn.gen_asm(gen_asm_state)
+ self.assertEqual(gen_asm_state.output, [
+ ])
+
if __name__ == "__main__":
unittest.main()