split compiler IR and register allocator out into their own files
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 06:09:38 +0000 (23:09 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 14 Oct 2022 06:09:38 +0000 (23:09 -0700)
src/bigint_presentation_code/compiler_ir.py [new file with mode: 0644]
src/bigint_presentation_code/register_allocator.py [new file with mode: 0644]
src/bigint_presentation_code/test_compiler_ir.py [new file with mode: 0644]
src/bigint_presentation_code/test_register_allocator.py [new file with mode: 0644]
src/bigint_presentation_code/test_toom_cook.py
src/bigint_presentation_code/toom_cook.py

diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py
new file mode 100644 (file)
index 0000000..645ff0e
--- /dev/null
@@ -0,0 +1,783 @@
+"""
+Compiler IR for Toom-Cook algorithm generator for SVP64
+"""
+
+from abc import ABCMeta, abstractmethod
+from collections import defaultdict
+from enum import Enum, EnumMeta, unique
+from functools import lru_cache
+from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
+                    TypeVar)
+
+from nmutil.plain_data import plain_data
+
+if TYPE_CHECKING:
+    from typing_extensions import final
+else:
+    def final(v):
+        return v
+
+
+class ABCEnumMeta(EnumMeta, ABCMeta):
+    pass
+
+
+class RegLoc(metaclass=ABCMeta):
+    __slots__ = ()
+
+    @abstractmethod
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        ...
+
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> RegLoc
+        if self not in subreg_type.reg_class:
+            raise ValueError(f"register not a member of subreg_type: "
+                             f"reg={self} subreg_type={subreg_type}")
+        if offset != 0:
+            raise ValueError(f"non-zero sub-register offset not supported "
+                             f"for register: {self}")
+        return self
+
+
+GPR_COUNT = 128
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GPRRange(RegLoc, Sequence["GPRRange"]):
+    __slots__ = "start", "length"
+
+    def __init__(self, start, length=None):
+        # type: (int | range, int | None) -> None
+        if isinstance(start, range):
+            if length is not None:
+                raise TypeError("can't specify length when input is a range")
+            if start.step != 1:
+                raise ValueError("range must have a step of 1")
+            length = len(start)
+            start = start.start
+        elif length is None:
+            length = 1
+        if length <= 0 or start < 0 or start + length > GPR_COUNT:
+            raise ValueError("invalid GPRRange")
+        self.start = start
+        self.length = length
+
+    @property
+    def stop(self):
+        return self.start + self.length
+
+    @property
+    def step(self):
+        return 1
+
+    @property
+    def range(self):
+        return range(self.start, self.stop, self.step)
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, item):
+        # type: (int | slice) -> GPRRange
+        return GPRRange(self.range[item])
+
+    def __contains__(self, value):
+        # type: (GPRRange) -> bool
+        return value.start >= self.start and value.stop <= self.stop
+
+    def index(self, sub, start=None, end=None):
+        # type: (GPRRange, int | None, int | None) -> int
+        r = self.range[start:end]
+        if sub.start < r.start or sub.stop > r.stop:
+            raise ValueError("GPR range not found")
+        return sub.start - self.start
+
+    def count(self, sub, start=None, end=None):
+        # type: (GPRRange, int | None, int | None) -> int
+        r = self.range[start:end]
+        if len(r) == 0:
+            return 0
+        return int(sub in GPRRange(r))
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, GPRRange):
+            return self.stop > other.start and other.stop > self.start
+        return False
+
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> GPRRange
+        if not isinstance(subreg_type, GPRRangeType):
+            raise ValueError(f"subreg_type is not a "
+                             f"GPRRangeType: {subreg_type}")
+        if offset < 0 or offset + subreg_type.length > self.stop:
+            raise ValueError(f"sub-register offset is out of range: {offset}")
+        return GPRRange(self.start + offset, subreg_type.length)
+
+
+SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
+
+
+@final
+@unique
+class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
+    CY = "CY"
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, XERBit):
+            return self == other
+        return False
+
+
+@final
+@unique
+class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
+    """singleton representing all non-StackSlot memory -- treated as a single
+    physical register for register allocation purposes.
+    """
+    GlobalMem = "GlobalMem"
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, GlobalMem):
+            return self == other
+        return False
+
+
+@final
+class RegClass(AbstractSet[RegLoc]):
+    """ an ordered set of registers.
+    earlier registers are preferred by the register allocator.
+    """
+
+    def __init__(self, regs):
+        # type: (Iterable[RegLoc]) -> None
+
+        # use dict to maintain order
+        self.__regs = dict.fromkeys(regs)  # type: dict[RegLoc, None]
+
+    def __len__(self):
+        return len(self.__regs)
+
+    def __iter__(self):
+        return iter(self.__regs)
+
+    def __contains__(self, v):
+        # type: (RegLoc) -> bool
+        return v in self.__regs
+
+    def __hash__(self):
+        return super()._hash()
+
+    @lru_cache(maxsize=None, typed=True)
+    def max_conflicts_with(self, other):
+        # type: (RegClass | RegLoc) -> int
+        """the largest number of registers in `self` that a single register
+        from `other` can conflict with
+        """
+        if isinstance(other, RegClass):
+            return max(self.max_conflicts_with(i) for i in other)
+        else:
+            return sum(other.conflicts(i) for i in self)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class RegType(metaclass=ABCMeta):
+    __slots__ = ()
+
+    @property
+    @abstractmethod
+    def reg_class(self):
+        # type: () -> RegClass
+        return ...
+
+
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+
+
+@plain_data(frozen=True, eq=False)
+class GPRRangeType(RegType):
+    __slots__ = "length",
+
+    def __init__(self, length):
+        # type: (int) -> None
+        if length < 1 or length > GPR_COUNT:
+            raise ValueError("invalid length")
+        self.length = length
+
+    @staticmethod
+    @lru_cache(maxsize=None)
+    def __get_reg_class(length):
+        # type: (int) -> RegClass
+        regs = []
+        for start in range(GPR_COUNT - length):
+            reg = GPRRange(start, length)
+            if any(i in reg for i in SPECIAL_GPRS):
+                continue
+            regs.append(reg)
+        return RegClass(regs)
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return GPRRangeType.__get_reg_class(self.length)
+
+    @final
+    def __eq__(self, other):
+        if isinstance(other, GPRRangeType):
+            return self.length == other.length
+        return False
+
+    @final
+    def __hash__(self):
+        return hash(self.length)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class GPRType(GPRRangeType):
+    __slots__ = ()
+
+    def __init__(self, length=1):
+        if length != 1:
+            raise ValueError("length must be 1")
+        super().__init__(length=1)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class CYType(RegType):
+    __slots__ = ()
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return RegClass([XERBit.CY])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GlobalMemType(RegType):
+    __slots__ = ()
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return RegClass([GlobalMem.GlobalMem])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class StackSlot(RegLoc):
+    __slots__ = "start_slot", "length_in_slots",
+
+    def __init__(self, start_slot, length_in_slots):
+        # type: (int, int) -> None
+        self.start_slot = start_slot
+        if length_in_slots < 1:
+            raise ValueError("invalid length_in_slots")
+        self.length_in_slots = length_in_slots
+
+    @property
+    def stop_slot(self):
+        return self.start_slot + self.length_in_slots
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, StackSlot):
+            return (self.stop_slot > other.start_slot
+                    and other.stop_slot > self.start_slot)
+        return False
+
+    def get_subreg_at_offset(self, subreg_type, offset):
+        # type: (RegType, int) -> StackSlot
+        if not isinstance(subreg_type, StackSlotType):
+            raise ValueError(f"subreg_type is not a "
+                             f"StackSlotType: {subreg_type}")
+        if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
+            raise ValueError(f"sub-register offset is out of range: {offset}")
+        return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
+
+
+STACK_SLOT_COUNT = 128
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class StackSlotType(RegType):
+    __slots__ = "length_in_slots",
+
+    def __init__(self, length_in_slots=1):
+        # type: (int) -> None
+        if length_in_slots < 1:
+            raise ValueError("invalid length_in_slots")
+        self.length_in_slots = length_in_slots
+
+    @staticmethod
+    @lru_cache(maxsize=None)
+    def __get_reg_class(length_in_slots):
+        # type: (int) -> RegClass
+        regs = []
+        for start in range(STACK_SLOT_COUNT - length_in_slots):
+            reg = StackSlot(start, length_in_slots)
+            regs.append(reg)
+        return RegClass(regs)
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return StackSlotType.__get_reg_class(self.length_in_slots)
+
+    @final
+    def __eq__(self, other):
+        if isinstance(other, StackSlotType):
+            return self.length_in_slots == other.length_in_slots
+        return False
+
+    @final
+    def __hash__(self):
+        return hash(self.length_in_slots)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class SSAVal(Generic[_RegT_co]):
+    __slots__ = "op", "arg_name", "ty", "arg_index"
+
+    def __init__(self, op, arg_name, ty):
+        # type: (Op, str, _RegT_co) -> None
+        self.op = op
+        """the Op that writes this SSAVal"""
+
+        self.arg_name = arg_name
+        """the name of the argument of self.op that writes this SSAVal"""
+
+        self.ty = ty
+
+    def __eq__(self, rhs):
+        if isinstance(rhs, SSAVal):
+            return (self.op is rhs.op
+                    and self.arg_name == rhs.arg_name)
+        return False
+
+    def __hash__(self):
+        return hash((id(self.op), self.arg_name))
+
+
+@final
+@plain_data(unsafe_hash=True, frozen=True)
+class EqualityConstraint:
+    __slots__ = "lhs", "rhs"
+
+    def __init__(self, lhs, rhs):
+        # type: (list[SSAVal], list[SSAVal]) -> None
+        self.lhs = lhs
+        self.rhs = rhs
+        if len(lhs) == 0 or len(rhs) == 0:
+            raise ValueError("can't constrain an empty list to be equal")
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+class Op(metaclass=ABCMeta):
+    __slots__ = ()
+
+    @abstractmethod
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        ...
+
+    @abstractmethod
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        ...
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        if False:
+            yield ...
+
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        if False:
+            yield ...
+
+    def __init__(self):
+        pass
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoadFromStackSlot(Op):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, src):
+        # type: (SSAVal[GPRRangeType]) -> None
+        self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+        self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStoreToStackSlot(Op):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, src):
+        # type: (SSAVal[StackSlotType]) -> None
+        self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+        self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpCopy(Op, Generic[_RegT_co]):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, src):
+        # type: (SSAVal[_RegT_co]) -> None
+        self.dest = SSAVal(self, "dest", src.ty)
+        self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpConcat(Op):
+    __slots__ = "dest", "sources"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, sources):
+        # type: (Iterable[SSAVal[GPRRangeType]]) -> None
+        sources = tuple(sources)
+        self.dest = SSAVal(self, "dest", GPRRangeType(
+            sum(i.ty.length for i in sources)))
+        self.sources = sources
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint([self.dest], [*self.sources])
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpSplit(Op):
+    __slots__ = "results", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {i.arg_name: i for i in self.results}
+
+    def __init__(self, src, split_indexes):
+        # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+        ranges = []  # type: list[GPRRangeType]
+        last = 0
+        for i in split_indexes:
+            if not (0 < i < src.ty.length):
+                raise ValueError(f"invalid split index: {i}, must be in "
+                                 f"0 < i < {src.ty.length}")
+            ranges.append(GPRRangeType(i - last))
+            last = i
+        ranges.append(GPRRangeType(src.ty.length - last))
+        self.src = src
+        self.results = tuple(
+            SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint([*self.results], [self.src])
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpAddSubE(Op):
+    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RT": self.RT, "CY_out": self.CY_out}
+
+    def __init__(self, RA, RB, CY_in, is_sub):
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+        if RA.ty != RB.ty:
+            raise TypeError(f"source types must match: "
+                            f"{RA} doesn't match {RB}")
+        self.RT = SSAVal(self, "RT", RA.ty)
+        self.RA = RA
+        self.RB = RB
+        self.CY_in = CY_in
+        self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
+        self.is_sub = is_sub
+
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.RA
+        yield self.RT, self.RB
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntMulDiv(Op):
+    __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RT": self.RT, "RS": self.RS}
+
+    def __init__(self, RA, RB, RC, is_div):
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+        self.RT = SSAVal(self, "RT", RA.ty)
+        self.RA = RA
+        self.RB = RB
+        self.RC = RC
+        self.RS = SSAVal(self, "RS", RC.ty)
+        self.is_div = is_div
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint([self.RC], [self.RS])
+
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.RA
+        yield self.RT, self.RB
+        yield self.RT, self.RC
+        yield self.RT, self.RS
+        yield self.RS, self.RA
+        yield self.RS, self.RB
+
+
+@final
+@unique
+class ShiftKind(Enum):
+    Sl = "sl"
+    Sr = "sr"
+    Sra = "sra"
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntShift(Op):
+    __slots__ = "RT", "inp", "sh", "kind"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"inp": self.inp, "sh": self.sh}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RT": self.RT}
+
+    def __init__(self, inp, sh, kind):
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+        self.RT = SSAVal(self, "RT", inp.ty)
+        self.inp = inp
+        self.sh = sh
+        self.kind = kind
+
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        yield self.RT, self.inp
+        yield self.RT, self.sh
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLI(Op):
+    __slots__ = "out", "value"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, value, length=1):
+        # type: (int, int) -> None
+        self.out = SSAVal(self, "out", GPRRangeType(length))
+        self.value = value
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpClearCY(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"out": self.out}
+
+    def __init__(self):
+        # type: () -> None
+        self.out = SSAVal(self, "out", CYType())
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoad(Op):
+    __slots__ = "RT", "RA", "offset", "mem"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RA": self.RA, "mem": self.mem}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RT": self.RT}
+
+    def __init__(self, RA, offset, mem, length=1):
+        # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+        self.RT = SSAVal(self, "RT", GPRRangeType(length))
+        self.RA = RA
+        self.offset = offset
+        self.mem = mem
+
+    def get_extra_interferences(self):
+        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+        if self.RT.ty.length > 1:
+            yield self.RT, self.RA
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStore(Op):
+    __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"mem_out": self.mem_out}
+
+    def __init__(self, RS, RA, offset, mem_in):
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+        self.RS = RS
+        self.RA = RA
+        self.offset = offset
+        self.mem_in = mem_in
+        self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpFuncArg(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, ty):
+        # type: (RegType) -> None
+        self.out = SSAVal(self, "out", ty)
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpInputMem(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"out": self.out}
+
+    def __init__(self):
+        # type: () -> None
+        self.out = SSAVal(self, "out", GlobalMemType())
+
+
+def op_set_to_list(ops):
+    # type: (Iterable[Op]) -> list[Op]
+    worklists = [set()]  # type: list[set[Op]]
+    input_vals_to_ops_map = defaultdict(set)  # type: dict[SSAVal, set[Op]]
+    ops_to_pending_input_count_map = {}  # type: dict[Op, int]
+    for op in ops:
+        input_count = 0
+        for val in op.inputs().values():
+            input_count += 1
+            input_vals_to_ops_map[val].add(op)
+        while len(worklists) <= input_count:
+            worklists.append(set())
+        ops_to_pending_input_count_map[op] = input_count
+        worklists[input_count].add(op)
+    retval = []  # type: list[Op]
+    ready_vals = set()  # type: set[SSAVal]
+    while len(worklists[0]) != 0:
+        writing_op = worklists[0].pop()
+        retval.append(writing_op)
+        for val in writing_op.outputs().values():
+            if val in ready_vals:
+                raise ValueError(f"multiple instructions must not write "
+                                 f"to the same SSA value: {val}")
+            ready_vals.add(val)
+            for reading_op in input_vals_to_ops_map[val]:
+                pending = ops_to_pending_input_count_map[reading_op]
+                worklists[pending].remove(reading_op)
+                pending -= 1
+                worklists[pending].add(reading_op)
+                ops_to_pending_input_count_map[reading_op] = pending
+    for worklist in worklists:
+        for op in worklist:
+            raise ValueError(f"instruction is part of a dependency loop or "
+                             f"its inputs are never written: {op}")
+    return retval
diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py
new file mode 100644 (file)
index 0000000..297e3e5
--- /dev/null
@@ -0,0 +1,414 @@
+"""
+Register Allocator for Toom-Cook algorithm generator for SVP64
+
+this uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
+from itertools import combinations
+from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
+                                                  RegLoc, RegType, SSAVal)
+
+if TYPE_CHECKING:
+    from typing_extensions import Self, final
+else:
+    def final(v):
+        return v
+
+
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+
+
+@plain_data(unsafe_hash=True, order=True, frozen=True)
+class LiveInterval:
+    __slots__ = "first_write", "last_use"
+
+    def __init__(self, first_write, last_use=None):
+        # type: (int, int | None) -> None
+        if last_use is None:
+            last_use = first_write
+        if last_use < first_write:
+            raise ValueError("uses must be after first_write")
+        if first_write < 0 or last_use < 0:
+            raise ValueError("indexes must be nonnegative")
+        self.first_write = first_write
+        self.last_use = last_use
+
+    def overlaps(self, other):
+        # type: (LiveInterval) -> bool
+        if self.first_write == other.first_write:
+            return True
+        return self.last_use > other.first_write \
+            and other.last_use > self.first_write
+
+    def __add__(self, use):
+        # type: (int) -> LiveInterval
+        last_use = max(self.last_use, use)
+        return LiveInterval(first_write=self.first_write, last_use=last_use)
+
+    @property
+    def live_after_op_range(self):
+        """the range of op indexes where self is live immediately after the
+        Op at each index
+        """
+        return range(self.first_write, self.last_use)
+
+
+@final
+class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
+    def __init__(self, reg_set):
+        # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
+        self.__items = {}  # type: dict[SSAVal[_RegT_co], int]
+        if isinstance(reg_set, SSAVal):
+            reg_set = [(reg_set, 0)]
+        for ssa_val, offset in reg_set:
+            if ssa_val in self.__items:
+                other = self.__items[ssa_val]
+                if offset != other:
+                    raise ValueError(
+                        f"can't merge register sets: conflicting offsets: "
+                        f"for {ssa_val}: {offset} != {other}")
+            else:
+                self.__items[ssa_val] = offset
+        first_item = None
+        for i in self.__items.items():
+            first_item = i
+            break
+        if first_item is None:
+            raise ValueError("can't have empty MergedRegs")
+        first_ssa_val, start = first_item
+        ty = first_ssa_val.ty
+        if isinstance(ty, GPRRangeType):
+            stop = start + ty.length
+            for ssa_val, offset in self.__items.items():
+                if not isinstance(ssa_val.ty, GPRRangeType):
+                    raise ValueError(f"can't merge incompatible types: "
+                                     f"{ssa_val.ty} and {ty}")
+                stop = max(stop, offset + ssa_val.ty.length)
+                start = min(start, offset)
+            ty = GPRRangeType(stop - start)
+        else:
+            stop = 1
+            for ssa_val, offset in self.__items.items():
+                if offset != 0:
+                    raise ValueError(f"can't have non-zero offset "
+                                     f"for {ssa_val.ty}")
+                if ty != ssa_val.ty:
+                    raise ValueError(f"can't merge incompatible types: "
+                                     f"{ssa_val.ty} and {ty}")
+        self.__start = start  # type: int
+        self.__stop = stop  # type: int
+        self.__ty = ty  # type: RegType
+        self.__hash = hash(frozenset(self.items()))
+
+    @staticmethod
+    def from_equality_constraint(constraint_sequence):
+        # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
+        if len(constraint_sequence) == 1:
+            # any type allowed with len = 1
+            return MergedRegSet(constraint_sequence[0])
+        offset = 0
+        retval = []
+        for val in constraint_sequence:
+            if not isinstance(val.ty, GPRRangeType):
+                raise ValueError("equality constraint sequences must only "
+                                 "have SSAVal type GPRRangeType")
+            retval.append((val, offset))
+            offset += val.ty.length
+        return MergedRegSet(retval)
+
+    @property
+    def ty(self):
+        return self.__ty
+
+    @property
+    def stop(self):
+        return self.__stop
+
+    @property
+    def start(self):
+        return self.__start
+
+    @property
+    def range(self):
+        return range(self.__start, self.__stop)
+
+    def offset_by(self, amount):
+        # type: (int) -> MergedRegSet[_RegT_co]
+        return MergedRegSet((k, v + amount) for k, v in self.items())
+
+    def normalized(self):
+        # type: () -> MergedRegSet[_RegT_co]
+        return self.offset_by(-self.start)
+
+    def with_offset_to_match(self, target):
+        # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
+        for ssa_val, offset in self.items():
+            if ssa_val in target:
+                return self.offset_by(target[ssa_val] - offset)
+        raise ValueError("can't change offset to match unrelated MergedRegSet")
+
+    def __getitem__(self, item):
+        # type: (SSAVal[_RegT_co]) -> int
+        return self.__items[item]
+
+    def __iter__(self):
+        return iter(self.__items)
+
+    def __len__(self):
+        return len(self.__items)
+
+    def __hash__(self):
+        return self.__hash
+
+    def __repr__(self):
+        return f"MergedRegSet({list(self.__items.items())})"
+
+
+@final
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
+    def __init__(self, ops):
+        # type: (Iterable[Op]) -> None
+        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegT_co]]
+        for op in ops:
+            for val in (*op.inputs().values(), *op.outputs().values()):
+                if val not in merged_sets:
+                    merged_sets[val] = MergedRegSet(val)
+            for e in op.get_equality_constraints():
+                lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
+                rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
+                lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
+                rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
+                full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+                for val in full_set.keys():
+                    merged_sets[val] = full_set
+
+        self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+
+    def __getitem__(self, key):
+        # type: (SSAVal) -> MergedRegSet
+        return self.__map[key]
+
+    def __iter__(self):
+        return iter(self.__map)
+
+    def __len__(self):
+        return len(self.__map)
+
+    def __repr__(self):
+        return f"MergedRegSets(data={self.__map})"
+
+
+@final
+class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
+    def __init__(self, ops):
+        # type: (list[Op]) -> None
+        self.__merged_reg_sets = MergedRegSets(ops)
+        live_intervals = {}  # type: dict[MergedRegSet[_RegT_co], LiveInterval]
+        for op_idx, op in enumerate(ops):
+            for val in op.inputs().values():
+                live_intervals[self.__merged_reg_sets[val]] += op_idx
+            for val in op.outputs().values():
+                reg_set = self.__merged_reg_sets[val]
+                if reg_set not in live_intervals:
+                    live_intervals[reg_set] = LiveInterval(op_idx)
+                else:
+                    live_intervals[reg_set] += op_idx
+        self.__live_intervals = live_intervals
+        live_after = []  # type: list[set[MergedRegSet[_RegT_co]]]
+        live_after += (set() for _ in ops)
+        for reg_set, live_interval in self.__live_intervals.items():
+            for i in live_interval.live_after_op_range:
+                live_after[i].add(reg_set)
+        self.__live_after = [frozenset(i) for i in live_after]
+
+    @property
+    def merged_reg_sets(self):
+        return self.__merged_reg_sets
+
+    def __getitem__(self, key):
+        # type: (MergedRegSet[_RegT_co]) -> LiveInterval
+        return self.__live_intervals[key]
+
+    def __iter__(self):
+        return iter(self.__live_intervals)
+
+    def reg_sets_live_after(self, op_index):
+        # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+        return self.__live_after[op_index]
+
+    def __repr__(self):
+        reg_sets_live_after = dict(enumerate(self.__live_after))
+        return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+                f"merged_reg_sets={self.merged_reg_sets}, "
+                f"reg_sets_live_after={reg_sets_live_after})")
+
+
+@final
+class IGNode(Generic[_RegT_co]):
+    """ interference graph node """
+    __slots__ = "merged_reg_set", "edges", "reg"
+
+    def __init__(self, merged_reg_set, edges=(), reg=None):
+        # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
+        self.merged_reg_set = merged_reg_set
+        self.edges = set(edges)
+        self.reg = reg
+
+    def add_edge(self, other):
+        # type: (IGNode) -> None
+        self.edges.add(other)
+        other.edges.add(self)
+
+    def __eq__(self, other):
+        # type: (object) -> bool
+        if isinstance(other, IGNode):
+            return self.merged_reg_set == other.merged_reg_set
+        return NotImplemented
+
+    def __hash__(self):
+        return hash(self.merged_reg_set)
+
+    def __repr__(self, nodes=None):
+        # type: (None | dict[IGNode, int]) -> str
+        if nodes is None:
+            nodes = {}
+        if self in nodes:
+            return f"<IGNode #{nodes[self]}>"
+        nodes[self] = len(nodes)
+        edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
+        return (f"IGNode(#{nodes[self]}, "
+                f"merged_reg_set={self.merged_reg_set}, "
+                f"edges={edges}, "
+                f"reg={self.reg})")
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return self.merged_reg_set.ty.reg_class
+
+    def reg_conflicts_with_neighbors(self, reg):
+        # type: (RegLoc) -> bool
+        for neighbor in self.edges:
+            if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+                return True
+        return False
+
+
+@final
+class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
+    def __init__(self, merged_reg_sets):
+        # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
+        self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
+
+    def __getitem__(self, key):
+        # type: (MergedRegSet[_RegT_co]) -> IGNode
+        return self.__nodes[key]
+
+    def __iter__(self):
+        return iter(self.__nodes)
+
+    def __repr__(self):
+        nodes = {}
+        nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+        nodes_text = ", ".join(nodes_text)
+        return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
+
+@plain_data()
+class AllocationFailed:
+    __slots__ = "node", "live_intervals", "interference_graph"
+
+    def __init__(self, node, live_intervals, interference_graph):
+        # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+        self.node = node
+        self.live_intervals = live_intervals
+        self.interference_graph = interference_graph
+
+
+def try_allocate_registers_without_spilling(ops):
+    # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+
+    live_intervals = LiveIntervals(ops)
+    merged_reg_sets = live_intervals.merged_reg_sets
+    interference_graph = InterferenceGraph(merged_reg_sets.values())
+    for op_idx, op in enumerate(ops):
+        reg_sets = live_intervals.reg_sets_live_after(op_idx)
+        for i, j in combinations(reg_sets, 2):
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
+        for i, j in op.get_extra_interferences():
+            i = merged_reg_sets[i]
+            j = merged_reg_sets[j]
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
+
+    nodes_remaining = set(interference_graph.values())
+
+    def local_colorability_score(node):
+        # type: (IGNode) -> int
+        """ returns a positive integer if node is locally colorable, returns
+        zero or a negative integer if node isn't known to be locally
+        colorable, the more negative the value, the less colorable
+        """
+        if node not in nodes_remaining:
+            raise ValueError()
+        retval = len(node.reg_class)
+        for neighbor in node.edges:
+            if neighbor in nodes_remaining:
+                retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+        return retval
+
+    node_stack = []  # type: list[IGNode]
+    while True:
+        best_node = None  # type: None | IGNode
+        best_score = 0
+        for node in nodes_remaining:
+            score = local_colorability_score(node)
+            if best_node is None or score > best_score:
+                best_node = node
+                best_score = score
+            if best_score > 0:
+                # it's locally colorable, no need to find a better one
+                break
+
+        if best_node is None:
+            break
+        node_stack.append(best_node)
+        nodes_remaining.remove(best_node)
+
+    retval = {}  # type: dict[SSAVal, RegLoc]
+
+    while len(node_stack) > 0:
+        node = node_stack.pop()
+        if node.reg is not None:
+            if node.reg_conflicts_with_neighbors(node.reg):
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+        else:
+            # pick the first non-conflicting register in node.reg_class, since
+            # register classes are ordered from most preferred to least
+            # preferred register.
+            for reg in node.reg_class:
+                if not node.reg_conflicts_with_neighbors(reg):
+                    node.reg = reg
+                    break
+            if node.reg is None:
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+
+        for ssa_val, offset in node.merged_reg_set.items():
+            retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
+
+    return retval
+
+
+def allocate_registers(ops):
+    # type: (list[Op]) -> None
+    raise NotImplementedError
diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py
new file mode 100644 (file)
index 0000000..231f2fb
--- /dev/null
@@ -0,0 +1,11 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir import Op, op_set_to_list
+
+
+class TestCompilerIR(unittest.TestCase):
+    pass  # no tests yet, just testing importing
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py
new file mode 100644 (file)
index 0000000..1345cc7
--- /dev/null
@@ -0,0 +1,14 @@
+import unittest
+
+from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.register_allocator import (
+    AllocationFailed, allocate_registers,
+    try_allocate_registers_without_spilling)
+
+
+class TestCompilerIR(unittest.TestCase):
+    pass  # no tests yet, just testing importing
+
+
+if __name__ == "__main__":
+    unittest.main()
index 9851afb59b5326100bc420c74f1272b50cef7120..8fe6cea693e98687abc4a877e830680ba963126c 100644 (file)
@@ -1,5 +1,6 @@
 import unittest
-from bigint_presentation_code.toom_cook import Op
+
+import bigint_presentation_code.toom_cook
 
 
 class TestToomCook(unittest.TestCase):
index 88e8f7e27c2e3715ec73c90b5fa2df192e4af7c9..c014d09a59a5d93a0428249bca81ac1d9bcf1fb1 100644 (file)
@@ -4,1174 +4,5 @@ Toom-Cook algorithm generator for SVP64
 the register allocator uses an algorithm based on:
 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
 """
