a7c5450f6f88ec5e3f4d204ad5449d7eccb669e4
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 """
2 Toom-Cook multiplication algorithm generator for SVP64
3 """
4 import math
5 from abc import abstractmethod
6 from enum import Enum
7 from fractions import Fraction
8 from typing import Iterable, Mapping, Tuple, Union
9
10 from cached_property import cached_property
11 from nmutil.plain_data import plain_data
12
13 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, BaseTy, Fn,
14 OpKind, SSAVal, Ty)
15 from bigint_presentation_code.matrix import Matrix
16 from bigint_presentation_code.type_util import Literal, final
17 from bigint_presentation_code.util import InternedMeta
18
19
20 @final
21 class PointAtInfinity(Enum):
22 POINT_AT_INFINITY = "POINT_AT_INFINITY"
23
24 def __repr__(self):
25 return self.name
26
27
28 POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
29
30 _EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
31 "EvalOpPoly", Fraction, int, None]
32
33
34 @plain_data(frozen=True, unsafe_hash=True, repr=False)
35 @final
36 class EvalOpPoly:
37 """polynomial"""
38 __slots__ = "const_coeff", "var_coeffs"
39
40 def __init__(
41 self, coeffs=None, # type: _EvalOpPolyCoefficients
42 const_coeff=None, # type: Fraction | int | None
43 var_coeffs=(), # type: Iterable[Fraction | int] | None
44 ):
45 if coeffs is not None:
46 if const_coeff is not None or var_coeffs != ():
47 raise ValueError(
48 "can't specify const_coeff or "
49 "var_coeffs along with coeffs")
50 if isinstance(coeffs, EvalOpPoly):
51 self.const_coeff = coeffs.const_coeff
52 self.var_coeffs = coeffs.var_coeffs
53 return
54 if isinstance(coeffs, (int, Fraction)):
55 const_coeff = Fraction(coeffs)
56 final_var_coeffs = [] # type: list[Fraction]
57 else:
58 const_coeff = 0
59 final_var_coeffs = []
60 for var, coeff in coeffs.items():
61 if coeff == 0:
62 continue
63 coeff = Fraction(coeff)
64 if var is None:
65 const_coeff = coeff
66 continue
67 if var < 0:
68 raise ValueError("invalid variable index")
69 if var >= len(final_var_coeffs):
70 additional = var - len(final_var_coeffs)
71 final_var_coeffs.extend((Fraction(),) * additional)
72 final_var_coeffs.append(coeff)
73 else:
74 final_var_coeffs[var] = coeff
75 else:
76 if var_coeffs is None:
77 final_var_coeffs = []
78 else:
79 final_var_coeffs = [Fraction(v) for v in var_coeffs]
80 while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
81 final_var_coeffs.pop()
82 if const_coeff is None:
83 const_coeff = 0
84 self.const_coeff = Fraction(const_coeff)
85 self.var_coeffs = tuple(final_var_coeffs)
86
87 def __add__(self, rhs):
88 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
89 rhs = EvalOpPoly(rhs)
90 const_coeff = self.const_coeff + rhs.const_coeff
91 var_coeffs = list(self.var_coeffs)
92 if len(rhs.var_coeffs) > len(var_coeffs):
93 var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
94 for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
95 var_coeffs[var] += rhs.var_coeffs[var]
96 return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
97
98 @property
99 def coefficients(self):
100 # type: () -> dict[int | None, Fraction]
101 retval = {} # type: dict[int | None, Fraction]
102 if self.const_coeff != 0:
103 retval[None] = self.const_coeff
104 for var, coeff in enumerate(self.var_coeffs):
105 if coeff != 0:
106 retval[var] = coeff
107 return retval
108
109 @property
110 def is_const(self):
111 # type: () -> bool
112 return self.var_coeffs == ()
113
114 def coeff(self, var):
115 # type: (int | None) -> Fraction
116 if var is None:
117 return self.const_coeff
118 if var < 0:
119 raise ValueError("invalid variable index")
120 if var < len(self.var_coeffs):
121 return self.var_coeffs[var]
122 return Fraction()
123
124 __radd__ = __add__
125
126 def __neg__(self):
127 return EvalOpPoly(const_coeff=-self.const_coeff,
128 var_coeffs=(-v for v in self.var_coeffs))
129
130 def __sub__(self, rhs):
131 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
132 return self + -rhs
133
134 def __rsub__(self, lhs):
135 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
136 return lhs + -self
137
138 def __mul__(self, rhs):
139 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
140 if isinstance(rhs, EvalOpPoly):
141 if self.is_const:
142 self, rhs = rhs, self
143 if not rhs.is_const:
144 raise ValueError("can't represent exponents larger than one")
145 rhs = rhs.const_coeff
146 if rhs == 0:
147 return EvalOpPoly()
148 return EvalOpPoly(const_coeff=self.const_coeff * rhs,
149 var_coeffs=(i * rhs for i in self.var_coeffs))
150
151 __rmul__ = __mul__
152
153 def __truediv__(self, rhs):
154 # type: (int | Fraction) -> EvalOpPoly
155 if rhs == 0:
156 raise ZeroDivisionError()
157 return EvalOpPoly(const_coeff=self.const_coeff / rhs,
158 var_coeffs=(i / rhs for i in self.var_coeffs))
159
160 def __repr__(self):
161 return f"EvalOpPoly({self.coefficients})"
162
163
164 @plain_data(frozen=True, unsafe_hash=True)
165 @final
166 class EvalOpValueRange:
167 __slots__ = ("eval_op", "inputs", "min_value", "max_value",
168 "is_signed", "output_size")
169
170 def __init__(self, eval_op, inputs):
171 # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None
172 super().__init__()
173 self.eval_op = eval_op
174 self.inputs = inputs
175 min_value = max_value = self.poly.const_coeff
176 for var, coeff in enumerate(self.poly.var_coeffs):
177 if coeff == 0:
178 continue
179 term_min = self.inputs[var].min_value * coeff
180 term_max = self.inputs[var].max_value * coeff
181 if term_min > term_max:
182 term_min, term_max = term_max, term_min
183 min_value += term_min
184 max_value += term_max
185 # output values are always integers, so eliminate any fractional part
186 # as impossible.
187 self.min_value = math.ceil(min_value) # exclude fractional part
188 self.max_value = math.floor(max_value) # exclude fractional part
189 self.is_signed = min_value < 0
190 output_size = 1
191 if self.is_signed:
192 min_v = -1 << (GPR_SIZE_IN_BITS - 1)
193 max_v = (1 << (GPR_SIZE_IN_BITS - 1)) - 1
194 else:
195 min_v = 0
196 max_v = (1 << GPR_SIZE_IN_BITS) - 1
197 while not (min_v <= self.min_value and self.max_value <= max_v):
198 output_size += 1
199 min_v <<= GPR_SIZE_IN_BITS
200 max_v <<= GPR_SIZE_IN_BITS
201 self.output_size = output_size
202
203 @cached_property
204 def poly(self):
205 if isinstance(self.eval_op, int):
206 return EvalOpPoly(const_coeff=self.eval_op)
207 return self.eval_op.poly
208
209
210 @plain_data(frozen=True, unsafe_hash=True)
211 @final
212 class EvalOpGenIrOutput:
213 __slots__ = "output", "value_range"
214
215 def __init__(self, output, value_range):
216 # type: (SSAVal, EvalOpValueRange) -> None
217 super().__init__()
218 if output.ty.reg_len != value_range.output_size:
219 raise ValueError("wrong output size")
220 self.output = output
221 self.value_range = value_range
222
223 @property
224 def eval_op(self):
225 # type: () -> EvalOp | int
226 return self.value_range.eval_op
227
228 @property
229 def inputs(self):
230 # type: () -> tuple[EvalOpGenIrInput, ...]
231 return self.value_range.inputs
232
233 @property
234 def min_value(self):
235 # type: () -> int
236 return self.value_range.min_value
237
238 @property
239 def max_value(self):
240 # type: () -> int
241 return self.value_range.max_value
242
243 @property
244 def is_signed(self):
245 # type: () -> bool
246 return self.value_range.is_signed
247
248 @property
249 def output_size(self):
250 # type: () -> int
251 return self.value_range.output_size
252
253 @property
254 def current_debugging_value(self):
255 # type: () -> tuple[int, ...]
256 """ get the current value for debugging in pdb or similar.
257
258 This is intended for use with
259 `PreRASimState.set_current_debugging_state`.
260
261 This is only intended for debugging, do not use in unit tests or
262 production code.
263 """
264 return self.output.current_debugging_value
265
266
267 @plain_data(frozen=True, unsafe_hash=True)
268 @final
269 class EvalOpGenIrInput:
270 __slots__ = "ssa_val", "is_signed", "min_value", "max_value"
271
272 def __init__(self, ssa_val, is_signed, min_value=None, max_value=None):
273 # type: (SSAVal, bool | None, int | None, int | None) -> None
274 super().__init__()
275 self.ssa_val = ssa_val
276 if ssa_val.base_ty != BaseTy.I64:
277 raise ValueError("input must have a base_ty of BaseTy.I64")
278 if is_signed is None:
279 if min_value is None or max_value is None:
280 raise ValueError("must specify either is_signed or both "
281 "min_value and max_value")
282 is_signed = min_value < 0
283 self.is_signed = is_signed
284 if is_signed:
285 if min_value is None:
286 min_value = -1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1)
287 if max_value is None:
288 max_value = (1 << (
289 ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1)) - 1
290 else:
291 if min_value is None:
292 min_value = 0
293 if max_value is None:
294 max_value = (1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS)) - 1
295 self.min_value = min_value
296 self.max_value = max_value
297 if self.min_value > self.max_value:
298 raise ValueError("invalid value range")
299
300 @property
301 def current_debugging_value(self):
302 # type: () -> tuple[int, ...]
303 """ get the current value for debugging in pdb or similar.
304
305 This is intended for use with
306 `PreRASimState.set_current_debugging_state`.
307
308 This is only intended for debugging, do not use in unit tests or
309 production code.
310 """
311 return self.ssa_val.current_debugging_value
312
313
314 @plain_data(frozen=True)
315 @final
316 class EvalOpGenIrState:
317 __slots__ = "fn", "inputs", "outputs_map"
318
319 def __init__(self, fn, inputs):
320 # type: (Fn, Iterable[EvalOpGenIrInput]) -> None
321 super().__init__()
322 self.fn = fn
323 self.inputs = tuple(inputs)
324 self.outputs_map = {} # type: dict[EvalOp | int, EvalOpGenIrOutput]
325
326 def get_output(self, eval_op):
327 # type: (EvalOp | int) -> EvalOpGenIrOutput
328 retval = self.outputs_map.get(eval_op, None)
329 if retval is not None:
330 return retval
331 value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs)
332 if isinstance(eval_op, int):
333 li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op],
334 name=f"li_{eval_op}")
335 output = cast_to_size(
336 fn=self.fn, ssa_val=li.outputs[0],
337 dest_size=value_range.output_size,
338 src_signed=value_range.is_signed, name=f"cast_{eval_op}")
339 retval = EvalOpGenIrOutput(output=output, value_range=value_range)
340 else:
341 retval = eval_op.make_output(state=self,
342 output_value_range=value_range)
343 if retval.value_range != value_range:
344 raise ValueError("wrong value_range")
345 return self.outputs_map.setdefault(eval_op, retval)
346
347
348 @plain_data(frozen=True, unsafe_hash=True)
349 class EvalOp(metaclass=InternedMeta):
350 __slots__ = "lhs", "rhs", "poly"
351
352 @property
353 def lhs_poly(self):
354 # type: () -> EvalOpPoly
355 if isinstance(self.lhs, int):
356 return EvalOpPoly(self.lhs)
357 return self.lhs.poly
358
359 @property
360 def rhs_poly(self):
361 # type: () -> EvalOpPoly
362 if isinstance(self.rhs, int):
363 return EvalOpPoly(self.rhs)
364 return self.rhs.poly
365
366 @abstractmethod
367 def _make_poly(self):
368 # type: () -> EvalOpPoly
369 ...
370
371 @abstractmethod
372 def make_output(self, state, output_value_range):
373 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
374 ...
375
376 def __init__(self, lhs, rhs):
377 # type: (EvalOp | int, EvalOp | int) -> None
378 super().__init__()
379 self.lhs = lhs
380 self.rhs = rhs
381 self.poly = self._make_poly()
382
383
384 @plain_data(frozen=True, unsafe_hash=True)
385 @final
386 class EvalOpAdd(EvalOp):
387 __slots__ = ()
388
389 def _make_poly(self):
390 # type: () -> EvalOpPoly
391 return self.lhs_poly + self.rhs_poly
392
393 def make_output(self, state, output_value_range):
394 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
395 lhs = state.get_output(self.lhs)
396 lhs_output = cast_to_size(
397 fn=state.fn, ssa_val=lhs.output,
398 dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
399 name="add_lhs_cast")
400 rhs = state.get_output(self.rhs)
401 rhs_output = cast_to_size(
402 fn=state.fn, ssa_val=rhs.output,
403 dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
404 name="add_rhs_cast")
405 setvl = state.fn.append_new_op(
406 OpKind.SetVLI, immediates=[output_value_range.output_size],
407 name="setvl", maxvl=output_value_range.output_size)
408 clear_ca = state.fn.append_new_op(OpKind.ClearCA, name="clear_ca")
409 add = state.fn.append_new_op(
410 OpKind.SvAddE, input_vals=[
411 lhs_output, rhs_output, clear_ca.outputs[0], setvl.outputs[0]],
412 maxvl=output_value_range.output_size, name="add")
413 return EvalOpGenIrOutput(
414 output=add.outputs[0], value_range=output_value_range)
415
416
417 @plain_data(frozen=True, unsafe_hash=True)
418 @final
419 class EvalOpSub(EvalOp):
420 __slots__ = ()
421
422 def _make_poly(self):
423 # type: () -> EvalOpPoly
424 return self.lhs_poly - self.rhs_poly
425
426 def make_output(self, state, output_value_range):
427 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
428 lhs = state.get_output(self.lhs)
429 lhs_output = cast_to_size(
430 fn=state.fn, ssa_val=lhs.output,
431 dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
432 name="add_lhs_cast")
433 rhs = state.get_output(self.rhs)
434 rhs_output = cast_to_size(
435 fn=state.fn, ssa_val=rhs.output,
436 dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
437 name="add_rhs_cast")
438 setvl = state.fn.append_new_op(
439 OpKind.SetVLI, immediates=[output_value_range.output_size],
440 name="setvl", maxvl=output_value_range.output_size)
441 set_ca = state.fn.append_new_op(OpKind.SetCA, name="set_ca")
442 sub = state.fn.append_new_op(
443 OpKind.SvSubFE, input_vals=[
444 rhs_output, lhs_output, set_ca.outputs[0], setvl.outputs[0]],
445 maxvl=output_value_range.output_size, name="sub")
446 return EvalOpGenIrOutput(
447 output=sub.outputs[0], value_range=output_value_range)
448
449
450 @plain_data(frozen=True, unsafe_hash=True)
451 @final
452 class EvalOpMul(EvalOp):
453 __slots__ = ()
454 rhs: int
455
456 def _make_poly(self):
457 # type: () -> EvalOpPoly
458 if not isinstance(self.rhs, int): # type: ignore
459 raise TypeError("invalid rhs type")
460 return self.lhs_poly * self.rhs
461
462 def make_output(self, state, output_value_range):
463 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
464 raise NotImplementedError # FIXME: finish
465
466
467 @plain_data(frozen=True, unsafe_hash=True)
468 @final
469 class EvalOpExactDiv(EvalOp):
470 __slots__ = ()
471 rhs: int
472
473 def _make_poly(self):
474 # type: () -> EvalOpPoly
475 if not isinstance(self.rhs, int): # type: ignore
476 raise TypeError("invalid rhs type")
477 return self.lhs_poly / self.rhs
478
479 def make_output(self, state, output_value_range):
480 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
481 raise NotImplementedError # FIXME: finish
482
483
484 @plain_data(frozen=True, unsafe_hash=True)
485 @final
486 class EvalOpInput(EvalOp):
487 __slots__ = ()
488 lhs: int
489 rhs: Literal[0]
490
491 def __init__(self, lhs, rhs=0):
492 # type: (int, int) -> None
493 if lhs < 0:
494 raise ValueError("Input part_index (lhs) must be >= 0")
495 if rhs != 0:
496 raise ValueError("Input rhs must be 0")
497 super().__init__(lhs, rhs)
498
499 @property
500 def part_index(self):
501 return self.lhs
502
503 def _make_poly(self):
504 # type: () -> EvalOpPoly
505 return EvalOpPoly({self.part_index: 1})
506
507 def make_output(self, state, output_value_range):
508 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
509 inp = state.inputs[self.part_index]
510 output = cast_to_size(
511 fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed,
512 dest_size=output_value_range.output_size,
513 name=f"input_{self.part_index}_cast")
514 return EvalOpGenIrOutput(output=output, value_range=output_value_range)
515
516
517 @plain_data(frozen=True, unsafe_hash=True)
518 @final
519 class ToomCookInstance:
520 __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
521 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
522
523 @property
524 def prod_part_count(self):
525 return self.lhs_part_count + self.rhs_part_count - 1
526
527 @staticmethod
528 def make_eval_matrix(width, eval_points):
529 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
530 retval = Matrix(height=len(eval_points), width=width)
531 for row, col in retval.indexes():
532 eval_point = eval_points[row]
533 if eval_point is POINT_AT_INFINITY:
534 retval[row, col] = int(col == width - 1)
535 else:
536 retval[row, col] = eval_point ** col
537 return retval
538
539 def get_lhs_eval_matrix(self):
540 # type: () -> Matrix[Fraction]
541 return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
542
543 @staticmethod
544 def make_input_poly_vector(height):
545 # type: (int) -> Matrix[EvalOpPoly]
546 return Matrix(height=height, width=1, element_type=EvalOpPoly,
547 data=(EvalOpPoly({i: 1}) for i in range(height)))
548
549 def get_lhs_eval_polys(self):
550 # type: () -> list[EvalOpPoly]
551 return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
552 @ self.make_input_poly_vector(self.lhs_part_count))
553
554 def get_rhs_eval_matrix(self):
555 # type: () -> Matrix[Fraction]
556 return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
557
558 def get_rhs_eval_polys(self):
559 # type: () -> list[EvalOpPoly]
560 return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
561 @ self.make_input_poly_vector(self.rhs_part_count))
562
563 def get_prod_inverse_eval_matrix(self):
564 # type: () -> Matrix[Fraction]
565 return self.make_eval_matrix(self.prod_part_count, self.eval_points)
566
567 def get_prod_eval_matrix(self):
568 # type: () -> Matrix[Fraction]
569 return self.get_prod_inverse_eval_matrix().inverse()
570
571 def get_prod_eval_polys(self):
572 # type: () -> list[EvalOpPoly]
573 return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
574 @ self.make_input_poly_vector(self.prod_part_count))
575
576 def __init__(
577 self, lhs_part_count, # type: int
578 rhs_part_count, # type: int
579 eval_points, # type: Iterable[PointAtInfinity | int]
580 lhs_eval_ops, # type: Iterable[EvalOp]
581 rhs_eval_ops, # type: Iterable[EvalOp]
582 prod_eval_ops, # type: Iterable[EvalOp]
583 ):
584 # type: (...) -> None
585 self.lhs_part_count = lhs_part_count
586 if self.lhs_part_count < 2:
587 raise ValueError("lhs_part_count must be at least 2")
588 self.rhs_part_count = rhs_part_count
589 if self.rhs_part_count < 2:
590 raise ValueError("rhs_part_count must be at least 2")
591 eval_points = list(eval_points)
592 self.eval_points = tuple(eval_points)
593 if len(self.eval_points) != len(set(self.eval_points)):
594 raise ValueError("duplicate eval points")
595 self.lhs_eval_ops = tuple(lhs_eval_ops)
596 if len(self.lhs_eval_ops) != self.prod_part_count:
597 raise ValueError("wrong number of lhs_eval_ops")
598 self.rhs_eval_ops = tuple(rhs_eval_ops)
599 if len(self.rhs_eval_ops) != self.prod_part_count:
600 raise ValueError("wrong number of rhs_eval_ops")
601 if len(self.eval_points) != self.prod_part_count:
602 raise ValueError("wrong number of eval_points")
603 self.prod_eval_ops = tuple(prod_eval_ops)
604 if len(self.prod_eval_ops) != self.prod_part_count:
605 raise ValueError("wrong number of prod_eval_ops")
606
607 lhs_eval_polys = self.get_lhs_eval_polys()
608 for i, eval_op in enumerate(self.lhs_eval_ops):
609 if lhs_eval_polys[i] != eval_op.poly:
610 raise ValueError(
611 f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
612 f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
613
614 rhs_eval_polys = self.get_rhs_eval_polys()
615 for i, eval_op in enumerate(self.rhs_eval_ops):
616 if rhs_eval_polys[i] != eval_op.poly:
617 raise ValueError(
618 f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
619 f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
620
621 prod_eval_polys = self.get_prod_eval_polys() # also checks matrix
622 for i, eval_op in enumerate(self.prod_eval_ops):
623 if prod_eval_polys[i] != eval_op.poly:
624 raise ValueError(
625 f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
626 f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
627
628 def reversed(self):
629 # type: () -> ToomCookInstance
630 """return a ToomCookInstance where lhs/rhs are reversed"""
631 return ToomCookInstance(
632 lhs_part_count=self.rhs_part_count,
633 rhs_part_count=self.lhs_part_count,
634 eval_points=self.eval_points,
635 lhs_eval_ops=self.rhs_eval_ops,
636 rhs_eval_ops=self.lhs_eval_ops,
637 prod_eval_ops=self.prod_eval_ops)
638
639 @staticmethod
640 def make_toom_2():
641 # type: () -> ToomCookInstance
642 """make an instance of Toom-2 aka Karatsuba multiplication"""
643 return ToomCookInstance(
644 lhs_part_count=2,
645 rhs_part_count=2,
646 eval_points=[0, 1, POINT_AT_INFINITY],
647 lhs_eval_ops=[
648 EvalOpInput(0),
649 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
650 EvalOpInput(1),
651 ],
652 rhs_eval_ops=[
653 EvalOpInput(0),
654 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
655 EvalOpInput(1),
656 ],
657 prod_eval_ops=[
658 EvalOpInput(0),
659 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
660 EvalOpInput(2)),
661 EvalOpInput(2),
662 ],
663 )
664
665 @staticmethod
666 def make_toom_2_5():
667 # type: () -> ToomCookInstance
668 """makes an instance of Toom-2.5"""
669 inp_0_plus_inp_2 = EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
670 inp_1_minus_inp_2 = EvalOpSub(EvalOpInput(1), EvalOpInput(2))
671 inp_1_plus_inp_2 = EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
672 inp_1_minus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_minus_inp_2, 2)
673 inp_1_plus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_plus_inp_2, 2)
674 return ToomCookInstance(
675 lhs_part_count=3,
676 rhs_part_count=2,
677 eval_points=[0, 1, -1, POINT_AT_INFINITY],
678 lhs_eval_ops=[
679 EvalOpInput(0),
680 EvalOpAdd(inp_0_plus_inp_2, EvalOpInput(1)),
681 EvalOpSub(inp_0_plus_inp_2, EvalOpInput(1)),
682 EvalOpInput(2),
683 ],
684 rhs_eval_ops=[
685 EvalOpInput(0),
686 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
687 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
688 EvalOpInput(1),
689 ],
690 prod_eval_ops=[
691 EvalOpInput(0),
692 EvalOpSub(inp_1_minus_inp_2_all_div_2, EvalOpInput(3)),
693 EvalOpSub(inp_1_plus_inp_2_all_div_2, EvalOpInput(0)),
694 EvalOpInput(3),
695 ],
696 )
697
698 # TODO: add make_toom_3
699
700
701 @plain_data(frozen=True, unsafe_hash=True)
702 @final
703 class PartialProduct:
704 __slots__ = "ssa_val_spread", "shift_in_words", "is_signed", "subtract"
705
706 def __init__(self, ssa_val_spread, shift_in_words, is_signed, subtract):
707 # type: (Iterable[SSAVal], int, bool, bool) -> None
708 if shift_in_words < 0:
709 raise ValueError("invalid shift_in_words")
710 self.ssa_val_spread = tuple(ssa_val_spread)
711 for ssa_val in ssa_val_spread:
712 if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1):
713 raise ValueError("invalid ssa_val.ty")
714 self.shift_in_words = shift_in_words
715 self.is_signed = is_signed
716 self.subtract = subtract
717
718
719 def sum_partial_products(fn, partial_products, retval_size, name):
720 # type: (Fn, Iterable[PartialProduct], int, str) -> SSAVal
721 retval_spread = [] # type: list[SSAVal]
722 retval_signed = False
723 zero = fn.append_new_op(OpKind.LI, immediates=[0],
724 name=f"{name}_zero").outputs[0]
725 has_carry_word = False
726 for idx, partial_product in enumerate(partial_products):
727 shift_in_words = partial_product.shift_in_words
728 spread = list(partial_product.ssa_val_spread)
729 if (not retval_signed and shift_in_words >= len(retval_spread)
730 and not partial_product.subtract):
731 retval_spread.extend(
732 [zero] * (shift_in_words - len(retval_spread)))
733 retval_spread.extend(spread)
734 retval_signed = partial_product.is_signed
735 has_carry_word = False
736 continue
737 assert len(retval_spread) != 0, "logic error"
738 retval_hi_len = len(retval_spread) - shift_in_words
739 if retval_hi_len <= len(spread):
740 maxvl = len(spread) + 1
741 has_carry_word = True
742 elif has_carry_word:
743 maxvl = retval_hi_len
744 else:
745 maxvl = retval_hi_len + 1
746 has_carry_word = True
747 if not has_carry_word:
748 maxvl += 1
749 has_carry_word = True
750 if maxvl > retval_size - shift_in_words:
751 maxvl = retval_size - shift_in_words
752 has_carry_word = False
753 retval_spread = cast_to_size_spread(
754 fn=fn, ssa_vals=retval_spread, src_signed=retval_signed,
755 dest_size=maxvl + shift_in_words, name=f"{name}_{idx}_cast_retval")
756 spread = cast_to_size_spread(
757 fn=fn, ssa_vals=spread, src_signed=partial_product.is_signed,
758 dest_size=maxvl, name=f"{name}_{idx}_cast_pp")
759 setvl = fn.append_new_op(
760 OpKind.SetVLI, immediates=[maxvl],
761 maxvl=maxvl, name=f"{name}_{idx}_setvl")
762 retval_concat = fn.append_new_op(
763 kind=OpKind.Concat,
764 input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]],
765 name=f"{name}_{idx}_retval_concat", maxvl=maxvl)
766 pp_concat = fn.append_new_op(
767 kind=OpKind.Concat,
768 input_vals=[*spread, setvl.outputs[0]],
769 name=f"{name}_{idx}_pp_concat", maxvl=maxvl)
770 if partial_product.subtract:
771 set_ca = fn.append_new_op(kind=OpKind.SetCA,
772 name=f"{name}_{idx}_set_ca")
773 add_sub = fn.append_new_op(
774 kind=OpKind.SvSubFE, input_vals=[
775 pp_concat.outputs[0], retval_concat.outputs[0],
776 set_ca.outputs[0], setvl.outputs[0]],
777 maxvl=maxvl, name=f"{name}_{idx}_sub")
778 else:
779 clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
780 name=f"{name}_{idx}_clear_ca")
781 add_sub = fn.append_new_op(
782 kind=OpKind.SvAddE, input_vals=[
783 retval_concat.outputs[0], pp_concat.outputs[0],
784 clear_ca.outputs[0], setvl.outputs[0]],
785 maxvl=maxvl, name=f"{name}_{idx}_add")
786 retval_spread[shift_in_words:] = fn.append_new_op(
787 kind=OpKind.Spread,
788 input_vals=[add_sub.outputs[0], setvl.outputs[0]],
789 name=f"{name}_{idx}_sum_spread", maxvl=maxvl).outputs
790 retval_spread = cast_to_size_spread(
791 fn=fn, ssa_vals=retval_spread, src_signed=retval_signed,
792 dest_size=retval_size, name=f"{name}_retval_cast")
793 retval_setvl = fn.append_new_op(
794 OpKind.SetVLI, immediates=[retval_size],
795 maxvl=retval_size, name=f"{name}_setvl")
796 retval_concat = fn.append_new_op(
797 kind=OpKind.Concat,
798 input_vals=[*retval_spread, retval_setvl.outputs[0]],
799 name=f"{name}_concat", maxvl=retval_size)
800 return retval_concat.outputs[0]
801
802
803 def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name, retval_size=None):
804 # type: (Fn, SSAVal, bool, SSAVal, bool, str, int | None) -> SSAVal
805 """ simple O(n^2) big-int multiply """
806 if retval_size is None:
807 retval_size = lhs.ty.reg_len + rhs.ty.reg_len
808 if lhs.ty.reg_len < rhs.ty.reg_len:
809 lhs, rhs = rhs, lhs
810 lhs_signed, rhs_signed = rhs_signed, lhs_signed
811 # split rhs into elements
812 rhs_setvl = fn.append_new_op(
813 kind=OpKind.SetVLI, immediates=[rhs.ty.reg_len],
814 name=f"{name}_rhs_setvl")
815 rhs_spread = fn.append_new_op(
816 kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
817 maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread")
818 rhs_words = rhs_spread.outputs
819 zero = fn.append_new_op(
820 kind=OpKind.LI, immediates=[0], name=f"{name}_zero").outputs[0]
821 maxvl = lhs.ty.reg_len
822 lhs_setvl = fn.append_new_op(
823 kind=OpKind.SetVLI, immediates=[maxvl], name=f"{name}_lhs_setvl",
824 maxvl=maxvl)
825 vl = lhs_setvl.outputs[0]
826
827 def partial_products():
828 # type: () -> Iterable[PartialProduct]
829 for shift_in_words, rhs_word in enumerate(rhs_words):
830 mul = fn.append_new_op(
831 kind=OpKind.SvMAddEDU, input_vals=[lhs, rhs_word, zero, vl],
832 maxvl=maxvl, name=f"{name}_{shift_in_words}_mul")
833 mul_rt_spread = fn.append_new_op(
834 kind=OpKind.Spread, input_vals=[mul.outputs[0], vl],
835 name=f"{name}_{shift_in_words}_mul_rt_spread", maxvl=maxvl)
836 yield PartialProduct(
837 ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]],
838 shift_in_words=shift_in_words,
839 is_signed=False, subtract=False)
840 if lhs_signed:
841 lhs_spread = fn.append_new_op(
842 kind=OpKind.Spread, input_vals=[lhs, lhs_setvl.outputs[0]],
843 maxvl=lhs.ty.reg_len, name=f"{name}_lhs_spread")
844 rhs_mask = fn.append_new_op(
845 kind=OpKind.SRADI, input_vals=[lhs_spread.outputs[-1]],
846 immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_rhs_mask")
847 lhs_and = fn.append_new_op(
848 kind=OpKind.SvAndVS,
849 input_vals=[rhs, rhs_mask.outputs[0], rhs_setvl.outputs[0]],
850 maxvl=rhs.ty.reg_len, name=f"{name}_rhs_and")
851 rhs_and_spread = fn.append_new_op(
852 kind=OpKind.Spread,
853 input_vals=[lhs_and.outputs[0], rhs_setvl.outputs[0]],
854 name=f"{name}_rhs_and_spread", maxvl=rhs.ty.reg_len)
855 yield PartialProduct(
856 ssa_val_spread=rhs_and_spread.outputs,
857 shift_in_words=lhs.ty.reg_len, is_signed=False, subtract=True)
858 if rhs_signed:
859 rhs_spread = fn.append_new_op(
860 kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
861 maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread")
862 lhs_mask = fn.append_new_op(
863 kind=OpKind.SRADI, input_vals=[rhs_spread.outputs[-1]],
864 immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_lhs_mask")
865 rhs_and = fn.append_new_op(
866 kind=OpKind.SvAndVS,
867 input_vals=[lhs, lhs_mask.outputs[0], lhs_setvl.outputs[0]],
868 maxvl=lhs.ty.reg_len, name=f"{name}_lhs_and")
869 lhs_and_spread = fn.append_new_op(
870 kind=OpKind.Spread,
871 input_vals=[rhs_and.outputs[0], lhs_setvl.outputs[0]],
872 name=f"{name}_lhs_and_spread", maxvl=lhs.ty.reg_len)
873 yield PartialProduct(
874 ssa_val_spread=lhs_and_spread.outputs,
875 shift_in_words=rhs.ty.reg_len, is_signed=False, subtract=True)
876 return sum_partial_products(
877 fn=fn, partial_products=partial_products(),
878 retval_size=retval_size, name=name)
879
880
881 def cast_to_size(fn, ssa_val, src_signed, dest_size, name):
882 # type: (Fn, SSAVal, bool, int, str) -> SSAVal
883 if dest_size <= 0:
884 raise ValueError("invalid dest_size -- must be a positive integer")
885 if ssa_val.ty.reg_len == dest_size:
886 return ssa_val
887 in_setvl = fn.append_new_op(
888 OpKind.SetVLI, immediates=[ssa_val.ty.reg_len],
889 maxvl=ssa_val.ty.reg_len, name=f"{name}_in_setvl")
890 spread = fn.append_new_op(
891 OpKind.Spread, input_vals=[ssa_val, in_setvl.outputs[0]],
892 name=f"{name}_spread", maxvl=ssa_val.ty.reg_len)
893 spread_values = cast_to_size_spread(
894 fn=fn, ssa_vals=spread.outputs, src_signed=src_signed,
895 dest_size=dest_size, name=name)
896 out_setvl = fn.append_new_op(
897 OpKind.SetVLI, immediates=[dest_size], maxvl=dest_size,
898 name=f"{name}_out_setvl")
899 concat = fn.append_new_op(
900 OpKind.Concat, input_vals=[*spread_values, out_setvl.outputs[0]],
901 name=f"{name}_concat", maxvl=dest_size)
902 return concat.outputs[0]
903
904
905 def cast_to_size_spread(fn, ssa_vals, src_signed, dest_size, name):
906 # type: (Fn, Iterable[SSAVal], bool, int, str) -> list[SSAVal]
907 if dest_size <= 0:
908 raise ValueError("invalid dest_size -- must be a positive integer")
909 spread_values = list(ssa_vals)
910 for ssa_val in ssa_vals:
911 if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1):
912 raise ValueError("invalid ssa_val.ty")
913 if len(spread_values) == dest_size:
914 return spread_values
915 if len(spread_values) > dest_size:
916 spread_values[dest_size:] = []
917 elif src_signed:
918 sign = fn.append_new_op(
919 OpKind.SRADI, input_vals=[spread_values[-1]],
920 immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_sign")
921 spread_values += [sign.outputs[0]] * (dest_size - len(spread_values))
922 else:
923 zero = fn.append_new_op(
924 OpKind.LI, immediates=[0], name=f"{name}_zero")
925 spread_values += [zero.outputs[0]] * (dest_size - len(spread_values))
926 return spread_values
927
928
929 def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
930 # type: (Fn, SSAVal, int, int, str) -> tuple[SSAVal, ...]
931 """split ssa_val into part_count parts, where all but the last part have
932 `part.ty.reg_len == part_size`.
933 """
934 if part_size <= 0:
935 raise ValueError("invalid part size, must be positive")
936 if part_count <= 0:
937 raise ValueError("invalid part count, must be positive")
938 if part_count == 1:
939 return (ssa_val,)
940 too_short_reg_len = (part_count - 1) * part_size
941 if ssa_val.ty.reg_len <= too_short_reg_len:
942 raise ValueError(f"ssa_val is too short to split, must have "
943 f"reg_len > {too_short_reg_len}: {ssa_val}")
944 maxvl = ssa_val.ty.reg_len
945 setvl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
946 maxvl=maxvl, name=f"{name}_setvl")
947 spread = fn.append_new_op(
948 OpKind.Spread, input_vals=[ssa_val, setvl.outputs[0]],
949 name=f"{name}_spread", maxvl=maxvl)
950 retval = [] # type: list[SSAVal]
951 for part in range(part_count):
952 start = part * part_size
953 stop = min(maxvl, start + part_size)
954 part_maxvl = stop - start
955 part_setvl = fn.append_new_op(
956 OpKind.SetVLI, immediates=[part_size], maxvl=part_size,
957 name=f"{name}_{part}_setvl")
958 concat = fn.append_new_op(
959 OpKind.Concat,
960 input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]],
961 name=f"{name}_{part}_concat", maxvl=part_maxvl)
962 retval.append(concat.outputs[0])
963 return tuple(retval)
964
965
966 _TCIs = Tuple[ToomCookInstance, ...]
967
968
969 @plain_data(frozen=True)
970 @final
971 class ToomCookMul:
972 __slots__ = (
973 "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
974 "retval_size", "start_instance_index", "instance", "part_size",
975 "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
976 "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
977 "prod_inputs", "prod_eval_state", "prod_parts",
978 "partial_products", "retval",
979 )
980
981 def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances,
982 retval_size=None, start_instance_index=0):
983 # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
984 self.fn = fn
985 self.lhs = lhs
986 self.lhs_signed = lhs_signed
987 self.rhs = rhs
988 self.rhs_signed = rhs_signed
989 self.instances = instances
990 if retval_size is None:
991 retval_size = lhs.ty.reg_len + rhs.ty.reg_len
992 self.retval_size = retval_size
993 if start_instance_index < 0:
994 raise ValueError("start_instance_index must be non-negative")
995 self.start_instance_index = start_instance_index
996 self.instance = None
997 self.part_size = 0 # type: int
998 while start_instance_index < len(instances):
999 self.instance = instances[start_instance_index]
1000 self.part_size = max(
1001 lhs.ty.reg_len // self.instance.lhs_part_count,
1002 rhs.ty.reg_len // self.instance.rhs_part_count)
1003 if self.part_size <= 0:
1004 self.instance = None
1005 start_instance_index += 1
1006 else:
1007 break
1008 if self.instance is None:
1009 self.retval = simple_mul(fn=fn,
1010 lhs=lhs, lhs_signed=lhs_signed,
1011 rhs=rhs, rhs_signed=rhs_signed,
1012 name="toom_cook_base_case")
1013 return
1014 self.lhs_parts = split_into_exact_sized_parts(
1015 fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count,
1016 part_size=self.part_size, name="lhs")
1017 self.lhs_inputs = [] # type: list[EvalOpGenIrInput]
1018 for part, ssa_val in enumerate(self.lhs_parts):
1019 self.lhs_inputs.append(EvalOpGenIrInput(
1020 ssa_val=ssa_val,
1021 is_signed=lhs_signed and part == len(self.lhs_parts) - 1))
1022 self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs)
1023 lhs_eval_ops = self.instance.lhs_eval_ops
1024 self.lhs_outputs = [
1025 self.lhs_eval_state.get_output(i) for i in lhs_eval_ops]
1026 self.rhs_parts = split_into_exact_sized_parts(
1027 fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count,
1028 part_size=self.part_size, name="rhs")
1029 self.rhs_inputs = [] # type: list[EvalOpGenIrInput]
1030 for part, ssa_val in enumerate(self.rhs_parts):
1031 self.rhs_inputs.append(EvalOpGenIrInput(
1032 ssa_val=ssa_val,
1033 is_signed=rhs_signed and part == len(self.rhs_parts) - 1))
1034 self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs)
1035 rhs_eval_ops = self.instance.rhs_eval_ops
1036 self.rhs_outputs = [
1037 self.rhs_eval_state.get_output(i) for i in rhs_eval_ops]
1038 self.prod_inputs = [] # type: list[EvalOpGenIrInput]
1039 for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs):
1040 ssa_val = toom_cook_mul(
1041 fn=fn,
1042 lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
1043 rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
1044 instances=instances,
1045 start_instance_index=start_instance_index + 1)
1046 products = (lhs_output.min_value * rhs_output.min_value,
1047 lhs_output.min_value * rhs_output.max_value,
1048 lhs_output.max_value * rhs_output.min_value,
1049 lhs_output.max_value * rhs_output.max_value)
1050 self.prod_inputs.append(EvalOpGenIrInput(
1051 ssa_val=ssa_val,
1052 is_signed=None,
1053 min_value=min(products),
1054 max_value=max(products)))
1055 self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs)
1056 prod_eval_ops = self.instance.prod_eval_ops
1057 self.prod_parts = [
1058 self.prod_eval_state.get_output(i) for i in prod_eval_ops]
1059
1060 def partial_products():
1061 # type: () -> Iterable[PartialProduct]
1062 for part, prod_part in enumerate(self.prod_parts):
1063 part_maxvl = prod_part.output.ty.reg_len
1064 part_setvl = fn.append_new_op(
1065 OpKind.SetVLI, immediates=[part_maxvl],
1066 name=f"prod_{part}_setvl", maxvl=part_maxvl)
1067 spread_part = fn.append_new_op(
1068 OpKind.Spread,
1069 input_vals=[prod_part.output, part_setvl.outputs[0]],
1070 name=f"prod_{part}_spread", maxvl=part_maxvl)
1071 yield PartialProduct(
1072 spread_part.outputs, shift_in_words=part * self.part_size,
1073 is_signed=prod_part.is_signed, subtract=False)
1074 self.partial_products = tuple(partial_products())
1075 self.retval = sum_partial_products(
1076 fn=fn, partial_products=self.partial_products,
1077 retval_size=retval_size, name="prod")
1078
1079
1080 def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
1081 retval_size=None, start_instance_index=0):
1082 # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
1083 return ToomCookMul(
1084 fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed,
1085 instances=instances, retval_size=retval_size,
1086 start_instance_index=start_instance_index).retval