working on rewriting compiler ir to fix reg alloc issues
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 """
2 Toom-Cook multiplication algorithm generator for SVP64
3 """
4 from abc import abstractmethod
5 from enum import Enum
6 from fractions import Fraction
7 from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
8
9 from nmutil.plain_data import plain_data
10
11 from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub,
12 OpBigIntMulDiv, OpConcat,
13 OpLI, OpSetCA, OpSetVLImm,
14 OpSplit, SSAGPRRange)
15 from bigint_presentation_code.matrix import Matrix
16 from bigint_presentation_code.type_util import Literal, final
17
18
19 @final
20 class PointAtInfinity(Enum):
21 POINT_AT_INFINITY = "POINT_AT_INFINITY"
22
23 def __repr__(self):
24 return self.name
25
26
27 POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
28 WORD_BITS = 64
29
30 _EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
31 "EvalOpPoly", Fraction, int, None]
32
33
34 @plain_data(frozen=True, unsafe_hash=True, repr=False)
35 @final
36 class EvalOpPoly:
37 """polynomial"""
38 __slots__ = "const_coeff", "var_coeffs"
39
40 def __init__(
41 self, coeffs=None, # type: _EvalOpPolyCoefficients
42 const_coeff=None, # type: Fraction | int | None
43 var_coeffs=(), # type: Iterable[Fraction | int] | None
44 ):
45 if coeffs is not None:
46 if const_coeff is not None or var_coeffs != ():
47 raise ValueError(
48 "can't specify const_coeff or "
49 "var_coeffs along with coeffs")
50 if isinstance(coeffs, EvalOpPoly):
51 self.const_coeff = coeffs.const_coeff
52 self.var_coeffs = coeffs.var_coeffs
53 return
54 if isinstance(coeffs, (int, Fraction)):
55 const_coeff = Fraction(coeffs)
56 final_var_coeffs = [] # type: list[Fraction]
57 else:
58 const_coeff = 0
59 final_var_coeffs = []
60 for var, coeff in coeffs.items():
61 if coeff == 0:
62 continue
63 coeff = Fraction(coeff)
64 if var is None:
65 const_coeff = coeff
66 continue
67 if var < 0:
68 raise ValueError("invalid variable index")
69 if var >= len(final_var_coeffs):
70 additional = var - len(final_var_coeffs)
71 final_var_coeffs.extend((Fraction(),) * additional)
72 final_var_coeffs.append(coeff)
73 else:
74 final_var_coeffs[var] = coeff
75 else:
76 if var_coeffs is None:
77 final_var_coeffs = []
78 else:
79 final_var_coeffs = [Fraction(v) for v in var_coeffs]
80 while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
81 final_var_coeffs.pop()
82 if const_coeff is None:
83 const_coeff = 0
84 self.const_coeff = Fraction(const_coeff)
85 self.var_coeffs = tuple(final_var_coeffs)
86
87 def __add__(self, rhs):
88 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
89 rhs = EvalOpPoly(rhs)
90 const_coeff = self.const_coeff + rhs.const_coeff
91 var_coeffs = list(self.var_coeffs)
92 if len(rhs.var_coeffs) > len(var_coeffs):
93 var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
94 for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
95 var_coeffs[var] += rhs.var_coeffs[var]
96 return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
97
98 @property
99 def coefficients(self):
100 # type: () -> dict[int | None, Fraction]
101 retval = {} # type: dict[int | None, Fraction]
102 if self.const_coeff != 0:
103 retval[None] = self.const_coeff
104 for var, coeff in enumerate(self.var_coeffs):
105 if coeff != 0:
106 retval[var] = coeff
107 return retval
108
109 @property
110 def is_const(self):
111 # type: () -> bool
112 return self.var_coeffs == ()
113
114 def coeff(self, var):
115 # type: (int | None) -> Fraction
116 if var is None:
117 return self.const_coeff
118 if var < 0:
119 raise ValueError("invalid variable index")
120 if var < len(self.var_coeffs):
121 return self.var_coeffs[var]
122 return Fraction()
123
124 __radd__ = __add__
125
126 def __neg__(self):
127 return EvalOpPoly(const_coeff=-self.const_coeff,
128 var_coeffs=(-v for v in self.var_coeffs))
129
130 def __sub__(self, rhs):
131 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
132 return self + -rhs
133
134 def __rsub__(self, lhs):
135 # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
136 return lhs + -self
137
138 def __mul__(self, rhs):
139 # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
140 if isinstance(rhs, EvalOpPoly):
141 if self.is_const:
142 self, rhs = rhs, self
143 if not rhs.is_const:
144 raise ValueError("can't represent exponents larger than one")
145 rhs = rhs.const_coeff
146 if rhs == 0:
147 return EvalOpPoly()
148 return EvalOpPoly(const_coeff=self.const_coeff * rhs,
149 var_coeffs=(i * rhs for i in self.var_coeffs))
150
151 __rmul__ = __mul__
152
153 def __truediv__(self, rhs):
154 # type: (int | Fraction) -> EvalOpPoly
155 if rhs == 0:
156 raise ZeroDivisionError()
157 return EvalOpPoly(const_coeff=self.const_coeff / rhs,
158 var_coeffs=(i / rhs for i in self.var_coeffs))
159
160 def __repr__(self):
161 return f"EvalOpPoly({self.coefficients})"
162
163
164 _EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
165 _EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
166
167
168 @plain_data(frozen=True, unsafe_hash=True)
169 class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
170 __slots__ = "lhs", "rhs", "poly"
171
172 @property
173 def lhs_poly(self):
174 # type: () -> EvalOpPoly
175 if isinstance(self.lhs, int):
176 return EvalOpPoly(self.lhs)
177 return self.lhs.poly
178
179 @property
180 def rhs_poly(self):
181 # type: () -> EvalOpPoly
182 if isinstance(self.rhs, int):
183 return EvalOpPoly(self.rhs)
184 return self.rhs.poly
185
186 @abstractmethod
187 def _make_poly(self):
188 # type: () -> EvalOpPoly
189 ...
190
191 def __init__(self, lhs, rhs):
192 # type: (_EvalOpLHS, _EvalOpRHS) -> None
193 self.lhs = lhs
194 self.rhs = rhs
195 self.poly = self._make_poly()
196
197
198 @plain_data(frozen=True, unsafe_hash=True)
199 @final
200 class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
201 __slots__ = ()
202
203 def _make_poly(self):
204 # type: () -> EvalOpPoly
205 return self.lhs_poly + self.rhs_poly
206
207
208 @plain_data(frozen=True, unsafe_hash=True)
209 @final
210 class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
211 __slots__ = ()
212
213 def _make_poly(self):
214 # type: () -> EvalOpPoly
215 return self.lhs_poly - self.rhs_poly
216
217
218 @plain_data(frozen=True, unsafe_hash=True)
219 @final
220 class EvalOpMul(EvalOp[_EvalOpLHS, int]):
221 __slots__ = ()
222
223 def _make_poly(self):
224 # type: () -> EvalOpPoly
225 return self.lhs_poly * self.rhs
226
227
228 @plain_data(frozen=True, unsafe_hash=True)
229 @final
230 class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
231 __slots__ = ()
232
233 def _make_poly(self):
234 # type: () -> EvalOpPoly
235 return self.lhs_poly / self.rhs
236
237
238 @plain_data(frozen=True, unsafe_hash=True)
239 @final
240 class EvalOpInput(EvalOp[int, Literal[0]]):
241 __slots__ = ()
242
243 def __init__(self, lhs, rhs=0):
244 # type: (int, int) -> None
245 if lhs < 0:
246 raise ValueError("Input part_index (lhs) must be >= 0")
247 if rhs != 0:
248 raise ValueError("Input rhs must be 0")
249 super().__init__(lhs, rhs)
250
251 @property
252 def part_index(self):
253 return self.lhs
254
255 def _make_poly(self):
256 # type: () -> EvalOpPoly
257 return EvalOpPoly({self.part_index: 1})
258
259
260 @plain_data(frozen=True, unsafe_hash=True)
261 @final
262 class ToomCookInstance:
263 __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
264 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
265
266 @property
267 def prod_part_count(self):
268 return self.lhs_part_count + self.rhs_part_count - 1
269
270 @staticmethod
271 def make_eval_matrix(width, eval_points):
272 # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
273 retval = Matrix(height=len(eval_points), width=width)
274 for row, col in retval.indexes():
275 eval_point = eval_points[row]
276 if eval_point is POINT_AT_INFINITY:
277 retval[row, col] = int(col == width - 1)
278 else:
279 retval[row, col] = eval_point ** col
280 return retval
281
282 def get_lhs_eval_matrix(self):
283 # type: () -> Matrix[Fraction]
284 return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
285
286 @staticmethod
287 def make_input_poly_vector(height):
288 # type: (int) -> Matrix[EvalOpPoly]
289 return Matrix(height=height, width=1, element_type=EvalOpPoly,
290 data=(EvalOpPoly({i: 1}) for i in range(height)))
291
292 def get_lhs_eval_polys(self):
293 # type: () -> list[EvalOpPoly]
294 return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
295 @ self.make_input_poly_vector(self.lhs_part_count))
296
297 def get_rhs_eval_matrix(self):
298 # type: () -> Matrix[Fraction]
299 return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
300
301 def get_rhs_eval_polys(self):
302 # type: () -> list[EvalOpPoly]
303 return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
304 @ self.make_input_poly_vector(self.rhs_part_count))
305
306 def get_prod_inverse_eval_matrix(self):
307 # type: () -> Matrix[Fraction]
308 return self.make_eval_matrix(self.prod_part_count, self.eval_points)
309
310 def get_prod_eval_matrix(self):
311 # type: () -> Matrix[Fraction]
312 return self.get_prod_inverse_eval_matrix().inverse()
313
314 def get_prod_eval_polys(self):
315 # type: () -> list[EvalOpPoly]
316 return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
317 @ self.make_input_poly_vector(self.prod_part_count))
318
319 def __init__(
320 self, lhs_part_count, # type: int
321 rhs_part_count, # type: int
322 eval_points, # type: Iterable[PointAtInfinity | int]
323 lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
324 rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
325 prod_eval_ops, # type: Iterable[EvalOp[Any, Any]]
326 ):
327 # type: (...) -> None
328 self.lhs_part_count = lhs_part_count
329 if self.lhs_part_count < 2:
330 raise ValueError("lhs_part_count must be at least 2")
331 self.rhs_part_count = rhs_part_count
332 if self.rhs_part_count < 2:
333 raise ValueError("rhs_part_count must be at least 2")
334 eval_points = list(eval_points)
335 self.eval_points = tuple(eval_points)
336 if len(self.eval_points) != len(set(self.eval_points)):
337 raise ValueError("duplicate eval points")
338 self.lhs_eval_ops = tuple(lhs_eval_ops)
339 if len(self.lhs_eval_ops) != self.prod_part_count:
340 raise ValueError("wrong number of lhs_eval_ops")
341 self.rhs_eval_ops = tuple(rhs_eval_ops)
342 if len(self.rhs_eval_ops) != self.prod_part_count:
343 raise ValueError("wrong number of rhs_eval_ops")
344 if len(self.eval_points) != self.prod_part_count:
345 raise ValueError("wrong number of eval_points")
346 self.prod_eval_ops = tuple(prod_eval_ops)
347 if len(self.prod_eval_ops) != self.prod_part_count:
348 raise ValueError("wrong number of prod_eval_ops")
349
350 lhs_eval_polys = self.get_lhs_eval_polys()
351 for i, eval_op in enumerate(self.lhs_eval_ops):
352 if lhs_eval_polys[i] != eval_op.poly:
353 raise ValueError(
354 f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
355 f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
356
357 rhs_eval_polys = self.get_rhs_eval_polys()
358 for i, eval_op in enumerate(self.rhs_eval_ops):
359 if rhs_eval_polys[i] != eval_op.poly:
360 raise ValueError(
361 f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
362 f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
363
364 prod_eval_polys = self.get_prod_eval_polys() # also checks matrix
365 for i, eval_op in enumerate(self.prod_eval_ops):
366 if prod_eval_polys[i] != eval_op.poly:
367 raise ValueError(
368 f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
369 f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
370
371 def reversed(self):
372 # type: () -> ToomCookInstance
373 """return a ToomCookInstance where lhs/rhs are reversed"""
374 return ToomCookInstance(
375 lhs_part_count=self.rhs_part_count,
376 rhs_part_count=self.lhs_part_count,
377 eval_points=self.eval_points,
378 lhs_eval_ops=self.rhs_eval_ops,
379 rhs_eval_ops=self.lhs_eval_ops,
380 prod_eval_ops=self.prod_eval_ops)
381
382 @staticmethod
383 def make_toom_2():
384 # type: () -> ToomCookInstance
385 """make an instance of Toom-2 aka Karatsuba multiplication"""
386 return ToomCookInstance(
387 lhs_part_count=2,
388 rhs_part_count=2,
389 eval_points=[0, 1, POINT_AT_INFINITY],
390 lhs_eval_ops=[
391 EvalOpInput(0),
392 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
393 EvalOpInput(1),
394 ],
395 rhs_eval_ops=[
396 EvalOpInput(0),
397 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
398 EvalOpInput(1),
399 ],
400 prod_eval_ops=[
401 EvalOpInput(0),
402 EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
403 EvalOpInput(2)),
404 EvalOpInput(2),
405 ],
406 )
407
408 @staticmethod
409 def make_toom_2_5():
410 # type: () -> ToomCookInstance
411 """makes an instance of Toom-2.5"""
412 inp_0_plus_inp_2 = EvalOpAdd(EvalOpInput(0), EvalOpInput(2))
413 inp_1_minus_inp_2 = EvalOpSub(EvalOpInput(1), EvalOpInput(2))
414 inp_1_plus_inp_2 = EvalOpAdd(EvalOpInput(1), EvalOpInput(2))
415 inp_1_minus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_minus_inp_2, 2)
416 inp_1_plus_inp_2_all_div_2 = EvalOpExactDiv(inp_1_plus_inp_2, 2)
417 return ToomCookInstance(
418 lhs_part_count=3,
419 rhs_part_count=2,
420 eval_points=[0, 1, -1, POINT_AT_INFINITY],
421 lhs_eval_ops=[
422 EvalOpInput(0),
423 EvalOpAdd(inp_0_plus_inp_2, EvalOpInput(1)),
424 EvalOpSub(inp_0_plus_inp_2, EvalOpInput(1)),
425 EvalOpInput(2),
426 ],
427 rhs_eval_ops=[
428 EvalOpInput(0),
429 EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
430 EvalOpSub(EvalOpInput(0), EvalOpInput(1)),
431 EvalOpInput(1),
432 ],
433 prod_eval_ops=[
434 EvalOpInput(0),
435 EvalOpSub(inp_1_minus_inp_2_all_div_2, EvalOpInput(3)),
436 EvalOpSub(inp_1_plus_inp_2_all_div_2, EvalOpInput(0)),
437 EvalOpInput(3),
438 ],
439 )
440
441 # TODO: add make_toom_3
442
443
444 def simple_mul(fn, lhs, rhs):
445 # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
446 """ simple O(n^2) big-int unsigned multiply """
447 if lhs.ty.length < rhs.ty.length:
448 lhs, rhs = rhs, lhs
449 # split rhs into elements
450 rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results
451 retval = None
452 vl = OpSetVLImm(fn, lhs.ty.length).out
453 zero = OpLI(fn, 0).out
454 for shift, rhs_word in enumerate(rhs_words):
455 mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero,
456 is_div=False, vl=vl)
457 if retval is None:
458 retval = OpConcat(fn, [mul.RT, mul.RS]).dest
459 else:
460 first_part, last_part = OpSplit(fn, retval, [shift]).results
461 add = OpBigIntAddSub(
462 fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out,
463 is_sub=False, vl=vl)
464 add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out,
465 is_sub=False)
466 retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest
467 assert retval is not None
468 return retval
469
470
471 def toom_cook_mul(fn, lhs, rhs, instances):
472 # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
473 raise NotImplementedError