copy-merging works afaict! -- some tests still broken: out-of-date
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
index 540af0a3e31316203cfefff1851ca2be3e1d816d..76fc966ccfbf721a85132f8ff2c2a1a0b880e881 100644 (file)
@@ -6,11 +6,11 @@ this uses an algorithm based on:
 """
 
 from functools import reduce
-from itertools import chain, combinations, count
+from itertools import combinations, count
 from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
 
 from cached_property import cached_property
-from nmutil.plain_data import plain_data
+from nmutil.plain_data import plain_data, replace
 
 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
                                                   LocSet, Op, ProgramRange,
@@ -59,8 +59,8 @@ class MergedSSAVal(metaclass=InternedMeta):
     __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
                  "first_loc")
 
-    def __init__(self, fn_analysis, ssa_val_offsets):
-        # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+    def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
+        # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
         self.fn_analysis = fn_analysis
         if isinstance(ssa_val_offsets, SSAVal):
             ssa_val_offsets = {ssa_val_offsets: 0}
@@ -74,7 +74,10 @@ class MergedSSAVal(metaclass=InternedMeta):
         self.first_ssa_val = first_ssa_val  # type: SSAVal
         # self.ty checks for mismatched base_ty
         reg_len = self.ty.reg_len
-        loc_set = None  # type: None | LocSet
+        if loc_set is not None and loc_set.ty != self.ty:
+            raise ValueError(
+                f"invalid loc_set, type doesn't match: "
+                f"{loc_set.ty} != {self.ty}")
         for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
             def locs():
                 # type: () -> Iterable[Loc]
@@ -163,12 +166,17 @@ class MergedSSAVal(metaclass=InternedMeta):
     @cached_property
     def __hash(self):
         # type: () -> int
-        return hash((self.fn_analysis, self.ssa_val_offsets))
+        return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
 
     def __hash__(self):
         # type: () -> int
         return self.__hash
 
+    @property
+    def only_loc(self):
+        # type: () -> Loc | None
+        return self.loc_set.only_loc
+
     @cached_property
     def offset(self):
         # type: () -> int
@@ -226,6 +234,16 @@ class MergedSSAVal(metaclass=InternedMeta):
                     ssa_val_offsets[ssa_val] + additional_offset - offset)
         raise ValueError("can't change offset to match unrelated MergedSSAVal")
 
+    def with_loc(self, loc):
+        # type: (Loc) -> MergedSSAVal
+        if loc not in self.loc_set:
+            raise ValueError(
+                f"Loc is not allowed -- not a member of `self.loc_set`: "
+                f"{loc} not in {self.loc_set}")
+        return MergedSSAVal(fn_analysis=self.fn_analysis,
+                            ssa_val_offsets=self.ssa_val_offsets,
+                            loc_set=LocSet([loc]))
+
     def merged(self, *others):
         # type: (*MergedSSAVal) -> MergedSSAVal
         retval = dict(self.ssa_val_offsets)
@@ -278,6 +296,21 @@ class MergedSSAVal(metaclass=InternedMeta):
                             return lhs, rhs
         return None
 
+    def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
+        # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
+        cr_lhs, cr_rhs = copy_relation
+        if cr_lhs.ssa_val not in self.ssa_vals:
+            cr_lhs, cr_rhs = cr_rhs, cr_lhs
+        lhs_merged = self.with_offset_to_match(
+            cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
+        if lhs_loc is not None:
+            lhs_merged = lhs_merged.with_loc(lhs_loc)
+        rhs_merged = rhs.with_offset_to_match(
+            cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
+        if rhs_loc is not None:
+            rhs_merged = rhs_merged.with_loc(rhs_loc)
+        return lhs_merged.merged(rhs_merged).normalized()
+
 
 @final
 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
@@ -350,7 +383,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                 added += 1
             retval = IGNode(
                 node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
-                edges={}, loc=None, ignored=False)
+                edges={}, loc=merged_ssa_val.only_loc, ignored=False)
             self.__map[merged_ssa_val] = retval
             self.__next_node_id += 1
             added = None
@@ -367,7 +400,6 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # type: (MergedSSAVal) -> IGNode
         source_nodes = OSet()  # type: OSet[IGNode]
         edges = {}  # type: dict[IGNode, IGEdge]
-        loc = None  # type: Loc | None
         for ssa_val in final_merged_ssa_val.ssa_vals:
             merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
             source_node = self.__map[merged_ssa_val]
@@ -380,11 +412,11 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                     f"but not in merged IGNode's merged_ssa_val: "
                     f"source_node={source_node} "
                     f"final_merged_ssa_val={final_merged_ssa_val}")
-            if loc is None:
-                loc = source_node.loc
-            elif source_node.loc is not None and loc != source_node.loc:
-                raise ValueError(f"can't merge IGNodes with mismatched `loc` "
-                                 f"values: {loc} != {source_node.loc}")
+            if source_node.loc != source_node.merged_ssa_val.only_loc:
+                raise ValueError(
+                    f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
+                    f"{source_node.loc} != "
+                    f"{source_node.merged_ssa_val.only_loc}")
             for n, edge in source_node.edges.items():
                 if n in edges:
                     edge = edge.merged(edges[n])
@@ -394,6 +426,18 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # we're finished checking validity, now we can modify stuff
         for n in source_nodes:
             edges.pop(n, None)
+        loc = final_merged_ssa_val.only_loc
+        for n, edge in edges.items():
+            if edge.copy_relation is None or not edge.interferes:
+                continue
+            try:
+                # if merging works, then the edge can't interfere
+                _ = final_merged_ssa_val.copy_merged(
+                    lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc,
+                    copy_relation=edge.copy_relation)
+            except BadMergedSSAVal:
+                continue
+            edges[n] = replace(edge, interferes=False)
         retval = IGNode(
             node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
             edges=edges, loc=loc, ignored=False)
@@ -440,7 +484,7 @@ class InterferenceGraph:
         merged2 = self.merged_ssa_val_map[ssa_val2]
         merged = merged1.with_offset_to_match(ssa_val1)
         return merged.merged(merged2.with_offset_to_match(
-            ssa_val2, additional_offset=additional_offset))
+            ssa_val2, additional_offset=additional_offset)).normalized()
 
     def merge(self, ssa_val1, ssa_val2, additional_offset=0):
         # type: (SSAVal, SSAVal, int) -> IGNode
@@ -448,27 +492,9 @@ class InterferenceGraph:
             ssa_val1=ssa_val1, ssa_val2=ssa_val2,
             additional_offset=additional_offset))
 
-    def copy_merge_preview(self, node1, node2):
-        # type: (IGNode, IGNode) -> MergedSSAVal
-        try:
-            copy_relation = node1.edges[node2].copy_relation
-        except KeyError:
-            raise ValueError("nodes aren't copy related")
-        if copy_relation is None:
-            raise ValueError("nodes aren't copy related")
-        lhs, rhs = copy_relation
-        if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals:
-            lhs, rhs = rhs, lhs
-        lhs_merged = node1.merged_ssa_val.with_offset_to_match(
-            lhs.ssa_val, additional_offset=-lhs.reg_idx)
-        rhs_merged = node2.merged_ssa_val.with_offset_to_match(
-            rhs.ssa_val, additional_offset=-rhs.reg_idx)
-        return lhs_merged.merged(rhs_merged)
-
     def copy_merge(self, node1, node2):
         # type: (IGNode, IGNode) -> IGNode
-        return self.nodes.merge_into_one_node(self.copy_merge_preview(
-            node1=node1, node2=node2))
+        return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
 
     def local_colorability_score(self, node, merged_in_copy=None):
         # type: (IGNode, None | IGNode) -> int
@@ -485,7 +511,10 @@ class InterferenceGraph:
         loc_set = node.loc_set
         edges = node.edges
         if merged_in_copy is not None:
-            loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set
+            if merged_in_copy.ignored:
+                raise ValueError(
+                    "can't get local_colorability_score of ignored node")
+            loc_set = node.copy_merge_preview(merged_in_copy).loc_set
             edges = edges.copy()
             for neighbor, edge in merged_in_copy.edges.items():
                 edges[neighbor] = edge.merged(edges.get(neighbor))
@@ -665,12 +694,12 @@ class IGNode:
     def __eq__(self, other):
         # type: (object) -> bool
         if isinstance(other, IGNode):
-            return self.merged_ssa_val == other.merged_ssa_val
+            return self.node_id == other.node_id
         return NotImplemented
 
     def __hash__(self):
         # type: () -> int
-        return hash(self.merged_ssa_val)
+        return hash(self.node_id)
 
     def __repr__(self, repr_state=None, short=False):
         # type: (None | IGNodeReprState, bool) -> str
@@ -703,6 +732,19 @@ class IGNode:
                 return True
         return False
 
+    def copy_merge_preview(self, rhs_node):
+        # type: (IGNode) -> MergedSSAVal
+        try:
+            copy_relation = self.edges[rhs_node].copy_relation
+        except KeyError:
+            raise ValueError("nodes aren't copy related")
+        if copy_relation is None:
+            raise ValueError("nodes aren't copy related")
+        return self.merged_ssa_val.copy_merged(
+            lhs_loc=self.loc,
+            rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
+            copy_relation=copy_relation)
+
 
 class AllocationFailedError(Exception):
     def __init__(self, msg, node, interference_graph):
@@ -748,24 +790,32 @@ def allocate_registers(
         print(f"After InterferenceGraph.minimally_merged():\n"
               f"{interference_graph}", file=debug_out, flush=True)
 
+    for i, j in combinations(interference_graph.nodes.values(), 2):
+        copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
+        i.merge_edge(j, IGEdge(copy_relation=copy_relation))
+
     for pp, ssa_vals in fn_analysis.live_at.items():
         live_merged_ssa_vals = OSet()  # type: OSet[MergedSSAVal]
         for ssa_val in ssa_vals:
             live_merged_ssa_vals.add(
                 interference_graph.merged_ssa_val_map[ssa_val])
         for i, j in combinations(live_merged_ssa_vals, 2):
-            if i.loc_set.max_conflicts_with(j.loc_set) != 0:
-                interference_graph.nodes[i].merge_edge(
-                    interference_graph.nodes[j],
-                    edge=IGEdge(interferes=True))
+            if i.loc_set.max_conflicts_with(j.loc_set) == 0:
+                continue
+            node_i = interference_graph.nodes[i]
+            node_j = interference_graph.nodes[j]
+            if node_j in node_i.edges:
+                if node_i.edges[node_j].copy_relation is not None:
+                    try:
+                        _ = node_i.copy_merge_preview(node_j)
+                        continue  # doesn't interfere if copy merging succeeds
+                    except BadMergedSSAVal:
+                        pass
+            node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
         if debug_out is not None:
             print(f"processed {pp} out of {fn_analysis.all_program_points}",
                   file=debug_out, flush=True)
 
-    for i, j in combinations(interference_graph.nodes.values(), 2):
-        copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
-        i.merge_edge(j, IGEdge(copy_relation=copy_relation))
-
     if debug_out is not None:
         print(f"After adding interference graph edges:\n"
               f"{interference_graph}", file=debug_out, flush=True)
@@ -900,10 +950,37 @@ def allocate_registers(
                     "IGNode is pre-allocated to a conflicting Loc",
                     node=node, interference_graph=interference_graph)
         else:
-            # pick the first non-conflicting register in node.reg_class, since
-            # register classes are ordered from most preferred to least
-            # preferred register.
+            # Locs to try allocating, ordered from most preferred to least
+            # preferred
+            locs = OSet()
+            # prefer eliminating copies
+            for neighbor, edge in node.edges.items():
+                if neighbor.loc is None or edge.copy_relation is None:
+                    continue
+                try:
+                    merged = node.copy_merge_preview(neighbor)
+                except BadMergedSSAVal:
+                    continue
+                # get merged_loc if merged.loc_set has a single Loc
+                merged_loc = merged.only_loc
+                if merged_loc is None:
+                    continue
+                ssa_val = node.merged_ssa_val.first_ssa_val
+                ssa_val_loc = merged_loc.get_subloc_at_offset(
+                    subloc_ty=ssa_val.ty,
+                    offset=merged.ssa_val_offsets[ssa_val])
+                node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
+                    superloc_ty=node.merged_ssa_val.ty,
+                    offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
+                assert node_loc in node.merged_ssa_val.loc_set, "logic error"
+                locs.add(node_loc)
+            # add node's allowed Locs as fallback
             for loc in node.loc_set:
+                # TODO: add in order of preference
+                locs.add(loc)
+            # pick the first non-conflicting register in locs, since locs is
+            # ordered from most preferred to least preferred register.
+            for loc in locs:
                 if not node.loc_conflicts_with_neighbors(loc):
                     node.loc = loc
                     break