copy-merging works afaict! -- some tests still broken: out-of-date
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
1 from contextlib import contextmanager
2 import sys
3 import unittest
4 from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
5
6 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
7 GPR_SIZE_IN_BYTES,
8 GPR_VALUE_MASK, BaseSimState,
9 Fn, GenAsmState, OpKind,
10 PostRASimState,
11 PreRASimState, SSAVal)
12 from bigint_presentation_code.register_allocator import allocate_registers
13 from bigint_presentation_code.register_allocator_test_util import GraphDumper
14 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
15 simple_mul)
16 from bigint_presentation_code.util import OSet
17
18 _StateFactory = Callable[[], ContextManager[BaseSimState]]
19
20
21 def simple_umul(fn, lhs, rhs):
22 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, None]
23 return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
24 rhs_signed=False, name="mul"), None
25
26
27 def get_pre_ra_state_factory(code):
28 # type: (Mul) -> _StateFactory
29 @contextmanager
30 def state_factory():
31 state = PreRASimState(ssa_vals={}, memory={})
32 with state.set_as_current_debugging_state():
33 yield state
34 return state_factory
35
36
37 class Mul:
38 _MulFn = Callable[[Fn, SSAVal, SSAVal], Tuple[SSAVal, Any]]
39
40 def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
41 # type: (_MulFn, int, int) -> None
42 super().__init__()
43 self.fn = fn = Fn()
44 self.dest_offset = 0
45 self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
46 self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
47 self.lhs_size_in_words = lhs_size_in_words
48 self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
49 self.rhs_size_in_words = rhs_size_in_words
50 self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
51 self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
52 self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
53 self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
54 name="ptr_in").outputs[0]
55 self.lhs_setvl = fn.append_new_op(
56 kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
57 maxvl=lhs_size_in_words, name="lhs_setvl")
58 self.load_lhs = fn.append_new_op(
59 kind=OpKind.SvLd, immediates=[self.lhs_offset],
60 input_vals=[self.ptr_in, self.lhs_setvl.outputs[0]],
61 name="load_lhs", maxvl=lhs_size_in_words)
62 self.rhs_setvl = fn.append_new_op(
63 kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
64 maxvl=rhs_size_in_words, name="rhs_setvl")
65 self.load_rhs = fn.append_new_op(
66 kind=OpKind.SvLd, immediates=[self.rhs_offset],
67 input_vals=[self.ptr_in, self.rhs_setvl.outputs[0]],
68 name="load_rhs", maxvl=rhs_size_in_words)
69 self.retval = mul(
70 fn, self.load_lhs.outputs[0], self.load_rhs.outputs[0])
71 self.dest_setvl = fn.append_new_op(
72 kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
73 maxvl=self.dest_size_in_words, name="dest_setvl")
74 self.store = fn.append_new_op(
75 kind=OpKind.SvStd,
76 input_vals=[self.retval[0], self.ptr_in,
77 self.dest_setvl.outputs[0]],
78 immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
79 name="store_dest")
80
81
82 class TestToomCook(unittest.TestCase):
83 maxDiff = None
84
85 def get_post_ra_state_factory(self, code):
86 # type: (Mul) -> _StateFactory
87 ssa_val_to_loc_map = allocate_registers(
88 code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
89
90 @contextmanager
91 def state_factory():
92 yield PostRASimState(
93 ssa_val_to_loc_map=ssa_val_to_loc_map,
94 memory={}, loc_values={})
95 return state_factory
96
97 def test_toom_2_repr(self):
98 TOOM_2 = ToomCookInstance.make_toom_2()
99 # print(repr(repr(TOOM_2)))
100 self.assertEqual(
101 repr(TOOM_2),
102 "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
103 "eval_points=(0, 1, POINT_AT_INFINITY), "
104 "lhs_eval_ops=("
105 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
106 "EvalOpAdd(lhs="
107 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
108 "rhs="
109 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
110 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
111 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
112 " rhs_eval_ops=("
113 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
114 "EvalOpAdd(lhs="
115 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
116 "rhs="
117 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
118 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
119 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
120 " prod_eval_ops=("
121 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
122 "EvalOpSub(lhs="
123 "EvalOpSub(lhs="
124 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
125 "rhs="
126 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
127 "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
128 "rhs="
129 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
130 "poly=EvalOpPoly({"
131 "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
132 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
133 )
134
135 def test_toom_2_5_repr(self):
136 TOOM_2_5 = ToomCookInstance.make_toom_2_5()
137 # print(repr(repr(TOOM_2_5)))
138 self.assertEqual(
139 repr(TOOM_2_5),
140 "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
141 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
142 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
143 "EvalOpAdd(lhs="
144 "EvalOpAdd(lhs="
145 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
146 "rhs=EvalOpInput(lhs=2, rhs=0, "
147 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
148 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
149 "rhs=EvalOpInput(lhs=1, rhs=0, "
150 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
151 "poly=EvalOpPoly({"
152 "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
153 "EvalOpSub(lhs="
154 "EvalOpAdd(lhs="
155 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
156 "rhs=EvalOpInput(lhs=2, rhs=0, "
157 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
158 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
159 "rhs=EvalOpInput(lhs=1, rhs=0, "
160 "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
161 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
162 "EvalOpInput(lhs=2, rhs=0, "
163 "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
164 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
165 "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
166 "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
167 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
168 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
169 "EvalOpSub(lhs="
170 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
171 "rhs=EvalOpInput(lhs=1, rhs=0, "
172 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
173 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
174 "EvalOpInput(lhs=1, rhs=0, "
175 "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
176 "prod_eval_ops=("
177 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
178 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
179 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
180 "rhs=EvalOpInput(lhs=2, rhs=0, "
181 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
182 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
183 "rhs=2, "
184 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
185 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
186 "poly=EvalOpPoly("
187 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
188 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
189 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
190 "rhs="
191 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
192 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
193 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
194 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
195 "poly=EvalOpPoly("
196 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
197 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
198 )
199
200 def test_reversed_toom_2_5_repr(self):
201 TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
202 # print(repr(repr(TOOM_2_5)))
203 self.assertEqual(
204 repr(TOOM_2_5),
205 "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
206 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
207 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
208 "EvalOpAdd(lhs="
209 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
210 "rhs="
211 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
212 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
213 "EvalOpSub(lhs="
214 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
215 "rhs="
216 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
217 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
218 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
219 " rhs_eval_ops=("
220 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
221 "EvalOpAdd(lhs=EvalOpAdd(lhs="
222 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
223 "rhs="
224 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
225 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
226 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
227 "poly=EvalOpPoly("
228 "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
229 "EvalOpSub(lhs=EvalOpAdd(lhs="
230 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
231 "rhs="
232 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
233 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
234 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
235 "poly=EvalOpPoly("
236 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
237 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
238 " prod_eval_ops=("
239 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
240 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
241 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
242 "rhs="
243 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
244 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
245 "rhs=2, "
246 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
247 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
248 "poly=EvalOpPoly("
249 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
250 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
251 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
252 "rhs="
253 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
254 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
255 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
256 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
257 "poly=EvalOpPoly("
258 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
259 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
260 )
261
262 def test_simple_mul_192x192_pre_ra_sim(self):
263 for lhs_signed in False, True:
264 for rhs_signed in False, True:
265 self.tst_simple_mul_192x192_sim(
266 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
267 get_state_factory=get_pre_ra_state_factory)
268
269 def test_simple_mul_192x192_post_ra_sim(self):
270 for lhs_signed in False, True:
271 for rhs_signed in False, True:
272 self.tst_simple_mul_192x192_sim(
273 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
274 get_state_factory=self.get_post_ra_state_factory)
275
276 def tst_simple_mul_192x192_sim(
277 self, lhs_signed, # type: bool
278 rhs_signed, # type: bool
279 get_state_factory, # type: Callable[[Mul], _StateFactory]
280 ):
281 # test multiplying:
282 # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
283 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
284 # ==
285 # int("0x00074736574206e_6f69746163696c70"
286 # "_69746c756d207469_622d3438333e2d32"
287 # "_3931783239312079_7261727469627261", base=0)
288 # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
289 # 'little')
290 lhs_value = 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
291 rhs_value = 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
292 prod_value = int.from_bytes(
293 b"arbitrary 192x192->384-bit multiplication test", 'little')
294 self.assertEqual(lhs_value * rhs_value, prod_value)
295 code = Mul(
296 mul=lambda fn, lhs, rhs: (simple_mul(
297 fn=fn, lhs=lhs, lhs_signed=lhs_signed,
298 rhs=rhs, rhs_signed=rhs_signed, name="mul"), None),
299 lhs_size_in_words=3, rhs_size_in_words=3)
300 state_factory = get_state_factory(code)
301 ptr_in = 0x100
302 dest_ptr = ptr_in + code.dest_offset
303 lhs_ptr = ptr_in + code.lhs_offset
304 rhs_ptr = ptr_in + code.rhs_offset
305 for lhs_neg in False, True:
306 for rhs_neg in False, True:
307 if lhs_neg and not lhs_signed:
308 continue
309 if rhs_neg and not rhs_signed:
310 continue
311 with self.subTest(lhs_signed=lhs_signed,
312 rhs_signed=rhs_signed,
313 lhs_neg=lhs_neg, rhs_neg=rhs_neg):
314 with state_factory() as state:
315 state[code.ptr_in] = ptr_in,
316 lhs = lhs_value
317 if lhs_neg:
318 lhs = 2 ** 192 - lhs
319 rhs = rhs_value
320 if rhs_neg:
321 rhs = 2 ** 192 - rhs
322 for i in range(3):
323 v = (lhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
324 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
325 for i in range(3):
326 v = (rhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK
327 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
328 code.fn.sim(state)
329 expected = prod_value
330 if lhs_neg != rhs_neg:
331 expected = 2 ** 384 - expected
332 prod = 0
333 for i in range(6):
334 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
335 prod += v << (GPR_SIZE_IN_BITS * i)
336 self.assertEqual(hex(prod), hex(expected))
337
338 def test_simple_mul_192x192_ops(self):
339 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
340 fn = code.fn
341 self.assertEqual(
342 fn.ops_to_str(),
343 "ptr_in:\n"
344 " (<...outputs[0]: <I64>>) <= FuncArgR3\n"
345 "lhs_setvl:\n"
346 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
347 "load_lhs:\n"
348 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
349 " <ptr_in.outputs[0]: <I64>>,\n"
350 " <lhs_setvl.outputs[0]: <VL_MAXVL>>, 0x30)\n"
351 "rhs_setvl:\n"
352 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
353 "load_rhs:\n"
354 " (<...outputs[0]: <I64*3>>) <= SvLd(\n"
355 " <ptr_in.outputs[0]: <I64>>,\n"
356 " <rhs_setvl.outputs[0]: <VL_MAXVL>>, 0x48)\n"
357 "mul_rhs_setvl:\n"
358 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
359 "mul_rhs_spread:\n"
360 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
361 " <...outputs[2]: <I64>>) <= Spread(\n"
362 " <load_rhs.outputs[0]: <I64*3>>,\n"
363 " <mul_rhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
364 "mul_zero:\n"
365 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
366 "mul_lhs_setvl:\n"
367 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x3)\n"
368 "mul_zero2:\n"
369 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
370 "mul_0_mul:\n"
371 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
372 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
373 " <mul_rhs_spread.outputs[0]: <I64>>,\n"
374 " <mul_zero.outputs[0]: <I64>>,\n"
375 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
376 "mul_0_mul_rt_spread:\n"
377 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
378 " <...outputs[2]: <I64>>) <= Spread(\n"
379 " <mul_0_mul.outputs[0]: <I64*3>>,\n"
380 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
381 "mul_1_mul:\n"
382 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
383 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
384 " <mul_rhs_spread.outputs[1]: <I64>>,\n"
385 " <mul_zero.outputs[0]: <I64>>,\n"
386 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
387 "mul_1_mul_rt_spread:\n"
388 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
389 " <...outputs[2]: <I64>>) <= Spread(\n"
390 " <mul_1_mul.outputs[0]: <I64*3>>,\n"
391 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
392 "mul_1_cast_retval_zero:\n"
393 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
394 "mul_1_cast_pp_zero:\n"
395 " (<...outputs[0]: <I64>>) <= LI(0x0)\n"
396 "mul_1_setvl:\n"
397 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x5)\n"
398 "mul_1_retval_concat:\n"
399 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
400 " <mul_0_mul_rt_spread.outputs[1]: <I64>>,\n"
401 " <mul_0_mul_rt_spread.outputs[2]: <I64>>,\n"
402 " <mul_0_mul.outputs[1]: <I64>>,\n"
403 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
404 " <mul_1_cast_retval_zero.outputs[0]: <I64>>,\n"
405 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
406 "mul_1_pp_concat:\n"
407 " (<...outputs[0]: <I64*5>>) <= Concat(\n"
408 " <mul_1_mul_rt_spread.outputs[0]: <I64>>,\n"
409 " <mul_1_mul_rt_spread.outputs[1]: <I64>>,\n"
410 " <mul_1_mul_rt_spread.outputs[2]: <I64>>,\n"
411 " <mul_1_mul.outputs[1]: <I64>>,\n"
412 " <mul_1_cast_pp_zero.outputs[0]: <I64>>,\n"
413 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
414 "mul_1_clear_ca:\n"
415 " (<...outputs[0]: <CA>>) <= ClearCA\n"
416 "mul_1_add:\n"
417 " (<...outputs[0]: <I64*5>>, <...outputs[1]: <CA>>\n"
418 " ) <= SvAddE(<mul_1_retval_concat.outputs[0]: <I64*5>>,\n"
419 " <mul_1_pp_concat.outputs[0]: <I64*5>>,\n"
420 " <mul_1_clear_ca.outputs[0]: <CA>>,\n"
421 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
422 "mul_1_sum_spread:\n"
423 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
424 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>,\n"
425 " <...outputs[4]: <I64>>) <= Spread(\n"
426 " <mul_1_add.outputs[0]: <I64*5>>,\n"
427 " <mul_1_setvl.outputs[0]: <VL_MAXVL>>)\n"
428 "mul_2_mul:\n"
429 " (<...outputs[0]: <I64*3>>, <...outputs[1]: <I64>>\n"
430 " ) <= SvMAddEDU(<load_lhs.outputs[0]: <I64*3>>,\n"
431 " <mul_rhs_spread.outputs[2]: <I64>>,\n"
432 " <mul_zero.outputs[0]: <I64>>,\n"
433 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
434 "mul_2_mul_rt_spread:\n"
435 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
436 " <...outputs[2]: <I64>>) <= Spread(\n"
437 " <mul_2_mul.outputs[0]: <I64*3>>,\n"
438 " <mul_lhs_setvl.outputs[0]: <VL_MAXVL>>)\n"
439 "mul_2_setvl:\n"
440 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x4)\n"
441 "mul_2_retval_concat:\n"
442 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
443 " <mul_1_sum_spread.outputs[1]: <I64>>,\n"
444 " <mul_1_sum_spread.outputs[2]: <I64>>,\n"
445 " <mul_1_sum_spread.outputs[3]: <I64>>,\n"
446 " <mul_1_sum_spread.outputs[4]: <I64>>,\n"
447 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
448 "mul_2_pp_concat:\n"
449 " (<...outputs[0]: <I64*4>>) <= Concat(\n"
450 " <mul_2_mul_rt_spread.outputs[0]: <I64>>,\n"
451 " <mul_2_mul_rt_spread.outputs[1]: <I64>>,\n"
452 " <mul_2_mul_rt_spread.outputs[2]: <I64>>,\n"
453 " <mul_2_mul.outputs[1]: <I64>>,\n"
454 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
455 "mul_2_clear_ca:\n"
456 " (<...outputs[0]: <CA>>) <= ClearCA\n"
457 "mul_2_add:\n"
458 " (<...outputs[0]: <I64*4>>, <...outputs[1]: <CA>>\n"
459 " ) <= SvAddE(<mul_2_retval_concat.outputs[0]: <I64*4>>,\n"
460 " <mul_2_pp_concat.outputs[0]: <I64*4>>,\n"
461 " <mul_2_clear_ca.outputs[0]: <CA>>,\n"
462 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
463 "mul_2_sum_spread:\n"
464 " (<...outputs[0]: <I64>>, <...outputs[1]: <I64>>,\n"
465 " <...outputs[2]: <I64>>, <...outputs[3]: <I64>>) <= Spread(\n"
466 " <mul_2_add.outputs[0]: <I64*4>>,\n"
467 " <mul_2_setvl.outputs[0]: <VL_MAXVL>>)\n"
468 "mul_setvl:\n"
469 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
470 "mul_concat:\n"
471 " (<...outputs[0]: <I64*6>>) <= Concat(\n"
472 " <mul_0_mul_rt_spread.outputs[0]: <I64>>,\n"
473 " <mul_1_sum_spread.outputs[0]: <I64>>,\n"
474 " <mul_2_sum_spread.outputs[0]: <I64>>,\n"
475 " <mul_2_sum_spread.outputs[1]: <I64>>,\n"
476 " <mul_2_sum_spread.outputs[2]: <I64>>,\n"
477 " <mul_2_sum_spread.outputs[3]: <I64>>,\n"
478 " <mul_setvl.outputs[0]: <VL_MAXVL>>)\n"
479 "dest_setvl:\n"
480 " (<...outputs[0]: <VL_MAXVL>>) <= SetVLI(0x6)\n"
481 "store_dest:\n"
482 " SvStd(<mul_concat.outputs[0]: <I64*6>>,\n"
483 " <ptr_in.outputs[0]: <I64>>,\n"
484 " <dest_setvl.outputs[0]: <VL_MAXVL>>, 0x0)"
485 )
486
487 def test_simple_mul_192x192_reg_alloc(self):
488 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
489 fn = code.fn
490 assigned_registers = allocate_registers(
491 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
492 self.assertEqual(
493 repr(assigned_registers), "{"
494 "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
495 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
496 "<store_dest.inp1.copy.outputs[0]: <I64>>: "
497 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
498 "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
499 "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
500 "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
501 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
502 "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
503 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
504 "<mul_concat.out0.copy.outputs[0]: <I64*6>>: "
505 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
506 "<mul_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
507 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
508 "<mul_concat.outputs[0]: <I64*6>>: "
509 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
510 "<mul_concat.inp0.copy.outputs[0]: <I64>>: "
511 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
512 "<mul_concat.inp1.copy.outputs[0]: <I64>>: "
513 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
514 "<mul_concat.inp2.copy.outputs[0]: <I64>>: "
515 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
516 "<mul_concat.inp3.copy.outputs[0]: <I64>>: "
517 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
518 "<mul_concat.inp4.copy.outputs[0]: <I64>>: "
519 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
520 "<mul_concat.inp5.copy.outputs[0]: <I64>>: "
521 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
522 "<mul_concat.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
523 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
524 "<mul_setvl.outputs[0]: <VL_MAXVL>>: "
525 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
526 "<mul_2_sum_spread.out3.copy.outputs[0]: <I64>>: "
527 "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
528 "<mul_2_sum_spread.out2.copy.outputs[0]: <I64>>: "
529 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
530 "<mul_2_sum_spread.out1.copy.outputs[0]: <I64>>: "
531 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
532 "<mul_2_sum_spread.out0.copy.outputs[0]: <I64>>: "
533 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
534 "<mul_2_sum_spread.outputs[0]: <I64>>: "
535 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
536 "<mul_2_sum_spread.outputs[1]: <I64>>: "
537 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
538 "<mul_2_sum_spread.outputs[2]: <I64>>: "
539 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
540 "<mul_2_sum_spread.outputs[3]: <I64>>: "
541 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
542 "<mul_2_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
543 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
544 "<mul_2_sum_spread.inp0.copy.outputs[0]: <I64*4>>: "
545 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
546 "<mul_2_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
547 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
548 "<mul_2_add.out0.copy.outputs[0]: <I64*4>>: "
549 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
550 "<mul_2_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
551 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
552 "<mul_2_clear_ca.outputs[0]: <CA>>: "
553 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
554 "<mul_2_add.outputs[1]: <CA>>: "
555 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
556 "<mul_2_add.outputs[0]: <I64*4>>: "
557 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
558 "<mul_2_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
559 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
560 "<mul_2_add.inp1.copy.outputs[0]: <I64*4>>: "
561 "Loc(kind=LocKind.GPR, start=7, reg_len=4), "
562 "<mul_2_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
563 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
564 "<mul_2_add.inp0.copy.outputs[0]: <I64*4>>: "
565 "Loc(kind=LocKind.GPR, start=14, reg_len=4), "
566 "<mul_2_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
567 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
568 "<mul_2_pp_concat.out0.copy.outputs[0]: <I64*4>>: "
569 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
570 "<mul_2_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
571 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
572 "<mul_2_pp_concat.outputs[0]: <I64*4>>: "
573 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
574 "<mul_2_pp_concat.inp0.copy.outputs[0]: <I64>>: "
575 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
576 "<mul_2_pp_concat.inp1.copy.outputs[0]: <I64>>: "
577 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
578 "<mul_2_pp_concat.inp2.copy.outputs[0]: <I64>>: "
579 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
580 "<mul_2_pp_concat.inp3.copy.outputs[0]: <I64>>: "
581 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
582 "<mul_2_pp_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
583 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
584 "<mul_2_retval_concat.out0.copy.outputs[0]: <I64*4>>: "
585 "Loc(kind=LocKind.GPR, start=7, reg_len=4), "
586 "<mul_2_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
587 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
588 "<mul_2_retval_concat.outputs[0]: <I64*4>>: "
589 "Loc(kind=LocKind.GPR, start=3, reg_len=4), "
590 "<mul_2_retval_concat.inp0.copy.outputs[0]: <I64>>: "
591 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
592 "<mul_2_retval_concat.inp1.copy.outputs[0]: <I64>>: "
593 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
594 "<mul_2_retval_concat.inp2.copy.outputs[0]: <I64>>: "
595 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
596 "<mul_2_retval_concat.inp3.copy.outputs[0]: <I64>>: "
597 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
598 "<mul_2_retval_concat.inp4.setvl.outputs[0]: <VL_MAXVL>>: "
599 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
600 "<mul_2_setvl.outputs[0]: <VL_MAXVL>>: "
601 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
602 "<mul_2_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
603 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
604 "<mul_2_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
605 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
606 "<mul_2_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
607 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
608 "<mul_2_mul_rt_spread.outputs[0]: <I64>>: "
609 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
610 "<mul_2_mul_rt_spread.outputs[1]: <I64>>: "
611 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
612 "<mul_2_mul_rt_spread.outputs[2]: <I64>>: "
613 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
614 "<mul_2_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
615 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
616 "<mul_2_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
617 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
618 "<mul_2_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
619 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
620 "<mul_2_mul.out1.copy.outputs[0]: <I64>>: "
621 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
622 "<mul_2_mul.out0.copy.outputs[0]: <I64*3>>: "
623 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
624 "<mul_2_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
625 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
626 "<mul_2_mul.inp2.copy.outputs[0]: <I64>>: "
627 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
628 "<mul_2_mul.outputs[1]: <I64>>: "
629 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
630 "<mul_2_mul.outputs[0]: <I64*3>>: "
631 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
632 "<mul_2_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
633 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
634 "<mul_2_mul.inp1.copy.outputs[0]: <I64>>: "
635 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
636 "<mul_2_mul.inp0.copy.outputs[0]: <I64*3>>: "
637 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
638 "<mul_2_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
639 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
640 "<mul_1_sum_spread.out4.copy.outputs[0]: <I64>>: "
641 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
642 "<mul_1_sum_spread.out3.copy.outputs[0]: <I64>>: "
643 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
644 "<mul_1_sum_spread.out2.copy.outputs[0]: <I64>>: "
645 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
646 "<mul_1_sum_spread.out1.copy.outputs[0]: <I64>>: "
647 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
648 "<mul_1_sum_spread.out0.copy.outputs[0]: <I64>>: "
649 "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
650 "<mul_1_sum_spread.outputs[0]: <I64>>: "
651 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
652 "<mul_1_sum_spread.outputs[1]: <I64>>: "
653 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
654 "<mul_1_sum_spread.outputs[2]: <I64>>: "
655 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
656 "<mul_1_sum_spread.outputs[3]: <I64>>: "
657 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
658 "<mul_1_sum_spread.outputs[4]: <I64>>: "
659 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
660 "<mul_1_sum_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
661 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
662 "<mul_1_sum_spread.inp0.copy.outputs[0]: <I64*5>>: "
663 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
664 "<mul_1_sum_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
665 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
666 "<mul_1_add.out0.copy.outputs[0]: <I64*5>>: "
667 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
668 "<mul_1_add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
669 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
670 "<mul_1_clear_ca.outputs[0]: <CA>>: "
671 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
672 "<mul_1_add.outputs[1]: <CA>>: "
673 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
674 "<mul_1_add.outputs[0]: <I64*5>>: "
675 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
676 "<mul_1_add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
677 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
678 "<mul_1_add.inp1.copy.outputs[0]: <I64*5>>: "
679 "Loc(kind=LocKind.GPR, start=8, reg_len=5), "
680 "<mul_1_add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
681 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
682 "<mul_1_add.inp0.copy.outputs[0]: <I64*5>>: "
683 "Loc(kind=LocKind.GPR, start=14, reg_len=5), "
684 "<mul_1_add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
685 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
686 "<mul_1_pp_concat.out0.copy.outputs[0]: <I64*5>>: "
687 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
688 "<mul_1_pp_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
689 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
690 "<mul_1_pp_concat.outputs[0]: <I64*5>>: "
691 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
692 "<mul_1_pp_concat.inp0.copy.outputs[0]: <I64>>: "
693 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
694 "<mul_1_pp_concat.inp1.copy.outputs[0]: <I64>>: "
695 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
696 "<mul_1_pp_concat.inp2.copy.outputs[0]: <I64>>: "
697 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
698 "<mul_1_pp_concat.inp3.copy.outputs[0]: <I64>>: "
699 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
700 "<mul_1_pp_concat.inp4.copy.outputs[0]: <I64>>: "
701 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
702 "<mul_1_pp_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
703 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
704 "<mul_1_retval_concat.out0.copy.outputs[0]: <I64*5>>: "
705 "Loc(kind=LocKind.GPR, start=8, reg_len=5), "
706 "<mul_1_retval_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
707 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
708 "<mul_1_retval_concat.outputs[0]: <I64*5>>: "
709 "Loc(kind=LocKind.GPR, start=3, reg_len=5), "
710 "<mul_1_retval_concat.inp0.copy.outputs[0]: <I64>>: "
711 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
712 "<mul_1_retval_concat.inp1.copy.outputs[0]: <I64>>: "
713 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
714 "<mul_1_retval_concat.inp2.copy.outputs[0]: <I64>>: "
715 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
716 "<mul_1_retval_concat.inp3.copy.outputs[0]: <I64>>: "
717 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
718 "<mul_1_retval_concat.inp4.copy.outputs[0]: <I64>>: "
719 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
720 "<mul_1_retval_concat.inp5.setvl.outputs[0]: <VL_MAXVL>>: "
721 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
722 "<mul_1_setvl.outputs[0]: <VL_MAXVL>>: "
723 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
724 "<mul_1_cast_pp_zero.out0.copy.outputs[0]: <I64>>: "
725 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
726 "<mul_1_cast_pp_zero.outputs[0]: <I64>>: "
727 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
728 "<mul_1_cast_retval_zero.out0.copy.outputs[0]: <I64>>: "
729 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
730 "<mul_1_cast_retval_zero.outputs[0]: <I64>>: "
731 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
732 "<mul_1_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
733 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
734 "<mul_1_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
735 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
736 "<mul_1_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
737 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
738 "<mul_1_mul_rt_spread.outputs[0]: <I64>>: "
739 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
740 "<mul_1_mul_rt_spread.outputs[1]: <I64>>: "
741 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
742 "<mul_1_mul_rt_spread.outputs[2]: <I64>>: "
743 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
744 "<mul_1_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
745 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
746 "<mul_1_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
747 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
748 "<mul_1_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
749 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
750 "<mul_1_mul.out1.copy.outputs[0]: <I64>>: "
751 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
752 "<mul_1_mul.out0.copy.outputs[0]: <I64*3>>: "
753 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
754 "<mul_1_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
755 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
756 "<mul_1_mul.inp2.copy.outputs[0]: <I64>>: "
757 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
758 "<mul_1_mul.outputs[1]: <I64>>: "
759 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
760 "<mul_1_mul.outputs[0]: <I64*3>>: "
761 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
762 "<mul_1_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
763 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
764 "<mul_1_mul.inp1.copy.outputs[0]: <I64>>: "
765 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
766 "<mul_1_mul.inp0.copy.outputs[0]: <I64*3>>: "
767 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
768 "<mul_1_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
769 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
770 "<mul_0_mul_rt_spread.out2.copy.outputs[0]: <I64>>: "
771 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
772 "<mul_0_mul_rt_spread.out1.copy.outputs[0]: <I64>>: "
773 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
774 "<mul_0_mul_rt_spread.out0.copy.outputs[0]: <I64>>: "
775 "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
776 "<mul_0_mul_rt_spread.outputs[0]: <I64>>: "
777 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
778 "<mul_0_mul_rt_spread.outputs[1]: <I64>>: "
779 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
780 "<mul_0_mul_rt_spread.outputs[2]: <I64>>: "
781 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
782 "<mul_0_mul_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
783 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
784 "<mul_0_mul_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
785 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
786 "<mul_0_mul_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
787 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
788 "<mul_0_mul.out1.copy.outputs[0]: <I64>>: "
789 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
790 "<mul_0_mul.out0.copy.outputs[0]: <I64*3>>: "
791 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
792 "<mul_0_mul.out0.setvl.outputs[0]: <VL_MAXVL>>: "
793 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
794 "<mul_0_mul.inp2.copy.outputs[0]: <I64>>: "
795 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
796 "<mul_0_mul.outputs[1]: <I64>>: "
797 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
798 "<mul_0_mul.outputs[0]: <I64*3>>: "
799 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
800 "<mul_0_mul.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
801 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
802 "<mul_0_mul.inp1.copy.outputs[0]: <I64>>: "
803 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
804 "<mul_0_mul.inp0.copy.outputs[0]: <I64*3>>: "
805 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
806 "<mul_0_mul.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
807 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
808 "<mul_zero2.out0.copy.outputs[0]: <I64>>: "
809 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
810 "<mul_zero2.outputs[0]: <I64>>: "
811 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
812 "<mul_lhs_setvl.outputs[0]: <VL_MAXVL>>: "
813 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
814 "<mul_zero.out0.copy.outputs[0]: <I64>>: "
815 "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
816 "<mul_zero.outputs[0]: <I64>>: "
817 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
818 "<mul_rhs_spread.out2.copy.outputs[0]: <I64>>: "
819 "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
820 "<mul_rhs_spread.out1.copy.outputs[0]: <I64>>: "
821 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
822 "<mul_rhs_spread.out0.copy.outputs[0]: <I64>>: "
823 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
824 "<mul_rhs_spread.outputs[0]: <I64>>: "
825 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
826 "<mul_rhs_spread.outputs[1]: <I64>>: "
827 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
828 "<mul_rhs_spread.outputs[2]: <I64>>: "
829 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
830 "<mul_rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
831 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
832 "<mul_rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
833 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
834 "<mul_rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
835 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
836 "<mul_rhs_setvl.outputs[0]: <VL_MAXVL>>: "
837 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
838 "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
839 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
840 "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
841 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
842 "<load_rhs.outputs[0]: <I64*3>>: "
843 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
844 "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
845 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
846 "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
847 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
848 "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
849 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
850 "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
851 "Loc(kind=LocKind.GPR, start=24, reg_len=3), "
852 "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
853 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
854 "<load_lhs.outputs[0]: <I64*3>>: "
855 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
856 "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
857 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
858 "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
859 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
860 "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
861 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
862 "<ptr_in.out0.copy.outputs[0]: <I64>>: "
863 "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
864 "<ptr_in.outputs[0]: <I64>>: "
865 "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
866 "}")
867
868 def test_simple_mul_192x192_asm(self):
869 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
870 fn = code.fn
871 assigned_registers = allocate_registers(
872 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
873 gen_asm_state = GenAsmState(assigned_registers)
874 fn.gen_asm(gen_asm_state)
875 self.assertEqual(gen_asm_state.output, [
876 'or 9, 3, 3',
877 'setvl 0, 0, 3, 0, 1, 1',
878 'setvl 0, 0, 3, 0, 1, 1',
879 'sv.ld *20, 48(9)',
880 'setvl 0, 0, 3, 0, 1, 1',
881 'setvl 0, 0, 3, 0, 1, 1',
882 'setvl 0, 0, 3, 0, 1, 1',
883 'sv.ld *10, 72(9)',
884 'setvl 0, 0, 3, 0, 1, 1',
885 'setvl 0, 0, 3, 0, 1, 1',
886 'setvl 0, 0, 3, 0, 1, 1',
887 'setvl 0, 0, 3, 0, 1, 1',
888 'addi 6, 0, 0',
889 'or 8, 6, 6',
890 'setvl 0, 0, 3, 0, 1, 1',
891 'addi 3, 0, 0',
892 'setvl 0, 0, 3, 0, 1, 1',
893 'or 6, 8, 8',
894 'setvl 0, 0, 3, 0, 1, 1',
895 'sv.maddedu *14, *20, 10, 6',
896 'setvl 0, 0, 3, 0, 1, 1',
897 'or 17, 6, 6',
898 'setvl 0, 0, 3, 0, 1, 1',
899 'setvl 0, 0, 3, 0, 1, 1',
900 'setvl 0, 0, 3, 0, 1, 1',
901 'or 6, 8, 8',
902 'setvl 0, 0, 3, 0, 1, 1',
903 'sv.maddedu *3, *20, 11, 6',
904 'setvl 0, 0, 3, 0, 1, 1',
905 'setvl 0, 0, 3, 0, 1, 1',
906 'setvl 0, 0, 3, 0, 1, 1',
907 'addi 18, 0, 0',
908 'addi 7, 0, 0',
909 'setvl 0, 0, 5, 0, 1, 1',
910 'or 19, 18, 18',
911 'setvl 0, 0, 5, 0, 1, 1',
912 'setvl 0, 0, 5, 0, 1, 1',
913 'setvl 0, 0, 5, 0, 1, 1',
914 'setvl 0, 0, 5, 0, 1, 1',
915 'addic 0, 0, 0',
916 'setvl 0, 0, 5, 0, 1, 1',
917 'setvl 0, 0, 5, 0, 1, 1',
918 'setvl 0, 0, 5, 0, 1, 1',
919 'sv.adde *15, *15, *3',
920 'setvl 0, 0, 5, 0, 1, 1',
921 'setvl 0, 0, 5, 0, 1, 1',
922 'setvl 0, 0, 5, 0, 1, 1',
923 'setvl 0, 0, 3, 0, 1, 1',
924 'or 6, 8, 8',
925 'setvl 0, 0, 3, 0, 1, 1',
926 'sv.maddedu *3, *20, 12, 6',
927 'setvl 0, 0, 3, 0, 1, 1',
928 'setvl 0, 0, 3, 0, 1, 1',
929 'setvl 0, 0, 3, 0, 1, 1',
930 'setvl 0, 0, 4, 0, 1, 1',
931 'setvl 0, 0, 4, 0, 1, 1',
932 'setvl 0, 0, 4, 0, 1, 1',
933 'setvl 0, 0, 4, 0, 1, 1',
934 'setvl 0, 0, 4, 0, 1, 1',
935 'addic 0, 0, 0',
936 'setvl 0, 0, 4, 0, 1, 1',
937 'setvl 0, 0, 4, 0, 1, 1',
938 'setvl 0, 0, 4, 0, 1, 1',
939 'sv.adde *16, *16, *3',
940 'setvl 0, 0, 4, 0, 1, 1',
941 'setvl 0, 0, 4, 0, 1, 1',
942 'setvl 0, 0, 4, 0, 1, 1',
943 'setvl 0, 0, 6, 0, 1, 1',
944 'setvl 0, 0, 6, 0, 1, 1',
945 'setvl 0, 0, 6, 0, 1, 1',
946 'setvl 0, 0, 6, 0, 1, 1',
947 'setvl 0, 0, 6, 0, 1, 1',
948 'setvl 0, 0, 6, 0, 1, 1',
949 'sv.std *14, 0(9)',
950 ])
951
952 def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
953 # type: (bool, bool) -> Mul
954 TOOM_2 = ToomCookInstance.make_toom_2()
955 instances = TOOM_2,
956
957 def mul(fn, lhs, rhs):
958 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
959 v = ToomCookMul(fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
960 rhs_signed=rhs_signed, instances=instances)
961 return v.retval, v
962 return Mul(mul=mul, lhs_size_in_words=4, rhs_size_in_words=4)
963
964 def make_256x256_mul_test_cases(self, lhs_signed, rhs_signed):
965 # type: (bool, bool) -> Iterator[tuple[int, int, int]]
966 # test multiplying `+-1 << n` and:
967 # 0xc162321a5eaad80b_4b86bb0efdfb93c0_a789ff04cc11b157_eaa08e29fb197621
968 # *
969 # 0x3138710167583371_998af336a8fac64d_e6da3737090787fe_85ba09ea701f4af2
970 # ==
971 # int("0x"
972 # "252e6e6f69746163_696c7069746c754d_"
973 # "2061627573746172_614b202d20322d4d_"
974 # "4f4f5420676e6973_75206c756d20746e_"
975 # "6967696220746962_2d36353278363532", base=0)
976 # == int.from_bytes(b'256x256-bit bigint mul using TOOM-2 '
977 # b'- Karatsuba Multiplication.%', 'little')
978 lhs_value_in = (0xc162321a5eaad80b_4b86bb0efdfb93c0 << 128) \
979 | 0xa789ff04cc11b157_eaa08e29fb197621
980 rhs_value_in = (0x3138710167583371_998af336a8fac64d << 128) \
981 | 0xe6da3737090787fe_85ba09ea701f4af2
982 prod_value_in = int.from_bytes(
983 b'256x256-bit bigint mul using TOOM-2 '
984 b'- Karatsuba Multiplication.%', 'little')
985 self.assertEqual(lhs_value_in * rhs_value_in, prod_value_in)
986 shifts = [*range(0, 256, 16), *range(15, 256, 16)]
987 lhs_values = [1 << i for i in shifts] + [0, lhs_value_in]
988 rhs_values = [1 << i for i in shifts] + [0, rhs_value_in]
989 if lhs_signed:
990 lhs_values.extend([-i for i in lhs_values])
991 if rhs_signed:
992 rhs_values.extend([-i for i in rhs_values])
993
994 def key(v):
995 # type: (int) -> tuple[bool, int]
996 return abs(v) in (lhs_value_in, rhs_value_in), v % (1 << 256)
997
998 lhs_values.sort(key=key)
999 rhs_values.sort(key=key)
1000 for lhs_value in lhs_values:
1001 for rhs_value in rhs_values:
1002 lhs_value %= 1 << 256
1003 rhs_value %= 1 << 256
1004 if lhs_value >> 255 != 0 and lhs_signed:
1005 lhs_value -= 1 << 256
1006 if rhs_value >> 255 != 0 and rhs_signed:
1007 rhs_value -= 1 << 256
1008 prod_value = lhs_value * rhs_value
1009 lhs_value %= 1 << 256
1010 rhs_value %= 1 << 256
1011 prod_value %= 1 << 512
1012 yield lhs_value, rhs_value, prod_value
1013
1014 def tst_toom_2_mul_256x256_sim(
1015 self, lhs_signed, # type: bool
1016 rhs_signed, # type: bool
1017 get_state_factory, # type: Callable[[Mul], _StateFactory]
1018 ):
1019 code = self.toom_2_mul_256x256(
1020 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1021 print(code.retval[1])
1022 print(code.fn.ops_to_str())
1023 state_factory = get_state_factory(code)
1024 ptr_in = 0x100
1025 dest_ptr = ptr_in + code.dest_offset
1026 lhs_ptr = ptr_in + code.lhs_offset
1027 rhs_ptr = ptr_in + code.rhs_offset
1028 values = self.make_256x256_mul_test_cases(
1029 lhs_signed=lhs_signed, rhs_signed=rhs_signed)
1030 for lhs_value, rhs_value, prod_value in values:
1031 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1032 lhs_value=hex(lhs_value),
1033 rhs_value=hex(rhs_value),
1034 prod_value=hex(prod_value)):
1035 with state_factory() as state:
1036 state[code.ptr_in] = ptr_in,
1037 for i in range(4):
1038 v = lhs_value >> GPR_SIZE_IN_BITS * i
1039 v &= GPR_VALUE_MASK
1040 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1041 for i in range(4):
1042 v = rhs_value >> GPR_SIZE_IN_BITS * i
1043 v &= GPR_VALUE_MASK
1044 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1045 code.fn.sim(state)
1046 prod = 0
1047 for i in range(8):
1048 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1049 prod += v << (GPR_SIZE_IN_BITS * i)
1050 self.assertEqual(hex(prod), hex(prod_value),
1051 f"failed: state={state}")
1052
1053 def test_toom_2_mul_256x256_pre_ra_sim(self):
1054 for lhs_signed in False, True:
1055 for rhs_signed in False, True:
1056 self.tst_toom_2_mul_256x256_sim(
1057 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1058 get_state_factory=get_pre_ra_state_factory)
1059
1060 def test_toom_2_mul_256x256_uu_post_ra_sim(self):
1061 self.tst_toom_2_mul_256x256_sim(
1062 lhs_signed=False, rhs_signed=False,
1063 get_state_factory=self.get_post_ra_state_factory)
1064
1065 def test_toom_2_mul_256x256_su_post_ra_sim(self):
1066 self.tst_toom_2_mul_256x256_sim(
1067 lhs_signed=True, rhs_signed=False,
1068 get_state_factory=self.get_post_ra_state_factory)
1069
1070 def test_toom_2_mul_256x256_us_post_ra_sim(self):
1071 self.tst_toom_2_mul_256x256_sim(
1072 lhs_signed=False, rhs_signed=True,
1073 get_state_factory=self.get_post_ra_state_factory)
1074
1075 def test_toom_2_mul_256x256_ss_post_ra_sim(self):
1076 self.tst_toom_2_mul_256x256_sim(
1077 lhs_signed=True, rhs_signed=True,
1078 get_state_factory=self.get_post_ra_state_factory)
1079
1080 def test_toom_2_mul_256x256_asm(self):
1081 code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
1082 fn = code.fn
1083 assigned_registers = allocate_registers(
1084 fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
1085 gen_asm_state = GenAsmState(assigned_registers)
1086 fn.gen_asm(gen_asm_state)
1087 self.assertEqual(gen_asm_state.output, [
1088 'or 45, 3, 3',
1089 'setvl 0, 0, 4, 0, 1, 1',
1090 'or 7, 45, 45',
1091 'setvl 0, 0, 4, 0, 1, 1',
1092 'sv.ld *3, 64(7)',
1093 'setvl 0, 0, 4, 0, 1, 1',
1094 'sv.or *8, *3, *3',
1095 'setvl 0, 0, 4, 0, 1, 1',
1096 'or 7, 45, 45',
1097 'setvl 0, 0, 4, 0, 1, 1',
1098 'sv.ld *3, 96(7)',
1099 'setvl 0, 0, 4, 0, 1, 1',
1100 'sv.or *18, *3, *3',
1101 'setvl 0, 0, 4, 0, 1, 1',
1102 'setvl 0, 0, 4, 0, 1, 1',
1103 'sv.or *3, *8, *8',
1104 'setvl 0, 0, 4, 0, 1, 1',
1105 'sv.or *9, *3, *3',
1106 'or 3, 9, 9',
1107 'or 8, 10, 10',
1108 'or 6, 11, 11',
1109 'or 7, 12, 12',
1110 'setvl 0, 0, 3, 0, 1, 1',
1111 'or 4, 8, 8',
1112 'or 5, 6, 6',
1113 'setvl 0, 0, 3, 0, 1, 1',
1114 'setvl 0, 0, 3, 0, 1, 1',
1115 'sv.or/mrr *4, *3, *3',
1116 'setvl 0, 0, 3, 0, 1, 1',
1117 'or 3, 7, 7',
1118 'setvl 0, 0, 3, 0, 1, 1',
1119 'or 46, 3, 3',
1120 'setvl 0, 0, 3, 0, 1, 1',
1121 'setvl 0, 0, 3, 0, 1, 1',
1122 'sv.or *3, *4, *4',
1123 'setvl 0, 0, 3, 0, 1, 1',
1124 'sv.or/mrr *5, *3, *3',
1125 'or 4, 5, 5',
1126 'or 9, 6, 6',
1127 'or 8, 7, 7',
1128 'addi 3, 0, 0',
1129 'or 7, 3, 3',
1130 'setvl 0, 0, 4, 0, 1, 1',
1131 'or 3, 4, 4',
1132 'or 4, 9, 9',
1133 'or 5, 8, 8',
1134 'or 6, 7, 7',
1135 'setvl 0, 0, 4, 0, 1, 1',
1136 'setvl 0, 0, 4, 0, 1, 1',
1137 'sv.or *35, *3, *3',
1138 'setvl 0, 0, 1, 0, 1, 1',
1139 'or 3, 46, 46',
1140 'setvl 0, 0, 1, 0, 1, 1',
1141 'or 4, 3, 3',
1142 'addi 3, 0, 0',
1143 'or 7, 3, 3',
1144 'setvl 0, 0, 4, 0, 1, 1',
1145 'or 3, 4, 4',
1146 'or 4, 7, 7',
1147 'or 5, 7, 7',
1148 'or 6, 7, 7',
1149 'setvl 0, 0, 4, 0, 1, 1',
1150 'setvl 0, 0, 4, 0, 1, 1',
1151 'setvl 0, 0, 4, 0, 1, 1',
1152 'addic 0, 0, 0',
1153 'setvl 0, 0, 4, 0, 1, 1',
1154 'sv.or *14, *35, *35',
1155 'setvl 0, 0, 4, 0, 1, 1',
1156 'sv.or *7, *3, *3',
1157 'setvl 0, 0, 4, 0, 1, 1',
1158 'sv.adde *3, *14, *7',
1159 'setvl 0, 0, 4, 0, 1, 1',
1160 'sv.or *40, *3, *3',
1161 'setvl 0, 0, 4, 0, 1, 1',
1162 'setvl 0, 0, 4, 0, 1, 1',
1163 'sv.or *3, *18, *18',
1164 'setvl 0, 0, 4, 0, 1, 1',
1165 'sv.or *9, *3, *3',
1166 'or 3, 9, 9',
1167 'or 8, 10, 10',
1168 'or 6, 11, 11',
1169 'or 7, 12, 12',
1170 'setvl 0, 0, 3, 0, 1, 1',
1171 'or 4, 8, 8',
1172 'or 5, 6, 6',
1173 'setvl 0, 0, 3, 0, 1, 1',
1174 'setvl 0, 0, 3, 0, 1, 1',
1175 'sv.or/mrr *4, *3, *3',
1176 'setvl 0, 0, 3, 0, 1, 1',
1177 'or 3, 7, 7',
1178 'setvl 0, 0, 3, 0, 1, 1',
1179 'or 47, 3, 3',
1180 'setvl 0, 0, 3, 0, 1, 1',
1181 'setvl 0, 0, 3, 0, 1, 1',
1182 'sv.or *3, *4, *4',
1183 'setvl 0, 0, 3, 0, 1, 1',
1184 'sv.or/mrr *5, *3, *3',
1185 'or 4, 5, 5',
1186 'or 9, 6, 6',
1187 'or 8, 7, 7',
1188 'addi 3, 0, 0',
1189 'or 7, 3, 3',
1190 'setvl 0, 0, 4, 0, 1, 1',
1191 'or 3, 4, 4',
1192 'or 4, 9, 9',
1193 'or 5, 8, 8',
1194 'or 6, 7, 7',
1195 'setvl 0, 0, 4, 0, 1, 1',
1196 'setvl 0, 0, 4, 0, 1, 1',
1197 'sv.or *18, *3, *3',
1198 'setvl 0, 0, 1, 0, 1, 1',
1199 'or 3, 47, 47',
1200 'setvl 0, 0, 1, 0, 1, 1',
1201 'or 4, 3, 3',
1202 'addi 3, 0, 0',
1203 'or 7, 3, 3',
1204 'setvl 0, 0, 4, 0, 1, 1',
1205 'or 3, 4, 4',
1206 'or 4, 7, 7',
1207 'or 5, 7, 7',
1208 'or 6, 7, 7',
1209 'setvl 0, 0, 4, 0, 1, 1',
1210 'setvl 0, 0, 4, 0, 1, 1',
1211 'setvl 0, 0, 4, 0, 1, 1',
1212 'addic 0, 0, 0',
1213 'setvl 0, 0, 4, 0, 1, 1',
1214 'sv.or *14, *18, *18',
1215 'setvl 0, 0, 4, 0, 1, 1',
1216 'sv.or *7, *3, *3',
1217 'setvl 0, 0, 4, 0, 1, 1',
1218 'sv.adde *3, *14, *7',
1219 'setvl 0, 0, 4, 0, 1, 1',
1220 'sv.or *29, *3, *3',
1221 'setvl 0, 0, 4, 0, 1, 1',
1222 'setvl 0, 0, 4, 0, 1, 1',
1223 'sv.or *3, *18, *18',
1224 'setvl 0, 0, 4, 0, 1, 1',
1225 'sv.or/mrr *5, *3, *3',
1226 'or 4, 5, 5',
1227 'or 9, 6, 6',
1228 'or 34, 7, 7',
1229 'or 33, 8, 8',
1230 'addi 3, 0, 0',
1231 'or 28, 3, 3',
1232 'setvl 0, 0, 4, 0, 1, 1',
1233 'addi 3, 0, 0',
1234 'setvl 0, 0, 4, 0, 1, 1',
1235 'sv.or *14, *35, *35',
1236 'or 8, 4, 4',
1237 'or 7, 28, 28',
1238 'setvl 0, 0, 4, 0, 1, 1',
1239 'sv.maddedu *3, *14, 8, 7',
1240 'setvl 0, 0, 4, 0, 1, 1',
1241 'or 24, 7, 7',
1242 'setvl 0, 0, 4, 0, 1, 1',
1243 'setvl 0, 0, 4, 0, 1, 1',
1244 'or 27, 3, 3',
1245 'or 23, 4, 4',
1246 'or 19, 5, 5',
1247 'or 18, 6, 6',
1248 'setvl 0, 0, 4, 0, 1, 1',
1249 'sv.or *14, *35, *35',
1250 'or 8, 9, 9',
1251 'or 7, 28, 28',
1252 'setvl 0, 0, 4, 0, 1, 1',
1253 'sv.maddedu *3, *14, 8, 7',
1254 'setvl 0, 0, 4, 0, 1, 1',
1255 'or 22, 7, 7',
1256 'setvl 0, 0, 4, 0, 1, 1',
1257 'setvl 0, 0, 4, 0, 1, 1',
1258 'or 21, 3, 3',
1259 'or 20, 4, 4',
1260 'or 12, 5, 5',
1261 'or 11, 6, 6',
1262 'addi 3, 0, 0',
1263 'or 10, 3, 3',
1264 'addi 3, 0, 0',
1265 'or 9, 3, 3',
1266 'setvl 0, 0, 6, 0, 1, 1',
1267 'or 3, 23, 23',
1268 'or 4, 19, 19',
1269 'or 5, 18, 18',
1270 'or 6, 24, 24',
1271 'or 7, 10, 10',
1272 'or 8, 10, 10',
1273 'setvl 0, 0, 6, 0, 1, 1',
1274 'setvl 0, 0, 6, 0, 1, 1',
1275 'sv.or *14, *3, *3',
1276 'or 3, 21, 21',
1277 'or 4, 20, 20',
1278 'or 5, 12, 12',
1279 'or 6, 11, 11',
1280 'or 7, 22, 22',
1281 'or 8, 9, 9',
1282 'setvl 0, 0, 6, 0, 1, 1',
1283 'setvl 0, 0, 6, 0, 1, 1',
1284 'addic 0, 0, 0',
1285 'setvl 0, 0, 6, 0, 1, 1',
1286 'sv.or *20, *14, *14',
1287 'setvl 0, 0, 6, 0, 1, 1',
1288 'sv.or *14, *3, *3',
1289 'setvl 0, 0, 6, 0, 1, 1',
1290 'sv.adde *3, *20, *14',
1291 'setvl 0, 0, 6, 0, 1, 1',
1292 'setvl 0, 0, 6, 0, 1, 1',
1293 'setvl 0, 0, 6, 0, 1, 1',
1294 'or 26, 3, 3',
1295 'or 25, 4, 4',
1296 'or 24, 5, 5',
1297 'or 23, 6, 6',
1298 'or 19, 7, 7',
1299 'or 18, 8, 8',
1300 'setvl 0, 0, 4, 0, 1, 1',
1301 'sv.or *14, *35, *35',
1302 'or 8, 34, 34',
1303 'or 7, 28, 28',
1304 'setvl 0, 0, 4, 0, 1, 1',
1305 'sv.maddedu *3, *14, 8, 7',
1306 'setvl 0, 0, 4, 0, 1, 1',
1307 'or 22, 7, 7',
1308 'setvl 0, 0, 4, 0, 1, 1',
1309 'setvl 0, 0, 4, 0, 1, 1',
1310 'or 21, 3, 3',
1311 'or 20, 4, 4',
1312 'or 12, 5, 5',
1313 'or 11, 6, 6',
1314 'addi 3, 0, 0',
1315 'or 10, 3, 3',
1316 'addi 3, 0, 0',
1317 'or 9, 3, 3',
1318 'setvl 0, 0, 6, 0, 1, 1',
1319 'or 3, 25, 25',
1320 'or 4, 24, 24',
1321 'or 5, 23, 23',
1322 'or 6, 19, 19',
1323 'or 7, 18, 18',
1324 'or 8, 10, 10',
1325 'setvl 0, 0, 6, 0, 1, 1',
1326 'setvl 0, 0, 6, 0, 1, 1',
1327 'sv.or *14, *3, *3',
1328 'or 3, 21, 21',
1329 'or 4, 20, 20',
1330 'or 5, 12, 12',
1331 'or 6, 11, 11',
1332 'or 7, 22, 22',
1333 'or 8, 9, 9',
1334 'setvl 0, 0, 6, 0, 1, 1',
1335 'setvl 0, 0, 6, 0, 1, 1',
1336 'addic 0, 0, 0',
1337 'setvl 0, 0, 6, 0, 1, 1',
1338 'sv.or *20, *14, *14',
1339 'setvl 0, 0, 6, 0, 1, 1',
1340 'sv.or *14, *3, *3',
1341 'setvl 0, 0, 6, 0, 1, 1',
1342 'sv.adde *3, *20, *14',
1343 'setvl 0, 0, 6, 0, 1, 1',
1344 'setvl 0, 0, 6, 0, 1, 1',
1345 'setvl 0, 0, 6, 0, 1, 1',
1346 'or 20, 3, 3',
1347 'or 19, 4, 4',
1348 'or 12, 5, 5',
1349 'or 11, 6, 6',
1350 'or 10, 7, 7',
1351 'or 9, 8, 8',
1352 'setvl 0, 0, 4, 0, 1, 1',
1353 'sv.or *14, *35, *35',
1354 'or 8, 33, 33',
1355 'or 7, 28, 28',
1356 'setvl 0, 0, 4, 0, 1, 1',
1357 'sv.maddedu *3, *14, 8, 7',
1358 'setvl 0, 0, 4, 0, 1, 1',
1359 'or 18, 7, 7',
1360 'setvl 0, 0, 4, 0, 1, 1',
1361 'setvl 0, 0, 4, 0, 1, 1',
1362 'or 17, 3, 3',
1363 'or 16, 4, 4',
1364 'or 15, 5, 5',
1365 'or 14, 6, 6',
1366 'setvl 0, 0, 5, 0, 1, 1',
1367 'or 3, 19, 19',
1368 'or 4, 12, 12',
1369 'or 5, 11, 11',
1370 'or 6, 10, 10',
1371 'or 7, 9, 9',
1372 'setvl 0, 0, 5, 0, 1, 1',
1373 'setvl 0, 0, 5, 0, 1, 1',
1374 'sv.or *8, *3, *3',
1375 'or 3, 17, 17',
1376 'or 4, 16, 16',
1377 'or 5, 15, 15',
1378 'or 6, 14, 14',
1379 'or 7, 18, 18',
1380 'setvl 0, 0, 5, 0, 1, 1',
1381 'setvl 0, 0, 5, 0, 1, 1',
1382 'addic 0, 0, 0',
1383 'setvl 0, 0, 5, 0, 1, 1',
1384 'sv.or *14, *8, *8',
1385 'setvl 0, 0, 5, 0, 1, 1',
1386 'sv.or *8, *3, *3',
1387 'setvl 0, 0, 5, 0, 1, 1',
1388 'sv.adde *3, *14, *8',
1389 'setvl 0, 0, 5, 0, 1, 1',
1390 'setvl 0, 0, 5, 0, 1, 1',
1391 'setvl 0, 0, 5, 0, 1, 1',
1392 'or 16, 3, 3',
1393 'or 15, 4, 4',
1394 'or 14, 5, 5',
1395 'or 12, 6, 6',
1396 'or 11, 7, 7',
1397 'setvl 0, 0, 8, 0, 1, 1',
1398 'or 3, 27, 27',
1399 'or 4, 26, 26',
1400 'or 5, 20, 20',
1401 'or 6, 16, 16',
1402 'or 7, 15, 15',
1403 'or 8, 14, 14',
1404 'or 9, 12, 12',
1405 'or 10, 11, 11',
1406 'setvl 0, 0, 8, 0, 1, 1',
1407 'setvl 0, 0, 8, 0, 1, 1',
1408 'sv.or *21, *3, *3',
1409 'setvl 0, 0, 4, 0, 1, 1',
1410 'setvl 0, 0, 4, 0, 1, 1',
1411 'sv.or *3, *29, *29',
1412 'setvl 0, 0, 4, 0, 1, 1',
1413 'sv.or/mrr *5, *3, *3',
1414 'or 4, 5, 5',
1415 'or 9, 6, 6',
1416 'or 39, 7, 7',
1417 'or 38, 8, 8',
1418 'addi 3, 0, 0',
1419 'or 37, 3, 3',
1420 'setvl 0, 0, 4, 0, 1, 1',
1421 'addi 3, 0, 0',
1422 'setvl 0, 0, 4, 0, 1, 1',
1423 'sv.or *14, *40, *40',
1424 'or 8, 4, 4',
1425 'or 7, 37, 37',
1426 'setvl 0, 0, 4, 0, 1, 1',
1427 'sv.maddedu *3, *14, 8, 7',
1428 'setvl 0, 0, 4, 0, 1, 1',
1429 'or 32, 7, 7',
1430 'setvl 0, 0, 4, 0, 1, 1',
1431 'setvl 0, 0, 4, 0, 1, 1',
1432 'or 36, 3, 3',
1433 'or 31, 4, 4',
1434 'or 19, 5, 5',
1435 'or 18, 6, 6',
1436 'setvl 0, 0, 4, 0, 1, 1',
1437 'sv.or *14, *40, *40',
1438 'or 8, 9, 9',
1439 'or 7, 37, 37',
1440 'setvl 0, 0, 4, 0, 1, 1',
1441 'sv.maddedu *3, *14, 8, 7',
1442 'setvl 0, 0, 4, 0, 1, 1',
1443 'or 30, 7, 7',
1444 'setvl 0, 0, 4, 0, 1, 1',
1445 'setvl 0, 0, 4, 0, 1, 1',
1446 'or 29, 3, 3',
1447 'or 20, 4, 4',
1448 'or 12, 5, 5',
1449 'or 11, 6, 6',
1450 'addi 3, 0, 0',
1451 'or 10, 3, 3',
1452 'addi 3, 0, 0',
1453 'or 9, 3, 3',
1454 'setvl 0, 0, 6, 0, 1, 1',
1455 'or 3, 31, 31',
1456 'or 4, 19, 19',
1457 'or 5, 18, 18',
1458 'or 6, 32, 32',
1459 'or 7, 10, 10',
1460 'or 8, 10, 10',
1461 'setvl 0, 0, 6, 0, 1, 1',
1462 'setvl 0, 0, 6, 0, 1, 1',
1463 'sv.or *14, *3, *3',
1464 'or 3, 29, 29',
1465 'or 4, 20, 20',
1466 'or 5, 12, 12',
1467 'or 6, 11, 11',
1468 'or 7, 30, 30',
1469 'or 8, 9, 9',
1470 'setvl 0, 0, 6, 0, 1, 1',
1471 'setvl 0, 0, 6, 0, 1, 1',
1472 'addic 0, 0, 0',
1473 'setvl 0, 0, 6, 0, 1, 1',
1474 'sv.or *29, *14, *14',
1475 'setvl 0, 0, 6, 0, 1, 1',
1476 'sv.or *14, *3, *3',
1477 'setvl 0, 0, 6, 0, 1, 1',
1478 'sv.adde *3, *29, *14',
1479 'setvl 0, 0, 6, 0, 1, 1',
1480 'setvl 0, 0, 6, 0, 1, 1',
1481 'setvl 0, 0, 6, 0, 1, 1',
1482 'or 35, 3, 3',
1483 'or 33, 4, 4',
1484 'or 32, 5, 5',
1485 'or 31, 6, 6',
1486 'or 19, 7, 7',
1487 'or 18, 8, 8',
1488 'setvl 0, 0, 4, 0, 1, 1',
1489 'sv.or *14, *40, *40',
1490 'or 8, 39, 39',
1491 'or 7, 37, 37',
1492 'setvl 0, 0, 4, 0, 1, 1',
1493 'sv.maddedu *3, *14, 8, 7',
1494 'setvl 0, 0, 4, 0, 1, 1',
1495 'or 30, 7, 7',
1496 'setvl 0, 0, 4, 0, 1, 1',
1497 'setvl 0, 0, 4, 0, 1, 1',
1498 'or 29, 3, 3',
1499 'or 20, 4, 4',
1500 'or 12, 5, 5',
1501 'or 11, 6, 6',
1502 'addi 3, 0, 0',
1503 'or 10, 3, 3',
1504 'addi 3, 0, 0',
1505 'or 9, 3, 3',
1506 'setvl 0, 0, 6, 0, 1, 1',
1507 'or 3, 33, 33',
1508 'or 4, 32, 32',
1509 'or 5, 31, 31',
1510 'or 6, 19, 19',
1511 'or 7, 18, 18',
1512 'or 8, 10, 10',
1513 'setvl 0, 0, 6, 0, 1, 1',
1514 'setvl 0, 0, 6, 0, 1, 1',
1515 'sv.or *14, *3, *3',
1516 'or 3, 29, 29',
1517 'or 4, 20, 20',
1518 'or 5, 12, 12',
1519 'or 6, 11, 11',
1520 'or 7, 30, 30',
1521 'or 8, 9, 9',
1522 'setvl 0, 0, 6, 0, 1, 1',
1523 'setvl 0, 0, 6, 0, 1, 1',
1524 'addic 0, 0, 0',
1525 'setvl 0, 0, 6, 0, 1, 1',
1526 'sv.or *29, *14, *14',
1527 'setvl 0, 0, 6, 0, 1, 1',
1528 'sv.or *14, *3, *3',
1529 'setvl 0, 0, 6, 0, 1, 1',
1530 'sv.adde *3, *29, *14',
1531 'setvl 0, 0, 6, 0, 1, 1',
1532 'setvl 0, 0, 6, 0, 1, 1',
1533 'setvl 0, 0, 6, 0, 1, 1',
1534 'or 20, 3, 3',
1535 'or 19, 4, 4',
1536 'or 12, 5, 5',
1537 'or 11, 6, 6',
1538 'or 10, 7, 7',
1539 'or 9, 8, 8',
1540 'setvl 0, 0, 4, 0, 1, 1',
1541 'sv.or *14, *40, *40',
1542 'or 8, 38, 38',
1543 'or 7, 37, 37',
1544 'setvl 0, 0, 4, 0, 1, 1',
1545 'sv.maddedu *3, *14, 8, 7',
1546 'setvl 0, 0, 4, 0, 1, 1',
1547 'or 18, 7, 7',
1548 'setvl 0, 0, 4, 0, 1, 1',
1549 'setvl 0, 0, 4, 0, 1, 1',
1550 'or 17, 3, 3',
1551 'or 16, 4, 4',
1552 'or 15, 5, 5',
1553 'or 14, 6, 6',
1554 'setvl 0, 0, 5, 0, 1, 1',
1555 'or 3, 19, 19',
1556 'or 4, 12, 12',
1557 'or 5, 11, 11',
1558 'or 6, 10, 10',
1559 'or 7, 9, 9',
1560 'setvl 0, 0, 5, 0, 1, 1',
1561 'setvl 0, 0, 5, 0, 1, 1',
1562 'sv.or *8, *3, *3',
1563 'or 3, 17, 17',
1564 'or 4, 16, 16',
1565 'or 5, 15, 15',
1566 'or 6, 14, 14',
1567 'or 7, 18, 18',
1568 'setvl 0, 0, 5, 0, 1, 1',
1569 'setvl 0, 0, 5, 0, 1, 1',
1570 'addic 0, 0, 0',
1571 'setvl 0, 0, 5, 0, 1, 1',
1572 'sv.or *14, *8, *8',
1573 'setvl 0, 0, 5, 0, 1, 1',
1574 'sv.or *8, *3, *3',
1575 'setvl 0, 0, 5, 0, 1, 1',
1576 'sv.adde *3, *14, *8',
1577 'setvl 0, 0, 5, 0, 1, 1',
1578 'setvl 0, 0, 5, 0, 1, 1',
1579 'setvl 0, 0, 5, 0, 1, 1',
1580 'or 16, 3, 3',
1581 'or 15, 4, 4',
1582 'or 14, 5, 5',
1583 'or 12, 6, 6',
1584 'or 11, 7, 7',
1585 'setvl 0, 0, 8, 0, 1, 1',
1586 'or 3, 36, 36',
1587 'or 4, 35, 35',
1588 'or 5, 20, 20',
1589 'or 6, 16, 16',
1590 'or 7, 15, 15',
1591 'or 8, 14, 14',
1592 'or 9, 12, 12',
1593 'or 10, 11, 11',
1594 'setvl 0, 0, 8, 0, 1, 1',
1595 'setvl 0, 0, 8, 0, 1, 1',
1596 'sv.or *37, *3, *3',
1597 'setvl 0, 0, 1, 0, 1, 1',
1598 'or 3, 47, 47',
1599 'setvl 0, 0, 1, 0, 1, 1',
1600 'or 5, 3, 3',
1601 'addi 3, 0, 0',
1602 'or 4, 3, 3',
1603 'setvl 0, 0, 1, 0, 1, 1',
1604 'addi 3, 0, 0',
1605 'or 6, 46, 46',
1606 'setvl 0, 0, 1, 0, 1, 1',
1607 'sv.maddedu *3, *6, 5, 4',
1608 'or 5, 4, 4',
1609 'setvl 0, 0, 1, 0, 1, 1',
1610 'setvl 0, 0, 2, 0, 1, 1',
1611 'or 4, 5, 5',
1612 'setvl 0, 0, 2, 0, 1, 1',
1613 'setvl 0, 0, 2, 0, 1, 1',
1614 'sv.or *35, *3, *3',
1615 'setvl 0, 0, 8, 0, 1, 1',
1616 'setvl 0, 0, 8, 0, 1, 1',
1617 'sv.or *3, *21, *21',
1618 'setvl 0, 0, 8, 0, 1, 1',
1619 'sv.or *17, *3, *3',
1620 'or 4, 17, 17',
1621 'or 16, 18, 18',
1622 'or 15, 19, 19',
1623 'or 14, 20, 20',
1624 'or 12, 21, 21',
1625 'or 11, 22, 22',
1626 'or 10, 23, 23',
1627 'or 3, 24, 24',
1628 'setvl 0, 0, 7, 0, 1, 1',
1629 'or 3, 4, 4',
1630 'or 4, 16, 16',
1631 'or 5, 15, 15',
1632 'or 6, 14, 14',
1633 'or 7, 12, 12',
1634 'or 8, 11, 11',
1635 'or 9, 10, 10',
1636 'setvl 0, 0, 7, 0, 1, 1',
1637 'setvl 0, 0, 7, 0, 1, 1',
1638 'sv.or *28, *3, *3',
1639 'setvl 0, 0, 8, 0, 1, 1',
1640 'setvl 0, 0, 8, 0, 1, 1',
1641 'sv.or *3, *37, *37',
1642 'setvl 0, 0, 8, 0, 1, 1',
1643 'sv.or *17, *3, *3',
1644 'or 4, 17, 17',
1645 'or 16, 18, 18',
1646 'or 15, 19, 19',
1647 'or 14, 20, 20',
1648 'or 12, 21, 21',
1649 'or 11, 22, 22',
1650 'or 10, 23, 23',
1651 'or 3, 24, 24',
1652 'setvl 0, 0, 7, 0, 1, 1',
1653 'or 3, 4, 4',
1654 'or 4, 16, 16',
1655 'or 5, 15, 15',
1656 'or 6, 14, 14',
1657 'or 7, 12, 12',
1658 'or 8, 11, 11',
1659 'or 9, 10, 10',
1660 'setvl 0, 0, 7, 0, 1, 1',
1661 'setvl 0, 0, 7, 0, 1, 1',
1662 'setvl 0, 0, 7, 0, 1, 1',
1663 'subfc 0, 0, 0',
1664 'setvl 0, 0, 7, 0, 1, 1',
1665 'sv.or *21, *28, *28',
1666 'setvl 0, 0, 7, 0, 1, 1',
1667 'sv.or *14, *3, *3',
1668 'setvl 0, 0, 7, 0, 1, 1',
1669 'sv.subfe *3, *21, *14',
1670 'setvl 0, 0, 7, 0, 1, 1',
1671 'sv.or *14, *3, *3',
1672 'setvl 0, 0, 2, 0, 1, 1',
1673 'setvl 0, 0, 2, 0, 1, 1',
1674 'sv.or *3, *35, *35',
1675 'setvl 0, 0, 2, 0, 1, 1',
1676 'sv.or *5, *3, *3',
1677 'or 4, 5, 5',
1678 'or 11, 6, 6',
1679 'addi 3, 0, 0',
1680 'or 10, 3, 3',
1681 'setvl 0, 0, 7, 0, 1, 1',
1682 'or 3, 4, 4',
1683 'or 4, 11, 11',
1684 'or 5, 10, 10',
1685 'or 6, 10, 10',
1686 'or 7, 10, 10',
1687 'or 8, 10, 10',
1688 'or 9, 10, 10',
1689 'setvl 0, 0, 7, 0, 1, 1',
1690 'setvl 0, 0, 7, 0, 1, 1',
1691 'setvl 0, 0, 7, 0, 1, 1',
1692 'subfc 0, 0, 0',
1693 'setvl 0, 0, 7, 0, 1, 1',
1694 'sv.or *21, *3, *3',
1695 'setvl 0, 0, 7, 0, 1, 1',
1696 'setvl 0, 0, 7, 0, 1, 1',
1697 'sv.subfe *3, *21, *14',
1698 'setvl 0, 0, 7, 0, 1, 1',
1699 'sv.or *14, *3, *3',
1700 'setvl 0, 0, 7, 0, 1, 1',
1701 'setvl 0, 0, 7, 0, 1, 1',
1702 'sv.or *3, *28, *28',
1703 'setvl 0, 0, 7, 0, 1, 1',
1704 'or 25, 3, 3',
1705 'or 24, 4, 4',
1706 'or 23, 5, 5',
1707 'or 22, 6, 6',
1708 'or 21, 7, 7',
1709 'or 12, 8, 8',
1710 'or 11, 9, 9',
1711 'setvl 0, 0, 7, 0, 1, 1',
1712 'setvl 0, 0, 7, 0, 1, 1',
1713 'sv.or *3, *14, *14',
1714 'setvl 0, 0, 7, 0, 1, 1',
1715 'sv.or/mrr *4, *3, *3',
1716 'or 18, 4, 4',
1717 'or 17, 5, 5',
1718 'or 16, 6, 6',
1719 'or 15, 7, 7',
1720 'or 14, 8, 8',
1721 'or 3, 9, 9',
1722 'or 3, 10, 10',
1723 'setvl 0, 0, 2, 0, 1, 1',
1724 'setvl 0, 0, 2, 0, 1, 1',
1725 'sv.or *3, *35, *35',
1726 'setvl 0, 0, 2, 0, 1, 1',
1727 'or 20, 3, 3',
1728 'or 19, 4, 4',
1729 'addi 3, 0, 0',
1730 'addi 3, 0, 0',
1731 'or 8, 3, 3',
1732 'setvl 0, 0, 5, 0, 1, 1',
1733 'or 3, 22, 22',
1734 'or 4, 21, 21',
1735 'or 5, 12, 12',
1736 'or 6, 11, 11',
1737 'or 7, 8, 8',
1738 'setvl 0, 0, 5, 0, 1, 1',
1739 'setvl 0, 0, 5, 0, 1, 1',
1740 'sv.or *8, *3, *3',
1741 'or 3, 18, 18',
1742 'or 4, 17, 17',
1743 'or 5, 16, 16',
1744 'or 6, 15, 15',
1745 'or 7, 14, 14',
1746 'setvl 0, 0, 5, 0, 1, 1',
1747 'setvl 0, 0, 5, 0, 1, 1',
1748 'addic 0, 0, 0',
1749 'setvl 0, 0, 5, 0, 1, 1',
1750 'sv.or *14, *8, *8',
1751 'setvl 0, 0, 5, 0, 1, 1',
1752 'sv.or *8, *3, *3',
1753 'setvl 0, 0, 5, 0, 1, 1',
1754 'sv.adde *3, *14, *8',
1755 'setvl 0, 0, 5, 0, 1, 1',
1756 'setvl 0, 0, 5, 0, 1, 1',
1757 'setvl 0, 0, 5, 0, 1, 1',
1758 'sv.or/mrr *4, *3, *3',
1759 'or 16, 4, 4',
1760 'or 15, 5, 5',
1761 'or 14, 6, 6',
1762 'or 3, 7, 7',
1763 'or 5, 8, 8',
1764 'setvl 0, 0, 2, 0, 1, 1',
1765 'or 4, 5, 5',
1766 'setvl 0, 0, 2, 0, 1, 1',
1767 'setvl 0, 0, 2, 0, 1, 1',
1768 'sv.or *5, *3, *3',
1769 'or 3, 20, 20',
1770 'or 4, 19, 19',
1771 'setvl 0, 0, 2, 0, 1, 1',
1772 'setvl 0, 0, 2, 0, 1, 1',
1773 'addic 0, 0, 0',
1774 'setvl 0, 0, 2, 0, 1, 1',
1775 'sv.or *7, *5, *5',
1776 'setvl 0, 0, 2, 0, 1, 1',
1777 'sv.or *5, *3, *3',
1778 'setvl 0, 0, 2, 0, 1, 1',
1779 'sv.adde *3, *7, *5',
1780 'setvl 0, 0, 2, 0, 1, 1',
1781 'setvl 0, 0, 2, 0, 1, 1',
1782 'setvl 0, 0, 2, 0, 1, 1',
1783 'or 12, 3, 3',
1784 'or 11, 4, 4',
1785 'setvl 0, 0, 8, 0, 1, 1',
1786 'or 3, 25, 25',
1787 'or 4, 24, 24',
1788 'or 5, 23, 23',
1789 'or 6, 16, 16',
1790 'or 7, 15, 15',
1791 'or 8, 14, 14',
1792 'or 9, 12, 12',
1793 'or 10, 11, 11',
1794 'setvl 0, 0, 8, 0, 1, 1',
1795 'setvl 0, 0, 8, 0, 1, 1',
1796 'setvl 0, 0, 8, 0, 1, 1',
1797 'setvl 0, 0, 8, 0, 1, 1',
1798 'sv.or/mrr *4, *3, *3',
1799 'or 3, 45, 45',
1800 'setvl 0, 0, 8, 0, 1, 1',
1801 'sv.std *4, 0(3)',
1802 ])
1803
1804 def tst_toom_mul_sim(
1805 self, code, # type: Mul
1806 lhs_signed, # type: bool
1807 rhs_signed, # type: bool
1808 get_state_factory, # type: Callable[[Mul], _StateFactory]
1809 test_cases, # type: Iterable[tuple[int, int]]
1810 ):
1811 print(code.retval[1])
1812 print(code.fn.ops_to_str())
1813 state_factory = get_state_factory(code)
1814 ptr_in = 0x100
1815 dest_ptr = ptr_in + code.dest_offset
1816 lhs_ptr = ptr_in + code.lhs_offset
1817 rhs_ptr = ptr_in + code.rhs_offset
1818 lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS
1819 rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS
1820 for lhs_value, rhs_value in test_cases:
1821 lhs_value %= 1 << lhs_size_in_bits
1822 rhs_value %= 1 << rhs_size_in_bits
1823 if lhs_signed and lhs_value >> (lhs_size_in_bits - 1):
1824 lhs_value -= 1 << lhs_size_in_bits
1825 if rhs_signed and rhs_value >> (rhs_size_in_bits - 1):
1826 rhs_value -= 1 << rhs_size_in_bits
1827 prod_value = lhs_value * rhs_value
1828 lhs_value %= 1 << lhs_size_in_bits
1829 rhs_value %= 1 << rhs_size_in_bits
1830 prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits)
1831 with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1832 lhs_value=hex(lhs_value),
1833 rhs_value=hex(rhs_value),
1834 prod_value=hex(prod_value)):
1835 with state_factory() as state:
1836 state[code.ptr_in] = ptr_in,
1837 for i in range(code.lhs_size_in_words):
1838 v = lhs_value >> GPR_SIZE_IN_BITS * i
1839 v &= GPR_VALUE_MASK
1840 state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1841 for i in range(code.rhs_size_in_words):
1842 v = rhs_value >> GPR_SIZE_IN_BITS * i
1843 v &= GPR_VALUE_MASK
1844 state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
1845 code.fn.sim(state)
1846 prod = 0
1847 for i in range(code.dest_size_in_words):
1848 v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
1849 prod += v << (GPR_SIZE_IN_BITS * i)
1850 self.assertEqual(hex(prod), hex(prod_value),
1851 f"failed: state={state}")
1852
1853 def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed):
1854 # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
1855 def mul(fn, lhs, rhs):
1856 # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
1857 v = ToomCookMul(
1858 fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
1859 rhs_signed=rhs_signed, instances=instances)
1860 return v.retval, v
1861 sizes_in_words = OSet() # type: OSet[int]
1862 for i in range(6):
1863 sizes_in_words.add(1 << i)
1864 sizes_in_words.add(3 << i)
1865 sizes_in_words = OSet(
1866 i for i in sorted(sizes_in_words) if 1 <= i <= 16)
1867 for lhs_size_in_words in sizes_in_words:
1868 for rhs_size_in_words in sizes_in_words:
1869 lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
1870 rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
1871 with self.subTest(lhs_size_in_words=lhs_size_in_words,
1872 rhs_size_in_words=rhs_size_in_words,
1873 lhs_signed=lhs_signed,
1874 rhs_signed=rhs_signed):
1875 test_cases = [] # type: list[tuple[int, int]]
1876 test_cases.append((-1, -1))
1877 test_cases.append(((0x80 << 2048) // 0xFF,
1878 (0x80 << 2048) // 0xFF))
1879 test_cases.append(((0x40 << 2048) // 0xFF,
1880 (0x80 << 2048) // 0xFF))
1881 test_cases.append(((0x80 << 2048) // 0xFF,
1882 (0x40 << 2048) // 0xFF))
1883 test_cases.append(((0x40 << 2048) // 0xFF,
1884 (0x40 << 2048) // 0xFF))
1885 test_cases.append((1 << (lhs_size_in_bits - 1),
1886 1 << (rhs_size_in_bits - 1)))
1887 test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
1888 test_cases.append((1 << (lhs_size_in_bits - 1), 1))
1889 test_cases.append((1, 1))
1890 self.tst_toom_mul_sim(
1891 code=Mul(mul=mul,
1892 lhs_size_in_words=lhs_size_in_words,
1893 rhs_size_in_words=rhs_size_in_words),
1894 lhs_signed=lhs_signed, rhs_signed=rhs_signed,
1895 get_state_factory=get_pre_ra_state_factory,
1896 test_cases=test_cases)
1897
1898 def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self):
1899 TOOM_2 = ToomCookInstance.make_toom_2()
1900 self.tst_toom_mul_all_sizes_pre_ra_sim(
1901 (TOOM_2,), lhs_signed=False, rhs_signed=False)
1902
1903 def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self):
1904 TOOM_2 = ToomCookInstance.make_toom_2()
1905 self.tst_toom_mul_all_sizes_pre_ra_sim(
1906 (TOOM_2,), lhs_signed=False, rhs_signed=True)
1907
1908 def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self):
1909 TOOM_2 = ToomCookInstance.make_toom_2()
1910 self.tst_toom_mul_all_sizes_pre_ra_sim(
1911 (TOOM_2,), lhs_signed=True, rhs_signed=False)
1912
1913 def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self):
1914 TOOM_2 = ToomCookInstance.make_toom_2()
1915 self.tst_toom_mul_all_sizes_pre_ra_sim(
1916 (TOOM_2,), lhs_signed=True, rhs_signed=True)
1917
1918 def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self):
1919 TOOM_2 = ToomCookInstance.make_toom_2()
1920 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1921 self.tst_toom_mul_all_sizes_pre_ra_sim(
1922 instances, lhs_signed=False, rhs_signed=False)
1923
1924 def test_toom_2_mul_us_all_sizes_pre_ra_sim(self):
1925 TOOM_2 = ToomCookInstance.make_toom_2()
1926 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1927 self.tst_toom_mul_all_sizes_pre_ra_sim(
1928 instances, lhs_signed=False, rhs_signed=True)
1929
1930 def test_toom_2_mul_su_all_sizes_pre_ra_sim(self):
1931 TOOM_2 = ToomCookInstance.make_toom_2()
1932 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1933 self.tst_toom_mul_all_sizes_pre_ra_sim(
1934 instances, lhs_signed=True, rhs_signed=False)
1935
1936 def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self):
1937 TOOM_2 = ToomCookInstance.make_toom_2()
1938 instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
1939 self.tst_toom_mul_all_sizes_pre_ra_sim(
1940 instances, lhs_signed=True, rhs_signed=True)
1941
1942
1943 if __name__ == "__main__":
1944 unittest.main()