-
-from abc import ABCMeta, abstractmethod
-from collections import defaultdict
-from enum import Enum, unique, EnumMeta
-from functools import lru_cache
-from itertools import combinations
-from typing import (Sequence, AbstractSet, Iterable, Mapping,
- TYPE_CHECKING, Sequence, TypeVar, Generic)
-
-from nmutil.plain_data import plain_data
-
-if TYPE_CHECKING:
- from typing_extensions import final, Self
-else:
- def final(v):
- return v
-
-
-class ABCEnumMeta(EnumMeta, ABCMeta):
- pass
-
-
-class RegLoc(metaclass=ABCMeta):
- __slots__ = ()
-
- @abstractmethod
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- ...
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> RegLoc
- if self not in subreg_type.reg_class:
- raise ValueError(f"register not a member of subreg_type: "
- f"reg={self} subreg_type={subreg_type}")
- if offset != 0:
- raise ValueError(f"non-zero sub-register offset not supported "
- f"for register: {self}")
- return self
-
-
-GPR_COUNT = 128
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GPRRange(RegLoc, Sequence["GPRRange"]):
- __slots__ = "start", "length"
-
- def __init__(self, start, length=None):
- # type: (int | range, int | None) -> None
- if isinstance(start, range):
- if length is not None:
- raise TypeError("can't specify length when input is a range")
- if start.step != 1:
- raise ValueError("range must have a step of 1")
- length = len(start)
- start = start.start
- elif length is None:
- length = 1
- if length <= 0 or start < 0 or start + length > GPR_COUNT:
- raise ValueError("invalid GPRRange")
- self.start = start
- self.length = length
-
- @property
- def stop(self):
- return self.start + self.length
-
- @property
- def step(self):
- return 1
-
- @property
- def range(self):
- return range(self.start, self.stop, self.step)
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, item):
- # type: (int | slice) -> GPRRange
- return GPRRange(self.range[item])
-
- def __contains__(self, value):
- # type: (GPRRange) -> bool
- return value.start >= self.start and value.stop <= self.stop
-
- def index(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if sub.start < r.start or sub.stop > r.stop:
- raise ValueError("GPR range not found")
- return sub.start - self.start
-
- def count(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if len(r) == 0:
- return 0
- return int(sub in GPRRange(r))
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GPRRange):
- return self.stop > other.start and other.stop > self.start
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> GPRRange
- if not isinstance(subreg_type, GPRRangeType):
- raise ValueError(f"subreg_type is not a "
- f"GPRRangeType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length > self.stop:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return GPRRange(self.start + offset, subreg_type.length)
-
-
-SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
-
-
-@final
-@unique
-class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
- CY = "CY"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, XERBit):
- return self == other
- return False
-
-
-@final
-@unique
-class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
- """singleton representing all non-StackSlot memory -- treated as a single
- physical register for register allocation purposes.
- """
- GlobalMem = "GlobalMem"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GlobalMem):
- return self == other
- return False
-
-
-@final
-class RegClass(AbstractSet[RegLoc]):
- """ an ordered set of registers.
- earlier registers are preferred by the register allocator.
- """
- def __init__(self, regs):
- # type: (Iterable[RegLoc]) -> None
-
- # use dict to maintain order
- self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
-
- def __len__(self):
- return len(self.__regs)
-
- def __iter__(self):
- return iter(self.__regs)
-
- def __contains__(self, v):
- # type: (RegLoc) -> bool
- return v in self.__regs
-
- def __hash__(self):
- return super()._hash()
-
- @lru_cache(maxsize=None, typed=True)
- def max_conflicts_with(self, other):
- # type: (RegClass | RegLoc) -> int
- """the largest number of registers in `self` that a single register
- from `other` can conflict with
- """
- if isinstance(other, RegClass):
- return max(self.max_conflicts_with(i) for i in other)
- else:
- return sum(other.conflicts(i) for i in self)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-class RegType(metaclass=ABCMeta):
- __slots__ = ()
-
- @property
- @abstractmethod
- def reg_class(self):
- # type: () -> RegClass
- return ...
-
-
-@plain_data(frozen=True, eq=False)
-class GPRRangeType(RegType):
- __slots__ = "length",
-
- def __init__(self, length):
- # type: (int) -> None
- if length < 1 or length > GPR_COUNT:
- raise ValueError("invalid length")
- self.length = length
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length):
- # type: (int) -> RegClass
- regs = []
- for start in range(GPR_COUNT - length):
- reg = GPRRange(start, length)
- if any(i in reg for i in SPECIAL_GPRS):
- continue
- regs.append(reg)
- return RegClass(regs)
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return GPRRangeType.__get_reg_class(self.length)
-
- @final
- def __eq__(self, other):
- if isinstance(other, GPRRangeType):
- return self.length == other.length
- return False
-
- @final
- def __hash__(self):
- return hash(self.length)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class GPRType(GPRRangeType):
- __slots__ = ()
-
- def __init__(self, length=1):
- if length != 1:
- raise ValueError("length must be 1")
- super().__init__(length=1)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class CYType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([XERBit.CY])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GlobalMemType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([GlobalMem.GlobalMem])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class StackSlot(RegLoc):
- __slots__ = "start_slot", "length_in_slots",
-
- def __init__(self, start_slot, length_in_slots):
- # type: (int, int) -> None
- self.start_slot = start_slot
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @property
- def stop_slot(self):
- return self.start_slot + self.length_in_slots
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, StackSlot):
- return (self.stop_slot > other.start_slot
- and other.stop_slot > self.start_slot)
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> StackSlot
- if not isinstance(subreg_type, StackSlotType):
- raise ValueError(f"subreg_type is not a "
- f"StackSlotType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
-
-
-STACK_SLOT_COUNT = 128
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class StackSlotType(RegType):
- __slots__ = "length_in_slots",
-
- def __init__(self, length_in_slots=1):
- # type: (int) -> None
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length_in_slots):
- # type: (int) -> RegClass
- regs = []
- for start in range(STACK_SLOT_COUNT - length_in_slots):
- reg = StackSlot(start, length_in_slots)
- regs.append(reg)
- return RegClass(regs)
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return StackSlotType.__get_reg_class(self.length_in_slots)
-
- @final
- def __eq__(self, other):
- if isinstance(other, StackSlotType):
- return self.length_in_slots == other.length_in_slots
- return False
-
- @final
- def __hash__(self):
- return hash(self.length_in_slots)
-
-
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class SSAVal(Generic[_RegT_co]):
- __slots__ = "op", "arg_name", "ty", "arg_index"
-
- def __init__(self, op, arg_name, ty):
- # type: (Op, str, _RegT_co) -> None
- self.op = op
- """the Op that writes this SSAVal"""
-
- self.arg_name = arg_name
- """the name of the argument of self.op that writes this SSAVal"""
-
- self.ty = ty
-
- def __eq__(self, rhs):
- if isinstance(rhs, SSAVal):
- return (self.op is rhs.op
- and self.arg_name == rhs.arg_name)
- return False
-
- def __hash__(self):
- return hash((id(self.op), self.arg_name))
-
-
-@final
-@plain_data(unsafe_hash=True, frozen=True)
-class EqualityConstraint:
- __slots__ = "lhs", "rhs"
-
- def __init__(self, lhs, rhs):
- # type: (list[SSAVal], list[SSAVal]) -> None
- self.lhs = lhs
- self.rhs = rhs
- if len(lhs) == 0 or len(rhs) == 0:
- raise ValueError("can't constrain an empty list to be equal")
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-class Op(metaclass=ABCMeta):
- __slots__ = ()
-
- @abstractmethod
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- @abstractmethod
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- if False:
- yield ...
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if False:
- yield ...
-
- def __init__(self):
- pass
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoadFromStackSlot(Op):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[GPRRangeType]) -> None
- self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStoreToStackSlot(Op):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[StackSlotType]) -> None
- self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpCopy(Op, Generic[_RegT_co]):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[_RegT_co]) -> None
- self.dest = SSAVal(self, "dest", src.ty)
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpConcat(Op):
- __slots__ = "dest", "sources"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, sources):
- # type: (Iterable[SSAVal[GPRRangeType]]) -> None
- sources = tuple(sources)
- self.dest = SSAVal(self, "dest", GPRRangeType(
- sum(i.ty.length for i in sources)))
- self.sources = sources
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.dest], [*self.sources])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpSplit(Op):
- __slots__ = "results", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {i.arg_name: i for i in self.results}
-
- def __init__(self, src, split_indexes):
- # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
- ranges = [] # type: list[GPRRangeType]
- last = 0
- for i in split_indexes:
- if not (0 < i < src.ty.length):
- raise ValueError(f"invalid split index: {i}, must be in "
- f"0 < i < {src.ty.length}")
- ranges.append(GPRRangeType(i - last))
- last = i
- ranges.append(GPRRangeType(src.ty.length - last))
- self.src = src
- self.results = tuple(
- SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([*self.results], [self.src])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpAddSubE(Op):
- __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "CY_out": self.CY_out}
-
- def __init__(self, RA, RB, CY_in, is_sub):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
- if RA.ty != RB.ty:
- raise TypeError(f"source types must match: "
- f"{RA} doesn't match {RB}")
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.CY_in = CY_in
- self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
- self.is_sub = is_sub
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntMulDiv(Op):
- __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "RS": self.RS}
-
- def __init__(self, RA, RB, RC, is_div):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.RC = RC
- self.RS = SSAVal(self, "RS", RC.ty)
- self.is_div = is_div
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.RC], [self.RS])
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
- yield self.RT, self.RC
- yield self.RT, self.RS
- yield self.RS, self.RA
- yield self.RS, self.RB
-
-
-@final
-@unique
-class ShiftKind(Enum):
- Sl = "sl"
- Sr = "sr"
- Sra = "sra"
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntShift(Op):
- __slots__ = "RT", "inp", "sh", "kind"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"inp": self.inp, "sh": self.sh}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
-
- def __init__(self, inp, sh, kind):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
- self.RT = SSAVal(self, "RT", inp.ty)
- self.inp = inp
- self.sh = sh
- self.kind = kind
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.inp
- yield self.RT, self.sh
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLI(Op):
- __slots__ = "out", "value"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, value, length=1):
- # type: (int, int) -> None
- self.out = SSAVal(self, "out", GPRRangeType(length))
- self.value = value
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpClearCY(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self):
- # type: () -> None
- self.out = SSAVal(self, "out", CYType())
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoad(Op):
- __slots__ = "RT", "RA", "offset", "mem"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "mem": self.mem}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
-
- def __init__(self, RA, offset, mem, length=1):
- # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
- self.RT = SSAVal(self, "RT", GPRRangeType(length))
- self.RA = RA
- self.offset = offset
- self.mem = mem
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if self.RT.ty.length > 1:
- yield self.RT, self.RA
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStore(Op):
- __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"mem_out": self.mem_out}
-
- def __init__(self, RS, RA, offset, mem_in):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
- self.RS = RS
- self.RA = RA
- self.offset = offset
- self.mem_in = mem_in
- self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpFuncArg(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, ty):
- # type: (RegType) -> None
- self.out = SSAVal(self, "out", ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpInputMem(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self):
- # type: () -> None
- self.out = SSAVal(self, "out", GlobalMemType())
-
-
-def op_set_to_list(ops):
- # type: (Iterable[Op]) -> list[Op]
- worklists = [set()] # type: list[set[Op]]
- input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
- ops_to_pending_input_count_map = {} # type: dict[Op, int]
- for op in ops:
- input_count = 0
- for val in op.inputs().values():
- input_count += 1
- input_vals_to_ops_map[val].add(op)
- while len(worklists) <= input_count:
- worklists.append(set())
- ops_to_pending_input_count_map[op] = input_count
- worklists[input_count].add(op)
- retval = [] # type: list[Op]
- ready_vals = set() # type: set[SSAVal]
- while len(worklists[0]) != 0:
- writing_op = worklists[0].pop()
- retval.append(writing_op)
- for val in writing_op.outputs().values():
- if val in ready_vals:
- raise ValueError(f"multiple instructions must not write "
- f"to the same SSA value: {val}")
- ready_vals.add(val)
- for reading_op in input_vals_to_ops_map[val]:
- pending = ops_to_pending_input_count_map[reading_op]
- worklists[pending].remove(reading_op)
- pending -= 1
- worklists[pending].add(reading_op)
- ops_to_pending_input_count_map[reading_op] = pending
- for worklist in worklists:
- for op in worklist:
- raise ValueError(f"instruction is part of a dependency loop or "
- f"its inputs are never written: {op}")
- return retval
-
-
-@plain_data(unsafe_hash=True, order=True, frozen=True)
-class LiveInterval:
- __slots__ = "first_write", "last_use"
-
- def __init__(self, first_write, last_use=None):
- # type: (int, int | None) -> None
- if last_use is None:
- last_use = first_write
- if last_use < first_write:
- raise ValueError("uses must be after first_write")
- if first_write < 0 or last_use < 0:
- raise ValueError("indexes must be nonnegative")
- self.first_write = first_write
- self.last_use = last_use
-
- def overlaps(self, other):
- # type: (LiveInterval) -> bool
- if self.first_write == other.first_write:
- return True
- return self.last_use > other.first_write \
- and other.last_use > self.first_write
-
- def __add__(self, use):
- # type: (int) -> LiveInterval
- last_use = max(self.last_use, use)
- return LiveInterval(first_write=self.first_write, last_use=last_use)
-
- @property
- def live_after_op_range(self):
- """the range of op indexes where self is live immediately after the
- Op at each index
- """
- return range(self.first_write, self.last_use)
-
-
-@final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
- def __init__(self, reg_set):
- # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
- self.__items = {} # type: dict[SSAVal[_RegT_co], int]
- if isinstance(reg_set, SSAVal):
- reg_set = [(reg_set, 0)]
- for ssa_val, offset in reg_set:
- if ssa_val in self.__items:
- other = self.__items[ssa_val]
- if offset != other:
- raise ValueError(
- f"can't merge register sets: conflicting offsets: "
- f"for {ssa_val}: {offset} != {other}")
- else:
- self.__items[ssa_val] = offset
- first_item = None
- for i in self.__items.items():
- first_item = i
- break
- if first_item is None:
- raise ValueError("can't have empty MergedRegs")
- first_ssa_val, start = first_item
- ty = first_ssa_val.ty
- if isinstance(ty, GPRRangeType):
- stop = start + ty.length
- for ssa_val, offset in self.__items.items():
- if not isinstance(ssa_val.ty, GPRRangeType):
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- stop = max(stop, offset + ssa_val.ty.length)
- start = min(start, offset)
- ty = GPRRangeType(stop - start)
- else:
- stop = 1
- for ssa_val, offset in self.__items.items():
- if offset != 0:
- raise ValueError(f"can't have non-zero offset "
- f"for {ssa_val.ty}")
- if ty != ssa_val.ty:
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- self.__start = start # type: int
- self.__stop = stop # type: int
- self.__ty = ty # type: RegType
- self.__hash = hash(frozenset(self.items()))
-
- @staticmethod
- def from_equality_constraint(constraint_sequence):
- # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
- if len(constraint_sequence) == 1:
- # any type allowed with len = 1
- return MergedRegSet(constraint_sequence[0])
- offset = 0
- retval = []
- for val in constraint_sequence:
- if not isinstance(val.ty, GPRRangeType):
- raise ValueError("equality constraint sequences must only "
- "have SSAVal type GPRRangeType")
- retval.append((val, offset))
- offset += val.ty.length
- return MergedRegSet(retval)
-
- @property
- def ty(self):
- return self.__ty
-
- @property
- def stop(self):
- return self.__stop
-
- @property
- def start(self):
- return self.__start
-
- @property
- def range(self):
- return range(self.__start, self.__stop)
-
- def offset_by(self, amount):
- # type: (int) -> MergedRegSet[_RegT_co]
- return MergedRegSet((k, v + amount) for k, v in self.items())
-
- def normalized(self):
- # type: () -> MergedRegSet[_RegT_co]
- return self.offset_by(-self.start)
-
- def with_offset_to_match(self, target):
- # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
- for ssa_val, offset in self.items():
- if ssa_val in target:
- return self.offset_by(target[ssa_val] - offset)
- raise ValueError("can't change offset to match unrelated MergedRegSet")
-
- def __getitem__(self, item):
- # type: (SSAVal[_RegT_co]) -> int
- return self.__items[item]
-
- def __iter__(self):
- return iter(self.__items)
-
- def __len__(self):
- return len(self.__items)
-
- def __hash__(self):
- return self.__hash
-
- def __repr__(self):
- return f"MergedRegSet({list(self.__items.items())})"
-
-
-@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
- def __init__(self, ops):
- # type: (Iterable[Op]) -> None
- merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]]
- for op in ops:
- for val in (*op.inputs().values(), *op.outputs().values()):
- if val not in merged_sets:
- merged_sets[val] = MergedRegSet(val)
- for e in op.get_equality_constraints():
- lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
- rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
- rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
- full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
- for val in full_set.keys():
- merged_sets[val] = full_set
-
- self.__map = {k: v.normalized() for k, v in merged_sets.items()}
-
- def __getitem__(self, key):
- # type: (SSAVal) -> MergedRegSet
- return self.__map[key]
-
- def __iter__(self):
- return iter(self.__map)
-
- def __len__(self):
- return len(self.__map)
-
- def __repr__(self):
- return f"MergedRegSets(data={self.__map})"
-
-
-@final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
- def __init__(self, ops):
- # type: (list[Op]) -> None
- self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval]
- for op_idx, op in enumerate(ops):
- for val in op.inputs().values():
- live_intervals[self.__merged_reg_sets[val]] += op_idx
- for val in op.outputs().values():
- reg_set = self.__merged_reg_sets[val]
- if reg_set not in live_intervals:
- live_intervals[reg_set] = LiveInterval(op_idx)
- else:
- live_intervals[reg_set] += op_idx
- self.__live_intervals = live_intervals
- live_after = [] # type: list[set[MergedRegSet[_RegT_co]]]
- live_after += (set() for _ in ops)
- for reg_set, live_interval in self.__live_intervals.items():
- for i in live_interval.live_after_op_range:
- live_after[i].add(reg_set)
- self.__live_after = [frozenset(i) for i in live_after]
-
- @property
- def merged_reg_sets(self):
- return self.__merged_reg_sets
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> LiveInterval
- return self.__live_intervals[key]
-
- def __iter__(self):
- return iter(self.__live_intervals)
-
- def reg_sets_live_after(self, op_index):
- # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
- return self.__live_after[op_index]
-
- def __repr__(self):
- reg_sets_live_after = dict(enumerate(self.__live_after))
- return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
- f"merged_reg_sets={self.merged_reg_sets}, "
- f"reg_sets_live_after={reg_sets_live_after})")
-
-
-@final
-class IGNode(Generic[_RegT_co]):
- """ interference graph node """
- __slots__ = "merged_reg_set", "edges", "reg"
-
- def __init__(self, merged_reg_set, edges=(), reg=None):
- # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
- self.merged_reg_set = merged_reg_set
- self.edges = set(edges)
- self.reg = reg
-
- def add_edge(self, other):
- # type: (IGNode) -> None
- self.edges.add(other)
- other.edges.add(self)
-
- def __eq__(self, other):
- # type: (object) -> bool
- if isinstance(other, IGNode):
- return self.merged_reg_set == other.merged_reg_set
- return NotImplemented
-
- def __hash__(self):
- return hash(self.merged_reg_set)
-
- def __repr__(self, nodes=None):
- # type: (None | dict[IGNode, int]) -> str
- if nodes is None:
- nodes = {}
- if self in nodes:
- return f"<IGNode #{nodes[self]}>"
- nodes[self] = len(nodes)
- edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
- return (f"IGNode(#{nodes[self]}, "
- f"merged_reg_set={self.merged_reg_set}, "
- f"edges={edges}, "
- f"reg={self.reg})")
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return self.merged_reg_set.ty.reg_class
-
- def reg_conflicts_with_neighbors(self, reg):
- # type: (RegLoc) -> bool
- for neighbor in self.edges:
- if neighbor.reg is not None and neighbor.reg.conflicts(reg):
- return True
- return False
-
-
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
- def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
- self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> IGNode
- return self.__nodes[key]
-
- def __iter__(self):
- return iter(self.__nodes)
-
- def __repr__(self):
- nodes = {}
- nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
- nodes_text = ", ".join(nodes_text)
- return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
-@plain_data()
-class AllocationFailed:
- __slots__ = "node", "live_intervals", "interference_graph"
-
- def __init__(self, node, live_intervals, interference_graph):
- # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
- self.node = node
- self.live_intervals = live_intervals
- self.interference_graph = interference_graph
-
-
-def try_allocate_registers_without_spilling(ops):
- # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
-
- live_intervals = LiveIntervals(ops)
- merged_reg_sets = live_intervals.merged_reg_sets
- interference_graph = InterferenceGraph(merged_reg_sets.values())
- for op_idx, op in enumerate(ops):
- reg_sets = live_intervals.reg_sets_live_after(op_idx)
- for i, j in combinations(reg_sets, 2):
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
- for i, j in op.get_extra_interferences():
- i = merged_reg_sets[i]
- j = merged_reg_sets[j]
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
-
- nodes_remaining = set(interference_graph.values())
-
- def local_colorability_score(node):
- # type: (IGNode) -> int
- """ returns a positive integer if node is locally colorable, returns
- zero or a negative integer if node isn't known to be locally
- colorable, the more negative the value, the less colorable
- """
- if node not in nodes_remaining:
- raise ValueError()
- retval = len(node.reg_class)
- for neighbor in node.edges:
- if neighbor in nodes_remaining:
- retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
- return retval
-
- node_stack = [] # type: list[IGNode]
- while True:
- best_node = None # type: None | IGNode
- best_score = 0
- for node in nodes_remaining:
- score = local_colorability_score(node)
- if best_node is None or score > best_score:
- best_node = node
- best_score = score
- if best_score > 0:
- # it's locally colorable, no need to find a better one
- break
-
- if best_node is None:
- break
- node_stack.append(best_node)
- nodes_remaining.remove(best_node)
-
- retval = {} # type: dict[SSAVal, RegLoc]
-
- while len(node_stack) > 0:
- node = node_stack.pop()
- if node.reg is not None:
- if node.reg_conflicts_with_neighbors(node.reg):
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
- else:
- # pick the first non-conflicting register in node.reg_class, since
- # register classes are ordered from most preferred to least
- # preferred register.
- for reg in node.reg_class:
- if not node.reg_conflicts_with_neighbors(reg):
- node.reg = reg
- break
- if node.reg is None:
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
-
- for ssa_val, offset in node.merged_reg_set.items():
- retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
-
- return retval
-
-
-def allocate_registers(ops):
- # type: (list[Op]) -> None
- raise NotImplementedError