2 Toom-Cook multiplication algorithm generator for SVP64
5 from abc
import abstractmethod
7 from fractions
import Fraction
8 from typing
import Iterable
, Mapping
, Tuple
, Union
10 from cached_property
import cached_property
11 from nmutil
.plain_data
import plain_data
13 from bigint_presentation_code
.compiler_ir
import (GPR_SIZE_IN_BITS
, BaseTy
, Fn
,
15 from bigint_presentation_code
.matrix
import Matrix
16 from bigint_presentation_code
.type_util
import Literal
, final
17 from bigint_presentation_code
.util
import InternedMeta
21 class PointAtInfinity(Enum
):
22 POINT_AT_INFINITY
= "POINT_AT_INFINITY"
28 POINT_AT_INFINITY
= PointAtInfinity
.POINT_AT_INFINITY
30 _EvalOpPolyCoefficients
= Union
["Mapping[int | None, Fraction | int]",
31 "EvalOpPoly", Fraction
, int, None]
34 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
38 __slots__
= "const_coeff", "var_coeffs"
41 self
, coeffs
=None, # type: _EvalOpPolyCoefficients
42 const_coeff
=None, # type: Fraction | int | None
43 var_coeffs
=(), # type: Iterable[Fraction | int] | None
45 if coeffs
is not None:
46 if const_coeff
is not None or var_coeffs
!= ():
48 "can't specify const_coeff or "
49 "var_coeffs along with coeffs")
50 if isinstance(coeffs
, EvalOpPoly
):
51 self
.const_coeff
= coeffs
.const_coeff
52 self
.var_coeffs
= coeffs
.var_coeffs
54 if isinstance(coeffs
, (int, Fraction
)):
55 const_coeff
= Fraction(coeffs
)
56 final_var_coeffs
= [] # type: list[Fraction]
60 for var
, coeff
in coeffs
.items():
63 coeff
= Fraction(coeff
)
68 raise ValueError("invalid variable index")
69 if var
>= len(final_var_coeffs
):
70 additional
= var
- len(final_var_coeffs
)
71 final_var_coeffs
.extend((Fraction(),) * additional
)
72 final_var_coeffs
.append(coeff
)
74 final_var_coeffs
[var
] = coeff
76 if var_coeffs
is None:
79 final_var_coeffs
= [Fraction(v
) for v
in var_coeffs
]
80 while len(final_var_coeffs
) > 0 and final_var_coeffs
[-1] == 0:
81 final_var_coeffs
.pop()
82 if const_coeff
is None:
84 self
.const_coeff
= Fraction(const_coeff
)
85 self
.var_coeffs
= tuple(final_var_coeffs
)
87 def __add__(self
, rhs
):
88 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
90 const_coeff
= self
.const_coeff
+ rhs
.const_coeff
91 var_coeffs
= list(self
.var_coeffs
)
92 if len(rhs
.var_coeffs
) > len(var_coeffs
):
93 var_coeffs
.extend(rhs
.var_coeffs
[len(var_coeffs
):])
94 for var
in range(min(len(self
.var_coeffs
), len(rhs
.var_coeffs
))):
95 var_coeffs
[var
] += rhs
.var_coeffs
[var
]
96 return EvalOpPoly(const_coeff
=const_coeff
, var_coeffs
=var_coeffs
)
99 def coefficients(self
):
100 # type: () -> dict[int | None, Fraction]
101 retval
= {} # type: dict[int | None, Fraction]
102 if self
.const_coeff
!= 0:
103 retval
[None] = self
.const_coeff
104 for var
, coeff
in enumerate(self
.var_coeffs
):
112 return self
.var_coeffs
== ()
114 def coeff(self
, var
):
115 # type: (int | None) -> Fraction
117 return self
.const_coeff
119 raise ValueError("invalid variable index")
120 if var
< len(self
.var_coeffs
):
121 return self
.var_coeffs
[var
]
127 return EvalOpPoly(const_coeff
=-self
.const_coeff
,
128 var_coeffs
=(-v
for v
in self
.var_coeffs
))
130 def __sub__(self
, rhs
):
131 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
134 def __rsub__(self
, lhs
):
135 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
138 def __mul__(self
, rhs
):
139 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
140 if isinstance(rhs
, EvalOpPoly
):
142 self
, rhs
= rhs
, self
144 raise ValueError("can't represent exponents larger than one")
145 rhs
= rhs
.const_coeff
148 return EvalOpPoly(const_coeff
=self
.const_coeff
* rhs
,
149 var_coeffs
=(i
* rhs
for i
in self
.var_coeffs
))
153 def __truediv__(self
, rhs
):
154 # type: (int | Fraction) -> EvalOpPoly
156 raise ZeroDivisionError()
157 return EvalOpPoly(const_coeff
=self
.const_coeff
/ rhs
,
158 var_coeffs
=(i
/ rhs
for i
in self
.var_coeffs
))
161 return f
"EvalOpPoly({self.coefficients})"
164 @plain_data(frozen
=True, unsafe_hash
=True)
166 class EvalOpValueRange
:
167 __slots__
= ("eval_op", "inputs", "min_value", "max_value",
168 "is_signed", "output_size")
170 def __init__(self
, eval_op
, inputs
):
171 # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None
173 self
.eval_op
= eval_op
175 min_value
= max_value
= self
.poly
.const_coeff
176 for var
, coeff
in enumerate(self
.poly
.var_coeffs
):
179 term_min
= self
.inputs
[var
].min_value
* coeff
180 term_max
= self
.inputs
[var
].max_value
* coeff
181 if term_min
> term_max
:
182 term_min
, term_max
= term_max
, term_min
183 min_value
+= term_min
184 max_value
+= term_max
185 # output values are always integers, so eliminate any fractional part
187 self
.min_value
= math
.ceil(min_value
) # exclude fractional part
188 self
.max_value
= math
.floor(max_value
) # exclude fractional part
189 self
.is_signed
= min_value
< 0
192 min_v
= -1 << (GPR_SIZE_IN_BITS
- 1)
193 max_v
= (1 << (GPR_SIZE_IN_BITS
- 1)) - 1
196 max_v
= (1 << GPR_SIZE_IN_BITS
) - 1
197 while not (min_v
<= self
.min_value
and self
.max_value
<= max_v
):
199 min_v
<<= GPR_SIZE_IN_BITS
200 max_v
<<= GPR_SIZE_IN_BITS
201 self
.output_size
= output_size
205 if isinstance(self
.eval_op
, int):
206 return EvalOpPoly(const_coeff
=self
.eval_op
)
207 return self
.eval_op
.poly
210 @plain_data(frozen
=True, unsafe_hash
=True)
212 class EvalOpGenIrOutput
:
213 __slots__
= "output", "value_range"
215 def __init__(self
, output
, value_range
):
216 # type: (SSAVal, EvalOpValueRange) -> None
218 if output
.ty
.reg_len
!= value_range
.output_size
:
219 raise ValueError("wrong output size")
221 self
.value_range
= value_range
225 # type: () -> EvalOp | int
226 return self
.value_range
.eval_op
230 # type: () -> tuple[EvalOpGenIrInput, ...]
231 return self
.value_range
.inputs
236 return self
.value_range
.min_value
241 return self
.value_range
.max_value
246 return self
.value_range
.is_signed
249 def output_size(self
):
251 return self
.value_range
.output_size
254 def current_debugging_value(self
):
255 # type: () -> tuple[int, ...]
256 """ get the current value for debugging in pdb or similar.
258 This is intended for use with
259 `PreRASimState.set_current_debugging_state`.
261 This is only intended for debugging, do not use in unit tests or
264 return self
.output
.current_debugging_value
267 @plain_data(frozen
=True, unsafe_hash
=True)
269 class EvalOpGenIrInput
:
270 __slots__
= "ssa_val", "is_signed", "min_value", "max_value"
272 def __init__(self
, ssa_val
, is_signed
, min_value
=None, max_value
=None):
273 # type: (SSAVal, bool | None, int | None, int | None) -> None
275 self
.ssa_val
= ssa_val
276 if ssa_val
.base_ty
!= BaseTy
.I64
:
277 raise ValueError("input must have a base_ty of BaseTy.I64")
278 if is_signed
is None:
279 if min_value
is None or max_value
is None:
280 raise ValueError("must specify either is_signed or both "
281 "min_value and max_value")
282 is_signed
= min_value
< 0
283 self
.is_signed
= is_signed
285 if min_value
is None:
286 min_value
= -1 << (ssa_val
.ty
.reg_len
* GPR_SIZE_IN_BITS
- 1)
287 if max_value
is None:
289 ssa_val
.ty
.reg_len
* GPR_SIZE_IN_BITS
- 1)) - 1
291 if min_value
is None:
293 if max_value
is None:
294 max_value
= (1 << (ssa_val
.ty
.reg_len
* GPR_SIZE_IN_BITS
)) - 1
295 self
.min_value
= min_value
296 self
.max_value
= max_value
297 if self
.min_value
> self
.max_value
:
298 raise ValueError("invalid value range")
301 def current_debugging_value(self
):
302 # type: () -> tuple[int, ...]
303 """ get the current value for debugging in pdb or similar.
305 This is intended for use with
306 `PreRASimState.set_current_debugging_state`.
308 This is only intended for debugging, do not use in unit tests or
311 return self
.ssa_val
.current_debugging_value
314 @plain_data(frozen
=True)
316 class EvalOpGenIrState
:
317 __slots__
= "fn", "inputs", "outputs_map"
319 def __init__(self
, fn
, inputs
):
320 # type: (Fn, Iterable[EvalOpGenIrInput]) -> None
323 self
.inputs
= tuple(inputs
)
324 self
.outputs_map
= {} # type: dict[EvalOp | int, EvalOpGenIrOutput]
326 def get_output(self
, eval_op
):
327 # type: (EvalOp | int) -> EvalOpGenIrOutput
328 retval
= self
.outputs_map
.get(eval_op
, None)
329 if retval
is not None:
331 value_range
= EvalOpValueRange(eval_op
=eval_op
, inputs
=self
.inputs
)
332 if isinstance(eval_op
, int):
333 li
= self
.fn
.append_new_op(OpKind
.LI
, immediates
=[eval_op
],
334 name
=f
"li_{eval_op}")
335 output
= cast_to_size(
336 fn
=self
.fn
, ssa_val
=li
.outputs
[0],
337 dest_size
=value_range
.output_size
,
338 src_signed
=value_range
.is_signed
, name
=f
"cast_{eval_op}")
339 retval
= EvalOpGenIrOutput(output
=output
, value_range
=value_range
)
341 retval
= eval_op
.make_output(state
=self
,
342 output_value_range
=value_range
)
343 if retval
.value_range
!= value_range
:
344 raise ValueError("wrong value_range")
345 return self
.outputs_map
.setdefault(eval_op
, retval
)
348 @plain_data(frozen
=True, unsafe_hash
=True)
349 class EvalOp(metaclass
=InternedMeta
):
350 __slots__
= "lhs", "rhs", "poly"
354 # type: () -> EvalOpPoly
355 if isinstance(self
.lhs
, int):
356 return EvalOpPoly(self
.lhs
)
361 # type: () -> EvalOpPoly
362 if isinstance(self
.rhs
, int):
363 return EvalOpPoly(self
.rhs
)
367 def _make_poly(self
):
368 # type: () -> EvalOpPoly
372 def make_output(self
, state
, output_value_range
):
373 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
376 def __init__(self
, lhs
, rhs
):
377 # type: (EvalOp | int, EvalOp | int) -> None
381 self
.poly
= self
._make
_poly
()
384 @plain_data(frozen
=True, unsafe_hash
=True)
386 class EvalOpAdd(EvalOp
):
389 def _make_poly(self
):
390 # type: () -> EvalOpPoly
391 return self
.lhs_poly
+ self
.rhs_poly
393 def make_output(self
, state
, output_value_range
):
394 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
395 lhs
= state
.get_output(self
.lhs
)
396 lhs_output
= cast_to_size(
397 fn
=state
.fn
, ssa_val
=lhs
.output
,
398 dest_size
=output_value_range
.output_size
, src_signed
=lhs
.is_signed
,
400 rhs
= state
.get_output(self
.rhs
)
401 rhs_output
= cast_to_size(
402 fn
=state
.fn
, ssa_val
=rhs
.output
,
403 dest_size
=output_value_range
.output_size
, src_signed
=rhs
.is_signed
,
405 setvl
= state
.fn
.append_new_op(
406 OpKind
.SetVLI
, immediates
=[output_value_range
.output_size
],
407 name
="setvl", maxvl
=output_value_range
.output_size
)
408 clear_ca
= state
.fn
.append_new_op(OpKind
.ClearCA
, name
="clear_ca")
409 add
= state
.fn
.append_new_op(
410 OpKind
.SvAddE
, input_vals
=[
411 lhs_output
, rhs_output
, clear_ca
.outputs
[0], setvl
.outputs
[0]],
412 maxvl
=output_value_range
.output_size
, name
="add")
413 return EvalOpGenIrOutput(
414 output
=add
.outputs
[0], value_range
=output_value_range
)
417 @plain_data(frozen
=True, unsafe_hash
=True)
419 class EvalOpSub(EvalOp
):
422 def _make_poly(self
):
423 # type: () -> EvalOpPoly
424 return self
.lhs_poly
- self
.rhs_poly
426 def make_output(self
, state
, output_value_range
):
427 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
428 lhs
= state
.get_output(self
.lhs
)
429 lhs_output
= cast_to_size(
430 fn
=state
.fn
, ssa_val
=lhs
.output
,
431 dest_size
=output_value_range
.output_size
, src_signed
=lhs
.is_signed
,
433 rhs
= state
.get_output(self
.rhs
)
434 rhs_output
= cast_to_size(
435 fn
=state
.fn
, ssa_val
=rhs
.output
,
436 dest_size
=output_value_range
.output_size
, src_signed
=rhs
.is_signed
,
438 setvl
= state
.fn
.append_new_op(
439 OpKind
.SetVLI
, immediates
=[output_value_range
.output_size
],
440 name
="setvl", maxvl
=output_value_range
.output_size
)
441 set_ca
= state
.fn
.append_new_op(OpKind
.SetCA
, name
="set_ca")
442 sub
= state
.fn
.append_new_op(
443 OpKind
.SvSubFE
, input_vals
=[
444 rhs_output
, lhs_output
, set_ca
.outputs
[0], setvl
.outputs
[0]],
445 maxvl
=output_value_range
.output_size
, name
="sub")
446 return EvalOpGenIrOutput(
447 output
=sub
.outputs
[0], value_range
=output_value_range
)
450 @plain_data(frozen
=True, unsafe_hash
=True)
452 class EvalOpMul(EvalOp
):
456 def _make_poly(self
):
457 # type: () -> EvalOpPoly
458 if not isinstance(self
.rhs
, int): # type: ignore
459 raise TypeError("invalid rhs type")
460 return self
.lhs_poly
* self
.rhs
462 def make_output(self
, state
, output_value_range
):
463 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
464 raise NotImplementedError # FIXME: finish
467 @plain_data(frozen
=True, unsafe_hash
=True)
469 class EvalOpExactDiv(EvalOp
):
473 def _make_poly(self
):
474 # type: () -> EvalOpPoly
475 if not isinstance(self
.rhs
, int): # type: ignore
476 raise TypeError("invalid rhs type")
477 return self
.lhs_poly
/ self
.rhs
479 def make_output(self
, state
, output_value_range
):
480 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
481 raise NotImplementedError # FIXME: finish
484 @plain_data(frozen
=True, unsafe_hash
=True)
486 class EvalOpInput(EvalOp
):
491 def __init__(self
, lhs
, rhs
=0):
492 # type: (int, int) -> None
494 raise ValueError("Input part_index (lhs) must be >= 0")
496 raise ValueError("Input rhs must be 0")
497 super().__init
__(lhs
, rhs
)
500 def part_index(self
):
503 def _make_poly(self
):
504 # type: () -> EvalOpPoly
505 return EvalOpPoly({self
.part_index
: 1})
507 def make_output(self
, state
, output_value_range
):
508 # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
509 inp
= state
.inputs
[self
.part_index
]
510 output
= cast_to_size(
511 fn
=state
.fn
, ssa_val
=inp
.ssa_val
, src_signed
=inp
.is_signed
,
512 dest_size
=output_value_range
.output_size
,
513 name
=f
"input_{self.part_index}_cast")
514 return EvalOpGenIrOutput(output
=output
, value_range
=output_value_range
)
517 @plain_data(frozen
=True, unsafe_hash
=True)
519 class ToomCookInstance
:
520 __slots__
= ("lhs_part_count", "rhs_part_count", "eval_points",
521 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
524 def prod_part_count(self
):
525 return self
.lhs_part_count
+ self
.rhs_part_count
- 1
528 def make_eval_matrix(width
, eval_points
):
529 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
530 retval
= Matrix(height
=len(eval_points
), width
=width
)
531 for row
, col
in retval
.indexes():
532 eval_point
= eval_points
[row
]
533 if eval_point
is POINT_AT_INFINITY
:
534 retval
[row
, col
] = int(col
== width
- 1)
536 retval
[row
, col
] = eval_point
** col
539 def get_lhs_eval_matrix(self
):
540 # type: () -> Matrix[Fraction]
541 return self
.make_eval_matrix(self
.lhs_part_count
, self
.eval_points
)
544 def make_input_poly_vector(height
):
545 # type: (int) -> Matrix[EvalOpPoly]
546 return Matrix(height
=height
, width
=1, element_type
=EvalOpPoly
,
547 data
=(EvalOpPoly({i
: 1}) for i
in range(height
)))
549 def get_lhs_eval_polys(self
):
550 # type: () -> list[EvalOpPoly]
551 return list(self
.get_lhs_eval_matrix().cast(EvalOpPoly
)
552 @ self
.make_input_poly_vector(self
.lhs_part_count
))
554 def get_rhs_eval_matrix(self
):
555 # type: () -> Matrix[Fraction]
556 return self
.make_eval_matrix(self
.rhs_part_count
, self
.eval_points
)
558 def get_rhs_eval_polys(self
):
559 # type: () -> list[EvalOpPoly]
560 return list(self
.get_rhs_eval_matrix().cast(EvalOpPoly
)
561 @ self
.make_input_poly_vector(self
.rhs_part_count
))
563 def get_prod_inverse_eval_matrix(self
):
564 # type: () -> Matrix[Fraction]
565 return self
.make_eval_matrix(self
.prod_part_count
, self
.eval_points
)
567 def get_prod_eval_matrix(self
):
568 # type: () -> Matrix[Fraction]
569 return self
.get_prod_inverse_eval_matrix().inverse()
571 def get_prod_eval_polys(self
):
572 # type: () -> list[EvalOpPoly]
573 return list(self
.get_prod_eval_matrix().cast(EvalOpPoly
)
574 @ self
.make_input_poly_vector(self
.prod_part_count
))
577 self
, lhs_part_count
, # type: int
578 rhs_part_count
, # type: int
579 eval_points
, # type: Iterable[PointAtInfinity | int]
580 lhs_eval_ops
, # type: Iterable[EvalOp]
581 rhs_eval_ops
, # type: Iterable[EvalOp]
582 prod_eval_ops
, # type: Iterable[EvalOp]
584 # type: (...) -> None
585 self
.lhs_part_count
= lhs_part_count
586 if self
.lhs_part_count
< 2:
587 raise ValueError("lhs_part_count must be at least 2")
588 self
.rhs_part_count
= rhs_part_count
589 if self
.rhs_part_count
< 2:
590 raise ValueError("rhs_part_count must be at least 2")
591 eval_points
= list(eval_points
)
592 self
.eval_points
= tuple(eval_points
)
593 if len(self
.eval_points
) != len(set(self
.eval_points
)):
594 raise ValueError("duplicate eval points")
595 self
.lhs_eval_ops
= tuple(lhs_eval_ops
)
596 if len(self
.lhs_eval_ops
) != self
.prod_part_count
:
597 raise ValueError("wrong number of lhs_eval_ops")
598 self
.rhs_eval_ops
= tuple(rhs_eval_ops
)
599 if len(self
.rhs_eval_ops
) != self
.prod_part_count
:
600 raise ValueError("wrong number of rhs_eval_ops")
601 if len(self
.eval_points
) != self
.prod_part_count
:
602 raise ValueError("wrong number of eval_points")
603 self
.prod_eval_ops
= tuple(prod_eval_ops
)
604 if len(self
.prod_eval_ops
) != self
.prod_part_count
:
605 raise ValueError("wrong number of prod_eval_ops")
607 lhs_eval_polys
= self
.get_lhs_eval_polys()
608 for i
, eval_op
in enumerate(self
.lhs_eval_ops
):
609 if lhs_eval_polys
[i
] != eval_op
.poly
:
611 f
"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
612 f
"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
614 rhs_eval_polys
= self
.get_rhs_eval_polys()
615 for i
, eval_op
in enumerate(self
.rhs_eval_ops
):
616 if rhs_eval_polys
[i
] != eval_op
.poly
:
618 f
"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
619 f
"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
621 prod_eval_polys
= self
.get_prod_eval_polys() # also checks matrix
622 for i
, eval_op
in enumerate(self
.prod_eval_ops
):
623 if prod_eval_polys
[i
] != eval_op
.poly
:
625 f
"prod_eval_ops[{i}] is incorrect: expected polynomial: "
626 f
"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
629 # type: () -> ToomCookInstance
630 """return a ToomCookInstance where lhs/rhs are reversed"""
631 return ToomCookInstance(
632 lhs_part_count
=self
.rhs_part_count
,
633 rhs_part_count
=self
.lhs_part_count
,
634 eval_points
=self
.eval_points
,
635 lhs_eval_ops
=self
.rhs_eval_ops
,
636 rhs_eval_ops
=self
.lhs_eval_ops
,
637 prod_eval_ops
=self
.prod_eval_ops
)
641 # type: () -> ToomCookInstance
642 """make an instance of Toom-2 aka Karatsuba multiplication"""
643 return ToomCookInstance(
646 eval_points
=[0, 1, POINT_AT_INFINITY
],
649 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
654 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
659 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
667 # type: () -> ToomCookInstance
668 """makes an instance of Toom-2.5"""
669 inp_0_plus_inp_2
= EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
670 inp_1_minus_inp_2
= EvalOpSub(EvalOpInput(1), EvalOpInput(2))
671 inp_1_plus_inp_2
= EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
672 inp_1_minus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_minus_inp_2
, 2)
673 inp_1_plus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_plus_inp_2
, 2)
674 return ToomCookInstance(
677 eval_points
=[0, 1, -1, POINT_AT_INFINITY
],
680 EvalOpAdd(inp_0_plus_inp_2
, EvalOpInput(1)),
681 EvalOpSub(inp_0_plus_inp_2
, EvalOpInput(1)),
686 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
687 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
692 EvalOpSub(inp_1_minus_inp_2_all_div_2
, EvalOpInput(3)),
693 EvalOpSub(inp_1_plus_inp_2_all_div_2
, EvalOpInput(0)),
698 # TODO: add make_toom_3
701 @plain_data(frozen
=True, unsafe_hash
=True)
703 class PartialProduct
:
704 __slots__
= "ssa_val_spread", "shift_in_words", "is_signed", "subtract"
706 def __init__(self
, ssa_val_spread
, shift_in_words
, is_signed
, subtract
):
707 # type: (Iterable[SSAVal], int, bool, bool) -> None
708 if shift_in_words
< 0:
709 raise ValueError("invalid shift_in_words")
710 self
.ssa_val_spread
= tuple(ssa_val_spread
)
711 for ssa_val
in ssa_val_spread
:
712 if ssa_val
.ty
!= Ty(base_ty
=BaseTy
.I64
, reg_len
=1):
713 raise ValueError("invalid ssa_val.ty")
714 self
.shift_in_words
= shift_in_words
715 self
.is_signed
= is_signed
716 self
.subtract
= subtract
719 def sum_partial_products(fn
, partial_products
, retval_size
, name
):
720 # type: (Fn, Iterable[PartialProduct], int, str) -> SSAVal
721 retval_spread
= [] # type: list[SSAVal]
722 retval_signed
= False
723 zero
= fn
.append_new_op(OpKind
.LI
, immediates
=[0],
724 name
=f
"{name}_zero").outputs
[0]
725 has_carry_word
= False
726 for idx
, partial_product
in enumerate(partial_products
):
727 shift_in_words
= partial_product
.shift_in_words
728 spread
= list(partial_product
.ssa_val_spread
)
729 if (not retval_signed
and shift_in_words
>= len(retval_spread
)
730 and not partial_product
.subtract
):
731 retval_spread
.extend(
732 [zero
] * (shift_in_words
- len(retval_spread
)))
733 retval_spread
.extend(spread
)
734 retval_signed
= partial_product
.is_signed
735 has_carry_word
= False
737 assert len(retval_spread
) != 0, "logic error"
738 retval_hi_len
= len(retval_spread
) - shift_in_words
739 if retval_hi_len
<= len(spread
):
740 maxvl
= len(spread
) + 1
741 has_carry_word
= True
743 maxvl
= retval_hi_len
745 maxvl
= retval_hi_len
+ 1
746 has_carry_word
= True
747 if not has_carry_word
:
749 has_carry_word
= True
750 if maxvl
> retval_size
- shift_in_words
:
751 maxvl
= retval_size
- shift_in_words
752 has_carry_word
= False
753 retval_spread
= cast_to_size_spread(
754 fn
=fn
, ssa_vals
=retval_spread
, src_signed
=retval_signed
,
755 dest_size
=maxvl
+ shift_in_words
, name
=f
"{name}_{idx}_cast_retval")
756 spread
= cast_to_size_spread(
757 fn
=fn
, ssa_vals
=spread
, src_signed
=partial_product
.is_signed
,
758 dest_size
=maxvl
, name
=f
"{name}_{idx}_cast_pp")
759 setvl
= fn
.append_new_op(
760 OpKind
.SetVLI
, immediates
=[maxvl
],
761 maxvl
=maxvl
, name
=f
"{name}_{idx}_setvl")
762 retval_concat
= fn
.append_new_op(
764 input_vals
=[*retval_spread
[shift_in_words
:], setvl
.outputs
[0]],
765 name
=f
"{name}_{idx}_retval_concat", maxvl
=maxvl
)
766 pp_concat
= fn
.append_new_op(
768 input_vals
=[*spread
, setvl
.outputs
[0]],
769 name
=f
"{name}_{idx}_pp_concat", maxvl
=maxvl
)
770 if partial_product
.subtract
:
771 set_ca
= fn
.append_new_op(kind
=OpKind
.SetCA
,
772 name
=f
"{name}_{idx}_set_ca")
773 add_sub
= fn
.append_new_op(
774 kind
=OpKind
.SvSubFE
, input_vals
=[
775 pp_concat
.outputs
[0], retval_concat
.outputs
[0],
776 set_ca
.outputs
[0], setvl
.outputs
[0]],
777 maxvl
=maxvl
, name
=f
"{name}_{idx}_sub")
779 clear_ca
= fn
.append_new_op(kind
=OpKind
.ClearCA
,
780 name
=f
"{name}_{idx}_clear_ca")
781 add_sub
= fn
.append_new_op(
782 kind
=OpKind
.SvAddE
, input_vals
=[
783 retval_concat
.outputs
[0], pp_concat
.outputs
[0],
784 clear_ca
.outputs
[0], setvl
.outputs
[0]],
785 maxvl
=maxvl
, name
=f
"{name}_{idx}_add")
786 retval_spread
[shift_in_words
:] = fn
.append_new_op(
788 input_vals
=[add_sub
.outputs
[0], setvl
.outputs
[0]],
789 name
=f
"{name}_{idx}_sum_spread", maxvl
=maxvl
).outputs
790 retval_spread
= cast_to_size_spread(
791 fn
=fn
, ssa_vals
=retval_spread
, src_signed
=retval_signed
,
792 dest_size
=retval_size
, name
=f
"{name}_retval_cast")
793 retval_setvl
= fn
.append_new_op(
794 OpKind
.SetVLI
, immediates
=[retval_size
],
795 maxvl
=retval_size
, name
=f
"{name}_setvl")
796 retval_concat
= fn
.append_new_op(
798 input_vals
=[*retval_spread
, retval_setvl
.outputs
[0]],
799 name
=f
"{name}_concat", maxvl
=retval_size
)
800 return retval_concat
.outputs
[0]
803 def simple_mul(fn
, lhs
, lhs_signed
, rhs
, rhs_signed
, name
, retval_size
=None):
804 # type: (Fn, SSAVal, bool, SSAVal, bool, str, int | None) -> SSAVal
805 """ simple O(n^2) big-int multiply """
806 if retval_size
is None:
807 retval_size
= lhs
.ty
.reg_len
+ rhs
.ty
.reg_len
808 if lhs
.ty
.reg_len
< rhs
.ty
.reg_len
:
810 lhs_signed
, rhs_signed
= rhs_signed
, lhs_signed
811 # split rhs into elements
812 rhs_setvl
= fn
.append_new_op(
813 kind
=OpKind
.SetVLI
, immediates
=[rhs
.ty
.reg_len
],
814 name
=f
"{name}_rhs_setvl")
815 rhs_spread
= fn
.append_new_op(
816 kind
=OpKind
.Spread
, input_vals
=[rhs
, rhs_setvl
.outputs
[0]],
817 maxvl
=rhs
.ty
.reg_len
, name
=f
"{name}_rhs_spread")
818 rhs_words
= rhs_spread
.outputs
819 zero
= fn
.append_new_op(
820 kind
=OpKind
.LI
, immediates
=[0], name
=f
"{name}_zero").outputs
[0]
821 maxvl
= lhs
.ty
.reg_len
822 lhs_setvl
= fn
.append_new_op(
823 kind
=OpKind
.SetVLI
, immediates
=[maxvl
], name
=f
"{name}_lhs_setvl",
825 vl
= lhs_setvl
.outputs
[0]
827 def partial_products():
828 # type: () -> Iterable[PartialProduct]
829 for shift_in_words
, rhs_word
in enumerate(rhs_words
):
830 mul
= fn
.append_new_op(
831 kind
=OpKind
.SvMAddEDU
, input_vals
=[lhs
, rhs_word
, zero
, vl
],
832 maxvl
=maxvl
, name
=f
"{name}_{shift_in_words}_mul")
833 mul_rt_spread
= fn
.append_new_op(
834 kind
=OpKind
.Spread
, input_vals
=[mul
.outputs
[0], vl
],
835 name
=f
"{name}_{shift_in_words}_mul_rt_spread", maxvl
=maxvl
)
836 yield PartialProduct(
837 ssa_val_spread
=[*mul_rt_spread
.outputs
, mul
.outputs
[1]],
838 shift_in_words
=shift_in_words
,
839 is_signed
=False, subtract
=False)
841 lhs_spread
= fn
.append_new_op(
842 kind
=OpKind
.Spread
, input_vals
=[lhs
, lhs_setvl
.outputs
[0]],
843 maxvl
=lhs
.ty
.reg_len
, name
=f
"{name}_lhs_spread")
844 rhs_mask
= fn
.append_new_op(
845 kind
=OpKind
.SRADI
, input_vals
=[lhs_spread
.outputs
[-1]],
846 immediates
=[GPR_SIZE_IN_BITS
- 1], name
=f
"{name}_rhs_mask")
847 lhs_and
= fn
.append_new_op(
849 input_vals
=[rhs
, rhs_mask
.outputs
[0], rhs_setvl
.outputs
[0]],
850 maxvl
=rhs
.ty
.reg_len
, name
=f
"{name}_rhs_and")
851 rhs_and_spread
= fn
.append_new_op(
853 input_vals
=[lhs_and
.outputs
[0], rhs_setvl
.outputs
[0]],
854 name
=f
"{name}_rhs_and_spread", maxvl
=rhs
.ty
.reg_len
)
855 yield PartialProduct(
856 ssa_val_spread
=rhs_and_spread
.outputs
,
857 shift_in_words
=lhs
.ty
.reg_len
, is_signed
=False, subtract
=True)
859 rhs_spread
= fn
.append_new_op(
860 kind
=OpKind
.Spread
, input_vals
=[rhs
, rhs_setvl
.outputs
[0]],
861 maxvl
=rhs
.ty
.reg_len
, name
=f
"{name}_rhs_spread")
862 lhs_mask
= fn
.append_new_op(
863 kind
=OpKind
.SRADI
, input_vals
=[rhs_spread
.outputs
[-1]],
864 immediates
=[GPR_SIZE_IN_BITS
- 1], name
=f
"{name}_lhs_mask")
865 rhs_and
= fn
.append_new_op(
867 input_vals
=[lhs
, lhs_mask
.outputs
[0], lhs_setvl
.outputs
[0]],
868 maxvl
=lhs
.ty
.reg_len
, name
=f
"{name}_lhs_and")
869 lhs_and_spread
= fn
.append_new_op(
871 input_vals
=[rhs_and
.outputs
[0], lhs_setvl
.outputs
[0]],
872 name
=f
"{name}_lhs_and_spread", maxvl
=lhs
.ty
.reg_len
)
873 yield PartialProduct(
874 ssa_val_spread
=lhs_and_spread
.outputs
,
875 shift_in_words
=rhs
.ty
.reg_len
, is_signed
=False, subtract
=True)
876 return sum_partial_products(
877 fn
=fn
, partial_products
=partial_products(),
878 retval_size
=retval_size
, name
=name
)
881 def cast_to_size(fn
, ssa_val
, src_signed
, dest_size
, name
):
882 # type: (Fn, SSAVal, bool, int, str) -> SSAVal
884 raise ValueError("invalid dest_size -- must be a positive integer")
885 if ssa_val
.ty
.reg_len
== dest_size
:
887 in_setvl
= fn
.append_new_op(
888 OpKind
.SetVLI
, immediates
=[ssa_val
.ty
.reg_len
],
889 maxvl
=ssa_val
.ty
.reg_len
, name
=f
"{name}_in_setvl")
890 spread
= fn
.append_new_op(
891 OpKind
.Spread
, input_vals
=[ssa_val
, in_setvl
.outputs
[0]],
892 name
=f
"{name}_spread", maxvl
=ssa_val
.ty
.reg_len
)
893 spread_values
= cast_to_size_spread(
894 fn
=fn
, ssa_vals
=spread
.outputs
, src_signed
=src_signed
,
895 dest_size
=dest_size
, name
=name
)
896 out_setvl
= fn
.append_new_op(
897 OpKind
.SetVLI
, immediates
=[dest_size
], maxvl
=dest_size
,
898 name
=f
"{name}_out_setvl")
899 concat
= fn
.append_new_op(
900 OpKind
.Concat
, input_vals
=[*spread_values
, out_setvl
.outputs
[0]],
901 name
=f
"{name}_concat", maxvl
=dest_size
)
902 return concat
.outputs
[0]
905 def cast_to_size_spread(fn
, ssa_vals
, src_signed
, dest_size
, name
):
906 # type: (Fn, Iterable[SSAVal], bool, int, str) -> list[SSAVal]
908 raise ValueError("invalid dest_size -- must be a positive integer")
909 spread_values
= list(ssa_vals
)
910 for ssa_val
in ssa_vals
:
911 if ssa_val
.ty
!= Ty(base_ty
=BaseTy
.I64
, reg_len
=1):
912 raise ValueError("invalid ssa_val.ty")
913 if len(spread_values
) == dest_size
:
915 if len(spread_values
) > dest_size
:
916 spread_values
[dest_size
:] = []
918 sign
= fn
.append_new_op(
919 OpKind
.SRADI
, input_vals
=[spread_values
[-1]],
920 immediates
=[GPR_SIZE_IN_BITS
- 1], name
=f
"{name}_sign")
921 spread_values
+= [sign
.outputs
[0]] * (dest_size
- len(spread_values
))
923 zero
= fn
.append_new_op(
924 OpKind
.LI
, immediates
=[0], name
=f
"{name}_zero")
925 spread_values
+= [zero
.outputs
[0]] * (dest_size
- len(spread_values
))
929 def split_into_exact_sized_parts(fn
, ssa_val
, part_count
, part_size
, name
):
930 # type: (Fn, SSAVal, int, int, str) -> tuple[SSAVal, ...]
931 """split ssa_val into part_count parts, where all but the last part have
932 `part.ty.reg_len == part_size`.
935 raise ValueError("invalid part size, must be positive")
937 raise ValueError("invalid part count, must be positive")
940 too_short_reg_len
= (part_count
- 1) * part_size
941 if ssa_val
.ty
.reg_len
<= too_short_reg_len
:
942 raise ValueError(f
"ssa_val is too short to split, must have "
943 f
"reg_len > {too_short_reg_len}: {ssa_val}")
944 maxvl
= ssa_val
.ty
.reg_len
945 setvl
= fn
.append_new_op(OpKind
.SetVLI
, immediates
=[maxvl
],
946 maxvl
=maxvl
, name
=f
"{name}_setvl")
947 spread
= fn
.append_new_op(
948 OpKind
.Spread
, input_vals
=[ssa_val
, setvl
.outputs
[0]],
949 name
=f
"{name}_spread", maxvl
=maxvl
)
950 retval
= [] # type: list[SSAVal]
951 for part
in range(part_count
):
952 start
= part
* part_size
953 stop
= min(maxvl
, start
+ part_size
)
954 part_maxvl
= stop
- start
955 part_setvl
= fn
.append_new_op(
956 OpKind
.SetVLI
, immediates
=[part_size
], maxvl
=part_size
,
957 name
=f
"{name}_{part}_setvl")
958 concat
= fn
.append_new_op(
960 input_vals
=[*spread
.outputs
[start
:stop
], part_setvl
.outputs
[0]],
961 name
=f
"{name}_{part}_concat", maxvl
=part_maxvl
)
962 retval
.append(concat
.outputs
[0])
966 _TCIs
= Tuple
[ToomCookInstance
, ...]
969 @plain_data(frozen
=True)
973 "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
974 "retval_size", "start_instance_index", "instance", "part_size",
975 "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
976 "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
977 "prod_inputs", "prod_eval_state", "prod_parts",
978 "partial_products", "retval",
981 def __init__(self
, fn
, lhs
, lhs_signed
, rhs
, rhs_signed
, instances
,
982 retval_size
=None, start_instance_index
=0):
983 # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
986 self
.lhs_signed
= lhs_signed
988 self
.rhs_signed
= rhs_signed
989 self
.instances
= instances
990 if retval_size
is None:
991 retval_size
= lhs
.ty
.reg_len
+ rhs
.ty
.reg_len
992 self
.retval_size
= retval_size
993 if start_instance_index
< 0:
994 raise ValueError("start_instance_index must be non-negative")
995 self
.start_instance_index
= start_instance_index
997 self
.part_size
= 0 # type: int
998 while start_instance_index
< len(instances
):
999 self
.instance
= instances
[start_instance_index
]
1000 self
.part_size
= max(
1001 lhs
.ty
.reg_len
// self
.instance
.lhs_part_count
,
1002 rhs
.ty
.reg_len
// self
.instance
.rhs_part_count
)
1003 if self
.part_size
<= 0:
1004 self
.instance
= None
1005 start_instance_index
+= 1
1008 if self
.instance
is None:
1009 self
.retval
= simple_mul(fn
=fn
,
1010 lhs
=lhs
, lhs_signed
=lhs_signed
,
1011 rhs
=rhs
, rhs_signed
=rhs_signed
,
1012 name
="toom_cook_base_case")
1014 self
.lhs_parts
= split_into_exact_sized_parts(
1015 fn
=fn
, ssa_val
=lhs
, part_count
=self
.instance
.lhs_part_count
,
1016 part_size
=self
.part_size
, name
="lhs")
1017 self
.lhs_inputs
= [] # type: list[EvalOpGenIrInput]
1018 for part
, ssa_val
in enumerate(self
.lhs_parts
):
1019 self
.lhs_inputs
.append(EvalOpGenIrInput(
1021 is_signed
=lhs_signed
and part
== len(self
.lhs_parts
) - 1))
1022 self
.lhs_eval_state
= EvalOpGenIrState(fn
=fn
, inputs
=self
.lhs_inputs
)
1023 lhs_eval_ops
= self
.instance
.lhs_eval_ops
1024 self
.lhs_outputs
= [
1025 self
.lhs_eval_state
.get_output(i
) for i
in lhs_eval_ops
]
1026 self
.rhs_parts
= split_into_exact_sized_parts(
1027 fn
=fn
, ssa_val
=rhs
, part_count
=self
.instance
.rhs_part_count
,
1028 part_size
=self
.part_size
, name
="rhs")
1029 self
.rhs_inputs
= [] # type: list[EvalOpGenIrInput]
1030 for part
, ssa_val
in enumerate(self
.rhs_parts
):
1031 self
.rhs_inputs
.append(EvalOpGenIrInput(
1033 is_signed
=rhs_signed
and part
== len(self
.rhs_parts
) - 1))
1034 self
.rhs_eval_state
= EvalOpGenIrState(fn
=fn
, inputs
=self
.rhs_inputs
)
1035 rhs_eval_ops
= self
.instance
.rhs_eval_ops
1036 self
.rhs_outputs
= [
1037 self
.rhs_eval_state
.get_output(i
) for i
in rhs_eval_ops
]
1038 self
.prod_inputs
= [] # type: list[EvalOpGenIrInput]
1039 for lhs_output
, rhs_output
in zip(self
.lhs_outputs
, self
.rhs_outputs
):
1040 ssa_val
= toom_cook_mul(
1042 lhs
=lhs_output
.output
, lhs_signed
=lhs_output
.is_signed
,
1043 rhs
=rhs_output
.output
, rhs_signed
=rhs_output
.is_signed
,
1044 instances
=instances
,
1045 start_instance_index
=start_instance_index
+ 1)
1046 products
= (lhs_output
.min_value
* rhs_output
.min_value
,
1047 lhs_output
.min_value
* rhs_output
.max_value
,
1048 lhs_output
.max_value
* rhs_output
.min_value
,
1049 lhs_output
.max_value
* rhs_output
.max_value
)
1050 self
.prod_inputs
.append(EvalOpGenIrInput(
1053 min_value
=min(products
),
1054 max_value
=max(products
)))
1055 self
.prod_eval_state
= EvalOpGenIrState(fn
=fn
, inputs
=self
.prod_inputs
)
1056 prod_eval_ops
= self
.instance
.prod_eval_ops
1058 self
.prod_eval_state
.get_output(i
) for i
in prod_eval_ops
]
1060 def partial_products():
1061 # type: () -> Iterable[PartialProduct]
1062 for part
, prod_part
in enumerate(self
.prod_parts
):
1063 part_maxvl
= prod_part
.output
.ty
.reg_len
1064 part_setvl
= fn
.append_new_op(
1065 OpKind
.SetVLI
, immediates
=[part_maxvl
],
1066 name
=f
"prod_{part}_setvl", maxvl
=part_maxvl
)
1067 spread_part
= fn
.append_new_op(
1069 input_vals
=[prod_part
.output
, part_setvl
.outputs
[0]],
1070 name
=f
"prod_{part}_spread", maxvl
=part_maxvl
)
1071 yield PartialProduct(
1072 spread_part
.outputs
, shift_in_words
=part
* self
.part_size
,
1073 is_signed
=prod_part
.is_signed
, subtract
=False)
1074 self
.partial_products
= tuple(partial_products())
1075 self
.retval
= sum_partial_products(
1076 fn
=fn
, partial_products
=self
.partial_products
,
1077 retval_size
=retval_size
, name
="prod")
1080 def toom_cook_mul(fn
, lhs
, lhs_signed
, rhs
, rhs_signed
, instances
,
1081 retval_size
=None, start_instance_index
=0):
1082 # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
1084 fn
=fn
, lhs
=lhs
, lhs_signed
=lhs_signed
, rhs
=rhs
, rhs_signed
=rhs_signed
,
1085 instances
=instances
, retval_size
=retval_size
,
1086 start_instance_index
=start_instance_index
).retval