2 Toom-Cook multiplication algorithm generator for SVP64
4 from abc
import abstractmethod
6 from fractions
import Fraction
7 from typing
import Any
, Generic
, Iterable
, Mapping
, TypeVar
, Union
9 from nmutil
.plain_data
import plain_data
11 from bigint_presentation_code
.compiler_ir
import (Fn
, OpBigIntAddSub
,
12 OpBigIntMulDiv
, OpConcat
,
13 OpLI
, OpSetCA
, OpSetVLImm
,
15 from bigint_presentation_code
.matrix
import Matrix
16 from bigint_presentation_code
.type_util
import Literal
, final
20 class PointAtInfinity(Enum
):
21 POINT_AT_INFINITY
= "POINT_AT_INFINITY"
27 POINT_AT_INFINITY
= PointAtInfinity
.POINT_AT_INFINITY
30 _EvalOpPolyCoefficients
= Union
["Mapping[int | None, Fraction | int]",
31 "EvalOpPoly", Fraction
, int, None]
34 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
38 __slots__
= "const_coeff", "var_coeffs"
41 self
, coeffs
=None, # type: _EvalOpPolyCoefficients
42 const_coeff
=None, # type: Fraction | int | None
43 var_coeffs
=(), # type: Iterable[Fraction | int] | None
45 if coeffs
is not None:
46 if const_coeff
is not None or var_coeffs
!= ():
48 "can't specify const_coeff or "
49 "var_coeffs along with coeffs")
50 if isinstance(coeffs
, EvalOpPoly
):
51 self
.const_coeff
= coeffs
.const_coeff
52 self
.var_coeffs
= coeffs
.var_coeffs
54 if isinstance(coeffs
, (int, Fraction
)):
55 const_coeff
= Fraction(coeffs
)
56 final_var_coeffs
= [] # type: list[Fraction]
60 for var
, coeff
in coeffs
.items():
63 coeff
= Fraction(coeff
)
68 raise ValueError("invalid variable index")
69 if var
>= len(final_var_coeffs
):
70 additional
= var
- len(final_var_coeffs
)
71 final_var_coeffs
.extend((Fraction(),) * additional
)
72 final_var_coeffs
.append(coeff
)
74 final_var_coeffs
[var
] = coeff
76 if var_coeffs
is None:
79 final_var_coeffs
= [Fraction(v
) for v
in var_coeffs
]
80 while len(final_var_coeffs
) > 0 and final_var_coeffs
[-1] == 0:
81 final_var_coeffs
.pop()
82 if const_coeff
is None:
84 self
.const_coeff
= Fraction(const_coeff
)
85 self
.var_coeffs
= tuple(final_var_coeffs
)
87 def __add__(self
, rhs
):
88 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
90 const_coeff
= self
.const_coeff
+ rhs
.const_coeff
91 var_coeffs
= list(self
.var_coeffs
)
92 if len(rhs
.var_coeffs
) > len(var_coeffs
):
93 var_coeffs
.extend(rhs
.var_coeffs
[len(var_coeffs
):])
94 for var
in range(min(len(self
.var_coeffs
), len(rhs
.var_coeffs
))):
95 var_coeffs
[var
] += rhs
.var_coeffs
[var
]
96 return EvalOpPoly(const_coeff
=const_coeff
, var_coeffs
=var_coeffs
)
99 def coefficients(self
):
100 # type: () -> dict[int | None, Fraction]
101 retval
= {} # type: dict[int | None, Fraction]
102 if self
.const_coeff
!= 0:
103 retval
[None] = self
.const_coeff
104 for var
, coeff
in enumerate(self
.var_coeffs
):
112 return self
.var_coeffs
== ()
114 def coeff(self
, var
):
115 # type: (int | None) -> Fraction
117 return self
.const_coeff
119 raise ValueError("invalid variable index")
120 if var
< len(self
.var_coeffs
):
121 return self
.var_coeffs
[var
]
127 return EvalOpPoly(const_coeff
=-self
.const_coeff
,
128 var_coeffs
=(-v
for v
in self
.var_coeffs
))
130 def __sub__(self
, rhs
):
131 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
134 def __rsub__(self
, lhs
):
135 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
138 def __mul__(self
, rhs
):
139 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
140 if isinstance(rhs
, EvalOpPoly
):
142 self
, rhs
= rhs
, self
144 raise ValueError("can't represent exponents larger than one")
145 rhs
= rhs
.const_coeff
148 return EvalOpPoly(const_coeff
=self
.const_coeff
* rhs
,
149 var_coeffs
=(i
* rhs
for i
in self
.var_coeffs
))
153 def __truediv__(self
, rhs
):
154 # type: (int | Fraction) -> EvalOpPoly
156 raise ZeroDivisionError()
157 return EvalOpPoly(const_coeff
=self
.const_coeff
/ rhs
,
158 var_coeffs
=(i
/ rhs
for i
in self
.var_coeffs
))
161 return f
"EvalOpPoly({self.coefficients})"
164 _EvalOpLHS
= TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
165 _EvalOpRHS
= TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
168 @plain_data(frozen
=True, unsafe_hash
=True)
169 class EvalOp(Generic
[_EvalOpLHS
, _EvalOpRHS
]):
170 __slots__
= "lhs", "rhs", "poly"
174 # type: () -> EvalOpPoly
175 if isinstance(self
.lhs
, int):
176 return EvalOpPoly(self
.lhs
)
181 # type: () -> EvalOpPoly
182 if isinstance(self
.rhs
, int):
183 return EvalOpPoly(self
.rhs
)
187 def _make_poly(self
):
188 # type: () -> EvalOpPoly
191 def __init__(self
, lhs
, rhs
):
192 # type: (_EvalOpLHS, _EvalOpRHS) -> None
195 self
.poly
= self
._make
_poly
()
198 @plain_data(frozen
=True, unsafe_hash
=True)
200 class EvalOpAdd(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
203 def _make_poly(self
):
204 # type: () -> EvalOpPoly
205 return self
.lhs_poly
+ self
.rhs_poly
208 @plain_data(frozen
=True, unsafe_hash
=True)
210 class EvalOpSub(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
213 def _make_poly(self
):
214 # type: () -> EvalOpPoly
215 return self
.lhs_poly
- self
.rhs_poly
218 @plain_data(frozen
=True, unsafe_hash
=True)
220 class EvalOpMul(EvalOp
[_EvalOpLHS
, int]):
223 def _make_poly(self
):
224 # type: () -> EvalOpPoly
225 return self
.lhs_poly
* self
.rhs
228 @plain_data(frozen
=True, unsafe_hash
=True)
230 class EvalOpExactDiv(EvalOp
[_EvalOpLHS
, int]):
233 def _make_poly(self
):
234 # type: () -> EvalOpPoly
235 return self
.lhs_poly
/ self
.rhs
238 @plain_data(frozen
=True, unsafe_hash
=True)
240 class EvalOpInput(EvalOp
[int, Literal
[0]]):
243 def __init__(self
, lhs
, rhs
=0):
244 # type: (int, int) -> None
246 raise ValueError("Input part_index (lhs) must be >= 0")
248 raise ValueError("Input rhs must be 0")
249 super().__init
__(lhs
, rhs
)
252 def part_index(self
):
255 def _make_poly(self
):
256 # type: () -> EvalOpPoly
257 return EvalOpPoly({self
.part_index
: 1})
260 @plain_data(frozen
=True, unsafe_hash
=True)
262 class ToomCookInstance
:
263 __slots__
= ("lhs_part_count", "rhs_part_count", "eval_points",
264 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
267 def prod_part_count(self
):
268 return self
.lhs_part_count
+ self
.rhs_part_count
- 1
271 def make_eval_matrix(width
, eval_points
):
272 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
273 retval
= Matrix(height
=len(eval_points
), width
=width
)
274 for row
, col
in retval
.indexes():
275 eval_point
= eval_points
[row
]
276 if eval_point
is POINT_AT_INFINITY
:
277 retval
[row
, col
] = int(col
== width
- 1)
279 retval
[row
, col
] = eval_point
** col
282 def get_lhs_eval_matrix(self
):
283 # type: () -> Matrix[Fraction]
284 return self
.make_eval_matrix(self
.lhs_part_count
, self
.eval_points
)
287 def make_input_poly_vector(height
):
288 # type: (int) -> Matrix[EvalOpPoly]
289 return Matrix(height
=height
, width
=1, element_type
=EvalOpPoly
,
290 data
=(EvalOpPoly({i
: 1}) for i
in range(height
)))
292 def get_lhs_eval_polys(self
):
293 # type: () -> list[EvalOpPoly]
294 return list(self
.get_lhs_eval_matrix().cast(EvalOpPoly
)
295 @ self
.make_input_poly_vector(self
.lhs_part_count
))
297 def get_rhs_eval_matrix(self
):
298 # type: () -> Matrix[Fraction]
299 return self
.make_eval_matrix(self
.rhs_part_count
, self
.eval_points
)
301 def get_rhs_eval_polys(self
):
302 # type: () -> list[EvalOpPoly]
303 return list(self
.get_rhs_eval_matrix().cast(EvalOpPoly
)
304 @ self
.make_input_poly_vector(self
.rhs_part_count
))
306 def get_prod_inverse_eval_matrix(self
):
307 # type: () -> Matrix[Fraction]
308 return self
.make_eval_matrix(self
.prod_part_count
, self
.eval_points
)
310 def get_prod_eval_matrix(self
):
311 # type: () -> Matrix[Fraction]
312 return self
.get_prod_inverse_eval_matrix().inverse()
314 def get_prod_eval_polys(self
):
315 # type: () -> list[EvalOpPoly]
316 return list(self
.get_prod_eval_matrix().cast(EvalOpPoly
)
317 @ self
.make_input_poly_vector(self
.prod_part_count
))
320 self
, lhs_part_count
, # type: int
321 rhs_part_count
, # type: int
322 eval_points
, # type: Iterable[PointAtInfinity | int]
323 lhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
324 rhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
325 prod_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
327 # type: (...) -> None
328 self
.lhs_part_count
= lhs_part_count
329 if self
.lhs_part_count
< 2:
330 raise ValueError("lhs_part_count must be at least 2")
331 self
.rhs_part_count
= rhs_part_count
332 if self
.rhs_part_count
< 2:
333 raise ValueError("rhs_part_count must be at least 2")
334 eval_points
= list(eval_points
)
335 self
.eval_points
= tuple(eval_points
)
336 if len(self
.eval_points
) != len(set(self
.eval_points
)):
337 raise ValueError("duplicate eval points")
338 self
.lhs_eval_ops
= tuple(lhs_eval_ops
)
339 if len(self
.lhs_eval_ops
) != self
.prod_part_count
:
340 raise ValueError("wrong number of lhs_eval_ops")
341 self
.rhs_eval_ops
= tuple(rhs_eval_ops
)
342 if len(self
.rhs_eval_ops
) != self
.prod_part_count
:
343 raise ValueError("wrong number of rhs_eval_ops")
344 if len(self
.eval_points
) != self
.prod_part_count
:
345 raise ValueError("wrong number of eval_points")
346 self
.prod_eval_ops
= tuple(prod_eval_ops
)
347 if len(self
.prod_eval_ops
) != self
.prod_part_count
:
348 raise ValueError("wrong number of prod_eval_ops")
350 lhs_eval_polys
= self
.get_lhs_eval_polys()
351 for i
, eval_op
in enumerate(self
.lhs_eval_ops
):
352 if lhs_eval_polys
[i
] != eval_op
.poly
:
354 f
"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
355 f
"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
357 rhs_eval_polys
= self
.get_rhs_eval_polys()
358 for i
, eval_op
in enumerate(self
.rhs_eval_ops
):
359 if rhs_eval_polys
[i
] != eval_op
.poly
:
361 f
"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
362 f
"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
364 prod_eval_polys
= self
.get_prod_eval_polys() # also checks matrix
365 for i
, eval_op
in enumerate(self
.prod_eval_ops
):
366 if prod_eval_polys
[i
] != eval_op
.poly
:
368 f
"prod_eval_ops[{i}] is incorrect: expected polynomial: "
369 f
"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
372 # type: () -> ToomCookInstance
373 """return a ToomCookInstance where lhs/rhs are reversed"""
374 return ToomCookInstance(
375 lhs_part_count
=self
.rhs_part_count
,
376 rhs_part_count
=self
.lhs_part_count
,
377 eval_points
=self
.eval_points
,
378 lhs_eval_ops
=self
.rhs_eval_ops
,
379 rhs_eval_ops
=self
.lhs_eval_ops
,
380 prod_eval_ops
=self
.prod_eval_ops
)
384 # type: () -> ToomCookInstance
385 """make an instance of Toom-2 aka Karatsuba multiplication"""
386 return ToomCookInstance(
389 eval_points
=[0, 1, POINT_AT_INFINITY
],
392 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
397 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
402 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
410 # type: () -> ToomCookInstance
411 """makes an instance of Toom-2.5"""
412 inp_0_plus_inp_2
= EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
413 inp_1_minus_inp_2
= EvalOpSub(EvalOpInput(1), EvalOpInput(2))
414 inp_1_plus_inp_2
= EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
415 inp_1_minus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_minus_inp_2
, 2)
416 inp_1_plus_inp_2_all_div_2
= EvalOpExactDiv(inp_1_plus_inp_2
, 2)
417 return ToomCookInstance(
420 eval_points
=[0, 1, -1, POINT_AT_INFINITY
],
423 EvalOpAdd(inp_0_plus_inp_2
, EvalOpInput(1)),
424 EvalOpSub(inp_0_plus_inp_2
, EvalOpInput(1)),
429 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
430 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
435 EvalOpSub(inp_1_minus_inp_2_all_div_2
, EvalOpInput(3)),
436 EvalOpSub(inp_1_plus_inp_2_all_div_2
, EvalOpInput(0)),
441 # TODO: add make_toom_3
444 def simple_mul(fn
, lhs
, rhs
):
445 # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
446 """ simple O(n^2) big-int unsigned multiply """
447 if lhs
.ty
.length
< rhs
.ty
.length
:
449 # split rhs into elements
450 rhs_words
= OpSplit(fn
, rhs
, range(1, rhs
.ty
.length
)).results
452 vl
= OpSetVLImm(fn
, lhs
.ty
.length
).out
453 zero
= OpLI(fn
, 0).out
454 for shift
, rhs_word
in enumerate(rhs_words
):
455 mul
= OpBigIntMulDiv(fn
, RA
=lhs
, RB
=rhs_word
, RC
=zero
,
458 retval
= OpConcat(fn
, [mul
.RT
, mul
.RS
]).dest
460 first_part
, last_part
= OpSplit(fn
, retval
, [shift
]).results
461 add
= OpBigIntAddSub(
462 fn
, lhs
=mul
.RT
, rhs
=last_part
, CA_in
=OpSetCA(fn
, False).out
,
464 add_hi
= OpBigIntAddSub(fn
, lhs
=mul
.RS
, rhs
=zero
, CA_in
=add
.CA_out
,
466 retval
= OpConcat(fn
, [first_part
, add
.out
, add_hi
.out
]).dest
467 assert retval
is not None
471 def toom_cook_mul(fn
, lhs
, rhs
, instances
):
472 # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
473 raise NotImplementedError