change InterferenceGraph edges to not imply interference
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 6 Dec 2022 08:43:30 +0000 (00:43 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 6 Dec 2022 08:43:30 +0000 (00:43 -0800)
nodes can be only copy-related but not interfering

src/bigint_presentation_code/_tests/test_register_allocator.py
src/bigint_presentation_code/register_allocator.py

index 9cf6270e581e2cfc2595cce0a96513f6c633de3f..5e9d25afe1fbf5e61eb6aafb190ea3d0563b6fb5 100644 (file)
@@ -262,54 +262,67 @@ class TestRegisterAllocator(unittest.TestCase):
         # FIXME: is_copy_related is not correct, it's missing a bunch of
         # edges (which aren't interference edges)
         self.assertEqual(graphs, {
-            'initial':
-            'graph {\n'
-            '    graph [pack = true]\n'
-            '    "0" [label = "<arg.outputs[0]: <I64>>: 0"]\n'
-            '    "1" [label = "<arg.out0.copy.outputs[0]: <I64>>: 0"]\n'
-            '    "2" [label = "<vl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "3" [label = "<ld.inp0.copy.outputs[0]: <I64>>: 0"]\n'
-            '    "4" [label = "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "5" [label = "<ld.outputs[0]: <I64*32>>: 0"]\n'
-            '    "6" [label = "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "7" [label = "<ld.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "8" [label = "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "9" [label = "<li.outputs[0]: <I64*32>>: 0"]\n'
-            '    "10" [label = "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "11" [label = "<li.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "12" [label = "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "13" [label = "<add.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "14" [label = "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "15" [label = "<add.inp1.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "16" [label = "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "17" [label = "<add.outputs[0]: <I64*32>>: 0"]\n'
-            '    "18" [label = "<ca.outputs[0]: <CA>>: 0\\n'
-            '<add.outputs[1]: <CA>>: 0"]\n'
-            '    "19" [label = "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "20" [label = "<add.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "21" [label = "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "22" [label = "<st.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
-            '    "23" [label = "<st.inp1.copy.outputs[0]: <I64>>: 0"]\n'
-            '    "24" [label = "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
-            '    "1" -- "3" [label = "IGEdge(is_copy_related=True)"]\n'
-            '    "1" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "7" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "20" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "1" -- "22" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "3" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "7" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "7" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "11" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "13" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "13" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "15" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
-            '    "22" -- "23" [label = "IGEdge(is_copy_related=False)"]\n'
-            '}'
+            'initial': r"""graph {
+    graph [pack = true]
+    "0" [label = "<arg.outputs[0]: <I64>>: 0"]
+    "1" [label = "<arg.out0.copy.outputs[0]: <I64>>: 0"]
+    "2" [label = "<vl.outputs[0]: <VL_MAXVL>>: 0"]
+    "3" [label = "<ld.inp0.copy.outputs[0]: <I64>>: 0"]
+    "4" [label = "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "5" [label = "<ld.outputs[0]: <I64*32>>: 0"]
+    "6" [label = "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "7" [label = "<ld.out0.copy.outputs[0]: <I64*32>>: 0"]
+    "8" [label = "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "9" [label = "<li.outputs[0]: <I64*32>>: 0"]
+    "10" [label = "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "11" [label = "<li.out0.copy.outputs[0]: <I64*32>>: 0"]
+    "12" [label = "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "13" [label = "<add.inp0.copy.outputs[0]: <I64*32>>: 0"]
+    "14" [label = "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "15" [label = "<add.inp1.copy.outputs[0]: <I64*32>>: 0"]
+    "16" [label = "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "17" [label = "<add.outputs[0]: <I64*32>>: 0"]
+    "18" [label = "<ca.outputs[0]: <CA>>: 0\n<add.outputs[1]: <CA>>: 0"]
+    "19" [label = "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "20" [label = "<add.out0.copy.outputs[0]: <I64*32>>: 0"]
+    "21" [label = "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "22" [label = "<st.inp0.copy.outputs[0]: <I64*32>>: 0"]
+    "23" [label = "<st.inp1.copy.outputs[0]: <I64>>: 0"]
+    "24" [label = "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: 0"]
+    "0" -- "1" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "0" -- "3" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "0" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "1" -- "3" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "3" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "1" -- "5" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "7" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "9" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "11" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "13" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "15" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "20" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "22" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "1" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "3" -- "5" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "3" -- "23" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "5" -- "7" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "5" -- "13" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "7" -- "9" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "7" -- "11" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "7" -- "13" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "9" -- "11" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "9" -- "15" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "11" -- "13" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "11" -- "15" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "13" -- "15" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "13" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "15" -- "17" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+    "17" -- "20" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "17" -- "22" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "20" -- "22" [label = "copy related", color = "blue", style = "dashed", decorate = true]
+    "22" -- "23" [label = "interferes", color = "darkred", style = "bold", decorate = true]
+}"""
         })
 
     def test_register_allocate_spread(self):
index 94de331e6862efd22039c64c03a51bd87c1533f9..a1d17859c13fae76884d888132431c81fec24fac 100644 (file)
@@ -484,10 +484,24 @@ class InterferenceGraph:
                 label_lines.append(f"{k}: {v}")
             label = quote("\n".join(label_lines))
             lines.append(f"    {node_id} [label = {label}]")
+
+        def append_edge(node1, node2, label, color, style):
+            # type: (IGNode, IGNode, str, str, str) -> None
+            label = quote(label)
+            color = quote(color)
+            style = quote(style)
+            lines.append(f"    {node_ids[node1]} -- {node_ids[node2]} ["
+                         f"label = {label}, "
+                         f"color = {color}, "
+                         f"style = {style}, "
+                         f"decorate = true]")
         for (node1, node2), edge in edges.items():
-            label = quote(repr(edge))
-            lines.append(f"    {node_ids[node1]} -- {node_ids[node2]} "
-                         f"[label = {label}]")
+            if edge.interferes:
+                append_edge(node1, node2, label="interferes",
+                            color="darkred", style="bold")
+            if edge.is_copy_related:
+                append_edge(node1, node2, label="copy related",
+                            color="blue", style="dashed")
         lines.append("}")
         return "\n".join(lines)
 
@@ -506,16 +520,20 @@ class IGNodeReprState:
 @final
 class IGEdge:
     """ interference graph edge """
-    __slots__ = "is_copy_related",
+    __slots__ = "interferes", "is_copy_related"
 
-    def __init__(self, is_copy_related):
-        # type: (bool) -> None
+    def __init__(self, interferes=False, is_copy_related=False):
+        # type: (bool, bool) -> None
+        self.interferes = interferes
         self.is_copy_related = is_copy_related
 
     def merged(self, other):
-        # type: (IGEdge) -> IGEdge
+        # type: (IGEdge | None) -> IGEdge
+        if other is None:
+            return self
         is_copy_related = self.is_copy_related | other.is_copy_related
-        return IGEdge(is_copy_related=is_copy_related)
+        interferes = self.interferes | other.interferes
+        return IGEdge(interferes=interferes, is_copy_related=is_copy_related)
 
 
 @final
@@ -529,10 +547,17 @@ class IGNode:
         self.edges = edges
         self.loc = loc
 
-    def add_edge(self, other, edge):
+    def merge_edge(self, other, edge):
         # type: (IGNode, IGEdge) -> None
-        self.edges[other] = edge
-        other.edges[self] = edge
+        old_edge = self.edges.get(other, None)
+        assert old_edge is other.edges.get(self, None), "inconsistent edges"
+        edge = edge.merged(old_edge)
+        if edge == IGEdge():
+            self.edges.pop(other, None)
+            other.edges.pop(self, None)
+        else:
+            self.edges[other] = edge
+            other.edges[self] = edge
 
     def __eq__(self, other):
         # type: (object) -> bool
@@ -570,7 +595,9 @@ class IGNode:
 
     def loc_conflicts_with_neighbors(self, loc):
         # type: (Loc) -> bool
-        for neighbor in self.edges:
+        for neighbor, edge in self.edges.items():
+            if not edge.interferes:
+                continue
             if neighbor.loc is not None and neighbor.loc.conflicts(loc):
                 return True
         return False
@@ -627,17 +654,21 @@ def allocate_registers(
                 interference_graph.merged_ssa_val_map[ssa_val])
         for i, j in combinations(live_merged_ssa_vals, 2):
             if i.loc_set.max_conflicts_with(j.loc_set) != 0:
-                # can't use:
-                # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
-                #     j.copy_related_ssa_vals)
-                # since it is too coarse
-                interference_graph.nodes[i].add_edge(
+                interference_graph.nodes[i].merge_edge(
                     interference_graph.nodes[j],
-                    edge=IGEdge(is_copy_related=i.is_copy_related(j)))
+                    edge=IGEdge(interferes=True))
         if debug_out is not None:
             print(f"processed {pp} out of {fn_analysis.all_program_points}",
                   file=debug_out, flush=True)
 
+    for i, j in combinations(interference_graph.nodes.values(), 2):
+        # can't use:
+        # is_copy_related = (not i.merged_ssa_val.copy_related_ssa_vals
+        #     .isdisjoint(j.merged_ssa_val.copy_related_ssa_vals))
+        # since it is too coarse
+        is_copy_related = i.merged_ssa_val.is_copy_related(j.merged_ssa_val)
+        i.merge_edge(j, IGEdge(is_copy_related=is_copy_related))
+
     if debug_out is not None:
         print(f"After adding interference graph edges:\n"
               f"{interference_graph}", file=debug_out, flush=True)
@@ -660,7 +691,9 @@ def allocate_registers(
         if retval is not None:
             return retval
         retval = len(node.loc_set)
-        for neighbor in node.edges:
+        for neighbor, edge in node.edges.items():
+            if not edge.interferes:
+                continue
             if neighbor in nodes_remaining:
                 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
         local_colorability_score_cache[node] = retval
@@ -686,7 +719,9 @@ def allocate_registers(
         node_stack.append(best_node)
         nodes_remaining.remove(best_node)
         local_colorability_score_cache.pop(best_node, None)
-        for neighbor in best_node.edges:
+        for neighbor, edge in best_node.edges.items():
+            if not edge.interferes:
+                continue
             local_colorability_score_cache.pop(neighbor, None)
 
     if debug_out is not None: