working on rewriting compiler ir to fix reg alloc issues
[bigint-presentation-code.git] / src / bigint_presentation_code / matrix.py
index 89c3ea23b2b4554cc1836a6773adc643713882c4..0674c02269b3bdcf4ef3dbe741e7acab562ee905 100644 (file)
@@ -1,10 +1,9 @@
-import operator
 from enum import Enum, unique
 from fractions import Fraction
-from numbers import Rational
+import operator
 from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar
 
-from bigint_presentation_code.util import final
+from bigint_presentation_code.type_util import final
 
 _T = TypeVar("_T")
 _T2 = TypeVar("_T2")
@@ -103,7 +102,7 @@ class Matrix(Generic[_T]):
         return retval
 
     def __truediv__(self, rhs):
-        # type: (Rational | int) -> Matrix
+        # type: (_T | int) -> Matrix[_T]
         retval = self.copy()
         for i in self.indexes():
             retval[i] /= rhs  # type: ignore
@@ -128,7 +127,7 @@ class Matrix(Generic[_T]):
         return lhs.__matmul__(self)
 
     def __elementwise_bin_op(self, rhs, op):
-        # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
+        # type: (Matrix[_T], Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
         if self.height != rhs.height or self.width != rhs.width:
             raise ValueError(
                 "matrix dimensions must match for element-wise operations")
@@ -172,8 +171,8 @@ class Matrix(Generic[_T]):
         # type: () -> str
         if self.height == 0 or self.width == 0:
             return f"Matrix(height={self.height}, width={self.width})"
-        lines = []
-        line = []
+        lines = []  # type: list[str]
+        line = []  # type: list[str]
         for row in range(self.height):
             line.clear()
             for col in range(self.width):
@@ -183,16 +182,16 @@ class Matrix(Generic[_T]):
                 else:
                     line.append(repr(el))
             lines.append(", ".join(line))
-        lines = ",\n    ".join(lines)
+        lines_str = ",\n    ".join(lines)
         element_type = ""
         if self.element_type is not Fraction:
             element_type = f"element_type={self.element_type}, "
         return (f"Matrix(height={self.height}, width={self.width}, "
                 f"{element_type}data=[\n"
-                f"    {lines},\n])")
+                f"    {lines_str},\n])")
 
     def __eq__(self, rhs):
-        # type: (object) -> bool
+        # type: (Matrix[Any] | Any) -> bool
         if not isinstance(rhs, Matrix):
             return NotImplemented
         return (self.height == rhs.height