try_allocate_registers_without_spilling works!
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
index 24b86494e37c71f10c9491e697efb5f0823e37b8..bfcf159c25e4fc2cced574fdf019977d0fcd102a 100644 (file)
@@ -6,12 +6,13 @@ from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, EnumMeta, unique
 from functools import lru_cache
-from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
-                    TypeVar, cast)
+from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
 
 from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
+from bigint_presentation_code.ordered_set import OFSet, OSet
+
 if TYPE_CHECKING:
     from typing_extensions import final
 else:
@@ -111,8 +112,8 @@ class GPRRange(RegLoc, Sequence["GPRRange"]):
 
     def get_subreg_at_offset(self, subreg_type, offset):
         # type: (RegType, int) -> GPRRange
-        if not isinstance(subreg_type, GPRRangeType):
-            raise ValueError(f"subreg_type is not a "
+        if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
+            raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
                              f"GPRRangeType: {subreg_type}")
         if offset < 0 or offset + subreg_type.length > self.stop:
             raise ValueError(f"sub-register offset is out of range: {offset}")
@@ -150,30 +151,11 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
 
 
 @final
-class RegClass(AbstractSet[RegLoc]):
+class RegClass(OFSet[RegLoc]):
     """ an ordered set of registers.
     earlier registers are preferred by the register allocator.
     """
 
-    def __init__(self, regs):
-        # type: (Iterable[RegLoc]) -> None
-
-        # use dict to maintain order
-        self.__regs = dict.fromkeys(regs)  # type: dict[RegLoc, None]
-
-    def __len__(self):
-        return len(self.__regs)
-
-    def __iter__(self):
-        return iter(self.__regs)
-
-    def __contains__(self, v):
-        # type: (RegLoc) -> bool
-        return v in self.__regs
-
-    def __hash__(self):
-        return super()._hash()
-
     @lru_cache(maxsize=None, typed=True)
     def max_conflicts_with(self, other):
         # type: (RegClass | RegLoc) -> int
@@ -251,12 +233,11 @@ class GPRType(GPRRangeType):
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class FixedGPRRangeType(GPRRangeType):
+class FixedGPRRangeType(RegType):
     __slots__ = "reg",
 
     def __init__(self, reg):
         # type: (GPRRange) -> None
-        super().__init__(length=reg.length)
         self.reg = reg
 
     @property
@@ -264,6 +245,11 @@ class FixedGPRRangeType(GPRRangeType):
         # type: () -> RegClass
         return RegClass([self.reg])
 
+    @property
+    def length(self):
+        # type: () -> int
+        return self.reg.length
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
@@ -384,7 +370,9 @@ class SSAVal(Generic[_RegType]):
     def __hash__(self):
         return hash((id(self.op), self.arg_name))
 
-    def __repr__(self):
+    def __repr__(self, long=False):
+        if not long:
+            return f"<#{self.op.id}.{self.arg_name}>"
         fields_list = []
         for name in fields(self):
             v = getattr(self, name, None)
@@ -449,17 +437,33 @@ class Op(metaclass=ABCMeta):
 
     @cached_property
     def id(self):
+        # type: () -> int
+        # use cached_property rather than done in init so id is usable even if
+        # init hasn't run
         retval = Op.__NEXT_ID
         Op.__NEXT_ID += 1
         return retval
 
+    def __init__(self):
+        self.id  # initialize
+
     @final
     def __repr__(self, just_id=False):
         fields_list = [f"#{self.id}"]
+        outputs = None
+        try:
+            outputs = self.outputs()
+        except AttributeError:
+            pass
         if not just_id:
             for name in fields(self):
                 v = getattr(self, name, _NOT_SET)
-                fields_list.append(f"{name}={v!r}")
+                if ((outputs is None or name in outputs)
+                        and isinstance(v, SSAVal)):
+                    v = v.__repr__(long=True)
+                else:
+                    v = repr(v)
+                fields_list.append(f"{name}={v}")
         fields_str = ', '.join(fields_list)
         return f"{self.__class__.__name__}({fields_str})"
 
@@ -479,6 +483,7 @@ class OpLoadFromStackSlot(Op):
 
     def __init__(self, src):
         # type: (SSAVal[GPRRangeType]) -> None
+        super().__init__()
         self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
         self.src = src
 
@@ -498,6 +503,7 @@ class OpStoreToStackSlot(Op):
 
     def __init__(self, src):
         # type: (SSAVal[StackSlotType]) -> None
+        super().__init__()
         self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
         self.src = src
 
@@ -520,11 +526,17 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]):
 
     def __init__(self, src, dest_ty=None):
         # type: (SSAVal[_RegSrcType], _RegType | None) -> None
+        super().__init__()
         if dest_ty is None:
             dest_ty = cast(_RegType, src.ty)
         if isinstance(src.ty, GPRRangeType) \
+                and isinstance(dest_ty, FixedGPRRangeType):
+            if src.ty.length != dest_ty.reg.length:
+                raise ValueError(f"incompatible source and destination "
+                                 f"types: {src.ty} and {dest_ty}")
+        elif isinstance(src.ty, FixedGPRRangeType) \
                 and isinstance(dest_ty, GPRRangeType):
-            if src.ty.length != dest_ty.length:
+            if src.ty.reg.length != dest_ty.length:
                 raise ValueError(f"incompatible source and destination "
                                  f"types: {src.ty} and {dest_ty}")
         elif src.ty != dest_ty:
@@ -550,6 +562,7 @@ class OpConcat(Op):
 
     def __init__(self, sources):
         # type: (Iterable[SSAVal[GPRRangeType]]) -> None
+        super().__init__()
         sources = tuple(sources)
         self.dest = SSAVal(self, "dest", GPRRangeType(
             sum(i.ty.length for i in sources)))
@@ -575,6 +588,7 @@ class OpSplit(Op):
 
     def __init__(self, src, split_indexes):
         # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+        super().__init__()
         ranges = []  # type: list[GPRRangeType]
         last = 0
         for i in split_indexes:
@@ -608,6 +622,7 @@ class OpAddSubE(Op):
 
     def __init__(self, RA, RB, CY_in, is_sub):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+        super().__init__()
         if RA.ty != RB.ty:
             raise TypeError(f"source types must match: "
                             f"{RA} doesn't match {RB}")
@@ -639,6 +654,7 @@ class OpBigIntMulDiv(Op):
 
     def __init__(self, RA, RB, RC, is_div):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
         self.RB = RB
@@ -683,6 +699,7 @@ class OpBigIntShift(Op):
 
     def __init__(self, inp, sh, kind):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", inp.ty)
         self.inp = inp
         self.sh = sh
@@ -709,6 +726,7 @@ class OpLI(Op):
 
     def __init__(self, value, length=1):
         # type: (int, int) -> None
+        super().__init__()
         self.out = SSAVal(self, "out", GPRRangeType(length))
         self.value = value
 
@@ -728,6 +746,7 @@ class OpClearCY(Op):
 
     def __init__(self):
         # type: () -> None
+        super().__init__()
         self.out = SSAVal(self, "out", CYType())
 
 
@@ -746,6 +765,7 @@ class OpLoad(Op):
 
     def __init__(self, RA, offset, mem, length=1):
         # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+        super().__init__()
         self.RT = SSAVal(self, "RT", GPRRangeType(length))
         self.RA = RA
         self.offset = offset
@@ -772,6 +792,7 @@ class OpStore(Op):
 
     def __init__(self, RS, RA, offset, mem_in):
         # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+        super().__init__()
         self.RS = RS
         self.RA = RA
         self.offset = offset
@@ -794,6 +815,7 @@ class OpFuncArg(Op):
 
     def __init__(self, ty):
         # type: (FixedGPRRangeType) -> None
+        super().__init__()
         self.out = SSAVal(self, "out", ty)
 
 
@@ -812,6 +834,7 @@ class OpInputMem(Op):
 
     def __init__(self):
         # type: () -> None
+        super().__init__()
         self.out = SSAVal(self, "out", GlobalMemType())
 
 
@@ -830,7 +853,7 @@ def op_set_to_list(ops):
         ops_to_pending_input_count_map[op] = input_count
         worklists[input_count][op] = None
     retval = []  # type: list[Op]
-    ready_vals = set()  # type: set[SSAVal]
+    ready_vals = OSet()  # type: OSet[SSAVal]
     while len(worklists[0]) != 0:
         writing_op = next(iter(worklists[0]))
         del worklists[0][writing_op]