From: Jacob Lifshay Date: Fri, 14 Oct 2022 06:09:38 +0000 (-0700) Subject: split compiler IR and register allocator out into their own files X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=096a65b3d3ad74faf28776f32f4b5513e4b8278d;p=bigint-presentation-code.git split compiler IR and register allocator out into their own files --- diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py new file mode 100644 index 0000000..645ff0e --- /dev/null +++ b/src/bigint_presentation_code/compiler_ir.py @@ -0,0 +1,783 @@ +""" +Compiler IR for Toom-Cook algorithm generator for SVP64 +""" + +from abc import ABCMeta, abstractmethod +from collections import defaultdict +from enum import Enum, EnumMeta, unique +from functools import lru_cache +from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence, + TypeVar) + +from nmutil.plain_data import plain_data + +if TYPE_CHECKING: + from typing_extensions import final +else: + def final(v): + return v + + +class ABCEnumMeta(EnumMeta, ABCMeta): + pass + + +class RegLoc(metaclass=ABCMeta): + __slots__ = () + + @abstractmethod + def conflicts(self, other): + # type: (RegLoc) -> bool + ... + + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> RegLoc + if self not in subreg_type.reg_class: + raise ValueError(f"register not a member of subreg_type: " + f"reg={self} subreg_type={subreg_type}") + if offset != 0: + raise ValueError(f"non-zero sub-register offset not supported " + f"for register: {self}") + return self + + +GPR_COUNT = 128 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GPRRange(RegLoc, Sequence["GPRRange"]): + __slots__ = "start", "length" + + def __init__(self, start, length=None): + # type: (int | range, int | None) -> None + if isinstance(start, range): + if length is not None: + raise TypeError("can't specify length when input is a range") + if start.step != 1: + raise ValueError("range must have a step of 1") + length = len(start) + start = start.start + elif length is None: + length = 1 + if length <= 0 or start < 0 or start + length > GPR_COUNT: + raise ValueError("invalid GPRRange") + self.start = start + self.length = length + + @property + def stop(self): + return self.start + self.length + + @property + def step(self): + return 1 + + @property + def range(self): + return range(self.start, self.stop, self.step) + + def __len__(self): + return self.length + + def __getitem__(self, item): + # type: (int | slice) -> GPRRange + return GPRRange(self.range[item]) + + def __contains__(self, value): + # type: (GPRRange) -> bool + return value.start >= self.start and value.stop <= self.stop + + def index(self, sub, start=None, end=None): + # type: (GPRRange, int | None, int | None) -> int + r = self.range[start:end] + if sub.start < r.start or sub.stop > r.stop: + raise ValueError("GPR range not found") + return sub.start - self.start + + def count(self, sub, start=None, end=None): + # type: (GPRRange, int | None, int | None) -> int + r = self.range[start:end] + if len(r) == 0: + return 0 + return int(sub in GPRRange(r)) + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, GPRRange): + return self.stop > other.start and other.stop > self.start + return False + + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> GPRRange + if not isinstance(subreg_type, GPRRangeType): + raise ValueError(f"subreg_type is not a " + f"GPRRangeType: {subreg_type}") + if offset < 0 or offset + subreg_type.length > self.stop: + raise ValueError(f"sub-register offset is out of range: {offset}") + return GPRRange(self.start + offset, subreg_type.length) + + +SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) + + +@final +@unique +class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta): + CY = "CY" + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, XERBit): + return self == other + return False + + +@final +@unique +class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): + """singleton representing all non-StackSlot memory -- treated as a single + physical register for register allocation purposes. + """ + GlobalMem = "GlobalMem" + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, GlobalMem): + return self == other + return False + + +@final +class RegClass(AbstractSet[RegLoc]): + """ an ordered set of registers. + earlier registers are preferred by the register allocator. + """ + + def __init__(self, regs): + # type: (Iterable[RegLoc]) -> None + + # use dict to maintain order + self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None] + + def __len__(self): + return len(self.__regs) + + def __iter__(self): + return iter(self.__regs) + + def __contains__(self, v): + # type: (RegLoc) -> bool + return v in self.__regs + + def __hash__(self): + return super()._hash() + + @lru_cache(maxsize=None, typed=True) + def max_conflicts_with(self, other): + # type: (RegClass | RegLoc) -> int + """the largest number of registers in `self` that a single register + from `other` can conflict with + """ + if isinstance(other, RegClass): + return max(self.max_conflicts_with(i) for i in other) + else: + return sum(other.conflicts(i) for i in self) + + +@plain_data(frozen=True, unsafe_hash=True) +class RegType(metaclass=ABCMeta): + __slots__ = () + + @property + @abstractmethod + def reg_class(self): + # type: () -> RegClass + return ... + + +_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) + + +@plain_data(frozen=True, eq=False) +class GPRRangeType(RegType): + __slots__ = "length", + + def __init__(self, length): + # type: (int) -> None + if length < 1 or length > GPR_COUNT: + raise ValueError("invalid length") + self.length = length + + @staticmethod + @lru_cache(maxsize=None) + def __get_reg_class(length): + # type: (int) -> RegClass + regs = [] + for start in range(GPR_COUNT - length): + reg = GPRRange(start, length) + if any(i in reg for i in SPECIAL_GPRS): + continue + regs.append(reg) + return RegClass(regs) + + @property + def reg_class(self): + # type: () -> RegClass + return GPRRangeType.__get_reg_class(self.length) + + @final + def __eq__(self, other): + if isinstance(other, GPRRangeType): + return self.length == other.length + return False + + @final + def __hash__(self): + return hash(self.length) + + +@plain_data(frozen=True, eq=False) +@final +class GPRType(GPRRangeType): + __slots__ = () + + def __init__(self, length=1): + if length != 1: + raise ValueError("length must be 1") + super().__init__(length=1) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class CYType(RegType): + __slots__ = () + + @property + def reg_class(self): + # type: () -> RegClass + return RegClass([XERBit.CY]) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GlobalMemType(RegType): + __slots__ = () + + @property + def reg_class(self): + # type: () -> RegClass + return RegClass([GlobalMem.GlobalMem]) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class StackSlot(RegLoc): + __slots__ = "start_slot", "length_in_slots", + + def __init__(self, start_slot, length_in_slots): + # type: (int, int) -> None + self.start_slot = start_slot + if length_in_slots < 1: + raise ValueError("invalid length_in_slots") + self.length_in_slots = length_in_slots + + @property + def stop_slot(self): + return self.start_slot + self.length_in_slots + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, StackSlot): + return (self.stop_slot > other.start_slot + and other.stop_slot > self.start_slot) + return False + + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> StackSlot + if not isinstance(subreg_type, StackSlotType): + raise ValueError(f"subreg_type is not a " + f"StackSlotType: {subreg_type}") + if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot: + raise ValueError(f"sub-register offset is out of range: {offset}") + return StackSlot(self.start_slot + offset, subreg_type.length_in_slots) + + +STACK_SLOT_COUNT = 128 + + +@plain_data(frozen=True, eq=False) +@final +class StackSlotType(RegType): + __slots__ = "length_in_slots", + + def __init__(self, length_in_slots=1): + # type: (int) -> None + if length_in_slots < 1: + raise ValueError("invalid length_in_slots") + self.length_in_slots = length_in_slots + + @staticmethod + @lru_cache(maxsize=None) + def __get_reg_class(length_in_slots): + # type: (int) -> RegClass + regs = [] + for start in range(STACK_SLOT_COUNT - length_in_slots): + reg = StackSlot(start, length_in_slots) + regs.append(reg) + return RegClass(regs) + + @property + def reg_class(self): + # type: () -> RegClass + return StackSlotType.__get_reg_class(self.length_in_slots) + + @final + def __eq__(self, other): + if isinstance(other, StackSlotType): + return self.length_in_slots == other.length_in_slots + return False + + @final + def __hash__(self): + return hash(self.length_in_slots) + + +@plain_data(frozen=True, eq=False) +@final +class SSAVal(Generic[_RegT_co]): + __slots__ = "op", "arg_name", "ty", "arg_index" + + def __init__(self, op, arg_name, ty): + # type: (Op, str, _RegT_co) -> None + self.op = op + """the Op that writes this SSAVal""" + + self.arg_name = arg_name + """the name of the argument of self.op that writes this SSAVal""" + + self.ty = ty + + def __eq__(self, rhs): + if isinstance(rhs, SSAVal): + return (self.op is rhs.op + and self.arg_name == rhs.arg_name) + return False + + def __hash__(self): + return hash((id(self.op), self.arg_name)) + + +@final +@plain_data(unsafe_hash=True, frozen=True) +class EqualityConstraint: + __slots__ = "lhs", "rhs" + + def __init__(self, lhs, rhs): + # type: (list[SSAVal], list[SSAVal]) -> None + self.lhs = lhs + self.rhs = rhs + if len(lhs) == 0 or len(rhs) == 0: + raise ValueError("can't constrain an empty list to be equal") + + +@plain_data(unsafe_hash=True, frozen=True) +class Op(metaclass=ABCMeta): + __slots__ = () + + @abstractmethod + def inputs(self): + # type: () -> dict[str, SSAVal] + ... + + @abstractmethod + def outputs(self): + # type: () -> dict[str, SSAVal] + ... + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + if False: + yield ... + + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + if False: + yield ... + + def __init__(self): + pass + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLoadFromStackSlot(Op): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, src): + # type: (SSAVal[GPRRangeType]) -> None + self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) + self.src = src + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpStoreToStackSlot(Op): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, src): + # type: (SSAVal[StackSlotType]) -> None + self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) + self.src = src + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpCopy(Op, Generic[_RegT_co]): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, src): + # type: (SSAVal[_RegT_co]) -> None + self.dest = SSAVal(self, "dest", src.ty) + self.src = src + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpConcat(Op): + __slots__ = "dest", "sources" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {f"sources[{i}]": v for i, v in enumerate(self.sources)} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, sources): + # type: (Iterable[SSAVal[GPRRangeType]]) -> None + sources = tuple(sources) + self.dest = SSAVal(self, "dest", GPRRangeType( + sum(i.ty.length for i in sources))) + self.sources = sources + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + yield EqualityConstraint([self.dest], [*self.sources]) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpSplit(Op): + __slots__ = "results", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {i.arg_name: i for i in self.results} + + def __init__(self, src, split_indexes): + # type: (SSAVal[GPRRangeType], Iterable[int]) -> None + ranges = [] # type: list[GPRRangeType] + last = 0 + for i in split_indexes: + if not (0 < i < src.ty.length): + raise ValueError(f"invalid split index: {i}, must be in " + f"0 < i < {src.ty.length}") + ranges.append(GPRRangeType(i - last)) + last = i + ranges.append(GPRRangeType(src.ty.length - last)) + self.src = src + self.results = tuple( + SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges)) + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + yield EqualityConstraint([*self.results], [self.src]) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpAddSubE(Op): + __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"RT": self.RT, "CY_out": self.CY_out} + + def __init__(self, RA, RB, CY_in, is_sub): + # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None + if RA.ty != RB.ty: + raise TypeError(f"source types must match: " + f"{RA} doesn't match {RB}") + self.RT = SSAVal(self, "RT", RA.ty) + self.RA = RA + self.RB = RB + self.CY_in = CY_in + self.CY_out = SSAVal(self, "CY_out", CY_in.ty) + self.is_sub = is_sub + + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.RA + yield self.RT, self.RB + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpBigIntMulDiv(Op): + __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"RA": self.RA, "RB": self.RB, "RC": self.RC} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"RT": self.RT, "RS": self.RS} + + def __init__(self, RA, RB, RC, is_div): + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None + self.RT = SSAVal(self, "RT", RA.ty) + self.RA = RA + self.RB = RB + self.RC = RC + self.RS = SSAVal(self, "RS", RC.ty) + self.is_div = is_div + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + yield EqualityConstraint([self.RC], [self.RS]) + + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.RA + yield self.RT, self.RB + yield self.RT, self.RC + yield self.RT, self.RS + yield self.RS, self.RA + yield self.RS, self.RB + + +@final +@unique +class ShiftKind(Enum): + Sl = "sl" + Sr = "sr" + Sra = "sra" + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpBigIntShift(Op): + __slots__ = "RT", "inp", "sh", "kind" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"inp": self.inp, "sh": self.sh} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"RT": self.RT} + + def __init__(self, inp, sh, kind): + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None + self.RT = SSAVal(self, "RT", inp.ty) + self.inp = inp + self.sh = sh + self.kind = kind + + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.inp + yield self.RT, self.sh + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLI(Op): + __slots__ = "out", "value" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"out": self.out} + + def __init__(self, value, length=1): + # type: (int, int) -> None + self.out = SSAVal(self, "out", GPRRangeType(length)) + self.value = value + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpClearCY(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"out": self.out} + + def __init__(self): + # type: () -> None + self.out = SSAVal(self, "out", CYType()) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLoad(Op): + __slots__ = "RT", "RA", "offset", "mem" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"RA": self.RA, "mem": self.mem} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"RT": self.RT} + + def __init__(self, RA, offset, mem, length=1): + # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None + self.RT = SSAVal(self, "RT", GPRRangeType(length)) + self.RA = RA + self.offset = offset + self.mem = mem + + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + if self.RT.ty.length > 1: + yield self.RT, self.RA + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpStore(Op): + __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"mem_out": self.mem_out} + + def __init__(self, RS, RA, offset, mem_in): + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None + self.RS = RS + self.RA = RA + self.offset = offset + self.mem_in = mem_in + self.mem_out = SSAVal(self, "mem_out", mem_in.ty) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpFuncArg(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"out": self.out} + + def __init__(self, ty): + # type: (RegType) -> None + self.out = SSAVal(self, "out", ty) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpInputMem(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"out": self.out} + + def __init__(self): + # type: () -> None + self.out = SSAVal(self, "out", GlobalMemType()) + + +def op_set_to_list(ops): + # type: (Iterable[Op]) -> list[Op] + worklists = [set()] # type: list[set[Op]] + input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]] + ops_to_pending_input_count_map = {} # type: dict[Op, int] + for op in ops: + input_count = 0 + for val in op.inputs().values(): + input_count += 1 + input_vals_to_ops_map[val].add(op) + while len(worklists) <= input_count: + worklists.append(set()) + ops_to_pending_input_count_map[op] = input_count + worklists[input_count].add(op) + retval = [] # type: list[Op] + ready_vals = set() # type: set[SSAVal] + while len(worklists[0]) != 0: + writing_op = worklists[0].pop() + retval.append(writing_op) + for val in writing_op.outputs().values(): + if val in ready_vals: + raise ValueError(f"multiple instructions must not write " + f"to the same SSA value: {val}") + ready_vals.add(val) + for reading_op in input_vals_to_ops_map[val]: + pending = ops_to_pending_input_count_map[reading_op] + worklists[pending].remove(reading_op) + pending -= 1 + worklists[pending].add(reading_op) + ops_to_pending_input_count_map[reading_op] = pending + for worklist in worklists: + for op in worklist: + raise ValueError(f"instruction is part of a dependency loop or " + f"its inputs are never written: {op}") + return retval diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py new file mode 100644 index 0000000..297e3e5 --- /dev/null +++ b/src/bigint_presentation_code/register_allocator.py @@ -0,0 +1,414 @@ +""" +Register Allocator for Toom-Cook algorithm generator for SVP64 + +this uses an algorithm based on: +[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) +""" + +from itertools import combinations +from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar + +from nmutil.plain_data import plain_data + +from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, + RegLoc, RegType, SSAVal) + +if TYPE_CHECKING: + from typing_extensions import Self, final +else: + def final(v): + return v + + +_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) + + +@plain_data(unsafe_hash=True, order=True, frozen=True) +class LiveInterval: + __slots__ = "first_write", "last_use" + + def __init__(self, first_write, last_use=None): + # type: (int, int | None) -> None + if last_use is None: + last_use = first_write + if last_use < first_write: + raise ValueError("uses must be after first_write") + if first_write < 0 or last_use < 0: + raise ValueError("indexes must be nonnegative") + self.first_write = first_write + self.last_use = last_use + + def overlaps(self, other): + # type: (LiveInterval) -> bool + if self.first_write == other.first_write: + return True + return self.last_use > other.first_write \ + and other.last_use > self.first_write + + def __add__(self, use): + # type: (int) -> LiveInterval + last_use = max(self.last_use, use) + return LiveInterval(first_write=self.first_write, last_use=last_use) + + @property + def live_after_op_range(self): + """the range of op indexes where self is live immediately after the + Op at each index + """ + return range(self.first_write, self.last_use) + + +@final +class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): + def __init__(self, reg_set): + # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None + self.__items = {} # type: dict[SSAVal[_RegT_co], int] + if isinstance(reg_set, SSAVal): + reg_set = [(reg_set, 0)] + for ssa_val, offset in reg_set: + if ssa_val in self.__items: + other = self.__items[ssa_val] + if offset != other: + raise ValueError( + f"can't merge register sets: conflicting offsets: " + f"for {ssa_val}: {offset} != {other}") + else: + self.__items[ssa_val] = offset + first_item = None + for i in self.__items.items(): + first_item = i + break + if first_item is None: + raise ValueError("can't have empty MergedRegs") + first_ssa_val, start = first_item + ty = first_ssa_val.ty + if isinstance(ty, GPRRangeType): + stop = start + ty.length + for ssa_val, offset in self.__items.items(): + if not isinstance(ssa_val.ty, GPRRangeType): + raise ValueError(f"can't merge incompatible types: " + f"{ssa_val.ty} and {ty}") + stop = max(stop, offset + ssa_val.ty.length) + start = min(start, offset) + ty = GPRRangeType(stop - start) + else: + stop = 1 + for ssa_val, offset in self.__items.items(): + if offset != 0: + raise ValueError(f"can't have non-zero offset " + f"for {ssa_val.ty}") + if ty != ssa_val.ty: + raise ValueError(f"can't merge incompatible types: " + f"{ssa_val.ty} and {ty}") + self.__start = start # type: int + self.__stop = stop # type: int + self.__ty = ty # type: RegType + self.__hash = hash(frozenset(self.items())) + + @staticmethod + def from_equality_constraint(constraint_sequence): + # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co] + if len(constraint_sequence) == 1: + # any type allowed with len = 1 + return MergedRegSet(constraint_sequence[0]) + offset = 0 + retval = [] + for val in constraint_sequence: + if not isinstance(val.ty, GPRRangeType): + raise ValueError("equality constraint sequences must only " + "have SSAVal type GPRRangeType") + retval.append((val, offset)) + offset += val.ty.length + return MergedRegSet(retval) + + @property + def ty(self): + return self.__ty + + @property + def stop(self): + return self.__stop + + @property + def start(self): + return self.__start + + @property + def range(self): + return range(self.__start, self.__stop) + + def offset_by(self, amount): + # type: (int) -> MergedRegSet[_RegT_co] + return MergedRegSet((k, v + amount) for k, v in self.items()) + + def normalized(self): + # type: () -> MergedRegSet[_RegT_co] + return self.offset_by(-self.start) + + def with_offset_to_match(self, target): + # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co] + for ssa_val, offset in self.items(): + if ssa_val in target: + return self.offset_by(target[ssa_val] - offset) + raise ValueError("can't change offset to match unrelated MergedRegSet") + + def __getitem__(self, item): + # type: (SSAVal[_RegT_co]) -> int + return self.__items[item] + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def __hash__(self): + return self.__hash + + def __repr__(self): + return f"MergedRegSet({list(self.__items.items())})" + + +@final +class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]): + def __init__(self, ops): + # type: (Iterable[Op]) -> None + merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]] + for op in ops: + for val in (*op.inputs().values(), *op.outputs().values()): + if val not in merged_sets: + merged_sets[val] = MergedRegSet(val) + for e in op.get_equality_constraints(): + lhs_set = MergedRegSet.from_equality_constraint(e.lhs) + rhs_set = MergedRegSet.from_equality_constraint(e.rhs) + lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set) + rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set) + full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()]) + for val in full_set.keys(): + merged_sets[val] = full_set + + self.__map = {k: v.normalized() for k, v in merged_sets.items()} + + def __getitem__(self, key): + # type: (SSAVal) -> MergedRegSet + return self.__map[key] + + def __iter__(self): + return iter(self.__map) + + def __len__(self): + return len(self.__map) + + def __repr__(self): + return f"MergedRegSets(data={self.__map})" + + +@final +class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): + def __init__(self, ops): + # type: (list[Op]) -> None + self.__merged_reg_sets = MergedRegSets(ops) + live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval] + for op_idx, op in enumerate(ops): + for val in op.inputs().values(): + live_intervals[self.__merged_reg_sets[val]] += op_idx + for val in op.outputs().values(): + reg_set = self.__merged_reg_sets[val] + if reg_set not in live_intervals: + live_intervals[reg_set] = LiveInterval(op_idx) + else: + live_intervals[reg_set] += op_idx + self.__live_intervals = live_intervals + live_after = [] # type: list[set[MergedRegSet[_RegT_co]]] + live_after += (set() for _ in ops) + for reg_set, live_interval in self.__live_intervals.items(): + for i in live_interval.live_after_op_range: + live_after[i].add(reg_set) + self.__live_after = [frozenset(i) for i in live_after] + + @property + def merged_reg_sets(self): + return self.__merged_reg_sets + + def __getitem__(self, key): + # type: (MergedRegSet[_RegT_co]) -> LiveInterval + return self.__live_intervals[key] + + def __iter__(self): + return iter(self.__live_intervals) + + def reg_sets_live_after(self, op_index): + # type: (int) -> frozenset[MergedRegSet[_RegT_co]] + return self.__live_after[op_index] + + def __repr__(self): + reg_sets_live_after = dict(enumerate(self.__live_after)) + return (f"LiveIntervals(live_intervals={self.__live_intervals}, " + f"merged_reg_sets={self.merged_reg_sets}, " + f"reg_sets_live_after={reg_sets_live_after})") + + +@final +class IGNode(Generic[_RegT_co]): + """ interference graph node """ + __slots__ = "merged_reg_set", "edges", "reg" + + def __init__(self, merged_reg_set, edges=(), reg=None): + # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None + self.merged_reg_set = merged_reg_set + self.edges = set(edges) + self.reg = reg + + def add_edge(self, other): + # type: (IGNode) -> None + self.edges.add(other) + other.edges.add(self) + + def __eq__(self, other): + # type: (object) -> bool + if isinstance(other, IGNode): + return self.merged_reg_set == other.merged_reg_set + return NotImplemented + + def __hash__(self): + return hash(self.merged_reg_set) + + def __repr__(self, nodes=None): + # type: (None | dict[IGNode, int]) -> str + if nodes is None: + nodes = {} + if self in nodes: + return f"" + nodes[self] = len(nodes) + edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" + return (f"IGNode(#{nodes[self]}, " + f"merged_reg_set={self.merged_reg_set}, " + f"edges={edges}, " + f"reg={self.reg})") + + @property + def reg_class(self): + # type: () -> RegClass + return self.merged_reg_set.ty.reg_class + + def reg_conflicts_with_neighbors(self, reg): + # type: (RegLoc) -> bool + for neighbor in self.edges: + if neighbor.reg is not None and neighbor.reg.conflicts(reg): + return True + return False + + +@final +class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]): + def __init__(self, merged_reg_sets): + # type: (Iterable[MergedRegSet[_RegT_co]]) -> None + self.__nodes = {i: IGNode(i) for i in merged_reg_sets} + + def __getitem__(self, key): + # type: (MergedRegSet[_RegT_co]) -> IGNode + return self.__nodes[key] + + def __iter__(self): + return iter(self.__nodes) + + def __repr__(self): + nodes = {} + nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] + nodes_text = ", ".join(nodes_text) + return f"InterferenceGraph(nodes={{{nodes_text}}})" + + +@plain_data() +class AllocationFailed: + __slots__ = "node", "live_intervals", "interference_graph" + + def __init__(self, node, live_intervals, interference_graph): + # type: (IGNode, LiveIntervals, InterferenceGraph) -> None + self.node = node + self.live_intervals = live_intervals + self.interference_graph = interference_graph + + +def try_allocate_registers_without_spilling(ops): + # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed + + live_intervals = LiveIntervals(ops) + merged_reg_sets = live_intervals.merged_reg_sets + interference_graph = InterferenceGraph(merged_reg_sets.values()) + for op_idx, op in enumerate(ops): + reg_sets = live_intervals.reg_sets_live_after(op_idx) + for i, j in combinations(reg_sets, 2): + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) + for i, j in op.get_extra_interferences(): + i = merged_reg_sets[i] + j = merged_reg_sets[j] + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) + + nodes_remaining = set(interference_graph.values()) + + def local_colorability_score(node): + # type: (IGNode) -> int + """ returns a positive integer if node is locally colorable, returns + zero or a negative integer if node isn't known to be locally + colorable, the more negative the value, the less colorable + """ + if node not in nodes_remaining: + raise ValueError() + retval = len(node.reg_class) + for neighbor in node.edges: + if neighbor in nodes_remaining: + retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) + return retval + + node_stack = [] # type: list[IGNode] + while True: + best_node = None # type: None | IGNode + best_score = 0 + for node in nodes_remaining: + score = local_colorability_score(node) + if best_node is None or score > best_score: + best_node = node + best_score = score + if best_score > 0: + # it's locally colorable, no need to find a better one + break + + if best_node is None: + break + node_stack.append(best_node) + nodes_remaining.remove(best_node) + + retval = {} # type: dict[SSAVal, RegLoc] + + while len(node_stack) > 0: + node = node_stack.pop() + if node.reg is not None: + if node.reg_conflicts_with_neighbors(node.reg): + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + else: + # pick the first non-conflicting register in node.reg_class, since + # register classes are ordered from most preferred to least + # preferred register. + for reg in node.reg_class: + if not node.reg_conflicts_with_neighbors(reg): + node.reg = reg + break + if node.reg is None: + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + + for ssa_val, offset in node.merged_reg_set.items(): + retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) + + return retval + + +def allocate_registers(ops): + # type: (list[Op]) -> None + raise NotImplementedError diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py new file mode 100644 index 0000000..231f2fb --- /dev/null +++ b/src/bigint_presentation_code/test_compiler_ir.py @@ -0,0 +1,11 @@ +import unittest + +from bigint_presentation_code.compiler_ir import Op, op_set_to_list + + +class TestCompilerIR(unittest.TestCase): + pass # no tests yet, just testing importing + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py new file mode 100644 index 0000000..1345cc7 --- /dev/null +++ b/src/bigint_presentation_code/test_register_allocator.py @@ -0,0 +1,14 @@ +import unittest + +from bigint_presentation_code.compiler_ir import Op +from bigint_presentation_code.register_allocator import ( + AllocationFailed, allocate_registers, + try_allocate_registers_without_spilling) + + +class TestCompilerIR(unittest.TestCase): + pass # no tests yet, just testing importing + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/test_toom_cook.py index 9851afb..8fe6cea 100644 --- a/src/bigint_presentation_code/test_toom_cook.py +++ b/src/bigint_presentation_code/test_toom_cook.py @@ -1,5 +1,6 @@ import unittest -from bigint_presentation_code.toom_cook import Op + +import bigint_presentation_code.toom_cook class TestToomCook(unittest.TestCase): diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 88e8f7e..c014d09 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -4,1174 +4,5 @@ Toom-Cook algorithm generator for SVP64 the register allocator uses an algorithm based on: [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) """ - -from abc import ABCMeta, abstractmethod -from collections import defaultdict -from enum import Enum, unique, EnumMeta -from functools import lru_cache -from itertools import combinations -from typing import (Sequence, AbstractSet, Iterable, Mapping, - TYPE_CHECKING, Sequence, TypeVar, Generic) - -from nmutil.plain_data import plain_data - -if TYPE_CHECKING: - from typing_extensions import final, Self -else: - def final(v): - return v - - -class ABCEnumMeta(EnumMeta, ABCMeta): - pass - - -class RegLoc(metaclass=ABCMeta): - __slots__ = () - - @abstractmethod - def conflicts(self, other): - # type: (RegLoc) -> bool - ... - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> RegLoc - if self not in subreg_type.reg_class: - raise ValueError(f"register not a member of subreg_type: " - f"reg={self} subreg_type={subreg_type}") - if offset != 0: - raise ValueError(f"non-zero sub-register offset not supported " - f"for register: {self}") - return self - - -GPR_COUNT = 128 - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GPRRange(RegLoc, Sequence["GPRRange"]): - __slots__ = "start", "length" - - def __init__(self, start, length=None): - # type: (int | range, int | None) -> None - if isinstance(start, range): - if length is not None: - raise TypeError("can't specify length when input is a range") - if start.step != 1: - raise ValueError("range must have a step of 1") - length = len(start) - start = start.start - elif length is None: - length = 1 - if length <= 0 or start < 0 or start + length > GPR_COUNT: - raise ValueError("invalid GPRRange") - self.start = start - self.length = length - - @property - def stop(self): - return self.start + self.length - - @property - def step(self): - return 1 - - @property - def range(self): - return range(self.start, self.stop, self.step) - - def __len__(self): - return self.length - - def __getitem__(self, item): - # type: (int | slice) -> GPRRange - return GPRRange(self.range[item]) - - def __contains__(self, value): - # type: (GPRRange) -> bool - return value.start >= self.start and value.stop <= self.stop - - def index(self, sub, start=None, end=None): - # type: (GPRRange, int | None, int | None) -> int - r = self.range[start:end] - if sub.start < r.start or sub.stop > r.stop: - raise ValueError("GPR range not found") - return sub.start - self.start - - def count(self, sub, start=None, end=None): - # type: (GPRRange, int | None, int | None) -> int - r = self.range[start:end] - if len(r) == 0: - return 0 - return int(sub in GPRRange(r)) - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, GPRRange): - return self.stop > other.start and other.stop > self.start - return False - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> GPRRange - if not isinstance(subreg_type, GPRRangeType): - raise ValueError(f"subreg_type is not a " - f"GPRRangeType: {subreg_type}") - if offset < 0 or offset + subreg_type.length > self.stop: - raise ValueError(f"sub-register offset is out of range: {offset}") - return GPRRange(self.start + offset, subreg_type.length) - - -SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) - - -@final -@unique -class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta): - CY = "CY" - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, XERBit): - return self == other - return False - - -@final -@unique -class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): - """singleton representing all non-StackSlot memory -- treated as a single - physical register for register allocation purposes. - """ - GlobalMem = "GlobalMem" - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, GlobalMem): - return self == other - return False - - -@final -class RegClass(AbstractSet[RegLoc]): - """ an ordered set of registers. - earlier registers are preferred by the register allocator. - """ - def __init__(self, regs): - # type: (Iterable[RegLoc]) -> None - - # use dict to maintain order - self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None] - - def __len__(self): - return len(self.__regs) - - def __iter__(self): - return iter(self.__regs) - - def __contains__(self, v): - # type: (RegLoc) -> bool - return v in self.__regs - - def __hash__(self): - return super()._hash() - - @lru_cache(maxsize=None, typed=True) - def max_conflicts_with(self, other): - # type: (RegClass | RegLoc) -> int - """the largest number of registers in `self` that a single register - from `other` can conflict with - """ - if isinstance(other, RegClass): - return max(self.max_conflicts_with(i) for i in other) - else: - return sum(other.conflicts(i) for i in self) - - -@plain_data(frozen=True, unsafe_hash=True) -class RegType(metaclass=ABCMeta): - __slots__ = () - - @property - @abstractmethod - def reg_class(self): - # type: () -> RegClass - return ... - - -@plain_data(frozen=True, eq=False) -class GPRRangeType(RegType): - __slots__ = "length", - - def __init__(self, length): - # type: (int) -> None - if length < 1 or length > GPR_COUNT: - raise ValueError("invalid length") - self.length = length - - @staticmethod - @lru_cache(maxsize=None) - def __get_reg_class(length): - # type: (int) -> RegClass - regs = [] - for start in range(GPR_COUNT - length): - reg = GPRRange(start, length) - if any(i in reg for i in SPECIAL_GPRS): - continue - regs.append(reg) - return RegClass(regs) - - @property - def reg_class(self): - # type: () -> RegClass - return GPRRangeType.__get_reg_class(self.length) - - @final - def __eq__(self, other): - if isinstance(other, GPRRangeType): - return self.length == other.length - return False - - @final - def __hash__(self): - return hash(self.length) - - -@plain_data(frozen=True, eq=False) -@final -class GPRType(GPRRangeType): - __slots__ = () - - def __init__(self, length=1): - if length != 1: - raise ValueError("length must be 1") - super().__init__(length=1) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class CYType(RegType): - __slots__ = () - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([XERBit.CY]) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GlobalMemType(RegType): - __slots__ = () - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([GlobalMem.GlobalMem]) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class StackSlot(RegLoc): - __slots__ = "start_slot", "length_in_slots", - - def __init__(self, start_slot, length_in_slots): - # type: (int, int) -> None - self.start_slot = start_slot - if length_in_slots < 1: - raise ValueError("invalid length_in_slots") - self.length_in_slots = length_in_slots - - @property - def stop_slot(self): - return self.start_slot + self.length_in_slots - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, StackSlot): - return (self.stop_slot > other.start_slot - and other.stop_slot > self.start_slot) - return False - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> StackSlot - if not isinstance(subreg_type, StackSlotType): - raise ValueError(f"subreg_type is not a " - f"StackSlotType: {subreg_type}") - if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot: - raise ValueError(f"sub-register offset is out of range: {offset}") - return StackSlot(self.start_slot + offset, subreg_type.length_in_slots) - - -STACK_SLOT_COUNT = 128 - - -@plain_data(frozen=True, eq=False) -@final -class StackSlotType(RegType): - __slots__ = "length_in_slots", - - def __init__(self, length_in_slots=1): - # type: (int) -> None - if length_in_slots < 1: - raise ValueError("invalid length_in_slots") - self.length_in_slots = length_in_slots - - @staticmethod - @lru_cache(maxsize=None) - def __get_reg_class(length_in_slots): - # type: (int) -> RegClass - regs = [] - for start in range(STACK_SLOT_COUNT - length_in_slots): - reg = StackSlot(start, length_in_slots) - regs.append(reg) - return RegClass(regs) - - @property - def reg_class(self): - # type: () -> RegClass - return StackSlotType.__get_reg_class(self.length_in_slots) - - @final - def __eq__(self, other): - if isinstance(other, StackSlotType): - return self.length_in_slots == other.length_in_slots - return False - - @final - def __hash__(self): - return hash(self.length_in_slots) - - -_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) - - -@plain_data(frozen=True, eq=False) -@final -class SSAVal(Generic[_RegT_co]): - __slots__ = "op", "arg_name", "ty", "arg_index" - - def __init__(self, op, arg_name, ty): - # type: (Op, str, _RegT_co) -> None - self.op = op - """the Op that writes this SSAVal""" - - self.arg_name = arg_name - """the name of the argument of self.op that writes this SSAVal""" - - self.ty = ty - - def __eq__(self, rhs): - if isinstance(rhs, SSAVal): - return (self.op is rhs.op - and self.arg_name == rhs.arg_name) - return False - - def __hash__(self): - return hash((id(self.op), self.arg_name)) - - -@final -@plain_data(unsafe_hash=True, frozen=True) -class EqualityConstraint: - __slots__ = "lhs", "rhs" - - def __init__(self, lhs, rhs): - # type: (list[SSAVal], list[SSAVal]) -> None - self.lhs = lhs - self.rhs = rhs - if len(lhs) == 0 or len(rhs) == 0: - raise ValueError("can't constrain an empty list to be equal") - - -@plain_data(unsafe_hash=True, frozen=True) -class Op(metaclass=ABCMeta): - __slots__ = () - - @abstractmethod - def inputs(self): - # type: () -> dict[str, SSAVal] - ... - - @abstractmethod - def outputs(self): - # type: () -> dict[str, SSAVal] - ... - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - if False: - yield ... - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - if False: - yield ... - - def __init__(self): - pass - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpLoadFromStackSlot(Op): - __slots__ = "dest", "src" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"src": self.src} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, src): - # type: (SSAVal[GPRRangeType]) -> None - self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) - self.src = src - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpStoreToStackSlot(Op): - __slots__ = "dest", "src" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"src": self.src} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, src): - # type: (SSAVal[StackSlotType]) -> None - self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) - self.src = src - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpCopy(Op, Generic[_RegT_co]): - __slots__ = "dest", "src" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"src": self.src} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, src): - # type: (SSAVal[_RegT_co]) -> None - self.dest = SSAVal(self, "dest", src.ty) - self.src = src - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpConcat(Op): - __slots__ = "dest", "sources" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {f"sources[{i}]": v for i, v in enumerate(self.sources)} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, sources): - # type: (Iterable[SSAVal[GPRRangeType]]) -> None - sources = tuple(sources) - self.dest = SSAVal(self, "dest", GPRRangeType( - sum(i.ty.length for i in sources))) - self.sources = sources - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([self.dest], [*self.sources]) - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpSplit(Op): - __slots__ = "results", "src" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"src": self.src} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {i.arg_name: i for i in self.results} - - def __init__(self, src, split_indexes): - # type: (SSAVal[GPRRangeType], Iterable[int]) -> None - ranges = [] # type: list[GPRRangeType] - last = 0 - for i in split_indexes: - if not (0 < i < src.ty.length): - raise ValueError(f"invalid split index: {i}, must be in " - f"0 < i < {src.ty.length}") - ranges.append(GPRRangeType(i - last)) - last = i - ranges.append(GPRRangeType(src.ty.length - last)) - self.src = src - self.results = tuple( - SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges)) - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([*self.results], [self.src]) - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpAddSubE(Op): - __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT, "CY_out": self.CY_out} - - def __init__(self, RA, RB, CY_in, is_sub): - # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None - if RA.ty != RB.ty: - raise TypeError(f"source types must match: " - f"{RA} doesn't match {RB}") - self.RT = SSAVal(self, "RT", RA.ty) - self.RA = RA - self.RB = RB - self.CY_in = CY_in - self.CY_out = SSAVal(self, "CY_out", CY_in.ty) - self.is_sub = is_sub - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.RA - yield self.RT, self.RB - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpBigIntMulDiv(Op): - __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "RB": self.RB, "RC": self.RC} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT, "RS": self.RS} - - def __init__(self, RA, RB, RC, is_div): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None - self.RT = SSAVal(self, "RT", RA.ty) - self.RA = RA - self.RB = RB - self.RC = RC - self.RS = SSAVal(self, "RS", RC.ty) - self.is_div = is_div - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([self.RC], [self.RS]) - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.RA - yield self.RT, self.RB - yield self.RT, self.RC - yield self.RT, self.RS - yield self.RS, self.RA - yield self.RS, self.RB - - -@final -@unique -class ShiftKind(Enum): - Sl = "sl" - Sr = "sr" - Sra = "sra" - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpBigIntShift(Op): - __slots__ = "RT", "inp", "sh", "kind" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"inp": self.inp, "sh": self.sh} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT} - - def __init__(self, inp, sh, kind): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None - self.RT = SSAVal(self, "RT", inp.ty) - self.inp = inp - self.sh = sh - self.kind = kind - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.inp - yield self.RT, self.sh - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpLI(Op): - __slots__ = "out", "value" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, value, length=1): - # type: (int, int) -> None - self.out = SSAVal(self, "out", GPRRangeType(length)) - self.value = value - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpClearCY(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self): - # type: () -> None - self.out = SSAVal(self, "out", CYType()) - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpLoad(Op): - __slots__ = "RT", "RA", "offset", "mem" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "mem": self.mem} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT} - - def __init__(self, RA, offset, mem, length=1): - # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None - self.RT = SSAVal(self, "RT", GPRRangeType(length)) - self.RA = RA - self.offset = offset - self.mem = mem - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - if self.RT.ty.length > 1: - yield self.RT, self.RA - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpStore(Op): - __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"mem_out": self.mem_out} - - def __init__(self, RS, RA, offset, mem_in): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None - self.RS = RS - self.RA = RA - self.offset = offset - self.mem_in = mem_in - self.mem_out = SSAVal(self, "mem_out", mem_in.ty) - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpFuncArg(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, ty): - # type: (RegType) -> None - self.out = SSAVal(self, "out", ty) - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpInputMem(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self): - # type: () -> None - self.out = SSAVal(self, "out", GlobalMemType()) - - -def op_set_to_list(ops): - # type: (Iterable[Op]) -> list[Op] - worklists = [set()] # type: list[set[Op]] - input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]] - ops_to_pending_input_count_map = {} # type: dict[Op, int] - for op in ops: - input_count = 0 - for val in op.inputs().values(): - input_count += 1 - input_vals_to_ops_map[val].add(op) - while len(worklists) <= input_count: - worklists.append(set()) - ops_to_pending_input_count_map[op] = input_count - worklists[input_count].add(op) - retval = [] # type: list[Op] - ready_vals = set() # type: set[SSAVal] - while len(worklists[0]) != 0: - writing_op = worklists[0].pop() - retval.append(writing_op) - for val in writing_op.outputs().values(): - if val in ready_vals: - raise ValueError(f"multiple instructions must not write " - f"to the same SSA value: {val}") - ready_vals.add(val) - for reading_op in input_vals_to_ops_map[val]: - pending = ops_to_pending_input_count_map[reading_op] - worklists[pending].remove(reading_op) - pending -= 1 - worklists[pending].add(reading_op) - ops_to_pending_input_count_map[reading_op] = pending - for worklist in worklists: - for op in worklist: - raise ValueError(f"instruction is part of a dependency loop or " - f"its inputs are never written: {op}") - return retval - - -@plain_data(unsafe_hash=True, order=True, frozen=True) -class LiveInterval: - __slots__ = "first_write", "last_use" - - def __init__(self, first_write, last_use=None): - # type: (int, int | None) -> None - if last_use is None: - last_use = first_write - if last_use < first_write: - raise ValueError("uses must be after first_write") - if first_write < 0 or last_use < 0: - raise ValueError("indexes must be nonnegative") - self.first_write = first_write - self.last_use = last_use - - def overlaps(self, other): - # type: (LiveInterval) -> bool - if self.first_write == other.first_write: - return True - return self.last_use > other.first_write \ - and other.last_use > self.first_write - - def __add__(self, use): - # type: (int) -> LiveInterval - last_use = max(self.last_use, use) - return LiveInterval(first_write=self.first_write, last_use=last_use) - - @property - def live_after_op_range(self): - """the range of op indexes where self is live immediately after the - Op at each index - """ - return range(self.first_write, self.last_use) - - -@final -class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): - def __init__(self, reg_set): - # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None - self.__items = {} # type: dict[SSAVal[_RegT_co], int] - if isinstance(reg_set, SSAVal): - reg_set = [(reg_set, 0)] - for ssa_val, offset in reg_set: - if ssa_val in self.__items: - other = self.__items[ssa_val] - if offset != other: - raise ValueError( - f"can't merge register sets: conflicting offsets: " - f"for {ssa_val}: {offset} != {other}") - else: - self.__items[ssa_val] = offset - first_item = None - for i in self.__items.items(): - first_item = i - break - if first_item is None: - raise ValueError("can't have empty MergedRegs") - first_ssa_val, start = first_item - ty = first_ssa_val.ty - if isinstance(ty, GPRRangeType): - stop = start + ty.length - for ssa_val, offset in self.__items.items(): - if not isinstance(ssa_val.ty, GPRRangeType): - raise ValueError(f"can't merge incompatible types: " - f"{ssa_val.ty} and {ty}") - stop = max(stop, offset + ssa_val.ty.length) - start = min(start, offset) - ty = GPRRangeType(stop - start) - else: - stop = 1 - for ssa_val, offset in self.__items.items(): - if offset != 0: - raise ValueError(f"can't have non-zero offset " - f"for {ssa_val.ty}") - if ty != ssa_val.ty: - raise ValueError(f"can't merge incompatible types: " - f"{ssa_val.ty} and {ty}") - self.__start = start # type: int - self.__stop = stop # type: int - self.__ty = ty # type: RegType - self.__hash = hash(frozenset(self.items())) - - @staticmethod - def from_equality_constraint(constraint_sequence): - # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co] - if len(constraint_sequence) == 1: - # any type allowed with len = 1 - return MergedRegSet(constraint_sequence[0]) - offset = 0 - retval = [] - for val in constraint_sequence: - if not isinstance(val.ty, GPRRangeType): - raise ValueError("equality constraint sequences must only " - "have SSAVal type GPRRangeType") - retval.append((val, offset)) - offset += val.ty.length - return MergedRegSet(retval) - - @property - def ty(self): - return self.__ty - - @property - def stop(self): - return self.__stop - - @property - def start(self): - return self.__start - - @property - def range(self): - return range(self.__start, self.__stop) - - def offset_by(self, amount): - # type: (int) -> MergedRegSet[_RegT_co] - return MergedRegSet((k, v + amount) for k, v in self.items()) - - def normalized(self): - # type: () -> MergedRegSet[_RegT_co] - return self.offset_by(-self.start) - - def with_offset_to_match(self, target): - # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co] - for ssa_val, offset in self.items(): - if ssa_val in target: - return self.offset_by(target[ssa_val] - offset) - raise ValueError("can't change offset to match unrelated MergedRegSet") - - def __getitem__(self, item): - # type: (SSAVal[_RegT_co]) -> int - return self.__items[item] - - def __iter__(self): - return iter(self.__items) - - def __len__(self): - return len(self.__items) - - def __hash__(self): - return self.__hash - - def __repr__(self): - return f"MergedRegSet({list(self.__items.items())})" - - -@final -class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]): - def __init__(self, ops): - # type: (Iterable[Op]) -> None - merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]] - for op in ops: - for val in (*op.inputs().values(), *op.outputs().values()): - if val not in merged_sets: - merged_sets[val] = MergedRegSet(val) - for e in op.get_equality_constraints(): - lhs_set = MergedRegSet.from_equality_constraint(e.lhs) - rhs_set = MergedRegSet.from_equality_constraint(e.rhs) - lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set) - rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set) - full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()]) - for val in full_set.keys(): - merged_sets[val] = full_set - - self.__map = {k: v.normalized() for k, v in merged_sets.items()} - - def __getitem__(self, key): - # type: (SSAVal) -> MergedRegSet - return self.__map[key] - - def __iter__(self): - return iter(self.__map) - - def __len__(self): - return len(self.__map) - - def __repr__(self): - return f"MergedRegSets(data={self.__map})" - - -@final -class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): - def __init__(self, ops): - # type: (list[Op]) -> None - self.__merged_reg_sets = MergedRegSets(ops) - live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval] - for op_idx, op in enumerate(ops): - for val in op.inputs().values(): - live_intervals[self.__merged_reg_sets[val]] += op_idx - for val in op.outputs().values(): - reg_set = self.__merged_reg_sets[val] - if reg_set not in live_intervals: - live_intervals[reg_set] = LiveInterval(op_idx) - else: - live_intervals[reg_set] += op_idx - self.__live_intervals = live_intervals - live_after = [] # type: list[set[MergedRegSet[_RegT_co]]] - live_after += (set() for _ in ops) - for reg_set, live_interval in self.__live_intervals.items(): - for i in live_interval.live_after_op_range: - live_after[i].add(reg_set) - self.__live_after = [frozenset(i) for i in live_after] - - @property - def merged_reg_sets(self): - return self.__merged_reg_sets - - def __getitem__(self, key): - # type: (MergedRegSet[_RegT_co]) -> LiveInterval - return self.__live_intervals[key] - - def __iter__(self): - return iter(self.__live_intervals) - - def reg_sets_live_after(self, op_index): - # type: (int) -> frozenset[MergedRegSet[_RegT_co]] - return self.__live_after[op_index] - - def __repr__(self): - reg_sets_live_after = dict(enumerate(self.__live_after)) - return (f"LiveIntervals(live_intervals={self.__live_intervals}, " - f"merged_reg_sets={self.merged_reg_sets}, " - f"reg_sets_live_after={reg_sets_live_after})") - - -@final -class IGNode(Generic[_RegT_co]): - """ interference graph node """ - __slots__ = "merged_reg_set", "edges", "reg" - - def __init__(self, merged_reg_set, edges=(), reg=None): - # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None - self.merged_reg_set = merged_reg_set - self.edges = set(edges) - self.reg = reg - - def add_edge(self, other): - # type: (IGNode) -> None - self.edges.add(other) - other.edges.add(self) - - def __eq__(self, other): - # type: (object) -> bool - if isinstance(other, IGNode): - return self.merged_reg_set == other.merged_reg_set - return NotImplemented - - def __hash__(self): - return hash(self.merged_reg_set) - - def __repr__(self, nodes=None): - # type: (None | dict[IGNode, int]) -> str - if nodes is None: - nodes = {} - if self in nodes: - return f"" - nodes[self] = len(nodes) - edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" - return (f"IGNode(#{nodes[self]}, " - f"merged_reg_set={self.merged_reg_set}, " - f"edges={edges}, " - f"reg={self.reg})") - - @property - def reg_class(self): - # type: () -> RegClass - return self.merged_reg_set.ty.reg_class - - def reg_conflicts_with_neighbors(self, reg): - # type: (RegLoc) -> bool - for neighbor in self.edges: - if neighbor.reg is not None and neighbor.reg.conflicts(reg): - return True - return False - - -@final -class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]): - def __init__(self, merged_reg_sets): - # type: (Iterable[MergedRegSet[_RegT_co]]) -> None - self.__nodes = {i: IGNode(i) for i in merged_reg_sets} - - def __getitem__(self, key): - # type: (MergedRegSet[_RegT_co]) -> IGNode - return self.__nodes[key] - - def __iter__(self): - return iter(self.__nodes) - - def __repr__(self): - nodes = {} - nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] - nodes_text = ", ".join(nodes_text) - return f"InterferenceGraph(nodes={{{nodes_text}}})" - - -@plain_data() -class AllocationFailed: - __slots__ = "node", "live_intervals", "interference_graph" - - def __init__(self, node, live_intervals, interference_graph): - # type: (IGNode, LiveIntervals, InterferenceGraph) -> None - self.node = node - self.live_intervals = live_intervals - self.interference_graph = interference_graph - - -def try_allocate_registers_without_spilling(ops): - # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed - - live_intervals = LiveIntervals(ops) - merged_reg_sets = live_intervals.merged_reg_sets - interference_graph = InterferenceGraph(merged_reg_sets.values()) - for op_idx, op in enumerate(ops): - reg_sets = live_intervals.reg_sets_live_after(op_idx) - for i, j in combinations(reg_sets, 2): - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: - interference_graph[i].add_edge(interference_graph[j]) - for i, j in op.get_extra_interferences(): - i = merged_reg_sets[i] - j = merged_reg_sets[j] - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: - interference_graph[i].add_edge(interference_graph[j]) - - nodes_remaining = set(interference_graph.values()) - - def local_colorability_score(node): - # type: (IGNode) -> int - """ returns a positive integer if node is locally colorable, returns - zero or a negative integer if node isn't known to be locally - colorable, the more negative the value, the less colorable - """ - if node not in nodes_remaining: - raise ValueError() - retval = len(node.reg_class) - for neighbor in node.edges: - if neighbor in nodes_remaining: - retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) - return retval - - node_stack = [] # type: list[IGNode] - while True: - best_node = None # type: None | IGNode - best_score = 0 - for node in nodes_remaining: - score = local_colorability_score(node) - if best_node is None or score > best_score: - best_node = node - best_score = score - if best_score > 0: - # it's locally colorable, no need to find a better one - break - - if best_node is None: - break - node_stack.append(best_node) - nodes_remaining.remove(best_node) - - retval = {} # type: dict[SSAVal, RegLoc] - - while len(node_stack) > 0: - node = node_stack.pop() - if node.reg is not None: - if node.reg_conflicts_with_neighbors(node.reg): - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) - else: - # pick the first non-conflicting register in node.reg_class, since - # register classes are ordered from most preferred to least - # preferred register. - for reg in node.reg_class: - if not node.reg_conflicts_with_neighbors(reg): - node.reg = reg - break - if node.reg is None: - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) - - for ssa_val, offset in node.merged_reg_set.items(): - retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) - - return retval - - -def allocate_registers(ops): - # type: (list[Op]) -> None - raise NotImplementedError +from bigint_presentation_code.compiler_ir import Op +from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed