working on toom-cook multiplication
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
index c014d09a59a5d93a0428249bca81ac1d9bcf1fb1..865015e52a3f7e5b517106b27f273b2131f67cfe 100644 (file)
@@ -1,8 +1,226 @@
 """
-Toom-Cook algorithm generator for SVP64
-
-the register allocator uses an algorithm based on:
-[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+Toom-Cook multiplication algorithm generator for SVP64
 """
-from bigint_presentation_code.compiler_ir import Op
-from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed
+from abc import abstractmethod
+from enum import Enum
+from fractions import Fraction
+from typing import Any, Generic, Iterable, Sequence, TypeVar
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import Fn, Op
+from bigint_presentation_code.util import Literal, OFSet, OSet, final
+
+
+@final
+class PointAtInfinity(Enum):
+    POINT_AT_INFINITY = "POINT_AT_INFINITY"
+
+    def __repr__(self):
+        return self.name
+
+
+POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
+WORD_BITS = 64
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpPoly:
+    """polynomial"""
+    __slots__ = "coefficients",
+
+    def __init__(self, coefficients=()):
+        # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None
+        if isinstance(coefficients, EvalOpPoly):
+            coefficients = coefficients.coefficients
+        elif isinstance(coefficients, (int, Fraction)):
+            coefficients = coefficients,
+        v = list(map(Fraction, coefficients))
+        while len(v) != 0 and v[-1] == 0:
+            v.pop()
+        self.coefficients = tuple(v)  # type: tuple[Fraction, ...]
+
+    def __add__(self, rhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        rhs = EvalOpPoly(rhs)
+        retval = list(self.coefficients)
+        extra = len(rhs.coefficients) - len(retval)
+        if extra > 0:
+            retval.extend([Fraction(0)] * extra)
+        for i, v in enumerate(rhs.coefficients):
+            retval[i] += v
+        return EvalOpPoly(retval)
+
+    __radd__ = __add__
+
+    def __neg__(self):
+        return EvalOpPoly(-v for v in self.coefficients)
+
+    def __sub__(self, rhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        return self + -rhs
+
+    def __rsub__(self, lhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        return lhs + -self
+
+    def __mul__(self, rhs):
+        # type: (int | Fraction) -> EvalOpPoly
+        return EvalOpPoly(v * rhs for v in self.coefficients)
+
+    __rmul__ = __mul__
+
+    def __truediv__(self, rhs):
+        # type: (int | Fraction) -> EvalOpPoly
+        if rhs == 0:
+            raise ZeroDivisionError()
+        return EvalOpPoly(v / rhs for v in self.coefficients)
+
+
+_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
+_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = "lhs", "rhs", "poly"
+
+    @property
+    def lhs_poly(self):
+        # type: () -> EvalOpPoly
+        if isinstance(self.lhs, int):
+            return EvalOpPoly(self.lhs)
+        return self.lhs.poly
+
+    @property
+    def rhs_poly(self):
+        # type: () -> EvalOpPoly
+        if isinstance(self.rhs, int):
+            return EvalOpPoly(self.rhs)
+        return self.rhs.poly
+
+    @abstractmethod
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        ...
+
+    def __init__(self, lhs, rhs):
+        # type: (_EvalOpLHS, _EvalOpRHS) -> None
+        self.lhs = lhs
+        self.rhs = rhs
+        self.poly = self._make_poly()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly + self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly - self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpMul(EvalOp[_EvalOpLHS, int]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly * self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly / self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpInput(EvalOp[int, Literal[0]]):
+    __slots__ = ()
+
+    def __init__(self, lhs, rhs=0):
+        # type: (...) -> None
+        if lhs < 0:
+            raise ValueError("Input split_index (lhs) must be >= 0")
+        if rhs != 0:
+            raise ValueError("Input rhs must be 0")
+        super().__init__(lhs, rhs)
+
+    @property
+    def split_index(self):
+        return self.lhs
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return EvalOpPoly([0] * self.split_index + [1])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ToomCookInstance:
+    __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points",
+                 "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops")
+
+    def __init__(
+        self, lhs_split_count,  # type: int
+        rhs_split_count,  # type: int
+        eval_points,  # type: Iterable[PointAtInfinity | int]
+        lhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+        rhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+        product_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+    ):
+        # type: (...) -> None
+        self.lhs_split_count = lhs_split_count
+        if self.lhs_split_count < 2:
+            raise ValueError("lhs_split_count must be at least 2")
+        self.rhs_split_count = rhs_split_count
+        if self.rhs_split_count < 2:
+            raise ValueError("rhs_split_count must be at least 2")
+        eval_points = list(eval_points)
+        self.eval_points = OFSet(eval_points)
+        if len(self.eval_points) != len(eval_points):
+            raise ValueError("duplicate eval points")
+        self.lhs_eval_ops = tuple(lhs_eval_ops)
+        if len(self.lhs_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of lhs_eval_ops")
+        self.rhs_eval_ops = tuple(rhs_eval_ops)
+        if len(self.rhs_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of rhs_eval_ops")
+        if self.lhs_split_count < 2:
+            raise ValueError("lhs_split_count must be at least 2")
+        if self.rhs_split_count < 2:
+            raise ValueError("rhs_split_count must be at least 2")
+        if (self.lhs_split_count + self.rhs_split_count - 1
+                != len(self.eval_points)):
+            raise ValueError("wrong number of eval_points")
+        self.product_eval_ops = tuple(product_eval_ops)
+        if len(self.product_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of product_eval_ops")
+        # TODO: compute and check matrix and all the *_eval_ops
+        raise NotImplementedError
+
+
+def toom_cook_mul(fn, word_count, instances):
+    # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
+    retval = OSet()  # type: OSet[Op]
+    raise NotImplementedError
+    return retval