-
-from abc import ABCMeta, abstractmethod
-from collections import defaultdict
-from enum import Enum, unique, EnumMeta
-from functools import lru_cache
-from itertools import combinations
-from typing import (Sequence, AbstractSet, Iterable, Mapping,
-                    TYPE_CHECKING, Sequence, TypeVar, Generic)
-
-from nmutil.plain_data import plain_data
-
-if TYPE_CHECKING:
-    from typing_extensions import final, Self
-else:
-    def final(v):
-        return v
-
-
-class ABCEnumMeta(EnumMeta, ABCMeta):
-    pass
-
-
-class RegLoc(metaclass=ABCMeta):
-    __slots__ = ()
-
-    @abstractmethod
-    def conflicts(self, other):
-        # type: (RegLoc) -> bool
-        ...
-
-    def get_subreg_at_offset(self, subreg_type, offset):
-        # type: (RegType, int) -> RegLoc
-        if self not in subreg_type.reg_class:
-            raise ValueError(f"register not a member of subreg_type: "
-                             f"reg={self} subreg_type={subreg_type}")
-        if offset != 0:
-            raise ValueError(f"non-zero sub-register offset not supported "
-                             f"for register: {self}")
-        return self
-
-
-GPR_COUNT = 128
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GPRRange(RegLoc, Sequence["GPRRange"]):
-    __slots__ = "start", "length"
-
-    def __init__(self, start, length=None):
-        # type: (int | range, int | None) -> None
-        if isinstance(start, range):
-            if length is not None:
-                raise TypeError("can't specify length when input is a range")
-            if start.step != 1:
-                raise ValueError("range must have a step of 1")
-            length = len(start)
-            start = start.start
-        elif length is None:
-            length = 1
-        if length <= 0 or start < 0 or start + length > GPR_COUNT:
-            raise ValueError("invalid GPRRange")
-        self.start = start
-        self.length = length
-
-    @property
-    def stop(self):
-        return self.start + self.length
-
-    @property
-    def step(self):
-        return 1
-
-    @property
-    def range(self):
-        return range(self.start, self.stop, self.step)
-
-    def __len__(self):
-        return self.length
-
-    def __getitem__(self, item):
-        # type: (int | slice) -> GPRRange
-        return GPRRange(self.range[item])
-
-    def __contains__(self, value):
-        # type: (GPRRange) -> bool
-        return value.start >= self.start and value.stop <= self.stop
-
-    def index(self, sub, start=None, end=None):
-        # type: (GPRRange, int | None, int | None) -> int
-        r = self.range[start:end]
-        if sub.start < r.start or sub.stop > r.stop:
-            raise ValueError("GPR range not found")
-        return sub.start - self.start
-
-    def count(self, sub, start=None, end=None):
-        # type: (GPRRange, int | None, int | None) -> int
-        r = self.range[start:end]
-        if len(r) == 0:
-            return 0
-        return int(sub in GPRRange(r))
-
-    def conflicts(self, other):
-        # type: (RegLoc) -> bool
-        if isinstance(other, GPRRange):
-            return self.stop > other.start and other.stop > self.start
-        return False
-
-    def get_subreg_at_offset(self, subreg_type, offset):
-        # type: (RegType, int) -> GPRRange
-        if not isinstance(subreg_type, GPRRangeType):
-            raise ValueError(f"subreg_type is not a "
-                             f"GPRRangeType: {subreg_type}")
-        if offset < 0 or offset + subreg_type.length > self.stop:
-            raise ValueError(f"sub-register offset is out of range: {offset}")
-        return GPRRange(self.start + offset, subreg_type.length)
-
-
-SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
-
-
-@final
-@unique
-class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
-    CY = "CY"
-
-    def conflicts(self, other):
-        # type: (RegLoc) -> bool
-        if isinstance(other, XERBit):
-            return self == other
-        return False
-
-
-@final
-@unique
-class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
-    """singleton representing all non-StackSlot memory -- treated as a single
-    physical register for register allocation purposes.
-    """
-    GlobalMem = "GlobalMem"
-
-    def conflicts(self, other):
-        # type: (RegLoc) -> bool
-        if isinstance(other, GlobalMem):
-            return self == other
-        return False
-
-
-@final
-class RegClass(AbstractSet[RegLoc]):
-    """ an ordered set of registers.
-    earlier registers are preferred by the register allocator.
-    """
-    def __init__(self, regs):
-        # type: (Iterable[RegLoc]) -> None
-
-        # use dict to maintain order
-        self.__regs = dict.fromkeys(regs)  # type: dict[RegLoc, None]
-
-    def __len__(self):
-        return len(self.__regs)
-
-    def __iter__(self):
-        return iter(self.__regs)
-
-    def __contains__(self, v):
-        # type: (RegLoc) -> bool
-        return v in self.__regs
-
-    def __hash__(self):
-        return super()._hash()
-
-    @lru_cache(maxsize=None, typed=True)
-    def max_conflicts_with(self, other):
-        # type: (RegClass | RegLoc) -> int
-        """the largest number of registers in `self` that a single register
-        from `other` can conflict with
-        """
-        if isinstance(other, RegClass):
-            return max(self.max_conflicts_with(i) for i in other)
-        else:
-            return sum(other.conflicts(i) for i in self)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-class RegType(metaclass=ABCMeta):
-    __slots__ = ()
-
-    @property
-    @abstractmethod
-    def reg_class(self):
-        # type: () -> RegClass
-        return ...
-
-
-@plain_data(frozen=True, eq=False)
-class GPRRangeType(RegType):
-    __slots__ = "length",
-
-    def __init__(self, length):
-        # type: (int) -> None
-        if length < 1 or length > GPR_COUNT:
-            raise ValueError("invalid length")
-        self.length = length
-
-    @staticmethod
-    @lru_cache(maxsize=None)
-    def __get_reg_class(length):
-        # type: (int) -> RegClass
-        regs = []
-        for start in range(GPR_COUNT - length):
-            reg = GPRRange(start, length)
-            if any(i in reg for i in SPECIAL_GPRS):
-                continue
-            regs.append(reg)
-        return RegClass(regs)
-
-    @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return GPRRangeType.__get_reg_class(self.length)
-
-    @final
-    def __eq__(self, other):
-        if isinstance(other, GPRRangeType):
-            return self.length == other.length
-        return False
-
-    @final
-    def __hash__(self):
-        return hash(self.length)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class GPRType(GPRRangeType):
-    __slots__ = ()
-
-    def __init__(self, length=1):
-        if length != 1:
-            raise ValueError("length must be 1")
-        super().__init__(length=1)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class CYType(RegType):
-    __slots__ = ()
-
-    @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return RegClass([XERBit.CY])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GlobalMemType(RegType):
-    __slots__ = ()
-
-    @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return RegClass([GlobalMem.GlobalMem])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class StackSlot(RegLoc):
-    __slots__ = "start_slot", "length_in_slots",
-
-    def __init__(self, start_slot, length_in_slots):
-        # type: (int, int) -> None
-        self.start_slot = start_slot
-        if length_in_slots < 1:
-            raise ValueError("invalid length_in_slots")
-        self.length_in_slots = length_in_slots
-
-    @property
-    def stop_slot(self):
-        return self.start_slot + self.length_in_slots
-
-    def conflicts(self, other):
-        # type: (RegLoc) -> bool
-        if isinstance(other, StackSlot):
-            return (self.stop_slot > other.start_slot
-                    and other.stop_slot > self.start_slot)
-        return False
-
-    def get_subreg_at_offset(self, subreg_type, offset):
-        # type: (RegType, int) -> StackSlot
-        if not isinstance(subreg_type, StackSlotType):
-            raise ValueError(f"subreg_type is not a "
-                             f"StackSlotType: {subreg_type}")
-        if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
-            raise ValueError(f"sub-register offset is out of range: {offset}")
-        return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
-
-
-STACK_SLOT_COUNT = 128
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class StackSlotType(RegType):
-    __slots__ = "length_in_slots",
-
-    def __init__(self, length_in_slots=1):
-        # type: (int) -> None
-        if length_in_slots < 1:
-            raise ValueError("invalid length_in_slots")
-        self.length_in_slots = length_in_slots
-
-    @staticmethod
-    @lru_cache(maxsize=None)
-    def __get_reg_class(length_in_slots):
-        # type: (int) -> RegClass
-        regs = []
-        for start in range(STACK_SLOT_COUNT - length_in_slots):
-            reg = StackSlot(start, length_in_slots)
-            regs.append(reg)
-        return RegClass(regs)
-
-    @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return StackSlotType.__get_reg_class(self.length_in_slots)
-
-    @final
-    def __eq__(self, other):
-        if isinstance(other, StackSlotType):
-            return self.length_in_slots == other.length_in_slots
-        return False
-
-    @final
-    def __hash__(self):
-        return hash(self.length_in_slots)
-
-
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class SSAVal(Generic[_RegT_co]):
-    __slots__ = "op", "arg_name", "ty", "arg_index"
-
-    def __init__(self, op, arg_name, ty):
-        # type: (Op, str, _RegT_co) -> None
-        self.op = op
-        """the Op that writes this SSAVal"""
-
-        self.arg_name = arg_name
-        """the name of the argument of self.op that writes this SSAVal"""
-
-        self.ty = ty
-
-    def __eq__(self, rhs):
-        if isinstance(rhs, SSAVal):
-            return (self.op is rhs.op
-                    and self.arg_name == rhs.arg_name)
-        return False
-
-    def __hash__(self):
-        return hash((id(self.op), self.arg_name))
-
-
-@final
-@plain_data(unsafe_hash=True, frozen=True)
-class EqualityConstraint:
-    __slots__ = "lhs", "rhs"
-
-    def __init__(self, lhs, rhs):
-        # type: (list[SSAVal], list[SSAVal]) -> None
-        self.lhs = lhs
-        self.rhs = rhs
-        if len(lhs) == 0 or len(rhs) == 0:
-            raise ValueError("can't constrain an empty list to be equal")
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-class Op(metaclass=ABCMeta):
-    __slots__ = ()
-
-    @abstractmethod
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        ...
-
-    @abstractmethod
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        ...
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        if False:
-            yield ...
-
-    def get_extra_interferences(self):
-        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        if False:
-            yield ...
-
-    def __init__(self):
-        pass
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoadFromStackSlot(Op):
-    __slots__ = "dest", "src"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"dest": self.dest}
-
-    def __init__(self, src):
-        # type: (SSAVal[GPRRangeType]) -> None
-        self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
-        self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStoreToStackSlot(Op):
-    __slots__ = "dest", "src"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"dest": self.dest}
-
-    def __init__(self, src):
-        # type: (SSAVal[StackSlotType]) -> None
-        self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
-        self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpCopy(Op, Generic[_RegT_co]):
-    __slots__ = "dest", "src"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"dest": self.dest}
-
-    def __init__(self, src):
-        # type: (SSAVal[_RegT_co]) -> None
-        self.dest = SSAVal(self, "dest", src.ty)
-        self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpConcat(Op):
-    __slots__ = "dest", "sources"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"dest": self.dest}
-
-    def __init__(self, sources):
-        # type: (Iterable[SSAVal[GPRRangeType]]) -> None
-        sources = tuple(sources)
-        self.dest = SSAVal(self, "dest", GPRRangeType(
-            sum(i.ty.length for i in sources)))
-        self.sources = sources
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint([self.dest], [*self.sources])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpSplit(Op):
-    __slots__ = "results", "src"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {i.arg_name: i for i in self.results}
-
-    def __init__(self, src, split_indexes):
-        # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
-        ranges = []  # type: list[GPRRangeType]
-        last = 0
-        for i in split_indexes:
-            if not (0 < i < src.ty.length):
-                raise ValueError(f"invalid split index: {i}, must be in "
-                                 f"0 < i < {src.ty.length}")
-            ranges.append(GPRRangeType(i - last))
-            last = i
-        ranges.append(GPRRangeType(src.ty.length - last))
-        self.src = src
-        self.results = tuple(
-            SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint([*self.results], [self.src])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpAddSubE(Op):
-    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT, "CY_out": self.CY_out}
-
-    def __init__(self, RA, RB, CY_in, is_sub):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
-        if RA.ty != RB.ty:
-            raise TypeError(f"source types must match: "
-                            f"{RA} doesn't match {RB}")
-        self.RT = SSAVal(self, "RT", RA.ty)
-        self.RA = RA
-        self.RB = RB
-        self.CY_in = CY_in
-        self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
-        self.is_sub = is_sub
-
-    def get_extra_interferences(self):
-        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        yield self.RT, self.RA
-        yield self.RT, self.RB
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntMulDiv(Op):
-    __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT, "RS": self.RS}
-
-    def __init__(self, RA, RB, RC, is_div):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
-        self.RT = SSAVal(self, "RT", RA.ty)
-        self.RA = RA
-        self.RB = RB
-        self.RC = RC
-        self.RS = SSAVal(self, "RS", RC.ty)
-        self.is_div = is_div
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint([self.RC], [self.RS])
-
-    def get_extra_interferences(self):
-        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        yield self.RT, self.RA
-        yield self.RT, self.RB
-        yield self.RT, self.RC
-        yield self.RT, self.RS
-        yield self.RS, self.RA
-        yield self.RS, self.RB
-
-
-@final
-@unique
-class ShiftKind(Enum):
-    Sl = "sl"
-    Sr = "sr"
-    Sra = "sra"
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntShift(Op):
-    __slots__ = "RT", "inp", "sh", "kind"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"inp": self.inp, "sh": self.sh}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT}
-
-    def __init__(self, inp, sh, kind):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
-        self.RT = SSAVal(self, "RT", inp.ty)
-        self.inp = inp
-        self.sh = sh
-        self.kind = kind
-
-    def get_extra_interferences(self):
-        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        yield self.RT, self.inp
-        yield self.RT, self.sh
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLI(Op):
-    __slots__ = "out", "value"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"out": self.out}
-
-    def __init__(self, value, length=1):
-        # type: (int, int) -> None
-        self.out = SSAVal(self, "out", GPRRangeType(length))
-        self.value = value
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpClearCY(Op):
-    __slots__ = "out",
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"out": self.out}
-
-    def __init__(self):
-        # type: () -> None
-        self.out = SSAVal(self, "out", CYType())
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoad(Op):
-    __slots__ = "RT", "RA", "offset", "mem"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "mem": self.mem}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT}
-
-    def __init__(self, RA, offset, mem, length=1):
-        # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
-        self.RT = SSAVal(self, "RT", GPRRangeType(length))
-        self.RA = RA
-        self.offset = offset
-        self.mem = mem
-
-    def get_extra_interferences(self):
-        # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        if self.RT.ty.length > 1:
-            yield self.RT, self.RA
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStore(Op):
-    __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"mem_out": self.mem_out}
-
-    def __init__(self, RS, RA, offset, mem_in):
-        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
-        self.RS = RS
-        self.RA = RA
-        self.offset = offset
-        self.mem_in = mem_in
-        self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpFuncArg(Op):
-    __slots__ = "out",
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"out": self.out}
-
-    def __init__(self, ty):
-        # type: (RegType) -> None
-        self.out = SSAVal(self, "out", ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpInputMem(Op):
-    __slots__ = "out",
-
-    def inputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {}
-
-    def outputs(self):
-        # type: () -> dict[str, SSAVal]
-        return {"out": self.out}
-
-    def __init__(self):
-        # type: () -> None
-        self.out = SSAVal(self, "out", GlobalMemType())
-
-
-def op_set_to_list(ops):
-    # type: (Iterable[Op]) -> list[Op]
-    worklists = [set()]  # type: list[set[Op]]
-    input_vals_to_ops_map = defaultdict(set)  # type: dict[SSAVal, set[Op]]
-    ops_to_pending_input_count_map = {}  # type: dict[Op, int]
-    for op in ops:
-        input_count = 0
-        for val in op.inputs().values():
-            input_count += 1
-            input_vals_to_ops_map[val].add(op)
-        while len(worklists) <= input_count:
-            worklists.append(set())
-        ops_to_pending_input_count_map[op] = input_count
-        worklists[input_count].add(op)
-    retval = []  # type: list[Op]
-    ready_vals = set()  # type: set[SSAVal]
-    while len(worklists[0]) != 0:
-        writing_op = worklists[0].pop()
-        retval.append(writing_op)
-        for val in writing_op.outputs().values():
-            if val in ready_vals:
-                raise ValueError(f"multiple instructions must not write "
-                                 f"to the same SSA value: {val}")
-            ready_vals.add(val)
-            for reading_op in input_vals_to_ops_map[val]:
-                pending = ops_to_pending_input_count_map[reading_op]
-                worklists[pending].remove(reading_op)
-                pending -= 1
-                worklists[pending].add(reading_op)
-                ops_to_pending_input_count_map[reading_op] = pending
-    for worklist in worklists:
-        for op in worklist:
-            raise ValueError(f"instruction is part of a dependency loop or "
-                             f"its inputs are never written: {op}")
-    return retval
-
-
-@plain_data(unsafe_hash=True, order=True, frozen=True)
-class LiveInterval:
-    __slots__ = "first_write", "last_use"
-
-    def __init__(self, first_write, last_use=None):
-        # type: (int, int | None) -> None
-        if last_use is None:
-            last_use = first_write
-        if last_use < first_write:
-            raise ValueError("uses must be after first_write")
-        if first_write < 0 or last_use < 0:
-            raise ValueError("indexes must be nonnegative")
-        self.first_write = first_write
-        self.last_use = last_use
-
-    def overlaps(self, other):
-        # type: (LiveInterval) -> bool
-        if self.first_write == other.first_write:
-            return True
-        return self.last_use > other.first_write \
-            and other.last_use > self.first_write
-
-    def __add__(self, use):
-        # type: (int) -> LiveInterval
-        last_use = max(self.last_use, use)
-        return LiveInterval(first_write=self.first_write, last_use=last_use)
-
-    @property
-    def live_after_op_range(self):
-        """the range of op indexes where self is live immediately after the
-        Op at each index
-        """
-        return range(self.first_write, self.last_use)
-
-
-@final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
-    def __init__(self, reg_set):
-        # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
-        self.__items = {}  # type: dict[SSAVal[_RegT_co], int]
-        if isinstance(reg_set, SSAVal):
-            reg_set = [(reg_set, 0)]
-        for ssa_val, offset in reg_set:
-            if ssa_val in self.__items:
-                other = self.__items[ssa_val]
-                if offset != other:
-                    raise ValueError(
-                        f"can't merge register sets: conflicting offsets: "
-                        f"for {ssa_val}: {offset} != {other}")
-            else:
-                self.__items[ssa_val] = offset
-        first_item = None
-        for i in self.__items.items():
-            first_item = i
-            break
-        if first_item is None:
-            raise ValueError("can't have empty MergedRegs")
-        first_ssa_val, start = first_item
-        ty = first_ssa_val.ty
-        if isinstance(ty, GPRRangeType):
-            stop = start + ty.length
-            for ssa_val, offset in self.__items.items():
-                if not isinstance(ssa_val.ty, GPRRangeType):
-                    raise ValueError(f"can't merge incompatible types: "
-                                     f"{ssa_val.ty} and {ty}")
-                stop = max(stop, offset + ssa_val.ty.length)
-                start = min(start, offset)
-            ty = GPRRangeType(stop - start)
-        else:
-            stop = 1
-            for ssa_val, offset in self.__items.items():
-                if offset != 0:
-                    raise ValueError(f"can't have non-zero offset "
-                                     f"for {ssa_val.ty}")
-                if ty != ssa_val.ty:
-                    raise ValueError(f"can't merge incompatible types: "
-                                     f"{ssa_val.ty} and {ty}")
-        self.__start = start  # type: int
-        self.__stop = stop  # type: int
-        self.__ty = ty  # type: RegType
-        self.__hash = hash(frozenset(self.items()))
-
-    @staticmethod
-    def from_equality_constraint(constraint_sequence):
-        # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
-        if len(constraint_sequence) == 1:
-            # any type allowed with len = 1
-            return MergedRegSet(constraint_sequence[0])
-        offset = 0
-        retval = []
-        for val in constraint_sequence:
-            if not isinstance(val.ty, GPRRangeType):
-                raise ValueError("equality constraint sequences must only "
-                                 "have SSAVal type GPRRangeType")
-            retval.append((val, offset))
-            offset += val.ty.length
-        return MergedRegSet(retval)
-
-    @property
-    def ty(self):
-        return self.__ty
-
-    @property
-    def stop(self):
-        return self.__stop
-
-    @property
-    def start(self):
-        return self.__start
-
-    @property
-    def range(self):
-        return range(self.__start, self.__stop)
-
-    def offset_by(self, amount):
-        # type: (int) -> MergedRegSet[_RegT_co]
-        return MergedRegSet((k, v + amount) for k, v in self.items())
-
-    def normalized(self):
-        # type: () -> MergedRegSet[_RegT_co]
-        return self.offset_by(-self.start)
-
-    def with_offset_to_match(self, target):
-        # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
-        for ssa_val, offset in self.items():
-            if ssa_val in target:
-                return self.offset_by(target[ssa_val] - offset)
-        raise ValueError("can't change offset to match unrelated MergedRegSet")
-
-    def __getitem__(self, item):
-        # type: (SSAVal[_RegT_co]) -> int
-        return self.__items[item]
-
-    def __iter__(self):
-        return iter(self.__items)
-
-    def __len__(self):
-        return len(self.__items)
-
-    def __hash__(self):
-        return self.__hash
-
-    def __repr__(self):
-        return f"MergedRegSet({list(self.__items.items())})"
-
-
-@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
-    def __init__(self, ops):
-        # type: (Iterable[Op]) -> None
-        merged_sets = {}  # type: dict[SSAVal, MergedRegSet[_RegT_co]]
-        for op in ops:
-            for val in (*op.inputs().values(), *op.outputs().values()):
-                if val not in merged_sets:
-                    merged_sets[val] = MergedRegSet(val)
-            for e in op.get_equality_constraints():
-                lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
-                rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
-                lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
-                rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
-                full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
-                for val in full_set.keys():
-                    merged_sets[val] = full_set
-
-        self.__map = {k: v.normalized() for k, v in merged_sets.items()}
-
-    def __getitem__(self, key):
-        # type: (SSAVal) -> MergedRegSet
-        return self.__map[key]
-
-    def __iter__(self):
-        return iter(self.__map)
-
-    def __len__(self):
-        return len(self.__map)
-
-    def __repr__(self):
-        return f"MergedRegSets(data={self.__map})"
-
-
-@final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
-    def __init__(self, ops):
-        # type: (list[Op]) -> None
-        self.__merged_reg_sets = MergedRegSets(ops)
-        live_intervals = {}  # type: dict[MergedRegSet[_RegT_co], LiveInterval]
-        for op_idx, op in enumerate(ops):
-            for val in op.inputs().values():
-                live_intervals[self.__merged_reg_sets[val]] += op_idx
-            for val in op.outputs().values():
-                reg_set = self.__merged_reg_sets[val]
-                if reg_set not in live_intervals:
-                    live_intervals[reg_set] = LiveInterval(op_idx)
-                else:
-                    live_intervals[reg_set] += op_idx
-        self.__live_intervals = live_intervals
-        live_after = []  # type: list[set[MergedRegSet[_RegT_co]]]
-        live_after += (set() for _ in ops)
-        for reg_set, live_interval in self.__live_intervals.items():
-            for i in live_interval.live_after_op_range:
-                live_after[i].add(reg_set)
-        self.__live_after = [frozenset(i) for i in live_after]
-
-    @property
-    def merged_reg_sets(self):
-        return self.__merged_reg_sets
-
-    def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> LiveInterval
-        return self.__live_intervals[key]
-
-    def __iter__(self):
-        return iter(self.__live_intervals)
-
-    def reg_sets_live_after(self, op_index):
-        # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
-        return self.__live_after[op_index]
-
-    def __repr__(self):
-        reg_sets_live_after = dict(enumerate(self.__live_after))
-        return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
-                f"merged_reg_sets={self.merged_reg_sets}, "
-                f"reg_sets_live_after={reg_sets_live_after})")
-
-
-@final
-class IGNode(Generic[_RegT_co]):
-    """ interference graph node """
-    __slots__ = "merged_reg_set", "edges", "reg"
-
-    def __init__(self, merged_reg_set, edges=(), reg=None):
-        # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
-        self.merged_reg_set = merged_reg_set
-        self.edges = set(edges)
-        self.reg = reg
-
-    def add_edge(self, other):
-        # type: (IGNode) -> None
-        self.edges.add(other)
-        other.edges.add(self)
-
-    def __eq__(self, other):
-        # type: (object) -> bool
-        if isinstance(other, IGNode):
-            return self.merged_reg_set == other.merged_reg_set
-        return NotImplemented
-
-    def __hash__(self):
-        return hash(self.merged_reg_set)
-
-    def __repr__(self, nodes=None):
-        # type: (None | dict[IGNode, int]) -> str
-        if nodes is None:
-            nodes = {}
-        if self in nodes:
-            return f"<IGNode #{nodes[self]}>"
-        nodes[self] = len(nodes)
-        edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
-        return (f"IGNode(#{nodes[self]}, "
-                f"merged_reg_set={self.merged_reg_set}, "
-                f"edges={edges}, "
-                f"reg={self.reg})")
-
-    @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return self.merged_reg_set.ty.reg_class
-
-    def reg_conflicts_with_neighbors(self, reg):
-        # type: (RegLoc) -> bool
-        for neighbor in self.edges:
-            if neighbor.reg is not None and neighbor.reg.conflicts(reg):
-                return True
-        return False
-
-
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
-    def __init__(self, merged_reg_sets):
-        # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
-        self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
-    def __getitem__(self, key):
-        # type: (MergedRegSet[_RegT_co]) -> IGNode
-        return self.__nodes[key]
-
-    def __iter__(self):
-        return iter(self.__nodes)
-
-    def __repr__(self):
-        nodes = {}
-        nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
-        nodes_text = ", ".join(nodes_text)
-        return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
-@plain_data()
-class AllocationFailed:
-    __slots__ = "node", "live_intervals", "interference_graph"
-
-    def __init__(self, node, live_intervals, interference_graph):
-        # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
-        self.node = node
-        self.live_intervals = live_intervals
-        self.interference_graph = interference_graph
-
-
-def try_allocate_registers_without_spilling(ops):
-    # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
-
-    live_intervals = LiveIntervals(ops)
-    merged_reg_sets = live_intervals.merged_reg_sets
-    interference_graph = InterferenceGraph(merged_reg_sets.values())
-    for op_idx, op in enumerate(ops):
-        reg_sets = live_intervals.reg_sets_live_after(op_idx)
-        for i, j in combinations(reg_sets, 2):
-            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
-                interference_graph[i].add_edge(interference_graph[j])
-        for i, j in op.get_extra_interferences():
-            i = merged_reg_sets[i]
-            j = merged_reg_sets[j]
-            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
-                interference_graph[i].add_edge(interference_graph[j])
-
-    nodes_remaining = set(interference_graph.values())
-
-    def local_colorability_score(node):
-        # type: (IGNode) -> int
-        """ returns a positive integer if node is locally colorable, returns
-        zero or a negative integer if node isn't known to be locally
-        colorable, the more negative the value, the less colorable
-        """
-        if node not in nodes_remaining:
-            raise ValueError()
-        retval = len(node.reg_class)
-        for neighbor in node.edges:
-            if neighbor in nodes_remaining:
-                retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
-        return retval
-
-    node_stack = []  # type: list[IGNode]
-    while True:
-        best_node = None  # type: None | IGNode
-        best_score = 0
-        for node in nodes_remaining:
-            score = local_colorability_score(node)
-            if best_node is None or score > best_score:
-                best_node = node
-                best_score = score
-            if best_score > 0:
-                # it's locally colorable, no need to find a better one
-                break
-
-        if best_node is None:
-            break
-        node_stack.append(best_node)
-        nodes_remaining.remove(best_node)
-
-    retval = {}  # type: dict[SSAVal, RegLoc]
-
-    while len(node_stack) > 0:
-        node = node_stack.pop()
-        if node.reg is not None:
-            if node.reg_conflicts_with_neighbors(node.reg):
-                return AllocationFailed(node=node,
-                                        live_intervals=live_intervals,
-                                        interference_graph=interference_graph)
-        else:
-            # pick the first non-conflicting register in node.reg_class, since
-            # register classes are ordered from most preferred to least
-            # preferred register.
-            for reg in node.reg_class:
-                if not node.reg_conflicts_with_neighbors(reg):
-                    node.reg = reg
-                    break
-            if node.reg is None:
-                return AllocationFailed(node=node,
-                                        live_intervals=live_intervals,
-                                        interference_graph=interference_graph)
-
-        for ssa_val, offset in node.merged_reg_set.items():
-            retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
-
-    return retval
-
-
-def allocate_registers(ops):
-    # type: (list[Op]) -> None
-    raise NotImplementedError
+from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed