working on rewriting compiler ir to fix reg alloc issues
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
index eacceb410d8726d49d32f9be1192b8446c46a707..666df1402c2745d57de81bc3430ec3484efa429e 100644 (file)
@@ -1,20 +1,23 @@
+from collections import defaultdict
 import enum
 from enum import Enum, unique
-from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
-from weakref import WeakValueDictionary
+from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
+                                           OSet, FMap)
+from functools import lru_cache
 
 
 @final
 class Fn:
     def __init__(self):
         self.ops = []  # type: list[Op]
-        op_names = WeakValueDictionary()
-        self.__op_names = op_names  # type: WeakValueDictionary[str, Op]
+        self.__op_names = _WeakVDict()  # type: _WeakVDict[str, Op]
         self.__next_name_suffix = 2
 
     def _add_op_with_unused_name(self, op, name=""):
@@ -32,278 +35,464 @@ class Fn:
             self.__next_name_suffix += 1
 
     def __repr__(self):
+        # type: () -> str
         return "<Fn>"
 
 
 @unique
 @final
-class RegKind(Enum):
-    GPR = enum.auto()
+class BaseTy(Enum):
+    I64 = enum.auto()
     CA = enum.auto()
     VL_MAXVL = enum.auto()
 
     @cached_property
     def only_scalar(self):
-        if self is RegKind.GPR:
+        # type: () -> bool
+        if self is BaseTy.I64:
             return False
-        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+        elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
             return True
         else:
             assert_never(self)
 
     @cached_property
-    def reg_count(self):
-        if self is RegKind.GPR:
+    def max_reg_len(self):
+        # type: () -> int
+        if self is BaseTy.I64:
             return 128
-        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+        elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
             return 1
         else:
             assert_never(self)
 
     def __repr__(self):
-        return "RegKind." + self._name_
+        return "BaseTy." + self._name_
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OperandType:
-    __slots__ = "kind", "vec"
+class Ty:
+    __slots__ = "base_ty", "reg_len"
 
-    def __init__(self, kind, vec):
-        # type: (RegKind, bool) -> None
-        self.kind = kind
-        if kind.only_scalar and vec:
-            raise ValueError(f"kind={kind} must have vec=False")
-        self.vec = vec
-
-    def get_length(self, maxvl):
-        # type: (int) -> int
-        # here's where subvl and elwid would be accounted for
-        if self.vec:
-            return maxvl
-        return 1
+    @staticmethod
+    def validate(base_ty, reg_len):
+        # type: (BaseTy, int) -> str | None
+        """ return a string with the error if the combination is invalid,
+        otherwise return None
+        """
+        if base_ty.only_scalar and reg_len != 1:
+            return f"can't create a vector of an only-scalar type: {base_ty}"
+        if reg_len < 1 or reg_len > base_ty.max_reg_len:
+            return "reg_len out of range"
+        return None
+
+    def __init__(self, base_ty, reg_len):
+        # type: (BaseTy, int) -> None
+        msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+        if msg is not None:
+            raise ValueError(msg)
+        self.base_ty = base_ty
+        self.reg_len = reg_len
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@unique
 @final
-class RegShape:
-    __slots__ = "kind", "length"
+class LocKind(Enum):
+    GPR = enum.auto()
+    StackI64 = enum.auto()
+    CA = enum.auto()
+    VL_MAXVL = enum.auto()
 
-    def __init__(self, kind, length=1):
-        # type: (RegKind, int) -> None
-        self.kind = kind
-        if length < 1 or length > kind.reg_count:
-            raise ValueError("invalid length")
-        self.length = length
+    @cached_property
+    def base_ty(self):
+        # type: () -> BaseTy
+        if self is LocKind.GPR or self is LocKind.StackI64:
+            return BaseTy.I64
+        if self is LocKind.CA:
+            return BaseTy.CA
+        if self is LocKind.VL_MAXVL:
+            return BaseTy.VL_MAXVL
+        else:
+            assert_never(self)
 
-    def try_concat(self, *others):
-        # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
-        kind = self.kind
-        length = self.length
-        for other in others:
-            if isinstance(other, (Reg, RegClass)):
-                other = other.shape
-            if other is None:
-                return None
-            if other.kind != self.kind:
-                return None
-            length += other.length
-        if length > kind.reg_count:
-            return None
-        return RegShape(kind=kind, length=length)
+    @cached_property
+    def loc_count(self):
+        # type: () -> int
+        if self is LocKind.StackI64:
+            return 1024
+        if self is LocKind.GPR or self is LocKind.CA \
+                or self is LocKind.VL_MAXVL:
+            return self.base_ty.max_reg_len
+        else:
+            assert_never(self)
+
+    def __repr__(self):
+        return "LocKind." + self._name_
 
 
-@plain_data(frozen=True, unsafe_hash=True)
 @final
