make Matrix support element types other than Fraction
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Oct 2022 03:23:14 +0000 (20:23 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 20 Oct 2022 03:23:59 +0000 (20:23 -0700)
src/bigint_presentation_code/matrix.py
src/bigint_presentation_code/test_matrix.py

index 3e1e1540143eb8b58cd8f59174f56dc7663736a7..49acddffa097f05b150c927d0a48c48a5efdec80 100644 (file)
@@ -1,42 +1,59 @@
 import operator
+from enum import Enum, unique
 from fractions import Fraction
 from numbers import Rational
-from typing import Callable, Iterable
+from typing import Callable, Generic, Iterable, Iterator, Type, TypeVar
 
+from bigint_presentation_code.util import final
 
-class Matrix:
-    __slots__ = "__height", "__width", "__data"
+_T = TypeVar("_T")
+
+
+@final
+@unique
+class SpecialMatrix(Enum):
+    Zero = 0
+    Identity = 1
+
+
+@final
+class Matrix(Generic[_T]):
+    __slots__ = "__height", "__width", "__data", "__element_type"
 
     @property
     def height(self):
+        # type: () -> int
         return self.__height
 
     @property
     def width(self):
+        # type: () -> int
         return self.__width
 
-    def __init__(self, height, width, data=None):
-        # type: (int, int, Iterable[Rational | int] | None) -> None
+    @property
+    def element_type(self):
+        # type: () -> Type[_T]
+        return self.__element_type
+
+    def __init__(self, height, width, data=SpecialMatrix.Zero,
+                 element_type=Fraction):
+        # type: (int, int, Iterable[_T | int] | SpecialMatrix, Type[_T]) -> None
         if width < 0 or height < 0:
             raise ValueError("matrix size must be non-negative")
         self.__height = height
         self.__width = width
-        self.__data = [Fraction()] * (height * width)
-        if data is not None:
-            data = list(data)
-            if len(data) != len(self.__data):
+        self.__element_type = element_type
+        if isinstance(data, SpecialMatrix):
+            self.__data = [element_type(0) for _ in range(height * width)]
+            if data is SpecialMatrix.Identity:
+                for i in range(min(width, height)):
+                    self[i, i] = element_type(1)
+            else:
+                assert data is SpecialMatrix.Zero
+        else:
+            self.__data = [element_type(v) for v in data]
+            if len(self.__data) != height * width:
                 raise ValueError("data has wrong length")
-            self.__data[:] = map(Fraction, data)
-
-    @staticmethod
-    def identity(height, width=None):
-        # type: (int, int | None) -> Matrix
-        if width is None:
-            width = height
-        retval = Matrix(height, width)
-        for i in range(min(height, width)):
-            retval[i, i] = 1
-        return retval
 
     def __idx(self, row, col):
         # type: (int, int) -> int
@@ -45,65 +62,67 @@ class Matrix:
         raise IndexError()
 
     def __getitem__(self, row_col):
-        # type: (tuple[int, int]) -> Fraction
+        # type: (tuple[int, int]) -> _T
         row, col = row_col
         return self.__data[self.__idx(row, col)]
 
     def __setitem__(self, row_col, value):
-        # type: (tuple[int, int], Rational | int) -> None
+        # type: (tuple[int, int], _T | int) -> None
         row, col = row_col
-        self.__data[self.__idx(row, col)] = Fraction(value)
+        self.__data[self.__idx(row, col)] = self.__element_type(value)
 
     def copy(self):
-        retval = Matrix(self.width, self.height)
-        retval.__data[:] = self.__data
-        return retval
+        # type: () -> Matrix[_T]
+        return Matrix(self.width, self.height, data=self.__data,
+                      element_type=self.element_type)
 
     def indexes(self):
+        # type: () -> Iterable[tuple[int, int]]
         for row in range(self.height):
             for col in range(self.width):
                 yield row, col
 
     def __mul__(self, rhs):
-        # type: (Rational | int) -> Matrix
-        rhs = Fraction(rhs)
+        # type: (_T | int) -> Matrix[_T]
         retval = self.copy()
         for i in self.indexes():
-            retval[i] *= rhs
+            retval[i] *= rhs  # type: ignore
         return retval
 
     def __rmul__(self, lhs):
-        # type: (Rational | int) -> Matrix
-        return self.__mul__(lhs)
+        # type: (_T | int) -> Matrix[_T]
+        retval = self.copy()
+        for i in self.indexes():
+            retval[i] = lhs * retval[i]  # type: ignore
+        return retval
 
     def __truediv__(self, rhs):
         # type: (Rational | int) -> Matrix
-        rhs = 1 / Fraction(rhs)
         retval = self.copy()
         for i in self.indexes():
-            retval[i] *= rhs
+            retval[i] /= rhs  # type: ignore
         return retval
 
     def __matmul__(self, rhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         if self.width != rhs.height:
             raise ValueError(
                 "lhs width must equal rhs height to multiply matrixes")
-        retval = Matrix(self.height, rhs.width)
+        retval = Matrix(self.height, rhs.width, element_type=self.element_type)
         for row in range(retval.height):
             for col in range(retval.width):
-                sum = Fraction()
+                sum = self.element_type()
                 for i in range(self.width):
-                    sum += self[row, i] * rhs[i, col]
+                    sum += self[row, i] * rhs[i, col]  # type: ignore
                 retval[row, col] = sum
         return retval
 
     def __rmatmul__(self, lhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         return lhs.__matmul__(self)
 
     def __elementwise_bin_op(self, rhs, op):
-        # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
+        # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
         if self.height != rhs.height or self.width != rhs.width:
             raise ValueError(
                 "matrix dimensions must match for element-wise operations")
@@ -113,34 +132,38 @@ class Matrix:
         return retval
 
     def __add__(self, rhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         return self.__elementwise_bin_op(rhs, operator.add)
 
     def __radd__(self, lhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         return lhs.__add__(self)
 
     def __sub__(self, rhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         return self.__elementwise_bin_op(rhs, operator.sub)
 
     def __rsub__(self, lhs):
-        # type: (Matrix) -> Matrix
+        # type: (Matrix[_T]) -> Matrix[_T]
         return lhs.__sub__(self)
 
     def __iter__(self):
+        # type: () -> Iterator[_T]
         return iter(self.__data)
 
     def __reversed__(self):
+        # type: () -> Iterator[_T]
         return reversed(self.__data)
 
     def __neg__(self):
+        # type: () -> Matrix[_T]
         retval = self.copy()
         for i in retval.indexes():
-            retval[i] = -retval[i]
+            retval[i] = -retval[i]  # type: ignore
         return retval
 
     def __repr__(self):
+        # type: () -> str
         if self.height == 0 or self.width == 0:
             return f"Matrix(height={self.height}, width={self.width})"
         lines = []
@@ -148,28 +171,40 @@ class Matrix:
         for row in range(self.height):
             line.clear()
             for col in range(self.width):
-                if self[row, col].denominator == 1:
-                    line.append(str(self[row, col].numerator))
+                el = self[row, col]
+                if isinstance(el, Fraction) and el.denominator == 1:
+                    line.append(str(el.numerator))
                 else:
-                    line.append(repr(self[row, col]))
+                    line.append(repr(el))
             lines.append(", ".join(line))
         lines = ",\n    ".join(lines)
-        return (f"Matrix(height={self.height}, width={self.width}, data=[\n"
+        element_type = ""
+        if self.element_type is not Fraction:
+            element_type = f"element_type={self.element_type}, "
+        return (f"Matrix(height={self.height}, width={self.width}, "
+                f"{element_type}data=[\n"
                 f"    {lines},\n])")
 
     def __eq__(self, rhs):
+        # type: (object) -> bool
         if not isinstance(rhs, Matrix):
             return NotImplemented
         return (self.height == rhs.height
                 and self.width == rhs.width
-                and self.__data == rhs.__data)
+                and self.__data == rhs.__data
+                and self.element_type == rhs.element_type)
 
-    def inverse(self):
+    def inverse(self  # type: Matrix[Fraction]
+                ):
+        # type: () -> Matrix[Fraction]
         size = self.height
         if size != self.width:
             raise ValueError("can't invert a non-square matrix")
+        if self.element_type is not Fraction:
+            raise TypeError("can't invert a matrix with element_type that "
+                            "isn't Fraction")
         inp = self.copy()
-        retval = Matrix.identity(size)
+        retval = Matrix(size, size, data=SpecialMatrix.Identity)
         # the algorithm is adapted from:
         # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
         for k in range(size):
index ef39742cf0c225ab4242d0b2d0b892d570848c73..1a56df005b31718711305d3cbd5f128fb0b4d354 100644 (file)
@@ -1,7 +1,7 @@
 import unittest
 from fractions import Fraction
 
-from bigint_presentation_code.matrix import Matrix
+from bigint_presentation_code.matrix import Matrix, SpecialMatrix
 
 
 class TestMatrix(unittest.TestCase):
@@ -36,15 +36,15 @@ class TestMatrix(unittest.TestCase):
                          Matrix(2, 2, [41, 32, 23, 14]))
 
     def test_identity(self):
-        self.assertEqual(Matrix.identity(2, 2),
+        self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity),
                          Matrix(2, 2, [1, 0,
                                        0, 1]))
-        self.assertEqual(Matrix.identity(1, 3),
+        self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity),
                          Matrix(1, 3, [1, 0, 0]))
-        self.assertEqual(Matrix.identity(2, 3),
+        self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity),
                          Matrix(2, 3, [1, 0, 0,
                                        0, 1, 0]))
-        self.assertEqual(Matrix.identity(3),
+        self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity),
                          Matrix(3, 3, [1, 0, 0,
                                        0, 1, 0,
                                        0, 0, 1]))