ToomCookInstance works!
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Oct 2022 06:46:02 +0000 (23:46 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Oct 2022 06:46:02 +0000 (23:46 -0700)
src/bigint_presentation_code/matrix.py
src/bigint_presentation_code/test_toom_cook.py
src/bigint_presentation_code/toom_cook.py

index 49acddffa097f05b150c927d0a48c48a5efdec80..89c3ea23b2b4554cc1836a6773adc643713882c4 100644 (file)
@@ -2,11 +2,12 @@ import operator
 from enum import Enum, unique
 from fractions import Fraction
 from numbers import Rational
-from typing import Callable, Generic, Iterable, Iterator, Type, TypeVar
+from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar
 
 from bigint_presentation_code.util import final
 
 _T = TypeVar("_T")
+_T2 = TypeVar("_T2")
 
 
 @final
@@ -37,7 +38,7 @@ class Matrix(Generic[_T]):
 
     def __init__(self, height, width, data=SpecialMatrix.Zero,
                  element_type=Fraction):
-        # type: (int, int, Iterable[_T | int] | SpecialMatrix, Type[_T]) -> None
+        # type: (int, int, Iterable[_T | int | Any] | SpecialMatrix, Type[_T]) -> None
         if width < 0 or height < 0:
             raise ValueError("matrix size must be non-negative")
         self.__height = height
@@ -55,6 +56,11 @@ class Matrix(Generic[_T]):
             if len(self.__data) != height * width:
                 raise ValueError("data has wrong length")
 
+    def cast(self, element_type):
+        # type: (Type[_T2]) -> Matrix[_T2]
+        data = self  # type: Iterable[Any]
+        return Matrix(self.height, self.width, data, element_type=element_type)
+
     def __idx(self, row, col):
         # type: (int, int) -> int
         if 0 <= col < self.width and 0 <= row < self.height:
index 8fe6cea693e98687abc4a877e830680ba963126c..f880a81ca406519fe27d9295c96e022393b930b4 100644 (file)
@@ -1,10 +1,46 @@
 import unittest
 
-import bigint_presentation_code.toom_cook
+from bigint_presentation_code.toom_cook import ToomCookInstance
 
 
 class TestToomCook(unittest.TestCase):
-    pass  # no tests yet, just testing importing
+    def test_toom_2(self):
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        print(repr(repr(TOOM_2)))
+        self.assertEqual(
+            repr(TOOM_2),
+            "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
+            "eval_points=(0, 1, POINT_AT_INFINITY), "
+            "lhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+            " rhs_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpAdd(lhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+            " prod_eval_ops=("
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "EvalOpSub(lhs="
+            "EvalOpSub(lhs="
+            "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
+            "rhs="
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+            "poly=EvalOpPoly({"
+            "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+            "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
+        )
 
 
 if __name__ == "__main__":
index 865015e52a3f7e5b517106b27f273b2131f67cfe..76a8a994b11d1123e1233b3c8e4a03857ebdc437 100644 (file)
@@ -4,12 +4,13 @@ Toom-Cook multiplication algorithm generator for SVP64
 from abc import abstractmethod
 from enum import Enum
 from fractions import Fraction
-from typing import Any, Generic, Iterable, Sequence, TypeVar
+from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
 
 from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import Fn, Op
-from bigint_presentation_code.util import Literal, OFSet, OSet, final
+from bigint_presentation_code.matrix import Matrix
+from bigint_presentation_code.util import Literal, OSet, final
 
 
 @final
@@ -23,39 +24,105 @@ class PointAtInfinity(Enum):
 POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
 WORD_BITS = 64
 
+_EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
+                                "EvalOpPoly", Fraction, int, None]
 
-@plain_data(frozen=True, unsafe_hash=True)
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class EvalOpPoly:
     """polynomial"""
-    __slots__ = "coefficients",
-
-    def __init__(self, coefficients=()):
-        # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None
-        if isinstance(coefficients, EvalOpPoly):
-            coefficients = coefficients.coefficients
-        elif isinstance(coefficients, (int, Fraction)):
-            coefficients = coefficients,
-        v = list(map(Fraction, coefficients))
-        while len(v) != 0 and v[-1] == 0:
-            v.pop()
-        self.coefficients = tuple(v)  # type: tuple[Fraction, ...]
+    __slots__ = "const_coeff", "var_coeffs"
+
+    def __init__(
+        self, coeffs=None,  # type: _EvalOpPolyCoefficients
+        const_coeff=None,  # type: Fraction | int | None
+        var_coeffs=(),  # type: Iterable[Fraction | int] | None
+    ):
+        if coeffs is not None:
+            if const_coeff is not None or var_coeffs != ():
+                raise ValueError(
+                    "can't specify const_coeff or "
+                    "var_coeffs along with coeffs")
+            if isinstance(coeffs, EvalOpPoly):
+                self.const_coeff = coeffs.const_coeff
+                self.var_coeffs = coeffs.var_coeffs
+                return
+            if isinstance(coeffs, (int, Fraction)):
+                const_coeff = Fraction(coeffs)
+                final_var_coeffs = []  # type: list[Fraction]
+            else:
+                const_coeff = 0
+                final_var_coeffs = []
+                for var, coeff in coeffs.items():
+                    if coeff == 0:
+                        continue
+                    coeff = Fraction(coeff)
+                    if var is None:
+                        const_coeff = coeff
+                        continue
+                    if var < 0:
+                        raise ValueError("invalid variable index")
+                    if var >= len(final_var_coeffs):
+                        additional = var - len(final_var_coeffs)
+                        final_var_coeffs.extend((Fraction(),) * additional)
+                        final_var_coeffs.append(coeff)
+                    else:
+                        final_var_coeffs[var] = coeff
+        else:
+            if var_coeffs is None:
+                final_var_coeffs = []
+            else:
+                final_var_coeffs = [Fraction(v) for v in var_coeffs]
+                while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
+                    final_var_coeffs.pop()
+        if const_coeff is None:
+            const_coeff = 0
+        self.const_coeff = Fraction(const_coeff)
+        self.var_coeffs = tuple(final_var_coeffs)
 
     def __add__(self, rhs):
         # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
         rhs = EvalOpPoly(rhs)
-        retval = list(self.coefficients)
-        extra = len(rhs.coefficients) - len(retval)
-        if extra > 0:
-            retval.extend([Fraction(0)] * extra)
-        for i, v in enumerate(rhs.coefficients):
-            retval[i] += v
-        return EvalOpPoly(retval)
+        const_coeff = self.const_coeff + rhs.const_coeff
+        var_coeffs = list(self.var_coeffs)
+        if len(rhs.var_coeffs) > len(var_coeffs):
+            var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
+        for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
+            var_coeffs[var] += rhs.var_coeffs[var]
+        return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
+
+    @property
+    def coefficients(self):
+        # type: () -> dict[int | None, Fraction]
+        retval = {}  # type: dict[int | None, Fraction]
+        if self.const_coeff != 0:
+            retval[None] = self.const_coeff
+        for var, coeff in enumerate(self.var_coeffs):
+            if coeff != 0:
+                retval[var] = coeff
+        return retval
+
+    @property
+    def is_const(self):
+        # type: () -> bool
+        return self.var_coeffs == ()
+
+    def coeff(self, var):
+        # type: (int | None) -> Fraction
+        if var is None:
+            return self.const_coeff
+        if var < 0:
+            raise ValueError("invalid variable index")
+        if var < len(self.var_coeffs):
+            return self.var_coeffs[var]
+        return Fraction()
 
     __radd__ = __add__
 
     def __neg__(self):
-        return EvalOpPoly(-v for v in self.coefficients)
+        return EvalOpPoly(const_coeff=-self.const_coeff,
+                          var_coeffs=(-v for v in self.var_coeffs))
 
     def __sub__(self, rhs):
         # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
@@ -66,8 +133,17 @@ class EvalOpPoly:
         return lhs + -self
 
     def __mul__(self, rhs):
-        # type: (int | Fraction) -> EvalOpPoly
-        return EvalOpPoly(v * rhs for v in self.coefficients)
+        # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
+        if isinstance(rhs, EvalOpPoly):
+            if self.is_const:
+                self, rhs = rhs, self
+            if not rhs.is_const:
+                raise ValueError("can't represent exponents larger than one")
+            rhs = rhs.const_coeff
+        if rhs == 0:
+            return EvalOpPoly()
+        return EvalOpPoly(const_coeff=self.const_coeff * rhs,
+                          var_coeffs=(i * rhs for i in self.var_coeffs))
 
     __rmul__ = __mul__
 
@@ -75,7 +151,11 @@ class EvalOpPoly:
         # type: (int | Fraction) -> EvalOpPoly
         if rhs == 0:
             raise ZeroDivisionError()
-        return EvalOpPoly(v / rhs for v in self.coefficients)
+        return EvalOpPoly(const_coeff=self.const_coeff / rhs,
+                          var_coeffs=(i / rhs for i in self.var_coeffs))
+
+    def __repr__(self):
+        return f"EvalOpPoly({self.coefficients})"
 
 
 _EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
@@ -160,63 +240,155 @@ class EvalOpInput(EvalOp[int, Literal[0]]):
     def __init__(self, lhs, rhs=0):
         # type: (...) -> None
         if lhs < 0:
-            raise ValueError("Input split_index (lhs) must be >= 0")
+            raise ValueError("Input part_index (lhs) must be >= 0")
         if rhs != 0:
             raise ValueError("Input rhs must be 0")
         super().__init__(lhs, rhs)
 
     @property
-    def split_index(self):
+    def part_index(self):
         return self.lhs
 
     def _make_poly(self):
         # type: () -> EvalOpPoly
-        return EvalOpPoly([0] * self.split_index + [1])
+        return EvalOpPoly({self.part_index: 1})
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class ToomCookInstance:
-    __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points",
-                 "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops")
+    __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
+                 "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
+
+    @property
+    def prod_part_count(self):
+        return self.lhs_part_count + self.rhs_part_count - 1
+
+    @staticmethod
+    def make_eval_matrix(width, eval_points):
+        # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
+        retval = Matrix(height=len(eval_points), width=width)
+        for row, col in retval.indexes():
+            eval_point = eval_points[row]
+            if eval_point is POINT_AT_INFINITY:
+                retval[row, col] = int(col == width - 1)
+            else:
+                retval[row, col] = eval_point ** col
+        return retval
+
+    def get_lhs_eval_matrix(self):
+        # type: () -> Matrix[Fraction]
+        return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
+
+    @staticmethod
+    def make_input_poly_vector(height):
+        # type: (int) -> Matrix[EvalOpPoly]
+        return Matrix(height=height, width=1, element_type=EvalOpPoly,
+                      data=(EvalOpPoly({i: 1}) for i in range(height)))
+
+    def get_lhs_eval_polys(self):
+        # type: () -> list[EvalOpPoly]
+        return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
+                    @ self.make_input_poly_vector(self.lhs_part_count))
+
+    def get_rhs_eval_matrix(self):
+        # type: () -> Matrix[Fraction]
+        return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
+
+    def get_rhs_eval_polys(self):
+        # type: () -> list[EvalOpPoly]
+        return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
+                    @ self.make_input_poly_vector(self.rhs_part_count))
+
+    def get_prod_inverse_eval_matrix(self):
+        # type: () -> Matrix[Fraction]
+        return self.make_eval_matrix(self.prod_part_count, self.eval_points)
+
+    def get_prod_eval_matrix(self):
+        # type: () -> Matrix[Fraction]
+        return self.get_prod_inverse_eval_matrix().inverse()
+
+    def get_prod_eval_polys(self):
+        # type: () -> list[EvalOpPoly]
+        return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
+                    @ self.make_input_poly_vector(self.prod_part_count))
 
     def __init__(
-        self, lhs_split_count,  # type: int
-        rhs_split_count,  # type: int
+        self, lhs_part_count,  # type: int
+        rhs_part_count,  # type: int
         eval_points,  # type: Iterable[PointAtInfinity | int]
         lhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
         rhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
-        product_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+        prod_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
     ):
         # type: (...) -> None
