From 8a5d41477305c7c1b781cb9c8e14cd99ab7d542b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 12 Oct 2022 20:34:41 -0700 Subject: [PATCH] construct interference graph --- src/bigint_presentation_code/toom_cook.py | 117 +++++++++++++++++----- 1 file changed, 94 insertions(+), 23 deletions(-) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index e8e97bf..ce26d68 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -9,6 +9,7 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, unique, EnumMeta from functools import lru_cache +from itertools import combinations from typing import (Sequence, AbstractSet, Iterable, Mapping, TYPE_CHECKING, Sequence, Sized, TypeVar, Generic) @@ -259,16 +260,16 @@ class StackSlot(GPRRangeOrStackLoc): return self.length -_RegType = TypeVar("_RegType", bound=RegType) +_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) @plain_data(frozen=True, eq=False) @final -class SSAVal(Generic[_RegType]): +class SSAVal(Generic[_RegT_co]): __slots__ = "op", "arg_name", "ty", "arg_index" def __init__(self, op, arg_name, ty): - # type: (Op, str, _RegType) -> None + # type: (Op, str, _RegT_co) -> None self.op = op """the Op that writes this SSAVal""" @@ -319,13 +320,18 @@ class Op(metaclass=ABCMeta): if False: yield ... + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + if False: + yield ... + def __init__(self): pass @plain_data(unsafe_hash=True, frozen=True) @final -class OpCopy(Op, Generic[_RegType]): +class OpCopy(Op, Generic[_RegT_co]): __slots__ = "dest", "src" def inputs(self): @@ -337,7 +343,7 @@ class OpCopy(Op, Generic[_RegType]): return {"dest": self.dest} def __init__(self, src): - # type: (SSAVal[_RegType]) -> None + # type: (SSAVal[_RegT_co]) -> None self.dest = SSAVal(self, "dest", src.ty) self.src = src @@ -425,6 +431,11 @@ class OpAddSubE(Op): self.CY_out = SSAVal(self, "CY_out", CY_in.ty) self.is_sub = is_sub + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.RA + yield self.RT, self.RB + @plain_data(unsafe_hash=True, frozen=True) @final @@ -452,6 +463,15 @@ class OpBigIntMulDiv(Op): # type: () -> Iterable[EqualityConstraint] yield EqualityConstraint([self.RC], [self.RS]) + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.RA + yield self.RT, self.RB + yield self.RT, self.RC + yield self.RT, self.RS + yield self.RS, self.RA + yield self.RS, self.RB + @final @unique @@ -481,6 +501,11 @@ class OpBigIntShift(Op): self.sh = sh self.kind = kind + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + yield self.RT, self.inp + yield self.RT, self.sh + @plain_data(unsafe_hash=True, frozen=True) @final @@ -539,6 +564,11 @@ class OpLoad(Op): self.offset = offset self.mem = mem + def get_extra_interferences(self): + # type: () -> Iterable[tuple[SSAVal, SSAVal]] + if self.RT.ty.length > 1: + yield self.RT, self.RA + @plain_data(unsafe_hash=True, frozen=True) @final @@ -662,12 +692,19 @@ class LiveInterval: last_use = max(self.last_use, use) return LiveInterval(first_write=self.first_write, last_use=last_use) + @property + def live_after_op_range(self): + """the range of op indexes where self is live immediately after the + Op at each index + """ + return range(self.first_write, self.last_use) + @final -class MergedRegSet(Mapping[SSAVal[_RegType], int]): +class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): def __init__(self, reg_set): - # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None - self.__items = {} # type: dict[SSAVal[_RegType], int] + # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None + self.__items = {} # type: dict[SSAVal[_RegT_co], int] if isinstance(reg_set, SSAVal): reg_set = [(reg_set, 0)] for ssa_val, offset in reg_set: @@ -711,7 +748,7 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]): @staticmethod def from_equality_constraint(constraint_sequence): - # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType] + # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co] if len(constraint_sequence) == 1: # any type allowed with len = 1 return MergedRegSet(constraint_sequence[0]) @@ -742,22 +779,22 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]): return range(self.__start, self.__stop) def offset_by(self, amount): - # type: (int) -> MergedRegSet[_RegType] + # type: (int) -> MergedRegSet[_RegT_co] return MergedRegSet((k, v + amount) for k, v in self.items()) def normalized(self): - # type: () -> MergedRegSet[_RegType] + # type: () -> MergedRegSet[_RegT_co] return self.offset_by(-self.start) def with_offset_to_match(self, target): - # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType] + # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co] for ssa_val, offset in self.items(): if ssa_val in target: return self.offset_by(target[ssa_val] - offset) raise ValueError("can't change offset to match unrelated MergedRegSet") def __getitem__(self, item): - # type: (SSAVal[_RegType]) -> int + # type: (SSAVal[_RegT_co]) -> int return self.__items[item] def __iter__(self): @@ -774,10 +811,10 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]): @final -class MergedRegSets(Mapping[SSAVal, MergedRegSet]): +class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]): def __init__(self, ops): # type: (Iterable[Op]) -> None - merged_sets = {} # type: dict[SSAVal, MergedRegSet] + merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]] for op in ops: for val in (*op.inputs().values(), *op.outputs().values()): if val not in merged_sets: @@ -803,13 +840,16 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet]): def __len__(self): return len(self.__map) + def __repr__(self): + return f"MergedRegSets(data={self.__map})" + @final -class LiveIntervals(Mapping[MergedRegSet, LiveInterval]): +class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): def __init__(self, ops): # type: (list[Op]) -> None self.__merged_reg_sets = MergedRegSets(ops) - live_intervals = {} # type: dict[MergedRegSet, LiveInterval] + live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval] for op_idx, op in enumerate(ops): for val in op.inputs().values(): live_intervals[self.__merged_reg_sets[val]] += op_idx @@ -820,26 +860,42 @@ class LiveIntervals(Mapping[MergedRegSet, LiveInterval]): else: live_intervals[reg_set] += op_idx self.__live_intervals = live_intervals + live_after = [] # type: list[set[MergedRegSet[_RegT_co]]] + live_after += (set() for _ in ops) + for reg_set, live_interval in self.__live_intervals.items(): + for i in live_interval.live_after_op_range: + live_after[i].add(reg_set) + self.__live_after = [frozenset(i) for i in live_after] @property def merged_reg_sets(self): return self.__merged_reg_sets def __getitem__(self, key): - # type: (MergedRegSet) -> LiveInterval + # type: (MergedRegSet[_RegT_co]) -> LiveInterval return self.__live_intervals[key] def __iter__(self): return iter(self.__live_intervals) + def reg_sets_live_after(self, op_index): + # type: (int) -> frozenset[MergedRegSet[_RegT_co]] + return self.__live_after[op_index] + + def __repr__(self): + reg_sets_live_after = dict(enumerate(self.__live_after)) + return (f"LiveIntervals(live_intervals={self.__live_intervals}, " + f"merged_reg_sets={self.merged_reg_sets}, " + f"reg_sets_live_after={reg_sets_live_after})") + @final -class IGNode: +class IGNode(Generic[_RegT_co]): """ interference graph node """ __slots__ = "merged_reg_set", "edges" def __init__(self, merged_reg_set, edges=()): - # type: (MergedRegSet, Iterable[IGNode]) -> None + # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None self.merged_reg_set = merged_reg_set self.edges = set(edges) @@ -871,18 +927,24 @@ class IGNode: @final -class InterferenceGraph(Mapping[MergedRegSet, IGNode]): +class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]): def __init__(self, merged_reg_sets): - # type: (Iterable[MergedRegSet]) -> None + # type: (Iterable[MergedRegSet[_RegT_co]]) -> None self.__nodes = {i: IGNode(i) for i in merged_reg_sets} def __getitem__(self, key): - # type: (MergedRegSet) -> IGNode + # type: (MergedRegSet[_RegT_co]) -> IGNode return self.__nodes[key] def __iter__(self): return iter(self.__nodes) + def __repr__(self): + nodes = {} + nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] + nodes_text = ", ".join(nodes_text) + return f"InterferenceGraph(nodes={{{nodes_text}}})" + @plain_data() class AllocationFailed: @@ -899,6 +961,15 @@ def try_allocate_registers_without_spilling(ops): # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed live_intervals = LiveIntervals(ops) + merged_reg_sets = live_intervals.merged_reg_sets + interference_graph = InterferenceGraph(merged_reg_sets.values()) + for op_idx, op in enumerate(ops): + reg_sets = live_intervals.reg_sets_live_after(op_idx) + for i, j in combinations(reg_sets, 2): + interference_graph[i].add_edge(interference_graph[j]) + for i, j in op.get_extra_interferences(): + interference_graph[merged_reg_sets[i]].add_edge( + interference_graph[merged_reg_sets[j]]) raise NotImplementedError -- 2.30.2