working on toom-cook multiplication
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 19 Oct 2022 09:01:16 +0000 (02:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 19 Oct 2022 09:01:16 +0000 (02:01 -0700)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/ordered_set.py [deleted file]
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/toom_cook.py
src/bigint_presentation_code/util.py [new file with mode: 0644]
src/bigint_presentation_code/util.pyi [new file with mode: 0644]

index 3ebb7cff3e9aeb1404c1773af5b1f5a9152bf8a6..517e5423e267ed031da12ef08c01af7c7090b39e 100644 (file)
@@ -8,18 +8,11 @@ from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, EnumMeta, unique
 from functools import lru_cache
-from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type,
-                    TypeVar, cast)
+from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
 
 from nmutil.plain_data import fields, plain_data
 
-from bigint_presentation_code.ordered_set import OFSet, OSet
-
-if TYPE_CHECKING:
-    from typing_extensions import final
-else:
-    def final(v):
-        return v
+from bigint_presentation_code.util import OFSet, OSet, final
 
 
 class ABCEnumMeta(EnumMeta, ABCMeta):
diff --git a/src/bigint_presentation_code/ordered_set.py b/src/bigint_presentation_code/ordered_set.py
deleted file mode 100644 (file)
index 018f97b..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-from typing import AbstractSet, Iterable, MutableSet, TypeVar
-
-_T_co = TypeVar("_T_co", covariant=True)
-_T = TypeVar("_T")
-
-
-class OFSet(AbstractSet[_T_co]):
-    """ ordered frozen set """
-
-    def __init__(self, items=()):
-        # type: (Iterable[_T_co]) -> None
-        self.__items = {v: None for v in items}
-
-    def __contains__(self, x):
-        return x in self.__items
-
-    def __iter__(self):
-        return iter(self.__items)
-
-    def __len__(self):
-        return len(self.__items)
-
-    def __hash__(self):
-        return self._hash()
-
-    def __repr__(self):
-        if len(self) == 0:
-            return "OFSet()"
-        return f"OFSet({list(self)})"
-
-
-class OSet(MutableSet[_T]):
-    """ ordered mutable set """
-
-    def __init__(self, items=()):
-        # type: (Iterable[_T]) -> None
-        self.__items = {v: None for v in items}
-
-    def __contains__(self, x):
-        return x in self.__items
-
-    def __iter__(self):
-        return iter(self.__items)
-
-    def __len__(self):
-        return len(self.__items)
-
-    def add(self, value):
-        # type: (_T) -> None
-        self.__items[value] = None
-
-    def discard(self, value):
-        # type: (_T) -> None
-        self.__items.pop(value, None)
-
-    def __repr__(self):
-        if len(self) == 0:
-            return "OSet()"
-        return f"OSet({list(self)})"
index a22299f842badfdb4d48439907f7ec373b31ed52..b8269e4bc354f6c31b5433b10d19b71a39a38c1e 100644 (file)
@@ -6,20 +6,13 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
+from typing import Generic, Iterable, Mapping, TypeVar
 
 from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
                                                   RegLoc, RegType, SSAVal)
-from bigint_presentation_code.ordered_set import OFSet, OSet
-
-if TYPE_CHECKING:
-    from typing_extensions import final
-else:
-    def final(v):
-        return v
-
+from bigint_presentation_code.util import OFSet, OSet, final
 
 _RegType = TypeVar("_RegType", bound=RegType)
 
index c014d09a59a5d93a0428249bca81ac1d9bcf1fb1..865015e52a3f7e5b517106b27f273b2131f67cfe 100644 (file)
@@ -1,8 +1,226 @@
 """
-Toom-Cook algorithm generator for SVP64
-
-the register allocator uses an algorithm based on:
-[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+Toom-Cook multiplication algorithm generator for SVP64
 """