-class Reg:
-    __slots__ = "shape", "start"
-
-    def __init__(self, shape, start):
-        # type: (RegShape, int) -> None
-        self.shape = shape
-        if start < 0 or start + shape.length > shape.kind.reg_count:
-            raise ValueError("start not in valid range")
-        self.start = start
+@unique
+class LocSubKind(Enum):
+    BASE_GPR = enum.auto()
+    SV_EXTRA2_VGPR = enum.auto()
+    SV_EXTRA2_SGPR = enum.auto()
+    SV_EXTRA3_VGPR = enum.auto()
+    SV_EXTRA3_SGPR = enum.auto()
+    StackI64 = enum.auto()
+    CA = enum.auto()
+    VL_MAXVL = enum.auto()
 
-    @property
+    @cached_property
     def kind(self):
-        return self.shape.kind
+        # type: () -> LocKind
+        # pyright fails typechecking when using `in` here:
+        # reported: https://github.com/microsoft/pyright/issues/4102
+        if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
+                or self is LocSubKind.SV_EXTRA2_SGPR \
+                or self is LocSubKind.SV_EXTRA3_VGPR \
+                or self is LocSubKind.SV_EXTRA3_SGPR:
+            return LocKind.GPR
+        if self is LocSubKind.StackI64:
+            return LocKind.StackI64
+        if self is LocSubKind.CA:
+            return LocKind.CA
+        if self is LocSubKind.VL_MAXVL:
+            return LocKind.VL_MAXVL
+        assert_never(self)
 
     @property
-    def length(self):
-        return self.shape.length
+    def base_ty(self):
+        return self.kind.base_ty
+
+    @lru_cache()
+    def allocatable_locs(self, ty):
+        # type: (Ty) -> LocSet
+        if ty.base_ty != self.base_ty:
+            raise ValueError("type mismatch")
+        raise NotImplementedError  # FIXME: finish
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericTy:
+    __slots__ = "base_ty", "is_vec"
+
+    def __init__(self, base_ty, is_vec):
+        # type: (BaseTy, bool) -> None
+        self.base_ty = base_ty
+        if base_ty.only_scalar and is_vec:
+            raise ValueError(f"base_ty={base_ty} requires is_vec=False")
+        self.is_vec = is_vec
+
+    def instantiate(self, maxvl):
+        # type: (int) -> Ty
+        # here's where subvl and elwid would be accounted for
+        if self.is_vec:
+            return Ty(self.base_ty, maxvl)
+        return Ty(self.base_ty, 1)
+
+    def can_instantiate_to(self, ty):
+        # type: (Ty) -> bool
+        if self.base_ty != ty.base_ty:
+            return False
+        if self.is_vec:
+            return True
+        return ty.reg_len == 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Loc:
+    __slots__ = "kind", "start", "reg_len"
+
+    @staticmethod
+    def validate(kind, start, reg_len):
+        # type: (LocKind, int, int) -> str | None
+        msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
+        if msg is not None:
+            return msg
+        if reg_len > kind.loc_count:
+            return "invalid reg_len"
+        if start < 0 or start + reg_len > kind.loc_count:
+            return "start not in valid range"
+        return None
+
+    @staticmethod
+    def try_make(kind, start, reg_len):
+        # type: (LocKind, int, int) -> Loc | None
+        msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
+        if msg is None:
+            return None
+        return Loc(kind=kind, start=start, reg_len=reg_len)
+
+    def __init__(self, kind, start, reg_len):
+        # type: (LocKind, int, int) -> None
+        msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+        if msg is not None:
+            raise ValueError(msg)
+        self.kind = kind
+        self.reg_len = reg_len
+        self.start = start
 
     def conflicts(self, other):
