working on generating output assembly
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 18 Oct 2022 07:36:02 +0000 (00:36 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 18 Oct 2022 07:36:02 +0000 (00:36 -0700)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/matrix.py
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/test_compiler_ir.py
src/bigint_presentation_code/test_register_allocator.py

index aa37fa48eaa8da590481956ee95a028e5d85e51e..f264b018baeaa14ed83e2e431a47da5b39a2b944 100644 (file)
@@ -1,14 +1,16 @@
 """
 Compiler IR for Toom-Cook algorithm generator for SVP64
+
+This assumes VL != 0 throughout.
 """
 
 from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, EnumMeta, unique
 from functools import lru_cache
-from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
+from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type,
+                    TypeVar, cast)
 
-from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
 from bigint_presentation_code.ordered_set import OFSet, OSet
@@ -126,7 +128,7 @@ SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
 @final
 @unique
 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
-    CY = "CY"
+    CA = "CA"
 
     def conflicts(self, other):
         # type: (RegLoc) -> bool
@@ -150,6 +152,19 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
         return False
 
 
+@final
+@unique
+class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
+    VL_MAXVL = "VL_MAXVL"
+    """VL and MAXVL"""
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, VL):
+            return self == other
+        return False
+
+
 @final
 class RegClass(OFSet[RegLoc]):
     """ an ordered set of registers.
@@ -180,13 +195,15 @@ class RegType(metaclass=ABCMeta):
 
 
 _RegType = TypeVar("_RegType", bound=RegType)
+_RegLoc = TypeVar("_RegLoc", bound=RegLoc)
 
 
 @plain_data(frozen=True, eq=False)
+@final
 class GPRRangeType(RegType):
     __slots__ = "length",
 
-    def __init__(self, length):
+    def __init__(self, length=1):
         # type: (int) -> None
         if length < 1 or length > GPR_COUNT:
             raise ValueError("invalid length")
@@ -205,6 +222,7 @@ class GPRRangeType(RegType):
         return RegClass(regs)
 
     @property
+    @final
     def reg_class(self):
         # type: () -> RegClass
         return GPRRangeType.__get_reg_class(self.length)
@@ -220,15 +238,8 @@ class GPRRangeType(RegType):
         return hash(self.length)
 
 
-@plain_data(frozen=True, eq=False)
-@final
-class GPRType(GPRRangeType):
-    __slots__ = ()
-
-    def __init__(self, length=1):
-        if length != 1:
-            raise ValueError("length must be 1")
-        super().__init__(length=1)
+GPRType = GPRRangeType
+"""a length=1 GPRRangeType"""
 
 
 @plain_data(frozen=True, unsafe_hash=True)
@@ -253,13 +264,13 @@ class FixedGPRRangeType(RegType):
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class CYType(RegType):
+class CAType(RegType):
     __slots__ = ()
 
     @property
     def reg_class(self):
         # type: () -> RegClass
-        return RegClass([XERBit.CY])
+        return RegClass([XERBit.CA])
 
 
 @plain_data(frozen=True, unsafe_hash=True)
@@ -273,6 +284,39 @@ class GlobalMemType(RegType):
         return RegClass([GlobalMem.GlobalMem])
 
 
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class KnownVLType(RegType):
+    __slots__ = "length",
+
+    def __init__(self, length):
+        # type: (int) -> None
+        if not (0 < length <= 64):
+            raise ValueError("invalid VL value")
+        self.length = length
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return RegClass([VL.VL_MAXVL])
+
+
+def assert_vl_is(vl, expected_vl):
+    # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
+    if vl is None:
+        vl = 1
+    elif isinstance(vl, SSAVal):
+        vl = vl.ty.length
+    elif isinstance(vl, KnownVLType):
+        vl = vl.length
+    if vl != expected_vl:
+        raise ValueError(
+            f"wrong VL: expected {expected_vl} got {vl}")
+
+
+STACK_SLOT_SIZE = 8
+
+
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class StackSlot(RegLoc):
@@ -289,6 +333,10 @@ class StackSlot(RegLoc):
     def stop_slot(self):
         return self.start_slot + self.length_in_slots
 
+    @property
+    def start_byte(self):
+        return self.start_slot * STACK_SLOT_SIZE
+
     def conflicts(self, other):
         # type: (RegLoc) -> bool
         if isinstance(other, StackSlot):
@@ -386,6 +434,11 @@ class SSAVal(Generic[_RegType]):
         return f"SSAVal({fields_str})"
 
 
+SSAGPRRange = SSAVal[GPRRangeType]
+SSAGPR = SSAVal[GPRType]
+SSAKnownVL = SSAVal[KnownVLType]
+
+
 @final
 @plain_data(unsafe_hash=True, frozen=True)
 class EqualityConstraint:
@@ -424,6 +477,177 @@ class _NotSet:
 _NOT_SET = _NotSet()
 
 
+@plain_data(frozen=True, unsafe_hash=True)
+class AsmTemplateSegment(Generic[_RegType], metaclass=ABCMeta):
+    __slots__ = "ssa_val",
+
+    def __init__(self, ssa_val):
+        # type: (SSAVal[_RegType]) -> None
+        self.ssa_val = ssa_val
+
+    def render(self, regs):
+        # type: (dict[SSAVal, RegLoc]) -> str
+        return self._render(regs[self.ssa_val])
+
+    @abstractmethod
+    def _render(self, reg):
+        # type: (RegLoc) -> str
+        ...
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSGPR(AsmTemplateSegment[GPRRangeType]):
+    __slots__ = "offset",
+
+    def __init__(self, ssa_val, offset=0):
+        # type: (SSAGPRRange, int) -> None
+        super().__init__(ssa_val)
+        self.offset = offset
+
+    def _render(self, reg):
+        # type: (RegLoc) -> str
+        if not isinstance(reg, GPRRange):
+            raise TypeError()
+        return str(reg.start + self.offset)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSStackSlot(AsmTemplateSegment[StackSlotType]):
+    __slots__ = ()
+
+    def _render(self, reg):
+        # type: (RegLoc) -> str
+        if not isinstance(reg, StackSlot):
+            raise TypeError()
+        return f"{reg.start_slot}(1)"
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSCopyGPRRange(AsmTemplateSegment["GPRRangeType | FixedGPRRangeType"]):
+    __slots__ = "src_ssa_val",
+
+    def __init__(self, ssa_val, src_ssa_val):
+        # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
+        self.ssa_val = ssa_val
+        self.src_ssa_val = src_ssa_val
+
+    def render(self, regs):
+        # type: (dict[SSAVal, RegLoc]) -> str
+        src = regs[self.src_ssa_val]
+        dest = regs[self.ssa_val]
+        if not isinstance(dest, GPRRange):
+            raise TypeError()
+        if not isinstance(src, GPRRange):
+            raise TypeError()
+        if src.length != dest.length:
+            raise ValueError()
+        if src == dest:
+            return ""
+        mrr = ""
+        sv_ = "sv."
+        if src.length == 1:
+            sv_ = ""
+        elif src.conflicts(dest) and src.start > dest.start:
+            mrr = "/mrr"
+        return f"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
+
+    def _render(self, reg):
+        # type: (RegLoc) -> str
+        raise TypeError("must call self.render")
+
+
+@final
+class AsmTemplate(Sequence["str | AsmTemplateSegment"]):
+    @staticmethod
+    def __process_segments(segments):
+        # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
+        for i in segments:
+            if isinstance(i, AsmTemplate):
+                yield from i
+            else:
+                yield i
+
+    def __init__(self, segments=()):
+        # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
+        self.__segments = tuple(self.__process_segments(segments))
+
+    def __getitem__(self, index):
+        # type: (int) -> str | AsmTemplateSegment
+        return self.__segments[index]
+
+    def __len__(self):
+        return len(self.__segments)
+
+    def __iter__(self):
+        return iter(self.__segments)
+
+    def __hash__(self):
+        return hash(self.__segments)
+
+    def render(self, regs):
+        # type: (dict[SSAVal, RegLoc]) -> str
+        retval = []  # type: list[str]
+        for segment in self:
+            if isinstance(segment, AsmTemplateSegment):
+                retval.append(segment.render(regs))
+            else:
+                retval.append(segment)
+        return "".join(retval)
+
+
+@final
+class AsmContext:
+    def __init__(self, assigned_registers):
+        # type: (dict[SSAVal, RegLoc]) -> None
+        self.__assigned_registers = assigned_registers
+
+    def reg(self, ssa_val, expected_ty):
+        # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
+        try:
+            reg = self.__assigned_registers[ssa_val]
+        except KeyError as e:
+            raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
+        wrong_len = (isinstance(reg, GPRRange)
+                     and reg.length != ssa_val.ty.length)
+        if not isinstance(reg, expected_ty) or wrong_len:
+            raise TypeError(
+                f"SSAVal is assigned a register of the wrong type: "
+                f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
+        return reg
+
+    def gpr_range(self, ssa_val):
+        # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
+        return self.reg(ssa_val, GPRRange)
+
+    def stack_slot(self, ssa_val):
+        # type: (SSAVal[StackSlotType]) -> StackSlot
+        return self.reg(ssa_val, StackSlot)
+
+    def gpr(self, ssa_val, vec, offset=0):
+        # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
+        reg = self.gpr_range(ssa_val).start + offset
+        return "*" * vec + str(reg)
+
+    def vgpr(self, ssa_val, offset=0):
+        # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
+        return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
+
+    def sgpr(self, ssa_val, offset=0):
+        # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
+        return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
+
+    def needs_sv(self, *regs):
+        # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
+        for reg in regs:
+            reg = self.gpr_range(reg)
+            if reg.length != 1 or reg.start >= 32:
+                return True
+        return False
+
+
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 class Op(metaclass=ABCMeta):
     __slots__ = "id", "fn"
@@ -476,45 +700,77 @@ class Op(metaclass=ABCMeta):
         fields_str = ', '.join(fields_list)
         return f"{self.__class__.__name__}({fields_str})"
 
+    @abstractmethod
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        """get the lines of assembly for this Op"""
+        ...
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLoadFromStackSlot(Op):
-    __slots__ = "dest", "src"
+    __slots__ = "dest", "src", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
+        retval = {"src": self.src}  # type: dict[str, SSAVal[Any]]
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, fn, src):
-        # type: (Fn, SSAVal[GPRRangeType]) -> None
+    def __init__(self, fn, src, vl=None):
+        # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
         super().__init__(fn)
-        self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+        self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
         self.src = src
+        self.vl = vl
+        assert_vl_is(vl, self.dest.ty.length)
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
+        src = ctx.stack_slot(self.src)
+        if ctx.needs_sv(self.dest):
+            return [f"sv.ld {dest}, {src.start_byte}(1)"]
+        return [f"ld {dest}, {src.start_byte}(1)"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpStoreToStackSlot(Op):
-    __slots__ = "dest", "src"
+    __slots__ = "dest", "src", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
+        retval = {"src": self.src}  # type: dict[str, SSAVal[Any]]
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, fn, src):
-        # type: (Fn, SSAVal[StackSlotType]) -> None
+    def __init__(self, fn, src, vl=None):
+        # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
         super().__init__(fn)
-        self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+        self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
         self.src = src
+        self.vl = vl
+        assert_vl_is(vl, src.ty.length)
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
+        dest = ctx.stack_slot(self.dest)
+        if ctx.needs_sv(self.src):
+            return [f"sv.std {src}, {dest.start_byte}(1)"]
+        return [f"std {src}, {dest.start_byte}(1)"]
 
 
 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
@@ -523,18 +779,21 @@ _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
-    __slots__ = "dest", "src"
+    __slots__ = "dest", "src", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"src": self.src}
+        retval = {"src": self.src}  # type: dict[str, SSAVal[Any]]
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
-    def __init__(self, fn, src, dest_ty=None):
-        # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None
+    def __init__(self, fn, src, dest_ty=None, vl=None):
+        # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
         super().__init__(fn)
         if dest_ty is None:
             dest_ty = cast(_RegType, src.ty)
@@ -543,17 +802,44 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
             if src.ty.length != dest_ty.reg.length:
                 raise ValueError(f"incompatible source and destination "
                                  f"types: {src.ty} and {dest_ty}")
+            length = src.ty.length
         elif isinstance(src.ty, FixedGPRRangeType) \
                 and isinstance(dest_ty, GPRRangeType):
             if src.ty.reg.length != dest_ty.length:
                 raise ValueError(f"incompatible source and destination "
                                  f"types: {src.ty} and {dest_ty}")
+            length = src.ty.length
         elif src.ty != dest_ty:
             raise ValueError(f"incompatible source and destination "
                              f"types: {src.ty} and {dest_ty}")
+        elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
+            length = src.ty.length
+        else:
+            length = 1
 
         self.dest = SSAVal(self, "dest", dest_ty)  # type: SSAVal[_RegType]
         self.src = src
+        self.vl = vl
+        assert_vl_is(vl, length)
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
+            return []
+        if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
+                isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
+            vec = self.dest.ty.length != 1
+            dest = ctx.gpr_range(self.dest)  # type: ignore
+            src = ctx.gpr_range(self.src)  # type: ignore
+            dest_s = ctx.gpr(self.dest, vec=vec)  # type: ignore
+            src_s = ctx.gpr(self.src, vec=vec)  # type: ignore
+            mrr = ""
+            if src.conflicts(dest) and src.start > dest.start:
+                mrr = "/mrr"
+            if ctx.needs_sv(self.src, self.dest):  # type: ignore
+                return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
+            return [f"or {dest_s}, {src_s}, {src_s}"]
+        raise NotImplementedError
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
@@ -570,7 +856,7 @@ class OpConcat(Op):
         return {"dest": self.dest}
 
     def __init__(self, fn, sources):
-        # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None
+        # type: (Fn, Iterable[SSAGPRRange]) -> None
         super().__init__(fn)
         sources = tuple(sources)
         self.dest = SSAVal(self, "dest", GPRRangeType(
@@ -581,6 +867,10 @@ class OpConcat(Op):
         # type: () -> Iterable[EqualityConstraint]
         yield EqualityConstraint([self.dest], [*self.sources])
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        return []
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -596,7 +886,7 @@ class OpSplit(Op):
         return {i.arg_name: i for i in self.results}
 
     def __init__(self, fn, src, split_indexes):
-        # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None
+        # type: (Fn, SSAGPRRange, Iterable[int]) -> None
         super().__init__(fn)
         ranges = []  # type: list[GPRRangeType]
         last = 0
@@ -615,54 +905,86 @@ class OpSplit(Op):
         # type: () -> Iterable[EqualityConstraint]
         yield EqualityConstraint([*self.results], [self.src])
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        return []
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
-class OpAddSubE(Op):
-    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
+class OpBigIntAddSub(Op):
+    __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        retval["lhs"] = self.lhs
+        retval["rhs"] = self.rhs
+        retval["CA_in"] = self.CA_in
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT, "CY_out": self.CY_out}
+        return {"out": self.out, "CA_out": self.CA_out}
 
-    def __init__(self, fn, RA, RB, CY_in, is_sub):
-        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+    def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
+        # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
         super().__init__(fn)
-        if RA.ty != RB.ty:
+        if lhs.ty != rhs.ty:
             raise TypeError(f"source types must match: "
-                            f"{RA} doesn't match {RB}")
-        self.RT = SSAVal(self, "RT", RA.ty)
-        self.RA = RA
-        self.RB = RB
-        self.CY_in = CY_in
-        self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
+                            f"{lhs} doesn't match {rhs}")
+        self.out = SSAVal(self, "out", lhs.ty)
+        self.lhs = lhs
+        self.rhs = rhs
+        self.CA_in = CA_in
+        self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
         self.is_sub = is_sub
+        self.vl = vl
+        assert_vl_is(vl, lhs.ty.length)
 
     def get_extra_interferences(self):
         # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        yield self.RT, self.RA
-        yield self.RT, self.RB
+        yield self.out, self.lhs
+        yield self.out, self.rhs
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        vec = self.out.ty.length != 1
+        out = ctx.gpr(self.out, vec=vec)
+        RA = ctx.gpr(self.lhs, vec=vec)
+        RB = ctx.gpr(self.rhs, vec=vec)
+        mnemonic = "adde"
+        if self.is_sub:
+            mnemonic = "subfe"
+            RA, RB = RB, RA  # reorder to match subfe
+        if ctx.needs_sv(self.out, self.lhs, self.rhs):
+            return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
+        return [f"{mnemonic} {out}, {RA}, {RB}"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpBigIntMulDiv(Op):
-    __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
+    __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        retval["RA"] = self.RA
+        retval["RB"] = self.RB
+        retval["RC"] = self.RC
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT, "RS": self.RS}
 
-    def __init__(self, fn, RA, RB, RC, is_div):
-        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+    def __init__(self, fn, RA, RB, RC, is_div, vl):
+        # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
         super().__init__(fn)
         self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
@@ -670,6 +992,8 @@ class OpBigIntMulDiv(Op):
         self.RC = RC
         self.RS = SSAVal(self, "RS", RC.ty)
         self.is_div = is_div
+        self.vl = vl
+        assert_vl_is(vl, RA.ty.length)
 
     def get_equality_constraints(self):
         # type: () -> Iterable[EqualityConstraint]
@@ -684,6 +1008,18 @@ class OpBigIntMulDiv(Op):
         yield self.RS, self.RA
         yield self.RS, self.RB
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        vec = self.RT.ty.length != 1
+        RT = ctx.gpr(self.RT, vec=vec)
+        RA = ctx.gpr(self.RA, vec=vec)
+        RB = ctx.sgpr(self.RB)
+        RC = ctx.sgpr(self.RC)
+        mnemonic = "maddedu"
+        if self.is_div:
+            mnemonic = "divmod2du/mrr"
+        return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
+
 
 @final
 @unique
@@ -692,58 +1028,179 @@ class ShiftKind(Enum):
     Sr = "sr"
     Sra = "sra"
 
+    def make_big_int_carry_in(self, fn, inp):
+        # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
+        if self is ShiftKind.Sl or self is ShiftKind.Sr:
+            li = OpLI(fn, 0)
+            return li.out, [li]
+        else:
+            assert self is ShiftKind.Sra
+            split = OpSplit(fn, inp, [inp.ty.length - 1])
+            shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
+            return shr.out, [split, shr]
+
+    def make_big_int_shift(self, fn, inp, sh, vl):
+        # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
+        carry_in, ops = self.make_big_int_carry_in(fn, inp)
+        big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
+        ops.append(big_int_shift)
+        return big_int_shift.out, ops
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpBigIntShift(Op):
-    __slots__ = "RT", "inp", "sh", "kind"
+    __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"inp": self.inp, "sh": self.sh}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        retval["inp"] = self.inp
+        retval["sh"] = self.sh
+        retval["carry_in"] = self.carry_in
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RT": self.RT}
+        return {"out": self.out, "_out_padding": self._out_padding}
 
-    def __init__(self, fn, inp, sh, kind):
-        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+    def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
+        # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
         super().__init__(fn)
-        self.RT = SSAVal(self, "RT", inp.ty)
+        self.out = SSAVal(self, "out", inp.ty)
+        self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
+        self.carry_in = carry_in
         self.inp = inp
         self.sh = sh
         self.kind = kind
+        self.vl = vl
+        assert_vl_is(vl, inp.ty.length)
 
     def get_extra_interferences(self):
         # type: () -> Iterable[tuple[SSAVal, SSAVal]]
-        yield self.RT, self.inp
-        yield self.RT, self.sh
+        yield self.out, self.sh
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        if self.kind is ShiftKind.Sl:
+            yield EqualityConstraint([self.carry_in, self.inp],
+                                     [self.out, self._out_padding])
+        else:
+            assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+            yield EqualityConstraint([self.inp, self.carry_in],
+                                     [self._out_padding, self.out])
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        vec = self.out.ty.length != 1
+        if self.kind is ShiftKind.Sl:
+            RT = ctx.gpr(self.out, vec=vec)
+            RA = ctx.gpr(self.out, vec=vec, offset=-1)
+            RB = ctx.sgpr(self.sh)
+            mrr = "/mrr" if vec else ""
+            return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
+        else:
+            assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+            RT = ctx.gpr(self.out, vec=vec)
+            RA = ctx.gpr(self.out, vec=vec, offset=1)
+            RB = ctx.sgpr(self.sh)
+            return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
+@final
+class OpShiftImm(Op):
+    __slots__ = "out", "inp", "sh", "kind", "ca_out"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"inp": self.inp}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        if self.ca_out is not None:
+            return {"out": self.out, "ca_out": self.ca_out}
+        return {"out": self.out}
+
+    def __init__(self, fn, inp, sh, kind):
+        # type: (Fn, SSAGPR, int, ShiftKind) -> None
+        super().__init__(fn)
+        self.out = SSAVal(self, "out", inp.ty)
+        self.inp = inp
+        if not (0 <= sh < 64):
+            raise ValueError("shift amount out of range")
+        self.sh = sh
+        self.kind = kind
+        if self.kind is ShiftKind.Sra:
+            self.ca_out = SSAVal(self, "ca_out", CAType())
+        else:
+            self.ca_out = None
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        out = ctx.sgpr(self.out)
+        inp = ctx.sgpr(self.inp)
+        if self.kind is ShiftKind.Sl:
+            mnemonic = "rldicr"
+            args = f"{self.sh}, {63 - self.sh}"
+        elif self.kind is ShiftKind.Sr:
+            mnemonic = "rldicl"
+            v = (64 - self.sh) % 64
+            args = f"{v}, {self.sh}"
+        else:
+            assert self.kind is ShiftKind.Sra
+            mnemonic = "sradi"
+            args = f"{self.sh}"
+        if ctx.needs_sv(self.out, self.inp):
+            return [f"sv.{mnemonic} {out}, {inp}, {args}"]
+        return [f"{mnemonic} {out}, {inp}, {args}"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLI(Op):
-    __slots__ = "out", "value"
+    __slots__ = "out", "value", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, fn, value, length=1):
-        # type: (Fn, int, int) -> None
+    def __init__(self, fn, value, vl=None):
+        # type: (Fn, int, SSAKnownVL | None) -> None
         super().__init__(fn)
+        if vl is None:
+            length = 1
+        else:
+            length = vl.ty.length
         self.out = SSAVal(self, "out", GPRRangeType(length))
+        if not (-1 << 15 <= value <= (1 << 15) - 1):
+            raise ValueError(f"value out of range: {value}")
         self.value = value
+        self.vl = vl
+        assert_vl_is(vl, length)
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        vec = self.out.ty.length != 1
+        out = ctx.gpr(self.out, vec=vec)
+        if ctx.needs_sv(self.out):
+            return [f"sv.addi {out}, 0, {self.value}"]
+        return [f"addi {out}, 0, {self.value}"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
-class OpClearCY(Op):
-    __slots__ = "out",
+class OpSetCA(Op):
+    __slots__ = "out", "value"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
@@ -753,60 +1210,110 @@ class OpClearCY(Op):
         # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, fn):
-        # type: (Fn) -> None
+    def __init__(self, fn, value):
+        # type: (Fn, bool) -> None
         super().__init__(fn)
-        self.out = SSAVal(self, "out", CYType())
+        self.out = SSAVal(self, "out", CAType())
+        self.value = value
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        if self.value:
+            return ["subfic 0, 0, -1"]
+        return ["addic 0, 0, 0"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpLoad(Op):
-    __slots__ = "RT", "RA", "offset", "mem"
+    __slots__ = "RT", "RA", "offset", "mem", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RA": self.RA, "mem": self.mem}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        retval["RA"] = self.RA
+        retval["mem"] = self.mem
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, fn, RA, offset, mem, length=1):
-        # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+    def __init__(self, fn, RA, offset, mem, vl=None):
+        # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
         super().__init__(fn)
+        if vl is None:
+            length = 1
+        else:
+            length = vl.ty.length
         self.RT = SSAVal(self, "RT", GPRRangeType(length))
         self.RA = RA
+        if not (-1 << 15 <= offset <= (1 << 15) - 1):
+            raise ValueError(f"offset out of range: {offset}")
+        if offset % 4 != 0:
+            raise ValueError(f"offset not aligned: {offset}")
         self.offset = offset
         self.mem = mem
+        self.vl = vl
+        assert_vl_is(vl, length)
 
     def get_extra_interferences(self):
         # type: () -> Iterable[tuple[SSAVal, SSAVal]]
         if self.RT.ty.length > 1:
             yield self.RT, self.RA
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
+        RA = ctx.sgpr(self.RA)
+        if ctx.needs_sv(self.RT, self.RA):
+            return [f"sv.ld {RT}, {self.offset}({RA})"]
+        return [f"ld {RT}, {self.offset}({RA})"]
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
 class OpStore(Op):
-    __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
+    __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
 
     def inputs(self):
         # type: () -> dict[str, SSAVal]
-        return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
+        retval = {}  # type: dict[str, SSAVal[Any]]
+        retval["RS"] = self.RS
+        retval["RA"] = self.RA
+        retval["mem_in"] = self.mem_in
+        if self.vl is not None:
+            retval["vl"] = self.vl
+        return retval
 
     def outputs(self):
         # type: () -> dict[str, SSAVal]
         return {"mem_out": self.mem_out}
 
-    def __init__(self, fn, RS, RA, offset, mem_in):
-        # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+    def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
+        # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
         super().__init__(fn)
         self.RS = RS
         self.RA = RA
+        if not (-1 << 15 <= offset <= (1 << 15) - 1):
+            raise ValueError(f"offset out of range: {offset}")
+        if offset % 4 != 0:
+            raise ValueError(f"offset not aligned: {offset}")
         self.offset = offset
         self.mem_in = mem_in
         self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
+        self.vl = vl
+        assert_vl_is(vl, RS.ty.length)
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
+        RA = ctx.sgpr(self.RA)
+        if ctx.needs_sv(self.RS, self.RA):
+            return [f"sv.std {RS}, {self.offset}({RA})"]
+        return [f"std {RS}, {self.offset}({RA})"]
 
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
@@ -827,6 +1334,10 @@ class OpFuncArg(Op):
         super().__init__(fn)
         self.out = SSAVal(self, "out", ty)
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        return []
+
 
 @plain_data(unsafe_hash=True, frozen=True, repr=False)
 @final
@@ -846,6 +1357,33 @@ class OpInputMem(Op):
         super().__init__(fn)
         self.out = SSAVal(self, "out", GlobalMemType())
 
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        return []
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
+@final
+class OpSetVLImm(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, fn, length):
+        # type: (Fn, int) -> None
+        super().__init__(fn)
+        self.out = SSAVal(self, "out", KnownVLType(length))
+
+    def get_asm_lines(self, ctx):
+        # type: (AsmContext) -> list[str]
+        return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
+
 
 def op_set_to_list(ops):
     # type: (Iterable[Op]) -> list[Op]
index 2636be85106c6b946de3d511e10cd4ac9cb740c3..3e1e1540143eb8b58cd8f59174f56dc7663736a7 100644 (file)
@@ -1,7 +1,7 @@
 import operator
-from typing import Callable, Iterable
 from fractions import Fraction
 from numbers import Rational
+from typing import Callable, Iterable
 
 
 class Matrix:
index b44c32dd76023c948ff9be447d34eb16e02a77bf..a22299f842badfdb4d48439907f7ec373b31ed52 100644 (file)
@@ -342,6 +342,13 @@ class AllocationFailed:
         self.interference_graph = interference_graph
 
 
+class AllocationFailedError(Exception):
+    def __init__(self, msg, allocation_failed):
+        # type: (str, AllocationFailed) -> None
+        super().__init__(msg, allocation_failed)
+        self.allocation_failed = allocation_failed
+
+
 def try_allocate_registers_without_spilling(ops):
     # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
 
@@ -422,5 +429,10 @@ def try_allocate_registers_without_spilling(ops):
 
 
 def allocate_registers(ops):
-    # type: (list[Op]) -> None
-    raise NotImplementedError
+    # type: (list[Op]) -> dict[SSAVal, RegLoc]
+    retval = try_allocate_registers_without_spilling(ops)
+    if isinstance(retval, AllocationFailed):
+        # TODO: implement spilling
+        raise AllocationFailedError(
+            "spilling required but not yet implemented", retval)
+    return retval
index ff52641fa273a654f80f02865b41517c419b1f33..0c8a3eee4bd90617d102664614922875abab3e13 100644 (file)
@@ -1,7 +1,11 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType,
-                                                  Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore,
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn,
+                                                  GPRRange, GPRType,
+                                                  OpBigIntAddSub, OpConcat,
+                                                  OpCopy, OpFuncArg,
+                                                  OpInputMem, OpLI, OpLoad,
+                                                  OpSetCA, OpSetVLImm, OpStore,
                                                   op_set_to_list)
 
 
@@ -15,32 +19,41 @@ class TestCompilerIR(unittest.TestCase):
         arg = op1.dest
         op2 = OpInputMem(fn)
         mem = op2.out
-        op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
-        a = op3.RT
-        op4 = OpLI(fn, 1)
-        b_0 = op4.out
-        op5 = OpLI(fn, 0, length=31)
-        b_rest = op5.out
-        op6 = OpConcat(fn, [b_0, b_rest])
-        b = op6.dest
-        op7 = OpClearCY(fn)
-        cy = op7.out
-        op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
-        s = op8.RT
-        op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
-        mem = op9.mem_out
+        op3 = OpSetVLImm(fn, 32)
+        vl = op3.out
+        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+        a = op4.RT
+        op5 = OpLI(fn, 1)
+        b_0 = op5.out
+        op6 = OpSetVLImm(fn, 31)
+        vl = op6.out
+        op7 = OpLI(fn, 0, vl=vl)
+        b_rest = op7.out
+        op8 = OpConcat(fn, [b_0, b_rest])
+        b = op8.dest
+        op9 = OpSetVLImm(fn, 32)
+        vl = op9.out
+        op10 = OpSetCA(fn, False)
+        ca = op10.out
+        op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+        s = op11.out
+        op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        mem = op12.mem_out
 
         expected_ops = [
-            op7,  # OpClearCY()
-            op5,  # OpLI(0, length=31)
-            op4,  # OpLI(1)
-            op2,  # OpInputMem()
-            op0,  # OpFuncArg(FixedGPRRangeType(GPRRange(3)))
-            op6,  # OpConcat([b_0, b_rest])
-            op1,  # OpCopy(op0.out, GPRType())
-            op3,  # OpLoad(arg, offset=0, mem=mem, length=32)
-            op8,  # OpAddSubE(a, b, cy, is_sub=False)
-            op9,  # OpStore(s, arg, offset=0, mem_in=mem)
+            op10,  # OpSetCA(fn, False)
+            op9,  # OpSetVLImm(fn, 32)
+            op6,  # OpSetVLImm(fn, 31)
+            op5,  # OpLI(fn, 1)
+            op3,  # OpSetVLImm(fn, 32)
+            op2,  # OpInputMem(fn)
+            op0,  # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+            op7,  # OpLI(fn, 0, vl=vl)
+            op1,  # OpCopy(fn, op0.out, GPRType())
+            op8,  # OpConcat(fn, [b_0, b_rest])
+            op4,  # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+            op11,  # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+            op12,  # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
         ]
 
         ops = op_set_to_list(fn.ops[::-1])
index bdc193889f8160ae17758be0d25cc2eea37ecf35..65582643e789af865c3e0372413aa62e5a8e97a9 100644 (file)
@@ -1,12 +1,14 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange,
-                                                  GPRType, GlobalMem, Op, OpAddSubE,
-                                                  OpClearCY, OpConcat, OpCopy,
-                                                  OpFuncArg, OpInputMem, OpLI,
-                                                  OpLoad, OpStore, XERBit)
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+                                                  GlobalMem, GPRRange, GPRType,
+                                                  OpBigIntAddSub, OpConcat,
+                                                  OpCopy, OpFuncArg,
+                                                  OpInputMem, OpLI, OpLoad,
+                                                  OpSetCA, OpSetVLImm, OpStore,
+                                                  XERBit)
 from bigint_presentation_code.register_allocator import (
-    AllocationFailed, allocate_registers, MergedRegSet,
+    AllocationFailed, MergedRegSet, allocate_registers,
     try_allocate_registers_without_spilling)
 
 
@@ -15,26 +17,26 @@ class TestMergedRegSet(unittest.TestCase):
 
     def test_from_equality_constraint(self):
         fn = Fn()
-        op0 = OpLI(fn, 0, length=1)
-        op1 = OpLI(fn, 0, length=2)
-        op2 = OpLI(fn, 0, length=3)
+        li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
+        li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
+        li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
         self.assertEqual(MergedRegSet.from_equality_constraint([
-            op0.out,
-            op1.out,
-            op2.out,
+            li0x1.out,
+            li0x2.out,
+            li0x3.out,
         ]), MergedRegSet({
-            op0.out: 0,
-            op1.out: 1,
-            op2.out: 3,
+            li0x1.out: 0,
+            li0x2.out: 1,
+            li0x3.out: 3,
         }.items()))
         self.assertEqual(MergedRegSet.from_equality_constraint([
-            op1.out,
-            op0.out,
-            op2.out,
+            li0x2.out,
+            li0x1.out,
+            li0x3.out,
         ]), MergedRegSet({
-            op1.out: 0,
-            op0.out: 2,
-            op2.out: 3,
+            li0x2.out: 0,
+            li0x1.out: 2,
+            li0x3.out: 3,
         }.items()))
 
 
@@ -43,38 +45,53 @@ class TestRegisterAllocator(unittest.TestCase):
 
     def test_try_alloc_fail(self):
         fn = Fn()
-        op0 = OpLI(fn, 0, length=52)
-        op1 = OpLI(fn, 0, length=64)
-        op2 = OpConcat(fn, [op0.out, op1.out])
+        op0 = OpSetVLImm(fn, 52)
+        op1 = OpLI(fn, 0, vl=op0.out)
+        op2 = OpSetVLImm(fn, 64)
+        op3 = OpLI(fn, 0, vl=op2.out)
+        op4 = OpConcat(fn, [op1.out, op3.out])
 
         reg_assignments = try_allocate_registers_without_spilling(fn.ops)
         self.assertEqual(
             repr(reg_assignments),
             "AllocationFailed("
             "node=IGNode(#0, merged_reg_set=MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
             "edges={}, reg=None), "
             "live_intervals=LiveIntervals("
             "live_intervals={"
-            "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): "
-            "LiveInterval(first_write=0, last_use=2)}, "
+            "MergedRegSet([(<#0.out>, 0)]): "
+            "LiveInterval(first_write=0, last_use=1), "
+            "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): "
+            "LiveInterval(first_write=1, last_use=4), "
+            "MergedRegSet([(<#2.out>, 0)]): "
+            "LiveInterval(first_write=2, last_use=3)}, "
             "merged_reg_sets=MergedRegSets(data={"
-            "<#0.out>: MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "<#0.out>: MergedRegSet([(<#0.out>, 0)]), "
             "<#1.out>: MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
-            "<#2.dest>: MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), "
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+            "<#2.out>: MergedRegSet([(<#2.out>, 0)]), "
+            "<#3.out>: MergedRegSet(["
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+            "<#4.dest>: MergedRegSet(["
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), "
             "reg_sets_live_after={"
-            "0: OFSet([MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+            "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), "
             "1: OFSet([MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
-            "2: OFSet()}), "
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+            "2: OFSet([MergedRegSet(["
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+            "MergedRegSet([(<#2.out>, 0)])]), "
+            "3: OFSet([MergedRegSet(["
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+            "4: OFSet()}), "
             "interference_graph=InterferenceGraph(nodes={"
-            "...: IGNode(#0, "
-            "merged_reg_set=MergedRegSet(["
-            "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+            "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), "
+            "edges={}, reg=None), "
+            "...: IGNode(#1, merged_reg_set=MergedRegSet(["
+            "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+            "edges={}, reg=None), "
+            "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), "
             "edges={}, reg=None)}))"
         )
 
@@ -85,20 +102,18 @@ class TestRegisterAllocator(unittest.TestCase):
         arg = op1.dest
         op2 = OpInputMem(fn)
         mem = op2.out
-        op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
-        a = op3.RT
-        op4 = OpLI(fn, 1)
-        b_0 = op4.out
-        op5 = OpLI(fn, 0, length=31)
-        b_rest = op5.out
-        op6 = OpConcat(fn, [b_0, b_rest])
-        b = op6.dest
-        op7 = OpClearCY(fn)
-        cy = op7.out
-        op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
-        s = op8.RT
-        op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
-        mem = op9.mem_out
+        op3 = OpSetVLImm(fn, 32)
+        vl = op3.out
+        op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+        a = op4.RT
+        op5 = OpLI(fn, 0, vl=vl)
+        b = op5.out
+        op6 = OpSetCA(fn, True)
+        ca = op6.out
+        op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+        s = op7.out
+        op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+        mem = op8.mem_out
 
         reg_assignments = try_allocate_registers_without_spilling(fn.ops)
 
@@ -106,14 +121,13 @@ class TestRegisterAllocator(unittest.TestCase):
             op0.out: GPRRange(start=3, length=1),
             op1.dest: GPRRange(start=3, length=1),
             op2.out: GlobalMem.GlobalMem,
-            op3.RT: GPRRange(start=78, length=32),
-            op4.out: GPRRange(start=46, length=1),
-            op5.out: GPRRange(start=47, length=31),
-            op6.dest: GPRRange(start=46, length=32),
-            op7.out: XERBit.CY,
-            op8.RT: GPRRange(start=14, length=32),
-            op8.CY_out: XERBit.CY,
-            op9.mem_out: GlobalMem.GlobalMem,
+            op3.out: VL.VL_MAXVL,
+            op4.RT: GPRRange(start=78, length=32),
+            op5.out: GPRRange(start=46, length=32),
+            op6.out: XERBit.CA,
+            op7.out: GPRRange(start=14, length=32),
+            op7.CA_out: XERBit.CA,
+            op8.mem_out: GlobalMem.GlobalMem,
         }
 
         self.assertEqual(reg_assignments, expected_reg_assignments)
@@ -121,14 +135,21 @@ class TestRegisterAllocator(unittest.TestCase):
     def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
         # type: (list[GPRRange], GPRRange) -> None
         fn = Fn()
-        li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)]
-        concat = OpConcat(fn, [i.out for i in li_ops])
+        inputs = []
+        expected_reg_assignments = {}
+        for i, r in enumerate(expected_regs):
+            vl = OpSetVLImm(fn, r.length).out
+            expected_reg_assignments[vl] = VL.VL_MAXVL
+            inp = OpLI(fn, i, vl=vl).out
+            inputs.append(inp)
+            expected_reg_assignments[inp] = r
+        concat = OpConcat(fn, inputs)
+        expected_reg_assignments[concat.dest] = expected_dest_reg
 
         reg_assignments = try_allocate_registers_without_spilling(fn.ops)
 
-        expected_reg_assignments = {concat.dest: expected_dest_reg}
-        for li_op, reg in zip(li_ops, expected_regs):
-            expected_reg_assignments[li_op.out] = reg
+        for inp, reg in zip(inputs, expected_regs):
+            expected_reg_assignments[inp] = reg
 
         self.assertEqual(reg_assignments, expected_reg_assignments)