2 from fractions
import Fraction
3 from numbers
import Rational
4 from typing
import Callable
, Iterable
8 __slots__
= "__height", "__width", "__data"
18 def __init__(self
, height
, width
, data
=None):
19 # type: (int, int, Iterable[Rational | int] | None) -> None
20 if width
< 0 or height
< 0:
21 raise ValueError("matrix size must be non-negative")
22 self
.__height
= height
24 self
.__data
= [Fraction()] * (height
* width
)
27 if len(data
) != len(self
.__data
):
28 raise ValueError("data has wrong length")
29 self
.__data
[:] = map(Fraction
, data
)
32 def identity(height
, width
=None):
33 # type: (int, int | None) -> Matrix
36 retval
= Matrix(height
, width
)
37 for i
in range(min(height
, width
)):
41 def __idx(self
, row
, col
):
42 # type: (int, int) -> int
43 if 0 <= col
< self
.width
and 0 <= row
< self
.height
:
44 return row
* self
.width
+ col
47 def __getitem__(self
, row_col
):
48 # type: (tuple[int, int]) -> Fraction
50 return self
.__data
[self
.__idx
(row
, col
)]
52 def __setitem__(self
, row_col
, value
):
53 # type: (tuple[int, int], Rational | int) -> None
55 self
.__data
[self
.__idx
(row
, col
)] = Fraction(value
)
58 retval
= Matrix(self
.width
, self
.height
)
59 retval
.__data
[:] = self
.__data
63 for row
in range(self
.height
):
64 for col
in range(self
.width
):
67 def __mul__(self
, rhs
):
68 # type: (Rational | int) -> Matrix
71 for i
in self
.indexes():
75 def __rmul__(self
, lhs
):
76 # type: (Rational | int) -> Matrix
77 return self
.__mul
__(lhs
)
79 def __truediv__(self
, rhs
):
80 # type: (Rational | int) -> Matrix
81 rhs
= 1 / Fraction(rhs
)
83 for i
in self
.indexes():
87 def __matmul__(self
, rhs
):
88 # type: (Matrix) -> Matrix
89 if self
.width
!= rhs
.height
:
91 "lhs width must equal rhs height to multiply matrixes")
92 retval
= Matrix(self
.height
, rhs
.width
)
93 for row
in range(retval
.height
):
94 for col
in range(retval
.width
):
96 for i
in range(self
.width
):
97 sum += self
[row
, i
] * rhs
[i
, col
]
98 retval
[row
, col
] = sum
101 def __rmatmul__(self
, lhs
):
102 # type: (Matrix) -> Matrix
103 return lhs
.__matmul
__(self
)
105 def __elementwise_bin_op(self
, rhs
, op
):
106 # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
107 if self
.height
!= rhs
.height
or self
.width
!= rhs
.width
:
109 "matrix dimensions must match for element-wise operations")
111 for i
in retval
.indexes():
112 retval
[i
] = op(retval
[i
], rhs
[i
])
115 def __add__(self
, rhs
):
116 # type: (Matrix) -> Matrix
117 return self
.__elementwise
_bin
_op
(rhs
, operator
.add
)
119 def __radd__(self
, lhs
):
120 # type: (Matrix) -> Matrix
121 return lhs
.__add
__(self
)
123 def __sub__(self
, rhs
):
124 # type: (Matrix) -> Matrix
125 return self
.__elementwise
_bin
_op
(rhs
, operator
.sub
)
127 def __rsub__(self
, lhs
):
128 # type: (Matrix) -> Matrix
129 return lhs
.__sub
__(self
)
132 return iter(self
.__data
)
134 def __reversed__(self
):
135 return reversed(self
.__data
)
139 for i
in retval
.indexes():
140 retval
[i
] = -retval
[i
]
144 if self
.height
== 0 or self
.width
== 0:
145 return f
"Matrix(height={self.height}, width={self.width})"
148 for row
in range(self
.height
):
150 for col
in range(self
.width
):
151 if self
[row
, col
].denominator
== 1:
152 line
.append(str(self
[row
, col
].numerator
))
154 line
.append(repr(self
[row
, col
]))
155 lines
.append(", ".join(line
))
156 lines
= ",\n ".join(lines
)
157 return (f
"Matrix(height={self.height}, width={self.width}, data=[\n"
160 def __eq__(self
, rhs
):
161 if not isinstance(rhs
, Matrix
):
162 return NotImplemented
163 return (self
.height
== rhs
.height
164 and self
.width
== rhs
.width
165 and self
.__data
== rhs
.__data
)
169 if size
!= self
.width
:
170 raise ValueError("can't invert a non-square matrix")
172 retval
= Matrix
.identity(size
)
173 # the algorithm is adapted from:
174 # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
175 for k
in range(size
):
176 f
= abs(inp
[k
, k
]) # Find pivot.
178 for i
in range(k
+ 1, size
):
184 raise ZeroDivisionError("Matrix is singular")
185 if p
!= k
: # Swap rows.
186 for j
in range(k
, size
):
188 inp
[j
, k
] = inp
[j
, p
]
190 for j
in range(size
):
192 retval
[j
, k
] = retval
[j
, p
]
194 f
= 1 / inp
[k
, k
] # Scale row so pivot is 1.
195 for j
in range(k
, size
):
197 for j
in range(size
):
199 for i
in range(size
): # Subtract to get zeros.
203 for j
in range(k
, size
):
204 inp
[j
, i
] -= inp
[j
, k
] * f
205 for j
in range(size
):
206 retval
[j
, i
] -= retval
[j
, k
] * f