working on adding signed multiplication -- needed for toom-cook
[bigint-presentation-code.git] / src / bigint_presentation_code / _tests / test_toom_cook.py
1 import unittest
2 from typing import Callable
3
4 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES,
5 BaseSimState, Fn,
6 GenAsmState, OpKind,
7 PostRASimState,
8 PreRASimState, SSAVal)
9 from bigint_presentation_code.register_allocator import allocate_registers
10 from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul,
11 toom_cook_mul)
12
13
14 def simple_umul(fn, lhs, rhs):
15 # type: (Fn, SSAVal, SSAVal) -> SSAVal
16 return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
17 rhs_signed=False, name="simple_umul")
18
19
20 class Mul:
21 def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
22 # type: (Callable[[Fn, SSAVal, SSAVal], SSAVal], int, int) -> None
23 super().__init__()
24 self.fn = fn = Fn()
25 self.dest_offset = 0
26 self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
27 self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
28 self.lhs_size_in_words = lhs_size_in_words
29 self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
30 self.rhs_size_in_words = rhs_size_in_words
31 self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
32 self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
33 self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
34 self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
35 name="ptr_in").outputs[0]
36 lhs_setvl = fn.append_new_op(
37 kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
38 maxvl=lhs_size_in_words, name="lhs_setvl")
39 load_lhs = fn.append_new_op(
40 kind=OpKind.SvLd, immediates=[self.lhs_offset],
41 input_vals=[self.ptr_in, lhs_setvl.outputs[0]],
42 name="load_lhs", maxvl=lhs_size_in_words)
43 rhs_setvl = fn.append_new_op(
44 kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
45 maxvl=rhs_size_in_words, name="rhs_setvl")
46 load_rhs = fn.append_new_op(
47 kind=OpKind.SvLd, immediates=[self.rhs_offset],
48 input_vals=[self.ptr_in, rhs_setvl.outputs[0]],
49 name="load_rhs", maxvl=3)
50 retval = mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
51 dest_setvl = fn.append_new_op(
52 kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
53 maxvl=self.dest_size_in_words, name="dest_setvl")
54 fn.append_new_op(
55 kind=OpKind.SvStd,
56 input_vals=[retval, self.ptr_in, dest_setvl.outputs[0]],
57 immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
58 name="store_dest")
59
60
61 class TestToomCook(unittest.TestCase):
62 maxDiff = None
63
64 def test_toom_2_repr(self):
65 TOOM_2 = ToomCookInstance.make_toom_2()
66 # print(repr(repr(TOOM_2)))
67 self.assertEqual(
68 repr(TOOM_2),
69 "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
70 "eval_points=(0, 1, POINT_AT_INFINITY), "
71 "lhs_eval_ops=("
72 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
73 "EvalOpAdd(lhs="
74 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
75 "rhs="
76 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
77 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
78 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
79 " rhs_eval_ops=("
80 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
81 "EvalOpAdd(lhs="
82 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
83 "rhs="
84 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
85 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
86 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
87 " prod_eval_ops=("
88 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
89 "EvalOpSub(lhs="
90 "EvalOpSub(lhs="
91 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
92 "rhs="
93 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
94 "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
95 "rhs="
96 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
97 "poly=EvalOpPoly({"
98 "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
99 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
100 )
101
102 def test_toom_2_5_repr(self):
103 TOOM_2_5 = ToomCookInstance.make_toom_2_5()
104 # print(repr(repr(TOOM_2_5)))
105 self.assertEqual(
106 repr(TOOM_2_5),
107 "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
108 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
109 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
110 "EvalOpAdd(lhs="
111 "EvalOpAdd(lhs="
112 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
113 "rhs=EvalOpInput(lhs=2, rhs=0, "
114 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
115 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
116 "rhs=EvalOpInput(lhs=1, rhs=0, "
117 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
118 "poly=EvalOpPoly({"
119 "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
120 "EvalOpSub(lhs="
121 "EvalOpAdd(lhs="
122 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
123 "rhs=EvalOpInput(lhs=2, rhs=0, "
124 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
125 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
126 "rhs=EvalOpInput(lhs=1, rhs=0, "
127 "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
128 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
129 "EvalOpInput(lhs=2, rhs=0, "
130 "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
131 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
132 "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
133 "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
134 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
135 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
136 "EvalOpSub(lhs="
137 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
138 "rhs=EvalOpInput(lhs=1, rhs=0, "
139 "poly=EvalOpPoly({1: Fraction(1, 1)})), "
140 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
141 "EvalOpInput(lhs=1, rhs=0, "
142 "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
143 "prod_eval_ops=("
144 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
145 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
146 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
147 "rhs=EvalOpInput(lhs=2, rhs=0, "
148 "poly=EvalOpPoly({2: Fraction(1, 1)})), "
149 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
150 "rhs=2, "
151 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
152 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
153 "poly=EvalOpPoly("
154 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
155 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
156 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
157 "rhs="
158 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
159 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
160 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
161 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
162 "poly=EvalOpPoly("
163 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
164 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
165 )
166
167 def test_reversed_toom_2_5_repr(self):
168 TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
169 # print(repr(repr(TOOM_2_5)))
170 self.assertEqual(
171 repr(TOOM_2_5),
172 "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
173 "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
174 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
175 "EvalOpAdd(lhs="
176 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
177 "rhs="
178 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
179 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
180 "EvalOpSub(lhs="
181 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
182 "rhs="
183 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
184 "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
185 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
186 " rhs_eval_ops=("
187 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
188 "EvalOpAdd(lhs=EvalOpAdd(lhs="
189 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
190 "rhs="
191 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
192 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
193 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
194 "poly=EvalOpPoly("
195 "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
196 "EvalOpSub(lhs=EvalOpAdd(lhs="
197 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
198 "rhs="
199 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
200 "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
201 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
202 "poly=EvalOpPoly("
203 "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
204 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
205 " prod_eval_ops=("
206 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
207 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
208 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
209 "rhs="
210 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
211 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
212 "rhs=2, "
213 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
214 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
215 "poly=EvalOpPoly("
216 "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
217 "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
218 "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
219 "rhs="
220 "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
221 "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
222 "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
223 "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
224 "poly=EvalOpPoly("
225 "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
226 "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
227 )
228
229 def test_simple_mul_192x192_pre_ra_sim(self):
230 self.skipTest("WIP") # FIXME: finish fixing simple_mul
231
232 def create_sim_state(code):
233 # type: (Mul) -> BaseSimState
234 return PreRASimState(ssa_vals={}, memory={})
235 self.tst_simple_mul_192x192_sim(create_sim_state)
236
237 def test_simple_mul_192x192_post_ra_sim(self):
238 self.skipTest("WIP") # FIXME: finish fixing simple_mul
239
240 def create_sim_state(code):
241 # type: (Mul) -> BaseSimState
242 ssa_val_to_loc_map = allocate_registers(code.fn)
243 return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
244 memory={}, loc_values={})
245 self.tst_simple_mul_192x192_sim(create_sim_state)
246
247 def tst_simple_mul_192x192_sim(self, create_sim_state):
248 # type: (Callable[[Mul], BaseSimState]) -> None
249 self.skipTest("WIP") # FIXME: finish fixing simple_mul
250 # test multiplying:
251 # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
252 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
253 # ==
254 # int("0x00074736574206e_6f69746163696c70"
255 # "_69746c756d207469_622d3438333e2d32"
256 # "_3931783239312079_7261727469627261", base=0)
257 # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
258 # 'little')
259 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
260 state = create_sim_state(code)
261 ptr_in = 0x100
262 dest_ptr = ptr_in + code.dest_offset
263 lhs_ptr = ptr_in + code.lhs_offset
264 rhs_ptr = ptr_in + code.rhs_offset
265 state[code.ptr_in] = ptr_in,
266 state.store(lhs_ptr, 0x821a2342132c5b57)
267 state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e)
268 state.store(lhs_ptr + 16, 0x000191acb262e15b)
269 state.store(rhs_ptr, 0x208a49071aeec507)
270 state.store(rhs_ptr + 8, 0xcf1f597598194ae6)
271 state.store(rhs_ptr + 16, 0x4a37c0567bcbab53)
272 code.fn.sim(state)
273 expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
274 OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
275 expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
276 out_bytes = bytes(
277 state.load_byte(dest_ptr + i) for i in range(OUT_BYTE_COUNT))
278 self.assertEqual(out_bytes, expected_bytes)
279
280 def test_simple_mul_192x192_ops(self):
281 self.skipTest("WIP") # FIXME: finish fixing simple_mul
282 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
283 fn = code.fn
284 self.assertEqual([repr(v) for v in fn.ops], [
285 "Op(kind=OpKind.FuncArgR3, "
286 "input_vals=[], "
287 "input_uses=(), immediates=[], "
288 "outputs=(<ptr_in.outputs[0]: <I64>>,), "
289 "name='ptr_in')",
290 "Op(kind=OpKind.SetVLI, "
291 "input_vals=[], "
292 "input_uses=(), immediates=[3], "
293 "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
294 "name='lhs_setvl')",
295 "Op(kind=OpKind.SvLd, "
296 "input_vals=[<ptr_in.outputs[0]: <I64>>, "
297 "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
298 "input_uses=(<load_lhs.input_uses[0]: <I64>>, "
299 "<load_lhs.input_uses[1]: <VL_MAXVL>>), immediates=[48], "
300 "outputs=(<load_lhs.outputs[0]: <I64*3>>,), "
301 "name='load_lhs')",
302 "Op(kind=OpKind.SetVLI, "
303 "input_vals=[], "
304 "input_uses=(), immediates=[3], "
305 "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
306 "name='rhs_setvl')",
307 "Op(kind=OpKind.SvLd, "
308 "input_vals=[<ptr_in.outputs[0]: <I64>>, "
309 "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
310 "input_uses=(<load_rhs.input_uses[0]: <I64>>, "
311 "<load_rhs.input_uses[1]: <VL_MAXVL>>), immediates=[72], "
312 "outputs=(<load_rhs.outputs[0]: <I64*3>>,), "
313 "name='load_rhs')",
314 "Op(kind=OpKind.SetVLI, "
315 "input_vals=[], "
316 "input_uses=(), immediates=[3], "
317 "outputs=(<rhs_setvl2.outputs[0]: <VL_MAXVL>>,), "
318 "name='rhs_setvl2')",
319 "Op(kind=OpKind.Spread, "
320 "input_vals=[<load_rhs.outputs[0]: <I64*3>>, "
321 "<rhs_setvl2.outputs[0]: <VL_MAXVL>>], "
322 "input_uses=(<rhs_spread.input_uses[0]: <I64*3>>, "
323 "<rhs_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
324 "outputs=(<rhs_spread.outputs[0]: <I64>>, "
325 "<rhs_spread.outputs[1]: <I64>>, "
326 "<rhs_spread.outputs[2]: <I64>>), "
327 "name='rhs_spread')",
328 "Op(kind=OpKind.SetVLI, "
329 "input_vals=[], "
330 "input_uses=(), immediates=[3], "
331 "outputs=(<lhs_setvl3.outputs[0]: <VL_MAXVL>>,), "
332 "name='lhs_setvl3')",
333 "Op(kind=OpKind.LI, "
334 "input_vals=[], "
335 "input_uses=(), immediates=[0], "
336 "outputs=(<zero.outputs[0]: <I64>>,), "
337 "name='zero')",
338 "Op(kind=OpKind.SvMAddEDU, "
339 "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
340 "<rhs_spread.outputs[0]: <I64>>, "
341 "<zero.outputs[0]: <I64>>, "
342 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
343 "input_uses=(<mul0.input_uses[0]: <I64*3>>, "
344 "<mul0.input_uses[1]: <I64>>, "
345 "<mul0.input_uses[2]: <I64>>, "
346 "<mul0.input_uses[3]: <VL_MAXVL>>), immediates=[], "
347 "outputs=(<mul0.outputs[0]: <I64*3>>, "
348 "<mul0.outputs[1]: <I64>>), "
349 "name='mul0')",
350 "Op(kind=OpKind.Spread, "
351 "input_vals=[<mul0.outputs[0]: <I64*3>>, "
352 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
353 "input_uses=(<mul0_rt_spread.input_uses[0]: <I64*3>>, "
354 "<mul0_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
355 "outputs=(<mul0_rt_spread.outputs[0]: <I64>>, "
356 "<mul0_rt_spread.outputs[1]: <I64>>, "
357 "<mul0_rt_spread.outputs[2]: <I64>>), "
358 "name='mul0_rt_spread')",
359 "Op(kind=OpKind.SvMAddEDU, "
360 "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
361 "<rhs_spread.outputs[1]: <I64>>, "
362 "<zero.outputs[0]: <I64>>, "
363 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
364 "input_uses=(<mul1.input_uses[0]: <I64*3>>, "
365 "<mul1.input_uses[1]: <I64>>, "
366 "<mul1.input_uses[2]: <I64>>, "
367 "<mul1.input_uses[3]: <VL_MAXVL>>), immediates=[], "
368 "outputs=(<mul1.outputs[0]: <I64*3>>, "
369 "<mul1.outputs[1]: <I64>>), "
370 "name='mul1')",
371 "Op(kind=OpKind.Concat, "
372 "input_vals=[<mul0_rt_spread.outputs[1]: <I64>>, "
373 "<mul0_rt_spread.outputs[2]: <I64>>, "
374 "<mul0.outputs[1]: <I64>>, "
375 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
376 "input_uses=(<add1_rb_concat.input_uses[0]: <I64>>, "
377 "<add1_rb_concat.input_uses[1]: <I64>>, "
378 "<add1_rb_concat.input_uses[2]: <I64>>, "
379 "<add1_rb_concat.input_uses[3]: <VL_MAXVL>>), immediates=[], "
380 "outputs=(<add1_rb_concat.outputs[0]: <I64*3>>,), "
381 "name='add1_rb_concat')",
382 "Op(kind=OpKind.ClearCA, "
383 "input_vals=[], "
384 "input_uses=(), immediates=[], "
385 "outputs=(<clear_ca1.outputs[0]: <CA>>,), "
386 "name='clear_ca1')",
387 "Op(kind=OpKind.SvAddE, "
388 "input_vals=[<mul1.outputs[0]: <I64*3>>, "
389 "<add1_rb_concat.outputs[0]: <I64*3>>, "
390 "<clear_ca1.outputs[0]: <CA>>, "
391 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
392 "input_uses=(<add1.input_uses[0]: <I64*3>>, "
393 "<add1.input_uses[1]: <I64*3>>, "
394 "<add1.input_uses[2]: <CA>>, "
395 "<add1.input_uses[3]: <VL_MAXVL>>), immediates=[], "
396 "outputs=(<add1.outputs[0]: <I64*3>>, "
397 "<add1.outputs[1]: <CA>>), "
398 "name='add1')",
399 "Op(kind=OpKind.Spread, "
400 "input_vals=[<add1.outputs[0]: <I64*3>>, "
401 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
402 "input_uses=(<add1_rt_spread.input_uses[0]: <I64*3>>, "
403 "<add1_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
404 "outputs=(<add1_rt_spread.outputs[0]: <I64>>, "
405 "<add1_rt_spread.outputs[1]: <I64>>, "
406 "<add1_rt_spread.outputs[2]: <I64>>), "
407 "name='add1_rt_spread')",
408 "Op(kind=OpKind.AddZE, "
409 "input_vals=[<mul1.outputs[1]: <I64>>, "
410 "<add1.outputs[1]: <CA>>], "
411 "input_uses=(<add_hi1.input_uses[0]: <I64>>, "
412 "<add_hi1.input_uses[1]: <CA>>), immediates=[], "
413 "outputs=(<add_hi1.outputs[0]: <I64>>, "
414 "<add_hi1.outputs[1]: <CA>>), "
415 "name='add_hi1')",
416 "Op(kind=OpKind.SvMAddEDU, "
417 "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
418 "<rhs_spread.outputs[2]: <I64>>, "
419 "<zero.outputs[0]: <I64>>, "
420 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
421 "input_uses=(<mul2.input_uses[0]: <I64*3>>, "
422 "<mul2.input_uses[1]: <I64>>, "
423 "<mul2.input_uses[2]: <I64>>, "
424 "<mul2.input_uses[3]: <VL_MAXVL>>), immediates=[], "
425 "outputs=(<mul2.outputs[0]: <I64*3>>, "
426 "<mul2.outputs[1]: <I64>>), "
427 "name='mul2')",
428 "Op(kind=OpKind.Concat, "
429 "input_vals=[<add1_rt_spread.outputs[1]: <I64>>, "
430 "<add1_rt_spread.outputs[2]: <I64>>, "
431 "<add_hi1.outputs[0]: <I64>>, "
432 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
433 "input_uses=(<add2_rb_concat.input_uses[0]: <I64>>, "
434 "<add2_rb_concat.input_uses[1]: <I64>>, "
435 "<add2_rb_concat.input_uses[2]: <I64>>, "
436 "<add2_rb_concat.input_uses[3]: <VL_MAXVL>>), immediates=[], "
437 "outputs=(<add2_rb_concat.outputs[0]: <I64*3>>,), "
438 "name='add2_rb_concat')",
439 "Op(kind=OpKind.ClearCA, "
440 "input_vals=[], "
441 "input_uses=(), immediates=[], "
442 "outputs=(<clear_ca2.outputs[0]: <CA>>,), "
443 "name='clear_ca2')",
444 "Op(kind=OpKind.SvAddE, "
445 "input_vals=[<mul2.outputs[0]: <I64*3>>, "
446 "<add2_rb_concat.outputs[0]: <I64*3>>, "
447 "<clear_ca2.outputs[0]: <CA>>, "
448 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
449 "input_uses=(<add2.input_uses[0]: <I64*3>>, "
450 "<add2.input_uses[1]: <I64*3>>, "
451 "<add2.input_uses[2]: <CA>>, "
452 "<add2.input_uses[3]: <VL_MAXVL>>), immediates=[], "
453 "outputs=(<add2.outputs[0]: <I64*3>>, "
454 "<add2.outputs[1]: <CA>>), "
455 "name='add2')",
456 "Op(kind=OpKind.Spread, "
457 "input_vals=[<add2.outputs[0]: <I64*3>>, "
458 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
459 "input_uses=(<add2_rt_spread.input_uses[0]: <I64*3>>, "
460 "<add2_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
461 "outputs=(<add2_rt_spread.outputs[0]: <I64>>, "
462 "<add2_rt_spread.outputs[1]: <I64>>, "
463 "<add2_rt_spread.outputs[2]: <I64>>), "
464 "name='add2_rt_spread')",
465 "Op(kind=OpKind.AddZE, "
466 "input_vals=[<mul2.outputs[1]: <I64>>, "
467 "<add2.outputs[1]: <CA>>], "
468 "input_uses=(<add_hi2.input_uses[0]: <I64>>, "
469 "<add_hi2.input_uses[1]: <CA>>), immediates=[], "
470 "outputs=(<add_hi2.outputs[0]: <I64>>, "
471 "<add_hi2.outputs[1]: <CA>>), "
472 "name='add_hi2')",
473 "Op(kind=OpKind.SetVLI, "
474 "input_vals=[], "
475 "input_uses=(), immediates=[6], "
476 "outputs=(<retval_setvl.outputs[0]: <VL_MAXVL>>,), "
477 "name='retval_setvl')",
478 "Op(kind=OpKind.Concat, "
479 "input_vals=[<mul0_rt_spread.outputs[0]: <I64>>, "
480 "<add1_rt_spread.outputs[0]: <I64>>, "
481 "<add2_rt_spread.outputs[0]: <I64>>, "
482 "<add2_rt_spread.outputs[1]: <I64>>, "
483 "<add2_rt_spread.outputs[2]: <I64>>, "
484 "<add_hi2.outputs[0]: <I64>>, "
485 "<retval_setvl.outputs[0]: <VL_MAXVL>>], "
486 "input_uses=(<concat_retval.input_uses[0]: <I64>>, "
487 "<concat_retval.input_uses[1]: <I64>>, "
488 "<concat_retval.input_uses[2]: <I64>>, "
489 "<concat_retval.input_uses[3]: <I64>>, "
490 "<concat_retval.input_uses[4]: <I64>>, "
491 "<concat_retval.input_uses[5]: <I64>>, "
492 "<concat_retval.input_uses[6]: <VL_MAXVL>>), immediates=[], "
493 "outputs=(<concat_retval.outputs[0]: <I64*6>>,), "
494 "name='concat_retval')",
495 "Op(kind=OpKind.SetVLI, "
496 "input_vals=[], "
497 "input_uses=(), immediates=[6], "
498 "outputs=(<dest_setvl.outputs[0]: <VL_MAXVL>>,), "
499 "name='dest_setvl')",
500 "Op(kind=OpKind.SvStd, "
501 "input_vals=[<concat_retval.outputs[0]: <I64*6>>, "
502 "<ptr_in.outputs[0]: <I64>>, "
503 "<dest_setvl.outputs[0]: <VL_MAXVL>>], "
504 "input_uses=(<store_dest.input_uses[0]: <I64*6>>, "
505 "<store_dest.input_uses[1]: <I64>>, "
506 "<store_dest.input_uses[2]: <VL_MAXVL>>), immediates=[0], "
507 "outputs=(), "
508 "name='store_dest')",
509 ])
510
511 def test_simple_mul_192x192_reg_alloc(self):
512 self.skipTest("WIP") # FIXME: finish fixing simple_mul
513 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
514 fn = code.fn
515 assigned_registers = allocate_registers(fn)
516 self.assertEqual(
517 repr(assigned_registers), "{"
518 "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
519 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
520 "<store_dest.inp1.copy.outputs[0]: <I64>>: "
521 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
522 "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
523 "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
524 "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
525 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
526 "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
527 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
528 "<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
529 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
530 "<concat_retval.out0.setvl.outputs[0]: <VL_MAXVL>>: "
531 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
532 "<concat_retval.outputs[0]: <I64*6>>: "
533 "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
534 "<concat_retval.inp0.copy.outputs[0]: <I64>>: "
535 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
536 "<concat_retval.inp1.copy.outputs[0]: <I64>>: "
537 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
538 "<concat_retval.inp2.copy.outputs[0]: <I64>>: "
539 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
540 "<concat_retval.inp3.copy.outputs[0]: <I64>>: "
541 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
542 "<concat_retval.inp4.copy.outputs[0]: <I64>>: "
543 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
544 "<concat_retval.inp5.copy.outputs[0]: <I64>>: "
545 "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
546 "<concat_retval.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
547 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
548 "<retval_setvl.outputs[0]: <VL_MAXVL>>: "
549 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
550 "<add_hi2.out0.copy.outputs[0]: <I64>>: "
551 "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
552 "<clear_ca2.outputs[0]: <CA>>: "
553 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
554 "<add2.outputs[1]: <CA>>: "
555 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
556 "<add_hi2.outputs[1]: <CA>>: "
557 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
558 "<add_hi2.outputs[0]: <I64>>: "
559 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
560 "<add_hi2.inp0.copy.outputs[0]: <I64>>: "
561 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
562 "<add2_rt_spread.out2.copy.outputs[0]: <I64>>: "
563 "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
564 "<add2_rt_spread.out1.copy.outputs[0]: <I64>>: "
565 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
566 "<add2_rt_spread.out0.copy.outputs[0]: <I64>>: "
567 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
568 "<add2_rt_spread.outputs[0]: <I64>>: "
569 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
570 "<add2_rt_spread.outputs[1]: <I64>>: "
571 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
572 "<add2_rt_spread.outputs[2]: <I64>>: "
573 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
574 "<add2_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
575 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
576 "<add2_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
577 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
578 "<add2_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
579 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
580 "<add2.out0.copy.outputs[0]: <I64*3>>: "
581 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
582 "<add2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
583 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
584 "<add2.outputs[0]: <I64*3>>: "
585 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
586 "<add2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
587 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
588 "<add2.inp1.copy.outputs[0]: <I64*3>>: "
589 "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
590 "<add2.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
591 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
592 "<add2.inp0.copy.outputs[0]: <I64*3>>: "
593 "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
594 "<add2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
595 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
596 "<add2_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
597 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
598 "<add2_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
599 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
600 "<add2_rb_concat.outputs[0]: <I64*3>>: "
601 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
602 "<add2_rb_concat.inp0.copy.outputs[0]: <I64>>: "
603 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
604 "<add2_rb_concat.inp1.copy.outputs[0]: <I64>>: "
605 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
606 "<add2_rb_concat.inp2.copy.outputs[0]: <I64>>: "
607 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
608 "<add2_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
609 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
610 "<mul2.out1.copy.outputs[0]: <I64>>: "
611 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
612 "<mul2.out0.copy.outputs[0]: <I64*3>>: "
613 "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
614 "<mul2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
615 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
616 "<mul2.inp2.copy.outputs[0]: <I64>>: "
617 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
618 "<mul2.outputs[1]: <I64>>: "
619 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
620 "<mul2.outputs[0]: <I64*3>>: "
621 "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
622 "<mul2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
623 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
624 "<mul2.inp1.copy.outputs[0]: <I64>>: "
625 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
626 "<mul2.inp0.copy.outputs[0]: <I64*3>>: "
627 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
628 "<mul2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
629 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
630 "<add_hi1.out0.copy.outputs[0]: <I64>>: "
631 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
632 "<clear_ca1.outputs[0]: <CA>>: "
633 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
634 "<add1.outputs[1]: <CA>>: "
635 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
636 "<add_hi1.outputs[1]: <CA>>: "
637 "Loc(kind=LocKind.CA, start=0, reg_len=1), "
638 "<add_hi1.outputs[0]: <I64>>: "
639 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
640 "<add_hi1.inp0.copy.outputs[0]: <I64>>: "
641 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
642 "<add1_rt_spread.out2.copy.outputs[0]: <I64>>: "
643 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
644 "<add1_rt_spread.out1.copy.outputs[0]: <I64>>: "
645 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
646 "<add1_rt_spread.out0.copy.outputs[0]: <I64>>: "
647 "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
648 "<add1_rt_spread.outputs[0]: <I64>>: "
649 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
650 "<add1_rt_spread.outputs[1]: <I64>>: "
651 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
652 "<add1_rt_spread.outputs[2]: <I64>>: "
653 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
654 "<add1_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
655 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
656 "<add1_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
657 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
658 "<add1_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
659 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
660 "<add1.out0.copy.outputs[0]: <I64*3>>: "
661 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
662 "<add1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
663 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
664 "<add1.outputs[0]: <I64*3>>: "
665 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
666 "<add1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
667 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
668 "<add1.inp1.copy.outputs[0]: <I64*3>>: "
669 "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
670 "<add1.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
671 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
672 "<add1.inp0.copy.outputs[0]: <I64*3>>: "
673 "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
674 "<add1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
675 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
676 "<add1_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
677 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
678 "<add1_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
679 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
680 "<add1_rb_concat.outputs[0]: <I64*3>>: "
681 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
682 "<add1_rb_concat.inp0.copy.outputs[0]: <I64>>: "
683 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
684 "<add1_rb_concat.inp1.copy.outputs[0]: <I64>>: "
685 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
686 "<add1_rb_concat.inp2.copy.outputs[0]: <I64>>: "
687 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
688 "<add1_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
689 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
690 "<mul1.out1.copy.outputs[0]: <I64>>: "
691 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
692 "<mul1.out0.copy.outputs[0]: <I64*3>>: "
693 "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
694 "<mul1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
695 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
696 "<mul1.inp2.copy.outputs[0]: <I64>>: "
697 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
698 "<mul1.outputs[1]: <I64>>: "
699 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
700 "<mul1.outputs[0]: <I64*3>>: "
701 "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
702 "<mul1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
703 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
704 "<mul1.inp1.copy.outputs[0]: <I64>>: "
705 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
706 "<mul1.inp0.copy.outputs[0]: <I64*3>>: "
707 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
708 "<mul1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
709 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
710 "<mul0_rt_spread.out2.copy.outputs[0]: <I64>>: "
711 "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
712 "<mul0_rt_spread.out1.copy.outputs[0]: <I64>>: "
713 "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
714 "<mul0_rt_spread.out0.copy.outputs[0]: <I64>>: "
715 "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
716 "<mul0_rt_spread.outputs[0]: <I64>>: "
717 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
718 "<mul0_rt_spread.outputs[1]: <I64>>: "
719 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
720 "<mul0_rt_spread.outputs[2]: <I64>>: "
721 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
722 "<mul0_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
723 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
724 "<mul0_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
725 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
726 "<mul0_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
727 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
728 "<mul0.out1.copy.outputs[0]: <I64>>: "
729 "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
730 "<mul0.out0.copy.outputs[0]: <I64*3>>: "
731 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
732 "<mul0.out0.setvl.outputs[0]: <VL_MAXVL>>: "
733 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
734 "<mul0.inp2.copy.outputs[0]: <I64>>: "
735 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
736 "<mul0.outputs[1]: <I64>>: "
737 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
738 "<mul0.outputs[0]: <I64*3>>: "
739 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
740 "<mul0.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
741 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
742 "<mul0.inp1.copy.outputs[0]: <I64>>: "
743 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
744 "<mul0.inp0.copy.outputs[0]: <I64*3>>: "
745 "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
746 "<mul0.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
747 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
748 "<zero.out0.copy.outputs[0]: <I64>>: "
749 "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
750 "<zero.outputs[0]: <I64>>: "
751 "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
752 "<lhs_setvl3.outputs[0]: <VL_MAXVL>>: "
753 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
754 "<rhs_spread.out2.copy.outputs[0]: <I64>>: "
755 "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
756 "<rhs_spread.out1.copy.outputs[0]: <I64>>: "
757 "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
758 "<rhs_spread.out0.copy.outputs[0]: <I64>>: "
759 "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
760 "<rhs_spread.outputs[0]: <I64>>: "
761 "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
762 "<rhs_spread.outputs[1]: <I64>>: "
763 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
764 "<rhs_spread.outputs[2]: <I64>>: "
765 "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
766 "<rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
767 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
768 "<rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
769 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
770 "<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
771 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
772 "<rhs_setvl2.outputs[0]: <VL_MAXVL>>: "
773 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
774 "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
775 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
776 "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
777 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
778 "<load_rhs.outputs[0]: <I64*3>>: "
779 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
780 "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
781 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
782 "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
783 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
784 "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
785 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
786 "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
787 "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
788 "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
789 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
790 "<load_lhs.outputs[0]: <I64*3>>: "
791 "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
792 "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
793 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
794 "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
795 "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
796 "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
797 "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
798 "<ptr_in.out0.copy.outputs[0]: <I64>>: "
799 "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
800 "<ptr_in.outputs[0]: <I64>>: "
801 "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
802 "}")
803
804 def test_simple_mul_192x192_asm(self):
805 self.skipTest("WIP") # FIXME: finish fixing simple_mul
806 code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
807 fn = code.fn
808 assigned_registers = allocate_registers(fn)
809 gen_asm_state = GenAsmState(assigned_registers)
810 fn.gen_asm(gen_asm_state)
811 self.assertEqual(gen_asm_state.output, [
812 'or 23, 3, 3',
813 'setvl 0, 0, 3, 0, 1, 1',
814 'or 6, 23, 23',
815 'setvl 0, 0, 3, 0, 1, 1',
816 'sv.ld *3, 48(6)',
817 'setvl 0, 0, 3, 0, 1, 1',
818 'sv.or *20, *3, *3',
819 'setvl 0, 0, 3, 0, 1, 1',
820 'or 6, 23, 23',
821 'setvl 0, 0, 3, 0, 1, 1',
822 'sv.ld *3, 72(6)',
823 'setvl 0, 0, 3, 0, 1, 1',
824 'setvl 0, 0, 3, 0, 1, 1',
825 'setvl 0, 0, 3, 0, 1, 1',
826 'setvl 0, 0, 3, 0, 1, 1',
827 'sv.or/mrr *5, *3, *3',
828 'or 4, 5, 5',
829 'or 14, 6, 6',
830 'or 19, 7, 7',
831 'setvl 0, 0, 3, 0, 1, 1',
832 'addi 3, 0, 0',
833 'or 18, 3, 3',
834 'setvl 0, 0, 3, 0, 1, 1',
835 'sv.or *8, *20, *20',
836 'or 7, 4, 4',
837 'or 6, 18, 18',
838 'setvl 0, 0, 3, 0, 1, 1',
839 'sv.maddedu *3, *8, 7, 6',
840 'setvl 0, 0, 3, 0, 1, 1',
841 'or 15, 6, 6',
842 'setvl 0, 0, 3, 0, 1, 1',
843 'setvl 0, 0, 3, 0, 1, 1',
844 'or 17, 3, 3',
845 'or 12, 4, 4',
846 'or 11, 5, 5',
847 'setvl 0, 0, 3, 0, 1, 1',
848 'sv.or *8, *20, *20',
849 'or 7, 14, 14',
850 'or 3, 18, 18',
851 'setvl 0, 0, 3, 0, 1, 1',
852 'sv.maddedu *4, *8, 7, 3',
853 'setvl 0, 0, 3, 0, 1, 1',
854 'sv.or/mrr *6, *4, *4',
855 'or 14, 3, 3',
856 'or 3, 12, 12',
857 'or 4, 11, 11',
858 'or 5, 15, 15',
859 'setvl 0, 0, 3, 0, 1, 1',
860 'setvl 0, 0, 3, 0, 1, 1',
861 'addic 0, 0, 0',
862 'setvl 0, 0, 3, 0, 1, 1',
863 'sv.or *9, *6, *6',
864 'setvl 0, 0, 3, 0, 1, 1',
865 'sv.or *6, *3, *3',
866 'setvl 0, 0, 3, 0, 1, 1',
867 'sv.adde *3, *9, *6',
868 'setvl 0, 0, 3, 0, 1, 1',
869 'setvl 0, 0, 3, 0, 1, 1',
870 'setvl 0, 0, 3, 0, 1, 1',
871 'or 16, 3, 3',
872 'or 15, 4, 4',
873 'or 12, 5, 5',
874 'or 4, 14, 14',
875 'addze *3, *4',
876 'or 11, 3, 3',
877 'setvl 0, 0, 3, 0, 1, 1',
878 'sv.or *8, *20, *20',
879 'or 7, 19, 19',
880 'or 3, 18, 18',
881 'setvl 0, 0, 3, 0, 1, 1',
882 'sv.maddedu *4, *8, 7, 3',
883 'setvl 0, 0, 3, 0, 1, 1',
884 'sv.or/mrr *6, *4, *4',
885 'or 14, 3, 3',
886 'or 3, 15, 15',
887 'or 4, 12, 12',
888 'or 5, 11, 11',
889 'setvl 0, 0, 3, 0, 1, 1',
890 'setvl 0, 0, 3, 0, 1, 1',
891 'addic 0, 0, 0',
892 'setvl 0, 0, 3, 0, 1, 1',
893 'sv.or *9, *6, *6',
894 'setvl 0, 0, 3, 0, 1, 1',
895 'sv.or *6, *3, *3',
896 'setvl 0, 0, 3, 0, 1, 1',
897 'sv.adde *3, *9, *6',
898 'setvl 0, 0, 3, 0, 1, 1',
899 'setvl 0, 0, 3, 0, 1, 1',
900 'setvl 0, 0, 3, 0, 1, 1',
901 'or 12, 3, 3',
902 'or 11, 4, 4',
903 'or 10, 5, 5',
904 'or 4, 14, 14',
905 'addze *3, *4',
906 'or 9, 3, 3',
907 'setvl 0, 0, 6, 0, 1, 1',
908 'or 3, 17, 17',
909 'or 4, 16, 16',
910 'or 5, 12, 12',
911 'or 6, 11, 11',
912 'or 7, 10, 10',
913 'or 8, 9, 9',
914 'setvl 0, 0, 6, 0, 1, 1',
915 'setvl 0, 0, 6, 0, 1, 1',
916 'setvl 0, 0, 6, 0, 1, 1',
917 'setvl 0, 0, 6, 0, 1, 1',
918 'sv.or/mrr *4, *3, *3',
919 'or 3, 23, 23',
920 'setvl 0, 0, 6, 0, 1, 1',
921 'sv.std *4, 0(3)'
922 ])
923
924 def test_toom_2_mul_256x256_asm(self):
925 self.skipTest("WIP") # FIXME: finish
926 TOOM_2 = ToomCookInstance.make_toom_2()
927 instances = TOOM_2, TOOM_2
928
929 def mul(fn, lhs, rhs):
930 # type: (Fn, SSAVal, SSAVal) -> SSAVal
931 return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
932 rhs_signed=False, instances=instances)
933 code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3)
934 fn = code.fn
935 assigned_registers = allocate_registers(fn)
936 gen_asm_state = GenAsmState(assigned_registers)
937 fn.gen_asm(gen_asm_state)
938 self.assertEqual(gen_asm_state.output, [
939 ])
940
941
942 if __name__ == "__main__":
943 unittest.main()