76a8a994b11d1123e1233b3c8e4a03857ebdc437
2 Toom-Cook multiplication algorithm generator for SVP64
4 from abc
import abstractmethod
6 from fractions
import Fraction
7 from typing
import Any
, Generic
, Iterable
, Mapping
, Sequence
, TypeVar
, Union
9 from nmutil
.plain_data
import plain_data
11 from bigint_presentation_code
.compiler_ir
import Fn
, Op
12 from bigint_presentation_code
.matrix
import Matrix
13 from bigint_presentation_code
.util
import Literal
, OSet
, final
17 class PointAtInfinity(Enum
):
18 POINT_AT_INFINITY
= "POINT_AT_INFINITY"
24 POINT_AT_INFINITY
= PointAtInfinity
.POINT_AT_INFINITY
27 _EvalOpPolyCoefficients
= Union
["Mapping[int | None, Fraction | int]",
28 "EvalOpPoly", Fraction
, int, None]
31 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
35 __slots__
= "const_coeff", "var_coeffs"
38 self
, coeffs
=None, # type: _EvalOpPolyCoefficients
39 const_coeff
=None, # type: Fraction | int | None
40 var_coeffs
=(), # type: Iterable[Fraction | int] | None
42 if coeffs
is not None:
43 if const_coeff
is not None or var_coeffs
!= ():
45 "can't specify const_coeff or "
46 "var_coeffs along with coeffs")
47 if isinstance(coeffs
, EvalOpPoly
):
48 self
.const_coeff
= coeffs
.const_coeff
49 self
.var_coeffs
= coeffs
.var_coeffs
51 if isinstance(coeffs
, (int, Fraction
)):
52 const_coeff
= Fraction(coeffs
)
53 final_var_coeffs
= [] # type: list[Fraction]
57 for var
, coeff
in coeffs
.items():
60 coeff
= Fraction(coeff
)
65 raise ValueError("invalid variable index")
66 if var
>= len(final_var_coeffs
):
67 additional
= var
- len(final_var_coeffs
)
68 final_var_coeffs
.extend((Fraction(),) * additional
)
69 final_var_coeffs
.append(coeff
)
71 final_var_coeffs
[var
] = coeff
73 if var_coeffs
is None:
76 final_var_coeffs
= [Fraction(v
) for v
in var_coeffs
]
77 while len(final_var_coeffs
) > 0 and final_var_coeffs
[-1] == 0:
78 final_var_coeffs
.pop()
79 if const_coeff
is None:
81 self
.const_coeff
= Fraction(const_coeff
)
82 self
.var_coeffs
= tuple(final_var_coeffs
)
84 def __add__(self
, rhs
):
85 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
87 const_coeff
= self
.const_coeff
+ rhs
.const_coeff
88 var_coeffs
= list(self
.var_coeffs
)
89 if len(rhs
.var_coeffs
) > len(var_coeffs
):
90 var_coeffs
.extend(rhs
.var_coeffs
[len(var_coeffs
):])
91 for var
in range(min(len(self
.var_coeffs
), len(rhs
.var_coeffs
))):
92 var_coeffs
[var
] += rhs
.var_coeffs
[var
]
93 return EvalOpPoly(const_coeff
=const_coeff
, var_coeffs
=var_coeffs
)
96 def coefficients(self
):
97 # type: () -> dict[int | None, Fraction]
98 retval
= {} # type: dict[int | None, Fraction]
99 if self
.const_coeff
!= 0:
100 retval
[None] = self
.const_coeff
101 for var
, coeff
in enumerate(self
.var_coeffs
):
109 return self
.var_coeffs
== ()
111 def coeff(self
, var
):
112 # type: (int | None) -> Fraction
114 return self
.const_coeff
116 raise ValueError("invalid variable index")
117 if var
< len(self
.var_coeffs
):
118 return self
.var_coeffs
[var
]
124 return EvalOpPoly(const_coeff
=-self
.const_coeff
,
125 var_coeffs
=(-v
for v
in self
.var_coeffs
))
127 def __sub__(self
, rhs
):
128 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
131 def __rsub__(self
, lhs
):
132 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
135 def __mul__(self
, rhs
):
136 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
137 if isinstance(rhs
, EvalOpPoly
):
139 self
, rhs
= rhs
, self
141 raise ValueError("can't represent exponents larger than one")
142 rhs
= rhs
.const_coeff
145 return EvalOpPoly(const_coeff
=self
.const_coeff
* rhs
,
146 var_coeffs
=(i
* rhs
for i
in self
.var_coeffs
))
150 def __truediv__(self
, rhs
):
151 # type: (int | Fraction) -> EvalOpPoly
153 raise ZeroDivisionError()
154 return EvalOpPoly(const_coeff
=self
.const_coeff
/ rhs
,
155 var_coeffs
=(i
/ rhs
for i
in self
.var_coeffs
))
158 return f
"EvalOpPoly({self.coefficients})"
161 _EvalOpLHS
= TypeVar("_EvalOpLHS", int, "EvalOp")
162 _EvalOpRHS
= TypeVar("_EvalOpRHS", int, "EvalOp")
165 @plain_data(frozen
=True, unsafe_hash
=True)
166 class EvalOp(Generic
[_EvalOpLHS
, _EvalOpRHS
]):
167 __slots__
= "lhs", "rhs", "poly"
171 # type: () -> EvalOpPoly
172 if isinstance(self
.lhs
, int):
173 return EvalOpPoly(self
.lhs
)
178 # type: () -> EvalOpPoly
179 if isinstance(self
.rhs
, int):
180 return EvalOpPoly(self
.rhs
)
184 def _make_poly(self
):
185 # type: () -> EvalOpPoly
188 def __init__(self
, lhs
, rhs
):
189 # type: (_EvalOpLHS, _EvalOpRHS) -> None
192 self
.poly
= self
._make
_poly
()
195 @plain_data(frozen
=True, unsafe_hash
=True)
197 class EvalOpAdd(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
200 def _make_poly(self
):
201 # type: () -> EvalOpPoly
202 return self
.lhs_poly
+ self
.rhs_poly
205 @plain_data(frozen
=True, unsafe_hash
=True)
207 class EvalOpSub(EvalOp
[_EvalOpLHS
, _EvalOpRHS
]):
210 def _make_poly(self
):
211 # type: () -> EvalOpPoly
212 return self
.lhs_poly
- self
.rhs_poly
215 @plain_data(frozen
=True, unsafe_hash
=True)
217 class EvalOpMul(EvalOp
[_EvalOpLHS
, int]):
220 def _make_poly(self
):
221 # type: () -> EvalOpPoly
222 return self
.lhs_poly
* self
.rhs
225 @plain_data(frozen
=True, unsafe_hash
=True)
227 class EvalOpExactDiv(EvalOp
[_EvalOpLHS
, int]):
230 def _make_poly(self
):
231 # type: () -> EvalOpPoly
232 return self
.lhs_poly
/ self
.rhs
235 @plain_data(frozen
=True, unsafe_hash
=True)
237 class EvalOpInput(EvalOp
[int, Literal
[0]]):
240 def __init__(self
, lhs
, rhs
=0):
241 # type: (...) -> None
243 raise ValueError("Input part_index (lhs) must be >= 0")
245 raise ValueError("Input rhs must be 0")
246 super().__init
__(lhs
, rhs
)
249 def part_index(self
):
252 def _make_poly(self
):
253 # type: () -> EvalOpPoly
254 return EvalOpPoly({self
.part_index
: 1})
257 @plain_data(frozen
=True, unsafe_hash
=True)
259 class ToomCookInstance
:
260 __slots__
= ("lhs_part_count", "rhs_part_count", "eval_points",
261 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
264 def prod_part_count(self
):
265 return self
.lhs_part_count
+ self
.rhs_part_count
- 1
268 def make_eval_matrix(width
, eval_points
):
269 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
270 retval
= Matrix(height
=len(eval_points
), width
=width
)
271 for row
, col
in retval
.indexes():
272 eval_point
= eval_points
[row
]
273 if eval_point
is POINT_AT_INFINITY
:
274 retval
[row
, col
] = int(col
== width
- 1)
276 retval
[row
, col
] = eval_point
** col
279 def get_lhs_eval_matrix(self
):
280 # type: () -> Matrix[Fraction]
281 return self
.make_eval_matrix(self
.lhs_part_count
, self
.eval_points
)
284 def make_input_poly_vector(height
):
285 # type: (int) -> Matrix[EvalOpPoly]
286 return Matrix(height
=height
, width
=1, element_type
=EvalOpPoly
,
287 data
=(EvalOpPoly({i
: 1}) for i
in range(height
)))
289 def get_lhs_eval_polys(self
):
290 # type: () -> list[EvalOpPoly]
291 return list(self
.get_lhs_eval_matrix().cast(EvalOpPoly
)
292 @ self
.make_input_poly_vector(self
.lhs_part_count
))
294 def get_rhs_eval_matrix(self
):
295 # type: () -> Matrix[Fraction]
296 return self
.make_eval_matrix(self
.rhs_part_count
, self
.eval_points
)
298 def get_rhs_eval_polys(self
):
299 # type: () -> list[EvalOpPoly]
300 return list(self
.get_rhs_eval_matrix().cast(EvalOpPoly
)
301 @ self
.make_input_poly_vector(self
.rhs_part_count
))
303 def get_prod_inverse_eval_matrix(self
):
304 # type: () -> Matrix[Fraction]
305 return self
.make_eval_matrix(self
.prod_part_count
, self
.eval_points
)
307 def get_prod_eval_matrix(self
):
308 # type: () -> Matrix[Fraction]
309 return self
.get_prod_inverse_eval_matrix().inverse()
311 def get_prod_eval_polys(self
):
312 # type: () -> list[EvalOpPoly]
313 return list(self
.get_prod_eval_matrix().cast(EvalOpPoly
)
314 @ self
.make_input_poly_vector(self
.prod_part_count
))
317 self
, lhs_part_count
, # type: int
318 rhs_part_count
, # type: int
319 eval_points
, # type: Iterable[PointAtInfinity | int]
320 lhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
321 rhs_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
322 prod_eval_ops
, # type: Iterable[EvalOp[Any, Any]]
324 # type: (...) -> None
325 self
.lhs_part_count
= lhs_part_count
326 if self
.lhs_part_count
< 2:
327 raise ValueError("lhs_part_count must be at least 2")
328 self
.rhs_part_count
= rhs_part_count
329 if self
.rhs_part_count
< 2:
330 raise ValueError("rhs_part_count must be at least 2")
331 eval_points
= list(eval_points
)
332 self
.eval_points
= tuple(eval_points
)
333 if len(self
.eval_points
) != len(set(self
.eval_points
)):
334 raise ValueError("duplicate eval points")
335 self
.lhs_eval_ops
= tuple(lhs_eval_ops
)
336 if len(self
.lhs_eval_ops
) != self
.prod_part_count
:
337 raise ValueError("wrong number of lhs_eval_ops")
338 self
.rhs_eval_ops
= tuple(rhs_eval_ops
)
339 if len(self
.rhs_eval_ops
) != self
.prod_part_count
:
340 raise ValueError("wrong number of rhs_eval_ops")
341 if len(self
.eval_points
) != self
.prod_part_count
:
342 raise ValueError("wrong number of eval_points")
343 self
.prod_eval_ops
= tuple(prod_eval_ops
)
344 if len(self
.prod_eval_ops
) != self
.prod_part_count
:
345 raise ValueError("wrong number of prod_eval_ops")
347 lhs_eval_polys
= self
.get_lhs_eval_polys()
348 for i
, eval_op
in enumerate(self
.lhs_eval_ops
):
349 if lhs_eval_polys
[i
] != eval_op
.poly
:
351 f
"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
352 f
"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
354 rhs_eval_polys
= self
.get_rhs_eval_polys()
355 for i
, eval_op
in enumerate(self
.rhs_eval_ops
):
356 if rhs_eval_polys
[i
] != eval_op
.poly
:
358 f
"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
359 f
"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
361 prod_eval_polys
= self
.get_prod_eval_polys() # also checks matrix
362 for i
, eval_op
in enumerate(self
.prod_eval_ops
):
363 if prod_eval_polys
[i
] != eval_op
.poly
:
365 f
"prod_eval_ops[{i}] is incorrect: expected polynomial: "
366 f
"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
370 # type: () -> ToomCookInstance
371 return ToomCookInstance(
374 eval_points
=[0, 1, POINT_AT_INFINITY
],
377 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
382 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
387 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
394 def toom_cook_mul(fn
, word_count
, instances
):
395 # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
396 retval
= OSet() # type: OSet[Op]
397 raise NotImplementedError