-        self.lhs_split_count = lhs_split_count
-        if self.lhs_split_count < 2:
-            raise ValueError("lhs_split_count must be at least 2")
-        self.rhs_split_count = rhs_split_count
-        if self.rhs_split_count < 2:
-            raise ValueError("rhs_split_count must be at least 2")
+        self.lhs_part_count = lhs_part_count
+        if self.lhs_part_count < 2:
+            raise ValueError("lhs_part_count must be at least 2")
+        self.rhs_part_count = rhs_part_count
+        if self.rhs_part_count < 2:
+            raise ValueError("rhs_part_count must be at least 2")
         eval_points = list(eval_points)
-        self.eval_points = OFSet(eval_points)
-        if len(self.eval_points) != len(eval_points):
+        self.eval_points = tuple(eval_points)
+        if len(self.eval_points) != len(set(self.eval_points)):
             raise ValueError("duplicate eval points")
         self.lhs_eval_ops = tuple(lhs_eval_ops)
-        if len(self.lhs_eval_ops) != len(self.eval_points):
+        if len(self.lhs_eval_ops) != self.prod_part_count:
             raise ValueError("wrong number of lhs_eval_ops")
         self.rhs_eval_ops = tuple(rhs_eval_ops)
-        if len(self.rhs_eval_ops) != len(self.eval_points):
+        if len(self.rhs_eval_ops) != self.prod_part_count:
             raise ValueError("wrong number of rhs_eval_ops")