-from bigint_presentation_code.compiler_ir import Op
-from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed
+from abc import abstractmethod
+from enum import Enum
+from fractions import Fraction
+from typing import Any, Generic, Iterable, Sequence, TypeVar
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import Fn, Op
+from bigint_presentation_code.util import Literal, OFSet, OSet, final
+
+
+@final
+class PointAtInfinity(Enum):
+    POINT_AT_INFINITY = "POINT_AT_INFINITY"
+
+    def __repr__(self):
+        return self.name
+
+
+POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
+WORD_BITS = 64
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpPoly:
+    """polynomial"""
+    __slots__ = "coefficients",
+
+    def __init__(self, coefficients=()):
+        # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None
+        if isinstance(coefficients, EvalOpPoly):
+            coefficients = coefficients.coefficients
+        elif isinstance(coefficients, (int, Fraction)):
+            coefficients = coefficients,
+        v = list(map(Fraction, coefficients))
+        while len(v) != 0 and v[-1] == 0:
+            v.pop()
+        self.coefficients = tuple(v)  # type: tuple[Fraction, ...]
+
+    def __add__(self, rhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        rhs = EvalOpPoly(rhs)
+        retval = list(self.coefficients)
+        extra = len(rhs.coefficients) - len(retval)
+        if extra > 0:
+            retval.extend([Fraction(0)] * extra)
+        for i, v in enumerate(rhs.coefficients):
+            retval[i] += v
+        return EvalOpPoly(retval)
+
+    __radd__ = __add__
+
+    def __neg__(self):
+        return EvalOpPoly(-v for v in self.coefficients)
+
+    def __sub__(self, rhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        return self + -rhs
+
+    def __rsub__(self, lhs):
+        # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+        return lhs + -self
+
+    def __mul__(self, rhs):
+        # type: (int | Fraction) -> EvalOpPoly
+        return EvalOpPoly(v * rhs for v in self.coefficients)
+
+    __rmul__ = __mul__
+
+    def __truediv__(self, rhs):
+        # type: (int | Fraction) -> EvalOpPoly
+        if rhs == 0:
+            raise ZeroDivisionError()
+        return EvalOpPoly(v / rhs for v in self.coefficients)
+
+
+_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
+_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = "lhs", "rhs", "poly"
+
+    @property
+    def lhs_poly(self):
+        # type: () -> EvalOpPoly
+        if isinstance(self.lhs, int):
+            return EvalOpPoly(self.lhs)
+        return self.lhs.poly
+
+    @property
+    def rhs_poly(self):
+        # type: () -> EvalOpPoly
+        if isinstance(self.rhs, int):
+            return EvalOpPoly(self.rhs)
+        return self.rhs.poly
+
+    @abstractmethod
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        ...
+
+    def __init__(self, lhs, rhs):
+        # type: (_EvalOpLHS, _EvalOpRHS) -> None
+        self.lhs = lhs
+        self.rhs = rhs
+        self.poly = self._make_poly()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly + self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly - self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpMul(EvalOp[_EvalOpLHS, int]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly * self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
+    __slots__ = ()
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return self.lhs_poly / self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpInput(EvalOp[int, Literal[0]]):
+    __slots__ = ()
+
+    def __init__(self, lhs, rhs=0):
+        # type: (...) -> None
+        if lhs < 0:
+            raise ValueError("Input split_index (lhs) must be >= 0")
+        if rhs != 0:
+            raise ValueError("Input rhs must be 0")
+        super().__init__(lhs, rhs)
+
+    @property
+    def split_index(self):
+        return self.lhs
+
+    def _make_poly(self):
+        # type: () -> EvalOpPoly
+        return EvalOpPoly([0] * self.split_index + [1])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ToomCookInstance:
+    __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points",
+                 "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops")
+
+    def __init__(
+        self, lhs_split_count,  # type: int
+        rhs_split_count,  # type: int
+        eval_points,  # type: Iterable[PointAtInfinity | int]
+        lhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+        rhs_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+        product_eval_ops,  # type: Iterable[EvalOp[Any, Any]]
+    ):
+        # type: (...) -> None
+        self.lhs_split_count = lhs_split_count
+        if self.lhs_split_count < 2:
+            raise ValueError("lhs_split_count must be at least 2")
+        self.rhs_split_count = rhs_split_count
+        if self.rhs_split_count < 2:
+            raise ValueError("rhs_split_count must be at least 2")
+        eval_points = list(eval_points)
+        self.eval_points = OFSet(eval_points)
+        if len(self.eval_points) != len(eval_points):
+            raise ValueError("duplicate eval points")
+        self.lhs_eval_ops = tuple(lhs_eval_ops)
+        if len(self.lhs_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of lhs_eval_ops")
+        self.rhs_eval_ops = tuple(rhs_eval_ops)
+        if len(self.rhs_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of rhs_eval_ops")
+        if self.lhs_split_count < 2:
+            raise ValueError("lhs_split_count must be at least 2")
+        if self.rhs_split_count < 2:
+            raise ValueError("rhs_split_count must be at least 2")
+        if (self.lhs_split_count + self.rhs_split_count - 1
+                != len(self.eval_points)):
+            raise ValueError("wrong number of eval_points")
+        self.product_eval_ops = tuple(product_eval_ops)
+        if len(self.product_eval_ops) != len(self.eval_points):
+            raise ValueError("wrong number of product_eval_ops")
+        # TODO: compute and check matrix and all the *_eval_ops
+        raise NotImplementedError
+
+
+def toom_cook_mul(fn, word_count, instances):
+    # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
+    retval = OSet()  # type: OSet[Op]
+    raise NotImplementedError
+    return retval
diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py
new file mode 100644 (file)
index 0000000..b8b2934
--- /dev/null
@@ -0,0 +1,113 @@
+from typing import (TYPE_CHECKING, AbstractSet, Iterable, Iterator, Mapping,
+                    MutableSet, TypeVar, Union)
+
+if TYPE_CHECKING:
+    from typing_extensions import Literal, final
+else:
+    def final(v):
+        return v
+
+    class _Literal:
+        def __getitem__(self, v):
+            if isinstance(v, tuple):
+                return Union[tuple(type(i) for i in v)]
+            return type(v)
+
+    Literal = _Literal()
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"]
+
+
+class OFSet(AbstractSet[_T_co]):
+    """ ordered frozen set """
+    __slots__ = "__items",
+
+    def __init__(self, items=()):
+        # type: (Iterable[_T_co]) -> None
+        self.__items = {v: None for v in items}
+
+    def __contains__(self, x):
+        return x in self.__items
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def __hash__(self):
+        return self._hash()
+
+    def __repr__(self):
+        if len(self) == 0:
+            return "OFSet()"
+        return f"OFSet({list(self)})"
+
+
+class OSet(MutableSet[_T]):
+    """ ordered mutable set """
+    __slots__ = "__items",
+
+    def __init__(self, items=()):
+        # type: (Iterable[_T]) -> None
+        self.__items = {v: None for v in items}
+
+    def __contains__(self, x):
+        return x in self.__items
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def add(self, value):
+        # type: (_T) -> None
+        self.__items[value] = None
+
+    def discard(self, value):
+        # type: (_T) -> None
+        self.__items.pop(value, None)
+
+    def __repr__(self):
+        if len(self) == 0:
+            return "OSet()"
+        return f"OSet({list(self)})"
+
+
+class FMap(Mapping[_T, _T_co]):
+    """ordered frozen hashable mapping"""
+    __slots__ = "__items", "__hash"
+
+    def __init__(self, items=()):
+        # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
+        self.__items = dict(items)  # type: dict[_T, _T_co]
+        self.__hash = None  # type: None | int
+
+    def __getitem__(self, item):
+        # type: (_T) -> _T_co
+        return self.__items[item]
+
+    def __iter__(self):
+        # type: () -> Iterator[_T]
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def __eq__(self, other):
+        # type: (object) -> bool
+        if isinstance(other, FMap):
+            return self.__items == other.__items
+        return super().__eq__(other)
+
+    def __hash__(self):
+        if self.__hash is None:
+            self.__hash = hash(frozenset(self.items()))
+        return self.__hash
+
+    def __repr__(self):
+        return f"FMap({self.__items})"
diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi
new file mode 100644 (file)
index 0000000..48445b1
--- /dev/null
@@ -0,0 +1,81 @@
+from typing import (AbstractSet, Iterable, Iterator, Mapping,
+                    MutableSet, TypeVar, overload)
+from typing_extensions import final, Literal
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"]
+
+
+class OFSet(AbstractSet[_T_co]):
+    """ ordered frozen set """
+
+    def __init__(self, items: Iterable[_T_co] = ()):
+        ...
+
+    def __contains__(self, x: object) -> bool:
+        ...
+
+    def __iter__(self) -> Iterator[_T_co]:
+        ...
+
+    def __len__(self) -> int:
+        ...
+
+    def __hash__(self) -> int:
+        ...
+
+    def __repr__(self) -> str:
+        ...
+
+
+class OSet(MutableSet[_T]):
+    """ ordered mutable set """
+
+    def __init__(self, items: Iterable[_T] = ()):
+        ...
+
+    def __contains__(self, x: object) -> bool:
+        ...
+
+    def __iter__(self) -> Iterator[_T]:
+        ...
+
+    def __len__(self) -> int:
+        ...
+
+    def add(self, value: _T) -> None:
+        ...
+
+    def discard(self, value: _T) -> None:
+        ...
+
+    def __repr__(self) -> str:
+        ...
+
+
+class FMap(Mapping[_T, _T_co]):
+    """ordered frozen hashable mapping"""
+    @overload
+    def __init__(self, items: Mapping[_T, _T_co] = ...): ...
+    @overload
+    def __init__(self, items: Iterable[tuple[_T, _T_co]] = ...): ...
+
+    def __getitem__(self, item: _T) -> _T_co:
+        ...
+
+    def __iter__(self) -> Iterator[_T]:
+        ...
+
+    def __len__(self) -> int:
+        ...
+
+    def __eq__(self, other: object) -> bool:
+        ...
+
+    def __hash__(self) -> int:
+        ...
+
+    def __repr__(self) -> str:
+        ...