WIP: copy merging -- currently broken
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 8 Dec 2022 08:45:05 +0000 (00:45 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 8 Dec 2022 08:45:05 +0000 (00:45 -0800)
currently not merging everything it should

src/bigint_presentation_code/_tests/test_register_allocator.py
src/bigint_presentation_code/register_allocator.py

index 5e9d25afe1fbf5e61eb6aafb190ea3d0563b6fb5..d0d9068404d9512c1ec595167e9e2f2b0122c4fe 100644 (file)
@@ -48,7 +48,7 @@ class TestRegisterAllocator(unittest.TestCase):
 
     def test_register_allocate(self):
         fn, _arg = self.make_add_fn()
-        reg_assignments = allocate_registers(fn)
+        reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
 
         self.assertEqual(
             repr(reg_assignments),
@@ -112,22 +112,37 @@ class TestRegisterAllocator(unittest.TestCase):
 
         self.assertEqual(
             repr(reg_assignments),
-            "{<add.outputs[0]: <I64*32>>: "
+            "{"
+            "<add.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.inp1.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<add.inp0.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
-            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<st.inp1.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<st.inp0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<add.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<ca.outputs[0]: <CA>>: "
@@ -140,30 +155,17 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.inp0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<vl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<arg.out0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
-            "<arg.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
+            "}"
         )
         state = GenAsmState(reg_assignments)
         fn.gen_asm(state)
@@ -174,15 +176,13 @@ class TestRegisterAllocator(unittest.TestCase):
             'setvl 0, 0, 32, 0, 1, 1',
             'sv.ld *14, 0(3)',
             'setvl 0, 0, 32, 0, 1, 1',
-            'sv.or *46, *14, *14',
             'setvl 0, 0, 32, 0, 1, 1',
-            'sv.addi *14, 0, 0',
+            'sv.addi *46, 0, 0',
             'setvl 0, 0, 32, 0, 1, 1',
             'subfc 0, 0, 0',
             'setvl 0, 0, 32, 0, 1, 1',
-            'sv.or *78, *46, *46',
+            'sv.or *78, *14, *14',
             'setvl 0, 0, 32, 0, 1, 1',
-            'sv.or *46, *14, *14',
             'setvl 0, 0, 32, 0, 1, 1',
             'sv.adde *14, *78, *46',
             'setvl 0, 0, 32, 0, 1, 1',
@@ -200,26 +200,42 @@ class TestRegisterAllocator(unittest.TestCase):
             # type: (str, str) -> None
             self.assertNotIn(name, graphs, "duplicate graph name")
             graphs[name] = dot
-        allocated = allocate_registers(fn, dump_graph=dump_graph)
+        allocated = allocate_registers(fn, dump_graph=dump_graph,
+                                       debug_out=sys.stdout)
+        dump_graphs(self, graphs)
         self.assertEqual(
             repr(allocated),
             "{"
             "<add.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.inp1.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<add.inp0.copy.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
-            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<st.inp1.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<st.inp0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<add.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<ca.outputs[0]: <CA>>: "
@@ -232,33 +248,18 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<li.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.out0.copy.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
             "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.outputs[0]: <I64*32>>: "
-            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<ld.inp0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<vl.outputs[0]: <VL_MAXVL>>: "
-            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
-            "<arg.out0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
-            "<arg.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)"
             "}"
         )
-        dump_graphs(self, graphs)
         # FIXME: is_copy_related is not correct, it's missing a bunch of
         # edges (which aren't interference edges)
         self.assertEqual(graphs, {
index e4e3ab68257dc322c23b389433a788ec56007bbf..540af0a3e31316203cfefff1851ca2be3e1d816d 100644 (file)
@@ -6,8 +6,8 @@ this uses an algorithm based on:
 """
 
 from functools import reduce
-from itertools import chain, combinations
-from typing import Callable, Iterable, Iterator, Mapping, TextIO, Tuple
+from itertools import chain, combinations, count
+from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
@@ -275,7 +275,7 @@ class MergedSSAVal(metaclass=InternedMeta):
                         lhs_src = self.fn_analysis.copies.get(lhs, lhs)
                         rhs_src = self.fn_analysis.copies.get(rhs, rhs)
                         if lhs_src == rhs_src:
-                            return lhs_src, rhs_src
+                            return lhs, rhs
         return None
 
 
@@ -319,6 +319,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # type: (...) -> None
         self.__merged_ssa_val_map = _private_merged_ssa_val_map
         self.__map = {}  # type: dict[MergedSSAVal, IGNode]
+        self.__next_node_id = 0
 
     def __getitem__(self, __key):
         # type: (MergedSSAVal) -> IGNode
@@ -347,9 +348,11 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                         f"{self.__merged_ssa_val_map[ssa_val]}")
                 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
                 added += 1
-            retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None,
-                            ignored=False)
+            retval = IGNode(
+                node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
+                edges={}, loc=None, ignored=False)
             self.__map[merged_ssa_val] = retval
+            self.__next_node_id += 1
             added = None
             return retval
         finally:
@@ -391,12 +394,18 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # we're finished checking validity, now we can modify stuff
         for n in source_nodes:
             edges.pop(n, None)
-        retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
-                        loc=loc, ignored=False)
+        retval = IGNode(
+            node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
+            edges=edges, loc=loc, ignored=False)
+        self.__next_node_id += 1
+        empty_e = IGEdge()
         for node in edges:
             edge = reduce(IGEdge.merged,
-                          (node.edges.pop(n) for n in source_nodes))
-            node.edges[retval] = edge
+                          (node.edges.pop(n, empty_e) for n in source_nodes))
+            if edge == empty_e:
+                node.edges.pop(retval, None)
+            else:
+                node.edges[retval] = edge
         for node in source_nodes:
             del self.__map[node.merged_ssa_val]
         self.__map[final_merged_ssa_val] = retval
@@ -452,7 +461,7 @@ class InterferenceGraph:
             lhs, rhs = rhs, lhs
         lhs_merged = node1.merged_ssa_val.with_offset_to_match(
             lhs.ssa_val, additional_offset=-lhs.reg_idx)
-        rhs_merged = node1.merged_ssa_val.with_offset_to_match(
+        rhs_merged = node2.merged_ssa_val.with_offset_to_match(
             rhs.ssa_val, additional_offset=-rhs.reg_idx)
         return lhs_merged.merged(rhs_merged)
 
@@ -474,12 +483,14 @@ class InterferenceGraph:
             raise ValueError(
                 "can't get local_colorability_score of ignored node")
         loc_set = node.loc_set
-        edges = node.edges.items()
+        edges = node.edges
         if merged_in_copy is not None:
             loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set
-            edges = chain(edges, merged_in_copy.edges.items())
+            edges = edges.copy()
+            for neighbor, edge in merged_in_copy.edges.items():
+                edges[neighbor] = edge.merged(edges.get(neighbor))
         retval = len(loc_set)
-        for neighbor, edge in edges:
+        for neighbor, edge in edges.items():
             if neighbor.ignored or not edge.interferes:
                 continue
             if neighbor == merged_in_copy or neighbor == node:
@@ -512,8 +523,12 @@ class InterferenceGraph:
         s = self.nodes.__repr__(repr_state)
         return f"InterferenceGraph(nodes={s}, <...>)"
 
-    def dump_to_dot(self):
-        # type: () -> str
+    def dump_to_dot(
+            self, highlighted_nodes=(),  # type: Container[IGNode]
+            node_scores=None,  # type: None | dict[IGNode, int]
+            edge_scores=None,  # type: None | dict[tuple[IGNode, IGNode], int]
+    ):
+        # type: (...) -> str
 
         def quote(s):
             # type: (object) -> str
@@ -523,10 +538,15 @@ class InterferenceGraph:
             s = s.replace('\n', r'\n')
             return f'"{s}"'
 
+        if node_scores is None:
+            node_scores = {}
+        if edge_scores is None:
+            edge_scores = {}
+
         edges = {}  # type: dict[tuple[IGNode, IGNode], IGEdge]
         node_ids = {}  # type: dict[IGNode, str]
         for node in self.nodes.values():
-            node_ids[node] = quote(len(node_ids))
+            node_ids[node] = quote(node.node_id)
             for neighbor, edge in node.edges.items():
                 edge_key = (node, neighbor)
                 # ensure we only insert each edge once by checking for
@@ -539,10 +559,23 @@ class InterferenceGraph:
         ]
         for node, node_id in node_ids.items():
             label_lines = []  # type: list[str]
+            score = node_scores.get(node)
+            if score is not None:
+                label_lines.append(f"score={score}")
             for k, v in node.merged_ssa_val.ssa_val_offsets.items():
                 label_lines.append(f"{k}: {v}")
             label = quote("\n".join(label_lines))
-            lines.append(f"    {node_id} [label = {label}]")
+            style = "dotted" if node.ignored else "solid"
+            color = "black"
+            if node in highlighted_nodes:
+                style = "bold"
+                color = "green"
+            style = quote(style)
+            color = quote(color)
+            lines.append(f"    {node_id} ["
+                         f"label = {label}, "
+                         f"style = {style}, "
+                         f"color = {color}]")
 
         def append_edge(node1, node2, label, color, style):
             # type: (IGNode, IGNode, str, str, str) -> None
@@ -555,11 +588,17 @@ class InterferenceGraph:
                          f"style = {style}, "
                          f"decorate = true]")
         for (node1, node2), edge in edges.items():
+            score = edge_scores.get((node1, node2))
+            if score is None:
+                score = edge_scores.get((node2, node1))
+            label_prefix = ""
+            if score is not None:
+                label_prefix = f"score={score}\n"
             if edge.interferes:
-                append_edge(node1, node2, label="interferes",
+                append_edge(node1, node2, label=label_prefix + "interferes",
                             color="darkred", style="bold")
             if edge.copy_relation is not None:
-                append_edge(node1, node2, label="copy related",
+                append_edge(node1, node2, label=label_prefix + "copy related",
                             color="blue", style="dashed")
         lines.append("}")
         return "\n".join(lines)
@@ -567,11 +606,10 @@ class InterferenceGraph:
 
 @plain_data(repr=False)
 class IGNodeReprState:
-    __slots__ = "node_ids", "did_full_repr"
+    __slots__ = "did_full_repr",
 
     def __init__(self):
         super().__init__()
-        self.node_ids = {}  # type: dict[IGNode, int]
         self.did_full_repr = OSet()  # type: OSet[IGNode]
 
 
@@ -600,10 +638,11 @@ class IGEdge:
 @final
 class IGNode:
     """ interference graph node """
-    __slots__ = "merged_ssa_val", "edges", "loc", "ignored"
+    __slots__ = "node_id", "merged_ssa_val", "edges", "loc", "ignored"
 
-    def __init__(self, merged_ssa_val, edges, loc, ignored):
-        # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
+    def __init__(self, node_id, merged_ssa_val, edges, loc, ignored):
+        # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
+        self.node_id = node_id
         self.merged_ssa_val = merged_ssa_val
         self.edges = edges
         self.loc = loc
@@ -611,6 +650,8 @@ class IGNode:
 
     def merge_edge(self, other, edge):
         # type: (IGNode, IGEdge) -> None
+        if self == other:
+            raise ValueError("can't have self-loops")
         old_edge = self.edges.get(other, None)
         assert old_edge is other.edges.get(self, None), "inconsistent edges"
         edge = edge.merged(old_edge)
@@ -637,15 +678,12 @@ class IGNode:
         del repr_state
         if rs is None:
             rs = IGNodeReprState()
-        node_id = rs.node_ids.get(self, None)
-        if node_id is None:
-            rs.node_ids[self] = node_id = len(rs.node_ids)
         if short or self in rs.did_full_repr:
-            return f"<IGNode #{node_id}>"
+            return f"<IGNode #{self.node_id}>"
         rs.did_full_repr.add(self)
         edges = ", ".join(
             f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
-        return (f"IGNode(#{node_id}, "
+        return (f"IGNode(#{self.node_id}, "
                 f"merged_ssa_val={self.merged_ssa_val}, "
                 f"edges={{{edges}}}, "
                 f"loc={self.loc}, "
@@ -734,28 +772,120 @@ def allocate_registers(
     if dump_graph is not None:
         dump_graph("initial", interference_graph.dump_to_dot())
 
-    # TODO: implement copy-merging
-
     node_stack = []  # type: list[IGNode]
-    while True:
+
+    debug_node_scores = {}  # type: dict[IGNode, int]
+    debug_edge_scores = {}  # type: dict[tuple[IGNode, IGNode], int]
+
+    def find_best_node(has_copy_relation):
+        # type: (bool) -> None | IGNode
         best_node = None  # type: None | IGNode
         best_score = 0
         for node in interference_graph.nodes.values():
             if node.ignored:
                 continue
+            node_has_copy_relation = False
+            for neighbor, edge in node.edges.items():
+                if neighbor.ignored:
+                    continue
+                if edge.copy_relation is not None:
+                    node_has_copy_relation = True
+                    break
+            if node_has_copy_relation != has_copy_relation:
+                continue
             score = interference_graph.local_colorability_score(node)
+            debug_node_scores[node] = score
             if best_node is None or score > best_score:
                 best_node = node
                 best_score = score
             if best_score > 0:
                 # it's locally colorable, no need to find a better one
                 break
+        if debug_out is not None:
+            print(f"find_best_node(has_copy_relation={has_copy_relation}):\n"
+                  f"{best_node}", file=debug_out, flush=True)
+        return best_node
+    # copy-merging algorithm based on Iterated Register Coalescing, section 5:
+    # https://dl.acm.org/doi/pdf/10.1145/229542.229546
+    # Build step is above.
+    for step in count():
+        debug_node_scores.clear()
+        debug_edge_scores.clear()
+        # Simplify:
+        best_node = find_best_node(has_copy_relation=False)
+        if best_node is not None:
+            if dump_graph is not None:
+                dump_graph(
+                    f"step_{step}_simplify", interference_graph.dump_to_dot(
+                        highlighted_nodes=[best_node],
+                        node_scores=debug_node_scores,
+                        edge_scores=debug_edge_scores))
+            node_stack.append(best_node)
+            best_node.ignored = True
+            continue
+        # Coalesce (aka. do copy-merges):
+        did_any_copy_merges = False
+        for node in interference_graph.nodes.values():
+            if node.ignored:
+                continue
+            for neighbor, edge in node.edges.items():
+                if neighbor.ignored:
+                    continue
+                if edge.copy_relation is None:
+                    continue
+                try:
+                    score = interference_graph.local_colorability_score(
+                        node, merged_in_copy=neighbor)
+                except BadMergedSSAVal:
+                    continue
+                if (neighbor, node) in debug_edge_scores:
+                    debug_edge_scores[(neighbor, node)] = score
+                else:
+                    debug_edge_scores[(node, neighbor)] = score
+                if score > 0:  # merged node is locally colorable
+                    if dump_graph is not None:
+                        dump_graph(
+                            f"step_{step}_copy_merge",
+                            interference_graph.dump_to_dot(
+                                highlighted_nodes=[node, neighbor],
+                                node_scores=debug_node_scores,
+                                edge_scores=debug_edge_scores))
+                    if debug_out is not None:
+                        print(f"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
+                              file=debug_out, flush=True)
+                    merged_node = interference_graph.copy_merge(node, neighbor)
+                    if dump_graph is not None:
+                        dump_graph(
+                            f"step_{step}_copy_merge_result",
+                            interference_graph.dump_to_dot(
+                                highlighted_nodes=[merged_node]))
+                    if debug_out is not None:
+                        print(f"merged_node:\n"
+                              f"{merged_node}", file=debug_out, flush=True)
+                    did_any_copy_merges = True
+                    break
+            if did_any_copy_merges:
+                break
+        if did_any_copy_merges:
+            continue
+        # Freeze:
+        best_node = find_best_node(has_copy_relation=True)
+        if best_node is not None:
+            if dump_graph is not None:
+                dump_graph(f"step_{step}_freeze",
+                           interference_graph.dump_to_dot(
+                               highlighted_nodes=[best_node],
+                               node_scores=debug_node_scores,
+                               edge_scores=debug_edge_scores))
+            # no need to clear copy relations since best_node won't be
+            # considered since it's now ignored.
+            node_stack.append(best_node)
+            best_node.ignored = True
+            continue
+        break
 
-        if best_node is None:
-            break
-        node_stack.append(best_node)
-        best_node.ignored = True
-
+    if dump_graph is not None:
+        dump_graph("final", interference_graph.dump_to_dot())
     if debug_out is not None:
         print(f"After deciding node allocation order:\n"
               f"{node_stack}", file=debug_out, flush=True)