add is_copy_related to interference graph edges
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 2 Dec 2022 07:19:59 +0000 (23:19 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 2 Dec 2022 07:19:59 +0000 (23:19 -0800)
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/register_allocator.py

index 12a32f1b6ad28bcf2258e804cc0958f1507047a4..0553a40660c5cc8ecd19ca067d12691e0eae0106 100644 (file)
@@ -1,3 +1,4 @@
+from collections import defaultdict
 from contextlib import contextmanager
 import enum
 from abc import ABCMeta, abstractmethod
@@ -413,6 +414,8 @@ class FnAnalysis:
         """ map from SSAValSubRegs to the original SSAValSubRegs that they are
         a copy of, looking through all layers of copies. The map excludes all
         SSAValSubRegs that aren't copies of other SSAValSubRegs.
+        This ignores inputs of copy Ops that aren't actually being copied
+        (e.g. the VL input of VecCopyToReg).
         """
         retval = {}  # type: dict[SSAValSubReg, SSAValSubReg]
         for op in self.op_indexes.keys():
@@ -434,6 +437,35 @@ class FnAnalysis:
                 retval[out] = inp
         return FMap(retval)
 
+    @cached_property
+    def copy_related_ssa_vals(self):
+        # type: () -> FMap[SSAVal, OFSet[SSAVal]]
+        """ map from SSAVals to the full set of SSAVals that are related by
+        being sources/destinations of copies, transitively looking through all
+        copies.
+        This ignores inputs of copy Ops that aren't actually being copied
+        (e.g. the VL input of VecCopyToReg).
+        """
+        sets_map = {i: OSet([i]) for i in self.uses.keys()}
+        for k, v in self.copies.items():
+            k_set = sets_map[k.ssa_val]
+            v_set = sets_map[v.ssa_val]
+            # merge k_set and v_set
+            if k_set is v_set:
+                continue
+            k_set |= v_set
+            for i in k_set:
+                sets_map[i] = k_set
+        # this way we construct each OFSet only once rather than
+        # for each SSAVal
+        sets_set = {id(i): i for i in sets_map.values()}
+        retval = {}  # type: dict[SSAVal, OFSet[SSAVal]]
+        for v in sets_set.values():
+            v = OFSet(v)
+            for k in v:
+                retval[k] = v
+        return FMap(retval)
+
     @cached_property
     def const_ssa_vals(self):
         # type: () -> FMap[SSAVal, tuple[int, ...]]
index 73381e41493bb0d3c1f6f76168ba250a0554c914..be1522ecc6c38fa3787e7b4502854af1799bb17b 100644 (file)
@@ -5,6 +5,7 @@ this uses an algorithm based on:
 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
 """
 
+from functools import reduce
 from itertools import combinations
 from typing import Iterable, Iterator, Mapping, TextIO
 
@@ -253,6 +254,15 @@ class MergedSSAVal(metaclass=InternedMeta):
                 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
                 f"live_interval={self.live_interval})")
 
+    @cached_property
+    def copy_related_ssa_vals(self):
+        # type: () -> OFSet[SSAVal]
+        sets = OSet()  # type: OSet[OFSet[SSAVal]]
+        # avoid merging the same sets multiple times
+        for ssa_val in self.ssa_vals:
+            sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
+        return OFSet(v for s in sets for v in s)
+
 
 @final
 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
@@ -322,7 +332,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                         f"{self.__merged_ssa_val_map[ssa_val]}")
                 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
                 added += 1
-            retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
+            retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None)
             self.__map[merged_ssa_val] = retval
             added = None
             return retval
@@ -337,7 +347,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
     def merge_into_one_node(self, final_merged_ssa_val):
         # type: (MergedSSAVal) -> IGNode
         source_nodes = OSet()  # type: OSet[IGNode]
-        edges = OSet()  # type: OSet[IGNode]
+        edges = {}  # type: dict[IGNode, IGEdge]
         loc = None  # type: Loc | None
         for ssa_val in final_merged_ssa_val.ssa_vals:
             merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
@@ -354,16 +364,21 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
             elif source_node.loc is not None and loc != source_node.loc:
                 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
                                  f"values: {loc} != {source_node.loc}")
-            edges |= source_node.edges
+            for n, edge in source_node.edges.items():
+                if n in edges:
+                    edge = edge.merged(edges[n])
+                edges[n] = edge
         if len(source_nodes) == 1:
             return source_nodes.pop()  # merging a single node is a no-op
         # we're finished checking validity, now we can modify stuff
-        edges -= source_nodes
+        for n in source_nodes:
+            edges.pop(n, None)
         retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
                         loc=loc)
         for node in edges:
-            node.edges -= source_nodes
-            node.edges.add(retval)
+            edge = reduce(IGEdge.merged,
+                          (node.edges.pop(n) for n in source_nodes))
+            node.edges[retval] = edge
         for node in source_nodes:
             del self.__map[node.merged_ssa_val]
         self.__map[final_merged_ssa_val] = retval
@@ -437,21 +452,37 @@ class IGNodeReprState:
         self.did_full_repr = OSet()  # type: OSet[IGNode]
 
 
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class IGEdge:
+    """ interference graph edge """
+    __slots__ = "is_copy_related",
+
+    def __init__(self, is_copy_related):
+        # type: (bool) -> None
+        self.is_copy_related = is_copy_related
+
+    def merged(self, other):
+        # type: (IGEdge) -> IGEdge
+        is_copy_related = self.is_copy_related | other.is_copy_related
+        return IGEdge(is_copy_related=is_copy_related)
+
+
 @final
 class IGNode:
     """ interference graph node """
     __slots__ = "merged_ssa_val", "edges", "loc"
 
     def __init__(self, merged_ssa_val, edges, loc):
-        # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
+        # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
         self.merged_ssa_val = merged_ssa_val
-        self.edges = OSet(edges)
+        self.edges = edges
         self.loc = loc
 
-    def add_edge(self, other):
-        # type: (IGNode) -> None
-        self.edges.add(other)
-        other.edges.add(self)
+    def add_edge(self, other, edge):
+        # type: (IGNode, IGEdge) -> None
+        self.edges[other] = edge
+        other.edges[self] = edge
 
     def __eq__(self, other):
         # type: (object) -> bool
@@ -465,15 +496,18 @@ class IGNode:
 
     def __repr__(self, repr_state=None, short=False):
         # type: (None | IGNodeReprState, bool) -> str
-        if repr_state is None:
-            repr_state = IGNodeReprState()
-        node_id = repr_state.node_ids.get(self, None)
+        rs = repr_state
+        del repr_state
+        if rs is None:
+            rs = IGNodeReprState()
+        node_id = rs.node_ids.get(self, None)
         if node_id is None:
-            repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
-        if short or self in repr_state.did_full_repr:
+            rs.node_ids[self] = node_id = len(rs.node_ids)
+        if short or self in rs.did_full_repr:
             return f"<IGNode #{node_id}>"
-        repr_state.did_full_repr.add(self)
-        edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
+        rs.did_full_repr.add(self)
+        edges = ", ".join(
+            f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
         return (f"IGNode(#{node_id}, "
                 f"merged_ssa_val={self.merged_ssa_val}, "
                 f"edges={{{edges}}}, "
@@ -539,8 +573,16 @@ def allocate_registers(fn, debug_out=None):
                 interference_graph.merged_ssa_val_map[ssa_val])
         for i, j in combinations(live_merged_ssa_vals, 2):
             if i.loc_set.max_conflicts_with(j.loc_set) != 0:
+                # can't use:
+                # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
+                #     j.copy_related_ssa_vals)
+                # since it is too coarse
+
+                # TODO: fill in is_copy_related afterwards
+                # using fn_analysis.copies
                 interference_graph.nodes[i].add_edge(
-                    interference_graph.nodes[j])
+                    interference_graph.nodes[j],
+                    edge=IGEdge(is_copy_related=False))
         if debug_out is not None:
             print(f"processed {pp} out of {fn_analysis.all_program_points}",
                   file=debug_out, flush=True)