2 Toom-Cook multiplication algorithm generator for SVP64
4 from abc
import ABCMeta
, abstractmethod
6 from fractions
import Fraction
7 from typing
import Any
, Generic
, Iterable
, Mapping
, TypeVar
, Union
9 from nmutil
.plain_data
import plain_data
11 from bigint_presentation_code
.compiler_ir
import (GPR_SIZE_IN_BITS
, Fn
, OpKind
,
13 from bigint_presentation_code
.matrix
import Matrix
14 from bigint_presentation_code
.type_util
import Literal
, final
18 class PointAtInfinity(Enum
):
19 POINT_AT_INFINITY
= "POINT_AT_INFINITY"
25 POINT_AT_INFINITY
= PointAtInfinity
.POINT_AT_INFINITY
26 WORD_BITS
= GPR_SIZE_IN_BITS
28 _EvalOpPolyCoefficients
= Union
["Mapping[int | None, Fraction | int]",
29 "EvalOpPoly", Fraction
, int, None]
32 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
36 __slots__
= "const_coeff", "var_coeffs"
39 self
, coeffs
=None, # type: _EvalOpPolyCoefficients
40 const_coeff
=None, # type: Fraction | int | None
41 var_coeffs
=(), # type: Iterable[Fraction | int] | None
43 if coeffs
is not None:
44 if const_coeff
is not None or var_coeffs
!= ():
46 "can't specify const_coeff or "
47 "var_coeffs along with coeffs")
48 if isinstance(coeffs
, EvalOpPoly
):
49 self
.const_coeff
= coeffs
.const_coeff
50 self
.var_coeffs
= coeffs
.var_coeffs
52 if isinstance(coeffs
, (int, Fraction
)):
53 const_coeff
= Fraction(coeffs
)
54 final_var_coeffs
= [] # type: list[Fraction]
58 for var
, coeff
in coeffs
.items():
61 coeff
= Fraction(coeff
)
66 raise ValueError("invalid variable index")
67 if var
>= len(final_var_coeffs
):
68 additional
= var
- len(final_var_coeffs
)
69 final_var_coeffs
.extend((Fraction(),) * additional
)
70 final_var_coeffs
.append(coeff
)
72 final_var_coeffs
[var
] = coeff
74 if var_coeffs
is None:
77 final_var_coeffs
= [Fraction(v
) for v
in var_coeffs
]
78 while len(final_var_coeffs
) > 0 and final_var_coeffs
[-1] == 0:
79 final_var_coeffs
.pop()
80 if const_coeff
is None:
82 self
.const_coeff
= Fraction(const_coeff
)
83 self
.var_coeffs
= tuple(final_var_coeffs
)
85 def __add__(self
, rhs
):
86 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
88 const_coeff
= self
.const_coeff
+ rhs
.const_coeff
89 var_coeffs
= list(self
.var_coeffs
)
90 if len(rhs
.var_coeffs
) > len(var_coeffs
):
91 var_coeffs
.extend(rhs
.var_coeffs
[len(var_coeffs
):])
92 for var
in range(min(len(self
.var_coeffs
), len(rhs
.var_coeffs
))):
93 var_coeffs
[var
] += rhs
.var_coeffs
[var
]
94 return EvalOpPoly(const_coeff
=const_coeff
, var_coeffs
=var_coeffs
)
97 def coefficients(self
):
98 # type: () -> dict[int | None, Fraction]
99 retval
= {} # type: dict[int | None, Fraction]
100 if self
.const_coeff
!= 0:
101 retval
[None] = self
.const_coeff
102 for var
, coeff
in enumerate(self
.var_coeffs
):
110 return self
.var_coeffs
== ()
112 def coeff(self
, var
):
113 # type: (int | None) -> Fraction
115 return self
.const_coeff
117 raise ValueError("invalid variable index")
118 if var
< len(self
.var_coeffs
):
119 return self
.var_coeffs
[var
]
125 return EvalOpPoly(const_coeff
=-self
.const_coeff
,
126 var_coeffs
=(-v
for v
in self
.var_coeffs
))
128 def __sub__(self
, rhs
):
129 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
132 def __rsub__(self
, lhs
):
133 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
136 def __mul__(self
, rhs
):
137 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
138 if isinstance(rhs
, EvalOpPoly
):
140 self
, rhs
= rhs
, self
142 raise ValueError("can't represent exponents larger than one")
143 rhs
= rhs
.const_coeff
146 return EvalOpPoly(const_coeff
=self
.const_coeff
* rhs
,
147 var_coeffs
=(i
* rhs
for i
in self
.var_coeffs
))
151 def __truediv__(self
, rhs
):
152 # type: (int | Fraction) -> EvalOpPoly
154 raise ZeroDivisionError()
155 return EvalOpPoly(const_coeff
=self
.const_coeff
/ rhs
,
156 var_coeffs
=(i
/ rhs
for i
in self
.var_coeffs
))
159 return f
"EvalOpPoly({self.coefficients})"
162 @plain_data(frozen
=True, unsafe_hash
=True)
163 class EvalOpValueRange
:
164 __slots__
= ("eval_op", "inputs_words", "min_value", "max_value",
167 def __init__(self
, eval_op
, inputs_words
):
168 # type: (EvalOp[Any, Any], Iterable[int]) -> None
170 self
.eval_op
= eval_op
171 self
.inputs_words
= tuple(inputs_words
)
172 for words
in self
.inputs_words
:
174 raise ValueError(f
"invalid word count: {words}")
175 min_value
= max_value
= eval_op
.poly
.const_coeff
176 for var
, coeff
in enumerate(eval_op
.poly
.var_coeffs
):
180 var_max
= (1 << self
.inputs_words
[var
] * WORD_BITS
) - 1
181 term_min
= var_min
* coeff
182 term_max
= var_max
* coeff
183 if term_min
> term_max
:
184 term_min
, term_max
= term_max
, term_min
185 min_value
+= term_min
186 max_value
+= term_max
187 self
.min_value
= min_value
188 self
.max_value
= max_value
189 self
.is_signed
= min_value
< 0
192 _EvalOpLHS
= TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
193 _EvalOpRHS
= TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
196 @plain_data(frozen
=True, unsafe_hash
=True)
197 class EvalOp(Generic
[_EvalOpLHS
, _EvalOpRHS
], metaclass
=ABCMeta
):
198 __slots__
= "lhs", "rhs", "poly"
202 # type: () -> EvalOpPoly
203 if isinstance(self
.lhs
, int):
204 return EvalOpPoly(self
.lhs
)
209 # type: () -> EvalOpPoly
210 if isinstance(self
.rhs
, int):
211 return EvalOpPoly(self
.rhs
)
215 def _make_poly(self
):
216 # type: () -> EvalOpPoly
219 def __init__(self
, lhs
, rhs
):
220 # type: (_EvalOpLHS, _EvalOpRHS) -> None
224 self
.poly
= self
._make
_poly
()
227 @plain_data(frozen
=True, unsafe_hash
=True)
229 class EvalOpAdd(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
232 def _make_poly(self
):
233 # type: () -> EvalOpPoly
234 return self
.lhs_poly
+ self
.rhs_poly
237 @plain_data(frozen
=True, unsafe_hash
=True)
239 class EvalOpSub(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
242 def _make_poly(self
):
243 # type: () -> EvalOpPoly
244 return self
.lhs_poly
- self
.rhs_poly
247 @plain_data(frozen
=True, unsafe_hash
=True)
249 class EvalOpMul(EvalOp
[_EvalOpLHS
, int]):
252 def _make_poly(self
):
253 # type: () -> EvalOpPoly
254 return self
.lhs_poly
* self
.rhs
257 @plain_data(frozen
=True, unsafe_hash
=True)
259 class EvalOpExactDiv(EvalOp
[_EvalOpLHS
, int]):
262 def _make_poly(self
):
263 # type: () -> EvalOpPoly
264 return self
.lhs_poly
/ self
.rhs
267 @plain_data(frozen
=True, unsafe_hash
=True)
269 class EvalOpInput(EvalOp
[int, Literal
[0]]):
272 def __init__(self
, lhs
, rhs
=0):
273 # type: (int, int) -> None
275 raise ValueError("Input part_index (lhs) must be >= 0")
277 raise ValueError("Input rhs must be 0")
278 super().__init
__(lhs
, rhs
)
281 def part_index(self
):
284 def _make_poly(self
):
285 # type: () -> EvalOpPoly
286 return EvalOpPoly({self
.part_index
: 1})
289 @plain_data(frozen
=True, unsafe_hash
=True)
291 class ToomCookInstance
:
292 __slots__
= ("lhs_part_count", "rhs_part_count", "eval_points",
293 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
296 def prod_part_count(self
):
297 return self
.lhs_part_count
+ self
.rhs_part_count
- 1
300 def make_eval_matrix(width
, eval_points
):
301 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
302 retval
= Matrix(height
=len(eval_points
), width
=width
)
303 for row
, col
in retval
.indexes():
304 eval_point
= eval_points
[row
]
305 if eval_point
is POINT_AT_INFINITY
:
306 retval
[row
, col
] = int(col
== width
- 1)
308 retval
[row
, col
] = eval_point
** col
311 def get_lhs_eval_matrix(self
):
312 # type: () -> Matrix[Fraction]
313 return self
.make_eval_matrix(self
.lhs_part_count
, self
.eval_points
)
316 def make_input_poly_vector(height
):
317 # type: (int) -> Matrix[EvalOpPoly]
318 return Matrix(height
=height
, width
=1, element_type
=EvalOpPoly
,
319 data
=(EvalOpPoly({i
: 1}) for i
in range(height
)))
321 def get_lhs_eval_polys(self
):
322 # type: () -> list[EvalOpPoly]
323 return list(self
.get_lhs_eval_matrix().cast(EvalOpPoly
)
324 @ self
.make_input_poly_vector(self
.lhs_part_count
))
326 def get_rhs_eval_matrix(self
):
327 # type: () -> Matrix[Fraction]
328 return self
.make_eval_matrix(self
.rhs_part_count
, self
.eval_points
)
330 def get_rhs_eval_polys(self
):
331 # type: () -> list[EvalOpPoly]
332 return list(self
.get_rhs_eval_matrix().cast(EvalOpPoly
)
333 @ self
.make_input_poly_vector(self
.rhs_part_count
))
335 def get_prod_inverse_eval_matrix(self
):
336 # type: () -> Matrix[Fraction]
337 return self
.make_eval_matrix(self
.prod_part_count
, self
.eval_points
)
339 def get_prod_eval_matrix(self
):
340 # type: () -> Matrix[Fraction]
341 return self
.get_prod_inverse_eval_matrix().inverse()
343 def get_prod_eval_polys(self
):
344 # type: () -> list[EvalOpPoly]
345 return list(self
.get_prod_eval_matrix().cast(EvalOpPoly
)
346 @ self
.make_input_poly_vector(self
.prod_part_count
))
349 self
, lhs_part_count
, # type: int
350 rhs_part_count
, # type: int
351 eval_points
, # type: Iterable[PointAtInfinity | int]
352 lhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
353 rhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
354 prod_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
356 # type: (...) -> None
357 self
.lhs_part_count
= lhs_part_count
358 if self
.lhs_part_count
< 2:
359 raise ValueError("lhs_part_count must be at least 2")
360 self
.rhs_part_count
= rhs_part_count
361 if self
.rhs_part_count
< 2:
362 raise ValueError("rhs_part_count must be at least 2")
363 eval_points
= list(eval_points
)
364 self
.eval_points
= tuple(eval_points
)
365 if len(self
.eval_points
) != len(set(self
.eval_points
)):
366 raise ValueError("duplicate eval points")
367 self
.lhs_eval_ops
= tuple(lhs_eval_ops
)
368 if len(self
.lhs_eval_ops
) != self
.prod_part_count
:
369 raise ValueError("wrong number of lhs_eval_ops")
370 self
.rhs_eval_ops
= tuple(rhs_eval_ops
)
371 if len(self
.rhs_eval_ops
) != self
.prod_part_count
:
372 raise ValueError("wrong number of rhs_eval_ops")
373 if len(self
.eval_points
) != self
.prod_part_count
:
374 raise ValueError("wrong number of eval_points")
375 self
.prod_eval_ops
= tuple(prod_eval_ops
)
376 if len(self
.prod_eval_ops
) != self
.prod_part_count
:
377 raise ValueError("wrong number of prod_eval_ops")
379 lhs_eval_polys
= self
.get_lhs_eval_polys()
380 for i
, eval_op
in enumerate(self
.lhs_eval_ops
):
381 if lhs_eval_polys
[i
] != eval_op
.poly
:
383 f
"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
384 f
"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
386 rhs_eval_polys
= self
.get_rhs_eval_polys()
387 for i
, eval_op
in enumerate(self
.rhs_eval_ops
):
388 if rhs_eval_polys
[i
] != eval_op
.poly
:
390 f
"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
391 f
"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
393 prod_eval_polys
= self
.get_prod_eval_polys() # also checks matrix
394 for i
, eval_op
in enumerate(self
.prod_eval_ops
):
395 if prod_eval_polys
[i
] != eval_op
.poly
:
397 f
"prod_eval_ops[{i}] is incorrect: expected polynomial: "
398 f
"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
401 # type: () -> ToomCookInstance
402 """return a ToomCookInstance where lhs/rhs are reversed"""
403 return ToomCookInstance(
404 lhs_part_count
=self
.rhs_part_count
,
405 rhs_part_count
=self
.lhs_part_count
,
406 eval_points
=self
.eval_points
,
407 lhs_eval_ops
=self
.rhs_eval_ops
,
408 rhs_eval_ops
=self
.lhs_eval_ops
,
409 prod_eval_ops
=self
.prod_eval_ops
)
413 # type: () -> ToomCookInstance
414 """make an instance of Toom-2 aka Karatsuba multiplication"""
415 return ToomCookInstance(
418 eval_points
=[0, 1, POINT_AT_INFINITY
],
421 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
426 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
431 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
439 # type: () -> ToomCookInstance
440 """makes an instance of Toom-2.5"""
441 inp_0_plus_inp_2
= EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
442 inp_1_minus_inp_2
= EvalOpSub(EvalOpInput(1), EvalOpInput(2))
443 inp_1_plus_inp_2
= EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
444 inp_1_minus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_minus_inp_2
, 2)
445 inp_1_plus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_plus_inp_2
, 2)
446 return ToomCookInstance(
449 eval_points
=[0, 1, -1, POINT_AT_INFINITY
],
452 EvalOpAdd(inp_0_plus_inp_2
, EvalOpInput(1)),
453 EvalOpSub(inp_0_plus_inp_2
, EvalOpInput(1)),
458 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
459 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
464 EvalOpSub(inp_1_minus_inp_2_all_div_2
, EvalOpInput(3)),
465 EvalOpSub(inp_1_plus_inp_2_all_div_2
, EvalOpInput(0)),
470 # TODO: add make_toom_3
473 def simple_mul(fn
, lhs
, rhs
):
474 # type: (Fn, SSAVal, SSAVal) -> SSAVal
475 """ simple O(n^2) big-int unsigned multiply """
476 if lhs
.ty
.reg_len
< rhs
.ty
.reg_len
:
478 # split rhs into elements
479 rhs_setvl
= fn
.append_new_op(kind
=OpKind
.SetVLI
,
480 immediates
=[rhs
.ty
.reg_len
], name
="rhs_setvl")
481 rhs_spread
= fn
.append_new_op(
482 kind
=OpKind
.Spread
, input_vals
=[rhs
, rhs_setvl
.outputs
[0]],
483 maxvl
=rhs
.ty
.reg_len
, name
="rhs_spread")
484 rhs_words
= rhs_spread
.outputs
485 spread_retval
= None # type: tuple[SSAVal, ...] | None
486 maxvl
= lhs
.ty
.reg_len
487 lhs_setvl
= fn
.append_new_op(kind
=OpKind
.SetVLI
,
488 immediates
=[lhs
.ty
.reg_len
], name
="lhs_setvl")
489 vl
= lhs_setvl
.outputs
[0]
490 zero_op
= fn
.append_new_op(kind
=OpKind
.LI
, immediates
=[0], name
="zero")
491 zero
= zero_op
.outputs
[0]
492 for shift
, rhs_word
in enumerate(rhs_words
):
493 mul
= fn
.append_new_op(kind
=OpKind
.SvMAddEDU
,
494 input_vals
=[lhs
, rhs_word
, zero
, vl
],
495 maxvl
=maxvl
, name
=f
"mul{shift}")
496 if spread_retval
is None:
497 mul_rt_spread
= fn
.append_new_op(
498 kind
=OpKind
.Spread
, input_vals
=[mul
.outputs
[0], vl
],
499 name
=f
"mul{shift}_rt_spread", maxvl
=maxvl
)
500 spread_retval
= (*mul_rt_spread
.outputs
, mul
.outputs
[1])
502 first_part
= spread_retval
[:shift
] # type: tuple[SSAVal, ...]
503 last_part
= spread_retval
[shift
:]
505 add_rb_concat
= fn
.append_new_op(
506 kind
=OpKind
.Concat
, input_vals
=[*last_part
, vl
],
507 name
=f
"add{shift}_rb_concat", maxvl
=maxvl
)
508 clear_ca
= fn
.append_new_op(kind
=OpKind
.ClearCA
,
509 name
=f
"clear_ca{shift}")
510 add
= fn
.append_new_op(
511 kind
=OpKind
.SvAddE
, input_vals
=[
512 mul
.outputs
[0], add_rb_concat
.outputs
[0],
513 clear_ca
.outputs
[0], vl
],
514 maxvl
=maxvl
, name
=f
"add{shift}")
515 add_rt_spread
= fn
.append_new_op(
516 kind
=OpKind
.Spread
, input_vals
=[add
.outputs
[0], vl
],
517 name
=f
"add{shift}_rt_spread", maxvl
=maxvl
)
518 add_hi
= fn
.append_new_op(
519 kind
=OpKind
.AddZE
, input_vals
=[mul
.outputs
[1], add
.outputs
[1]],
520 name
=f
"add_hi{shift}")
522 *first_part
, *add_rt_spread
.outputs
, add_hi
.outputs
[0])
523 assert spread_retval
is not None
524 lhs_setvl
= fn
.append_new_op(
525 kind
=OpKind
.SetVLI
, immediates
=[len(spread_retval
)],
527 concat_retval
= fn
.append_new_op(
528 kind
=OpKind
.Concat
, input_vals
=[*spread_retval
, lhs_setvl
.outputs
[0]],
529 name
="concat_retval", maxvl
=len(spread_retval
))
530 return concat_retval
.outputs
[0]
533 def toom_cook_mul(fn
, lhs
, rhs
, instances
):
534 # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal
535 raise NotImplementedError