TOOM-2 multiplication works for all sizes
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
1 from contextlib import contextmanager
2 import unittest
3 from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
4
5 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
6 GPR_SIZE_IN_BYTES,
7 GPR_VALUE_MASK, BaseSimState,
8 Fn, GenAsmState, OpKind,
9 PostRASimState,
10 PreRASimState, SSAVal)
11 from bigint_presentation_code.register_allocator import allocate_registers
12 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
13 simple_mul)
14 from bigint_presentation_code.util import OSet
15
16 _StateFactory = Callable[[], ContextManager[BaseSimState]]
17
18
19 def simple_umul(fn, lhs, rhs):
20 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, None]
21 return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
22 rhs_signed=False, name="mul"), None
23
24
25 def get_pre_ra_state_factory(code):
26 # type: (Mul) -> _StateFactory
27 @contextmanager
28 def state_factory():
29 state = PreRASimState(ssa_vals={}, memory={})
30 with state.set_as_current_debugging_state():
31 yield state
32 return state_factory
33
34
35 class Mul:
36 _MulFn = Callable[[Fn, SSAVal, SSAVal], Tuple[SSAVal, Any]]
37
38 def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
39 # type: (_MulFn, int, int) -> None
40 super().__init__()
41 self.fn = fn = Fn()
42 self.dest_offset = 0
43 self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
44 self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
45 self.lhs_size_in_words = lhs_size_in_words
46 self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
47 self.rhs_size_in_words = rhs_size_in_words
48 self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
49 self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
50 self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
51 self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
52 name="ptr_in").outputs[0]
53 self.lhs_setvl = fn.append_new_op(
54 kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
55 maxvl=lhs_size_in_words, name="lhs_setvl")
56 self.load_lhs = fn.append_new_op(
57 kind=OpKind.SvLd, immediates=[self.lhs_offset],
58 input_vals=[self.ptr_in, self.lhs_setvl.outputs[0]],
59 name="load_lhs", maxvl=lhs_size_in_words)
60 self.rhs_setvl = fn.append_new_op(
61 kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
62 maxvl=rhs_size_in_words, name="rhs_setvl")
63 self.load_rhs = fn.append_new_op(
64 kind=OpKind.SvLd, immediates=[self.rhs_offset],
65 input_vals=[self.ptr_in, self.rhs_setvl.outputs[0]],
66 name="load_rhs", maxvl=rhs_size_in_words)
67 self.retval = mul(
68 fn, self.load_lhs.outputs[0], self.load_rhs.outputs[0])
69 self.dest_setvl = fn.append_new_op(
70 kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
71 maxvl=self.dest_size_in_words, name="dest_setvl")
72 self.store = fn.append_new_op(
73 kind=OpKind.SvStd,
74 input_vals=[self.retval[0], self.ptr_in,
75 self.dest_setvl.outputs[0]],
76 immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
77 name="store_dest")
78
79
80 def get_post_ra_state_factory(code):
81 # type: (Mul) -> _StateFactory
82 ssa_val_to_loc_map = allocate_registers(code.fn)
83
84 @contextmanager
85 def state_factory():
86 yield PostRASimState(
87 ssa_val_to_loc_map=ssa_val_to_loc_map,
88 memory={}, loc_values={})
89 return state_factory
90
91
92 class TestToomCook(unittest.TestCase):
93 maxDiff = None
94
95 def test_toom_2_repr(self):
96 TOOM_2 = ToomCookInstance.make_toom_2()
97 # print(repr(repr(TOOM_2)))
98 self.assertEqual(
99 repr(TOOM_2),
100 "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
101 "eval_points=(0, 1, POINT_AT_INFINITY), "
102 "lhs_eval_ops=("
103 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
104 "EvalOpAdd(lhs="
105 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
106 "rhs="
107 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
108 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
109 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
110 " rhs_eval_ops=("
111 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
112 "EvalOpAdd(lhs="
113 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
114 "rhs="
115 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
116 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
117 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
118 " prod_eval_ops=("
119 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
120 "EvalOpSub(lhs="
121 "EvalOpSub(lhs="
122 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
123 "rhs="
124 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
125 "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
126 "rhs="
127 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
128 "poly=EvalOpPoly({"
129 "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
130 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
131 )
132
133 def test_toom_2_5_repr(self):
134 TOOM_2_5 = ToomCookInstance.make_toom_2_5()
135 # print(repr(repr(TOOM_2_5)))
136 self.assertEqual(
137 repr(TOOM_2_5),
138 "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
139 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
140 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
141 "EvalOpAdd(lhs="
142 "EvalOpAdd(lhs="
143 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
144 "rhs=EvalOpInput(lhs=2, rhs=0, "
145 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
146 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
147 "rhs=EvalOpInput(lhs=1, rhs=0, "
148 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
149 "poly=EvalOpPoly({"
150 "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
151 "EvalOpSub(lhs="
152 "EvalOpAdd(lhs="
153 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
154 "rhs=EvalOpInput(lhs=2, rhs=0, "
155 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
156 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
157 "rhs=EvalOpInput(lhs=1, rhs=0, "
158 "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
159 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
160 "EvalOpInput(lhs=2, rhs=0, "
161 "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
162 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
163 "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
164 "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
165 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
166 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
167 "EvalOpSub(lhs="
168 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
169 "rhs=EvalOpInput(lhs=1, rhs=0, "
170 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
171 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
172 "EvalOpInput(lhs=1, rhs=0, "
173 "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
174 "prod_eval_ops=("
175 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
176 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
177 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
178 "rhs=EvalOpInput(lhs=2, rhs=0, "
179 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
180 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
181 "rhs=2, "
182 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
183 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
184 "poly=EvalOpPoly("
185 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
186 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
187 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
188 "rhs="
189 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
190 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
191 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
192 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
193 "poly=EvalOpPoly("
194 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
195 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
196 )
197
198 def test_reversed_toom_2_5_repr(self):
199 TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
200 # print(repr(repr(TOOM_2_5)))
201 self.assertEqual(
202 repr(TOOM_2_5),
203 "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
204 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
205 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
206 "EvalOpAdd(lhs="
207 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
208 "rhs="
209 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
210 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
211 "EvalOpSub(lhs="
212 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
213 "rhs="
214 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
215 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
216 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
217 " rhs_eval_ops=("
218 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
219 "EvalOpAdd(lhs=EvalOpAdd(lhs="
220 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
221 "rhs="
222 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
223 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
224 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
225 "poly=EvalOpPoly("
226 "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
227 "EvalOpSub(lhs=EvalOpAdd(lhs="
228 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
229 "rhs="
230 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
231 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
232 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
233 "poly=EvalOpPoly("
234 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
235 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
236 " prod_eval_ops=("
237 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
238 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
239 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
240 "rhs="
241 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
242 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
243 "rhs=2, "
244 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
245 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
246 "poly=EvalOpPoly("
247 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
248 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
249 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
250 "rhs="
251 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
252 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
253 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
254 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
255 "poly=EvalOpPoly("
256 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
257 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
258 )
259
260 def test_simple_mul_192x192_pre_ra_sim(self):
261 for lhs_signed in False, True:
262 for rhs_signed in False, True:
263 self.tst_simple_mul_192x192_sim(
264 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
265 get_state_factory=get_pre_ra_state_factory)
266
267 def test_simple_mul_192x192_post_ra_sim(self):
268 for lhs_signed in False, True:
269 for rhs_signed in False, True:
270 self.tst_simple_mul_192x192_sim(
271 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
272 get_state_factory=get_post_ra_state_factory)
273
274 def tst_simple_mul_192x192_sim(
275 self, lhs_signed, # type: bool
276 rhs_signed, # type: bool
277 get_state_factory, # type: Callable[[Mul], _StateFactory]
278 ):
279 # test multiplying:
280 # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
281 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
282 # ==
283 # int("0x00074736574206e_6f69746163696c70"
284 # "_69746c756d207469_622d3438333e2d32"
285 # "_3931783239312079_7261727469627261", base=0)
286 # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
287 # 'little')
288 lhs_value = 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
289 rhs_value = 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
290 prod_value = int.from_bytes(
291 b"arbitrary 192x192->384-bit multiplication test", 'little')
292 self.assertEqual(lhs_value * rhs_value, prod_value)
293 code = Mul(
294 mul=lambda fn, lhs, rhs: (simple_mul(
295 fn=fn, lhs=lhs, lhs_signed=lhs_signed,
296 rhs=rhs, rhs_signed=rhs_signed, name="mul"), None),
297 lhs_size_in_words=3, rhs_size_in_words=3)
298 state_factory = get_state_factory(code)
299 ptr_in = 0x100
300 dest_ptr = ptr_in + code.dest_offset
301 lhs_ptr = ptr_in + code.lhs_offset
302 rhs_ptr = ptr_in + code.rhs_offset
303 for lhs_neg in False, True:
304 for rhs_neg in False, True:
305 if lhs_neg and not lhs_signed:
306 continue
307 if rhs_neg and not rhs_signed:
308 continue
309 with self.subTest(lhs_signed=lhs_signed,
310 rhs_signed=rhs_signed,
311 lhs_neg=lhs_neg, rhs_neg=rhs_neg):
312 with state_factory() as state:
313 state[code.ptr_in] = ptr_in,
314 lhs = lhs_value
315 if lhs_neg:
316 lhs = 2 ** 192 - lhs
317 rhs = rhs_value
318 if rhs_neg:
319 rhs = 2 ** 192 - rhs
320 for i in range(3):
321 v = (lhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
322 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
323 for i in range(3):
324 v = (rhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
325 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
326 code.fn.sim(state)
327 expected = prod_value
328 if lhs_neg != rhs_neg:
329 expected = 2 ** 384 - expected
330 prod = 0
331 for i in range(6):
332 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
333 prod += v << (GPR_SIZE_IN_BITS * i)
334 self.assertEqual(hex(prod), hex(expected))
335
336 def test_simple_mul_192x192_ops(self):
337 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
338 fn = code.fn
339 self.assertEqual(
340 fn.ops_to_str(),
341 "ptr_in:\n"
342 " (<...outputs[0]: <I64>>) <= FuncArgR3\n"
343 "lhs_setvl:\n"
344 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
345 "load_lhs:\n"
346 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
347 " <ptr_in.outputs[0]: <I64>>,\n"
348 " <lhs_setvl.outputs[0]: <VL_MAXVL>>, 0x30)\n"
349 "rhs_setvl:\n"
350 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
351 "load_rhs:\n"
352 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
353 " <ptr_in.outputs[0]: <I64>>,\n"
354 " <rhs_setvl.outputs[0]: <VL_MAXVL>>, 0x48)\n"
355 "mul_rhs_setvl:\n"
356 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
357 "mul_rhs_spread:\n"
358 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
359 " <...outputs[2]: <I64>>) <= Spread(\n"
360 " <load_rhs.outputs[0]: <I64*3>>,\n"
361 " <mul_rhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
362 "mul_zero:\n"
363 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
364 "mul_lhs_setvl:\n"
365 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
366 "mul_zero2:\n"
367 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
368 "mul_0_mul:\n"
369 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
370 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
371 " <mul_rhs_spread.outputs[0]: <I64>>,\n"
372 " <mul_zero.outputs[0]: <I64>>,\n"
373 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
374 "mul_0_mul_rt_spread:\n"
375 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
376 " <...outputs[2]: <I64>>) <= Spread(\n"
377 " <mul_0_mul.outputs[0]: <I64*3>>,\n"
378 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
379 "mul_1_mul:\n"
380 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
381 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
382 " <mul_rhs_spread.outputs[1]: <I64>>,\n"
383 " <mul_zero.outputs[0]: <I64>>,\n"
384 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
385 "mul_1_mul_rt_spread:\n"
386 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
387 " <...outputs[2]: <I64>>) <= Spread(\n"
388 " <mul_1_mul.outputs[0]: <I64*3>>,\n"
389 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
390 "mul_1_cast_retval_zero:\n"
391 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
392 "mul_1_cast_pp_zero:\n"
393 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
394 "mul_1_setvl:\n"
395 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x5)\n"
396 "mul_1_retval_concat:\n"
397 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
398 " <mul_0_mul_rt_spread.outputs[1]: <I64>>,\n"
399 " <mul_0_mul_rt_spread.outputs[2]: <I64>>,\n"
400 " <mul_0_mul.outputs[1]: <I64>>,\n"
401 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
402 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
403 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
404 "mul_1_pp_concat:\n"
405 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
406 " <mul_1_mul_rt_spread.outputs[0]: <I64>>,\n"
407 " <mul_1_mul_rt_spread.outputs[1]: <I64>>,\n"
408 " <mul_1_mul_rt_spread.outputs[2]: <I64>>,\n"
409 " <mul_1_mul.outputs[1]: <I64>>,\n"
410 " <mul_1_cast_pp_zero.outputs[0]: <I64>>,\n"
411 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
412 "mul_1_clear_ca:\n"
413 " (<...outputs[0]: <CA>>) <= ClearCA\n"
414 "mul_1_add:\n"
415 " (<...outputs[0]: <I64*5>>, <...outputs[1]: <CA>>\n"
416 " ) <= SvAddE(<mul_1_retval_concat.outputs[0]: <I64*5>>,\n"
417 " <mul_1_pp_concat.outputs[0]: <I64*5>>,\n"
418 " <mul_1_clear_ca.outputs[0]: <CA>>,\n"
419 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
420 "mul_1_sum_spread:\n"
421 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
422 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>,\n"
423 " <...outputs[4]: <I64>>) <= Spread(\n"
424 " <mul_1_add.outputs[0]: <I64*5>>,\n"
425 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
426 "mul_2_mul:\n"
427 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
428 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
429 " <mul_rhs_spread.outputs[2]: <I64>>,\n"
430 " <mul_zero.outputs[0]: <I64>>,\n"
431 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
432 "mul_2_mul_rt_spread:\n"
433 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
434 " <...outputs[2]: <I64>>) <= Spread(\n"
435 " <mul_2_mul.outputs[0]: <I64*3>>,\n"
436 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
437 "mul_2_setvl:\n"
438 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x4)\n"
439 "mul_2_retval_concat:\n"
440 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
441 " <mul_1_sum_spread.outputs[1]: <I64>>,\n"
442 " <mul_1_sum_spread.outputs[2]: <I64>>,\n"
443 " <mul_1_sum_spread.outputs[3]: <I64>>,\n"
444 " <mul_1_sum_spread.outputs[4]: <I64>>,\n"
445 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
446 "mul_2_pp_concat:\n"
447 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
448 " <mul_2_mul_rt_spread.outputs[0]: <I64>>,\n"
449 " <mul_2_mul_rt_spread.outputs[1]: <I64>>,\n"
450 " <mul_2_mul_rt_spread.outputs[2]: <I64>>,\n"
451 " <mul_2_mul.outputs[1]: <I64>>,\n"
452 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
453 "mul_2_clear_ca:\n"
454 " (<...outputs[0]: <CA>>) <= ClearCA\n"
455 "mul_2_add:\n"
456 " (<...outputs[0]: <I64*4>>, <...outputs[1]: <CA>>\n"
457 " ) <= SvAddE(<mul_2_retval_concat.outputs[0]: <I64*4>>,\n"
458 " <mul_2_pp_concat.outputs[0]: <I64*4>>,\n"
459 " <mul_2_clear_ca.outputs[0]: <CA>>,\n"
460 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
461 "mul_2_sum_spread:\n"
462 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
463 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>) <= Spread(\n"
464 " <mul_2_add.outputs[0]: <I64*4>>,\n"
465 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
466 "mul_setvl:\n"
467 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
468 "mul_concat:\n"
469 " (<...outputs[0]: <I64*6>>) <= Concat(\n"
470 " <mul_0_mul_rt_spread.outputs[0]: <I64>>,\n"
471 " <mul_1_sum_spread.outputs[0]: <I64>>,\n"
472 " <mul_2_sum_spread.outputs[0]: <I64>>,\n"
473 " <mul_2_sum_spread.outputs[1]: <I64>>,\n"
474 " <mul_2_sum_spread.outputs[2]: <I64>>,\n"
475 " <mul_2_sum_spread.outputs[3]: <I64>>,\n"
476 " <mul_setvl.outputs[0]: <VL_MAXVL>>)\n"
477 "dest_setvl:\n"
478 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
479 "store_dest:\n"
480 " SvStd(<mul_concat.outputs[0]: <I64*6>>,\n"
481 " <ptr_in.outputs[0]: <I64>>,\n"
482 " <dest_setvl.outputs[0]: <VL_MAXVL>>, 0x0)"
483 )
484
485 def test_simple_mul_192x192_reg_alloc(self):
486 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
487 fn = code.fn
488 assigned_registers = allocate_registers(fn)
489 self.assertEqual(
490 repr(assigned_registers), "{"
491 "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
492 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
493 "<store_dest.inp1.copy.outputs[0]: <I64>>: "
494 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
495 "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
496 "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
497 "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
498 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
499 "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
500 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
501 "<mul_concat.out0.copy.outputs[0]: <I64*6>>: "
502 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
503 "<mul_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
504 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
505 "<mul_concat.outputs[0]: <I64*6>>: "
506 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
507 "<mul_concat.inp0.copy.outputs[0]: <I64>>: "
508 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
509 "<mul_concat.inp1.copy.outputs[0]: <I64>>: "
510 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
511 "<mul_concat.inp2.copy.outputs[0]: <I64>>: "
512 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
513 "<mul_concat.inp3.copy.outputs[0]: <I64>>: "
514 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
515 "<mul_concat.inp4.copy.outputs[0]: <I64>>: "
516 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
517 "<mul_concat.inp5.copy.outputs[0]: <I64>>: "
518 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
519 "<mul_concat.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
520 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
521 "<mul_setvl.outputs[0]: <VL_MAXVL>>: "
522 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
523 "<mul_2_sum_spread.out3.copy.outputs[0]: <I64>>: "
524 "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
525 "<mul_2_sum_spread.out2.copy.outputs[0]: <I64>>: "
526 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
527 "<mul_2_sum_spread.out1.copy.outputs[0]: <I64>>: "
528 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
529 "<mul_2_sum_spread.out0.copy.outputs[0]: <I64>>: "
530 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
531 "<mul_2_sum_spread.outputs[0]: <I64>>: "
532 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
533 "<mul_2_sum_spread.outputs[1]: <I64>>: "
534 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
535 "<mul_2_sum_spread.outputs[2]: <I64>>: "
536 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
537 "<mul_2_sum_spread.outputs[3]: <I64>>: "
538 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
539 "<mul_2_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
540 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
541 "<mul_2_sum_spread.inp0.copy.outputs[0]: <I64*4>>: "
542 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
543 "<mul_2_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
544 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
545 "<mul_2_add.out0.copy.outputs[0]: <I64*4>>: "
546 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
547 "<mul_2_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
548 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
549 "<mul_2_clear_ca.outputs[0]: <CA>>: "
550 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
551 "<mul_2_add.outputs[1]: <CA>>: "
552 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
553 "<mul_2_add.outputs[0]: <I64*4>>: "
554 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
555 "<mul_2_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
556 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
557 "<mul_2_add.inp1.copy.outputs[0]: <I64*4>>: "
558 "Loc(kind=LocKind.GPR, start=7, reg_len=4), "
559 "<mul_2_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
560 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
561 "<mul_2_add.inp0.copy.outputs[0]: <I64*4>>: "
562 "Loc(kind=LocKind.GPR, start=14, reg_len=4), "
563 "<mul_2_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
564 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
565 "<mul_2_pp_concat.out0.copy.outputs[0]: <I64*4>>: "
566 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
567 "<mul_2_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
568 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
569 "<mul_2_pp_concat.outputs[0]: <I64*4>>: "
570 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
571 "<mul_2_pp_concat.inp0.copy.outputs[0]: <I64>>: "
572 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
573 "<mul_2_pp_concat.inp1.copy.outputs[0]: <I64>>: "
574 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
575 "<mul_2_pp_concat.inp2.copy.outputs[0]: <I64>>: "
576 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
577 "<mul_2_pp_concat.inp3.copy.outputs[0]: <I64>>: "
578 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
579 "<mul_2_pp_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
580 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
581 "<mul_2_retval_concat.out0.copy.outputs[0]: <I64*4>>: "
582 "Loc(kind=LocKind.GPR, start=7, reg_len=4), "
583 "<mul_2_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
584 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
585 "<mul_2_retval_concat.outputs[0]: <I64*4>>: "
586 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
587 "<mul_2_retval_concat.inp0.copy.outputs[0]: <I64>>: "
588 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
589 "<mul_2_retval_concat.inp1.copy.outputs[0]: <I64>>: "
590 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
591 "<mul_2_retval_concat.inp2.copy.outputs[0]: <I64>>: "
592 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
593 "<mul_2_retval_concat.inp3.copy.outputs[0]: <I64>>: "
594 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
595 "<mul_2_retval_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
596 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
597 "<mul_2_setvl.outputs[0]: <VL_MAXVL>>: "
598 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
599 "<mul_2_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
600 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
601 "<mul_2_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
602 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
603 "<mul_2_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
604 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
605 "<mul_2_mul_rt_spread.outputs[0]: <I64>>: "
606 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
607 "<mul_2_mul_rt_spread.outputs[1]: <I64>>: "
608 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
609 "<mul_2_mul_rt_spread.outputs[2]: <I64>>: "
610 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
611 "<mul_2_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
612 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
613 "<mul_2_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
614 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
615 "<mul_2_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
616 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
617 "<mul_2_mul.out1.copy.outputs[0]: <I64>>: "
618 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
619 "<mul_2_mul.out0.copy.outputs[0]: <I64*3>>: "
620 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
621 "<mul_2_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
622 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
623 "<mul_2_mul.inp2.copy.outputs[0]: <I64>>: "
624 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
625 "<mul_2_mul.outputs[1]: <I64>>: "
626 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
627 "<mul_2_mul.outputs[0]: <I64*3>>: "
628 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
629 "<mul_2_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
630 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
631 "<mul_2_mul.inp1.copy.outputs[0]: <I64>>: "
632 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
633 "<mul_2_mul.inp0.copy.outputs[0]: <I64*3>>: "
634 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
635 "<mul_2_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
636 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
637 "<mul_1_sum_spread.out4.copy.outputs[0]: <I64>>: "
638 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
639 "<mul_1_sum_spread.out3.copy.outputs[0]: <I64>>: "
640 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
641 "<mul_1_sum_spread.out2.copy.outputs[0]: <I64>>: "
642 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
643 "<mul_1_sum_spread.out1.copy.outputs[0]: <I64>>: "
644 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
645 "<mul_1_sum_spread.out0.copy.outputs[0]: <I64>>: "
646 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
647 "<mul_1_sum_spread.outputs[0]: <I64>>: "
648 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
649 "<mul_1_sum_spread.outputs[1]: <I64>>: "
650 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
651 "<mul_1_sum_spread.outputs[2]: <I64>>: "
652 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
653 "<mul_1_sum_spread.outputs[3]: <I64>>: "
654 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
655 "<mul_1_sum_spread.outputs[4]: <I64>>: "
656 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
657 "<mul_1_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
658 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
659 "<mul_1_sum_spread.inp0.copy.outputs[0]: <I64*5>>: "
660 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
661 "<mul_1_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
662 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
663 "<mul_1_add.out0.copy.outputs[0]: <I64*5>>: "
664 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
665 "<mul_1_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
666 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
667 "<mul_1_clear_ca.outputs[0]: <CA>>: "
668 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
669 "<mul_1_add.outputs[1]: <CA>>: "
670 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
671 "<mul_1_add.outputs[0]: <I64*5>>: "
672 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
673 "<mul_1_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
674 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
675 "<mul_1_add.inp1.copy.outputs[0]: <I64*5>>: "
676 "Loc(kind=LocKind.GPR, start=8, reg_len=5), "
677 "<mul_1_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
678 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
679 "<mul_1_add.inp0.copy.outputs[0]: <I64*5>>: "
680 "Loc(kind=LocKind.GPR, start=14, reg_len=5), "
681 "<mul_1_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
682 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
683 "<mul_1_pp_concat.out0.copy.outputs[0]: <I64*5>>: "
684 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
685 "<mul_1_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
686 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
687 "<mul_1_pp_concat.outputs[0]: <I64*5>>: "
688 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
689 "<mul_1_pp_concat.inp0.copy.outputs[0]: <I64>>: "
690 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
691 "<mul_1_pp_concat.inp1.copy.outputs[0]: <I64>>: "
692 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
693 "<mul_1_pp_concat.inp2.copy.outputs[0]: <I64>>: "
694 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
695 "<mul_1_pp_concat.inp3.copy.outputs[0]: <I64>>: "
696 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
697 "<mul_1_pp_concat.inp4.copy.outputs[0]: <I64>>: "
698 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
699 "<mul_1_pp_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
700 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
701 "<mul_1_retval_concat.out0.copy.outputs[0]: <I64*5>>: "
702 "Loc(kind=LocKind.GPR, start=8, reg_len=5), "
703 "<mul_1_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
704 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
705 "<mul_1_retval_concat.outputs[0]: <I64*5>>: "
706 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
707 "<mul_1_retval_concat.inp0.copy.outputs[0]: <I64>>: "
708 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
709 "<mul_1_retval_concat.inp1.copy.outputs[0]: <I64>>: "
710 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
711 "<mul_1_retval_concat.inp2.copy.outputs[0]: <I64>>: "
712 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
713 "<mul_1_retval_concat.inp3.copy.outputs[0]: <I64>>: "
714 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
715 "<mul_1_retval_concat.inp4.copy.outputs[0]: <I64>>: "
716 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
717 "<mul_1_retval_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
718 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
719 "<mul_1_setvl.outputs[0]: <VL_MAXVL>>: "
720 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
721 "<mul_1_cast_pp_zero.out0.copy.outputs[0]: <I64>>: "
722 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
723 "<mul_1_cast_pp_zero.outputs[0]: <I64>>: "
724 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
725 "<mul_1_cast_retval_zero.out0.copy.outputs[0]: <I64>>: "
726 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
727 "<mul_1_cast_retval_zero.outputs[0]: <I64>>: "
728 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
729 "<mul_1_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
730 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
731 "<mul_1_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
732 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
733 "<mul_1_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
734 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
735 "<mul_1_mul_rt_spread.outputs[0]: <I64>>: "
736 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
737 "<mul_1_mul_rt_spread.outputs[1]: <I64>>: "
738 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
739 "<mul_1_mul_rt_spread.outputs[2]: <I64>>: "
740 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
741 "<mul_1_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
742 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
743 "<mul_1_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
744 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
745 "<mul_1_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
746 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
747 "<mul_1_mul.out1.copy.outputs[0]: <I64>>: "
748 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
749 "<mul_1_mul.out0.copy.outputs[0]: <I64*3>>: "
750 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
751 "<mul_1_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
752 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
753 "<mul_1_mul.inp2.copy.outputs[0]: <I64>>: "
754 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
755 "<mul_1_mul.outputs[1]: <I64>>: "
756 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
757 "<mul_1_mul.outputs[0]: <I64*3>>: "
758 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
759 "<mul_1_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
760 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
761 "<mul_1_mul.inp1.copy.outputs[0]: <I64>>: "
762 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
763 "<mul_1_mul.inp0.copy.outputs[0]: <I64*3>>: "
764 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
765 "<mul_1_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
766 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
767 "<mul_0_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
768 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
769 "<mul_0_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
770 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
771 "<mul_0_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
772 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
773 "<mul_0_mul_rt_spread.outputs[0]: <I64>>: "
774 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
775 "<mul_0_mul_rt_spread.outputs[1]: <I64>>: "
776 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
777 "<mul_0_mul_rt_spread.outputs[2]: <I64>>: "
778 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
779 "<mul_0_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
780 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
781 "<mul_0_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
782 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
783 "<mul_0_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
784 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
785 "<mul_0_mul.out1.copy.outputs[0]: <I64>>: "
786 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
787 "<mul_0_mul.out0.copy.outputs[0]: <I64*3>>: "
788 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
789 "<mul_0_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
790 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
791 "<mul_0_mul.inp2.copy.outputs[0]: <I64>>: "
792 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
793 "<mul_0_mul.outputs[1]: <I64>>: "
794 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
795 "<mul_0_mul.outputs[0]: <I64*3>>: "
796 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
797 "<mul_0_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
798 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
799 "<mul_0_mul.inp1.copy.outputs[0]: <I64>>: "
800 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
801 "<mul_0_mul.inp0.copy.outputs[0]: <I64*3>>: "
802 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
803 "<mul_0_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
804 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
805 "<mul_zero2.out0.copy.outputs[0]: <I64>>: "
806 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
807 "<mul_zero2.outputs[0]: <I64>>: "
808 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
809 "<mul_lhs_setvl.outputs[0]: <VL_MAXVL>>: "
810 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
811 "<mul_zero.out0.copy.outputs[0]: <I64>>: "
812 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
813 "<mul_zero.outputs[0]: <I64>>: "
814 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
815 "<mul_rhs_spread.out2.copy.outputs[0]: <I64>>: "
816 "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
817 "<mul_rhs_spread.out1.copy.outputs[0]: <I64>>: "
818 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
819 "<mul_rhs_spread.out0.copy.outputs[0]: <I64>>: "
820 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
821 "<mul_rhs_spread.outputs[0]: <I64>>: "
822 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
823 "<mul_rhs_spread.outputs[1]: <I64>>: "
824 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
825 "<mul_rhs_spread.outputs[2]: <I64>>: "
826 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
827 "<mul_rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
828 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
829 "<mul_rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
830 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
831 "<mul_rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
832 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
833 "<mul_rhs_setvl.outputs[0]: <VL_MAXVL>>: "
834 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
835 "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
836 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
837 "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
838 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
839 "<load_rhs.outputs[0]: <I64*3>>: "
840 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
841 "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
842 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
843 "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
844 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
845 "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
846 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
847 "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
848 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
849 "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
850 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
851 "<load_lhs.outputs[0]: <I64*3>>: "
852 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
853 "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
854 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
855 "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
856 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
857 "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
858 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
859 "<ptr_in.out0.copy.outputs[0]: <I64>>: "
860 "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
861 "<ptr_in.outputs[0]: <I64>>: "
862 "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
863 "}")
864
865 def test_simple_mul_192x192_asm(self):
866 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
867 fn = code.fn
868 assigned_registers = allocate_registers(fn)
869 gen_asm_state = GenAsmState(assigned_registers)
870 fn.gen_asm(gen_asm_state)
871 self.assertEqual(gen_asm_state.output, [
872 'or 27, 3, 3',
873 'setvl 0, 0, 3, 0, 1, 1',
874 'or 6, 27, 27',
875 'setvl 0, 0, 3, 0, 1, 1',
876 'sv.ld *3, 48(6)',
877 'setvl 0, 0, 3, 0, 1, 1',
878 'sv.or *24, *3, *3',
879 'setvl 0, 0, 3, 0, 1, 1',
880 'or 6, 27, 27',
881 'setvl 0, 0, 3, 0, 1, 1',
882 'sv.ld *3, 72(6)',
883 'setvl 0, 0, 3, 0, 1, 1',
884 'setvl 0, 0, 3, 0, 1, 1',
885 'setvl 0, 0, 3, 0, 1, 1',
886 'setvl 0, 0, 3, 0, 1, 1',
887 'sv.or/mrr *5, *3, *3',
888 'or 4, 5, 5',
889 'or 14, 6, 6',
890 'or 23, 7, 7',
891 'addi 3, 0, 0',
892 'or 22, 3, 3',
893 'setvl 0, 0, 3, 0, 1, 1',
894 'addi 3, 0, 0',
895 'setvl 0, 0, 3, 0, 1, 1',
896 'sv.or *8, *24, *24',
897 'or 7, 4, 4',
898 'or 6, 22, 22',
899 'setvl 0, 0, 3, 0, 1, 1',
900 'sv.maddedu *3, *8, 7, 6',
901 'setvl 0, 0, 3, 0, 1, 1',
902 'or 19, 6, 6',
903 'setvl 0, 0, 3, 0, 1, 1',
904 'setvl 0, 0, 3, 0, 1, 1',
905 'or 21, 3, 3',
906 'or 12, 4, 4',
907 'or 11, 5, 5',
908 'setvl 0, 0, 3, 0, 1, 1',
909 'sv.or *8, *24, *24',
910 'or 7, 14, 14',
911 'or 6, 22, 22',
912 'setvl 0, 0, 3, 0, 1, 1',
913 'sv.maddedu *3, *8, 7, 6',
914 'setvl 0, 0, 3, 0, 1, 1',
915 'or 18, 6, 6',
916 'setvl 0, 0, 3, 0, 1, 1',
917 'setvl 0, 0, 3, 0, 1, 1',
918 'or 17, 3, 3',
919 'or 16, 4, 4',
920 'or 15, 5, 5',
921 'addi 3, 0, 0',
922 'or 8, 3, 3',
923 'addi 3, 0, 0',
924 'or 14, 3, 3',
925 'setvl 0, 0, 5, 0, 1, 1',
926 'or 3, 12, 12',
927 'or 4, 11, 11',
928 'or 5, 19, 19',
929 'or 6, 8, 8',
930 'or 7, 8, 8',
931 'setvl 0, 0, 5, 0, 1, 1',
932 'setvl 0, 0, 5, 0, 1, 1',
933 'sv.or *8, *3, *3',
934 'or 3, 17, 17',
935 'or 4, 16, 16',
936 'or 5, 15, 15',
937 'or 6, 18, 18',
938 'or 7, 14, 14',
939 'setvl 0, 0, 5, 0, 1, 1',
940 'setvl 0, 0, 5, 0, 1, 1',
941 'addic 0, 0, 0',
942 'setvl 0, 0, 5, 0, 1, 1',
943 'sv.or *14, *8, *8',
944 'setvl 0, 0, 5, 0, 1, 1',
945 'sv.or *8, *3, *3',
946 'setvl 0, 0, 5, 0, 1, 1',
947 'sv.adde *3, *14, *8',
948 'setvl 0, 0, 5, 0, 1, 1',
949 'setvl 0, 0, 5, 0, 1, 1',
950 'setvl 0, 0, 5, 0, 1, 1',
951 'or 20, 3, 3',
952 'or 19, 4, 4',
953 'or 18, 5, 5',
954 'or 17, 6, 6',
955 'or 16, 7, 7',
956 'setvl 0, 0, 3, 0, 1, 1',
957 'sv.or *8, *24, *24',
958 'or 7, 23, 23',
959 'or 6, 22, 22',
960 'setvl 0, 0, 3, 0, 1, 1',
961 'sv.maddedu *3, *8, 7, 6',
962 'setvl 0, 0, 3, 0, 1, 1',
963 'or 15, 6, 6',
964 'setvl 0, 0, 3, 0, 1, 1',
965 'setvl 0, 0, 3, 0, 1, 1',
966 'or 14, 3, 3',
967 'or 12, 4, 4',
968 'or 11, 5, 5',
969 'setvl 0, 0, 4, 0, 1, 1',
970 'or 3, 19, 19',
971 'or 4, 18, 18',
972 'or 5, 17, 17',
973 'or 6, 16, 16',
974 'setvl 0, 0, 4, 0, 1, 1',
975 'setvl 0, 0, 4, 0, 1, 1',
976 'sv.or *7, *3, *3',
977 'or 3, 14, 14',
978 'or 4, 12, 12',
979 'or 5, 11, 11',
980 'or 6, 15, 15',
981 'setvl 0, 0, 4, 0, 1, 1',
982 'setvl 0, 0, 4, 0, 1, 1',
983 'addic 0, 0, 0',
984 'setvl 0, 0, 4, 0, 1, 1',
985 'sv.or *14, *7, *7',
986 'setvl 0, 0, 4, 0, 1, 1',
987 'sv.or *7, *3, *3',
988 'setvl 0, 0, 4, 0, 1, 1',
989 'sv.adde *3, *14, *7',
990 'setvl 0, 0, 4, 0, 1, 1',
991 'setvl 0, 0, 4, 0, 1, 1',
992 'setvl 0, 0, 4, 0, 1, 1',
993 'or 12, 3, 3',
994 'or 11, 4, 4',
995 'or 10, 5, 5',
996 'or 9, 6, 6',
997 'setvl 0, 0, 6, 0, 1, 1',
998 'or 3, 21, 21',
999 'or 4, 20, 20',
1000 'or 5, 12, 12',
1001 'or 6, 11, 11',
1002 'or 7, 10, 10',
1003 'or 8, 9, 9',
1004 'setvl 0, 0, 6, 0, 1, 1',
1005 'setvl 0, 0, 6, 0, 1, 1',
1006 'setvl 0, 0, 6, 0, 1, 1',
1007 'setvl 0, 0, 6, 0, 1, 1',
1008 'sv.or/mrr *4, *3, *3',
1009 'or 3, 27, 27',
1010 'setvl 0, 0, 6, 0, 1, 1',
1011 'sv.std *4, 0(3)'
1012 ])
1013
1014 def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
1015 # type: (bool, bool) -> Mul
1016 TOOM_2 = ToomCookInstance.make_toom_2()
1017 instances = TOOM_2,
1018
1019 def mul(fn, lhs, rhs):
1020 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
1021 v = ToomCookMul(fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
1022 rhs_signed=rhs_signed, instances=instances)
1023 return v.retval, v
1024 return Mul(mul=mul, lhs_size_in_words=4, rhs_size_in_words=4)
1025
1026 def make_256x256_mul_test_cases(self, lhs_signed, rhs_signed):
1027 # type: (bool, bool) -> Iterator[tuple[int, int, int]]
1028 # test multiplying `+-1 << n` and:
1029 # 0xc162321a5eaad80b_4b86bb0efdfb93c0_a789ff04cc11b157_eaa08e29fb197621
1030 # *
1031 # 0x3138710167583371_998af336a8fac64d_e6da3737090787fe_85ba09ea701f4af2
1032 # ==
1033 # int("0x"
1034 # "252e6e6f69746163_696c7069746c754d_"
1035 # "2061627573746172_614b202d20322d4d_"
1036 # "4f4f5420676e6973_75206c756d20746e_"
1037 # "6967696220746962_2d36353278363532", base=0)
1038 # == int.from_bytes(b'256x256-bit bigint mul using TOOM-2 '
1039 # b'- Karatsuba Multiplication.%', 'little')
1040 lhs_value_in = (0xc162321a5eaad80b_4b86bb0efdfb93c0 << 128) \
1041 | 0xa789ff04cc11b157_eaa08e29fb197621
1042 rhs_value_in = (0x3138710167583371_998af336a8fac64d << 128) \
1043 | 0xe6da3737090787fe_85ba09ea701f4af2
1044 prod_value_in = int.from_bytes(
1045 b'256x256-bit bigint mul using TOOM-2 '
1046 b'- Karatsuba Multiplication.%', 'little')
1047 self.assertEqual(lhs_value_in * rhs_value_in, prod_value_in)
1048 shifts = [*range(0, 256, 16), *range(15, 256, 16)]
1049 lhs_values = [1 << i for i in shifts] + [0, lhs_value_in]
1050 rhs_values = [1 << i for i in shifts] + [0, rhs_value_in]
1051 if lhs_signed:
1052 lhs_values.extend([-i for i in lhs_values])
1053 if rhs_signed:
1054 rhs_values.extend([-i for i in rhs_values])
1055
1056 def key(v):
1057 # type: (int) -> tuple[bool, int]
1058 return abs(v) in (lhs_value_in, rhs_value_in), v % (1 << 256)
1059
1060 lhs_values.sort(key=key)
1061 rhs_values.sort(key=key)
1062 for lhs_value in lhs_values:
1063 for rhs_value in rhs_values:
1064 lhs_value %= 1 << 256
1065 rhs_value %= 1 << 256
1066 if lhs_value >> 255 != 0 and lhs_signed:
1067 lhs_value -= 1 << 256
1068 if rhs_value >> 255 != 0 and rhs_signed:
1069 rhs_value -= 1 << 256
1070 prod_value = lhs_value * rhs_value
1071 lhs_value %= 1 << 256
1072 rhs_value %= 1 << 256
1073 prod_value %= 1 << 512
1074 yield lhs_value, rhs_value, prod_value
1075
1076 def tst_toom_2_mul_256x256_sim(
1077 self, lhs_signed, # type: bool
1078 rhs_signed, # type: bool
1079 get_state_factory, # type: Callable[[Mul], _StateFactory]
1080 ):
1081 code = self.toom_2_mul_256x256(
1082 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1083 print(code.retval[1])
1084 print(code.fn.ops_to_str())
1085 state_factory = get_state_factory(code)
1086 ptr_in = 0x100
1087 dest_ptr = ptr_in + code.dest_offset
1088 lhs_ptr = ptr_in + code.lhs_offset
1089 rhs_ptr = ptr_in + code.rhs_offset
1090 values = self.make_256x256_mul_test_cases(
1091 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1092 for lhs_value, rhs_value, prod_value in values:
1093 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1094 lhs_value=hex(lhs_value),
1095 rhs_value=hex(rhs_value),
1096 prod_value=hex(prod_value)):
1097 with state_factory() as state:
1098 state[code.ptr_in] = ptr_in,
1099 for i in range(4):
1100 v = lhs_value >> GPR_SIZE_IN_BITS * i
1101 v &= GPR_VALUE_MASK
1102 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1103 for i in range(4):
1104 v = rhs_value >> GPR_SIZE_IN_BITS * i
1105 v &= GPR_VALUE_MASK
1106 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1107 code.fn.sim(state)
1108 prod = 0
1109 for i in range(8):
1110 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1111 prod += v << (GPR_SIZE_IN_BITS * i)
1112 self.assertEqual(hex(prod), hex(prod_value),
1113 f"failed: state={state}")
1114
1115 def test_toom_2_mul_256x256_pre_ra_sim(self):
1116 for lhs_signed in False, True:
1117 for rhs_signed in False, True:
1118 self.tst_toom_2_mul_256x256_sim(
1119 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1120 get_state_factory=get_pre_ra_state_factory)
1121
1122 def test_toom_2_mul_256x256_uu_post_ra_sim(self):
1123 self.tst_toom_2_mul_256x256_sim(
1124 lhs_signed=False, rhs_signed=False,
1125 get_state_factory=get_post_ra_state_factory)
1126
1127 def test_toom_2_mul_256x256_su_post_ra_sim(self):
1128 self.tst_toom_2_mul_256x256_sim(
1129 lhs_signed=True, rhs_signed=False,
1130 get_state_factory=get_post_ra_state_factory)
1131
1132 def test_toom_2_mul_256x256_us_post_ra_sim(self):
1133 self.tst_toom_2_mul_256x256_sim(
1134 lhs_signed=False, rhs_signed=True,
1135 get_state_factory=get_post_ra_state_factory)
1136
1137 def test_toom_2_mul_256x256_ss_post_ra_sim(self):
1138 self.tst_toom_2_mul_256x256_sim(
1139 lhs_signed=True, rhs_signed=True,
1140 get_state_factory=get_post_ra_state_factory)
1141
1142 def test_toom_2_mul_256x256_asm(self):
1143 code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
1144 fn = code.fn
1145 assigned_registers = allocate_registers(fn)
1146 gen_asm_state = GenAsmState(assigned_registers)
1147 fn.gen_asm(gen_asm_state)
1148 self.assertEqual(gen_asm_state.output, [
1149 'or 42, 3, 3',
1150 'setvl 0, 0, 4, 0, 1, 1',
1151 'or 7, 42, 42',
1152 'setvl 0, 0, 4, 0, 1, 1',
1153 'sv.ld *3, 64(7)',
1154 'setvl 0, 0, 4, 0, 1, 1',
1155 'sv.or *8, *3, *3',
1156 'setvl 0, 0, 4, 0, 1, 1',
1157 'or 7, 42, 42',
1158 'setvl 0, 0, 4, 0, 1, 1',
1159 'sv.ld *3, 96(7)',
1160 'setvl 0, 0, 4, 0, 1, 1',
1161 'sv.or *14, *3, *3',
1162 'setvl 0, 0, 4, 0, 1, 1',
1163 'setvl 0, 0, 4, 0, 1, 1',
1164 'sv.or *3, *8, *8',
1165 'setvl 0, 0, 4, 0, 1, 1',
1166 'sv.or *9, *3, *3',
1167 'or 3, 9, 9',
1168 'or 5, 10, 10',
1169 'or 8, 11, 11',
1170 'or 7, 12, 12',
1171 'setvl 0, 0, 2, 0, 1, 1',
1172 'or 4, 5, 5',
1173 'setvl 0, 0, 2, 0, 1, 1',
1174 'setvl 0, 0, 2, 0, 1, 1',
1175 'sv.or *5, *3, *3',
1176 'setvl 0, 0, 2, 0, 1, 1',
1177 'or 3, 8, 8',
1178 'or 4, 7, 7',
1179 'setvl 0, 0, 2, 0, 1, 1',
1180 'setvl 0, 0, 2, 0, 1, 1',
1181 'sv.or *8, *3, *3',
1182 'setvl 0, 0, 2, 0, 1, 1',
1183 'setvl 0, 0, 2, 0, 1, 1',
1184 'sv.or *3, *5, *5',
1185 'setvl 0, 0, 2, 0, 1, 1',
1186 'sv.or *5, *3, *3',
1187 'or 4, 5, 5',
1188 'or 7, 6, 6',
1189 'addi 3, 0, 0',
1190 'or 6, 3, 3',
1191 'setvl 0, 0, 3, 0, 1, 1',
1192 'or 3, 4, 4',
1193 'or 4, 7, 7',
1194 'or 5, 6, 6',
1195 'setvl 0, 0, 3, 0, 1, 1',
1196 'setvl 0, 0, 3, 0, 1, 1',
1197 'sv.or *24, *3, *3',
1198 'setvl 0, 0, 2, 0, 1, 1',
1199 'setvl 0, 0, 2, 0, 1, 1',
1200 'sv.or *3, *8, *8',
1201 'setvl 0, 0, 2, 0, 1, 1',
1202 'sv.or *5, *3, *3',
1203 'or 4, 5, 5',
1204 'or 7, 6, 6',
1205 'addi 3, 0, 0',
1206 'or 6, 3, 3',
1207 'setvl 0, 0, 3, 0, 1, 1',
1208 'or 3, 4, 4',
1209 'or 4, 7, 7',
1210 'or 5, 6, 6',
1211 'setvl 0, 0, 3, 0, 1, 1',
1212 'setvl 0, 0, 3, 0, 1, 1',
1213 'sv.or *30, *3, *3',
1214 'setvl 0, 0, 3, 0, 1, 1',
1215 'addic 0, 0, 0',
1216 'setvl 0, 0, 3, 0, 1, 1',
1217 'sv.or *9, *24, *24',
1218 'setvl 0, 0, 3, 0, 1, 1',
1219 'sv.or *6, *30, *30',
1220 'setvl 0, 0, 3, 0, 1, 1',
1221 'sv.adde *3, *9, *6',
1222 'setvl 0, 0, 3, 0, 1, 1',
1223 'sv.or *39, *3, *3',
1224 'setvl 0, 0, 4, 0, 1, 1',
1225 'setvl 0, 0, 4, 0, 1, 1',
1226 'sv.or *3, *14, *14',
1227 'setvl 0, 0, 4, 0, 1, 1',
1228 'sv.or *9, *3, *3',
1229 'or 3, 9, 9',
1230 'or 5, 10, 10',
1231 'or 8, 11, 11',
1232 'or 7, 12, 12',
1233 'setvl 0, 0, 2, 0, 1, 1',
1234 'or 4, 5, 5',
1235 'setvl 0, 0, 2, 0, 1, 1',
1236 'setvl 0, 0, 2, 0, 1, 1',
1237 'sv.or *5, *3, *3',
1238 'setvl 0, 0, 2, 0, 1, 1',
1239 'or 3, 8, 8',
1240 'or 4, 7, 7',
1241 'setvl 0, 0, 2, 0, 1, 1',
1242 'setvl 0, 0, 2, 0, 1, 1',
1243 'sv.or *8, *3, *3',
1244 'setvl 0, 0, 2, 0, 1, 1',
1245 'setvl 0, 0, 2, 0, 1, 1',
1246 'sv.or *3, *5, *5',
1247 'setvl 0, 0, 2, 0, 1, 1',
1248 'sv.or *5, *3, *3',
1249 'or 4, 5, 5',
1250 'or 7, 6, 6',
1251 'addi 3, 0, 0',
1252 'or 6, 3, 3',
1253 'setvl 0, 0, 3, 0, 1, 1',
1254 'or 3, 4, 4',
1255 'or 4, 7, 7',
1256 'or 5, 6, 6',
1257 'setvl 0, 0, 3, 0, 1, 1',
1258 'setvl 0, 0, 3, 0, 1, 1',
1259 'sv.or *14, *3, *3',
1260 'setvl 0, 0, 2, 0, 1, 1',
1261 'setvl 0, 0, 2, 0, 1, 1',
1262 'sv.or *3, *8, *8',
1263 'setvl 0, 0, 2, 0, 1, 1',
1264 'sv.or *5, *3, *3',
1265 'or 4, 5, 5',
1266 'or 7, 6, 6',
1267 'addi 3, 0, 0',
1268 'or 6, 3, 3',
1269 'setvl 0, 0, 3, 0, 1, 1',
1270 'or 3, 4, 4',
1271 'or 4, 7, 7',
1272 'or 5, 6, 6',
1273 'setvl 0, 0, 3, 0, 1, 1',
1274 'setvl 0, 0, 3, 0, 1, 1',
1275 'sv.or *33, *3, *3',
1276 'setvl 0, 0, 3, 0, 1, 1',
1277 'addic 0, 0, 0',
1278 'setvl 0, 0, 3, 0, 1, 1',
1279 'sv.or *9, *14, *14',
1280 'setvl 0, 0, 3, 0, 1, 1',
1281 'sv.or *6, *33, *33',
1282 'setvl 0, 0, 3, 0, 1, 1',
1283 'sv.adde *3, *9, *6',
1284 'setvl 0, 0, 3, 0, 1, 1',
1285 'sv.or *36, *3, *3',
1286 'setvl 0, 0, 3, 0, 1, 1',
1287 'setvl 0, 0, 3, 0, 1, 1',
1288 'sv.or *3, *14, *14',
1289 'setvl 0, 0, 3, 0, 1, 1',
1290 'sv.or/mrr *5, *3, *3',
1291 'or 4, 5, 5',
1292 'or 14, 6, 6',
1293 'or 23, 7, 7',
1294 'addi 3, 0, 0',
1295 'or 22, 3, 3',
1296 'setvl 0, 0, 3, 0, 1, 1',
1297 'addi 3, 0, 0',
1298 'setvl 0, 0, 3, 0, 1, 1',
1299 'sv.or *8, *24, *24',
1300 'or 7, 4, 4',
1301 'or 6, 22, 22',
1302 'setvl 0, 0, 3, 0, 1, 1',
1303 'sv.maddedu *3, *8, 7, 6',
1304 'setvl 0, 0, 3, 0, 1, 1',
1305 'or 19, 6, 6',
1306 'setvl 0, 0, 3, 0, 1, 1',
1307 'setvl 0, 0, 3, 0, 1, 1',
1308 'or 21, 3, 3',
1309 'or 12, 4, 4',
1310 'or 11, 5, 5',
1311 'setvl 0, 0, 3, 0, 1, 1',
1312 'sv.or *8, *24, *24',
1313 'or 7, 14, 14',
1314 'or 6, 22, 22',
1315 'setvl 0, 0, 3, 0, 1, 1',
1316 'sv.maddedu *3, *8, 7, 6',
1317 'setvl 0, 0, 3, 0, 1, 1',
1318 'or 18, 6, 6',
1319 'setvl 0, 0, 3, 0, 1, 1',
1320 'setvl 0, 0, 3, 0, 1, 1',
1321 'or 17, 3, 3',
1322 'or 16, 4, 4',
1323 'or 15, 5, 5',
1324 'addi 3, 0, 0',
1325 'or 8, 3, 3',
1326 'addi 3, 0, 0',
1327 'or 14, 3, 3',
1328 'setvl 0, 0, 5, 0, 1, 1',
1329 'or 3, 12, 12',
1330 'or 4, 11, 11',
1331 'or 5, 19, 19',
1332 'or 6, 8, 8',
1333 'or 7, 8, 8',
1334 'setvl 0, 0, 5, 0, 1, 1',
1335 'setvl 0, 0, 5, 0, 1, 1',
1336 'sv.or *8, *3, *3',
1337 'or 3, 17, 17',
1338 'or 4, 16, 16',
1339 'or 5, 15, 15',
1340 'or 6, 18, 18',
1341 'or 7, 14, 14',
1342 'setvl 0, 0, 5, 0, 1, 1',
1343 'setvl 0, 0, 5, 0, 1, 1',
1344 'addic 0, 0, 0',
1345 'setvl 0, 0, 5, 0, 1, 1',
1346 'sv.or *14, *8, *8',
1347 'setvl 0, 0, 5, 0, 1, 1',
1348 'sv.or *8, *3, *3',
1349 'setvl 0, 0, 5, 0, 1, 1',
1350 'sv.adde *3, *14, *8',
1351 'setvl 0, 0, 5, 0, 1, 1',
1352 'setvl 0, 0, 5, 0, 1, 1',
1353 'setvl 0, 0, 5, 0, 1, 1',
1354 'or 20, 3, 3',
1355 'or 19, 4, 4',
1356 'or 18, 5, 5',
1357 'or 17, 6, 6',
1358 'or 16, 7, 7',
1359 'setvl 0, 0, 3, 0, 1, 1',
1360 'sv.or *8, *24, *24',
1361 'or 7, 23, 23',
1362 'or 6, 22, 22',
1363 'setvl 0, 0, 3, 0, 1, 1',
1364 'sv.maddedu *3, *8, 7, 6',
1365 'setvl 0, 0, 3, 0, 1, 1',
1366 'or 15, 6, 6',
1367 'setvl 0, 0, 3, 0, 1, 1',
1368 'setvl 0, 0, 3, 0, 1, 1',
1369 'or 14, 3, 3',
1370 'or 12, 4, 4',
1371 'or 11, 5, 5',
1372 'setvl 0, 0, 4, 0, 1, 1',
1373 'or 3, 19, 19',
1374 'or 4, 18, 18',
1375 'or 5, 17, 17',
1376 'or 6, 16, 16',
1377 'setvl 0, 0, 4, 0, 1, 1',
1378 'setvl 0, 0, 4, 0, 1, 1',
1379 'sv.or *7, *3, *3',
1380 'or 3, 14, 14',
1381 'or 4, 12, 12',
1382 'or 5, 11, 11',
1383 'or 6, 15, 15',
1384 'setvl 0, 0, 4, 0, 1, 1',
1385 'setvl 0, 0, 4, 0, 1, 1',
1386 'addic 0, 0, 0',
1387 'setvl 0, 0, 4, 0, 1, 1',
1388 'sv.or *14, *7, *7',
1389 'setvl 0, 0, 4, 0, 1, 1',
1390 'sv.or *7, *3, *3',
1391 'setvl 0, 0, 4, 0, 1, 1',
1392 'sv.adde *3, *14, *7',
1393 'setvl 0, 0, 4, 0, 1, 1',
1394 'setvl 0, 0, 4, 0, 1, 1',
1395 'setvl 0, 0, 4, 0, 1, 1',
1396 'or 12, 3, 3',
1397 'or 11, 4, 4',
1398 'or 10, 5, 5',
1399 'or 9, 6, 6',
1400 'setvl 0, 0, 6, 0, 1, 1',
1401 'or 3, 21, 21',
1402 'or 4, 20, 20',
1403 'or 5, 12, 12',
1404 'or 6, 11, 11',
1405 'or 7, 10, 10',
1406 'or 8, 9, 9',
1407 'setvl 0, 0, 6, 0, 1, 1',
1408 'setvl 0, 0, 6, 0, 1, 1',
1409 'sv.or *24, *3, *3',
1410 'setvl 0, 0, 3, 0, 1, 1',
1411 'setvl 0, 0, 3, 0, 1, 1',
1412 'sv.or *3, *36, *36',
1413 'setvl 0, 0, 3, 0, 1, 1',
1414 'sv.or/mrr *5, *3, *3',
1415 'or 4, 5, 5',
1416 'or 14, 6, 6',
1417 'or 23, 7, 7',
1418 'addi 3, 0, 0',
1419 'or 22, 3, 3',
1420 'setvl 0, 0, 3, 0, 1, 1',
1421 'addi 3, 0, 0',
1422 'setvl 0, 0, 3, 0, 1, 1',
1423 'sv.or *8, *39, *39',
1424 'or 7, 4, 4',
1425 'or 6, 22, 22',
1426 'setvl 0, 0, 3, 0, 1, 1',
1427 'sv.maddedu *3, *8, 7, 6',
1428 'setvl 0, 0, 3, 0, 1, 1',
1429 'or 19, 6, 6',
1430 'setvl 0, 0, 3, 0, 1, 1',
1431 'setvl 0, 0, 3, 0, 1, 1',
1432 'or 21, 3, 3',
1433 'or 12, 4, 4',
1434 'or 11, 5, 5',
1435 'setvl 0, 0, 3, 0, 1, 1',
1436 'sv.or *8, *39, *39',
1437 'or 7, 14, 14',
1438 'or 6, 22, 22',
1439 'setvl 0, 0, 3, 0, 1, 1',
1440 'sv.maddedu *3, *8, 7, 6',
1441 'setvl 0, 0, 3, 0, 1, 1',
1442 'or 18, 6, 6',
1443 'setvl 0, 0, 3, 0, 1, 1',
1444 'setvl 0, 0, 3, 0, 1, 1',
1445 'or 17, 3, 3',
1446 'or 16, 4, 4',
1447 'or 15, 5, 5',
1448 'addi 3, 0, 0',
1449 'or 8, 3, 3',
1450 'addi 3, 0, 0',
1451 'or 14, 3, 3',
1452 'setvl 0, 0, 5, 0, 1, 1',
1453 'or 3, 12, 12',
1454 'or 4, 11, 11',
1455 'or 5, 19, 19',
1456 'or 6, 8, 8',
1457 'or 7, 8, 8',
1458 'setvl 0, 0, 5, 0, 1, 1',
1459 'setvl 0, 0, 5, 0, 1, 1',
1460 'sv.or *8, *3, *3',
1461 'or 3, 17, 17',
1462 'or 4, 16, 16',
1463 'or 5, 15, 15',
1464 'or 6, 18, 18',
1465 'or 7, 14, 14',
1466 'setvl 0, 0, 5, 0, 1, 1',
1467 'setvl 0, 0, 5, 0, 1, 1',
1468 'addic 0, 0, 0',
1469 'setvl 0, 0, 5, 0, 1, 1',
1470 'sv.or *14, *8, *8',
1471 'setvl 0, 0, 5, 0, 1, 1',
1472 'sv.or *8, *3, *3',
1473 'setvl 0, 0, 5, 0, 1, 1',
1474 'sv.adde *3, *14, *8',
1475 'setvl 0, 0, 5, 0, 1, 1',
1476 'setvl 0, 0, 5, 0, 1, 1',
1477 'setvl 0, 0, 5, 0, 1, 1',
1478 'or 20, 3, 3',
1479 'or 19, 4, 4',
1480 'or 18, 5, 5',
1481 'or 17, 6, 6',
1482 'or 16, 7, 7',
1483 'setvl 0, 0, 3, 0, 1, 1',
1484 'sv.or *8, *39, *39',
1485 'or 7, 23, 23',
1486 'or 6, 22, 22',
1487 'setvl 0, 0, 3, 0, 1, 1',
1488 'sv.maddedu *3, *8, 7, 6',
1489 'setvl 0, 0, 3, 0, 1, 1',
1490 'or 15, 6, 6',
1491 'setvl 0, 0, 3, 0, 1, 1',
1492 'setvl 0, 0, 3, 0, 1, 1',
1493 'or 14, 3, 3',
1494 'or 12, 4, 4',
1495 'or 11, 5, 5',
1496 'setvl 0, 0, 4, 0, 1, 1',
1497 'or 3, 19, 19',
1498 'or 4, 18, 18',
1499 'or 5, 17, 17',
1500 'or 6, 16, 16',
1501 'setvl 0, 0, 4, 0, 1, 1',
1502 'setvl 0, 0, 4, 0, 1, 1',
1503 'sv.or *7, *3, *3',
1504 'or 3, 14, 14',
1505 'or 4, 12, 12',
1506 'or 5, 11, 11',
1507 'or 6, 15, 15',
1508 'setvl 0, 0, 4, 0, 1, 1',
1509 'setvl 0, 0, 4, 0, 1, 1',
1510 'addic 0, 0, 0',
1511 'setvl 0, 0, 4, 0, 1, 1',
1512 'sv.or *14, *7, *7',
1513 'setvl 0, 0, 4, 0, 1, 1',
1514 'sv.or *7, *3, *3',
1515 'setvl 0, 0, 4, 0, 1, 1',
1516 'sv.adde *3, *14, *7',
1517 'setvl 0, 0, 4, 0, 1, 1',
1518 'setvl 0, 0, 4, 0, 1, 1',
1519 'setvl 0, 0, 4, 0, 1, 1',
1520 'or 12, 3, 3',
1521 'or 11, 4, 4',
1522 'or 10, 5, 5',
1523 'or 9, 6, 6',
1524 'setvl 0, 0, 6, 0, 1, 1',
1525 'or 3, 21, 21',
1526 'or 4, 20, 20',
1527 'or 5, 12, 12',
1528 'or 6, 11, 11',
1529 'or 7, 10, 10',
1530 'or 8, 9, 9',
1531 'setvl 0, 0, 6, 0, 1, 1',
1532 'setvl 0, 0, 6, 0, 1, 1',
1533 'sv.or *36, *3, *3',
1534 'setvl 0, 0, 3, 0, 1, 1',
1535 'setvl 0, 0, 3, 0, 1, 1',
1536 'sv.or *3, *33, *33',
1537 'setvl 0, 0, 3, 0, 1, 1',
1538 'sv.or/mrr *5, *3, *3',
1539 'or 4, 5, 5',
1540 'or 14, 6, 6',
1541 'or 23, 7, 7',
1542 'addi 3, 0, 0',
1543 'or 22, 3, 3',
1544 'setvl 0, 0, 3, 0, 1, 1',
1545 'addi 3, 0, 0',
1546 'setvl 0, 0, 3, 0, 1, 1',
1547 'sv.or *8, *30, *30',
1548 'or 7, 4, 4',
1549 'or 6, 22, 22',
1550 'setvl 0, 0, 3, 0, 1, 1',
1551 'sv.maddedu *3, *8, 7, 6',
1552 'setvl 0, 0, 3, 0, 1, 1',
1553 'or 19, 6, 6',
1554 'setvl 0, 0, 3, 0, 1, 1',
1555 'setvl 0, 0, 3, 0, 1, 1',
1556 'or 21, 3, 3',
1557 'or 12, 4, 4',
1558 'or 11, 5, 5',
1559 'setvl 0, 0, 3, 0, 1, 1',
1560 'sv.or *8, *30, *30',
1561 'or 7, 14, 14',
1562 'or 6, 22, 22',
1563 'setvl 0, 0, 3, 0, 1, 1',
1564 'sv.maddedu *3, *8, 7, 6',
1565 'setvl 0, 0, 3, 0, 1, 1',
1566 'or 18, 6, 6',
1567 'setvl 0, 0, 3, 0, 1, 1',
1568 'setvl 0, 0, 3, 0, 1, 1',
1569 'or 17, 3, 3',
1570 'or 16, 4, 4',
1571 'or 15, 5, 5',
1572 'addi 3, 0, 0',
1573 'or 8, 3, 3',
1574 'addi 3, 0, 0',
1575 'or 14, 3, 3',
1576 'setvl 0, 0, 5, 0, 1, 1',
1577 'or 3, 12, 12',
1578 'or 4, 11, 11',
1579 'or 5, 19, 19',
1580 'or 6, 8, 8',
1581 'or 7, 8, 8',
1582 'setvl 0, 0, 5, 0, 1, 1',
1583 'setvl 0, 0, 5, 0, 1, 1',
1584 'sv.or *8, *3, *3',
1585 'or 3, 17, 17',
1586 'or 4, 16, 16',
1587 'or 5, 15, 15',
1588 'or 6, 18, 18',
1589 'or 7, 14, 14',
1590 'setvl 0, 0, 5, 0, 1, 1',
1591 'setvl 0, 0, 5, 0, 1, 1',
1592 'addic 0, 0, 0',
1593 'setvl 0, 0, 5, 0, 1, 1',
1594 'sv.or *14, *8, *8',
1595 'setvl 0, 0, 5, 0, 1, 1',
1596 'sv.or *8, *3, *3',
1597 'setvl 0, 0, 5, 0, 1, 1',
1598 'sv.adde *3, *14, *8',
1599 'setvl 0, 0, 5, 0, 1, 1',
1600 'setvl 0, 0, 5, 0, 1, 1',
1601 'setvl 0, 0, 5, 0, 1, 1',
1602 'or 20, 3, 3',
1603 'or 19, 4, 4',
1604 'or 18, 5, 5',
1605 'or 17, 6, 6',
1606 'or 16, 7, 7',
1607 'setvl 0, 0, 3, 0, 1, 1',
1608 'sv.or *8, *30, *30',
1609 'or 7, 23, 23',
1610 'or 6, 22, 22',
1611 'setvl 0, 0, 3, 0, 1, 1',
1612 'sv.maddedu *3, *8, 7, 6',
1613 'setvl 0, 0, 3, 0, 1, 1',
1614 'or 15, 6, 6',
1615 'setvl 0, 0, 3, 0, 1, 1',
1616 'setvl 0, 0, 3, 0, 1, 1',
1617 'or 14, 3, 3',
1618 'or 12, 4, 4',
1619 'or 11, 5, 5',
1620 'setvl 0, 0, 4, 0, 1, 1',
1621 'or 3, 19, 19',
1622 'or 4, 18, 18',
1623 'or 5, 17, 17',
1624 'or 6, 16, 16',
1625 'setvl 0, 0, 4, 0, 1, 1',
1626 'setvl 0, 0, 4, 0, 1, 1',
1627 'sv.or *7, *3, *3',
1628 'or 3, 14, 14',
1629 'or 4, 12, 12',
1630 'or 5, 11, 11',
1631 'or 6, 15, 15',
1632 'setvl 0, 0, 4, 0, 1, 1',
1633 'setvl 0, 0, 4, 0, 1, 1',
1634 'addic 0, 0, 0',
1635 'setvl 0, 0, 4, 0, 1, 1',
1636 'sv.or *14, *7, *7',
1637 'setvl 0, 0, 4, 0, 1, 1',
1638 'sv.or *7, *3, *3',
1639 'setvl 0, 0, 4, 0, 1, 1',
1640 'sv.adde *3, *14, *7',
1641 'setvl 0, 0, 4, 0, 1, 1',
1642 'setvl 0, 0, 4, 0, 1, 1',
1643 'setvl 0, 0, 4, 0, 1, 1',
1644 'or 12, 3, 3',
1645 'or 11, 4, 4',
1646 'or 10, 5, 5',
1647 'or 9, 6, 6',
1648 'setvl 0, 0, 6, 0, 1, 1',
1649 'or 3, 21, 21',
1650 'or 4, 20, 20',
1651 'or 5, 12, 12',
1652 'or 6, 11, 11',
1653 'or 7, 10, 10',
1654 'or 8, 9, 9',
1655 'setvl 0, 0, 6, 0, 1, 1',
1656 'setvl 0, 0, 6, 0, 1, 1',
1657 'sv.or *30, *3, *3',
1658 'setvl 0, 0, 6, 0, 1, 1',
1659 'setvl 0, 0, 6, 0, 1, 1',
1660 'sv.or *3, *24, *24',
1661 'setvl 0, 0, 6, 0, 1, 1',
1662 'sv.or *14, *3, *3',
1663 'or 4, 14, 14',
1664 'or 11, 15, 15',
1665 'or 10, 16, 16',
1666 'or 9, 17, 17',
1667 'or 8, 18, 18',
1668 'or 3, 19, 19',
1669 'setvl 0, 0, 5, 0, 1, 1',
1670 'or 3, 4, 4',
1671 'or 4, 11, 11',
1672 'or 5, 10, 10',
1673 'or 6, 9, 9',
1674 'or 7, 8, 8',
1675 'setvl 0, 0, 5, 0, 1, 1',
1676 'setvl 0, 0, 5, 0, 1, 1',
1677 'sv.or *25, *3, *3',
1678 'setvl 0, 0, 6, 0, 1, 1',
1679 'setvl 0, 0, 6, 0, 1, 1',
1680 'sv.or *3, *36, *36',
1681 'setvl 0, 0, 6, 0, 1, 1',
1682 'sv.or *14, *3, *3',
1683 'or 4, 14, 14',
1684 'or 11, 15, 15',
1685 'or 10, 16, 16',
1686 'or 9, 17, 17',
1687 'or 8, 18, 18',
1688 'or 3, 19, 19',
1689 'setvl 0, 0, 5, 0, 1, 1',
1690 'or 3, 4, 4',
1691 'or 4, 11, 11',
1692 'or 5, 10, 10',
1693 'or 6, 9, 9',
1694 'or 7, 8, 8',
1695 'setvl 0, 0, 5, 0, 1, 1',
1696 'setvl 0, 0, 5, 0, 1, 1',
1697 'setvl 0, 0, 5, 0, 1, 1',
1698 'subfc 0, 0, 0',
1699 'setvl 0, 0, 5, 0, 1, 1',
1700 'sv.or *14, *25, *25',
1701 'setvl 0, 0, 5, 0, 1, 1',
1702 'sv.or *8, *3, *3',
1703 'setvl 0, 0, 5, 0, 1, 1',
1704 'sv.subfe *3, *14, *8',
1705 'setvl 0, 0, 5, 0, 1, 1',
1706 'sv.or *20, *3, *3',
1707 'setvl 0, 0, 6, 0, 1, 1',
1708 'setvl 0, 0, 6, 0, 1, 1',
1709 'sv.or *3, *30, *30',
1710 'setvl 0, 0, 6, 0, 1, 1',
1711 'sv.or *14, *3, *3',
1712 'or 4, 14, 14',
1713 'or 11, 15, 15',
1714 'or 10, 16, 16',
1715 'or 9, 17, 17',
1716 'or 8, 18, 18',
1717 'or 3, 19, 19',
1718 'setvl 0, 0, 5, 0, 1, 1',
1719 'or 3, 4, 4',
1720 'or 4, 11, 11',
1721 'or 5, 10, 10',
1722 'or 6, 9, 9',
1723 'or 7, 8, 8',
1724 'setvl 0, 0, 5, 0, 1, 1',
1725 'setvl 0, 0, 5, 0, 1, 1',
1726 'sv.or *30, *3, *3',
1727 'setvl 0, 0, 5, 0, 1, 1',
1728 'subfc 0, 0, 0',
1729 'setvl 0, 0, 5, 0, 1, 1',
1730 'sv.or *14, *30, *30',
1731 'setvl 0, 0, 5, 0, 1, 1',
1732 'sv.or *8, *20, *20',
1733 'setvl 0, 0, 5, 0, 1, 1',
1734 'sv.subfe *3, *14, *8',
1735 'setvl 0, 0, 5, 0, 1, 1',
1736 'sv.or *16, *3, *3',
1737 'setvl 0, 0, 5, 0, 1, 1',
1738 'setvl 0, 0, 5, 0, 1, 1',
1739 'sv.or *3, *25, *25',
1740 'setvl 0, 0, 5, 0, 1, 1',
1741 'or 29, 3, 3',
1742 'or 28, 4, 4',
1743 'or 8, 5, 5',
1744 'or 15, 6, 6',
1745 'or 14, 7, 7',
1746 'setvl 0, 0, 5, 0, 1, 1',
1747 'setvl 0, 0, 5, 0, 1, 1',
1748 'sv.or *3, *16, *16',
1749 'setvl 0, 0, 5, 0, 1, 1',
1750 'or 24, 3, 3',
1751 'or 23, 4, 4',
1752 'or 22, 5, 5',
1753 'or 21, 6, 6',
1754 'or 20, 7, 7',
1755 'setvl 0, 0, 5, 0, 1, 1',
1756 'setvl 0, 0, 5, 0, 1, 1',
1757 'sv.or *3, *30, *30',
1758 'setvl 0, 0, 5, 0, 1, 1',
1759 'or 27, 3, 3',
1760 'or 26, 4, 4',
1761 'or 12, 5, 5',
1762 'or 11, 6, 6',
1763 'or 3, 7, 7',
1764 'addi 3, 0, 0',
1765 'addi 3, 0, 0',
1766 'or 10, 3, 3',
1767 'or 3, 20, 20',
1768 'sradi 3, 3, 63',
1769 'or 9, 3, 3',
1770 'setvl 0, 0, 6, 0, 1, 1',
1771 'or 3, 8, 8',
1772 'or 4, 15, 15',
1773 'or 5, 14, 14',
1774 'or 6, 10, 10',
1775 'or 7, 10, 10',
1776 'or 8, 10, 10',
1777 'setvl 0, 0, 6, 0, 1, 1',
1778 'setvl 0, 0, 6, 0, 1, 1',
1779 'sv.or *14, *3, *3',
1780 'or 3, 24, 24',
1781 'or 4, 23, 23',
1782 'or 5, 22, 22',
1783 'or 6, 21, 21',
1784 'or 7, 20, 20',
1785 'or 8, 9, 9',
1786 'setvl 0, 0, 6, 0, 1, 1',
1787 'setvl 0, 0, 6, 0, 1, 1',
1788 'addic 0, 0, 0',
1789 'setvl 0, 0, 6, 0, 1, 1',
1790 'sv.or *20, *14, *14',
1791 'setvl 0, 0, 6, 0, 1, 1',
1792 'sv.or *14, *3, *3',
1793 'setvl 0, 0, 6, 0, 1, 1',
1794 'sv.adde *3, *20, *14',
1795 'setvl 0, 0, 6, 0, 1, 1',
1796 'setvl 0, 0, 6, 0, 1, 1',
1797 'setvl 0, 0, 6, 0, 1, 1',
1798 'sv.or *20, *3, *3',
1799 'or 19, 20, 20',
1800 'or 18, 21, 21',
1801 'or 3, 22, 22',
1802 'or 9, 23, 23',
1803 'or 8, 24, 24',
1804 'or 7, 25, 25',
1805 'setvl 0, 0, 4, 0, 1, 1',
1806 'or 4, 9, 9',
1807 'or 5, 8, 8',
1808 'or 6, 7, 7',
1809 'setvl 0, 0, 4, 0, 1, 1',
1810 'setvl 0, 0, 4, 0, 1, 1',
1811 'sv.or *7, *3, *3',
1812 'or 3, 27, 27',
1813 'or 4, 26, 26',
1814 'or 5, 12, 12',
1815 'or 6, 11, 11',
1816 'setvl 0, 0, 4, 0, 1, 1',
1817 'setvl 0, 0, 4, 0, 1, 1',
1818 'addic 0, 0, 0',
1819 'setvl 0, 0, 4, 0, 1, 1',
1820 'sv.or *14, *7, *7',
1821 'setvl 0, 0, 4, 0, 1, 1',
1822 'sv.or *7, *3, *3',
1823 'setvl 0, 0, 4, 0, 1, 1',
1824 'sv.adde *3, *14, *7',
1825 'setvl 0, 0, 4, 0, 1, 1',
1826 'setvl 0, 0, 4, 0, 1, 1',
1827 'setvl 0, 0, 4, 0, 1, 1',
1828 'or 15, 3, 3',
1829 'or 14, 4, 4',
1830 'or 12, 5, 5',
1831 'or 11, 6, 6',
1832 'setvl 0, 0, 8, 0, 1, 1',
1833 'or 3, 29, 29',
1834 'or 4, 28, 28',
1835 'or 5, 19, 19',
1836 'or 6, 18, 18',
1837 'or 7, 15, 15',
1838 'or 8, 14, 14',
1839 'or 9, 12, 12',
1840 'or 10, 11, 11',
1841 'setvl 0, 0, 8, 0, 1, 1',
1842 'setvl 0, 0, 8, 0, 1, 1',
1843 'setvl 0, 0, 8, 0, 1, 1',
1844 'setvl 0, 0, 8, 0, 1, 1',
1845 'sv.or/mrr *4, *3, *3',
1846 'or 3, 42, 42',
1847 'setvl 0, 0, 8, 0, 1, 1',
1848 'sv.std *4, 0(3)'
1849 ])
1850
1851 def tst_toom_mul_sim(
1852 self, code, # type: Mul
1853 lhs_signed, # type: bool
1854 rhs_signed, # type: bool
1855 get_state_factory, # type: Callable[[Mul], _StateFactory]
1856 test_cases, # type: Iterable[tuple[int, int]]
1857 ):
1858 print(code.retval[1])
1859 print(code.fn.ops_to_str())
1860 state_factory = get_state_factory(code)
1861 ptr_in = 0x100
1862 dest_ptr = ptr_in + code.dest_offset
1863 lhs_ptr = ptr_in + code.lhs_offset
1864 rhs_ptr = ptr_in + code.rhs_offset
1865 lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS
1866 rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS
1867 for lhs_value, rhs_value in test_cases:
1868 lhs_value %= 1 << lhs_size_in_bits
1869 rhs_value %= 1 << rhs_size_in_bits
1870 if lhs_signed and lhs_value >> (lhs_size_in_bits - 1):
1871 lhs_value -= 1 << lhs_size_in_bits
1872 if rhs_signed and rhs_value >> (rhs_size_in_bits - 1):
1873 rhs_value -= 1 << rhs_size_in_bits
1874 prod_value = lhs_value * rhs_value
1875 lhs_value %= 1 << lhs_size_in_bits
1876 rhs_value %= 1 << rhs_size_in_bits
1877 prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits)
1878 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1879 lhs_value=hex(lhs_value),
1880 rhs_value=hex(rhs_value),
1881 prod_value=hex(prod_value)):
1882 with state_factory() as state:
1883 state[code.ptr_in] = ptr_in,
1884 for i in range(code.lhs_size_in_words):
1885 v = lhs_value >> GPR_SIZE_IN_BITS * i
1886 v &= GPR_VALUE_MASK
1887 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1888 for i in range(code.rhs_size_in_words):
1889 v = rhs_value >> GPR_SIZE_IN_BITS * i
1890 v &= GPR_VALUE_MASK
1891 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1892 code.fn.sim(state)
1893 prod = 0
1894 for i in range(code.dest_size_in_words):
1895 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1896 prod += v << (GPR_SIZE_IN_BITS * i)
1897 self.assertEqual(hex(prod), hex(prod_value),
1898 f"failed: state={state}")
1899
1900 def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed):
1901 # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
1902 def mul(fn, lhs, rhs):
1903 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
1904 v = ToomCookMul(
1905 fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
1906 rhs_signed=rhs_signed, instances=instances)
1907 return v.retval, v
1908 sizes_in_words = OSet() # type: OSet[int]
1909 for i in range(6):
1910 sizes_in_words.add(1 << i)
1911 sizes_in_words.add(3 << i)
1912 sizes_in_words = OSet(
1913 i for i in sorted(sizes_in_words) if 1 <= i <= 16)
1914 for lhs_size_in_words in sizes_in_words:
1915 for rhs_size_in_words in sizes_in_words:
1916 lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
1917 rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
1918 with self.subTest(lhs_size_in_words=lhs_size_in_words,
1919 rhs_size_in_words=rhs_size_in_words,
1920 lhs_signed=lhs_signed,
1921 rhs_signed=rhs_signed):
1922 test_cases = [] # type: list[tuple[int, int]]
1923 test_cases.append((-1, -1))
1924 test_cases.append(((0x80 << 2048) // 0xFF,
1925 (0x80 << 2048) // 0xFF))
1926 test_cases.append(((0x40 << 2048) // 0xFF,
1927 (0x80 << 2048) // 0xFF))
1928 test_cases.append(((0x80 << 2048) // 0xFF,
1929 (0x40 << 2048) // 0xFF))
1930 test_cases.append(((0x40 << 2048) // 0xFF,
1931 (0x40 << 2048) // 0xFF))
1932 test_cases.append((1 << (lhs_size_in_bits - 1),
1933 1 << (rhs_size_in_bits - 1)))
1934 test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
1935 test_cases.append((1 << (lhs_size_in_bits - 1), 1))
1936 test_cases.append((1, 1))
1937 self.tst_toom_mul_sim(
1938 code=Mul(mul=mul,
1939 lhs_size_in_words=lhs_size_in_words,
1940 rhs_size_in_words=rhs_size_in_words),
1941 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1942 get_state_factory=get_pre_ra_state_factory,
1943 test_cases=test_cases)
1944
1945 def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self):
1946 TOOM_2 = ToomCookInstance.make_toom_2()
1947 self.tst_toom_mul_all_sizes_pre_ra_sim(
1948 (TOOM_2,), lhs_signed=False, rhs_signed=False)
1949
1950 def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self):
1951 TOOM_2 = ToomCookInstance.make_toom_2()
1952 self.tst_toom_mul_all_sizes_pre_ra_sim(
1953 (TOOM_2,), lhs_signed=False, rhs_signed=True)
1954
1955 def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self):
1956 TOOM_2 = ToomCookInstance.make_toom_2()
1957 self.tst_toom_mul_all_sizes_pre_ra_sim(
1958 (TOOM_2,), lhs_signed=True, rhs_signed=False)
1959
1960 def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self):
1961 TOOM_2 = ToomCookInstance.make_toom_2()
1962 self.tst_toom_mul_all_sizes_pre_ra_sim(
1963 (TOOM_2,), lhs_signed=True, rhs_signed=True)
1964
1965 def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self):
1966 TOOM_2 = ToomCookInstance.make_toom_2()
1967 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1968 self.tst_toom_mul_all_sizes_pre_ra_sim(
1969 instances, lhs_signed=False, rhs_signed=False)
1970
1971 def test_toom_2_mul_us_all_sizes_pre_ra_sim(self):
1972 TOOM_2 = ToomCookInstance.make_toom_2()
1973 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1974 self.tst_toom_mul_all_sizes_pre_ra_sim(
1975 instances, lhs_signed=False, rhs_signed=True)
1976
1977 def test_toom_2_mul_su_all_sizes_pre_ra_sim(self):
1978 TOOM_2 = ToomCookInstance.make_toom_2()
1979 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1980 self.tst_toom_mul_all_sizes_pre_ra_sim(
1981 instances, lhs_signed=True, rhs_signed=False)
1982
1983 def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self):
1984 TOOM_2 = ToomCookInstance.make_toom_2()
1985 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1986 self.tst_toom_mul_all_sizes_pre_ra_sim(
1987 instances, lhs_signed=True, rhs_signed=True)
1988
1989
1990 if __name__ == "__main__":
1991 unittest.main()