From 8efbb81a5dcd6a38e4269550b121b5de4adf1440 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 7 Dec 2022 20:58:37 -0800 Subject: [PATCH] refactor preparing for copy merging --- .../register_allocator.py | 160 +++++++++++------- 1 file changed, 96 insertions(+), 64 deletions(-) diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index a1d1785..e4e3ab6 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -6,8 +6,8 @@ this uses an algorithm based on: """ from functools import reduce -from itertools import combinations -from typing import Callable, Iterable, Iterator, Mapping, TextIO +from itertools import chain, combinations +from typing import Callable, Iterable, Iterator, Mapping, TextIO, Tuple from cached_property import cached_property from nmutil.plain_data import plain_data @@ -23,6 +23,9 @@ class BadMergedSSAVal(ValueError): pass +_CopyRelation = Tuple[SSAValSubReg, SSAValSubReg] + + @plain_data(frozen=True, repr=False) @final class MergedSSAVal(metaclass=InternedMeta): @@ -263,17 +266,17 @@ class MergedSSAVal(metaclass=InternedMeta): sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val]) return OFSet(v for s in sets for v in s) - def is_copy_related(self, other): - # type: (MergedSSAVal) -> bool + def get_copy_relation(self, other): + # type: (MergedSSAVal) -> None | _CopyRelation for lhs_ssa_val in self.ssa_vals: for rhs_ssa_val in other.ssa_vals: for lhs in lhs_ssa_val.ssa_val_sub_regs: for rhs in rhs_ssa_val.ssa_val_sub_regs: - lhs = self.fn_analysis.copies.get(lhs, lhs) - rhs = self.fn_analysis.copies.get(rhs, rhs) - if lhs == rhs: - return True - return False + lhs_src = self.fn_analysis.copies.get(lhs, lhs) + rhs_src = self.fn_analysis.copies.get(rhs, rhs) + if lhs_src == rhs_src: + return lhs_src, rhs_src + return None @final @@ -344,7 +347,8 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): f"{self.__merged_ssa_val_map[ssa_val]}") self.__merged_ssa_val_map[ssa_val] = merged_ssa_val added += 1 - retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None) + retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None, + ignored=False) self.__map[merged_ssa_val] = retval added = None return retval @@ -364,6 +368,8 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): for ssa_val in final_merged_ssa_val.ssa_vals: merged_ssa_val = self.__merged_ssa_val_map[ssa_val] source_node = self.__map[merged_ssa_val] + if source_node.ignored: + raise ValueError(f"can't merge ignored nodes: {source_node}") source_nodes.add(source_node) for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals: raise ValueError( @@ -386,7 +392,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): for n in source_nodes: edges.pop(n, None) retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, - loc=loc) + loc=loc, ignored=False) for node in edges: edge = reduce(IGEdge.merged, (node.edges.pop(n) for n in source_nodes)) @@ -419,14 +425,67 @@ class InterferenceGraph: for i in merged_ssa_vals: self.nodes.add_node(i) - def merge(self, ssa_val1, ssa_val2, additional_offset=0): - # type: (SSAVal, SSAVal, int) -> IGNode + def merge_preview(self, ssa_val1, ssa_val2, additional_offset=0): + # type: (SSAVal, SSAVal, int) -> MergedSSAVal merged1 = self.merged_ssa_val_map[ssa_val1] merged2 = self.merged_ssa_val_map[ssa_val2] merged = merged1.with_offset_to_match(ssa_val1) - merged = merged.merged(merged2.with_offset_to_match( + return merged.merged(merged2.with_offset_to_match( ssa_val2, additional_offset=additional_offset)) - return self.nodes.merge_into_one_node(merged) + + def merge(self, ssa_val1, ssa_val2, additional_offset=0): + # type: (SSAVal, SSAVal, int) -> IGNode + return self.nodes.merge_into_one_node(self.merge_preview( + ssa_val1=ssa_val1, ssa_val2=ssa_val2, + additional_offset=additional_offset)) + + def copy_merge_preview(self, node1, node2): + # type: (IGNode, IGNode) -> MergedSSAVal + try: + copy_relation = node1.edges[node2].copy_relation + except KeyError: + raise ValueError("nodes aren't copy related") + if copy_relation is None: + raise ValueError("nodes aren't copy related") + lhs, rhs = copy_relation + if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals: + lhs, rhs = rhs, lhs + lhs_merged = node1.merged_ssa_val.with_offset_to_match( + lhs.ssa_val, additional_offset=-lhs.reg_idx) + rhs_merged = node1.merged_ssa_val.with_offset_to_match( + rhs.ssa_val, additional_offset=-rhs.reg_idx) + return lhs_merged.merged(rhs_merged) + + def copy_merge(self, node1, node2): + # type: (IGNode, IGNode) -> IGNode + return self.nodes.merge_into_one_node(self.copy_merge_preview( + node1=node1, node2=node2)) + + def local_colorability_score(self, node, merged_in_copy=None): + # type: (IGNode, None | IGNode) -> int + """ returns a positive integer if node is locally colorable, returns + zero or a negative integer if node isn't known to be locally + colorable, the more negative the value, the less colorable. + + if `merged_in_copy` is not `None`, then the node used is what would be + the result of `self.copy_merge(node, merged_in_copy)`. + """ + if node.ignored: + raise ValueError( + "can't get local_colorability_score of ignored node") + loc_set = node.loc_set + edges = node.edges.items() + if merged_in_copy is not None: + loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set + edges = chain(edges, merged_in_copy.edges.items()) + retval = len(loc_set) + for neighbor, edge in edges: + if neighbor.ignored or not edge.interferes: + continue + if neighbor == merged_in_copy or neighbor == node: + continue + retval -= loc_set.max_conflicts_with(neighbor.loc_set) + return retval @staticmethod def minimally_merged(fn_analysis): @@ -499,7 +558,7 @@ class InterferenceGraph: if edge.interferes: append_edge(node1, node2, label="interferes", color="darkred", style="bold") - if edge.is_copy_related: + if edge.copy_relation is not None: append_edge(node1, node2, label="copy related", color="blue", style="dashed") lines.append("}") @@ -520,32 +579,35 @@ class IGNodeReprState: @final class IGEdge: """ interference graph edge """ - __slots__ = "interferes", "is_copy_related" + __slots__ = "interferes", "copy_relation" - def __init__(self, interferes=False, is_copy_related=False): - # type: (bool, bool) -> None + def __init__(self, interferes=False, copy_relation=None): + # type: (bool, None | _CopyRelation) -> None self.interferes = interferes - self.is_copy_related = is_copy_related + self.copy_relation = copy_relation def merged(self, other): # type: (IGEdge | None) -> IGEdge if other is None: return self - is_copy_related = self.is_copy_related | other.is_copy_related + copy_relation = self.copy_relation + if copy_relation is None: + copy_relation = other.copy_relation interferes = self.interferes | other.interferes - return IGEdge(interferes=interferes, is_copy_related=is_copy_related) + return IGEdge(interferes=interferes, copy_relation=copy_relation) @final class IGNode: """ interference graph node """ - __slots__ = "merged_ssa_val", "edges", "loc" + __slots__ = "merged_ssa_val", "edges", "loc", "ignored" - def __init__(self, merged_ssa_val, edges, loc): - # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None + def __init__(self, merged_ssa_val, edges, loc, ignored): + # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None self.merged_ssa_val = merged_ssa_val self.edges = edges self.loc = loc + self.ignored = ignored def merge_edge(self, other, edge): # type: (IGNode, IGEdge) -> None @@ -586,7 +648,8 @@ class IGNode: return (f"IGNode(#{node_id}, " f"merged_ssa_val={self.merged_ssa_val}, " f"edges={{{edges}}}, " - f"loc={self.loc})") + f"loc={self.loc}, " + f"ignored={self.ignored})") @property def loc_set(self): @@ -662,12 +725,8 @@ def allocate_registers( file=debug_out, flush=True) for i, j in combinations(interference_graph.nodes.values(), 2): - # can't use: - # is_copy_related = (not i.merged_ssa_val.copy_related_ssa_vals - # .isdisjoint(j.merged_ssa_val.copy_related_ssa_vals)) - # since it is too coarse - is_copy_related = i.merged_ssa_val.is_copy_related(j.merged_ssa_val) - i.merge_edge(j, IGEdge(is_copy_related=is_copy_related)) + copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val) + i.merge_edge(j, IGEdge(copy_relation=copy_relation)) if debug_out is not None: print(f"After adding interference graph edges:\n" @@ -675,38 +734,16 @@ def allocate_registers( if dump_graph is not None: dump_graph("initial", interference_graph.dump_to_dot()) - nodes_remaining = OSet(interference_graph.nodes.values()) - - local_colorability_score_cache = {} # type: dict[IGNode, int] - - def local_colorability_score(node): - # type: (IGNode) -> int - """ returns a positive integer if node is locally colorable, returns - zero or a negative integer if node isn't known to be locally - colorable, the more negative the value, the less colorable - """ - if node not in nodes_remaining: - raise ValueError() - retval = local_colorability_score_cache.get(node, None) - if retval is not None: - return retval - retval = len(node.loc_set) - for neighbor, edge in node.edges.items(): - if not edge.interferes: - continue - if neighbor in nodes_remaining: - retval -= node.loc_set.max_conflicts_with(neighbor.loc_set) - local_colorability_score_cache[node] = retval - return retval - # TODO: implement copy-merging node_stack = [] # type: list[IGNode] while True: best_node = None # type: None | IGNode best_score = 0 - for node in nodes_remaining: - score = local_colorability_score(node) + for node in interference_graph.nodes.values(): + if node.ignored: + continue + score = interference_graph.local_colorability_score(node) if best_node is None or score > best_score: best_node = node best_score = score @@ -717,12 +754,7 @@ def allocate_registers( if best_node is None: break node_stack.append(best_node) - nodes_remaining.remove(best_node) - local_colorability_score_cache.pop(best_node, None) - for neighbor, edge in best_node.edges.items(): - if not edge.interferes: - continue - local_colorability_score_cache.pop(neighbor, None) + best_node.ignored = True if debug_out is not None: print(f"After deciding node allocation order:\n" -- 2.30.2