working on refactoring register allocator to use new ir
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 1 Nov 2022 06:26:53 +0000 (23:26 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 1 Nov 2022 06:26:53 +0000 (23:26 -0700)
src/bigint_presentation_code/_tests/test_compiler_ir2.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py [new file with mode: 0644]

index 74c38e90e5fc0339455c482e37be2a5a904fd6d2..4aa8e3eda5f31cdceaa1751d431ab6178f162838 100644 (file)
@@ -39,30 +39,32 @@ class TestCompilerIR(unittest.TestCase):
             "Op(kind=OpKind.FuncArgR3, "
             "inputs=[], "
             "immediates=[], "
-            "outputs=(<arg#0: <I64>>,), name='arg')",
+            "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<vl#0: <VL_MAXVL>>,), name='vl')",
+            "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
             "Op(kind=OpKind.SvLd, "
-            "inputs=[<arg#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+            "inputs=[<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
-            "outputs=(<ld#0: <I64*32>>,), name='ld')",
+            "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
             "Op(kind=OpKind.SvLI, "
-            "inputs=[<vl#0: <VL_MAXVL>>], "
+            "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
-            "outputs=(<li#0: <I64*32>>,), name='li')",
+            "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
             "Op(kind=OpKind.SetCA, "
             "inputs=[], "
             "immediates=[], "
-            "outputs=(<ca#0: <CA>>,), name='ca')",
+            "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
             "Op(kind=OpKind.SvAddE, "
-            "inputs=[<ld#0: <I64*32>>, <li#0: <I64*32>>, <ca#0: <CA>>, "
-            "<vl#0: <VL_MAXVL>>], "
+            "inputs=[<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+            "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<add#0: <I64*32>>, <add#1: <CA>>), name='add')",
+            "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+            "name='add')",
             "Op(kind=OpKind.SvStd, "
-            "inputs=[<add#0: <I64*32>>, <arg#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+            "inputs=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
             "outputs=(), name='st')",
         ])
@@ -150,90 +152,93 @@ class TestCompilerIR(unittest.TestCase):
             "Op(kind=OpKind.FuncArgR3, "
             "inputs=[], "
             "immediates=[], "
-            "outputs=(<arg#0: <I64>>,), name='arg')",
+            "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
             "Op(kind=OpKind.CopyFromReg, "
-            "inputs=[<arg#0: <I64>>], "
+            "inputs=[<arg.outputs[0]: <I64>>], "
             "immediates=[], "
-            "outputs=(<2#0: <I64>>,), name='2')",
+            "outputs=(<2.outputs[0]: <I64>>,), name='2')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<vl#0: <VL_MAXVL>>,), name='vl')",
+            "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
             "Op(kind=OpKind.CopyToReg, "
-            "inputs=[<2#0: <I64>>], "
+            "inputs=[<2.outputs[0]: <I64>>], "
             "immediates=[], "
-            "outputs=(<3#0: <I64>>,), name='3')",
+            "outputs=(<3.outputs[0]: <I64>>,), name='3')",
             "Op(kind=OpKind.SvLd, "
-            "inputs=[<3#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+            "inputs=[<3.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
-            "outputs=(<ld#0: <I64*32>>,), name='ld')",
+            "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<4#0: <VL_MAXVL>>,), name='4')",
+            "outputs=(<4.outputs[0]: <VL_MAXVL>>,), name='4')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<ld#0: <I64*32>>, <4#0: <VL_MAXVL>>], "
+            "inputs=[<ld.outputs[0]: <I64*32>>, <4.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<5#0: <I64*32>>,), name='5')",
+            "outputs=(<5.outputs[0]: <I64*32>>,), name='5')",
             "Op(kind=OpKind.SvLI, "
-            "inputs=[<vl#0: <VL_MAXVL>>], "
+            "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
-            "outputs=(<li#0: <I64*32>>,), name='li')",
+            "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<6#0: <VL_MAXVL>>,), name='6')",
+            "outputs=(<6.outputs[0]: <VL_MAXVL>>,), name='6')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<li#0: <I64*32>>, <6#0: <VL_MAXVL>>], "
+            "inputs=[<li.outputs[0]: <I64*32>>, <6.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<7#0: <I64*32>>,), name='7')",
+            "outputs=(<7.outputs[0]: <I64*32>>,), name='7')",
             "Op(kind=OpKind.SetCA, "
             "inputs=[], "
             "immediates=[], "
-            "outputs=(<ca#0: <CA>>,), name='ca')",
+            "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<8#0: <VL_MAXVL>>,), name='8')",
+            "outputs=(<8.outputs[0]: <VL_MAXVL>>,), name='8')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<5#0: <I64*32>>, <8#0: <VL_MAXVL>>], "
+            "inputs=[<5.outputs[0]: <I64*32>>, <8.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<9#0: <I64*32>>,), name='9')",
+            "outputs=(<9.outputs[0]: <I64*32>>,), name='9')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<10#0: <VL_MAXVL>>,), name='10')",
+            "outputs=(<10.outputs[0]: <VL_MAXVL>>,), name='10')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<7#0: <I64*32>>, <10#0: <VL_MAXVL>>], "
+            "inputs=[<7.outputs[0]: <I64*32>>, <10.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<11#0: <I64*32>>,), name='11')",
+            "outputs=(<11.outputs[0]: <I64*32>>,), name='11')",
             "Op(kind=OpKind.SvAddE, "
-            "inputs=[<9#0: <I64*32>>, <11#0: <I64*32>>, <ca#0: <CA>>, "
-            "<vl#0: <VL_MAXVL>>], "
+            "inputs=[<9.outputs[0]: <I64*32>>, <11.outputs[0]: <I64*32>>, "
+            "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<add#0: <I64*32>>, <add#1: <CA>>), name='add')",
+            "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+            "name='add')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<12#0: <VL_MAXVL>>,), name='12')",
+            "outputs=(<12.outputs[0]: <VL_MAXVL>>,), name='12')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<add#0: <I64*32>>, <12#0: <VL_MAXVL>>], "
+            "inputs=[<add.outputs[0]: <I64*32>>, "
+            "<12.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<13#0: <I64*32>>,), name='13')",
+            "outputs=(<13.outputs[0]: <I64*32>>,), name='13')",
             "Op(kind=OpKind.SetVLI, "
             "inputs=[], "
             "immediates=[32], "
-            "outputs=(<14#0: <VL_MAXVL>>,), name='14')",
+            "outputs=(<14.outputs[0]: <VL_MAXVL>>,), name='14')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<13#0: <I64*32>>, <14#0: <VL_MAXVL>>], "
+            "inputs=[<13.outputs[0]: <I64*32>>, <14.outputs[0]: <VL_MAXVL>>], "
             "immediates=[], "
-            "outputs=(<15#0: <I64*32>>,), name='15')",
+            "outputs=(<15.outputs[0]: <I64*32>>,), name='15')",
             "Op(kind=OpKind.CopyToReg, "
-            "inputs=[<2#0: <I64>>], "
+            "inputs=[<2.outputs[0]: <I64>>], "
             "immediates=[], "
-            "outputs=(<16#0: <I64>>,), name='16')",
+            "outputs=(<16.outputs[0]: <I64>>,), name='16')",
             "Op(kind=OpKind.SvStd, "
-            "inputs=[<15#0: <I64*32>>, <16#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+            "inputs=[<15.outputs[0]: <I64*32>>, <16.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
             "immediates=[0], "
             "outputs=(), name='st')",
         ])
@@ -471,16 +476,17 @@ class TestCompilerIR(unittest.TestCase):
                     size_in_bytes=GPR_SIZE_IN_BYTES)
         self.assertEqual(
             repr(state),
-            "PreRASimState(ssa_vals={<arg#0: <I64>>: (0x100,)}, memory={\n"
+            "PreRASimState(ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)}, "
+            "memory={\n"
             "0x00100: <0xffffffffffffffff>,\n"
             "0x00108: <0xabcdef0123456789>})")
         fn.pre_ra_sim(state)
         self.assertEqual(
             repr(state),
             "PreRASimState(ssa_vals={\n"
-            "<arg#0: <I64>>: (0x100,),\n"
-            "<vl#0: <VL_MAXVL>>: (0x20,),\n"
-            "<ld#0: <I64*32>>: (\n"
+            "<arg.outputs[0]: <I64>>: (0x100,),\n"
+            "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
+            "<ld.outputs[0]: <I64*32>>: (\n"
             "    0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
@@ -489,7 +495,7 @@ class TestCompilerIR(unittest.TestCase):
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0),\n"
-            "<li#0: <I64*32>>: (\n"
+            "<li.outputs[0]: <I64*32>>: (\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
@@ -498,8 +504,8 @@ class TestCompilerIR(unittest.TestCase):
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0),\n"
-            "<ca#0: <CA>>: (0x1,),\n"
-            "<add#0: <I64*32>>: (\n"
+            "<ca.outputs[0]: <CA>>: (0x1,),\n"
+            "<add.outputs[0]: <I64*32>>: (\n"
             "    0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
@@ -508,7 +514,7 @@ class TestCompilerIR(unittest.TestCase):
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0,\n"
             "    0x0, 0x0, 0x0, 0x0),\n"
-            "<add#1: <CA>>: (0x0,),\n"
+            "<add.outputs[1]: <CA>>: (0x0,),\n"
             "}, memory={\n"
             "0x00100: <0x0000000000000000>,\n"
             "0x00108: <0xabcdef012345678a>,\n"
index 7109e776d23ce08fc70d1ec2e44b01cde4e77d44..8e365096dee70ed15da4f32303cd611d6c1ad275 100644 (file)
@@ -1,5 +1,5 @@
 import enum
-from abc import abstractmethod
+from abc import ABCMeta, abstractmethod
 from enum import Enum, unique
 from functools import lru_cache
 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
@@ -10,7 +10,7 @@ from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
 from bigint_presentation_code.type_util import Self, assert_never, final
-from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
+from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
 
 
 @final
@@ -103,6 +103,33 @@ class Fn:
                     assert_never(out.ty.base_ty)
 
 
+@plain_data(frozen=True, eq=False)
+@final
+class FnWithUses:
+    __slots__ = "fn", "uses"
+
+    def __init__(self, fn):
+        # type: (Fn) -> None
+        self.fn = fn
+        retval = {}  # type: dict[SSAVal, OSet[SSAUse]]
+        for op in fn.ops:
+            for idx, inp in enumerate(op.inputs):
+                retval[inp].add(SSAUse(op, idx))
+            for out in op.outputs:
+                retval[out] = OSet()
+        self.uses = FMap((k, OFSet(v)) for k, v in retval.items())
+
+    def __eq__(self, other):
+        # type: (FnWithUses | Any) -> bool
+        if isinstance(other, FnWithUses):
+            return self.fn == other.fn
+        return NotImplemented
+
+    def __hash__(self):
+        # type: () -> int
+        return hash(self.fn)
+
+
 @unique
 @final
 class BaseTy(Enum):
@@ -1074,41 +1101,93 @@ class OpKind(Enum):
     _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
 
 
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+class SSAValOrUse(metaclass=ABCMeta):
+    __slots__ = "op",
+
+    def __init__(self, op):
+        # type: (Op) -> None
+        self.op = op
+
+    @abstractmethod
+    def __repr__(self):
+        # type: () -> str
+        ...
+
+    @property
+    @abstractmethod
+    def defining_descriptor(self):
+        # type: () -> OperandDesc
+        ...
+
+    @cached_property
+    def ty(self):
+        # type: () -> Ty
+        return self.defining_descriptor.ty
+
+    @cached_property
+    def ty_before_spread(self):
+        # type: () -> Ty
+        return self.defining_descriptor.ty_before_spread
+
+    @property
+    def base_ty(self):
+        # type: () -> BaseTy
+        return self.ty_before_spread.base_ty
+
+
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
-class SSAVal:
-    __slots__ = "op", "output_idx"
+class SSAVal(SSAValOrUse):
+    __slots__ = "output_idx",
 
     def __init__(self, op, output_idx):
         # type: (Op, int) -> None
-        self.op = op
+        super().__init__(op)
         if output_idx < 0 or output_idx >= len(op.properties.outputs):
             raise ValueError("invalid output_idx")
         self.output_idx = output_idx
 
     def __repr__(self):
         # type: () -> str
-        return f"<{self.op.name}#{self.output_idx}: {self.ty}>"
+        return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>"
+
+    @cached_property
+    def def_loc_set_before_spread(self):
+        # type: () -> LocSet
+        return self.defining_descriptor.loc_set_before_spread
 
     @cached_property
     def defining_descriptor(self):
         # type: () -> OperandDesc
         return self.op.properties.outputs[self.output_idx]
 
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAUse(SSAValOrUse):
+    __slots__ = "input_idx",
+
+    def __init__(self, op, input_idx):
+        # type: (Op, int) -> None
+        super().__init__(op)
+        self.input_idx = input_idx
+        if input_idx < 0 or input_idx >= len(op.inputs):
+            raise ValueError("input_idx out of range")
+
     @cached_property
-    def loc_set_before_spread(self):
+    def use_loc_set_before_spread(self):
         # type: () -> LocSet
         return self.defining_descriptor.loc_set_before_spread
 
     @cached_property
-    def ty(self):
-        # type: () -> Ty
-        return self.defining_descriptor.ty
+    def defining_descriptor(self):
+        # type: () -> OperandDesc
+        return self.op.properties.inputs[self.input_idx]
 
-    @cached_property
-    def ty_before_spread(self):
-        # type: () -> Ty
-        return self.defining_descriptor.ty_before_spread
+    def __repr__(self):
+        # type: () -> str
+        return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>"
 
 
 _T = TypeVar("_T")
@@ -1135,6 +1214,10 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
         self._verify_write_with_desc(idx, item, desc)
         return idx
 
+    def _on_set(self, idx, new_item, old_item):
+        # type: (int, _T, _T | None) -> None
+        pass
+
     @abstractmethod
     def _get_descriptors(self):
         # type: () -> tuple[_Desc, ...]
@@ -1212,6 +1295,10 @@ class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
             raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
                              f"corresponding input's type {desc.ty!r}")
 
+    def _on_set(self, idx, new_item, old_item):
+        # type: (int, SSAVal, SSAVal | None) -> None
+        SSAUses._on_op_input_set(self, idx, new_item, old_item)  # type: ignore
+
     def __init__(self, items, op):
         # type: (Iterable[SSAVal], Op) -> None
         if hasattr(op, "inputs"):
diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py
new file mode 100644 (file)
index 0000000..962a021
--- /dev/null
@@ -0,0 +1,471 @@
+"""
+Register Allocator for Toom-Cook algorithm generator for SVP64
+
+this uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
+from itertools import combinations
+from functools import reduce
+from typing import Generic, Iterable, Mapping
+from cached_property import cached_property
+import operator
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir2 import (
+    Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses)
+from bigint_presentation_code.type_util import final, Self
+from bigint_presentation_code.util import OFSet, OSet, FMap
+
+
+@plain_data(unsafe_hash=True, order=True, frozen=True)
+class LiveInterval:
+    __slots__ = "first_write", "last_use"
+
+    def __init__(self, first_write, last_use=None):
+        # type: (int, int | None) -> None
+        if last_use is None:
+            last_use = first_write
+        if last_use < first_write:
+            raise ValueError("uses must be after first_write")
+        if first_write < 0 or last_use < 0:
+            raise ValueError("indexes must be nonnegative")
+        self.first_write = first_write
+        self.last_use = last_use
+
+    def overlaps(self, other):
+        # type: (LiveInterval) -> bool
+        if self.first_write == other.first_write:
+            return True
+        return self.last_use > other.first_write \
+            and other.last_use > self.first_write
+
+    def __add__(self, use):
+        # type: (int) -> LiveInterval
+        last_use = max(self.last_use, use)
+        return LiveInterval(first_write=self.first_write, last_use=last_use)
+
+    @property
+    def live_after_op_range(self):
+        """the range of op indexes where self is live immediately after the
+        Op at each index
+        """
+        return range(self.first_write, self.last_use)
+
+
+class BadMergedSSAVal(ValueError):
+    pass
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class MergedSSAVal:
+    """a set of `SSAVal`s along with their offsets, all register allocated as
+    a single unit.
+
+    Definition of the term `offset` for this class:
+
+    Let `locs[x]` be the `Loc` that `x` is assigned to after register
+    allocation and let `msv` be a `MergedSSAVal` instance, then the offset
+    for each `SSAVal` `ssa_val` in `msv` is defined as:
+
+    ```
+    msv.ssa_val_offsets[ssa_val] = (msv.offset
+                                    + locs[ssa_val].start - locs[msv].start)
+    ```
+
+    Example:
+    ```
+    v1.ty == <I64*4>
+    v2.ty == <I64*2>
+    v3.ty == <I64>
+    msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
+    msv.ty == <I64*6>
+    ```
+    if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
+    * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
+    * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
+    * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
+    """
+    __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
+
+    def __init__(self, fn_with_uses, ssa_val_offsets):
+        # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
+        self.fn_with_uses = fn_with_uses
+        if isinstance(ssa_val_offsets, SSAVal):
+            ssa_val_offsets = {ssa_val_offsets: 0}
+        self.ssa_val_offsets = FMap(ssa_val_offsets)  # type: FMap[SSAVal, int]
+        base_ty = None
+        for ssa_val in self.ssa_val_offsets.keys():
+            base_ty = ssa_val.base_ty
+            break
+        if base_ty is None:
+            raise BadMergedSSAVal("MergedSSAVal can't be empty")
+        self.base_ty = base_ty  # type: BaseTy
+        # self.ty checks for mismatched base_ty
+        reg_len = self.ty.reg_len
+        loc_set = None  # type: None | LocSet
+        for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
+            def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
+
+            def locs():
+                # type: () -> Iterable[Loc]
+                for loc in ssa_val.def_loc_set_before_spread:
+                    disallowed_by_use = False
+                    for use in fn_with_uses.uses[ssa_val]:
+                        use_spread_idx = \
+                            use.defining_descriptor.spread_index or 0
+                        # calculate the start for the use's Loc before spread
+                        # e.g. if the def's Loc before spread starts at r6
+                        # and the def's spread_index is 5
+                        # and the use's spread_index is 3
+                        # then the use's Loc before spread starts at r8
+                        # because 8 == 6 + 5 - 3
+                        start = loc.start + def_spread_idx - use_spread_idx
+                        use_loc = Loc.try_make(
+                            loc.kind, start=start,
+                            reg_len=use.ty_before_spread.reg_len)
+                        if (use_loc is None or
+                                use_loc not in use.use_loc_set_before_spread):
+                            disallowed_by_use = True
+                            break
+                    if disallowed_by_use:
+                        continue
+                    # FIXME: add spread consistency check
+                    start = loc.start - cur_offset + self.offset
+                    loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
+                    if loc is not None and (loc_set is None or loc in loc_set):
+                        yield loc
+            loc_set = LocSet(locs())
+        assert loc_set is not None, "already checked that self isn't empty"
+        if loc_set.ty is None:
+            raise BadMergedSSAVal("there are no valid Locs left")
+        assert loc_set.ty == self.ty, "logic error somewhere"
+        self.loc_set = loc_set  # type: LocSet
+
+    @cached_property
+    def offset(self):
+        # type: () -> int
+        return min(self.ssa_val_offsets_before_spread.values())
+
+    @cached_property
+    def ty(self):
+        # type: () -> Ty
+        reg_len = 0
+        for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
+            cur_ty = ssa_val.ty_before_spread
+            if self.base_ty != cur_ty.base_ty:
+                raise BadMergedSSAVal(
+                    f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
+            reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
+        return Ty(base_ty=self.base_ty, reg_len=reg_len)
+
+    @cached_property
+    def ssa_val_offsets_before_spread(self):
+        # type: () -> FMap[SSAVal, int]
+        retval = {}  # type: dict[SSAVal, int]
+        for ssa_val, offset in self.ssa_val_offsets.items():
+            offset_before_spread = offset
+            spread_index = ssa_val.defining_descriptor.spread_index
+            if spread_index is not None:
+                assert ssa_val.ty.reg_len == 1, (
+                    "this function assumes spreading always converts a vector "
+                    "to a contiguous sequence of scalars, if that's changed "
+                    "in the future, then this function needs to be adjusted")
+                offset_before_spread -= spread_index
+            retval[ssa_val] = offset_before_spread
+        return FMap(retval)
+
+    def offset_by(self, amount):
+        # type: (int) -> MergedSSAVal
+        v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
+        return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v)
+
+    def normalized(self):
+        # type: () -> MergedSSAVal
+        return self.offset_by(-self.offset)
+
+    def with_offset_to_match(self, target):
+        # type: (MergedSSAVal) -> MergedSSAVal
+        for ssa_val, offset in self.ssa_val_offsets.items():
+            if ssa_val in target.ssa_val_offsets:
+                return self.offset_by(target.ssa_val_offsets[ssa_val] - offset)
+        raise ValueError("can't change offset to match unrelated MergedSSAVal")
+
+
+@final
+class MergedSSAVals(OFSet[MergedSSAVal]):
+    def __init__(self, merged_ssa_vals=()):
+        # type: (Iterable[MergedSSAVal]) -> None
+        super().__init__(merged_ssa_vals)
+        merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
+        for merged_ssa_val in self:
+            for ssa_val in merged_ssa_val.ssa_val_offsets.keys():
+                if ssa_val in merge_map:
+                    raise ValueError(
+                        f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
+                        f"{merged_ssa_val} and {merge_map[ssa_val]}")
+                merge_map[ssa_val] = merged_ssa_val
+        self.__merge_map = FMap(merge_map)
+
+    @cached_property
+    def merge_map(self):
+        # type: () -> FMap[SSAVal, MergedSSAVal]
+        return self.__merge_map
+
+# FIXME: work on code from here
+
+    @staticmethod
+    def minimally_merged(fn_with_uses):
+        # type: (FnWithUses) -> MergedSSAVals
+        merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
+        for op in fn_with_uses.fn.ops:
+            for fn
+            for val in (*op.inputs().values(), *op.outputs().values()):
+                if val not in merged_sets:
+                    merged_sets[val] = MergedRegSet(val)
+            for e in op.get_equality_constraints():
+                lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
+                rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
+                items = []  # type: list[tuple[SSAVal, int]]
+                for i in e.lhs:
+                    s = merged_sets[i].with_offset_to_match(lhs_set)
+                    items.extend(s.items())
+                for i in e.rhs:
+                    s = merged_sets[i].with_offset_to_match(rhs_set)
+                    items.extend(s.items())
+                full_set = MergedRegSet(items)
+                for val in full_set.keys():
+                    merged_sets[val] = full_set
+
+        self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+
+
+@final
+class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
+    def __init__(self, ops):
+        # type: (list[Op]) -> None
+        self.__merged_reg_sets = MergedRegSets(ops)
+        live_intervals = {}  # type: dict[MergedRegSet[_RegType], LiveInterval]
+        for op_idx, op in enumerate(ops):
+            for val in op.inputs().values():
+                live_intervals[self.__merged_reg_sets[val]] += op_idx
+            for val in op.outputs().values():
+                reg_set = self.__merged_reg_sets[val]
+                if reg_set not in live_intervals:
+                    live_intervals[reg_set] = LiveInterval(op_idx)
+                else:
+                    live_intervals[reg_set] += op_idx
+        self.__live_intervals = live_intervals
+        live_after = []  # type: list[OSet[MergedRegSet[_RegType]]]
+        live_after += (OSet() for _ in ops)
+        for reg_set, live_interval in self.__live_intervals.items():
+            for i in live_interval.live_after_op_range:
+                live_after[i].add(reg_set)
+        self.__live_after = [OFSet(i) for i in live_after]
+
+    @property
+    def merged_reg_sets(self):
+        return self.__merged_reg_sets
+
+    def __getitem__(self, key):
+        # type: (MergedRegSet[_RegType]) -> LiveInterval
+        return self.__live_intervals[key]
+
+    def __iter__(self):
+        return iter(self.__live_intervals)
+
+    def __len__(self):
+        return len(self.__live_intervals)
+
+    def reg_sets_live_after(self, op_index):
+        # type: (int) -> OFSet[MergedRegSet[_RegType]]
+        return self.__live_after[op_index]
+
+    def __repr__(self):
+        reg_sets_live_after = dict(enumerate(self.__live_after))
+        return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+                f"merged_reg_sets={self.merged_reg_sets}, "
+                f"reg_sets_live_after={reg_sets_live_after})")
+
+
+@final
+class IGNode(Generic[_RegType]):
+    """ interference graph node """
+    __slots__ = "merged_reg_set", "edges", "reg"
+
+    def __init__(self, merged_reg_set, edges=(), reg=None):
+        # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
+        self.merged_reg_set = merged_reg_set
+        self.edges = OSet(edges)
+        self.reg = reg
+
+    def add_edge(self, other):
+        # type: (IGNode) -> None
+        self.edges.add(other)
+        other.edges.add(self)
+
+    def __eq__(self, other):
+        # type: (object) -> bool
+        if isinstance(other, IGNode):
+            return self.merged_reg_set == other.merged_reg_set
+        return NotImplemented
+
+    def __hash__(self):
+        return hash(self.merged_reg_set)
+
+    def __repr__(self, nodes=None):
+        # type: (None | dict[IGNode, int]) -> str
+        if nodes is None:
+            nodes = {}
+        if self in nodes:
+            return f"<IGNode #{nodes[self]}>"
+        nodes[self] = len(nodes)
+        edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
+        return (f"IGNode(#{nodes[self]}, "
+                f"merged_reg_set={self.merged_reg_set}, "
+                f"edges={edges}, "
+                f"reg={self.reg})")
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return self.merged_reg_set.ty.reg_class
+
+    def reg_conflicts_with_neighbors(self, reg):
+        # type: (RegLoc) -> bool
+        for neighbor in self.edges:
+            if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+                return True
+        return False
+
+
+@final
+class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
+    def __init__(self, merged_reg_sets):
+        # type: (Iterable[MergedRegSet[_RegType]]) -> None
+        self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
+
+    def __getitem__(self, key):
+        # type: (MergedRegSet[_RegType]) -> IGNode
+        return self.__nodes[key]
+
+    def __iter__(self):
+        return iter(self.__nodes)
+
+    def __len__(self):
+        return len(self.__nodes)
+
+    def __repr__(self):
+        nodes = {}
+        nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+        nodes_text = ", ".join(nodes_text)
+        return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
+
+@plain_data()
+class AllocationFailed:
+    __slots__ = "node", "live_intervals", "interference_graph"
+
+    def __init__(self, node, live_intervals, interference_graph):
+        # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+        self.node = node
+        self.live_intervals = live_intervals
+        self.interference_graph = interference_graph
+
+
+class AllocationFailedError(Exception):
+    def __init__(self, msg, allocation_failed):
+        # type: (str, AllocationFailed) -> None
+        super().__init__(msg, allocation_failed)
+        self.allocation_failed = allocation_failed
+
+
+def try_allocate_registers_without_spilling(ops):
+    # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+
+    live_intervals = LiveIntervals(ops)
+    merged_reg_sets = live_intervals.merged_reg_sets
+    interference_graph = InterferenceGraph(merged_reg_sets.values())
+    for op_idx, op in enumerate(ops):
+        reg_sets = live_intervals.reg_sets_live_after(op_idx)
+        for i, j in combinations(reg_sets, 2):
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
+        for i, j in op.get_extra_interferences():
+            i = merged_reg_sets[i]
+            j = merged_reg_sets[j]
+            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+                interference_graph[i].add_edge(interference_graph[j])
+
+    nodes_remaining = OSet(interference_graph.values())
+
+    def local_colorability_score(node):
+        # type: (IGNode) -> int
+        """ returns a positive integer if node is locally colorable, returns
+        zero or a negative integer if node isn't known to be locally
+        colorable, the more negative the value, the less colorable
+        """
+        if node not in nodes_remaining:
+            raise ValueError()
+        retval = len(node.reg_class)
+        for neighbor in node.edges:
+            if neighbor in nodes_remaining:
+                retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+        return retval
+
+    node_stack = []  # type: list[IGNode]
+    while True:
+        best_node = None  # type: None | IGNode
+        best_score = 0
+        for node in nodes_remaining:
+            score = local_colorability_score(node)
+            if best_node is None or score > best_score:
+                best_node = node
+                best_score = score
+            if best_score > 0:
+                # it's locally colorable, no need to find a better one
+                break
+
+        if best_node is None:
+            break
+        node_stack.append(best_node)
+        nodes_remaining.remove(best_node)
+
+    retval = {}  # type: dict[SSAVal, RegLoc]
+
+    while len(node_stack) > 0:
+        node = node_stack.pop()
+        if node.reg is not None:
+            if node.reg_conflicts_with_neighbors(node.reg):
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+        else:
+            # pick the first non-conflicting register in node.reg_class, since
+            # register classes are ordered from most preferred to least
+            # preferred register.
+            for reg in node.reg_class:
+                if not node.reg_conflicts_with_neighbors(reg):
+                    node.reg = reg
+                    break
+            if node.reg is None:
+                return AllocationFailed(node=node,
+                                        live_intervals=live_intervals,
+                                        interference_graph=interference_graph)
+
+        for ssa_val, offset in node.merged_reg_set.items():
+            retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
+
+    return retval
+
+
+def allocate_registers(ops):
+    # type: (list[Op]) -> dict[SSAVal, RegLoc]
+    retval = try_allocate_registers_without_spilling(ops)
+    if isinstance(retval, AllocationFailed):
+        # TODO: implement spilling
+        raise AllocationFailedError(
+            "spilling required but not yet implemented", retval)
+    return retval