working on generating output assembly
[bigint-presentation-code.git] / src / bigint_presentation_code / matrix.py
1 import operator
2 from fractions import Fraction
3 from numbers import Rational
4 from typing import Callable, Iterable
5
6
7 class Matrix:
8 __slots__ = "__height", "__width", "__data"
9
10 @property
11 def height(self):
12 return self.__height
13
14 @property
15 def width(self):
16 return self.__width
17
18 def __init__(self, height, width, data=None):
19 # type: (int, int, Iterable[Rational | int] | None) -> None
20 if width < 0 or height < 0:
21 raise ValueError("matrix size must be non-negative")
22 self.__height = height
23 self.__width = width
24 self.__data = [Fraction()] * (height * width)
25 if data is not None:
26 data = list(data)
27 if len(data) != len(self.__data):
28 raise ValueError("data has wrong length")
29 self.__data[:] = map(Fraction, data)
30
31 @staticmethod
32 def identity(height, width=None):
33 # type: (int, int | None) -> Matrix
34 if width is None:
35 width = height
36 retval = Matrix(height, width)
37 for i in range(min(height, width)):
38 retval[i, i] = 1
39 return retval
40
41 def __idx(self, row, col):
42 # type: (int, int) -> int
43 if 0 <= col < self.width and 0 <= row < self.height:
44 return row * self.width + col
45 raise IndexError()
46
47 def __getitem__(self, row_col):
48 # type: (tuple[int, int]) -> Fraction
49 row, col = row_col
50 return self.__data[self.__idx(row, col)]
51
52 def __setitem__(self, row_col, value):
53 # type: (tuple[int, int], Rational | int) -> None
54 row, col = row_col
55 self.__data[self.__idx(row, col)] = Fraction(value)
56
57 def copy(self):
58 retval = Matrix(self.width, self.height)
59 retval.__data[:] = self.__data
60 return retval
61
62 def indexes(self):
63 for row in range(self.height):
64 for col in range(self.width):
65 yield row, col
66
67 def __mul__(self, rhs):
68 # type: (Rational | int) -> Matrix
69 rhs = Fraction(rhs)
70 retval = self.copy()
71 for i in self.indexes():
72 retval[i] *= rhs
73 return retval
74
75 def __rmul__(self, lhs):
76 # type: (Rational | int) -> Matrix
77 return self.__mul__(lhs)
78
79 def __truediv__(self, rhs):
80 # type: (Rational | int) -> Matrix
81 rhs = 1 / Fraction(rhs)
82 retval = self.copy()
83 for i in self.indexes():
84 retval[i] *= rhs
85 return retval
86
87 def __matmul__(self, rhs):
88 # type: (Matrix) -> Matrix
89 if self.width != rhs.height:
90 raise ValueError(
91 "lhs width must equal rhs height to multiply matrixes")
92 retval = Matrix(self.height, rhs.width)
93 for row in range(retval.height):
94 for col in range(retval.width):
95 sum = Fraction()
96 for i in range(self.width):
97 sum += self[row, i] * rhs[i, col]
98 retval[row, col] = sum
99 return retval
100
101 def __rmatmul__(self, lhs):
102 # type: (Matrix) -> Matrix
103 return lhs.__matmul__(self)
104
105 def __elementwise_bin_op(self, rhs, op):
106 # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
107 if self.height != rhs.height or self.width != rhs.width:
108 raise ValueError(
109 "matrix dimensions must match for element-wise operations")
110 retval = self.copy()
111 for i in retval.indexes():
112 retval[i] = op(retval[i], rhs[i])
113 return retval
114
115 def __add__(self, rhs):
116 # type: (Matrix) -> Matrix
117 return self.__elementwise_bin_op(rhs, operator.add)
118
119 def __radd__(self, lhs):
120 # type: (Matrix) -> Matrix
121 return lhs.__add__(self)
122
123 def __sub__(self, rhs):
124 # type: (Matrix) -> Matrix
125 return self.__elementwise_bin_op(rhs, operator.sub)
126
127 def __rsub__(self, lhs):
128 # type: (Matrix) -> Matrix
129 return lhs.__sub__(self)
130
131 def __iter__(self):
132 return iter(self.__data)
133
134 def __reversed__(self):
135 return reversed(self.__data)
136
137 def __neg__(self):
138 retval = self.copy()
139 for i in retval.indexes():
140 retval[i] = -retval[i]
141 return retval
142
143 def __repr__(self):
144 if self.height == 0 or self.width == 0:
145 return f"Matrix(height={self.height}, width={self.width})"
146 lines = []
147 line = []
148 for row in range(self.height):
149 line.clear()
150 for col in range(self.width):
151 if self[row, col].denominator == 1:
152 line.append(str(self[row, col].numerator))
153 else:
154 line.append(repr(self[row, col]))
155 lines.append(", ".join(line))
156 lines = ",\n ".join(lines)
157 return (f"Matrix(height={self.height}, width={self.width}, data=[\n"
158 f" {lines},\n])")
159
160 def __eq__(self, rhs):
161 if not isinstance(rhs, Matrix):
162 return NotImplemented
163 return (self.height == rhs.height
164 and self.width == rhs.width
165 and self.__data == rhs.__data)
166
167 def inverse(self):
168 size = self.height
169 if size != self.width:
170 raise ValueError("can't invert a non-square matrix")
171 inp = self.copy()
172 retval = Matrix.identity(size)
173 # the algorithm is adapted from:
174 # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
175 for k in range(size):
176 f = abs(inp[k, k]) # Find pivot.
177 p = k
178 for i in range(k + 1, size):
179 g = abs(inp[k, i])
180 if g > f:
181 f = g
182 p = i
183 if f == 0:
184 raise ZeroDivisionError("Matrix is singular")
185 if p != k: # Swap rows.
186 for j in range(k, size):
187 f = inp[j, k]
188 inp[j, k] = inp[j, p]
189 inp[j, p] = f
190 for j in range(size):
191 f = retval[j, k]
192 retval[j, k] = retval[j, p]
193 retval[j, p] = f
194 f = 1 / inp[k, k] # Scale row so pivot is 1.
195 for j in range(k, size):
196 inp[j, k] *= f
197 for j in range(size):
198 retval[j, k] *= f
199 for i in range(size): # Subtract to get zeros.
200 if i == k:
201 continue
202 f = inp[k, i]
203 for j in range(k, size):
204 inp[j, i] -= inp[j, k] * f
205 for j in range(size):
206 retval[j, i] -= retval[j, k] * f
207 return retval