-        # type: (Reg) -> bool
-        return (self.kind == other.kind
+        # type: (Loc) -> bool
+        return (self.kind != other.kind
                 and self.start < other.stop and other.start < self.stop)
 
+    @staticmethod
+    def make_ty(kind, reg_len):
+        # type: (LocKind, int) -> Ty
+        return Ty(base_ty=kind.base_ty, reg_len=reg_len)
+
+    @cached_property
+    def ty(self):
+        # type: () -> Ty
+        return self.make_ty(kind=self.kind, reg_len=self.reg_len)
+
     @property
     def stop(self):
-        return self.start + self.length
+        # type: () -> int
+        return self.start + self.reg_len
 
     def try_concat(self, *others):
-        # type: (*Reg | None) -> Reg | None
-        shape = self.shape.try_concat(*others)
-        if shape is None:
-            return None
+        # type: (*Loc | None) -> Loc | None
+        reg_len = self.reg_len
         stop = self.stop
         for other in others:
-            assert other is not None, "already caught by RegShape.try_concat"
+            if other is None or other.kind != self.kind:
+                return None
             if stop != other.start:
                 return None
             stop = other.stop
-        return Reg(shape, self.start)
+            reg_len += other.reg_len
+        return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
 
 
+@plain_data(frozen=True, eq=False, repr=False)
 @final
-class RegClass(AbstractSet[Reg]):
-    def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
-        # type: (Iterable[Reg | int], RegShape | None, int) -> None
-        for reg_or_start in regs_or_starts:
-            if isinstance(reg_or_start, Reg):
-                if shape is None:
-                    shape = reg_or_start.shape
-                elif shape != reg_or_start.shape:
-                    raise ValueError(f"conflicting RegShapes: {shape} and "
-                                     f"{reg_or_start.shape}")
-                start = reg_or_start.start
-            else:
-                start = reg_or_start
-            if start < 0:
-                raise ValueError("a Reg's start is out of range")
-            starts_bitset |= 1 << start
-        if starts_bitset == 0:
-            shape = None
-        self.__shape = shape
-        self.__starts_bitset = starts_bitset
-        if shape is None:
-            if starts_bitset != 0:
-                raise ValueError("non-empty RegClass must have non-None shape")
+class LocSet(AbstractSet[Loc]):
+    __slots__ = "starts", "ty"
+
+    def __init__(self, __locs=()):
+        # type: (Iterable[Loc]) -> None
+        if isinstance(__locs, LocSet):
+            self.starts = __locs.starts  # type: FMap[LocKind, FBitSet]
+            self.ty = __locs.ty  # type: Ty | None
             return
-        if self.stops_bitset >= 1 << shape.kind.reg_count:
-            raise ValueError("a Reg's start is out of range")
-
-    @property
-    def shape(self):
-        # type: () -> RegShape | None
-        return self.__shape
-
-    @property
-    def starts_bitset(self):
-        # type: () -> int
-        return self.__starts_bitset
-
-    @property
-    def stops_bitset(self):
-        # type: () -> int
-        if self.__shape is None:
-            return 0
-        return self.__starts_bitset << self.__shape.length
-
-    @cached_property
-    def starts(self):
-        # type: () -> OFSet[int]
-        if self.length is None:
-            return OFSet()
-        # TODO: fixme
-        # return OFSet(for i in range(self.length))
+        starts = {i: BitSet() for i in LocKind}
+        ty = None
+        for loc in __locs:
+            if ty is None:
+                ty = loc.ty
+            if ty != loc.ty:
+                raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+            starts[loc.kind].add(loc.start)
+        self.starts = FMap(
+            (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
+        self.ty = ty
 
     @cached_property
     def stops(self):
-        # type: () -> OFSet[int]
-        if self.__shape is None:
-            return OFSet()
-        return OFSet(i + self.__shape.length for i in self.__starts)
+        # type: () -> FMap[LocKind, FBitSet]
+        if self.ty is None:
+            return FMap()
+        sh = self.ty.reg_len
+        return FMap(
+            (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
 
     @property
-    def kind(self):
-        if self.__shape is None:
+    def kinds(self):
+        # type: () -> AbstractSet[LocKind]
+        return self.starts.keys()
+
+    @property
+    def reg_len(self):
+        # type: () -> int | None
+        if self.ty is None:
             return None
-        return self.__shape.kind
+        return self.ty.reg_len
 
     @property
-    def length(self):
-        """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
-        if self.__shape is None:
+    def base_ty(self):
+        # type: () -> BaseTy | None
+        if self.ty is None:
             return None
-        return self.__shape.length
+        return self.ty.base_ty
 
     def concat(self, *others):
-        # type: (*RegClass) -> RegClass
-        shape = self.__shape
-        if shape is None:
-            return RegClass()
-        shape = shape.try_concat(*others)
-        if shape is None:
-            return RegClass()
-        starts = OSet(self.starts)
-        offset = shape.length
+        # type: (*LocSet) -> LocSet
+        if self.ty is None:
+            return LocSet()
+        base_ty = self.ty.base_ty
+        reg_len = self.ty.reg_len
+        starts = {k: BitSet(v) for k, v in self.starts.items()}
         for other in others:
-            assert other.__shape is not None, \
-                "already caught by RegShape.try_concat"
-            starts &= OSet(i - offset for i in other.starts)
-            offset += other.__shape.length
-        return RegClass(starts, shape=shape)
-
-    def __contains__(self, reg):
-        # type: (Reg) -> bool
-        return reg.shape == self.shape and reg.start in self.starts
+            if other.ty is None:
+                return LocSet()
+            if other.ty.base_ty != base_ty:
+                return LocSet()
+            for kind, other_starts in other.starts.items():
+                if kind not in starts:
+                    continue
+                starts[kind].bits &= other_starts.bits >> reg_len
+                if starts[kind] == 0:
+                    del starts[kind]
+                    if len(starts) == 0:
+                        return LocSet()
+            reg_len += other.ty.reg_len
+
+        def locs():
+            # type: () -> Iterable[Loc]
+            for kind, v in starts.items():
+                for start in v:
+                    loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
+                    if loc is not None:
+                        yield loc
+        return LocSet(locs())
+
+    def __contains__(self, loc):
+        # type: (Loc | Any) -> bool
+        if not isinstance(loc, Loc) or loc.ty == self.ty:
+            return False
+        if loc.kind not in self.starts:
+            return False
+        return loc.start in self.starts[loc.kind]
 
     def __iter__(self):
-        # type: () -> Iterator[Reg]
-        if self.shape is None:
+        # type: () -> Iterator[Loc]
+        if self.ty is None:
             return
-        for start in self.starts:
-            yield Reg(shape=self.shape, start=start)
+        for kind, starts in self.starts.items():
+            for start in starts:
+                yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
+
+    @cached_property
+    def __len(self):
+        return sum((len(v) for v in self.starts.values()), 0)
 
     def __len__(self):
-        return len(self.starts)
+        return self.__len
 
-    def __hash__(self):
+    @cached_property
+    def __hash(self):
         return super()._hash()
 
+    def __hash__(self):
+        return self.__hash
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class Operand:
-    __slots__ = "ty", "regs"
-
-    def __init__(self, ty, regs=None):
-        # type: (OperandType, OFSet[int] | None) -> None
-        pass
-
-
-OT_VGPR = OperandType(RegKind.GPR, vec=True)
-OT_SGPR = OperandType(RegKind.GPR, vec=False)
-OT_CA = OperandType(RegKind.CA, vec=False)
-OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+class GenericOperandDesc:
+    """generic Op operand descriptor"""
+    __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index"
+
+    def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None):
+        # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None
+        self.ty = ty
+        self.sub_kinds = OFSet(sub_kinds)
+        if len(self.sub_kinds) == 0:
+            raise ValueError("sub_kinds can't be empty")
+        self.fixed_loc = fixed_loc
+        if fixed_loc is not None:
+            if tied_input_index is not None:
+                raise ValueError("operand can't be both tied and fixed")
+            if not ty.can_instantiate_to(fixed_loc.ty):
+                raise ValueError(
+                    f"fixed_loc has incompatible type for given generic "
+                    f"type: fixed_loc={fixed_loc} generic ty={ty}")
+            if len(self.sub_kinds) != 1:
+                raise ValueError(
+                    "multiple sub_kinds not allowed for fixed operand")
+            for sub_kind in self.sub_kinds:
+                if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
+                    raise ValueError(
+                        f"fixed_loc not in given sub_kind: "
+                        f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
+        for sub_kind in self.sub_kinds:
+            if sub_kind.base_ty != ty.base_ty:
+                raise ValueError(f"sub_kind is incompatible with type: "
+                                 f"sub_kind={sub_kind} ty={ty}")
+        if tied_input_index is not None and tied_input_index < 0:
+            raise ValueError("invalid tied_input_index")
+        self.tied_input_index = tied_input_index
+
+    def tied_to_input(self, tied_input_index):
+        # type: (int) -> Self
+        return GenericOperandDesc(self.ty, self.sub_kinds,
+                                  tied_input_index=tied_input_index)
+
+    def with_fixed_loc(self, fixed_loc):
+        # type: (Loc) -> Self
+        return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+
+    def instantiate(self, maxvl):
+        # type: (int) -> OperandDesc
+        ty = self.ty.instantiate(maxvl=maxvl)
+
+        def locs():
+            # type: () -> Iterable[Loc]
+            if self.fixed_loc is not None:
+                if ty != self.fixed_loc.ty:
+                    raise ValueError(
+                        f"instantiation failed: type mismatch with fixed_loc: "
+                        f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+                yield self.fixed_loc
+                return
+            for sub_kind in self.sub_kinds:
+                yield from sub_kind.allocatable_locs(ty)
+        return OperandDesc(loc_set=LocSet(locs()),
+                           tied_input_index=self.tied_input_index)
 
 
 @plain_data(frozen=True, unsafe_hash=True)
-class TiedOutput:
-    __slots__ = "input_index", "output_index"
-
-    def __init__(self, input_index, output_index):
-        # type: (int, int) -> None
-        self.input_index = input_index
-        self.output_index = output_index
-
-
-Constraint = Union[TiedOutput, NoReturn]
+@final
+class OperandDesc:
+    """Op operand descriptor"""
+    __slots__ = "loc_set", "tied_input_index"
+
+    def __init__(self, loc_set, tied_input_index):
+        # type: (LocSet, int | None) -> None
+        if len(loc_set) == 0:
+            raise ValueError("loc_set must not be empty")
+        self.loc_set = loc_set
+        self.tied_input_index = tied_input_index
+
+
+OD_BASE_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.BASE_GPR])
+OD_EXTRA3_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
+OD_EXTRA3_VGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+    sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
+OD_EXTRA2_SGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+    sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
+OD_EXTRA2_VGPR = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+    sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
+OD_CA = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
+    sub_kinds=[LocSubKind.CA])
+OD_VL = GenericOperandDesc(
+    ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
+    sub_kinds=[LocSubKind.VL_MAXVL])
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OpProperties:
-    __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+class GenericOpProperties:
+    __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
                  "is_copy", "is_load_immediate", "has_side_effects")
 
     def __init__(self, demo_asm,  # type: str
-                 inputs,  # type: Iterable[OperandType]
-                 outputs,  # type: Iterable[OperandType]
+                 inputs,  # type: Iterable[GenericOperandDesc]
+                 outputs,  # type: Iterable[GenericOperandDesc]
                  immediates,  # type: Iterable[range]
-                 constraints,  # type: Iterable[Constraint]
                  is_copy=False,  # type: bool
                  is_load_immediate=False,  # type: bool
                  has_side_effects=False,  # type: bool
@@ -313,17 +502,22 @@ class OpProperties:
         self.inputs = tuple(inputs)
         self.outputs = tuple(outputs)
         self.immediates = tuple(immediates)
-        self.constraints = tuple(constraints)
         self.is_copy = is_copy
         self.is_load_immediate = is_load_immediate
         self.has_side_effects = has_side_effects
 
+    def instantiate(self, maxvl):
+        # type: (int) -> OpProperties
+        raise NotImplementedError  # FIXME: finish
+
+
+# FIXME: add OpProperties
 
 @unique
 @final
 class OpKind(Enum):
     def __init__(self, properties):
-        # type: (OpProperties) -> None
+        # type: (GenericOpProperties) -> None
         super().__init__()
         self.properties = properties
 
@@ -451,13 +645,13 @@ class SSAVal:
         self.sliced_op_outputs = tuple(processed)
 
     def __add__(self, other):
-        # type: (SSAVal) -> SSAVal
+        # type: (SSAVal | Any) -> SSAVal
         if not isinstance(other, SSAVal):
             return NotImplemented
         return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
 
     def __radd__(self, other):
-        # type: (SSAVal) -> SSAVal
+        # type: (SSAVal | Any) -> SSAVal
         if isinstance(other, SSAVal):
             return other.__add__(self)
         return NotImplemented
@@ -465,7 +659,7 @@ class SSAVal:
     @cached_property
     def expanded_sliced_op_outputs(self):
         # type: () -> tuple[tuple[Op, int, int], ...]
-        retval = []
+        retval = []  # type: list[tuple[Op, int, int]]
         for op, output_index, range_ in self.sliced_op_outputs:
             for i in range_:
                 retval.append((op, output_index, i))
@@ -490,7 +684,7 @@ class SSAVal:
         # type: () -> str
         if len(self.sliced_op_outputs) == 0:
             return "SSAVal([])"
-        parts = []
+        parts = []  # type: list[str]
         for op, output_index, range_ in self.sliced_op_outputs:
             out_len = op.properties.outputs[output_index].get_length(op.maxvl)
             parts.append(f"<{op.name}#{output_index}>")
@@ -513,13 +707,14 @@ class Op:
         self.maxvl = maxvl
         outputs_len = len(self.properties.outputs)
         self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
-        self.name = fn._add_op_with_unused_name(self, name)
+        self.name = fn._add_op_with_unused_name(self, name)  # type: ignore
 
     @property
     def properties(self):
         return self.kind.properties
 
     def __eq__(self, other):
+        # type: (Op | Any) -> bool
         if isinstance(other, Op):
             return self is other
         return NotImplemented