-        if self.lhs_split_count < 2:
-            raise ValueError("lhs_split_count must be at least 2")
-        if self.rhs_split_count < 2:
-            raise ValueError("rhs_split_count must be at least 2")
-        if (self.lhs_split_count + self.rhs_split_count - 1
-                != len(self.eval_points)):
+        if len(self.eval_points) != self.prod_part_count:
             raise ValueError("wrong number of eval_points")
-        self.product_eval_ops = tuple(product_eval_ops)
-        if len(self.product_eval_ops) != len(self.eval_points):
-            raise ValueError("wrong number of product_eval_ops")
-        # TODO: compute and check matrix and all the *_eval_ops
-        raise NotImplementedError
+        self.prod_eval_ops = tuple(prod_eval_ops)
+        if len(self.prod_eval_ops) != self.prod_part_count:
+            raise ValueError("wrong number of prod_eval_ops")
+
+        lhs_eval_polys = self.get_lhs_eval_polys()
+        for i, eval_op in enumerate(self.lhs_eval_ops):
+            if lhs_eval_polys[i] != eval_op.poly:
+                raise ValueError(
+                    f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
+                    f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+        rhs_eval_polys = self.get_rhs_eval_polys()
+        for i, eval_op in enumerate(self.rhs_eval_ops):
+            if rhs_eval_polys[i] != eval_op.poly:
+                raise ValueError(
+                    f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
+                    f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+        prod_eval_polys = self.get_prod_eval_polys()  # also checks matrix
+        for i, eval_op in enumerate(self.prod_eval_ops):
+            if prod_eval_polys[i] != eval_op.poly:
+                raise ValueError(
+                    f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
+                    f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+    @staticmethod
+    def make_toom_2():
+        # type: () -> ToomCookInstance
+        return ToomCookInstance(
+            lhs_part_count=2,
+            rhs_part_count=2,
+            eval_points=[0, 1, POINT_AT_INFINITY],
+            lhs_eval_ops=[
+                EvalOpInput(0),
+                EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
+                EvalOpInput(1),
+            ],
+            rhs_eval_ops=[
+                EvalOpInput(0),
+                EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
+                EvalOpInput(1),
+            ],
+            prod_eval_ops=[
+                EvalOpInput(0),
+                EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
+                          EvalOpInput(2)),
+                EvalOpInput(2),
+            ],
+        )
 
 
 def toom_cook_mul(fn, word_count, instances):