register_allocator works!
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
1 from contextlib import contextmanager
2 import sys
3 import unittest
4 from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
5
6 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
7 GPR_SIZE_IN_BYTES,
8 GPR_VALUE_MASK, BaseSimState,
9 Fn, GenAsmState, OpKind,
10 PostRASimState,
11 PreRASimState, SSAVal)
12 from bigint_presentation_code.register_allocator import allocate_registers
13 from bigint_presentation_code.register_allocator_test_util import GraphDumper
14 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
15 simple_mul)
16 from bigint_presentation_code.util import OSet
17
18 _StateFactory = Callable[[], ContextManager[BaseSimState]]
19
20
21 def simple_umul(fn, lhs, rhs):
22 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, None]
23 return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
24 rhs_signed=False, name="mul"), None
25
26
27 def get_pre_ra_state_factory(code):
28 # type: (Mul) -> _StateFactory
29 @contextmanager
30 def state_factory():
31 state = PreRASimState(ssa_vals={}, memory={})
32 with state.set_as_current_debugging_state():
33 yield state
34 return state_factory
35
36
37 class Mul:
38 _MulFn = Callable[[Fn, SSAVal, SSAVal], Tuple[SSAVal, Any]]
39
40 def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
41 # type: (_MulFn, int, int) -> None
42 super().__init__()
43 self.fn = fn = Fn()
44 self.dest_offset = 0
45 self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
46 self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
47 self.lhs_size_in_words = lhs_size_in_words
48 self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
49 self.rhs_size_in_words = rhs_size_in_words
50 self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
51 self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
52 self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
53 self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
54 name="ptr_in").outputs[0]
55 self.lhs_setvl = fn.append_new_op(
56 kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
57 maxvl=lhs_size_in_words, name="lhs_setvl")
58 self.load_lhs = fn.append_new_op(
59 kind=OpKind.SvLd, immediates=[self.lhs_offset],
60 input_vals=[self.ptr_in, self.lhs_setvl.outputs[0]],
61 name="load_lhs", maxvl=lhs_size_in_words)
62 self.rhs_setvl = fn.append_new_op(
63 kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
64 maxvl=rhs_size_in_words, name="rhs_setvl")
65 self.load_rhs = fn.append_new_op(
66 kind=OpKind.SvLd, immediates=[self.rhs_offset],
67 input_vals=[self.ptr_in, self.rhs_setvl.outputs[0]],
68 name="load_rhs", maxvl=rhs_size_in_words)
69 self.retval = mul(
70 fn, self.load_lhs.outputs[0], self.load_rhs.outputs[0])
71 self.dest_setvl = fn.append_new_op(
72 kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
73 maxvl=self.dest_size_in_words, name="dest_setvl")
74 self.store = fn.append_new_op(
75 kind=OpKind.SvStd,
76 input_vals=[self.retval[0], self.ptr_in,
77 self.dest_setvl.outputs[0]],
78 immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
79 name="store_dest")
80
81
82 class TestToomCook(unittest.TestCase):
83 maxDiff = None
84
85 def get_post_ra_state_factory(self, code):
86 # type: (Mul) -> _StateFactory
87 ssa_val_to_loc_map = allocate_registers(
88 code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
89
90 @contextmanager
91 def state_factory():
92 yield PostRASimState(
93 ssa_val_to_loc_map=ssa_val_to_loc_map,
94 memory={}, loc_values={})
95 return state_factory
96
97 def test_toom_2_repr(self):
98 TOOM_2 = ToomCookInstance.make_toom_2()
99 # print(repr(repr(TOOM_2)))
100 self.assertEqual(
101 repr(TOOM_2),
102 "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
103 "eval_points=(0, 1, POINT_AT_INFINITY), "
104 "lhs_eval_ops=("
105 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
106 "EvalOpAdd(lhs="
107 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
108 "rhs="
109 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
110 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
111 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
112 " rhs_eval_ops=("
113 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
114 "EvalOpAdd(lhs="
115 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
116 "rhs="
117 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
118 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
119 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
120 " prod_eval_ops=("
121 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
122 "EvalOpSub(lhs="
123 "EvalOpSub(lhs="
124 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
125 "rhs="
126 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
127 "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
128 "rhs="
129 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
130 "poly=EvalOpPoly({"
131 "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
132 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
133 )
134
135 def test_toom_2_5_repr(self):
136 TOOM_2_5 = ToomCookInstance.make_toom_2_5()
137 # print(repr(repr(TOOM_2_5)))
138 self.assertEqual(
139 repr(TOOM_2_5),
140 "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
141 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
142 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
143 "EvalOpAdd(lhs="
144 "EvalOpAdd(lhs="
145 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
146 "rhs=EvalOpInput(lhs=2, rhs=0, "
147 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
148 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
149 "rhs=EvalOpInput(lhs=1, rhs=0, "
150 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
151 "poly=EvalOpPoly({"
152 "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
153 "EvalOpSub(lhs="
154 "EvalOpAdd(lhs="
155 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
156 "rhs=EvalOpInput(lhs=2, rhs=0, "
157 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
158 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
159 "rhs=EvalOpInput(lhs=1, rhs=0, "
160 "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
161 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
162 "EvalOpInput(lhs=2, rhs=0, "
163 "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
164 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
165 "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
166 "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
167 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
168 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
169 "EvalOpSub(lhs="
170 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
171 "rhs=EvalOpInput(lhs=1, rhs=0, "
172 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
173 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
174 "EvalOpInput(lhs=1, rhs=0, "
175 "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
176 "prod_eval_ops=("
177 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
178 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
179 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
180 "rhs=EvalOpInput(lhs=2, rhs=0, "
181 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
182 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
183 "rhs=2, "
184 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
185 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
186 "poly=EvalOpPoly("
187 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
188 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
189 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
190 "rhs="
191 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
192 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
193 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
194 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
195 "poly=EvalOpPoly("
196 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
197 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
198 )
199
200 def test_reversed_toom_2_5_repr(self):
201 TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
202 # print(repr(repr(TOOM_2_5)))
203 self.assertEqual(
204 repr(TOOM_2_5),
205 "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
206 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
207 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
208 "EvalOpAdd(lhs="
209 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
210 "rhs="
211 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
212 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
213 "EvalOpSub(lhs="
214 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
215 "rhs="
216 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
217 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
218 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
219 " rhs_eval_ops=("
220 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
221 "EvalOpAdd(lhs=EvalOpAdd(lhs="
222 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
223 "rhs="
224 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
225 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
226 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
227 "poly=EvalOpPoly("
228 "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
229 "EvalOpSub(lhs=EvalOpAdd(lhs="
230 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
231 "rhs="
232 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
233 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
234 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
235 "poly=EvalOpPoly("
236 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
237 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
238 " prod_eval_ops=("
239 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
240 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
241 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
242 "rhs="
243 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
244 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
245 "rhs=2, "
246 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
247 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
248 "poly=EvalOpPoly("
249 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
250 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
251 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
252 "rhs="
253 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
254 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
255 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
256 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
257 "poly=EvalOpPoly("
258 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
259 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
260 )
261
262 def test_simple_mul_192x192_pre_ra_sim(self):
263 for lhs_signed in False, True:
264 for rhs_signed in False, True:
265 self.tst_simple_mul_192x192_sim(
266 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
267 get_state_factory=get_pre_ra_state_factory)
268
269 def test_simple_mul_192x192_post_ra_sim(self):
270 for lhs_signed in False, True:
271 for rhs_signed in False, True:
272 self.tst_simple_mul_192x192_sim(
273 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
274 get_state_factory=self.get_post_ra_state_factory)
275
276 def tst_simple_mul_192x192_sim(
277 self, lhs_signed, # type: bool
278 rhs_signed, # type: bool
279 get_state_factory, # type: Callable[[Mul], _StateFactory]
280 ):
281 # test multiplying:
282 # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
283 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
284 # ==
285 # int("0x00074736574206e_6f69746163696c70"
286 # "_69746c756d207469_622d3438333e2d32"
287 # "_3931783239312079_7261727469627261", base=0)
288 # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
289 # 'little')
290 lhs_value = 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
291 rhs_value = 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
292 prod_value = int.from_bytes(
293 b"arbitrary 192x192->384-bit multiplication test", 'little')
294 self.assertEqual(lhs_value * rhs_value, prod_value)
295 code = Mul(
296 mul=lambda fn, lhs, rhs: (simple_mul(
297 fn=fn, lhs=lhs, lhs_signed=lhs_signed,
298 rhs=rhs, rhs_signed=rhs_signed, name="mul"), None),
299 lhs_size_in_words=3, rhs_size_in_words=3)
300 state_factory = get_state_factory(code)
301 ptr_in = 0x100
302 dest_ptr = ptr_in + code.dest_offset
303 lhs_ptr = ptr_in + code.lhs_offset
304 rhs_ptr = ptr_in + code.rhs_offset
305 for lhs_neg in False, True:
306 for rhs_neg in False, True:
307 if lhs_neg and not lhs_signed:
308 continue
309 if rhs_neg and not rhs_signed:
310 continue
311 with self.subTest(lhs_signed=lhs_signed,
312 rhs_signed=rhs_signed,
313 lhs_neg=lhs_neg, rhs_neg=rhs_neg):
314 with state_factory() as state:
315 state[code.ptr_in] = ptr_in,
316 lhs = lhs_value
317 if lhs_neg:
318 lhs = 2 ** 192 - lhs
319 rhs = rhs_value
320 if rhs_neg:
321 rhs = 2 ** 192 - rhs
322 for i in range(3):
323 v = (lhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
324 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
325 for i in range(3):
326 v = (rhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
327 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
328 code.fn.sim(state)
329 expected = prod_value
330 if lhs_neg != rhs_neg:
331 expected = 2 ** 384 - expected
332 prod = 0
333 for i in range(6):
334 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
335 prod += v << (GPR_SIZE_IN_BITS * i)
336 self.assertEqual(hex(prod), hex(expected))
337
338 def test_simple_mul_192x192_ops(self):
339 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
340 fn = code.fn
341 self.assertEqual(
342 fn.ops_to_str(),
343 "ptr_in:\n"
344 " (<...outputs[0]: <I64>>) <= FuncArgR3\n"
345 "lhs_setvl:\n"
346 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
347 "load_lhs:\n"
348 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
349 " <ptr_in.outputs[0]: <I64>>,\n"
350 " <lhs_setvl.outputs[0]: <VL_MAXVL>>, 0x30)\n"
351 "rhs_setvl:\n"
352 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
353 "load_rhs:\n"
354 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
355 " <ptr_in.outputs[0]: <I64>>,\n"
356 " <rhs_setvl.outputs[0]: <VL_MAXVL>>, 0x48)\n"
357 "mul_rhs_setvl:\n"
358 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
359 "mul_rhs_spread:\n"
360 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
361 " <...outputs[2]: <I64>>) <= Spread(\n"
362 " <load_rhs.outputs[0]: <I64*3>>,\n"
363 " <mul_rhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
364 "mul_zero:\n"
365 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
366 "mul_lhs_setvl:\n"
367 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
368 "mul_zero2:\n"
369 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
370 "mul_0_mul:\n"
371 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
372 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
373 " <mul_rhs_spread.outputs[0]: <I64>>,\n"
374 " <mul_zero.outputs[0]: <I64>>,\n"
375 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
376 "mul_0_mul_rt_spread:\n"
377 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
378 " <...outputs[2]: <I64>>) <= Spread(\n"
379 " <mul_0_mul.outputs[0]: <I64*3>>,\n"
380 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
381 "mul_1_mul:\n"
382 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
383 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
384 " <mul_rhs_spread.outputs[1]: <I64>>,\n"
385 " <mul_zero.outputs[0]: <I64>>,\n"
386 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
387 "mul_1_mul_rt_spread:\n"
388 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
389 " <...outputs[2]: <I64>>) <= Spread(\n"
390 " <mul_1_mul.outputs[0]: <I64*3>>,\n"
391 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
392 "mul_1_cast_retval_zero:\n"
393 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
394 "mul_1_cast_pp_zero:\n"
395 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
396 "mul_1_setvl:\n"
397 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x5)\n"
398 "mul_1_retval_concat:\n"
399 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
400 " <mul_0_mul_rt_spread.outputs[1]: <I64>>,\n"
401 " <mul_0_mul_rt_spread.outputs[2]: <I64>>,\n"
402 " <mul_0_mul.outputs[1]: <I64>>,\n"
403 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
404 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
405 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
406 "mul_1_pp_concat:\n"
407 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
408 " <mul_1_mul_rt_spread.outputs[0]: <I64>>,\n"
409 " <mul_1_mul_rt_spread.outputs[1]: <I64>>,\n"
410 " <mul_1_mul_rt_spread.outputs[2]: <I64>>,\n"
411 " <mul_1_mul.outputs[1]: <I64>>,\n"
412 " <mul_1_cast_pp_zero.outputs[0]: <I64>>,\n"
413 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
414 "mul_1_clear_ca:\n"
415 " (<...outputs[0]: <CA>>) <= ClearCA\n"
416 "mul_1_add:\n"
417 " (<...outputs[0]: <I64*5>>, <...outputs[1]: <CA>>\n"
418 " ) <= SvAddE(<mul_1_retval_concat.outputs[0]: <I64*5>>,\n"
419 " <mul_1_pp_concat.outputs[0]: <I64*5>>,\n"
420 " <mul_1_clear_ca.outputs[0]: <CA>>,\n"
421 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
422 "mul_1_sum_spread:\n"
423 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
424 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>,\n"
425 " <...outputs[4]: <I64>>) <= Spread(\n"
426 " <mul_1_add.outputs[0]: <I64*5>>,\n"
427 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
428 "mul_2_mul:\n"
429 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
430 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
431 " <mul_rhs_spread.outputs[2]: <I64>>,\n"
432 " <mul_zero.outputs[0]: <I64>>,\n"
433 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
434 "mul_2_mul_rt_spread:\n"
435 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
436 " <...outputs[2]: <I64>>) <= Spread(\n"
437 " <mul_2_mul.outputs[0]: <I64*3>>,\n"
438 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
439 "mul_2_setvl:\n"
440 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x4)\n"
441 "mul_2_retval_concat:\n"
442 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
443 " <mul_1_sum_spread.outputs[1]: <I64>>,\n"
444 " <mul_1_sum_spread.outputs[2]: <I64>>,\n"
445 " <mul_1_sum_spread.outputs[3]: <I64>>,\n"
446 " <mul_1_sum_spread.outputs[4]: <I64>>,\n"
447 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
448 "mul_2_pp_concat:\n"
449 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
450 " <mul_2_mul_rt_spread.outputs[0]: <I64>>,\n"
451 " <mul_2_mul_rt_spread.outputs[1]: <I64>>,\n"
452 " <mul_2_mul_rt_spread.outputs[2]: <I64>>,\n"
453 " <mul_2_mul.outputs[1]: <I64>>,\n"
454 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
455 "mul_2_clear_ca:\n"
456 " (<...outputs[0]: <CA>>) <= ClearCA\n"
457 "mul_2_add:\n"
458 " (<...outputs[0]: <I64*4>>, <...outputs[1]: <CA>>\n"
459 " ) <= SvAddE(<mul_2_retval_concat.outputs[0]: <I64*4>>,\n"
460 " <mul_2_pp_concat.outputs[0]: <I64*4>>,\n"
461 " <mul_2_clear_ca.outputs[0]: <CA>>,\n"
462 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
463 "mul_2_sum_spread:\n"
464 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
465 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>) <= Spread(\n"
466 " <mul_2_add.outputs[0]: <I64*4>>,\n"
467 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
468 "mul_setvl:\n"
469 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
470 "mul_concat:\n"
471 " (<...outputs[0]: <I64*6>>) <= Concat(\n"
472 " <mul_0_mul_rt_spread.outputs[0]: <I64>>,\n"
473 " <mul_1_sum_spread.outputs[0]: <I64>>,\n"
474 " <mul_2_sum_spread.outputs[0]: <I64>>,\n"
475 " <mul_2_sum_spread.outputs[1]: <I64>>,\n"
476 " <mul_2_sum_spread.outputs[2]: <I64>>,\n"
477 " <mul_2_sum_spread.outputs[3]: <I64>>,\n"
478 " <mul_setvl.outputs[0]: <VL_MAXVL>>)\n"
479 "dest_setvl:\n"
480 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
481 "store_dest:\n"
482 " SvStd(<mul_concat.outputs[0]: <I64*6>>,\n"
483 " <ptr_in.outputs[0]: <I64>>,\n"
484 " <dest_setvl.outputs[0]: <VL_MAXVL>>, 0x0)"
485 )
486
487 def test_simple_mul_192x192_reg_alloc(self):
488 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
489 fn = code.fn
490 assigned_registers = allocate_registers(
491 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
492 print(repr(assigned_registers))
493 self.assertEqual(
494 repr(assigned_registers), "{"
495 "<mul_1_cast_pp_zero.outputs[0]: <I64>>: "
496 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
497 "<mul_1_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
498 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
499 "<mul_1_pp_concat.out0.copy.outputs[0]: <I64*5>>: "
500 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
501 "<mul_1_cast_pp_zero.out0.copy.outputs[0]: <I64>>: "
502 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
503 "<mul_1_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
504 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
505 "<mul_1_pp_concat.inp0.copy.outputs[0]: <I64>>: "
506 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
507 "<mul_1_pp_concat.inp1.copy.outputs[0]: <I64>>: "
508 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
509 "<mul_1_pp_concat.inp2.copy.outputs[0]: <I64>>: "
510 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
511 "<mul_1_pp_concat.inp3.copy.outputs[0]: <I64>>: "
512 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
513 "<mul_1_pp_concat.inp4.copy.outputs[0]: <I64>>: "
514 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
515 "<mul_1_add.inp1.copy.outputs[0]: <I64*5>>: "
516 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
517 "<mul_1_mul.outputs[0]: <I64*3>>: "
518 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
519 "<mul_1_mul.out0.copy.outputs[0]: <I64*3>>: "
520 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
521 "<mul_1_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
522 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
523 "<mul_1_mul_rt_spread.outputs[0]: <I64>>: "
524 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
525 "<mul_1_mul_rt_spread.outputs[1]: <I64>>: "
526 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
527 "<mul_1_mul_rt_spread.outputs[2]: <I64>>: "
528 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
529 "<mul_1_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
530 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
531 "<mul_1_pp_concat.outputs[0]: <I64*5>>: "
532 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
533 "<mul_2_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
534 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
535 "<mul_2_pp_concat.outputs[0]: <I64*4>>: "
536 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
537 "<mul_2_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
538 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
539 "<mul_2_pp_concat.out0.copy.outputs[0]: <I64*4>>: "
540 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
541 "<mul_2_add.inp1.copy.outputs[0]: <I64*4>>: "
542 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
543 "<mul_2_mul.outputs[0]: <I64*3>>: "
544 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
545 "<mul_2_mul.out0.copy.outputs[0]: <I64*3>>: "
546 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
547 "<mul_2_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
548 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
549 "<mul_2_mul_rt_spread.outputs[0]: <I64>>: "
550 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
551 "<mul_2_mul_rt_spread.outputs[1]: <I64>>: "
552 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
553 "<mul_2_mul_rt_spread.outputs[2]: <I64>>: "
554 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
555 "<mul_2_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
556 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
557 "<mul_2_pp_concat.inp0.copy.outputs[0]: <I64>>: "
558 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
559 "<mul_2_pp_concat.inp1.copy.outputs[0]: <I64>>: "
560 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
561 "<mul_2_pp_concat.inp2.copy.outputs[0]: <I64>>: "
562 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
563 "<mul_2_pp_concat.inp3.copy.outputs[0]: <I64>>: "
564 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
565 "<mul_zero.outputs[0]: <I64>>: "
566 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
567 "<mul_0_mul.inp2.copy.outputs[0]: <I64>>: "
568 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
569 "<mul_0_mul.outputs[1]: <I64>>: "
570 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
571 "<mul_1_mul.inp2.copy.outputs[0]: <I64>>: "
572 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
573 "<mul_1_mul.outputs[1]: <I64>>: "
574 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
575 "<mul_1_mul.out1.copy.outputs[0]: <I64>>: "
576 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
577 "<mul_2_mul.inp2.copy.outputs[0]: <I64>>: "
578 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
579 "<mul_2_mul.outputs[1]: <I64>>: "
580 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
581 "<mul_2_mul.out1.copy.outputs[0]: <I64>>: "
582 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
583 "<mul_0_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
584 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
585 "<mul_0_mul_rt_spread.outputs[0]: <I64>>: "
586 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
587 "<mul_0_mul_rt_spread.outputs[1]: <I64>>: "
588 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
589 "<mul_0_mul_rt_spread.outputs[2]: <I64>>: "
590 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
591 "<mul_0_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
592 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
593 "<mul_0_mul.outputs[0]: <I64*3>>: "
594 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
595 "<mul_0_mul.out0.copy.outputs[0]: <I64*3>>: "
596 "Loc(kind=LocKind.GPR, start=14, reg_len=3), "
597 "<mul_1_retval_concat.outputs[0]: <I64*5>>: "
598 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
599 "<mul_1_add.inp0.copy.outputs[0]: <I64*5>>: "
600 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
601 "<mul_0_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
602 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
603 "<mul_0_mul.out1.copy.outputs[0]: <I64>>: "
604 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
605 "<mul_1_retval_concat.out0.copy.outputs[0]: <I64*5>>: "
606 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
607 "<mul_1_cast_retval_zero.out0.copy.outputs[0]: <I64>>: "
608 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
609 "<mul_1_retval_concat.inp0.copy.outputs[0]: <I64>>: "
610 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
611 "<mul_1_retval_concat.inp1.copy.outputs[0]: <I64>>: "
612 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
613 "<mul_1_retval_concat.inp2.copy.outputs[0]: <I64>>: "
614 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
615 "<mul_1_retval_concat.inp3.copy.outputs[0]: <I64>>: "
616 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
617 "<mul_1_retval_concat.inp4.copy.outputs[0]: <I64>>: "
618 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
619 "<mul_2_sum_spread.out3.copy.outputs[0]: <I64>>: "
620 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
621 "<mul_1_sum_spread.out0.copy.outputs[0]: <I64>>: "
622 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
623 "<mul_0_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
624 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
625 "<mul_concat.out0.copy.outputs[0]: <I64*6>>: "
626 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
627 "<mul_2_add.inp0.copy.outputs[0]: <I64*4>>: "
628 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
629 "<mul_1_add.outputs[0]: <I64*5>>: "
630 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
631 "<mul_1_add.out0.copy.outputs[0]: <I64*5>>: "
632 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
633 "<mul_2_sum_spread.out0.copy.outputs[0]: <I64>>: "
634 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
635 "<mul_concat.inp0.copy.outputs[0]: <I64>>: "
636 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
637 "<mul_concat.inp1.copy.outputs[0]: <I64>>: "
638 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
639 "<mul_concat.inp2.copy.outputs[0]: <I64>>: "
640 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
641 "<mul_concat.inp3.copy.outputs[0]: <I64>>: "
642 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
643 "<mul_concat.inp4.copy.outputs[0]: <I64>>: "
644 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
645 "<mul_concat.inp5.copy.outputs[0]: <I64>>: "
646 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
647 "<mul_1_sum_spread.inp0.copy.outputs[0]: <I64*5>>: "
648 "Loc(kind=LocKind.GPR, start=15, reg_len=5), "
649 "<mul_1_sum_spread.outputs[0]: <I64>>: "
650 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
651 "<mul_1_sum_spread.outputs[1]: <I64>>: "
652 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
653 "<mul_1_sum_spread.outputs[2]: <I64>>: "
654 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
655 "<mul_1_sum_spread.outputs[3]: <I64>>: "
656 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
657 "<mul_1_sum_spread.outputs[4]: <I64>>: "
658 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
659 "<mul_1_sum_spread.out2.copy.outputs[0]: <I64>>: "
660 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
661 "<mul_2_retval_concat.outputs[0]: <I64*4>>: "
662 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
663 "<mul_1_sum_spread.out3.copy.outputs[0]: <I64>>: "
664 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
665 "<mul_2_retval_concat.out0.copy.outputs[0]: <I64*4>>: "
666 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
667 "<mul_1_sum_spread.out4.copy.outputs[0]: <I64>>: "
668 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
669 "<mul_1_sum_spread.out1.copy.outputs[0]: <I64>>: "
670 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
671 "<mul_2_retval_concat.inp0.copy.outputs[0]: <I64>>: "
672 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
673 "<mul_2_retval_concat.inp1.copy.outputs[0]: <I64>>: "
674 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
675 "<mul_2_retval_concat.inp2.copy.outputs[0]: <I64>>: "
676 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
677 "<mul_2_retval_concat.inp3.copy.outputs[0]: <I64>>: "
678 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
679 "<mul_2_add.outputs[0]: <I64*4>>: "
680 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
681 "<mul_2_add.out0.copy.outputs[0]: <I64*4>>: "
682 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
683 "<mul_2_sum_spread.inp0.copy.outputs[0]: <I64*4>>: "
684 "Loc(kind=LocKind.GPR, start=16, reg_len=4), "
685 "<mul_2_sum_spread.outputs[0]: <I64>>: "
686 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
687 "<mul_2_sum_spread.outputs[1]: <I64>>: "
688 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
689 "<mul_2_sum_spread.outputs[2]: <I64>>: "
690 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
691 "<mul_2_sum_spread.outputs[3]: <I64>>: "
692 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
693 "<mul_2_sum_spread.out1.copy.outputs[0]: <I64>>: "
694 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
695 "<mul_concat.outputs[0]: <I64*6>>: "
696 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
697 "<mul_2_sum_spread.out2.copy.outputs[0]: <I64>>: "
698 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
699 "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
700 "Loc(kind=LocKind.GPR, start=14, reg_len=6), "
701 "<mul_1_cast_retval_zero.outputs[0]: <I64>>: "
702 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
703 "<mul_zero.out0.copy.outputs[0]: <I64>>: "
704 "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
705 "<ptr_in.out0.copy.outputs[0]: <I64>>: "
706 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
707 "<store_dest.inp1.copy.outputs[0]: <I64>>: "
708 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
709 "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
710 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
711 "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
712 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
713 "<ptr_in.outputs[0]: <I64>>: "
714 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
715 "<mul_2_mul.inp1.copy.outputs[0]: <I64>>: "
716 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
717 "<load_rhs.outputs[0]: <I64*3>>: "
718 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
719 "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
720 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
721 "<mul_rhs_spread.out2.copy.outputs[0]: <I64>>: "
722 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
723 "<mul_rhs_spread.out1.copy.outputs[0]: <I64>>: "
724 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
725 "<mul_1_mul.inp1.copy.outputs[0]: <I64>>: "
726 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
727 "<mul_rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
728 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
729 "<mul_rhs_spread.outputs[0]: <I64>>: "
730 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
731 "<mul_rhs_spread.outputs[1]: <I64>>: "
732 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
733 "<mul_rhs_spread.outputs[2]: <I64>>: "
734 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
735 "<mul_rhs_spread.out0.copy.outputs[0]: <I64>>: "
736 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
737 "<mul_0_mul.inp1.copy.outputs[0]: <I64>>: "
738 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
739 "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
740 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
741 "<load_lhs.outputs[0]: <I64*3>>: "
742 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
743 "<mul_0_mul.inp0.copy.outputs[0]: <I64*3>>: "
744 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
745 "<mul_1_mul.inp0.copy.outputs[0]: <I64*3>>: "
746 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
747 "<mul_2_mul.inp0.copy.outputs[0]: <I64*3>>: "
748 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
749 "<mul_zero2.outputs[0]: <I64>>: "
750 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
751 "<mul_zero2.out0.copy.outputs[0]: <I64>>: "
752 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
753 "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
754 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
755 "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
756 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
757 "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
758 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
759 "<mul_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
760 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
761 "<mul_concat.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
762 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
763 "<mul_setvl.outputs[0]: <VL_MAXVL>>: "
764 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
765 "<mul_2_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
766 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
767 "<mul_2_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
768 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
769 "<mul_2_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
770 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
771 "<mul_2_clear_ca.outputs[0]: <CA>>: "
772 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
773 "<mul_2_add.outputs[1]: <CA>>: "
774 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
775 "<mul_2_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
776 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
777 "<mul_2_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
778 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
779 "<mul_2_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
780 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
781 "<mul_2_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
782 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
783 "<mul_2_pp_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
784 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
785 "<mul_2_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
786 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
787 "<mul_2_retval_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
788 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
789 "<mul_2_setvl.outputs[0]: <VL_MAXVL>>: "
790 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
791 "<mul_2_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
792 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
793 "<mul_2_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
794 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
795 "<mul_2_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
796 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
797 "<mul_2_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
798 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
799 "<mul_2_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
800 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
801 "<mul_1_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
802 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
803 "<mul_1_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
804 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
805 "<mul_1_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
806 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
807 "<mul_1_clear_ca.outputs[0]: <CA>>: "
808 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
809 "<mul_1_add.outputs[1]: <CA>>: "
810 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
811 "<mul_1_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
812 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
813 "<mul_1_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
814 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
815 "<mul_1_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
816 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
817 "<mul_1_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
818 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
819 "<mul_1_pp_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
820 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
821 "<mul_1_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
822 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
823 "<mul_1_retval_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
824 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
825 "<mul_1_setvl.outputs[0]: <VL_MAXVL>>: "
826 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
827 "<mul_1_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
828 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
829 "<mul_1_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
830 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
831 "<mul_1_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
832 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
833 "<mul_1_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
834 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
835 "<mul_1_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
836 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
837 "<mul_0_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
838 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
839 "<mul_0_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
840 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
841 "<mul_0_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
842 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
843 "<mul_0_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
844 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
845 "<mul_0_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
846 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
847 "<mul_lhs_setvl.outputs[0]: <VL_MAXVL>>: "
848 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
849 "<mul_rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
850 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
851 "<mul_rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
852 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
853 "<mul_rhs_setvl.outputs[0]: <VL_MAXVL>>: "
854 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
855 "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
856 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
857 "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
858 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
859 "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
860 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
861 "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
862 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
863 "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
864 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
865 "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
866 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
867 "}")
868
869 def test_simple_mul_192x192_asm(self):
870 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
871 fn = code.fn
872 assigned_registers = allocate_registers(
873 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
874 gen_asm_state = GenAsmState(assigned_registers)
875 fn.gen_asm(gen_asm_state)
876 print(gen_asm_state.output)
877 self.assertEqual(gen_asm_state.output, [
878 'or 10, 3, 3',
879 'setvl 0, 0, 3, 0, 1, 1',
880 'setvl 0, 0, 3, 0, 1, 1',
881 'sv.ld *24, 48(10)',
882 'setvl 0, 0, 3, 0, 1, 1',
883 'setvl 0, 0, 3, 0, 1, 1',
884 'setvl 0, 0, 3, 0, 1, 1',
885 'sv.ld *20, 72(10)',
886 'setvl 0, 0, 3, 0, 1, 1',
887 'setvl 0, 0, 3, 0, 1, 1',
888 'setvl 0, 0, 3, 0, 1, 1',
889 'setvl 0, 0, 3, 0, 1, 1',
890 'addi 6, 0, 0',
891 'or 9, 6, 6',
892 'setvl 0, 0, 3, 0, 1, 1',
893 'addi 3, 0, 0',
894 'setvl 0, 0, 3, 0, 1, 1',
895 'or 6, 9, 9',
896 'setvl 0, 0, 3, 0, 1, 1',
897 'sv.maddedu *14, *24, 20, 6',
898 'setvl 0, 0, 3, 0, 1, 1',
899 'or 17, 6, 6',
900 'setvl 0, 0, 3, 0, 1, 1',
901 'setvl 0, 0, 3, 0, 1, 1',
902 'setvl 0, 0, 3, 0, 1, 1',
903 'or 6, 9, 9',
904 'setvl 0, 0, 3, 0, 1, 1',
905 'sv.maddedu *3, *24, 21, 6',
906 'setvl 0, 0, 3, 0, 1, 1',
907 'setvl 0, 0, 3, 0, 1, 1',
908 'setvl 0, 0, 3, 0, 1, 1',
909 'addi 8, 0, 0',
910 'or 18, 8, 8',
911 'addi 7, 0, 0',
912 'setvl 0, 0, 5, 0, 1, 1',
913 'or 19, 18, 18',
914 'setvl 0, 0, 5, 0, 1, 1',
915 'setvl 0, 0, 5, 0, 1, 1',
916 'setvl 0, 0, 5, 0, 1, 1',
917 'setvl 0, 0, 5, 0, 1, 1',
918 'addic 0, 0, 0',
919 'setvl 0, 0, 5, 0, 1, 1',
920 'setvl 0, 0, 5, 0, 1, 1',
921 'setvl 0, 0, 5, 0, 1, 1',
922 'sv.adde *15, *15, *3',
923 'setvl 0, 0, 5, 0, 1, 1',
924 'setvl 0, 0, 5, 0, 1, 1',
925 'setvl 0, 0, 5, 0, 1, 1',
926 'setvl 0, 0, 3, 0, 1, 1',
927 'or 6, 9, 9',
928 'setvl 0, 0, 3, 0, 1, 1',
929 'sv.maddedu *3, *24, 22, 6',
930 'setvl 0, 0, 3, 0, 1, 1',
931 'setvl 0, 0, 3, 0, 1, 1',
932 'setvl 0, 0, 3, 0, 1, 1',
933 'setvl 0, 0, 4, 0, 1, 1',
934 'setvl 0, 0, 4, 0, 1, 1',
935 'setvl 0, 0, 4, 0, 1, 1',
936 'setvl 0, 0, 4, 0, 1, 1',
937 'setvl 0, 0, 4, 0, 1, 1',
938 'addic 0, 0, 0',
939 'setvl 0, 0, 4, 0, 1, 1',
940 'setvl 0, 0, 4, 0, 1, 1',
941 'setvl 0, 0, 4, 0, 1, 1',
942 'sv.adde *16, *16, *3',
943 'setvl 0, 0, 4, 0, 1, 1',
944 'setvl 0, 0, 4, 0, 1, 1',
945 'setvl 0, 0, 4, 0, 1, 1',
946 'setvl 0, 0, 6, 0, 1, 1',
947 'setvl 0, 0, 6, 0, 1, 1',
948 'setvl 0, 0, 6, 0, 1, 1',
949 'setvl 0, 0, 6, 0, 1, 1',
950 'setvl 0, 0, 6, 0, 1, 1',
951 'setvl 0, 0, 6, 0, 1, 1',
952 'sv.std *14, 0(10)'
953 ])
954
955 def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
956 # type: (bool, bool) -> Mul
957 TOOM_2 = ToomCookInstance.make_toom_2()
958 instances = TOOM_2,
959
960 def mul(fn, lhs, rhs):
961 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
962 v = ToomCookMul(fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
963 rhs_signed=rhs_signed, instances=instances)
964 return v.retval, v
965 return Mul(mul=mul, lhs_size_in_words=4, rhs_size_in_words=4)
966
967 def make_256x256_mul_test_cases(self, lhs_signed, rhs_signed):
968 # type: (bool, bool) -> Iterator[tuple[int, int, int]]
969 # test multiplying `+-1 << n` and:
970 # 0xc162321a5eaad80b_4b86bb0efdfb93c0_a789ff04cc11b157_eaa08e29fb197621
971 # *
972 # 0x3138710167583371_998af336a8fac64d_e6da3737090787fe_85ba09ea701f4af2
973 # ==
974 # int("0x"
975 # "252e6e6f69746163_696c7069746c754d_"
976 # "2061627573746172_614b202d20322d4d_"
977 # "4f4f5420676e6973_75206c756d20746e_"
978 # "6967696220746962_2d36353278363532", base=0)
979 # == int.from_bytes(b'256x256-bit bigint mul using TOOM-2 '
980 # b'- Karatsuba Multiplication.%', 'little')
981 lhs_value_in = (0xc162321a5eaad80b_4b86bb0efdfb93c0 << 128) \
982 | 0xa789ff04cc11b157_eaa08e29fb197621
983 rhs_value_in = (0x3138710167583371_998af336a8fac64d << 128) \
984 | 0xe6da3737090787fe_85ba09ea701f4af2
985 prod_value_in = int.from_bytes(
986 b'256x256-bit bigint mul using TOOM-2 '
987 b'- Karatsuba Multiplication.%', 'little')
988 self.assertEqual(lhs_value_in * rhs_value_in, prod_value_in)
989 shifts = [*range(0, 256, 16), *range(15, 256, 16)]
990 lhs_values = [1 << i for i in shifts] + [0, lhs_value_in]
991 rhs_values = [1 << i for i in shifts] + [0, rhs_value_in]
992 if lhs_signed:
993 lhs_values.extend([-i for i in lhs_values])
994 if rhs_signed:
995 rhs_values.extend([-i for i in rhs_values])
996
997 def key(v):
998 # type: (int) -> tuple[bool, int]
999 return abs(v) in (lhs_value_in, rhs_value_in), v % (1 << 256)
1000
1001 lhs_values.sort(key=key)
1002 rhs_values.sort(key=key)
1003 for lhs_value in lhs_values:
1004 for rhs_value in rhs_values:
1005 lhs_value %= 1 << 256
1006 rhs_value %= 1 << 256
1007 if lhs_value >> 255 != 0 and lhs_signed:
1008 lhs_value -= 1 << 256
1009 if rhs_value >> 255 != 0 and rhs_signed:
1010 rhs_value -= 1 << 256
1011 prod_value = lhs_value * rhs_value
1012 lhs_value %= 1 << 256
1013 rhs_value %= 1 << 256
1014 prod_value %= 1 << 512
1015 yield lhs_value, rhs_value, prod_value
1016
1017 def tst_toom_2_mul_256x256_sim(
1018 self, lhs_signed, # type: bool
1019 rhs_signed, # type: bool
1020 get_state_factory, # type: Callable[[Mul], _StateFactory]
1021 ):
1022 code = self.toom_2_mul_256x256(
1023 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1024 print(code.retval[1])
1025 print(code.fn.ops_to_str())
1026 state_factory = get_state_factory(code)
1027 ptr_in = 0x100
1028 dest_ptr = ptr_in + code.dest_offset
1029 lhs_ptr = ptr_in + code.lhs_offset
1030 rhs_ptr = ptr_in + code.rhs_offset
1031 values = self.make_256x256_mul_test_cases(
1032 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1033 for lhs_value, rhs_value, prod_value in values:
1034 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1035 lhs_value=hex(lhs_value),
1036 rhs_value=hex(rhs_value),
1037 prod_value=hex(prod_value)):
1038 with state_factory() as state:
1039 state[code.ptr_in] = ptr_in,
1040 for i in range(4):
1041 v = lhs_value >> GPR_SIZE_IN_BITS * i
1042 v &= GPR_VALUE_MASK
1043 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1044 for i in range(4):
1045 v = rhs_value >> GPR_SIZE_IN_BITS * i
1046 v &= GPR_VALUE_MASK
1047 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1048 code.fn.sim(state)
1049 prod = 0
1050 for i in range(8):
1051 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1052 prod += v << (GPR_SIZE_IN_BITS * i)
1053 self.assertEqual(hex(prod), hex(prod_value),
1054 f"failed: state={state}")
1055
1056 def test_toom_2_mul_256x256_pre_ra_sim(self):
1057 for lhs_signed in False, True:
1058 for rhs_signed in False, True:
1059 self.tst_toom_2_mul_256x256_sim(
1060 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1061 get_state_factory=get_pre_ra_state_factory)
1062
1063 def test_toom_2_mul_256x256_uu_post_ra_sim(self):
1064 self.tst_toom_2_mul_256x256_sim(
1065 lhs_signed=False, rhs_signed=False,
1066 get_state_factory=self.get_post_ra_state_factory)
1067
1068 def test_toom_2_mul_256x256_su_post_ra_sim(self):
1069 self.tst_toom_2_mul_256x256_sim(
1070 lhs_signed=True, rhs_signed=False,
1071 get_state_factory=self.get_post_ra_state_factory)
1072
1073 def test_toom_2_mul_256x256_us_post_ra_sim(self):
1074 self.tst_toom_2_mul_256x256_sim(
1075 lhs_signed=False, rhs_signed=True,
1076 get_state_factory=self.get_post_ra_state_factory)
1077
1078 def test_toom_2_mul_256x256_ss_post_ra_sim(self):
1079 self.tst_toom_2_mul_256x256_sim(
1080 lhs_signed=True, rhs_signed=True,
1081 get_state_factory=self.get_post_ra_state_factory)
1082
1083 def test_toom_2_mul_256x256_asm(self):
1084 code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
1085 fn = code.fn
1086 assigned_registers = allocate_registers(
1087 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
1088 gen_asm_state = GenAsmState(assigned_registers)
1089 fn.gen_asm(gen_asm_state)
1090 print(gen_asm_state.output)
1091 self.assertEqual(gen_asm_state.output, [
1092 'or 49, 3, 3',
1093 'setvl 0, 0, 4, 0, 1, 1',
1094 'or 3, 49, 49',
1095 'setvl 0, 0, 4, 0, 1, 1',
1096 'sv.ld *39, 64(3)',
1097 'setvl 0, 0, 4, 0, 1, 1',
1098 'setvl 0, 0, 4, 0, 1, 1',
1099 'or 3, 49, 49',
1100 'setvl 0, 0, 4, 0, 1, 1',
1101 'sv.ld *14, 96(3)',
1102 'setvl 0, 0, 4, 0, 1, 1',
1103 'sv.or *3, *14, *14',
1104 'setvl 0, 0, 4, 0, 1, 1',
1105 'setvl 0, 0, 4, 0, 1, 1',
1106 'sv.or *14, *39, *39',
1107 'setvl 0, 0, 4, 0, 1, 1',
1108 'or 32, 16, 16',
1109 'setvl 0, 0, 3, 0, 1, 1',
1110 'or 16, 32, 32',
1111 'setvl 0, 0, 3, 0, 1, 1',
1112 'setvl 0, 0, 3, 0, 1, 1',
1113 'sv.or *30, *14, *14',
1114 'setvl 0, 0, 3, 0, 1, 1',
1115 'setvl 0, 0, 3, 0, 1, 1',
1116 'or 48, 17, 17',
1117 'setvl 0, 0, 3, 0, 1, 1',
1118 'setvl 0, 0, 3, 0, 1, 1',
1119 'sv.or *14, *30, *30',
1120 'setvl 0, 0, 3, 0, 1, 1',
1121 'or 41, 16, 16',
1122 'addi 7, 0, 0',
1123 'setvl 0, 0, 4, 0, 1, 1',
1124 'or 16, 41, 41',
1125 'or 17, 7, 7',
1126 'setvl 0, 0, 4, 0, 1, 1',
1127 'setvl 0, 0, 4, 0, 1, 1',
1128 'sv.or *30, *14, *14',
1129 'setvl 0, 0, 1, 0, 1, 1',
1130 'or 17, 48, 48',
1131 'setvl 0, 0, 1, 0, 1, 1',
1132 'sv.or *7, *17, *17',
1133 'addi 8, 0, 0',
1134 'setvl 0, 0, 4, 0, 1, 1',
1135 'or 9, 8, 8',
1136 'or 10, 8, 8',
1137 'setvl 0, 0, 4, 0, 1, 1',
1138 'setvl 0, 0, 4, 0, 1, 1',
1139 'sv.or *17, *7, *7',
1140 'setvl 0, 0, 4, 0, 1, 1',
1141 'addic 0, 0, 0',
1142 'setvl 0, 0, 4, 0, 1, 1',
1143 'sv.or *39, *30, *30',
1144 'setvl 0, 0, 4, 0, 1, 1',
1145 'setvl 0, 0, 4, 0, 1, 1',
1146 'sv.adde *34, *39, *17',
1147 'setvl 0, 0, 4, 0, 1, 1',
1148 'sv.or *44, *34, *34',
1149 'setvl 0, 0, 4, 0, 1, 1',
1150 'setvl 0, 0, 4, 0, 1, 1',
1151 'sv.or *14, *3, *3',
1152 'setvl 0, 0, 4, 0, 1, 1',
1153 'or 3, 14, 14',
1154 'or 10, 17, 17',
1155 'setvl 0, 0, 3, 0, 1, 1',
1156 'or 14, 3, 3',
1157 'setvl 0, 0, 3, 0, 1, 1',
1158 'setvl 0, 0, 3, 0, 1, 1',
1159 'setvl 0, 0, 3, 0, 1, 1',
1160 'setvl 0, 0, 3, 0, 1, 1',
1161 'sv.or *17, *10, *10',
1162 'or 43, 17, 17',
1163 'setvl 0, 0, 3, 0, 1, 1',
1164 'setvl 0, 0, 3, 0, 1, 1',
1165 'sv.or *7, *14, *14',
1166 'setvl 0, 0, 3, 0, 1, 1',
1167 'or 3, 7, 7',
1168 'or 4, 8, 8',
1169 'or 5, 9, 9',
1170 'addi 7, 0, 0',
1171 'setvl 0, 0, 4, 0, 1, 1',
1172 'or 6, 7, 7',
1173 'setvl 0, 0, 4, 0, 1, 1',
1174 'setvl 0, 0, 4, 0, 1, 1',
1175 'setvl 0, 0, 1, 0, 1, 1',
1176 'or 17, 43, 43',
1177 'setvl 0, 0, 1, 0, 1, 1',
1178 'addi 7, 0, 0',
1179 'setvl 0, 0, 4, 0, 1, 1',
1180 'or 18, 7, 7',
1181 'or 19, 7, 7',
1182 'or 20, 7, 7',
1183 'setvl 0, 0, 4, 0, 1, 1',
1184 'setvl 0, 0, 4, 0, 1, 1',
1185 'setvl 0, 0, 4, 0, 1, 1',
1186 'addic 0, 0, 0',
1187 'setvl 0, 0, 4, 0, 1, 1',
1188 'setvl 0, 0, 4, 0, 1, 1',
1189 'setvl 0, 0, 4, 0, 1, 1',
1190 'sv.adde *7, *3, *17',
1191 'setvl 0, 0, 4, 0, 1, 1',
1192 'sv.or *38, *7, *7',
1193 'setvl 0, 0, 4, 0, 1, 1',
1194 'setvl 0, 0, 4, 0, 1, 1',
1195 'setvl 0, 0, 4, 0, 1, 1',
1196 'addi 7, 0, 0',
1197 'or 35, 7, 7',
1198 'setvl 0, 0, 4, 0, 1, 1',
1199 'addi 7, 0, 0',
1200 'setvl 0, 0, 4, 0, 1, 1',
1201 'sv.or *14, *30, *30',
1202 'or 7, 3, 3',
1203 'or 21, 35, 35',
1204 'setvl 0, 0, 4, 0, 1, 1',
1205 'sv.maddedu *22, *14, 7, 21',
1206 'setvl 0, 0, 4, 0, 1, 1',
1207 'or 26, 21, 21',
1208 'setvl 0, 0, 4, 0, 1, 1',
1209 'setvl 0, 0, 4, 0, 1, 1',
1210 'setvl 0, 0, 4, 0, 1, 1',
1211 'sv.or *14, *30, *30',
1212 'or 34, 4, 4',
1213 'or 21, 35, 35',
1214 'setvl 0, 0, 4, 0, 1, 1',
1215 'sv.maddedu *7, *14, 34, 21',
1216 'setvl 0, 0, 4, 0, 1, 1',
1217 'or 11, 21, 21',
1218 'setvl 0, 0, 4, 0, 1, 1',
1219 'setvl 0, 0, 4, 0, 1, 1',
1220 'addi 14, 0, 0',
1221 'or 27, 14, 14',
1222 'addi 14, 0, 0',
1223 'or 12, 14, 14',
1224 'setvl 0, 0, 6, 0, 1, 1',
1225 'or 28, 27, 27',
1226 'setvl 0, 0, 6, 0, 1, 1',
1227 'setvl 0, 0, 6, 0, 1, 1',
1228 'setvl 0, 0, 6, 0, 1, 1',
1229 'setvl 0, 0, 6, 0, 1, 1',
1230 'addic 0, 0, 0',
1231 'setvl 0, 0, 6, 0, 1, 1',
1232 'setvl 0, 0, 6, 0, 1, 1',
1233 'setvl 0, 0, 6, 0, 1, 1',
1234 'sv.adde *23, *23, *7',
1235 'setvl 0, 0, 6, 0, 1, 1',
1236 'setvl 0, 0, 6, 0, 1, 1',
1237 'setvl 0, 0, 6, 0, 1, 1',
1238 'setvl 0, 0, 4, 0, 1, 1',
1239 'sv.or *14, *30, *30',
1240 'or 34, 5, 5',
1241 'or 21, 35, 35',
1242 'setvl 0, 0, 4, 0, 1, 1',
1243 'sv.maddedu *7, *14, 34, 21',
1244 'setvl 0, 0, 4, 0, 1, 1',
1245 'or 11, 21, 21',
1246 'setvl 0, 0, 4, 0, 1, 1',
1247 'setvl 0, 0, 4, 0, 1, 1',
1248 'addi 14, 0, 0',
1249 'or 29, 14, 14',
1250 'addi 14, 0, 0',
1251 'or 12, 14, 14',
1252 'setvl 0, 0, 6, 0, 1, 1',
1253 'setvl 0, 0, 6, 0, 1, 1',
1254 'setvl 0, 0, 6, 0, 1, 1',
1255 'setvl 0, 0, 6, 0, 1, 1',
1256 'setvl 0, 0, 6, 0, 1, 1',
1257 'addic 0, 0, 0',
1258 'setvl 0, 0, 6, 0, 1, 1',
1259 'setvl 0, 0, 6, 0, 1, 1',
1260 'setvl 0, 0, 6, 0, 1, 1',
1261 'sv.adde *24, *24, *7',
1262 'setvl 0, 0, 6, 0, 1, 1',
1263 'setvl 0, 0, 6, 0, 1, 1',
1264 'setvl 0, 0, 6, 0, 1, 1',
1265 'setvl 0, 0, 4, 0, 1, 1',
1266 'sv.or *14, *30, *30',
1267 'or 8, 6, 6',
1268 'or 21, 35, 35',
1269 'setvl 0, 0, 4, 0, 1, 1',
1270 'sv.maddedu *3, *14, 8, 21',
1271 'setvl 0, 0, 4, 0, 1, 1',
1272 'or 7, 21, 21',
1273 'setvl 0, 0, 4, 0, 1, 1',
1274 'setvl 0, 0, 4, 0, 1, 1',
1275 'setvl 0, 0, 5, 0, 1, 1',
1276 'setvl 0, 0, 5, 0, 1, 1',
1277 'setvl 0, 0, 5, 0, 1, 1',
1278 'setvl 0, 0, 5, 0, 1, 1',
1279 'setvl 0, 0, 5, 0, 1, 1',
1280 'addic 0, 0, 0',
1281 'setvl 0, 0, 5, 0, 1, 1',
1282 'setvl 0, 0, 5, 0, 1, 1',
1283 'setvl 0, 0, 5, 0, 1, 1',
1284 'sv.adde *25, *25, *3',
1285 'setvl 0, 0, 5, 0, 1, 1',
1286 'setvl 0, 0, 5, 0, 1, 1',
1287 'setvl 0, 0, 5, 0, 1, 1',
1288 'setvl 0, 0, 8, 0, 1, 1',
1289 'setvl 0, 0, 8, 0, 1, 1',
1290 'setvl 0, 0, 8, 0, 1, 1',
1291 'setvl 0, 0, 4, 0, 1, 1',
1292 'setvl 0, 0, 4, 0, 1, 1',
1293 'sv.or *7, *38, *38',
1294 'setvl 0, 0, 4, 0, 1, 1',
1295 'or 38, 7, 7',
1296 'or 39, 8, 8',
1297 'or 40, 9, 9',
1298 'or 41, 10, 10',
1299 'addi 3, 0, 0',
1300 'or 10, 3, 3',
1301 'setvl 0, 0, 4, 0, 1, 1',
1302 'addi 3, 0, 0',
1303 'setvl 0, 0, 4, 0, 1, 1',
1304 'sv.or *34, *44, *44',
1305 'or 9, 38, 38',
1306 'or 7, 10, 10',
1307 'setvl 0, 0, 4, 0, 1, 1',
1308 'sv.maddedu *14, *34, 9, 7',
1309 'setvl 0, 0, 4, 0, 1, 1',
1310 'or 18, 7, 7',
1311 'setvl 0, 0, 4, 0, 1, 1',
1312 'setvl 0, 0, 4, 0, 1, 1',
1313 'or 30, 14, 14',
1314 'setvl 0, 0, 4, 0, 1, 1',
1315 'or 9, 39, 39',
1316 'or 7, 10, 10',
1317 'setvl 0, 0, 4, 0, 1, 1',
1318 'sv.maddedu *3, *44, 9, 7',
1319 'setvl 0, 0, 4, 0, 1, 1',
1320 'setvl 0, 0, 4, 0, 1, 1',
1321 'setvl 0, 0, 4, 0, 1, 1',
1322 'addi 9, 0, 0',
1323 'or 19, 9, 9',
1324 'addi 9, 0, 0',
1325 'or 8, 9, 9',
1326 'setvl 0, 0, 6, 0, 1, 1',
1327 'or 20, 19, 19',
1328 'setvl 0, 0, 6, 0, 1, 1',
1329 'setvl 0, 0, 6, 0, 1, 1',
1330 'setvl 0, 0, 6, 0, 1, 1',
1331 'setvl 0, 0, 6, 0, 1, 1',
1332 'addic 0, 0, 0',
1333 'setvl 0, 0, 6, 0, 1, 1',
1334 'setvl 0, 0, 6, 0, 1, 1',
1335 'setvl 0, 0, 6, 0, 1, 1',
1336 'sv.adde *15, *15, *3',
1337 'setvl 0, 0, 6, 0, 1, 1',
1338 'setvl 0, 0, 6, 0, 1, 1',
1339 'setvl 0, 0, 6, 0, 1, 1',
1340 'or 31, 15, 15',
1341 'setvl 0, 0, 4, 0, 1, 1',
1342 'or 9, 40, 40',
1343 'or 7, 10, 10',
1344 'setvl 0, 0, 4, 0, 1, 1',
1345 'sv.maddedu *3, *44, 9, 7',
1346 'setvl 0, 0, 4, 0, 1, 1',
1347 'setvl 0, 0, 4, 0, 1, 1',
1348 'setvl 0, 0, 4, 0, 1, 1',
1349 'addi 9, 0, 0',
1350 'or 21, 9, 9',
1351 'addi 8, 0, 0',
1352 'setvl 0, 0, 6, 0, 1, 1',
1353 'setvl 0, 0, 6, 0, 1, 1',
1354 'setvl 0, 0, 6, 0, 1, 1',
1355 'setvl 0, 0, 6, 0, 1, 1',
1356 'setvl 0, 0, 6, 0, 1, 1',
1357 'addic 0, 0, 0',
1358 'setvl 0, 0, 6, 0, 1, 1',
1359 'setvl 0, 0, 6, 0, 1, 1',
1360 'setvl 0, 0, 6, 0, 1, 1',
1361 'sv.adde *16, *16, *3',
1362 'setvl 0, 0, 6, 0, 1, 1',
1363 'setvl 0, 0, 6, 0, 1, 1',
1364 'setvl 0, 0, 6, 0, 1, 1',
1365 'or 32, 16, 16',
1366 'setvl 0, 0, 4, 0, 1, 1',
1367 'or 9, 41, 41',
1368 'or 7, 10, 10',
1369 'setvl 0, 0, 4, 0, 1, 1',
1370 'sv.maddedu *38, *44, 9, 7',
1371 'setvl 0, 0, 4, 0, 1, 1',
1372 'or 42, 7, 7',
1373 'setvl 0, 0, 4, 0, 1, 1',
1374 'setvl 0, 0, 4, 0, 1, 1',
1375 'setvl 0, 0, 5, 0, 1, 1',
1376 'setvl 0, 0, 5, 0, 1, 1',
1377 'setvl 0, 0, 5, 0, 1, 1',
1378 'setvl 0, 0, 5, 0, 1, 1',
1379 'setvl 0, 0, 5, 0, 1, 1',
1380 'addic 0, 0, 0',
1381 'setvl 0, 0, 5, 0, 1, 1',
1382 'setvl 0, 0, 5, 0, 1, 1',
1383 'setvl 0, 0, 5, 0, 1, 1',
1384 'sv.adde *17, *17, *38',
1385 'setvl 0, 0, 5, 0, 1, 1',
1386 'setvl 0, 0, 5, 0, 1, 1',
1387 'setvl 0, 0, 5, 0, 1, 1',
1388 'or 35, 19, 19',
1389 'or 36, 20, 20',
1390 'or 37, 21, 21',
1391 'setvl 0, 0, 8, 0, 1, 1',
1392 'or 14, 30, 30',
1393 'or 15, 31, 31',
1394 'or 16, 32, 32',
1395 'or 19, 35, 35',
1396 'or 20, 36, 36',
1397 'or 21, 37, 37',
1398 'setvl 0, 0, 8, 0, 1, 1',
1399 'setvl 0, 0, 8, 0, 1, 1',
1400 'sv.or *30, *14, *14',
1401 'setvl 0, 0, 1, 0, 1, 1',
1402 'or 10, 43, 43',
1403 'setvl 0, 0, 1, 0, 1, 1',
1404 'or 43, 10, 10',
1405 'addi 3, 0, 0',
1406 'setvl 0, 0, 1, 0, 1, 1',
1407 'addi 4, 0, 0',
1408 'or 42, 48, 48',
1409 'or 10, 43, 43',
1410 'or 4, 3, 3',
1411 'setvl 0, 0, 1, 0, 1, 1',
1412 'sv.maddedu *3, *42, 10, 4',
1413 'setvl 0, 0, 1, 0, 1, 1',
1414 'setvl 0, 0, 2, 0, 1, 1',
1415 'setvl 0, 0, 2, 0, 1, 1',
1416 'setvl 0, 0, 2, 0, 1, 1',
1417 'setvl 0, 0, 8, 0, 1, 1',
1418 'setvl 0, 0, 8, 0, 1, 1',
1419 'setvl 0, 0, 8, 0, 1, 1',
1420 'setvl 0, 0, 7, 0, 1, 1',
1421 'setvl 0, 0, 7, 0, 1, 1',
1422 'setvl 0, 0, 7, 0, 1, 1',
1423 'setvl 0, 0, 8, 0, 1, 1',
1424 'setvl 0, 0, 8, 0, 1, 1',
1425 'sv.or *14, *30, *30',
1426 'setvl 0, 0, 8, 0, 1, 1',
1427 'or 30, 14, 14',
1428 'or 31, 15, 15',
1429 'or 32, 16, 16',
1430 'or 35, 19, 19',
1431 'or 36, 20, 20',
1432 'or 37, 21, 21',
1433 'setvl 0, 0, 7, 0, 1, 1',
1434 'or 14, 30, 30',
1435 'or 15, 31, 31',
1436 'or 16, 32, 32',
1437 'or 19, 35, 35',
1438 'or 20, 36, 36',
1439 'setvl 0, 0, 7, 0, 1, 1',
1440 'setvl 0, 0, 7, 0, 1, 1',
1441 'sv.or *30, *14, *14',
1442 'setvl 0, 0, 7, 0, 1, 1',
1443 'subfc 0, 0, 0',
1444 'setvl 0, 0, 7, 0, 1, 1',
1445 'setvl 0, 0, 7, 0, 1, 1',
1446 'sv.or *14, *30, *30',
1447 'setvl 0, 0, 7, 0, 1, 1',
1448 'sv.subfe *30, *22, *14',
1449 'setvl 0, 0, 7, 0, 1, 1',
1450 'setvl 0, 0, 2, 0, 1, 1',
1451 'setvl 0, 0, 2, 0, 1, 1',
1452 'setvl 0, 0, 2, 0, 1, 1',
1453 'addi 10, 0, 0',
1454 'setvl 0, 0, 7, 0, 1, 1',
1455 'or 5, 10, 10',
1456 'or 6, 10, 10',
1457 'or 7, 10, 10',
1458 'or 8, 10, 10',
1459 'or 9, 10, 10',
1460 'setvl 0, 0, 7, 0, 1, 1',
1461 'setvl 0, 0, 7, 0, 1, 1',
1462 'setvl 0, 0, 7, 0, 1, 1',
1463 'subfc 0, 0, 0',
1464 'setvl 0, 0, 7, 0, 1, 1',
1465 'setvl 0, 0, 7, 0, 1, 1',
1466 'setvl 0, 0, 7, 0, 1, 1',
1467 'sv.subfe *14, *3, *30',
1468 'setvl 0, 0, 7, 0, 1, 1',
1469 'setvl 0, 0, 7, 0, 1, 1',
1470 'setvl 0, 0, 7, 0, 1, 1',
1471 'setvl 0, 0, 7, 0, 1, 1',
1472 'setvl 0, 0, 7, 0, 1, 1',
1473 'setvl 0, 0, 7, 0, 1, 1',
1474 'setvl 0, 0, 7, 0, 1, 1',
1475 'setvl 0, 0, 2, 0, 1, 1',
1476 'setvl 0, 0, 2, 0, 1, 1',
1477 'setvl 0, 0, 2, 0, 1, 1',
1478 'addi 10, 0, 0',
1479 'addi 10, 0, 0',
1480 'or 29, 10, 10',
1481 'setvl 0, 0, 5, 0, 1, 1',
1482 'setvl 0, 0, 5, 0, 1, 1',
1483 'setvl 0, 0, 5, 0, 1, 1',
1484 'setvl 0, 0, 5, 0, 1, 1',
1485 'setvl 0, 0, 5, 0, 1, 1',
1486 'addic 0, 0, 0',
1487 'setvl 0, 0, 5, 0, 1, 1',
1488 'setvl 0, 0, 5, 0, 1, 1',
1489 'setvl 0, 0, 5, 0, 1, 1',
1490 'sv.adde *25, *25, *14',
1491 'setvl 0, 0, 5, 0, 1, 1',
1492 'setvl 0, 0, 5, 0, 1, 1',
1493 'setvl 0, 0, 5, 0, 1, 1',
1494 'setvl 0, 0, 2, 0, 1, 1',
1495 'setvl 0, 0, 2, 0, 1, 1',
1496 'setvl 0, 0, 2, 0, 1, 1',
1497 'setvl 0, 0, 2, 0, 1, 1',
1498 'setvl 0, 0, 2, 0, 1, 1',
1499 'addic 0, 0, 0',
1500 'setvl 0, 0, 2, 0, 1, 1',
1501 'setvl 0, 0, 2, 0, 1, 1',
1502 'setvl 0, 0, 2, 0, 1, 1',
1503 'sv.adde *28, *28, *3',
1504 'setvl 0, 0, 2, 0, 1, 1',
1505 'setvl 0, 0, 2, 0, 1, 1',
1506 'setvl 0, 0, 2, 0, 1, 1',
1507 'setvl 0, 0, 8, 0, 1, 1',
1508 'setvl 0, 0, 8, 0, 1, 1',
1509 'setvl 0, 0, 8, 0, 1, 1',
1510 'setvl 0, 0, 8, 0, 1, 1',
1511 'setvl 0, 0, 8, 0, 1, 1',
1512 'or 3, 49, 49',
1513 'setvl 0, 0, 8, 0, 1, 1',
1514 'sv.std *22, 0(3)',
1515 ])
1516
1517 def tst_toom_mul_sim(
1518 self, code, # type: Mul
1519 lhs_signed, # type: bool
1520 rhs_signed, # type: bool
1521 get_state_factory, # type: Callable[[Mul], _StateFactory]
1522 test_cases, # type: Iterable[tuple[int, int]]
1523 ):
1524 print(code.retval[1])
1525 print(code.fn.ops_to_str())
1526 state_factory = get_state_factory(code)
1527 ptr_in = 0x100
1528 dest_ptr = ptr_in + code.dest_offset
1529 lhs_ptr = ptr_in + code.lhs_offset
1530 rhs_ptr = ptr_in + code.rhs_offset
1531 lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS
1532 rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS
1533 for lhs_value, rhs_value in test_cases:
1534 lhs_value %= 1 << lhs_size_in_bits
1535 rhs_value %= 1 << rhs_size_in_bits
1536 if lhs_signed and lhs_value >> (lhs_size_in_bits - 1):
1537 lhs_value -= 1 << lhs_size_in_bits
1538 if rhs_signed and rhs_value >> (rhs_size_in_bits - 1):
1539 rhs_value -= 1 << rhs_size_in_bits
1540 prod_value = lhs_value * rhs_value
1541 lhs_value %= 1 << lhs_size_in_bits
1542 rhs_value %= 1 << rhs_size_in_bits
1543 prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits)
1544 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1545 lhs_value=hex(lhs_value),
1546 rhs_value=hex(rhs_value),
1547 prod_value=hex(prod_value)):
1548 with state_factory() as state:
1549 state[code.ptr_in] = ptr_in,
1550 for i in range(code.lhs_size_in_words):
1551 v = lhs_value >> GPR_SIZE_IN_BITS * i
1552 v &= GPR_VALUE_MASK
1553 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1554 for i in range(code.rhs_size_in_words):
1555 v = rhs_value >> GPR_SIZE_IN_BITS * i
1556 v &= GPR_VALUE_MASK
1557 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1558 code.fn.sim(state)
1559 prod = 0
1560 for i in range(code.dest_size_in_words):
1561 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1562 prod += v << (GPR_SIZE_IN_BITS * i)
1563 self.assertEqual(hex(prod), hex(prod_value),
1564 f"failed: state={state}")
1565
1566 def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed):
1567 # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
1568 def mul(fn, lhs, rhs):
1569 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
1570 v = ToomCookMul(
1571 fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
1572 rhs_signed=rhs_signed, instances=instances)
1573 return v.retval, v
1574 sizes_in_words = OSet() # type: OSet[int]
1575 for i in range(6):
1576 sizes_in_words.add(1 << i)
1577 sizes_in_words.add(3 << i)
1578 sizes_in_words = OSet(
1579 i for i in sorted(sizes_in_words) if 1 <= i <= 16)
1580 for lhs_size_in_words in sizes_in_words:
1581 for rhs_size_in_words in sizes_in_words:
1582 lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
1583 rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
1584 with self.subTest(lhs_size_in_words=lhs_size_in_words,
1585 rhs_size_in_words=rhs_size_in_words,
1586 lhs_signed=lhs_signed,
1587 rhs_signed=rhs_signed):
1588 test_cases = [] # type: list[tuple[int, int]]
1589 test_cases.append((-1, -1))
1590 test_cases.append(((0x80 << 2048) // 0xFF,
1591 (0x80 << 2048) // 0xFF))
1592 test_cases.append(((0x40 << 2048) // 0xFF,
1593 (0x80 << 2048) // 0xFF))
1594 test_cases.append(((0x80 << 2048) // 0xFF,
1595 (0x40 << 2048) // 0xFF))
1596 test_cases.append(((0x40 << 2048) // 0xFF,
1597 (0x40 << 2048) // 0xFF))
1598 test_cases.append((1 << (lhs_size_in_bits - 1),
1599 1 << (rhs_size_in_bits - 1)))
1600 test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
1601 test_cases.append((1 << (lhs_size_in_bits - 1), 1))
1602 test_cases.append((1, 1))
1603 self.tst_toom_mul_sim(
1604 code=Mul(mul=mul,
1605 lhs_size_in_words=lhs_size_in_words,
1606 rhs_size_in_words=rhs_size_in_words),
1607 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1608 get_state_factory=get_pre_ra_state_factory,
1609 test_cases=test_cases)
1610
1611 def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self):
1612 TOOM_2 = ToomCookInstance.make_toom_2()
1613 self.tst_toom_mul_all_sizes_pre_ra_sim(
1614 (TOOM_2,), lhs_signed=False, rhs_signed=False)
1615
1616 def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self):
1617 TOOM_2 = ToomCookInstance.make_toom_2()
1618 self.tst_toom_mul_all_sizes_pre_ra_sim(
1619 (TOOM_2,), lhs_signed=False, rhs_signed=True)
1620
1621 def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self):
1622 TOOM_2 = ToomCookInstance.make_toom_2()
1623 self.tst_toom_mul_all_sizes_pre_ra_sim(
1624 (TOOM_2,), lhs_signed=True, rhs_signed=False)
1625
1626 def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self):
1627 TOOM_2 = ToomCookInstance.make_toom_2()
1628 self.tst_toom_mul_all_sizes_pre_ra_sim(
1629 (TOOM_2,), lhs_signed=True, rhs_signed=True)
1630
1631 def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self):
1632 TOOM_2 = ToomCookInstance.make_toom_2()
1633 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1634 self.tst_toom_mul_all_sizes_pre_ra_sim(
1635 instances, lhs_signed=False, rhs_signed=False)
1636
1637 def test_toom_2_mul_us_all_sizes_pre_ra_sim(self):
1638 TOOM_2 = ToomCookInstance.make_toom_2()
1639 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1640 self.tst_toom_mul_all_sizes_pre_ra_sim(
1641 instances, lhs_signed=False, rhs_signed=True)
1642
1643 def test_toom_2_mul_su_all_sizes_pre_ra_sim(self):
1644 TOOM_2 = ToomCookInstance.make_toom_2()
1645 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1646 self.tst_toom_mul_all_sizes_pre_ra_sim(
1647 instances, lhs_signed=True, rhs_signed=False)
1648
1649 def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self):
1650 TOOM_2 = ToomCookInstance.make_toom_2()
1651 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1652 self.tst_toom_mul_all_sizes_pre_ra_sim(
1653 instances, lhs_signed=True, rhs_signed=True)
1654
1655
1656 if __name__ == "__main__":
1657 unittest.main()