refactor preparing for copy merging
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 8 Dec 2022 04:58:37 +0000 (20:58 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 8 Dec 2022 06:04:07 +0000 (22:04 -0800)
src/bigint_presentation_code/register_allocator.py

index a1d17859c13fae76884d888132431c81fec24fac..e4e3ab68257dc322c23b389433a788ec56007bbf 100644 (file)
@@ -6,8 +6,8 @@ this uses an algorithm based on:
 """
 
 from functools import reduce
-from itertools import combinations
-from typing import Callable, Iterable, Iterator, Mapping, TextIO
+from itertools import chain, combinations
+from typing import Callable, Iterable, Iterator, Mapping, TextIO, Tuple
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
@@ -23,6 +23,9 @@ class BadMergedSSAVal(ValueError):
     pass
 
 
+_CopyRelation = Tuple[SSAValSubReg, SSAValSubReg]
+
+
 @plain_data(frozen=True, repr=False)
 @final
 class MergedSSAVal(metaclass=InternedMeta):
@@ -263,17 +266,17 @@ class MergedSSAVal(metaclass=InternedMeta):
             sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
         return OFSet(v for s in sets for v in s)
 
-    def is_copy_related(self, other):
-        # type: (MergedSSAVal) -> bool
+    def get_copy_relation(self, other):
+        # type: (MergedSSAVal) -> None | _CopyRelation
         for lhs_ssa_val in self.ssa_vals:
             for rhs_ssa_val in other.ssa_vals:
                 for lhs in lhs_ssa_val.ssa_val_sub_regs:
                     for rhs in rhs_ssa_val.ssa_val_sub_regs:
-                        lhs = self.fn_analysis.copies.get(lhs, lhs)
-                        rhs = self.fn_analysis.copies.get(rhs, rhs)
-                        if lhs == rhs:
-                            return True
-        return False
+                        lhs_src = self.fn_analysis.copies.get(lhs, lhs)
+                        rhs_src = self.fn_analysis.copies.get(rhs, rhs)
+                        if lhs_src == rhs_src:
+                            return lhs_src, rhs_src
+        return None
 
 
 @final
@@ -344,7 +347,8 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                         f"{self.__merged_ssa_val_map[ssa_val]}")
                 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
                 added += 1
-            retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None)
+            retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None,
+                            ignored=False)
             self.__map[merged_ssa_val] = retval
             added = None
             return retval
@@ -364,6 +368,8 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         for ssa_val in final_merged_ssa_val.ssa_vals:
             merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
             source_node = self.__map[merged_ssa_val]
+            if source_node.ignored:
+                raise ValueError(f"can't merge ignored nodes: {source_node}")
             source_nodes.add(source_node)
             for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
                 raise ValueError(
@@ -386,7 +392,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         for n in source_nodes:
             edges.pop(n, None)
         retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
-                        loc=loc)
+                        loc=loc, ignored=False)
         for node in edges:
             edge = reduce(IGEdge.merged,
                           (node.edges.pop(n) for n in source_nodes))
@@ -419,14 +425,67 @@ class InterferenceGraph:
         for i in merged_ssa_vals:
             self.nodes.add_node(i)
 
-    def merge(self, ssa_val1, ssa_val2, additional_offset=0):
-        # type: (SSAVal, SSAVal, int) -> IGNode
+    def merge_preview(self, ssa_val1, ssa_val2, additional_offset=0):
+        # type: (SSAVal, SSAVal, int) -> MergedSSAVal
         merged1 = self.merged_ssa_val_map[ssa_val1]
         merged2 = self.merged_ssa_val_map[ssa_val2]
         merged = merged1.with_offset_to_match(ssa_val1)
-        merged = merged.merged(merged2.with_offset_to_match(
+        return merged.merged(merged2.with_offset_to_match(
             ssa_val2, additional_offset=additional_offset))
-        return self.nodes.merge_into_one_node(merged)
+
+    def merge(self, ssa_val1, ssa_val2, additional_offset=0):
+        # type: (SSAVal, SSAVal, int) -> IGNode
+        return self.nodes.merge_into_one_node(self.merge_preview(
+            ssa_val1=ssa_val1, ssa_val2=ssa_val2,
+            additional_offset=additional_offset))
+
+    def copy_merge_preview(self, node1, node2):
+        # type: (IGNode, IGNode) -> MergedSSAVal
+        try:
+            copy_relation = node1.edges[node2].copy_relation
+        except KeyError:
+            raise ValueError("nodes aren't copy related")
+        if copy_relation is None:
+            raise ValueError("nodes aren't copy related")
+        lhs, rhs = copy_relation
+        if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals:
+            lhs, rhs = rhs, lhs
+        lhs_merged = node1.merged_ssa_val.with_offset_to_match(
+            lhs.ssa_val, additional_offset=-lhs.reg_idx)
+        rhs_merged = node1.merged_ssa_val.with_offset_to_match(
+            rhs.ssa_val, additional_offset=-rhs.reg_idx)
+        return lhs_merged.merged(rhs_merged)
+
+    def copy_merge(self, node1, node2):
+        # type: (IGNode, IGNode) -> IGNode
+        return self.nodes.merge_into_one_node(self.copy_merge_preview(
+            node1=node1, node2=node2))
+
+    def local_colorability_score(self, node, merged_in_copy=None):
+        # type: (IGNode, None | IGNode) -> int
+        """ returns a positive integer if node is locally colorable, returns
+        zero or a negative integer if node isn't known to be locally
+        colorable, the more negative the value, the less colorable.
+
+        if `merged_in_copy` is not `None`, then the node used is what would be
+        the result of `self.copy_merge(node, merged_in_copy)`.
+        """
+        if node.ignored:
+            raise ValueError(
+                "can't get local_colorability_score of ignored node")
+        loc_set = node.loc_set
+        edges = node.edges.items()
+        if merged_in_copy is not None:
+            loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set
+            edges = chain(edges, merged_in_copy.edges.items())
+        retval = len(loc_set)
+        for neighbor, edge in edges:
+            if neighbor.ignored or not edge.interferes:
+                continue
+            if neighbor == merged_in_copy or neighbor == node:
+                continue
+            retval -= loc_set.max_conflicts_with(neighbor.loc_set)
+        return retval
 
     @staticmethod
     def minimally_merged(fn_analysis):
@@ -499,7 +558,7 @@ class InterferenceGraph:
             if edge.interferes:
                 append_edge(node1, node2, label="interferes",
                             color="darkred", style="bold")
-            if edge.is_copy_related:
+            if edge.copy_relation is not None:
                 append_edge(node1, node2, label="copy related",
                             color="blue", style="dashed")
         lines.append("}")
@@ -520,32 +579,35 @@ class IGNodeReprState:
 @final
 class IGEdge:
     """ interference graph edge """
-    __slots__ = "interferes", "is_copy_related"
+    __slots__ = "interferes", "copy_relation"
 
-    def __init__(self, interferes=False, is_copy_related=False):
-        # type: (bool, bool) -> None
+    def __init__(self, interferes=False, copy_relation=None):
+        # type: (bool, None | _CopyRelation) -> None
         self.interferes = interferes
-        self.is_copy_related = is_copy_related
+        self.copy_relation = copy_relation
 
     def merged(self, other):
         # type: (IGEdge | None) -> IGEdge
         if other is None:
             return self
-        is_copy_related = self.is_copy_related | other.is_copy_related
+        copy_relation = self.copy_relation
+        if copy_relation is None:
+            copy_relation = other.copy_relation
         interferes = self.interferes | other.interferes
-        return IGEdge(interferes=interferes, is_copy_related=is_copy_related)
+        return IGEdge(interferes=interferes, copy_relation=copy_relation)
 
 
 @final
 class IGNode:
     """ interference graph node """
-    __slots__ = "merged_ssa_val", "edges", "loc"
+    __slots__ = "merged_ssa_val", "edges", "loc", "ignored"
 
-    def __init__(self, merged_ssa_val, edges, loc):
-        # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
+    def __init__(self, merged_ssa_val, edges, loc, ignored):
+        # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
         self.merged_ssa_val = merged_ssa_val
         self.edges = edges
         self.loc = loc
+        self.ignored = ignored
 
     def merge_edge(self, other, edge):
         # type: (IGNode, IGEdge) -> None
@@ -586,7 +648,8 @@ class IGNode:
         return (f"IGNode(#{node_id}, "
                 f"merged_ssa_val={self.merged_ssa_val}, "
                 f"edges={{{edges}}}, "
-                f"loc={self.loc})")
+                f"loc={self.loc}, "
+                f"ignored={self.ignored})")
 
     @property
     def loc_set(self):
@@ -662,12 +725,8 @@ def allocate_registers(
                   file=debug_out, flush=True)
 
     for i, j in combinations(interference_graph.nodes.values(), 2):
-        # can't use:
-        # is_copy_related = (not i.merged_ssa_val.copy_related_ssa_vals
-        #     .isdisjoint(j.merged_ssa_val.copy_related_ssa_vals))
-        # since it is too coarse
-        is_copy_related = i.merged_ssa_val.is_copy_related(j.merged_ssa_val)
-        i.merge_edge(j, IGEdge(is_copy_related=is_copy_related))
+        copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
+        i.merge_edge(j, IGEdge(copy_relation=copy_relation))
 
     if debug_out is not None:
         print(f"After adding interference graph edges:\n"
@@ -675,38 +734,16 @@ def allocate_registers(
     if dump_graph is not None:
         dump_graph("initial", interference_graph.dump_to_dot())
 
-    nodes_remaining = OSet(interference_graph.nodes.values())
-
-    local_colorability_score_cache = {}  # type: dict[IGNode, int]
-
-    def local_colorability_score(node):
-        # type: (IGNode) -> int
-        """ returns a positive integer if node is locally colorable, returns
-        zero or a negative integer if node isn't known to be locally
-        colorable, the more negative the value, the less colorable
-        """
-        if node not in nodes_remaining:
-            raise ValueError()
-        retval = local_colorability_score_cache.get(node, None)
-        if retval is not None:
-            return retval
-        retval = len(node.loc_set)
-        for neighbor, edge in node.edges.items():
-            if not edge.interferes:
-                continue
-            if neighbor in nodes_remaining:
-                retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
-        local_colorability_score_cache[node] = retval
-        return retval
-
     # TODO: implement copy-merging
 
     node_stack = []  # type: list[IGNode]
     while True:
         best_node = None  # type: None | IGNode
         best_score = 0
-        for node in nodes_remaining:
-            score = local_colorability_score(node)
+        for node in interference_graph.nodes.values():
+            if node.ignored:
+                continue
+            score = interference_graph.local_colorability_score(node)
             if best_node is None or score > best_score:
                 best_node = node
                 best_score = score
@@ -717,12 +754,7 @@ def allocate_registers(
         if best_node is None:
             break
         node_stack.append(best_node)
-        nodes_remaining.remove(best_node)
-        local_colorability_score_cache.pop(best_node, None)
-        for neighbor, edge in best_node.edges.items():
-            if not edge.interferes:
-                continue
-            local_colorability_score_cache.pop(neighbor, None)
+        best_node.ignored = True
 
     if debug_out is not None:
         print(f"After deciding node allocation order:\n"