89c3ea23b2b4554cc1836a6773adc643713882c4
[bigint-presentation-code.git] / src / bigint_presentation_code / matrix.py
1 import operator
2 from enum import Enum, unique
3 from fractions import Fraction
4 from numbers import Rational
5 from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar
6
7 from bigint_presentation_code.util import final
8
9 _T = TypeVar("_T")
10 _T2 = TypeVar("_T2")
11
12
13 @final
14 @unique
15 class SpecialMatrix(Enum):
16 Zero = 0
17 Identity = 1
18
19
20 @final
21 class Matrix(Generic[_T]):
22 __slots__ = "__height", "__width", "__data", "__element_type"
23
24 @property
25 def height(self):
26 # type: () -> int
27 return self.__height
28
29 @property
30 def width(self):
31 # type: () -> int
32 return self.__width
33
34 @property
35 def element_type(self):
36 # type: () -> Type[_T]
37 return self.__element_type
38
39 def __init__(self, height, width, data=SpecialMatrix.Zero,
40 element_type=Fraction):
41 # type: (int, int, Iterable[_T | int | Any] | SpecialMatrix, Type[_T]) -> None
42 if width < 0 or height < 0:
43 raise ValueError("matrix size must be non-negative")
44 self.__height = height
45 self.__width = width
46 self.__element_type = element_type
47 if isinstance(data, SpecialMatrix):
48 self.__data = [element_type(0) for _ in range(height * width)]
49 if data is SpecialMatrix.Identity:
50 for i in range(min(width, height)):
51 self[i, i] = element_type(1)
52 else:
53 assert data is SpecialMatrix.Zero
54 else:
55 self.__data = [element_type(v) for v in data]
56 if len(self.__data) != height * width:
57 raise ValueError("data has wrong length")
58
59 def cast(self, element_type):
60 # type: (Type[_T2]) -> Matrix[_T2]
61 data = self # type: Iterable[Any]
62 return Matrix(self.height, self.width, data, element_type=element_type)
63
64 def __idx(self, row, col):
65 # type: (int, int) -> int
66 if 0 <= col < self.width and 0 <= row < self.height:
67 return row * self.width + col
68 raise IndexError()
69
70 def __getitem__(self, row_col):
71 # type: (tuple[int, int]) -> _T
72 row, col = row_col
73 return self.__data[self.__idx(row, col)]
74
75 def __setitem__(self, row_col, value):
76 # type: (tuple[int, int], _T | int) -> None
77 row, col = row_col
78 self.__data[self.__idx(row, col)] = self.__element_type(value)
79
80 def copy(self):
81 # type: () -> Matrix[_T]
82 return Matrix(self.width, self.height, data=self.__data,
83 element_type=self.element_type)
84
85 def indexes(self):
86 # type: () -> Iterable[tuple[int, int]]
87 for row in range(self.height):
88 for col in range(self.width):
89 yield row, col
90
91 def __mul__(self, rhs):
92 # type: (_T | int) -> Matrix[_T]
93 retval = self.copy()
94 for i in self.indexes():
95 retval[i] *= rhs # type: ignore
96 return retval
97
98 def __rmul__(self, lhs):
99 # type: (_T | int) -> Matrix[_T]
100 retval = self.copy()
101 for i in self.indexes():
102 retval[i] = lhs * retval[i] # type: ignore
103 return retval
104
105 def __truediv__(self, rhs):
106 # type: (Rational | int) -> Matrix
107 retval = self.copy()
108 for i in self.indexes():
109 retval[i] /= rhs # type: ignore
110 return retval
111
112 def __matmul__(self, rhs):
113 # type: (Matrix[_T]) -> Matrix[_T]
114 if self.width != rhs.height:
115 raise ValueError(
116 "lhs width must equal rhs height to multiply matrixes")
117 retval = Matrix(self.height, rhs.width, element_type=self.element_type)
118 for row in range(retval.height):
119 for col in range(retval.width):
120 sum = self.element_type()
121 for i in range(self.width):
122 sum += self[row, i] * rhs[i, col] # type: ignore
123 retval[row, col] = sum
124 return retval
125
126 def __rmatmul__(self, lhs):
127 # type: (Matrix[_T]) -> Matrix[_T]
128 return lhs.__matmul__(self)
129
130 def __elementwise_bin_op(self, rhs, op):
131 # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
132 if self.height != rhs.height or self.width != rhs.width:
133 raise ValueError(
134 "matrix dimensions must match for element-wise operations")
135 retval = self.copy()
136 for i in retval.indexes():
137 retval[i] = op(retval[i], rhs[i])
138 return retval
139
140 def __add__(self, rhs):
141 # type: (Matrix[_T]) -> Matrix[_T]
142 return self.__elementwise_bin_op(rhs, operator.add)
143
144 def __radd__(self, lhs):
145 # type: (Matrix[_T]) -> Matrix[_T]
146 return lhs.__add__(self)
147
148 def __sub__(self, rhs):
149 # type: (Matrix[_T]) -> Matrix[_T]
150 return self.__elementwise_bin_op(rhs, operator.sub)
151
152 def __rsub__(self, lhs):
153 # type: (Matrix[_T]) -> Matrix[_T]
154 return lhs.__sub__(self)
155
156 def __iter__(self):
157 # type: () -> Iterator[_T]
158 return iter(self.__data)
159
160 def __reversed__(self):
161 # type: () -> Iterator[_T]
162 return reversed(self.__data)
163
164 def __neg__(self):
165 # type: () -> Matrix[_T]
166 retval = self.copy()
167 for i in retval.indexes():
168 retval[i] = -retval[i] # type: ignore
169 return retval
170
171 def __repr__(self):
172 # type: () -> str
173 if self.height == 0 or self.width == 0:
174 return f"Matrix(height={self.height}, width={self.width})"
175 lines = []
176 line = []
177 for row in range(self.height):
178 line.clear()
179 for col in range(self.width):
180 el = self[row, col]
181 if isinstance(el, Fraction) and el.denominator == 1:
182 line.append(str(el.numerator))
183 else:
184 line.append(repr(el))
185 lines.append(", ".join(line))
186 lines = ",\n ".join(lines)
187 element_type = ""
188 if self.element_type is not Fraction:
189 element_type = f"element_type={self.element_type}, "
190 return (f"Matrix(height={self.height}, width={self.width}, "
191 f"{element_type}data=[\n"
192 f" {lines},\n])")
193
194 def __eq__(self, rhs):
195 # type: (object) -> bool
196 if not isinstance(rhs, Matrix):
197 return NotImplemented
198 return (self.height == rhs.height
199 and self.width == rhs.width
200 and self.__data == rhs.__data
201 and self.element_type == rhs.element_type)
202
203 def inverse(self # type: Matrix[Fraction]
204 ):
205 # type: () -> Matrix[Fraction]
206 size = self.height
207 if size != self.width:
208 raise ValueError("can't invert a non-square matrix")
209 if self.element_type is not Fraction:
210 raise TypeError("can't invert a matrix with element_type that "
211 "isn't Fraction")
212 inp = self.copy()
213 retval = Matrix(size, size, data=SpecialMatrix.Identity)
214 # the algorithm is adapted from:
215 # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
216 for k in range(size):
217 f = abs(inp[k, k]) # Find pivot.
218 p = k
219 for i in range(k + 1, size):
220 g = abs(inp[k, i])
221 if g > f:
222 f = g
223 p = i
224 if f == 0:
225 raise ZeroDivisionError("Matrix is singular")
226 if p != k: # Swap rows.
227 for j in range(k, size):
228 f = inp[j, k]
229 inp[j, k] = inp[j, p]
230 inp[j, p] = f
231 for j in range(size):
232 f = retval[j, k]
233 retval[j, k] = retval[j, p]
234 retval[j, p] = f
235 f = 1 / inp[k, k] # Scale row so pivot is 1.
236 for j in range(k, size):
237 inp[j, k] *= f
238 for j in range(size):
239 retval[j, k] *= f
240 for i in range(size): # Subtract to get zeros.
241 if i == k:
242 continue
243 f = inp[k, i]
244 for j in range(k, size):
245 inp[j, i] -= inp[j, k] * f
246 for j in range(size):
247 retval[j, i] -= retval[j, k] * f
248 return retval