76a8a994b11d1123e1233b3c8e4a03857ebdc437
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 """
2 Toom-Cook multiplication algorithm generator for SVP64
3 """
4 from abc import abstractmethod
5 from enum import Enum
6 from fractions import Fraction
7 from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
8
9 from nmutil.plain_data import plain_data
10
11 from bigint_presentation_code.compiler_ir import Fn, Op
12 from bigint_presentation_code.matrix import Matrix
13 from bigint_presentation_code.util import Literal, OSet, final
14
15
16 @final
17 class PointAtInfinity(Enum):
18 POINT_AT_INFINITY = "POINT_AT_INFINITY"
19
20 def __repr__(self):
21 return self.name
22
23
24 POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
25 WORD_BITS = 64
26
27 _EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
28 "EvalOpPoly", Fraction, int, None]
29
30
31 @plain_data(frozen=True, unsafe_hash=True, repr=False)
32 @final
33 class EvalOpPoly:
34 """polynomial"""
35 __slots__ = "const_coeff", "var_coeffs"
36
37 def __init__(
38 self, coeffs=None, # type: _EvalOpPolyCoefficients
39 const_coeff=None, # type: Fraction | int | None
40 var_coeffs=(), # type: Iterable[Fraction | int] | None
41 ):
42 if coeffs is not None:
43 if const_coeff is not None or var_coeffs != ():
44 raise ValueError(
45 "can't specify const_coeff or "
46 "var_coeffs along with coeffs")
47 if isinstance(coeffs, EvalOpPoly):
48 self.const_coeff = coeffs.const_coeff
49 self.var_coeffs = coeffs.var_coeffs
50 return
51 if isinstance(coeffs, (int, Fraction)):
52 const_coeff = Fraction(coeffs)
53 final_var_coeffs = [] # type: list[Fraction]
54 else:
55 const_coeff = 0
56 final_var_coeffs = []
57 for var, coeff in coeffs.items():
58 if coeff == 0:
59 continue
60 coeff = Fraction(coeff)
61 if var is None:
62 const_coeff = coeff
63 continue
64 if var < 0:
65 raise ValueError("invalid variable index")
66 if var >= len(final_var_coeffs):
67 additional = var - len(final_var_coeffs)
68 final_var_coeffs.extend((Fraction(),) * additional)
69 final_var_coeffs.append(coeff)
70 else:
71 final_var_coeffs[var] = coeff
72 else:
73 if var_coeffs is None:
74 final_var_coeffs = []
75 else:
76 final_var_coeffs = [Fraction(v) for v in var_coeffs]
77 while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
78 final_var_coeffs.pop()
79 if const_coeff is None:
80 const_coeff = 0
81 self.const_coeff = Fraction(const_coeff)
82 self.var_coeffs = tuple(final_var_coeffs)
83
84 def __add__(self, rhs):
85 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
86 rhs = EvalOpPoly(rhs)
87 const_coeff = self.const_coeff + rhs.const_coeff
88 var_coeffs = list(self.var_coeffs)
89 if len(rhs.var_coeffs) > len(var_coeffs):
90 var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
91 for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
92 var_coeffs[var] += rhs.var_coeffs[var]
93 return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
94
95 @property
96 def coefficients(self):
97 # type: () -> dict[int | None, Fraction]
98 retval = {} # type: dict[int | None, Fraction]
99 if self.const_coeff != 0:
100 retval[None] = self.const_coeff
101 for var, coeff in enumerate(self.var_coeffs):
102 if coeff != 0:
103 retval[var] = coeff
104 return retval
105
106 @property
107 def is_const(self):
108 # type: () -> bool
109 return self.var_coeffs == ()
110
111 def coeff(self, var):
112 # type: (int | None) -> Fraction
113 if var is None:
114 return self.const_coeff
115 if var < 0:
116 raise ValueError("invalid variable index")
117 if var < len(self.var_coeffs):
118 return self.var_coeffs[var]
119 return Fraction()
120
121 __radd__ = __add__
122
123 def __neg__(self):
124 return EvalOpPoly(const_coeff=-self.const_coeff,
125 var_coeffs=(-v for v in self.var_coeffs))
126
127 def __sub__(self, rhs):
128 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
129 return self + -rhs
130
131 def __rsub__(self, lhs):
132 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
133 return lhs + -self
134
135 def __mul__(self, rhs):
136 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
137 if isinstance(rhs, EvalOpPoly):
138 if self.is_const:
139 self, rhs = rhs, self
140 if not rhs.is_const:
141 raise ValueError("can't represent exponents larger than one")
142 rhs = rhs.const_coeff
143 if rhs == 0:
144 return EvalOpPoly()
145 return EvalOpPoly(const_coeff=self.const_coeff * rhs,
146 var_coeffs=(i * rhs for i in self.var_coeffs))
147
148 __rmul__ = __mul__
149
150 def __truediv__(self, rhs):
151 # type: (int | Fraction) -> EvalOpPoly
152 if rhs == 0:
153 raise ZeroDivisionError()
154 return EvalOpPoly(const_coeff=self.const_coeff / rhs,
155 var_coeffs=(i / rhs for i in self.var_coeffs))
156
157 def __repr__(self):
158 return f"EvalOpPoly({self.coefficients})"
159
160
161 _EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
162 _EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
163
164
165 @plain_data(frozen=True, unsafe_hash=True)
166 class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
167 __slots__ = "lhs", "rhs", "poly"
168
169 @property
170 def lhs_poly(self):
171 # type: () -> EvalOpPoly
172 if isinstance(self.lhs, int):
173 return EvalOpPoly(self.lhs)
174 return self.lhs.poly
175
176 @property
177 def rhs_poly(self):
178 # type: () -> EvalOpPoly
179 if isinstance(self.rhs, int):
180 return EvalOpPoly(self.rhs)
181 return self.rhs.poly
182
183 @abstractmethod
184 def _make_poly(self):
185 # type: () -> EvalOpPoly
186 ...
187
188 def __init__(self, lhs, rhs):
189 # type: (_EvalOpLHS, _EvalOpRHS) -> None
190 self.lhs = lhs
191 self.rhs = rhs
192 self.poly = self._make_poly()
193
194
195 @plain_data(frozen=True, unsafe_hash=True)
196 @final
197 class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
198 __slots__ = ()
199
200 def _make_poly(self):
201 # type: () -> EvalOpPoly
202 return self.lhs_poly + self.rhs_poly
203
204
205 @plain_data(frozen=True, unsafe_hash=True)
206 @final
207 class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
208 __slots__ = ()
209
210 def _make_poly(self):
211 # type: () -> EvalOpPoly
212 return self.lhs_poly - self.rhs_poly
213
214
215 @plain_data(frozen=True, unsafe_hash=True)
216 @final
217 class EvalOpMul(EvalOp[_EvalOpLHS, int]):
218 __slots__ = ()
219
220 def _make_poly(self):
221 # type: () -> EvalOpPoly
222 return self.lhs_poly * self.rhs
223
224
225 @plain_data(frozen=True, unsafe_hash=True)
226 @final
227 class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
228 __slots__ = ()
229
230 def _make_poly(self):
231 # type: () -> EvalOpPoly
232 return self.lhs_poly / self.rhs
233
234
235 @plain_data(frozen=True, unsafe_hash=True)
236 @final
237 class EvalOpInput(EvalOp[int, Literal[0]]):
238 __slots__ = ()
239
240 def __init__(self, lhs, rhs=0):
241 # type: (...) -> None
242 if lhs < 0:
243 raise ValueError("Input part_index (lhs) must be >= 0")
244 if rhs != 0:
245 raise ValueError("Input rhs must be 0")
246 super().__init__(lhs, rhs)
247
248 @property
249 def part_index(self):
250 return self.lhs
251
252 def _make_poly(self):
253 # type: () -> EvalOpPoly
254 return EvalOpPoly({self.part_index: 1})
255
256
257 @plain_data(frozen=True, unsafe_hash=True)
258 @final
259 class ToomCookInstance:
260 __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
261 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
262
263 @property
264 def prod_part_count(self):
265 return self.lhs_part_count + self.rhs_part_count - 1
266
267 @staticmethod
268 def make_eval_matrix(width, eval_points):
269 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
270 retval = Matrix(height=len(eval_points), width=width)
271 for row, col in retval.indexes():
272 eval_point = eval_points[row]
273 if eval_point is POINT_AT_INFINITY:
274 retval[row, col] = int(col == width - 1)
275 else:
276 retval[row, col] = eval_point ** col
277 return retval
278
279 def get_lhs_eval_matrix(self):
280 # type: () -> Matrix[Fraction]
281 return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
282
283 @staticmethod
284 def make_input_poly_vector(height):
285 # type: (int) -> Matrix[EvalOpPoly]
286 return Matrix(height=height, width=1, element_type=EvalOpPoly,
287 data=(EvalOpPoly({i: 1}) for i in range(height)))
288
289 def get_lhs_eval_polys(self):
290 # type: () -> list[EvalOpPoly]
291 return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
292 @ self.make_input_poly_vector(self.lhs_part_count))
293
294 def get_rhs_eval_matrix(self):
295 # type: () -> Matrix[Fraction]
296 return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
297
298 def get_rhs_eval_polys(self):
299 # type: () -> list[EvalOpPoly]
300 return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
301 @ self.make_input_poly_vector(self.rhs_part_count))
302
303 def get_prod_inverse_eval_matrix(self):
304 # type: () -> Matrix[Fraction]
305 return self.make_eval_matrix(self.prod_part_count, self.eval_points)
306
307 def get_prod_eval_matrix(self):
308 # type: () -> Matrix[Fraction]
309 return self.get_prod_inverse_eval_matrix().inverse()
310
311 def get_prod_eval_polys(self):
312 # type: () -> list[EvalOpPoly]
313 return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
314 @ self.make_input_poly_vector(self.prod_part_count))
315
316 def __init__(
317 self, lhs_part_count, # type: int
318 rhs_part_count, # type: int
319 eval_points, # type: Iterable[PointAtInfinity | int]
320 lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
321 rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
322 prod_eval_ops, # type: Iterable[EvalOp[Any, Any]]
323 ):
324 # type: (...) -> None
325 self.lhs_part_count = lhs_part_count
326 if self.lhs_part_count < 2:
327 raise ValueError("lhs_part_count must be at least 2")
328 self.rhs_part_count = rhs_part_count
329 if self.rhs_part_count < 2:
330 raise ValueError("rhs_part_count must be at least 2")
331 eval_points = list(eval_points)
332 self.eval_points = tuple(eval_points)
333 if len(self.eval_points) != len(set(self.eval_points)):
334 raise ValueError("duplicate eval points")
335 self.lhs_eval_ops = tuple(lhs_eval_ops)
336 if len(self.lhs_eval_ops) != self.prod_part_count:
337 raise ValueError("wrong number of lhs_eval_ops")
338 self.rhs_eval_ops = tuple(rhs_eval_ops)
339 if len(self.rhs_eval_ops) != self.prod_part_count:
340 raise ValueError("wrong number of rhs_eval_ops")
341 if len(self.eval_points) != self.prod_part_count:
342 raise ValueError("wrong number of eval_points")
343 self.prod_eval_ops = tuple(prod_eval_ops)
344 if len(self.prod_eval_ops) != self.prod_part_count:
345 raise ValueError("wrong number of prod_eval_ops")
346
347 lhs_eval_polys = self.get_lhs_eval_polys()
348 for i, eval_op in enumerate(self.lhs_eval_ops):
349 if lhs_eval_polys[i] != eval_op.poly:
350 raise ValueError(
351 f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
352 f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
353
354 rhs_eval_polys = self.get_rhs_eval_polys()
355 for i, eval_op in enumerate(self.rhs_eval_ops):
356 if rhs_eval_polys[i] != eval_op.poly:
357 raise ValueError(
358 f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
359 f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
360
361 prod_eval_polys = self.get_prod_eval_polys() # also checks matrix
362 for i, eval_op in enumerate(self.prod_eval_ops):
363 if prod_eval_polys[i] != eval_op.poly:
364 raise ValueError(
365 f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
366 f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
367
368 @staticmethod
369 def make_toom_2():
370 # type: () -> ToomCookInstance
371 return ToomCookInstance(
372 lhs_part_count=2,
373 rhs_part_count=2,
374 eval_points=[0, 1, POINT_AT_INFINITY],
375 lhs_eval_ops=[
376 EvalOpInput(0),
377 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
378 EvalOpInput(1),
379 ],
380 rhs_eval_ops=[
381 EvalOpInput(0),
382 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
383 EvalOpInput(1),
384 ],
385 prod_eval_ops=[
386 EvalOpInput(0),
387 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
388 EvalOpInput(2)),
389 EvalOpInput(2),
390 ],
391 )
392
393
394 def toom_cook_mul(fn, word_count, instances):
395 # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
396 retval = OSet() # type: OSet[Op]
397 raise NotImplementedError
398 return retval