89c3ea23b2b4554cc1836a6773adc643713882c4
2 from enum
import Enum
, unique
3 from fractions
import Fraction
4 from numbers
import Rational
5 from typing
import Any
, Callable
, Generic
, Iterable
, Iterator
, Type
, TypeVar
7 from bigint_presentation_code
.util
import final
15 class SpecialMatrix(Enum
):
21 class Matrix(Generic
[_T
]):
22 __slots__
= "__height", "__width", "__data", "__element_type"
35 def element_type(self
):
36 # type: () -> Type[_T]
37 return self
.__element
_type
39 def __init__(self
, height
, width
, data
=SpecialMatrix
.Zero
,
40 element_type
=Fraction
):
41 # type: (int, int, Iterable[_T | int | Any] | SpecialMatrix, Type[_T]) -> None
42 if width
< 0 or height
< 0:
43 raise ValueError("matrix size must be non-negative")
44 self
.__height
= height
46 self
.__element
_type
= element_type
47 if isinstance(data
, SpecialMatrix
):
48 self
.__data
= [element_type(0) for _
in range(height
* width
)]
49 if data
is SpecialMatrix
.Identity
:
50 for i
in range(min(width
, height
)):
51 self
[i
, i
] = element_type(1)
53 assert data
is SpecialMatrix
.Zero
55 self
.__data
= [element_type(v
) for v
in data
]
56 if len(self
.__data
) != height
* width
:
57 raise ValueError("data has wrong length")
59 def cast(self
, element_type
):
60 # type: (Type[_T2]) -> Matrix[_T2]
61 data
= self
# type: Iterable[Any]
62 return Matrix(self
.height
, self
.width
, data
, element_type
=element_type
)
64 def __idx(self
, row
, col
):
65 # type: (int, int) -> int
66 if 0 <= col
< self
.width
and 0 <= row
< self
.height
:
67 return row
* self
.width
+ col
70 def __getitem__(self
, row_col
):
71 # type: (tuple[int, int]) -> _T
73 return self
.__data
[self
.__idx
(row
, col
)]
75 def __setitem__(self
, row_col
, value
):
76 # type: (tuple[int, int], _T | int) -> None
78 self
.__data
[self
.__idx
(row
, col
)] = self
.__element
_type
(value
)
81 # type: () -> Matrix[_T]
82 return Matrix(self
.width
, self
.height
, data
=self
.__data
,
83 element_type
=self
.element_type
)
86 # type: () -> Iterable[tuple[int, int]]
87 for row
in range(self
.height
):
88 for col
in range(self
.width
):
91 def __mul__(self
, rhs
):
92 # type: (_T | int) -> Matrix[_T]
94 for i
in self
.indexes():
95 retval
[i
] *= rhs
# type: ignore
98 def __rmul__(self
, lhs
):
99 # type: (_T | int) -> Matrix[_T]
101 for i
in self
.indexes():
102 retval
[i
] = lhs
* retval
[i
] # type: ignore
105 def __truediv__(self
, rhs
):
106 # type: (Rational | int) -> Matrix
108 for i
in self
.indexes():
109 retval
[i
] /= rhs
# type: ignore
112 def __matmul__(self
, rhs
):
113 # type: (Matrix[_T]) -> Matrix[_T]
114 if self
.width
!= rhs
.height
:
116 "lhs width must equal rhs height to multiply matrixes")
117 retval
= Matrix(self
.height
, rhs
.width
, element_type
=self
.element_type
)
118 for row
in range(retval
.height
):
119 for col
in range(retval
.width
):
120 sum = self
.element_type()
121 for i
in range(self
.width
):
122 sum += self
[row
, i
] * rhs
[i
, col
] # type: ignore
123 retval
[row
, col
] = sum
126 def __rmatmul__(self
, lhs
):
127 # type: (Matrix[_T]) -> Matrix[_T]
128 return lhs
.__matmul
__(self
)
130 def __elementwise_bin_op(self
, rhs
, op
):
131 # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
132 if self
.height
!= rhs
.height
or self
.width
!= rhs
.width
:
134 "matrix dimensions must match for element-wise operations")
136 for i
in retval
.indexes():
137 retval
[i
] = op(retval
[i
], rhs
[i
])
140 def __add__(self
, rhs
):
141 # type: (Matrix[_T]) -> Matrix[_T]
142 return self
.__elementwise
_bin
_op
(rhs
, operator
.add
)
144 def __radd__(self
, lhs
):
145 # type: (Matrix[_T]) -> Matrix[_T]
146 return lhs
.__add
__(self
)
148 def __sub__(self
, rhs
):
149 # type: (Matrix[_T]) -> Matrix[_T]
150 return self
.__elementwise
_bin
_op
(rhs
, operator
.sub
)
152 def __rsub__(self
, lhs
):
153 # type: (Matrix[_T]) -> Matrix[_T]
154 return lhs
.__sub
__(self
)
157 # type: () -> Iterator[_T]
158 return iter(self
.__data
)
160 def __reversed__(self
):
161 # type: () -> Iterator[_T]
162 return reversed(self
.__data
)
165 # type: () -> Matrix[_T]
167 for i
in retval
.indexes():
168 retval
[i
] = -retval
[i
] # type: ignore
173 if self
.height
== 0 or self
.width
== 0:
174 return f
"Matrix(height={self.height}, width={self.width})"
177 for row
in range(self
.height
):
179 for col
in range(self
.width
):
181 if isinstance(el
, Fraction
) and el
.denominator
== 1:
182 line
.append(str(el
.numerator
))
184 line
.append(repr(el
))
185 lines
.append(", ".join(line
))
186 lines
= ",\n ".join(lines
)
188 if self
.element_type
is not Fraction
:
189 element_type
= f
"element_type={self.element_type}, "
190 return (f
"Matrix(height={self.height}, width={self.width}, "
191 f
"{element_type}data=[\n"
194 def __eq__(self
, rhs
):
195 # type: (object) -> bool
196 if not isinstance(rhs
, Matrix
):
197 return NotImplemented
198 return (self
.height
== rhs
.height
199 and self
.width
== rhs
.width
200 and self
.__data
== rhs
.__data
201 and self
.element_type
== rhs
.element_type
)
203 def inverse(self
# type: Matrix[Fraction]
205 # type: () -> Matrix[Fraction]
207 if size
!= self
.width
:
208 raise ValueError("can't invert a non-square matrix")
209 if self
.element_type
is not Fraction
:
210 raise TypeError("can't invert a matrix with element_type that "
213 retval
= Matrix(size
, size
, data
=SpecialMatrix
.Identity
)
214 # the algorithm is adapted from:
215 # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
216 for k
in range(size
):
217 f
= abs(inp
[k
, k
]) # Find pivot.
219 for i
in range(k
+ 1, size
):
225 raise ZeroDivisionError("Matrix is singular")
226 if p
!= k
: # Swap rows.
227 for j
in range(k
, size
):
229 inp
[j
, k
] = inp
[j
, p
]
231 for j
in range(size
):
233 retval
[j
, k
] = retval
[j
, p
]
235 f
= 1 / inp
[k
, k
] # Scale row so pivot is 1.
236 for j
in range(k
, size
):
238 for j
in range(size
):
240 for i
in range(size
): # Subtract to get zeros.
244 for j
in range(k
, size
):
245 inp
[j
, i
] -= inp
[j
, k
] * f
246 for j
in range(size
):
247 retval
[j
, i
] -= retval
[j
, k
] * f