1 from contextlib
import contextmanager
4 from typing
import Any
, Callable
, ContextManager
, Iterator
, Tuple
, Iterable
6 from bigint_presentation_code
.compiler_ir
import (GPR_SIZE_IN_BITS
,
8 GPR_VALUE_MASK
, BaseSimState
,
9 Fn
, GenAsmState
, OpKind
,
11 PreRASimState
, SSAVal
)
12 from bigint_presentation_code
.register_allocator
import allocate_registers
13 from bigint_presentation_code
.register_allocator_test_util
import GraphDumper
14 from bigint_presentation_code
.toom_cook
import (ToomCookInstance
, ToomCookMul
,
16 from bigint_presentation_code
.util
import OSet
18 _StateFactory
= Callable
[[], ContextManager
[BaseSimState
]]
21 def simple_umul(fn
, lhs
, rhs
):
22 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, None]
23 return simple_mul(fn
=fn
, lhs
=lhs
, lhs_signed
=False, rhs
=rhs
,
24 rhs_signed
=False, name
="mul"), None
27 def get_pre_ra_state_factory(code
):
28 # type: (Mul) -> _StateFactory
31 state
= PreRASimState(ssa_vals
={}, memory
={})
32 with state
.set_as_current_debugging_state():
38 _MulFn
= Callable
[[Fn
, SSAVal
, SSAVal
], Tuple
[SSAVal
, Any
]]
40 def __init__(self
, mul
, lhs_size_in_words
, rhs_size_in_words
):
41 # type: (_MulFn, int, int) -> None
45 self
.dest_size_in_words
= lhs_size_in_words
+ rhs_size_in_words
46 self
.dest_size_in_bytes
= self
.dest_size_in_words
* GPR_SIZE_IN_BYTES
47 self
.lhs_size_in_words
= lhs_size_in_words
48 self
.lhs_size_in_bytes
= self
.lhs_size_in_words
* GPR_SIZE_IN_BYTES
49 self
.rhs_size_in_words
= rhs_size_in_words
50 self
.rhs_size_in_bytes
= self
.rhs_size_in_words
* GPR_SIZE_IN_BYTES
51 self
.lhs_offset
= self
.dest_size_in_bytes
+ self
.dest_offset
52 self
.rhs_offset
= self
.lhs_size_in_bytes
+ self
.lhs_offset
53 self
.ptr_in
= fn
.append_new_op(kind
=OpKind
.FuncArgR3
,
54 name
="ptr_in").outputs
[0]
55 self
.lhs_setvl
= fn
.append_new_op(
56 kind
=OpKind
.SetVLI
, immediates
=[lhs_size_in_words
],
57 maxvl
=lhs_size_in_words
, name
="lhs_setvl")
58 self
.load_lhs
= fn
.append_new_op(
59 kind
=OpKind
.SvLd
, immediates
=[self
.lhs_offset
],
60 input_vals
=[self
.ptr_in
, self
.lhs_setvl
.outputs
[0]],
61 name
="load_lhs", maxvl
=lhs_size_in_words
)
62 self
.rhs_setvl
= fn
.append_new_op(
63 kind
=OpKind
.SetVLI
, immediates
=[rhs_size_in_words
],
64 maxvl
=rhs_size_in_words
, name
="rhs_setvl")
65 self
.load_rhs
= fn
.append_new_op(
66 kind
=OpKind
.SvLd
, immediates
=[self
.rhs_offset
],
67 input_vals
=[self
.ptr_in
, self
.rhs_setvl
.outputs
[0]],
68 name
="load_rhs", maxvl
=rhs_size_in_words
)
70 fn
, self
.load_lhs
.outputs
[0], self
.load_rhs
.outputs
[0])
71 self
.dest_setvl
= fn
.append_new_op(
72 kind
=OpKind
.SetVLI
, immediates
=[self
.dest_size_in_words
],
73 maxvl
=self
.dest_size_in_words
, name
="dest_setvl")
74 self
.store
= fn
.append_new_op(
76 input_vals
=[self
.retval
[0], self
.ptr_in
,
77 self
.dest_setvl
.outputs
[0]],
78 immediates
=[self
.dest_offset
], maxvl
=self
.dest_size_in_words
,
82 class TestToomCook(unittest
.TestCase
):
85 def get_post_ra_state_factory(self
, code
):
86 # type: (Mul) -> _StateFactory
87 ssa_val_to_loc_map
= allocate_registers(
88 code
.fn
, debug_out
=sys
.stdout
, dump_graph
=GraphDumper(self
))
93 ssa_val_to_loc_map
=ssa_val_to_loc_map
,
94 memory
={}, loc_values
={})
97 def test_toom_2_repr(self
):
98 TOOM_2
= ToomCookInstance
.make_toom_2()
99 # print(repr(repr(TOOM_2)))
102 "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
103 "eval_points=(0, 1, POINT_AT_INFINITY), "
105 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
107 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
109 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
110 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
111 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
113 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
115 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
117 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
118 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
119 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
121 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
124 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
126 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
127 "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
129 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
131 "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
132 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
135 def test_toom_2_5_repr(self
):
136 TOOM_2_5
= ToomCookInstance
.make_toom_2_5()
137 # print(repr(repr(TOOM_2_5)))
140 "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
141 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
142 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
145 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
146 "rhs=EvalOpInput(lhs=2, rhs=0, "
147 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
148 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
149 "rhs=EvalOpInput(lhs=1, rhs=0, "
150 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
152 "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
155 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
156 "rhs=EvalOpInput(lhs=2, rhs=0, "
157 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
158 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
159 "rhs=EvalOpInput(lhs=1, rhs=0, "
160 "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
161 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
162 "EvalOpInput(lhs=2, rhs=0, "
163 "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
164 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
165 "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
166 "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
167 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
168 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
170 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
171 "rhs=EvalOpInput(lhs=1, rhs=0, "
172 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
173 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
174 "EvalOpInput(lhs=1, rhs=0, "
175 "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
177 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
178 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
179 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
180 "rhs=EvalOpInput(lhs=2, rhs=0, "
181 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
182 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
184 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
185 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
187 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
188 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
189 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
191 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
192 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
193 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
194 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
196 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
197 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
200 def test_reversed_toom_2_5_repr(self
):
201 TOOM_2_5
= ToomCookInstance
.make_toom_2_5().reversed()
202 # print(repr(repr(TOOM_2_5)))
205 "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
206 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
207 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
209 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
211 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
212 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
214 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
216 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
217 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
218 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
220 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
221 "EvalOpAdd(lhs=EvalOpAdd(lhs="
222 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
224 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
225 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
226 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
228 "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
229 "EvalOpSub(lhs=EvalOpAdd(lhs="
230 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
232 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
233 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
234 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
236 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
237 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
239 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
240 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
241 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
243 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
244 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
246 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
247 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
249 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
250 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
251 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
253 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
254 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
255 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
256 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
258 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
259 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
262 def test_simple_mul_192x192_pre_ra_sim(self
):
263 for lhs_signed
in False, True:
264 for rhs_signed
in False, True:
265 self
.tst_simple_mul_192x192_sim(
266 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
267 get_state_factory
=get_pre_ra_state_factory
)
269 def test_simple_mul_192x192_post_ra_sim(self
):
270 for lhs_signed
in False, True:
271 for rhs_signed
in False, True:
272 self
.tst_simple_mul_192x192_sim(
273 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
274 get_state_factory
=self
.get_post_ra_state_factory
)
276 def tst_simple_mul_192x192_sim(
277 self
, lhs_signed
, # type: bool
278 rhs_signed
, # type: bool
279 get_state_factory
, # type: Callable[[Mul], _StateFactory]
282 # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
283 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
285 # int("0x00074736574206e_6f69746163696c70"
286 # "_69746c756d207469_622d3438333e2d32"
287 # "_3931783239312079_7261727469627261", base=0)
288 # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
290 lhs_value
= 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
291 rhs_value
= 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
292 prod_value
= int.from_bytes(
293 b
"arbitrary 192x192->384-bit multiplication test", 'little')
294 self
.assertEqual(lhs_value
* rhs_value
, prod_value
)
296 mul
=lambda fn
, lhs
, rhs
: (simple_mul(
297 fn
=fn
, lhs
=lhs
, lhs_signed
=lhs_signed
,
298 rhs
=rhs
, rhs_signed
=rhs_signed
, name
="mul"), None),
299 lhs_size_in_words
=3, rhs_size_in_words
=3)
300 state_factory
= get_state_factory(code
)
302 dest_ptr
= ptr_in
+ code
.dest_offset
303 lhs_ptr
= ptr_in
+ code
.lhs_offset
304 rhs_ptr
= ptr_in
+ code
.rhs_offset
305 for lhs_neg
in False, True:
306 for rhs_neg
in False, True:
307 if lhs_neg
and not lhs_signed
:
309 if rhs_neg
and not rhs_signed
:
311 with self
.subTest(lhs_signed
=lhs_signed
,
312 rhs_signed
=rhs_signed
,
313 lhs_neg
=lhs_neg
, rhs_neg
=rhs_neg
):
314 with
state_factory() as state
:
315 state
[code
.ptr_in
] = ptr_in
,
323 v
= (lhs
>> GPR_SIZE_IN_BITS
* i
) & GPR_VALUE_MASK
324 state
.store(lhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
326 v
= (rhs
>> GPR_SIZE_IN_BITS
* i
) & GPR_VALUE_MASK
327 state
.store(rhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
329 expected
= prod_value
330 if lhs_neg
!= rhs_neg
:
331 expected
= 2 ** 384 - expected
334 v
= state
.load(dest_ptr
+ GPR_SIZE_IN_BYTES
* i
)
335 prod
+= v
<< (GPR_SIZE_IN_BITS
* i
)
336 self
.assertEqual(hex(prod
), hex(expected
))
338 def test_simple_mul_192x192_ops(self
):
339 code
= Mul(mul
=simple_umul
, lhs_size_in_words
=3, rhs_size_in_words
=3)
344 " (<...outputs[0]: <I64>>) <= FuncArgR3\n"
346 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
348 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
349 " <ptr_in.outputs[0]: <I64>>,\n"
350 " <lhs_setvl.outputs[0]: <VL_MAXVL>>, 0x30)\n"
352 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
354 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
355 " <ptr_in.outputs[0]: <I64>>,\n"
356 " <rhs_setvl.outputs[0]: <VL_MAXVL>>, 0x48)\n"
358 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
360 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
361 " <...outputs[2]: <I64>>) <= Spread(\n"
362 " <load_rhs.outputs[0]: <I64*3>>,\n"
363 " <mul_rhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
365 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
367 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
369 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
371 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
372 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
373 " <mul_rhs_spread.outputs[0]: <I64>>,\n"
374 " <mul_zero.outputs[0]: <I64>>,\n"
375 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
376 "mul_0_mul_rt_spread:\n"
377 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
378 " <...outputs[2]: <I64>>) <= Spread(\n"
379 " <mul_0_mul.outputs[0]: <I64*3>>,\n"
380 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
382 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
383 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
384 " <mul_rhs_spread.outputs[1]: <I64>>,\n"
385 " <mul_zero.outputs[0]: <I64>>,\n"
386 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
387 "mul_1_mul_rt_spread:\n"
388 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
389 " <...outputs[2]: <I64>>) <= Spread(\n"
390 " <mul_1_mul.outputs[0]: <I64*3>>,\n"
391 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
392 "mul_1_cast_retval_zero:\n"
393 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
394 "mul_1_cast_pp_zero:\n"
395 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
397 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x5)\n"
398 "mul_1_retval_concat:\n"
399 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
400 " <mul_0_mul_rt_spread.outputs[1]: <I64>>,\n"
401 " <mul_0_mul_rt_spread.outputs[2]: <I64>>,\n"
402 " <mul_0_mul.outputs[1]: <I64>>,\n"
403 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
404 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
405 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
407 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
408 " <mul_1_mul_rt_spread.outputs[0]: <I64>>,\n"
409 " <mul_1_mul_rt_spread.outputs[1]: <I64>>,\n"
410 " <mul_1_mul_rt_spread.outputs[2]: <I64>>,\n"
411 " <mul_1_mul.outputs[1]: <I64>>,\n"
412 " <mul_1_cast_pp_zero.outputs[0]: <I64>>,\n"
413 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
415 " (<...outputs[0]: <CA>>) <= ClearCA\n"
417 " (<...outputs[0]: <I64*5>>, <...outputs[1]: <CA>>\n"
418 " ) <= SvAddE(<mul_1_retval_concat.outputs[0]: <I64*5>>,\n"
419 " <mul_1_pp_concat.outputs[0]: <I64*5>>,\n"
420 " <mul_1_clear_ca.outputs[0]: <CA>>,\n"
421 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
422 "mul_1_sum_spread:\n"
423 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
424 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>,\n"
425 " <...outputs[4]: <I64>>) <= Spread(\n"
426 " <mul_1_add.outputs[0]: <I64*5>>,\n"
427 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
429 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
430 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
431 " <mul_rhs_spread.outputs[2]: <I64>>,\n"
432 " <mul_zero.outputs[0]: <I64>>,\n"
433 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
434 "mul_2_mul_rt_spread:\n"
435 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
436 " <...outputs[2]: <I64>>) <= Spread(\n"
437 " <mul_2_mul.outputs[0]: <I64*3>>,\n"
438 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
440 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x4)\n"
441 "mul_2_retval_concat:\n"
442 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
443 " <mul_1_sum_spread.outputs[1]: <I64>>,\n"
444 " <mul_1_sum_spread.outputs[2]: <I64>>,\n"
445 " <mul_1_sum_spread.outputs[3]: <I64>>,\n"
446 " <mul_1_sum_spread.outputs[4]: <I64>>,\n"
447 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
449 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
450 " <mul_2_mul_rt_spread.outputs[0]: <I64>>,\n"
451 " <mul_2_mul_rt_spread.outputs[1]: <I64>>,\n"
452 " <mul_2_mul_rt_spread.outputs[2]: <I64>>,\n"
453 " <mul_2_mul.outputs[1]: <I64>>,\n"
454 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
456 " (<...outputs[0]: <CA>>) <= ClearCA\n"
458 " (<...outputs[0]: <I64*4>>, <...outputs[1]: <CA>>\n"
459 " ) <= SvAddE(<mul_2_retval_concat.outputs[0]: <I64*4>>,\n"
460 " <mul_2_pp_concat.outputs[0]: <I64*4>>,\n"
461 " <mul_2_clear_ca.outputs[0]: <CA>>,\n"
462 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
463 "mul_2_sum_spread:\n"
464 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
465 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>) <= Spread(\n"
466 " <mul_2_add.outputs[0]: <I64*4>>,\n"
467 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
469 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
471 " (<...outputs[0]: <I64*6>>) <= Concat(\n"
472 " <mul_0_mul_rt_spread.outputs[0]: <I64>>,\n"
473 " <mul_1_sum_spread.outputs[0]: <I64>>,\n"
474 " <mul_2_sum_spread.outputs[0]: <I64>>,\n"
475 " <mul_2_sum_spread.outputs[1]: <I64>>,\n"
476 " <mul_2_sum_spread.outputs[2]: <I64>>,\n"
477 " <mul_2_sum_spread.outputs[3]: <I64>>,\n"
478 " <mul_setvl.outputs[0]: <VL_MAXVL>>)\n"
480 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
482 " SvStd(<mul_concat.outputs[0]: <I64*6>>,\n"
483 " <ptr_in.outputs[0]: <I64>>,\n"
484 " <dest_setvl.outputs[0]: <VL_MAXVL>>, 0x0)"
487 def test_simple_mul_192x192_reg_alloc(self
):
488 code
= Mul(mul
=simple_umul
, lhs_size_in_words
=3, rhs_size_in_words
=3)
490 assigned_registers
= allocate_registers(
491 fn
, debug_out
=sys
.stdout
, dump_graph
=GraphDumper(self
))
492 print(repr(assigned_registers
))
494 repr(assigned_registers
), "{"
495 "<mul_1_cast_pp_zero.outputs[0]: <I64>>: "
496 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
497 "<mul_1_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
498 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
499 "<mul_1_pp_concat.out0.copy.outputs[0]: <I64*5>>: "
500 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
501 "<mul_1_cast_pp_zero.out0.copy.outputs[0]: <I64>>: "
502 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
503 "<mul_1_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
504 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
505 "<mul_1_pp_concat.inp0.copy.outputs[0]: <I64>>: "
506 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
507 "<mul_1_pp_concat.inp1.copy.outputs[0]: <I64>>: "
508 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
509 "<mul_1_pp_concat.inp2.copy.outputs[0]: <I64>>: "
510 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
511 "<mul_1_pp_concat.inp3.copy.outputs[0]: <I64>>: "
512 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
513 "<mul_1_pp_concat.inp4.copy.outputs[0]: <I64>>: "
514 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
515 "<mul_1_add.inp1.copy.outputs[0]: <I64*5>>: "
516 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
517 "<mul_1_mul.outputs[0]: <I64*3>>: "
518 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
519 "<mul_1_mul.out0.copy.outputs[0]: <I64*3>>: "
520 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
521 "<mul_1_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
522 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
523 "<mul_1_mul_rt_spread.outputs[0]: <I64>>: "
524 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
525 "<mul_1_mul_rt_spread.outputs[1]: <I64>>: "
526 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
527 "<mul_1_mul_rt_spread.outputs[2]: <I64>>: "
528 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
529 "<mul_1_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
530 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
531 "<mul_1_pp_concat.outputs[0]: <I64*5>>: "
532 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
533 "<mul_2_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
534 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
535 "<mul_2_pp_concat.outputs[0]: <I64*4>>: "
536 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
537 "<mul_2_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
538 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
539 "<mul_2_pp_concat.out0.copy.outputs[0]: <I64*4>>: "
540 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
541 "<mul_2_add.inp1.copy.outputs[0]: <I64*4>>: "
542 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
543 "<mul_2_mul.outputs[0]: <I64*3>>: "
544 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
545 "<mul_2_mul.out0.copy.outputs[0]: <I64*3>>: "
546 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
547 "<mul_2_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
548 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
549 "<mul_2_mul_rt_spread.outputs[0]: <I64>>: "
550 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
551 "<mul_2_mul_rt_spread.outputs[1]: <I64>>: "
552 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
553 "<mul_2_mul_rt_spread.outputs[2]: <I64>>: "
554 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
555 "<mul_2_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
556 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
557 "<mul_2_pp_concat.inp0.copy.outputs[0]: <I64>>: "
558 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
559 "<mul_2_pp_concat.inp1.copy.outputs[0]: <I64>>: "
560 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
561 "<mul_2_pp_concat.inp2.copy.outputs[0]: <I64>>: "
562 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
563 "<mul_2_pp_concat.inp3.copy.outputs[0]: <I64>>: "
564 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
565 "<mul_zero.outputs[0]: <I64>>: "
566 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
567 "<mul_0_mul.inp2.copy.outputs[0]: <I64>>: "
568 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
569 "<mul_0_mul.outputs[1]: <I64>>: "
570 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
571 "<mul_1_mul.inp2.copy.outputs[0]: <I64>>: "
572 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
573 "<mul_1_mul.outputs[1]: <I64>>: "
574 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
575 "<mul_1_mul.out1.copy.outputs[0]: <I64>>: "
576 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
577 "<mul_2_mul.inp2.copy.outputs[0]: <I64>>: "
578 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
579 "<mul_2_mul.outputs[1]: <I64>>: "
580 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
581 "<mul_2_mul.out1.copy.outputs[0]: <I64>>: "
582 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
583 "<mul_0_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
584 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
585 "<mul_0_mul_rt_spread.outputs[0]: <I64>>: "
586 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
587 "<mul_0_mul_rt_spread.outputs[1]: <I64>>: "
588 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
589 "<mul_0_mul_rt_spread.outputs[2]: <I64>>: "
590 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
591 "<mul_0_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
592 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
593 "<mul_0_mul.outputs[0]: <I64*3>>: "
594 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
595 "<mul_0_mul.out0.copy.outputs[0]: <I64*3>>: "
596 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
597 "<mul_1_retval_concat.outputs[0]: <I64*5>>: "
598 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
599 "<mul_1_add.inp0.copy.outputs[0]: <I64*5>>: "
600 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
601 "<mul_0_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
602 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
603 "<mul_0_mul.out1.copy.outputs[0]: <I64>>: "
604 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
605 "<mul_1_retval_concat.out0.copy.outputs[0]: <I64*5>>: "
606 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
607 "<mul_1_cast_retval_zero.out0.copy.outputs[0]: <I64>>: "
608 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
609 "<mul_1_retval_concat.inp0.copy.outputs[0]: <I64>>: "
610 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
611 "<mul_1_retval_concat.inp1.copy.outputs[0]: <I64>>: "
612 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
613 "<mul_1_retval_concat.inp2.copy.outputs[0]: <I64>>: "
614 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
615 "<mul_1_retval_concat.inp3.copy.outputs[0]: <I64>>: "
616 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
617 "<mul_1_retval_concat.inp4.copy.outputs[0]: <I64>>: "
618 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
619 "<mul_2_sum_spread.out3.copy.outputs[0]: <I64>>: "
620 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
621 "<mul_1_sum_spread.out0.copy.outputs[0]: <I64>>: "
622 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
623 "<mul_0_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
624 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
625 "<mul_concat.out0.copy.outputs[0]: <I64*6>>: "
626 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
627 "<mul_2_add.inp0.copy.outputs[0]: <I64*4>>: "
628 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
629 "<mul_1_add.outputs[0]: <I64*5>>: "
630 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
631 "<mul_1_add.out0.copy.outputs[0]: <I64*5>>: "
632 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
633 "<mul_2_sum_spread.out0.copy.outputs[0]: <I64>>: "
634 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
635 "<mul_concat.inp0.copy.outputs[0]: <I64>>: "
636 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
637 "<mul_concat.inp1.copy.outputs[0]: <I64>>: "
638 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
639 "<mul_concat.inp2.copy.outputs[0]: <I64>>: "
640 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
641 "<mul_concat.inp3.copy.outputs[0]: <I64>>: "
642 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
643 "<mul_concat.inp4.copy.outputs[0]: <I64>>: "
644 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
645 "<mul_concat.inp5.copy.outputs[0]: <I64>>: "
646 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
647 "<mul_1_sum_spread.inp0.copy.outputs[0]: <I64*5>>: "
648 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
649 "<mul_1_sum_spread.outputs[0]: <I64>>: "
650 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
651 "<mul_1_sum_spread.outputs[1]: <I64>>: "
652 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
653 "<mul_1_sum_spread.outputs[2]: <I64>>: "
654 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
655 "<mul_1_sum_spread.outputs[3]: <I64>>: "
656 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
657 "<mul_1_sum_spread.outputs[4]: <I64>>: "
658 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
659 "<mul_1_sum_spread.out2.copy.outputs[0]: <I64>>: "
660 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
661 "<mul_2_retval_concat.outputs[0]: <I64*4>>: "
662 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
663 "<mul_1_sum_spread.out3.copy.outputs[0]: <I64>>: "
664 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
665 "<mul_2_retval_concat.out0.copy.outputs[0]: <I64*4>>: "
666 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
667 "<mul_1_sum_spread.out4.copy.outputs[0]: <I64>>: "
668 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
669 "<mul_1_sum_spread.out1.copy.outputs[0]: <I64>>: "
670 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
671 "<mul_2_retval_concat.inp0.copy.outputs[0]: <I64>>: "
672 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
673 "<mul_2_retval_concat.inp1.copy.outputs[0]: <I64>>: "
674 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
675 "<mul_2_retval_concat.inp2.copy.outputs[0]: <I64>>: "
676 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
677 "<mul_2_retval_concat.inp3.copy.outputs[0]: <I64>>: "
678 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
679 "<mul_2_add.outputs[0]: <I64*4>>: "
680 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
681 "<mul_2_add.out0.copy.outputs[0]: <I64*4>>: "
682 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
683 "<mul_2_sum_spread.inp0.copy.outputs[0]: <I64*4>>: "
684 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
685 "<mul_2_sum_spread.outputs[0]: <I64>>: "
686 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
687 "<mul_2_sum_spread.outputs[1]: <I64>>: "
688 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
689 "<mul_2_sum_spread.outputs[2]: <I64>>: "
690 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
691 "<mul_2_sum_spread.outputs[3]: <I64>>: "
692 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
693 "<mul_2_sum_spread.out1.copy.outputs[0]: <I64>>: "
694 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
695 "<mul_concat.outputs[0]: <I64*6>>: "
696 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
697 "<mul_2_sum_spread.out2.copy.outputs[0]: <I64>>: "
698 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
699 "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
700 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
701 "<mul_1_cast_retval_zero.outputs[0]: <I64>>: "
702 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
703 "<mul_zero.out0.copy.outputs[0]: <I64>>: "
704 "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
705 "<ptr_in.out0.copy.outputs[0]: <I64>>: "
706 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
707 "<store_dest.inp1.copy.outputs[0]: <I64>>: "
708 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
709 "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
710 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
711 "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
712 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
713 "<ptr_in.outputs[0]: <I64>>: "
714 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
715 "<mul_2_mul.inp1.copy.outputs[0]: <I64>>: "
716 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
717 "<load_rhs.outputs[0]: <I64*3>>: "
718 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
719 "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
720 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
721 "<mul_rhs_spread.out2.copy.outputs[0]: <I64>>: "
722 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
723 "<mul_rhs_spread.out1.copy.outputs[0]: <I64>>: "
724 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
725 "<mul_1_mul.inp1.copy.outputs[0]: <I64>>: "
726 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
727 "<mul_rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
728 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
729 "<mul_rhs_spread.outputs[0]: <I64>>: "
730 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
731 "<mul_rhs_spread.outputs[1]: <I64>>: "
732 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
733 "<mul_rhs_spread.outputs[2]: <I64>>: "
734 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
735 "<mul_rhs_spread.out0.copy.outputs[0]: <I64>>: "
736 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
737 "<mul_0_mul.inp1.copy.outputs[0]: <I64>>: "
738 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
739 "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
740 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
741 "<load_lhs.outputs[0]: <I64*3>>: "
742 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
743 "<mul_0_mul.inp0.copy.outputs[0]: <I64*3>>: "
744 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
745 "<mul_1_mul.inp0.copy.outputs[0]: <I64*3>>: "
746 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
747 "<mul_2_mul.inp0.copy.outputs[0]: <I64*3>>: "
748 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
749 "<mul_zero2.outputs[0]: <I64>>: "
750 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
751 "<mul_zero2.out0.copy.outputs[0]: <I64>>: "
752 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
753 "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
754 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
755 "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
756 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
757 "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
758 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
759 "<mul_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
760 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
761 "<mul_concat.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
762 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
763 "<mul_setvl.outputs[0]: <VL_MAXVL>>: "
764 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
765 "<mul_2_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
766 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
767 "<mul_2_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
768 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
769 "<mul_2_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
770 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
771 "<mul_2_clear_ca.outputs[0]: <CA>>: "
772 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
773 "<mul_2_add.outputs[1]: <CA>>: "
774 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
775 "<mul_2_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
776 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
777 "<mul_2_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
778 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
779 "<mul_2_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
780 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
781 "<mul_2_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
782 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
783 "<mul_2_pp_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
784 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
785 "<mul_2_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
786 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
787 "<mul_2_retval_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
788 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
789 "<mul_2_setvl.outputs[0]: <VL_MAXVL>>: "
790 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
791 "<mul_2_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
792 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
793 "<mul_2_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
794 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
795 "<mul_2_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
796 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
797 "<mul_2_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
798 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
799 "<mul_2_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
800 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
801 "<mul_1_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
802 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
803 "<mul_1_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
804 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
805 "<mul_1_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
806 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
807 "<mul_1_clear_ca.outputs[0]: <CA>>: "
808 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
809 "<mul_1_add.outputs[1]: <CA>>: "
810 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
811 "<mul_1_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
812 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
813 "<mul_1_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
814 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
815 "<mul_1_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
816 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
817 "<mul_1_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
818 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
819 "<mul_1_pp_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
820 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
821 "<mul_1_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
822 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
823 "<mul_1_retval_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
824 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
825 "<mul_1_setvl.outputs[0]: <VL_MAXVL>>: "
826 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
827 "<mul_1_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
828 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
829 "<mul_1_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
830 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
831 "<mul_1_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
832 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
833 "<mul_1_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
834 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
835 "<mul_1_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
836 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
837 "<mul_0_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
838 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
839 "<mul_0_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
840 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
841 "<mul_0_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
842 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
843 "<mul_0_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
844 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
845 "<mul_0_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
846 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
847 "<mul_lhs_setvl.outputs[0]: <VL_MAXVL>>: "
848 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
849 "<mul_rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
850 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
851 "<mul_rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
852 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
853 "<mul_rhs_setvl.outputs[0]: <VL_MAXVL>>: "
854 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
855 "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
856 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
857 "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
858 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
859 "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
860 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
861 "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
862 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
863 "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
864 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
865 "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
866 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
869 def test_simple_mul_192x192_asm(self
):
870 code
= Mul(mul
=simple_umul
, lhs_size_in_words
=3, rhs_size_in_words
=3)
872 assigned_registers
= allocate_registers(
873 fn
, debug_out
=sys
.stdout
, dump_graph
=GraphDumper(self
))
874 gen_asm_state
= GenAsmState(assigned_registers
)
875 fn
.gen_asm(gen_asm_state
)
876 print(gen_asm_state
.output
)
877 self
.assertEqual(gen_asm_state
.output
, [
879 'setvl 0, 0, 3, 0, 1, 1',
880 'setvl 0, 0, 3, 0, 1, 1',
882 'setvl 0, 0, 3, 0, 1, 1',
883 'setvl 0, 0, 3, 0, 1, 1',
884 'setvl 0, 0, 3, 0, 1, 1',
886 'setvl 0, 0, 3, 0, 1, 1',
887 'setvl 0, 0, 3, 0, 1, 1',
888 'setvl 0, 0, 3, 0, 1, 1',
889 'setvl 0, 0, 3, 0, 1, 1',
892 'setvl 0, 0, 3, 0, 1, 1',
894 'setvl 0, 0, 3, 0, 1, 1',
896 'setvl 0, 0, 3, 0, 1, 1',
897 'sv.maddedu *14, *24, 20, 6',
898 'setvl 0, 0, 3, 0, 1, 1',
900 'setvl 0, 0, 3, 0, 1, 1',
901 'setvl 0, 0, 3, 0, 1, 1',
902 'setvl 0, 0, 3, 0, 1, 1',
904 'setvl 0, 0, 3, 0, 1, 1',
905 'sv.maddedu *3, *24, 21, 6',
906 'setvl 0, 0, 3, 0, 1, 1',
907 'setvl 0, 0, 3, 0, 1, 1',
908 'setvl 0, 0, 3, 0, 1, 1',
912 'setvl 0, 0, 5, 0, 1, 1',
914 'setvl 0, 0, 5, 0, 1, 1',
915 'setvl 0, 0, 5, 0, 1, 1',
916 'setvl 0, 0, 5, 0, 1, 1',
917 'setvl 0, 0, 5, 0, 1, 1',
919 'setvl 0, 0, 5, 0, 1, 1',
920 'setvl 0, 0, 5, 0, 1, 1',
921 'setvl 0, 0, 5, 0, 1, 1',
922 'sv.adde *15, *15, *3',
923 'setvl 0, 0, 5, 0, 1, 1',
924 'setvl 0, 0, 5, 0, 1, 1',
925 'setvl 0, 0, 5, 0, 1, 1',
926 'setvl 0, 0, 3, 0, 1, 1',
928 'setvl 0, 0, 3, 0, 1, 1',
929 'sv.maddedu *3, *24, 22, 6',
930 'setvl 0, 0, 3, 0, 1, 1',
931 'setvl 0, 0, 3, 0, 1, 1',
932 'setvl 0, 0, 3, 0, 1, 1',
933 'setvl 0, 0, 4, 0, 1, 1',
934 'setvl 0, 0, 4, 0, 1, 1',
935 'setvl 0, 0, 4, 0, 1, 1',
936 'setvl 0, 0, 4, 0, 1, 1',
937 'setvl 0, 0, 4, 0, 1, 1',
939 'setvl 0, 0, 4, 0, 1, 1',
940 'setvl 0, 0, 4, 0, 1, 1',
941 'setvl 0, 0, 4, 0, 1, 1',
942 'sv.adde *16, *16, *3',
943 'setvl 0, 0, 4, 0, 1, 1',
944 'setvl 0, 0, 4, 0, 1, 1',
945 'setvl 0, 0, 4, 0, 1, 1',
946 'setvl 0, 0, 6, 0, 1, 1',
947 'setvl 0, 0, 6, 0, 1, 1',
948 'setvl 0, 0, 6, 0, 1, 1',
949 'setvl 0, 0, 6, 0, 1, 1',
950 'setvl 0, 0, 6, 0, 1, 1',
951 'setvl 0, 0, 6, 0, 1, 1',
955 def toom_2_mul_256x256(self
, lhs_signed
, rhs_signed
):
956 # type: (bool, bool) -> Mul
957 TOOM_2
= ToomCookInstance
.make_toom_2()
960 def mul(fn
, lhs
, rhs
):
961 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
962 v
= ToomCookMul(fn
=fn
, lhs
=lhs
, lhs_signed
=lhs_signed
, rhs
=rhs
,
963 rhs_signed
=rhs_signed
, instances
=instances
)
965 return Mul(mul
=mul
, lhs_size_in_words
=4, rhs_size_in_words
=4)
967 def make_256x256_mul_test_cases(self
, lhs_signed
, rhs_signed
):
968 # type: (bool, bool) -> Iterator[tuple[int, int, int]]
969 # test multiplying `+-1 << n` and:
970 # 0xc162321a5eaad80b_4b86bb0efdfb93c0_a789ff04cc11b157_eaa08e29fb197621
972 # 0x3138710167583371_998af336a8fac64d_e6da3737090787fe_85ba09ea701f4af2
975 # "252e6e6f69746163_696c7069746c754d_"
976 # "2061627573746172_614b202d20322d4d_"
977 # "4f4f5420676e6973_75206c756d20746e_"
978 # "6967696220746962_2d36353278363532", base=0)
979 # == int.from_bytes(b'256x256-bit bigint mul using TOOM-2 '
980 # b'- Karatsuba Multiplication.%', 'little')
981 lhs_value_in
= (0xc162321a5eaad80b_4b86bb0efdfb93c0 << 128) \
982 |
0xa789ff04cc11b157_eaa08e29fb197621
983 rhs_value_in
= (0x3138710167583371_998af336a8fac64d << 128) \
984 |
0xe6da3737090787fe_85ba09ea701f4af2
985 prod_value_in
= int.from_bytes(
986 b
'256x256-bit bigint mul using TOOM-2 '
987 b
'- Karatsuba Multiplication.%', 'little')
988 self
.assertEqual(lhs_value_in
* rhs_value_in
, prod_value_in
)
989 shifts
= [*range(0, 256, 16), *range(15, 256, 16)]
990 lhs_values
= [1 << i
for i
in shifts
] + [0, lhs_value_in
]
991 rhs_values
= [1 << i
for i
in shifts
] + [0, rhs_value_in
]
993 lhs_values
.extend([-i
for i
in lhs_values
])
995 rhs_values
.extend([-i
for i
in rhs_values
])
998 # type: (int) -> tuple[bool, int]
999 return abs(v
) in (lhs_value_in
, rhs_value_in
), v
% (1 << 256)
1001 lhs_values
.sort(key
=key
)
1002 rhs_values
.sort(key
=key
)
1003 for lhs_value
in lhs_values
:
1004 for rhs_value
in rhs_values
:
1005 lhs_value
%= 1 << 256
1006 rhs_value
%= 1 << 256
1007 if lhs_value
>> 255 != 0 and lhs_signed
:
1008 lhs_value
-= 1 << 256
1009 if rhs_value
>> 255 != 0 and rhs_signed
:
1010 rhs_value
-= 1 << 256
1011 prod_value
= lhs_value
* rhs_value
1012 lhs_value
%= 1 << 256
1013 rhs_value
%= 1 << 256
1014 prod_value
%= 1 << 512
1015 yield lhs_value
, rhs_value
, prod_value
1017 def tst_toom_2_mul_256x256_sim(
1018 self
, lhs_signed
, # type: bool
1019 rhs_signed
, # type: bool
1020 get_state_factory
, # type: Callable[[Mul], _StateFactory]
1022 code
= self
.toom_2_mul_256x256(
1023 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
)
1024 print(code
.retval
[1])
1025 print(code
.fn
.ops_to_str())
1026 state_factory
= get_state_factory(code
)
1028 dest_ptr
= ptr_in
+ code
.dest_offset
1029 lhs_ptr
= ptr_in
+ code
.lhs_offset
1030 rhs_ptr
= ptr_in
+ code
.rhs_offset
1031 values
= self
.make_256x256_mul_test_cases(
1032 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
)
1033 for lhs_value
, rhs_value
, prod_value
in values
:
1034 with self
.subTest(lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
1035 lhs_value
=hex(lhs_value
),
1036 rhs_value
=hex(rhs_value
),
1037 prod_value
=hex(prod_value
)):
1038 with
state_factory() as state
:
1039 state
[code
.ptr_in
] = ptr_in
,
1041 v
= lhs_value
>> GPR_SIZE_IN_BITS
* i
1043 state
.store(lhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
1045 v
= rhs_value
>> GPR_SIZE_IN_BITS
* i
1047 state
.store(rhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
1051 v
= state
.load(dest_ptr
+ GPR_SIZE_IN_BYTES
* i
)
1052 prod
+= v
<< (GPR_SIZE_IN_BITS
* i
)
1053 self
.assertEqual(hex(prod
), hex(prod_value
),
1054 f
"failed: state={state}")
1056 def test_toom_2_mul_256x256_pre_ra_sim(self
):
1057 for lhs_signed
in False, True:
1058 for rhs_signed
in False, True:
1059 self
.tst_toom_2_mul_256x256_sim(
1060 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
1061 get_state_factory
=get_pre_ra_state_factory
)
1063 def test_toom_2_mul_256x256_uu_post_ra_sim(self
):
1064 self
.tst_toom_2_mul_256x256_sim(
1065 lhs_signed
=False, rhs_signed
=False,
1066 get_state_factory
=self
.get_post_ra_state_factory
)
1068 def test_toom_2_mul_256x256_su_post_ra_sim(self
):
1069 self
.tst_toom_2_mul_256x256_sim(
1070 lhs_signed
=True, rhs_signed
=False,
1071 get_state_factory
=self
.get_post_ra_state_factory
)
1073 def test_toom_2_mul_256x256_us_post_ra_sim(self
):
1074 self
.tst_toom_2_mul_256x256_sim(
1075 lhs_signed
=False, rhs_signed
=True,
1076 get_state_factory
=self
.get_post_ra_state_factory
)
1078 def test_toom_2_mul_256x256_ss_post_ra_sim(self
):
1079 self
.tst_toom_2_mul_256x256_sim(
1080 lhs_signed
=True, rhs_signed
=True,
1081 get_state_factory
=self
.get_post_ra_state_factory
)
1083 def test_toom_2_mul_256x256_asm(self
):
1084 code
= self
.toom_2_mul_256x256(lhs_signed
=False, rhs_signed
=False)
1086 assigned_registers
= allocate_registers(
1087 fn
, debug_out
=sys
.stdout
, dump_graph
=GraphDumper(self
))
1088 gen_asm_state
= GenAsmState(assigned_registers
)
1089 fn
.gen_asm(gen_asm_state
)
1090 print(gen_asm_state
.output
)
1091 self
.assertEqual(gen_asm_state
.output
, [
1093 'setvl 0, 0, 4, 0, 1, 1',
1095 'setvl 0, 0, 4, 0, 1, 1',
1097 'setvl 0, 0, 4, 0, 1, 1',
1098 'setvl 0, 0, 4, 0, 1, 1',
1100 'setvl 0, 0, 4, 0, 1, 1',
1102 'setvl 0, 0, 4, 0, 1, 1',
1103 'sv.or *3, *14, *14',
1104 'setvl 0, 0, 4, 0, 1, 1',
1105 'setvl 0, 0, 4, 0, 1, 1',
1106 'sv.or *14, *39, *39',
1107 'setvl 0, 0, 4, 0, 1, 1',
1109 'setvl 0, 0, 3, 0, 1, 1',
1111 'setvl 0, 0, 3, 0, 1, 1',
1112 'setvl 0, 0, 3, 0, 1, 1',
1113 'sv.or *30, *14, *14',
1114 'setvl 0, 0, 3, 0, 1, 1',
1115 'setvl 0, 0, 3, 0, 1, 1',
1117 'setvl 0, 0, 3, 0, 1, 1',
1118 'setvl 0, 0, 3, 0, 1, 1',
1119 'sv.or *14, *30, *30',
1120 'setvl 0, 0, 3, 0, 1, 1',
1123 'setvl 0, 0, 4, 0, 1, 1',
1126 'setvl 0, 0, 4, 0, 1, 1',
1127 'setvl 0, 0, 4, 0, 1, 1',
1128 'sv.or *30, *14, *14',
1129 'setvl 0, 0, 1, 0, 1, 1',
1131 'setvl 0, 0, 1, 0, 1, 1',
1132 'sv.or *7, *17, *17',
1134 'setvl 0, 0, 4, 0, 1, 1',
1137 'setvl 0, 0, 4, 0, 1, 1',
1138 'setvl 0, 0, 4, 0, 1, 1',
1139 'sv.or *17, *7, *7',
1140 'setvl 0, 0, 4, 0, 1, 1',
1142 'setvl 0, 0, 4, 0, 1, 1',
1143 'sv.or *39, *30, *30',
1144 'setvl 0, 0, 4, 0, 1, 1',
1145 'setvl 0, 0, 4, 0, 1, 1',
1146 'sv.adde *34, *39, *17',
1147 'setvl 0, 0, 4, 0, 1, 1',
1148 'sv.or *44, *34, *34',
1149 'setvl 0, 0, 4, 0, 1, 1',
1150 'setvl 0, 0, 4, 0, 1, 1',
1151 'sv.or *14, *3, *3',
1152 'setvl 0, 0, 4, 0, 1, 1',
1155 'setvl 0, 0, 3, 0, 1, 1',
1157 'setvl 0, 0, 3, 0, 1, 1',
1158 'setvl 0, 0, 3, 0, 1, 1',
1159 'setvl 0, 0, 3, 0, 1, 1',
1160 'setvl 0, 0, 3, 0, 1, 1',
1161 'sv.or *17, *10, *10',
1163 'setvl 0, 0, 3, 0, 1, 1',
1164 'setvl 0, 0, 3, 0, 1, 1',
1165 'sv.or *7, *14, *14',
1166 'setvl 0, 0, 3, 0, 1, 1',
1171 'setvl 0, 0, 4, 0, 1, 1',
1173 'setvl 0, 0, 4, 0, 1, 1',
1174 'setvl 0, 0, 4, 0, 1, 1',
1175 'setvl 0, 0, 1, 0, 1, 1',
1177 'setvl 0, 0, 1, 0, 1, 1',
1179 'setvl 0, 0, 4, 0, 1, 1',
1183 'setvl 0, 0, 4, 0, 1, 1',
1184 'setvl 0, 0, 4, 0, 1, 1',
1185 'setvl 0, 0, 4, 0, 1, 1',
1187 'setvl 0, 0, 4, 0, 1, 1',
1188 'setvl 0, 0, 4, 0, 1, 1',
1189 'setvl 0, 0, 4, 0, 1, 1',
1190 'sv.adde *7, *3, *17',
1191 'setvl 0, 0, 4, 0, 1, 1',
1192 'sv.or *38, *7, *7',
1193 'setvl 0, 0, 4, 0, 1, 1',
1194 'setvl 0, 0, 4, 0, 1, 1',
1195 'setvl 0, 0, 4, 0, 1, 1',
1198 'setvl 0, 0, 4, 0, 1, 1',
1200 'setvl 0, 0, 4, 0, 1, 1',
1201 'sv.or *14, *30, *30',
1204 'setvl 0, 0, 4, 0, 1, 1',
1205 'sv.maddedu *22, *14, 7, 21',
1206 'setvl 0, 0, 4, 0, 1, 1',
1208 'setvl 0, 0, 4, 0, 1, 1',
1209 'setvl 0, 0, 4, 0, 1, 1',
1210 'setvl 0, 0, 4, 0, 1, 1',
1211 'sv.or *14, *30, *30',
1214 'setvl 0, 0, 4, 0, 1, 1',
1215 'sv.maddedu *7, *14, 34, 21',
1216 'setvl 0, 0, 4, 0, 1, 1',
1218 'setvl 0, 0, 4, 0, 1, 1',
1219 'setvl 0, 0, 4, 0, 1, 1',
1224 'setvl 0, 0, 6, 0, 1, 1',
1226 'setvl 0, 0, 6, 0, 1, 1',
1227 'setvl 0, 0, 6, 0, 1, 1',
1228 'setvl 0, 0, 6, 0, 1, 1',
1229 'setvl 0, 0, 6, 0, 1, 1',
1231 'setvl 0, 0, 6, 0, 1, 1',
1232 'setvl 0, 0, 6, 0, 1, 1',
1233 'setvl 0, 0, 6, 0, 1, 1',
1234 'sv.adde *23, *23, *7',
1235 'setvl 0, 0, 6, 0, 1, 1',
1236 'setvl 0, 0, 6, 0, 1, 1',
1237 'setvl 0, 0, 6, 0, 1, 1',
1238 'setvl 0, 0, 4, 0, 1, 1',
1239 'sv.or *14, *30, *30',
1242 'setvl 0, 0, 4, 0, 1, 1',
1243 'sv.maddedu *7, *14, 34, 21',
1244 'setvl 0, 0, 4, 0, 1, 1',
1246 'setvl 0, 0, 4, 0, 1, 1',
1247 'setvl 0, 0, 4, 0, 1, 1',
1252 'setvl 0, 0, 6, 0, 1, 1',
1253 'setvl 0, 0, 6, 0, 1, 1',
1254 'setvl 0, 0, 6, 0, 1, 1',
1255 'setvl 0, 0, 6, 0, 1, 1',
1256 'setvl 0, 0, 6, 0, 1, 1',
1258 'setvl 0, 0, 6, 0, 1, 1',
1259 'setvl 0, 0, 6, 0, 1, 1',
1260 'setvl 0, 0, 6, 0, 1, 1',
1261 'sv.adde *24, *24, *7',
1262 'setvl 0, 0, 6, 0, 1, 1',
1263 'setvl 0, 0, 6, 0, 1, 1',
1264 'setvl 0, 0, 6, 0, 1, 1',
1265 'setvl 0, 0, 4, 0, 1, 1',
1266 'sv.or *14, *30, *30',
1269 'setvl 0, 0, 4, 0, 1, 1',
1270 'sv.maddedu *3, *14, 8, 21',
1271 'setvl 0, 0, 4, 0, 1, 1',
1273 'setvl 0, 0, 4, 0, 1, 1',
1274 'setvl 0, 0, 4, 0, 1, 1',
1275 'setvl 0, 0, 5, 0, 1, 1',
1276 'setvl 0, 0, 5, 0, 1, 1',
1277 'setvl 0, 0, 5, 0, 1, 1',
1278 'setvl 0, 0, 5, 0, 1, 1',
1279 'setvl 0, 0, 5, 0, 1, 1',
1281 'setvl 0, 0, 5, 0, 1, 1',
1282 'setvl 0, 0, 5, 0, 1, 1',
1283 'setvl 0, 0, 5, 0, 1, 1',
1284 'sv.adde *25, *25, *3',
1285 'setvl 0, 0, 5, 0, 1, 1',
1286 'setvl 0, 0, 5, 0, 1, 1',
1287 'setvl 0, 0, 5, 0, 1, 1',
1288 'setvl 0, 0, 8, 0, 1, 1',
1289 'setvl 0, 0, 8, 0, 1, 1',
1290 'setvl 0, 0, 8, 0, 1, 1',
1291 'setvl 0, 0, 4, 0, 1, 1',
1292 'setvl 0, 0, 4, 0, 1, 1',
1293 'sv.or *7, *38, *38',
1294 'setvl 0, 0, 4, 0, 1, 1',
1301 'setvl 0, 0, 4, 0, 1, 1',
1303 'setvl 0, 0, 4, 0, 1, 1',
1304 'sv.or *34, *44, *44',
1307 'setvl 0, 0, 4, 0, 1, 1',
1308 'sv.maddedu *14, *34, 9, 7',
1309 'setvl 0, 0, 4, 0, 1, 1',
1311 'setvl 0, 0, 4, 0, 1, 1',
1312 'setvl 0, 0, 4, 0, 1, 1',
1314 'setvl 0, 0, 4, 0, 1, 1',
1317 'setvl 0, 0, 4, 0, 1, 1',
1318 'sv.maddedu *3, *44, 9, 7',
1319 'setvl 0, 0, 4, 0, 1, 1',
1320 'setvl 0, 0, 4, 0, 1, 1',
1321 'setvl 0, 0, 4, 0, 1, 1',
1326 'setvl 0, 0, 6, 0, 1, 1',
1328 'setvl 0, 0, 6, 0, 1, 1',
1329 'setvl 0, 0, 6, 0, 1, 1',
1330 'setvl 0, 0, 6, 0, 1, 1',
1331 'setvl 0, 0, 6, 0, 1, 1',
1333 'setvl 0, 0, 6, 0, 1, 1',
1334 'setvl 0, 0, 6, 0, 1, 1',
1335 'setvl 0, 0, 6, 0, 1, 1',
1336 'sv.adde *15, *15, *3',
1337 'setvl 0, 0, 6, 0, 1, 1',
1338 'setvl 0, 0, 6, 0, 1, 1',
1339 'setvl 0, 0, 6, 0, 1, 1',
1341 'setvl 0, 0, 4, 0, 1, 1',
1344 'setvl 0, 0, 4, 0, 1, 1',
1345 'sv.maddedu *3, *44, 9, 7',
1346 'setvl 0, 0, 4, 0, 1, 1',
1347 'setvl 0, 0, 4, 0, 1, 1',
1348 'setvl 0, 0, 4, 0, 1, 1',
1352 'setvl 0, 0, 6, 0, 1, 1',
1353 'setvl 0, 0, 6, 0, 1, 1',
1354 'setvl 0, 0, 6, 0, 1, 1',
1355 'setvl 0, 0, 6, 0, 1, 1',
1356 'setvl 0, 0, 6, 0, 1, 1',
1358 'setvl 0, 0, 6, 0, 1, 1',
1359 'setvl 0, 0, 6, 0, 1, 1',
1360 'setvl 0, 0, 6, 0, 1, 1',
1361 'sv.adde *16, *16, *3',
1362 'setvl 0, 0, 6, 0, 1, 1',
1363 'setvl 0, 0, 6, 0, 1, 1',
1364 'setvl 0, 0, 6, 0, 1, 1',
1366 'setvl 0, 0, 4, 0, 1, 1',
1369 'setvl 0, 0, 4, 0, 1, 1',
1370 'sv.maddedu *38, *44, 9, 7',
1371 'setvl 0, 0, 4, 0, 1, 1',
1373 'setvl 0, 0, 4, 0, 1, 1',
1374 'setvl 0, 0, 4, 0, 1, 1',
1375 'setvl 0, 0, 5, 0, 1, 1',
1376 'setvl 0, 0, 5, 0, 1, 1',
1377 'setvl 0, 0, 5, 0, 1, 1',
1378 'setvl 0, 0, 5, 0, 1, 1',
1379 'setvl 0, 0, 5, 0, 1, 1',
1381 'setvl 0, 0, 5, 0, 1, 1',
1382 'setvl 0, 0, 5, 0, 1, 1',
1383 'setvl 0, 0, 5, 0, 1, 1',
1384 'sv.adde *17, *17, *38',
1385 'setvl 0, 0, 5, 0, 1, 1',
1386 'setvl 0, 0, 5, 0, 1, 1',
1387 'setvl 0, 0, 5, 0, 1, 1',
1391 'setvl 0, 0, 8, 0, 1, 1',
1398 'setvl 0, 0, 8, 0, 1, 1',
1399 'setvl 0, 0, 8, 0, 1, 1',
1400 'sv.or *30, *14, *14',
1401 'setvl 0, 0, 1, 0, 1, 1',
1403 'setvl 0, 0, 1, 0, 1, 1',
1406 'setvl 0, 0, 1, 0, 1, 1',
1411 'setvl 0, 0, 1, 0, 1, 1',
1412 'sv.maddedu *3, *42, 10, 4',
1413 'setvl 0, 0, 1, 0, 1, 1',
1414 'setvl 0, 0, 2, 0, 1, 1',
1415 'setvl 0, 0, 2, 0, 1, 1',
1416 'setvl 0, 0, 2, 0, 1, 1',
1417 'setvl 0, 0, 8, 0, 1, 1',
1418 'setvl 0, 0, 8, 0, 1, 1',
1419 'setvl 0, 0, 8, 0, 1, 1',
1420 'setvl 0, 0, 7, 0, 1, 1',
1421 'setvl 0, 0, 7, 0, 1, 1',
1422 'setvl 0, 0, 7, 0, 1, 1',
1423 'setvl 0, 0, 8, 0, 1, 1',
1424 'setvl 0, 0, 8, 0, 1, 1',
1425 'sv.or *14, *30, *30',
1426 'setvl 0, 0, 8, 0, 1, 1',
1433 'setvl 0, 0, 7, 0, 1, 1',
1439 'setvl 0, 0, 7, 0, 1, 1',
1440 'setvl 0, 0, 7, 0, 1, 1',
1441 'sv.or *30, *14, *14',
1442 'setvl 0, 0, 7, 0, 1, 1',
1444 'setvl 0, 0, 7, 0, 1, 1',
1445 'setvl 0, 0, 7, 0, 1, 1',
1446 'sv.or *14, *30, *30',
1447 'setvl 0, 0, 7, 0, 1, 1',
1448 'sv.subfe *30, *22, *14',
1449 'setvl 0, 0, 7, 0, 1, 1',
1450 'setvl 0, 0, 2, 0, 1, 1',
1451 'setvl 0, 0, 2, 0, 1, 1',
1452 'setvl 0, 0, 2, 0, 1, 1',
1454 'setvl 0, 0, 7, 0, 1, 1',
1460 'setvl 0, 0, 7, 0, 1, 1',
1461 'setvl 0, 0, 7, 0, 1, 1',
1462 'setvl 0, 0, 7, 0, 1, 1',
1464 'setvl 0, 0, 7, 0, 1, 1',
1465 'setvl 0, 0, 7, 0, 1, 1',
1466 'setvl 0, 0, 7, 0, 1, 1',
1467 'sv.subfe *14, *3, *30',
1468 'setvl 0, 0, 7, 0, 1, 1',
1469 'setvl 0, 0, 7, 0, 1, 1',
1470 'setvl 0, 0, 7, 0, 1, 1',
1471 'setvl 0, 0, 7, 0, 1, 1',
1472 'setvl 0, 0, 7, 0, 1, 1',
1473 'setvl 0, 0, 7, 0, 1, 1',
1474 'setvl 0, 0, 7, 0, 1, 1',
1475 'setvl 0, 0, 2, 0, 1, 1',
1476 'setvl 0, 0, 2, 0, 1, 1',
1477 'setvl 0, 0, 2, 0, 1, 1',
1481 'setvl 0, 0, 5, 0, 1, 1',
1482 'setvl 0, 0, 5, 0, 1, 1',
1483 'setvl 0, 0, 5, 0, 1, 1',
1484 'setvl 0, 0, 5, 0, 1, 1',
1485 'setvl 0, 0, 5, 0, 1, 1',
1487 'setvl 0, 0, 5, 0, 1, 1',
1488 'setvl 0, 0, 5, 0, 1, 1',
1489 'setvl 0, 0, 5, 0, 1, 1',
1490 'sv.adde *25, *25, *14',
1491 'setvl 0, 0, 5, 0, 1, 1',
1492 'setvl 0, 0, 5, 0, 1, 1',
1493 'setvl 0, 0, 5, 0, 1, 1',
1494 'setvl 0, 0, 2, 0, 1, 1',
1495 'setvl 0, 0, 2, 0, 1, 1',
1496 'setvl 0, 0, 2, 0, 1, 1',
1497 'setvl 0, 0, 2, 0, 1, 1',
1498 'setvl 0, 0, 2, 0, 1, 1',
1500 'setvl 0, 0, 2, 0, 1, 1',
1501 'setvl 0, 0, 2, 0, 1, 1',
1502 'setvl 0, 0, 2, 0, 1, 1',
1503 'sv.adde *28, *28, *3',
1504 'setvl 0, 0, 2, 0, 1, 1',
1505 'setvl 0, 0, 2, 0, 1, 1',
1506 'setvl 0, 0, 2, 0, 1, 1',
1507 'setvl 0, 0, 8, 0, 1, 1',
1508 'setvl 0, 0, 8, 0, 1, 1',
1509 'setvl 0, 0, 8, 0, 1, 1',
1510 'setvl 0, 0, 8, 0, 1, 1',
1511 'setvl 0, 0, 8, 0, 1, 1',
1513 'setvl 0, 0, 8, 0, 1, 1',
1517 def tst_toom_mul_sim(
1518 self
, code
, # type: Mul
1519 lhs_signed
, # type: bool
1520 rhs_signed
, # type: bool
1521 get_state_factory
, # type: Callable[[Mul], _StateFactory]
1522 test_cases
, # type: Iterable[tuple[int, int]]
1524 print(code
.retval
[1])
1525 print(code
.fn
.ops_to_str())
1526 state_factory
= get_state_factory(code
)
1528 dest_ptr
= ptr_in
+ code
.dest_offset
1529 lhs_ptr
= ptr_in
+ code
.lhs_offset
1530 rhs_ptr
= ptr_in
+ code
.rhs_offset
1531 lhs_size_in_bits
= code
.lhs_size_in_words
* GPR_SIZE_IN_BITS
1532 rhs_size_in_bits
= code
.rhs_size_in_words
* GPR_SIZE_IN_BITS
1533 for lhs_value
, rhs_value
in test_cases
:
1534 lhs_value
%= 1 << lhs_size_in_bits
1535 rhs_value
%= 1 << rhs_size_in_bits
1536 if lhs_signed
and lhs_value
>> (lhs_size_in_bits
- 1):
1537 lhs_value
-= 1 << lhs_size_in_bits
1538 if rhs_signed
and rhs_value
>> (rhs_size_in_bits
- 1):
1539 rhs_value
-= 1 << rhs_size_in_bits
1540 prod_value
= lhs_value
* rhs_value
1541 lhs_value
%= 1 << lhs_size_in_bits
1542 rhs_value
%= 1 << rhs_size_in_bits
1543 prod_value
%= 1 << (lhs_size_in_bits
+ rhs_size_in_bits
)
1544 with self
.subTest(lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
1545 lhs_value
=hex(lhs_value
),
1546 rhs_value
=hex(rhs_value
),
1547 prod_value
=hex(prod_value
)):
1548 with
state_factory() as state
:
1549 state
[code
.ptr_in
] = ptr_in
,
1550 for i
in range(code
.lhs_size_in_words
):
1551 v
= lhs_value
>> GPR_SIZE_IN_BITS
* i
1553 state
.store(lhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
1554 for i
in range(code
.rhs_size_in_words
):
1555 v
= rhs_value
>> GPR_SIZE_IN_BITS
* i
1557 state
.store(rhs_ptr
+ i
* GPR_SIZE_IN_BYTES
, v
)
1560 for i
in range(code
.dest_size_in_words
):
1561 v
= state
.load(dest_ptr
+ GPR_SIZE_IN_BYTES
* i
)
1562 prod
+= v
<< (GPR_SIZE_IN_BITS
* i
)
1563 self
.assertEqual(hex(prod
), hex(prod_value
),
1564 f
"failed: state={state}")
1566 def tst_toom_mul_all_sizes_pre_ra_sim(self
, instances
, lhs_signed
, rhs_signed
):
1567 # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
1568 def mul(fn
, lhs
, rhs
):
1569 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
1571 fn
=fn
, lhs
=lhs
, lhs_signed
=lhs_signed
, rhs
=rhs
,
1572 rhs_signed
=rhs_signed
, instances
=instances
)
1574 sizes_in_words
= OSet() # type: OSet[int]
1576 sizes_in_words
.add(1 << i
)
1577 sizes_in_words
.add(3 << i
)
1578 sizes_in_words
= OSet(
1579 i
for i
in sorted(sizes_in_words
) if 1 <= i
<= 16)
1580 for lhs_size_in_words
in sizes_in_words
:
1581 for rhs_size_in_words
in sizes_in_words
:
1582 lhs_size_in_bits
= GPR_SIZE_IN_BITS
* lhs_size_in_words
1583 rhs_size_in_bits
= GPR_SIZE_IN_BITS
* rhs_size_in_words
1584 with self
.subTest(lhs_size_in_words
=lhs_size_in_words
,
1585 rhs_size_in_words
=rhs_size_in_words
,
1586 lhs_signed
=lhs_signed
,
1587 rhs_signed
=rhs_signed
):
1588 test_cases
= [] # type: list[tuple[int, int]]
1589 test_cases
.append((-1, -1))
1590 test_cases
.append(((0x80 << 2048) // 0xFF,
1591 (0x80 << 2048) // 0xFF))
1592 test_cases
.append(((0x40 << 2048) // 0xFF,
1593 (0x80 << 2048) // 0xFF))
1594 test_cases
.append(((0x80 << 2048) // 0xFF,
1595 (0x40 << 2048) // 0xFF))
1596 test_cases
.append(((0x40 << 2048) // 0xFF,
1597 (0x40 << 2048) // 0xFF))
1598 test_cases
.append((1 << (lhs_size_in_bits
- 1),
1599 1 << (rhs_size_in_bits
- 1)))
1600 test_cases
.append((1, 1 << (rhs_size_in_bits
- 1)))
1601 test_cases
.append((1 << (lhs_size_in_bits
- 1), 1))
1602 test_cases
.append((1, 1))
1603 self
.tst_toom_mul_sim(
1605 lhs_size_in_words
=lhs_size_in_words
,
1606 rhs_size_in_words
=rhs_size_in_words
),
1607 lhs_signed
=lhs_signed
, rhs_signed
=rhs_signed
,
1608 get_state_factory
=get_pre_ra_state_factory
,
1609 test_cases
=test_cases
)
1611 def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self
):
1612 TOOM_2
= ToomCookInstance
.make_toom_2()
1613 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1614 (TOOM_2
,), lhs_signed
=False, rhs_signed
=False)
1616 def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self
):
1617 TOOM_2
= ToomCookInstance
.make_toom_2()
1618 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1619 (TOOM_2
,), lhs_signed
=False, rhs_signed
=True)
1621 def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self
):
1622 TOOM_2
= ToomCookInstance
.make_toom_2()
1623 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1624 (TOOM_2
,), lhs_signed
=True, rhs_signed
=False)
1626 def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self
):
1627 TOOM_2
= ToomCookInstance
.make_toom_2()
1628 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1629 (TOOM_2
,), lhs_signed
=True, rhs_signed
=True)
1631 def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self
):
1632 TOOM_2
= ToomCookInstance
.make_toom_2()
1633 instances
= TOOM_2
, TOOM_2
, TOOM_2
, TOOM_2
1634 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1635 instances
, lhs_signed
=False, rhs_signed
=False)
1637 def test_toom_2_mul_us_all_sizes_pre_ra_sim(self
):
1638 TOOM_2
= ToomCookInstance
.make_toom_2()
1639 instances
= TOOM_2
, TOOM_2
, TOOM_2
, TOOM_2
1640 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1641 instances
, lhs_signed
=False, rhs_signed
=True)
1643 def test_toom_2_mul_su_all_sizes_pre_ra_sim(self
):
1644 TOOM_2
= ToomCookInstance
.make_toom_2()
1645 instances
= TOOM_2
, TOOM_2
, TOOM_2
, TOOM_2
1646 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1647 instances
, lhs_signed
=True, rhs_signed
=False)
1649 def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self
):
1650 TOOM_2
= ToomCookInstance
.make_toom_2()
1651 instances
= TOOM_2
, TOOM_2
, TOOM_2
, TOOM_2
1652 self
.tst_toom_mul_all_sizes_pre_ra_sim(
1653 instances
, lhs_signed
=True, rhs_signed
=True)
1656 if __name__
== "__main__":