03f3e9327c827c4d0403889653e9fa3f37ff76f1
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 """
2 Toom-Cook multiplication algorithm generator for SVP64
3 """
4 from abc import ABCMeta, abstractmethod
5 from enum import Enum
6 from fractions import Fraction
7 from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
8
9 from nmutil.plain_data import plain_data
10
11 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, Fn, OpKind,
12 SSAVal)
13 from bigint_presentation_code.matrix import Matrix
14 from bigint_presentation_code.type_util import Literal, final
15
16
17 @final
18 class PointAtInfinity(Enum):
19 POINT_AT_INFINITY = "POINT_AT_INFINITY"
20
21 def __repr__(self):
22 return self.name
23
24
25 POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
26 WORD_BITS = GPR_SIZE_IN_BITS
27
28 _EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
29 "EvalOpPoly", Fraction, int, None]
30
31
32 @plain_data(frozen=True, unsafe_hash=True, repr=False)
33 @final
34 class EvalOpPoly:
35 """polynomial"""
36 __slots__ = "const_coeff", "var_coeffs"
37
38 def __init__(
39 self, coeffs=None, # type: _EvalOpPolyCoefficients
40 const_coeff=None, # type: Fraction | int | None
41 var_coeffs=(), # type: Iterable[Fraction | int] | None
42 ):
43 if coeffs is not None:
44 if const_coeff is not None or var_coeffs != ():
45 raise ValueError(
46 "can't specify const_coeff or "
47 "var_coeffs along with coeffs")
48 if isinstance(coeffs, EvalOpPoly):
49 self.const_coeff = coeffs.const_coeff
50 self.var_coeffs = coeffs.var_coeffs
51 return
52 if isinstance(coeffs, (int, Fraction)):
53 const_coeff = Fraction(coeffs)
54 final_var_coeffs = [] # type: list[Fraction]
55 else:
56 const_coeff = 0
57 final_var_coeffs = []
58 for var, coeff in coeffs.items():
59 if coeff == 0:
60 continue
61 coeff = Fraction(coeff)
62 if var is None:
63 const_coeff = coeff
64 continue
65 if var < 0:
66 raise ValueError("invalid variable index")
67 if var >= len(final_var_coeffs):
68 additional = var - len(final_var_coeffs)
69 final_var_coeffs.extend((Fraction(),) * additional)
70 final_var_coeffs.append(coeff)
71 else:
72 final_var_coeffs[var] = coeff
73 else:
74 if var_coeffs is None:
75 final_var_coeffs = []
76 else:
77 final_var_coeffs = [Fraction(v) for v in var_coeffs]
78 while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
79 final_var_coeffs.pop()
80 if const_coeff is None:
81 const_coeff = 0
82 self.const_coeff = Fraction(const_coeff)
83 self.var_coeffs = tuple(final_var_coeffs)
84
85 def __add__(self, rhs):
86 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
87 rhs = EvalOpPoly(rhs)
88 const_coeff = self.const_coeff + rhs.const_coeff
89 var_coeffs = list(self.var_coeffs)
90 if len(rhs.var_coeffs) > len(var_coeffs):
91 var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
92 for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
93 var_coeffs[var] += rhs.var_coeffs[var]
94 return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
95
96 @property
97 def coefficients(self):
98 # type: () -> dict[int | None, Fraction]
99 retval = {} # type: dict[int | None, Fraction]
100 if self.const_coeff != 0:
101 retval[None] = self.const_coeff
102 for var, coeff in enumerate(self.var_coeffs):
103 if coeff != 0:
104 retval[var] = coeff
105 return retval
106
107 @property
108 def is_const(self):
109 # type: () -> bool
110 return self.var_coeffs == ()
111
112 def coeff(self, var):
113 # type: (int | None) -> Fraction
114 if var is None:
115 return self.const_coeff
116 if var < 0:
117 raise ValueError("invalid variable index")
118 if var < len(self.var_coeffs):
119 return self.var_coeffs[var]
120 return Fraction()
121
122 __radd__ = __add__
123
124 def __neg__(self):
125 return EvalOpPoly(const_coeff=-self.const_coeff,
126 var_coeffs=(-v for v in self.var_coeffs))
127
128 def __sub__(self, rhs):
129 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
130 return self + -rhs
131
132 def __rsub__(self, lhs):
133 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
134 return lhs + -self
135
136 def __mul__(self, rhs):
137 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
138 if isinstance(rhs, EvalOpPoly):
139 if self.is_const:
140 self, rhs = rhs, self
141 if not rhs.is_const:
142 raise ValueError("can't represent exponents larger than one")
143 rhs = rhs.const_coeff
144 if rhs == 0:
145 return EvalOpPoly()
146 return EvalOpPoly(const_coeff=self.const_coeff * rhs,
147 var_coeffs=(i * rhs for i in self.var_coeffs))
148
149 __rmul__ = __mul__
150
151 def __truediv__(self, rhs):
152 # type: (int | Fraction) -> EvalOpPoly
153 if rhs == 0:
154 raise ZeroDivisionError()
155 return EvalOpPoly(const_coeff=self.const_coeff / rhs,
156 var_coeffs=(i / rhs for i in self.var_coeffs))
157
158 def __repr__(self):
159 return f"EvalOpPoly({self.coefficients})"
160
161
162 @plain_data(frozen=True, unsafe_hash=True)
163 class EvalOpValueRange:
164 __slots__ = ("eval_op", "inputs_words", "min_value", "max_value",
165 "is_signed")
166
167 def __init__(self, eval_op, inputs_words):
168 # type: (EvalOp[Any, Any], Iterable[int]) -> None
169 super().__init__()
170 self.eval_op = eval_op
171 self.inputs_words = tuple(inputs_words)
172 for words in self.inputs_words:
173 if words <= 0:
174 raise ValueError(f"invalid word count: {words}")
175 min_value = max_value = eval_op.poly.const_coeff
176 for var, coeff in enumerate(eval_op.poly.var_coeffs):
177 if coeff == 0:
178 continue
179 var_min = 0
180 var_max = (1 << self.inputs_words[var] * WORD_BITS) - 1
181 term_min = var_min * coeff
182 term_max = var_max * coeff
183 if term_min > term_max:
184 term_min, term_max = term_max, term_min
185 min_value += term_min
186 max_value += term_max
187 self.min_value = min_value
188 self.max_value = max_value
189 self.is_signed = min_value < 0
190
191
192 _EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
193 _EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
194
195
196 @plain_data(frozen=True, unsafe_hash=True)
197 class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS], metaclass=ABCMeta):
198 __slots__ = "lhs", "rhs", "poly"
199
200 @property
201 def lhs_poly(self):
202 # type: () -> EvalOpPoly
203 if isinstance(self.lhs, int):
204 return EvalOpPoly(self.lhs)
205 return self.lhs.poly
206
207 @property
208 def rhs_poly(self):
209 # type: () -> EvalOpPoly
210 if isinstance(self.rhs, int):
211 return EvalOpPoly(self.rhs)
212 return self.rhs.poly
213
214 @abstractmethod
215 def _make_poly(self):
216 # type: () -> EvalOpPoly
217 ...
218
219 def __init__(self, lhs, rhs):
220 # type: (_EvalOpLHS, _EvalOpRHS) -> None
221 super().__init__()
222 self.lhs = lhs
223 self.rhs = rhs
224 self.poly = self._make_poly()
225
226
227 @plain_data(frozen=True, unsafe_hash=True)
228 @final
229 class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
230 __slots__ = ()
231
232 def _make_poly(self):
233 # type: () -> EvalOpPoly
234 return self.lhs_poly + self.rhs_poly
235
236
237 @plain_data(frozen=True, unsafe_hash=True)
238 @final
239 class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
240 __slots__ = ()
241
242 def _make_poly(self):
243 # type: () -> EvalOpPoly
244 return self.lhs_poly - self.rhs_poly
245
246
247 @plain_data(frozen=True, unsafe_hash=True)
248 @final
249 class EvalOpMul(EvalOp[_EvalOpLHS, int]):
250 __slots__ = ()
251
252 def _make_poly(self):
253 # type: () -> EvalOpPoly
254 return self.lhs_poly * self.rhs
255
256
257 @plain_data(frozen=True, unsafe_hash=True)
258 @final
259 class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
260 __slots__ = ()
261
262 def _make_poly(self):
263 # type: () -> EvalOpPoly
264 return self.lhs_poly / self.rhs
265
266
267 @plain_data(frozen=True, unsafe_hash=True)
268 @final
269 class EvalOpInput(EvalOp[int, Literal[0]]):
270 __slots__ = ()
271
272 def __init__(self, lhs, rhs=0):
273 # type: (int, int) -> None
274 if lhs < 0:
275 raise ValueError("Input part_index (lhs) must be >= 0")
276 if rhs != 0:
277 raise ValueError("Input rhs must be 0")
278 super().__init__(lhs, rhs)
279
280 @property
281 def part_index(self):
282 return self.lhs
283
284 def _make_poly(self):
285 # type: () -> EvalOpPoly
286 return EvalOpPoly({self.part_index: 1})
287
288
289 @plain_data(frozen=True, unsafe_hash=True)
290 @final
291 class ToomCookInstance:
292 __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
293 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
294
295 @property
296 def prod_part_count(self):
297 return self.lhs_part_count + self.rhs_part_count - 1
298
299 @staticmethod
300 def make_eval_matrix(width, eval_points):
301 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
302 retval = Matrix(height=len(eval_points), width=width)
303 for row, col in retval.indexes():
304 eval_point = eval_points[row]
305 if eval_point is POINT_AT_INFINITY:
306 retval[row, col] = int(col == width - 1)
307 else:
308 retval[row, col] = eval_point ** col
309 return retval
310
311 def get_lhs_eval_matrix(self):
312 # type: () -> Matrix[Fraction]
313 return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
314
315 @staticmethod
316 def make_input_poly_vector(height):
317 # type: (int) -> Matrix[EvalOpPoly]
318 return Matrix(height=height, width=1, element_type=EvalOpPoly,
319 data=(EvalOpPoly({i: 1}) for i in range(height)))
320
321 def get_lhs_eval_polys(self):
322 # type: () -> list[EvalOpPoly]
323 return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
324 @ self.make_input_poly_vector(self.lhs_part_count))
325
326 def get_rhs_eval_matrix(self):
327 # type: () -> Matrix[Fraction]
328 return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
329
330 def get_rhs_eval_polys(self):
331 # type: () -> list[EvalOpPoly]
332 return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
333 @ self.make_input_poly_vector(self.rhs_part_count))
334
335 def get_prod_inverse_eval_matrix(self):
336 # type: () -> Matrix[Fraction]
337 return self.make_eval_matrix(self.prod_part_count, self.eval_points)
338
339 def get_prod_eval_matrix(self):
340 # type: () -> Matrix[Fraction]
341 return self.get_prod_inverse_eval_matrix().inverse()
342
343 def get_prod_eval_polys(self):
344 # type: () -> list[EvalOpPoly]
345 return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
346 @ self.make_input_poly_vector(self.prod_part_count))
347
348 def __init__(
349 self, lhs_part_count, # type: int
350 rhs_part_count, # type: int
351 eval_points, # type: Iterable[PointAtInfinity | int]
352 lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
353 rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
354 prod_eval_ops, # type: Iterable[EvalOp[Any, Any]]
355 ):
356 # type: (...) -> None
357 self.lhs_part_count = lhs_part_count
358 if self.lhs_part_count < 2:
359 raise ValueError("lhs_part_count must be at least 2")
360 self.rhs_part_count = rhs_part_count
361 if self.rhs_part_count < 2:
362 raise ValueError("rhs_part_count must be at least 2")
363 eval_points = list(eval_points)
364 self.eval_points = tuple(eval_points)
365 if len(self.eval_points) != len(set(self.eval_points)):
366 raise ValueError("duplicate eval points")
367 self.lhs_eval_ops = tuple(lhs_eval_ops)
368 if len(self.lhs_eval_ops) != self.prod_part_count:
369 raise ValueError("wrong number of lhs_eval_ops")
370 self.rhs_eval_ops = tuple(rhs_eval_ops)
371 if len(self.rhs_eval_ops) != self.prod_part_count:
372 raise ValueError("wrong number of rhs_eval_ops")
373 if len(self.eval_points) != self.prod_part_count:
374 raise ValueError("wrong number of eval_points")
375 self.prod_eval_ops = tuple(prod_eval_ops)
376 if len(self.prod_eval_ops) != self.prod_part_count:
377 raise ValueError("wrong number of prod_eval_ops")
378
379 lhs_eval_polys = self.get_lhs_eval_polys()
380 for i, eval_op in enumerate(self.lhs_eval_ops):
381 if lhs_eval_polys[i] != eval_op.poly:
382 raise ValueError(
383 f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
384 f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
385
386 rhs_eval_polys = self.get_rhs_eval_polys()
387 for i, eval_op in enumerate(self.rhs_eval_ops):
388 if rhs_eval_polys[i] != eval_op.poly:
389 raise ValueError(
390 f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
391 f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
392
393 prod_eval_polys = self.get_prod_eval_polys() # also checks matrix
394 for i, eval_op in enumerate(self.prod_eval_ops):
395 if prod_eval_polys[i] != eval_op.poly:
396 raise ValueError(
397 f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
398 f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
399
400 def reversed(self):
401 # type: () -> ToomCookInstance
402 """return a ToomCookInstance where lhs/rhs are reversed"""
403 return ToomCookInstance(
404 lhs_part_count=self.rhs_part_count,
405 rhs_part_count=self.lhs_part_count,
406 eval_points=self.eval_points,
407 lhs_eval_ops=self.rhs_eval_ops,
408 rhs_eval_ops=self.lhs_eval_ops,
409 prod_eval_ops=self.prod_eval_ops)
410
411 @staticmethod
412 def make_toom_2():
413 # type: () -> ToomCookInstance
414 """make an instance of Toom-2 aka Karatsuba multiplication"""
415 return ToomCookInstance(
416 lhs_part_count=2,
417 rhs_part_count=2,
418 eval_points=[0, 1, POINT_AT_INFINITY],
419 lhs_eval_ops=[
420 EvalOpInput(0),
421 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
422 EvalOpInput(1),
423 ],
424 rhs_eval_ops=[
425 EvalOpInput(0),
426 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
427 EvalOpInput(1),
428 ],
429 prod_eval_ops=[
430 EvalOpInput(0),
431 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
432 EvalOpInput(2)),
433 EvalOpInput(2),
434 ],
435 )
436
437 @staticmethod
438 def make_toom_2_5():
439 # type: () -> ToomCookInstance
440 """makes an instance of Toom-2.5"""
441 inp_0_plus_inp_2 = EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
442 inp_1_minus_inp_2 = EvalOpSub(EvalOpInput(1), EvalOpInput(2))
443 inp_1_plus_inp_2 = EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
444 inp_1_minus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_minus_inp_2, 2)
445 inp_1_plus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_plus_inp_2, 2)
446 return ToomCookInstance(
447 lhs_part_count=3,
448 rhs_part_count=2,
449 eval_points=[0, 1, -1, POINT_AT_INFINITY],
450 lhs_eval_ops=[
451 EvalOpInput(0),
452 EvalOpAdd(inp_0_plus_inp_2, EvalOpInput(1)),
453 EvalOpSub(inp_0_plus_inp_2, EvalOpInput(1)),
454 EvalOpInput(2),
455 ],
456 rhs_eval_ops=[
457 EvalOpInput(0),
458 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
459 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
460 EvalOpInput(1),
461 ],
462 prod_eval_ops=[
463 EvalOpInput(0),
464 EvalOpSub(inp_1_minus_inp_2_all_div_2, EvalOpInput(3)),
465 EvalOpSub(inp_1_plus_inp_2_all_div_2, EvalOpInput(0)),
466 EvalOpInput(3),
467 ],
468 )
469
470 # TODO: add make_toom_3
471
472
473 def simple_mul(fn, lhs, rhs):
474 # type: (Fn, SSAVal, SSAVal) -> SSAVal
475 """ simple O(n^2) big-int unsigned multiply """
476 if lhs.ty.reg_len < rhs.ty.reg_len:
477 lhs, rhs = rhs, lhs
478 # split rhs into elements
479 rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
480 immediates=[rhs.ty.reg_len], name="rhs_setvl")
481 rhs_spread = fn.append_new_op(
482 kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
483 maxvl=rhs.ty.reg_len, name="rhs_spread")
484 rhs_words = rhs_spread.outputs
485 spread_retval = None # type: tuple[SSAVal, ...] | None
486 maxvl = lhs.ty.reg_len
487 lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
488 immediates=[lhs.ty.reg_len], name="lhs_setvl")
489 vl = lhs_setvl.outputs[0]
490 zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero")
491 zero = zero_op.outputs[0]
492 for shift, rhs_word in enumerate(rhs_words):
493 mul = fn.append_new_op(kind=OpKind.SvMAddEDU,
494 input_vals=[lhs, rhs_word, zero, vl],
495 maxvl=maxvl, name=f"mul{shift}")
496 if spread_retval is None:
497 mul_rt_spread = fn.append_new_op(
498 kind=OpKind.Spread, input_vals=[mul.outputs[0], vl],
499 name=f"mul{shift}_rt_spread", maxvl=maxvl)
500 spread_retval = (*mul_rt_spread.outputs, mul.outputs[1])
501 else:
502 first_part = spread_retval[:shift] # type: tuple[SSAVal, ...]
503 last_part = spread_retval[shift:]
504
505 add_rb_concat = fn.append_new_op(
506 kind=OpKind.Concat, input_vals=[*last_part, vl],
507 name=f"add{shift}_rb_concat", maxvl=maxvl)
508 clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
509 name=f"clear_ca{shift}")
510 add = fn.append_new_op(
511 kind=OpKind.SvAddE, input_vals=[
512 mul.outputs[0], add_rb_concat.outputs[0],
513 clear_ca.outputs[0], vl],
514 maxvl=maxvl, name=f"add{shift}")
515 add_rt_spread = fn.append_new_op(
516 kind=OpKind.Spread, input_vals=[add.outputs[0], vl],
517 name=f"add{shift}_rt_spread", maxvl=maxvl)
518 add_hi = fn.append_new_op(
519 kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]],
520 name=f"add_hi{shift}")
521 spread_retval = (
522 *first_part, *add_rt_spread.outputs, add_hi.outputs[0])
523 assert spread_retval is not None
524 lhs_setvl = fn.append_new_op(
525 kind=OpKind.SetVLI, immediates=[len(spread_retval)],
526 name="retval_setvl")
527 concat_retval = fn.append_new_op(
528 kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]],
529 name="concat_retval", maxvl=len(spread_retval))
530 return concat_retval.outputs[0]
531
532
533 def toom_cook_mul(fn, lhs, rhs, instances):
534 # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal
535 raise NotImplementedError