add Matrix class
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 6 Oct 2022 05:03:33 +0000 (22:03 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 6 Oct 2022 05:03:33 +0000 (22:03 -0700)
.gitignore
setup.py
src/bigint_presentation_code/matrix.py [new file with mode: 0644]
src/bigint_presentation_code/test_matrix.py [new file with mode: 0644]

index d0bb043e47040006c72567ae9f15b8256548c1b1..4134655cc4f0968d8562505e4b9080cabcc6920b 100644 (file)
@@ -4,3 +4,4 @@ __pycache__
 *.gtkw
 *.egg-info
 *.il
+/.vscode
index 5aaa3c25584589cc8435b12a6c1cd7bfa174c780..36d91f44a9c231cf00429a520d481affb91caea3 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -6,6 +6,7 @@ README = Path(__file__).with_name('README.md').read_text("UTF-8")
 version = '0.0.1'
 
 install_requires = [
+    "libresoc-nmutil",
     'libresoc-openpower-isa',
 ]
 
diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py
new file mode 100644 (file)
index 0000000..2636be8
--- /dev/null
@@ -0,0 +1,207 @@
+import operator
+from typing import Callable, Iterable
+from fractions import Fraction
+from numbers import Rational
+
+
+class Matrix:
+    __slots__ = "__height", "__width", "__data"
+
+    @property
+    def height(self):
+        return self.__height
+
+    @property
+    def width(self):
+        return self.__width
+
+    def __init__(self, height, width, data=None):
+        # type: (int, int, Iterable[Rational | int] | None) -> None
+        if width < 0 or height < 0:
+            raise ValueError("matrix size must be non-negative")
+        self.__height = height
+        self.__width = width
+        self.__data = [Fraction()] * (height * width)
+        if data is not None:
+            data = list(data)
+            if len(data) != len(self.__data):
+                raise ValueError("data has wrong length")
+            self.__data[:] = map(Fraction, data)
+
+    @staticmethod
+    def identity(height, width=None):
+        # type: (int, int | None) -> Matrix
+        if width is None:
+            width = height
+        retval = Matrix(height, width)
+        for i in range(min(height, width)):
+            retval[i, i] = 1
+        return retval
+
+    def __idx(self, row, col):
+        # type: (int, int) -> int
+        if 0 <= col < self.width and 0 <= row < self.height:
+            return row * self.width + col
+        raise IndexError()
+
+    def __getitem__(self, row_col):
+        # type: (tuple[int, int]) -> Fraction
+        row, col = row_col
+        return self.__data[self.__idx(row, col)]
+
+    def __setitem__(self, row_col, value):
+        # type: (tuple[int, int], Rational | int) -> None
+        row, col = row_col
+        self.__data[self.__idx(row, col)] = Fraction(value)
+
+    def copy(self):
+        retval = Matrix(self.width, self.height)
+        retval.__data[:] = self.__data
+        return retval
+
+    def indexes(self):
+        for row in range(self.height):
+            for col in range(self.width):
+                yield row, col
+
+    def __mul__(self, rhs):
+        # type: (Rational | int) -> Matrix
+        rhs = Fraction(rhs)
+        retval = self.copy()
+        for i in self.indexes():
+            retval[i] *= rhs
+        return retval
+
+    def __rmul__(self, lhs):
+        # type: (Rational | int) -> Matrix
+        return self.__mul__(lhs)
+
+    def __truediv__(self, rhs):
+        # type: (Rational | int) -> Matrix
+        rhs = 1 / Fraction(rhs)
+        retval = self.copy()
+        for i in self.indexes():
+            retval[i] *= rhs
+        return retval
+
+    def __matmul__(self, rhs):
+        # type: (Matrix) -> Matrix
+        if self.width != rhs.height:
+            raise ValueError(
+                "lhs width must equal rhs height to multiply matrixes")
+        retval = Matrix(self.height, rhs.width)
+        for row in range(retval.height):
+            for col in range(retval.width):
+                sum = Fraction()
+                for i in range(self.width):
+                    sum += self[row, i] * rhs[i, col]
+                retval[row, col] = sum
+        return retval
+
+    def __rmatmul__(self, lhs):
+        # type: (Matrix) -> Matrix
+        return lhs.__matmul__(self)
+
+    def __elementwise_bin_op(self, rhs, op):
+        # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
+        if self.height != rhs.height or self.width != rhs.width:
+            raise ValueError(
+                "matrix dimensions must match for element-wise operations")
+        retval = self.copy()
+        for i in retval.indexes():
+            retval[i] = op(retval[i], rhs[i])
+        return retval
+
+    def __add__(self, rhs):
+        # type: (Matrix) -> Matrix
+        return self.__elementwise_bin_op(rhs, operator.add)
+
+    def __radd__(self, lhs):
+        # type: (Matrix) -> Matrix
+        return lhs.__add__(self)
+
+    def __sub__(self, rhs):
+        # type: (Matrix) -> Matrix
+        return self.__elementwise_bin_op(rhs, operator.sub)
+
+    def __rsub__(self, lhs):
+        # type: (Matrix) -> Matrix
+        return lhs.__sub__(self)
+
+    def __iter__(self):
+        return iter(self.__data)
+
+    def __reversed__(self):
+        return reversed(self.__data)
+
+    def __neg__(self):
+        retval = self.copy()
+        for i in retval.indexes():
+            retval[i] = -retval[i]
+        return retval
+
+    def __repr__(self):
+        if self.height == 0 or self.width == 0:
+            return f"Matrix(height={self.height}, width={self.width})"
+        lines = []
+        line = []
+        for row in range(self.height):
+            line.clear()
+            for col in range(self.width):
+                if self[row, col].denominator == 1:
+                    line.append(str(self[row, col].numerator))
+                else:
+                    line.append(repr(self[row, col]))
+            lines.append(", ".join(line))
+        lines = ",\n    ".join(lines)
+        return (f"Matrix(height={self.height}, width={self.width}, data=[\n"
+                f"    {lines},\n])")
+
+    def __eq__(self, rhs):
+        if not isinstance(rhs, Matrix):
+            return NotImplemented
+        return (self.height == rhs.height
+                and self.width == rhs.width
+                and self.__data == rhs.__data)
+
+    def inverse(self):
+        size = self.height
+        if size != self.width:
+            raise ValueError("can't invert a non-square matrix")
+        inp = self.copy()
+        retval = Matrix.identity(size)
+        # the algorithm is adapted from:
+        # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
+        for k in range(size):
+            f = abs(inp[k, k])  # Find pivot.
+            p = k
+            for i in range(k + 1, size):
+                g = abs(inp[k, i])
+                if g > f:
+                    f = g
+                    p = i
+            if f == 0:
+                raise ZeroDivisionError("Matrix is singular")
+            if p != k:  # Swap rows.
+                for j in range(k, size):
+                    f = inp[j, k]
+                    inp[j, k] = inp[j, p]
+                    inp[j, p] = f
+                for j in range(size):
+                    f = retval[j, k]
+                    retval[j, k] = retval[j, p]
+                    retval[j, p] = f
+            f = 1 / inp[k, k]  # Scale row so pivot is 1.
+            for j in range(k, size):
+                inp[j, k] *= f
+            for j in range(size):
+                retval[j, k] *= f
+            for i in range(size):  # Subtract to get zeros.
+                if i == k:
+                    continue
+                f = inp[k, i]
+                for j in range(k, size):
+                    inp[j, i] -= inp[j, k] * f
+                for j in range(size):
+                    retval[j, i] -= retval[j, k] * f
+        return retval
diff --git a/src/bigint_presentation_code/test_matrix.py b/src/bigint_presentation_code/test_matrix.py
new file mode 100644 (file)
index 0000000..ef39742
--- /dev/null
@@ -0,0 +1,113 @@
+import unittest
+from fractions import Fraction
+
+from bigint_presentation_code.matrix import Matrix
+
+
+class TestMatrix(unittest.TestCase):
+    def test_repr(self):
+        self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
+                                            3, 4, 5])),
+                         'Matrix(height=2, width=3, data=[\n'
+                         '    0, 1, 2,\n'
+                         '    3, 4, 5,\n'
+                         '])')
+        self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
+                                            3, 4, 5])),
+                         'Matrix(height=2, width=3, data=[\n'
+                         '    0, 1, Fraction(2, 3),\n'
+                         '    3, 4, 5,\n'
+                         '])')
+        self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
+        self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
+
+    def test_eq(self):
+        self.assertFalse(Matrix(1, 1) == 5)
+        self.assertFalse(5 == Matrix(1, 1))
+        self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
+        self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
+        self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
+        self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
+        self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
+
+    def test_add(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+                         + Matrix(2, 2, [40, 30, 20, 10]),
+                         Matrix(2, 2, [41, 32, 23, 14]))
+
+    def test_identity(self):
+        self.assertEqual(Matrix.identity(2, 2),
+                         Matrix(2, 2, [1, 0,
+                                       0, 1]))
+        self.assertEqual(Matrix.identity(1, 3),
+                         Matrix(1, 3, [1, 0, 0]))
+        self.assertEqual(Matrix.identity(2, 3),
+                         Matrix(2, 3, [1, 0, 0,
+                                       0, 1, 0]))
+        self.assertEqual(Matrix.identity(3),
+                         Matrix(3, 3, [1, 0, 0,
+                                       0, 1, 0,
+                                       0, 0, 1]))
+
+    def test_sub(self):
+        self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
+                         - Matrix(2, 2, [-1, -2, -3, -4]),
+                         Matrix(2, 2, [41, 32, 23, 14]))
+
+    def test_neg(self):
+        self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
+                         Matrix(2, 2, [-40, -30, -20, -10]))
+
+    def test_mul(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
+                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+        self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
+                         Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+
+    def test_matmul(self):
+        self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+                         @ Matrix(2, 2, [4, 3, 2, 1]),
+                         Matrix(2, 2, [8, 5, 20, 13]))
+        self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
+                         @ Matrix(2, 1, [1, 2]),
+                         Matrix(3, 1, [16, 10, 4]))
+
+    def test_inverse(self):
+        self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
+        self.assertEqual(Matrix(1, 1, [2]).inverse(),
+                         Matrix(1, 1, [Fraction(1, 2)]))
+        self.assertEqual(Matrix(1, 1, [1]).inverse(),
+                         Matrix(1, 1, [1]))
+        self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
+                         Matrix(2, 2, [1, 0, -1, 1]))
+        self.assertEqual(Matrix(3, 3, [0, 1, 0,
+                                       1, 0, 0,
+                                       0, 0, 1]).inverse(),
+                         Matrix(3, 3, [0, 1, 0,
+                                       1, 0, 0,
+                                       0, 0, 1]))
+        _1_2 = Fraction(1, 2)
+        _1_3 = Fraction(1, 3)
+        _1_6 = Fraction(1, 6)
+        self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
+                                       1, 1, 1, 1, 1,
+                                       1, -1, 1, -1, 1,
+                                       1, -2, 4, -8, 16,
+                                       0, 0, 0, 0, 1]).inverse(),
+                         Matrix(5, 5, [1, 0, 0, 0, 0,
+                                       _1_2, _1_3, -1, _1_6, -2,
+                                       -1, _1_2, _1_2, 0, -1,
+                                       -_1_2, _1_6, _1_2, -_1_6, 2,
+                                       0, 0, 0, 0, 1]))
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(1, 1, [0]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [0, 0, 1, 1]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [1, 0, 1, 0]).inverse()
+        with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+            Matrix(2, 2, [1, 1, 1, 1]).inverse()
+
+
+if __name__ == "__main__":
+    unittest.main()