6e7d87c27948a5e5a1abd6fec3cece7772f2a334
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from functools import reduce
9 from itertools import combinations
10 from typing import Callable, Iterable, Iterator, Mapping, TextIO
11
12 from cached_property import cached_property
13 from nmutil.plain_data import plain_data
14
15 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
16 LocSet, Op, ProgramRange,
17 SSAVal, SSAValSubReg, Ty)
18 from bigint_presentation_code.type_util import final
19 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
20
21
22 class BadMergedSSAVal(ValueError):
23 pass
24
25
26 @plain_data(frozen=True, repr=False)
27 @final
28 class MergedSSAVal(metaclass=InternedMeta):
29 """a set of `SSAVal`s along with their offsets, all register allocated as
30 a single unit.
31
32 Definition of the term `offset` for this class:
33
34 Let `locs[x]` be the `Loc` that `x` is assigned to after register
35 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
36 for each `SSAVal` `ssa_val` in `msv` is defined as:
37
38 ```
39 msv.ssa_val_offsets[ssa_val] = (msv.offset
40 + locs[ssa_val].start - locs[msv].start)
41 ```
42
43 Example:
44 ```
45 v1.ty == <I64*4>
46 v2.ty == <I64*2>
47 v3.ty == <I64>
48 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
49 msv.ty == <I64*6>
50 ```
51 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
52 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
53 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
54 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
55 """
56 __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
57 "first_loc")
58
59 def __init__(self, fn_analysis, ssa_val_offsets):
60 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
61 self.fn_analysis = fn_analysis
62 if isinstance(ssa_val_offsets, SSAVal):
63 ssa_val_offsets = {ssa_val_offsets: 0}
64 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
65 first_ssa_val = None
66 for ssa_val in self.ssa_vals:
67 first_ssa_val = ssa_val
68 break
69 if first_ssa_val is None:
70 raise BadMergedSSAVal("MergedSSAVal can't be empty")
71 self.first_ssa_val = first_ssa_val # type: SSAVal
72 # self.ty checks for mismatched base_ty
73 reg_len = self.ty.reg_len
74 loc_set = None # type: None | LocSet
75 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
76 def locs():
77 # type: () -> Iterable[Loc]
78 for loc in ssa_val.def_loc_set_before_spread:
79 disallowed_by_use = False
80 for use in fn_analysis.uses[ssa_val]:
81 # calculate the start for the use's Loc before spread
82 # e.g. if the def's Loc before spread starts at r6
83 # and the def's reg_offset_in_unspread is 5
84 # and the use's reg_offset_in_unspread is 3
85 # then the use's Loc before spread starts at r8
86 # because 8 == 6 + 5 - 3
87 start = (loc.start + ssa_val.reg_offset_in_unspread
88 - use.reg_offset_in_unspread)
89 use_loc = Loc.try_make(
90 loc.kind, start=start,
91 reg_len=use.ty_before_spread.reg_len)
92 if (use_loc is None or
93 use_loc not in use.use_loc_set_before_spread):
94 disallowed_by_use = True
95 break
96 if disallowed_by_use:
97 continue
98 start = loc.start - cur_offset + self.offset
99 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
100 if loc is not None and (loc_set is None or loc in loc_set):
101 yield loc
102 loc_set = LocSet(locs())
103 assert loc_set is not None, "already checked that self isn't empty"
104 first_loc = None
105 for loc in loc_set:
106 first_loc = loc
107 break
108 if first_loc is None:
109 raise BadMergedSSAVal("there are no valid Locs left")
110 self.first_loc = first_loc
111 assert loc_set.ty == self.ty, "logic error somewhere"
112 self.loc_set = loc_set # type: LocSet
113 self.__mergable_check()
114
115 def __mergable_check(self):
116 # type: () -> None
117 """ checks that nothing is forcing two independent SSAVals
118 to illegally overlap. This is required to avoid copy merging merging
119 things that can't be merged.
120 spread arguments are one of the things that can force two values to
121 illegally overlap.
122 """
123 ops = OSet() # type: Iterable[Op]
124 for ssa_val in self.ssa_vals:
125 ops.add(ssa_val.op)
126 for use in self.fn_analysis.uses[ssa_val]:
127 ops.add(use.op)
128 ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
129 vals = {} # type: dict[int, SSAValSubReg]
130 for op in ops:
131 for inp in op.input_vals:
132 try:
133 ssa_val_offset = self.ssa_val_offsets[inp]
134 except KeyError:
135 continue
136 for orig_reg in inp.ssa_val_sub_regs:
137 reg_offset = ssa_val_offset + orig_reg.reg_idx
138 replaced_reg = vals[reg_offset]
139 if not self.fn_analysis.is_always_equal(
140 orig_reg, replaced_reg):
141 raise BadMergedSSAVal(
142 f"attempting to merge values that aren't known to "
143 f"be always equal: {orig_reg} != {replaced_reg}")
144 output_offsets = dict.fromkeys(range(
145 self.offset, self.offset + self.ty.reg_len))
146 for out in op.outputs:
147 try:
148 ssa_val_offset = self.ssa_val_offsets[out]
149 except KeyError:
150 continue
151 for reg in out.ssa_val_sub_regs:
152 reg_offset = ssa_val_offset + reg.reg_idx
153 try:
154 del output_offsets[reg_offset]
155 except KeyError:
156 raise BadMergedSSAVal("attempted to merge two outputs "
157 "of the same instruction")
158 vals[reg_offset] = reg
159
160 @cached_property
161 def __hash(self):
162 # type: () -> int
163 return hash((self.fn_analysis, self.ssa_val_offsets))
164
165 def __hash__(self):
166 # type: () -> int
167 return self.__hash
168
169 @cached_property
170 def offset(self):
171 # type: () -> int
172 return min(self.ssa_val_offsets_before_spread.values())
173
174 @property
175 def base_ty(self):
176 # type: () -> BaseTy
177 return self.first_ssa_val.base_ty
178
179 @cached_property
180 def ssa_vals(self):
181 # type: () -> OFSet[SSAVal]
182 return OFSet(self.ssa_val_offsets.keys())
183
184 @cached_property
185 def ty(self):
186 # type: () -> Ty
187 reg_len = 0
188 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
189 cur_ty = ssa_val.ty_before_spread
190 if self.base_ty != cur_ty.base_ty:
191 raise BadMergedSSAVal(
192 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
193 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
194 return Ty(base_ty=self.base_ty, reg_len=reg_len)
195
196 @cached_property
197 def ssa_val_offsets_before_spread(self):
198 # type: () -> FMap[SSAVal, int]
199 retval = {} # type: dict[SSAVal, int]
200 for ssa_val, offset in self.ssa_val_offsets.items():
201 retval[ssa_val] = (
202 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
203 return FMap(retval)
204
205 def offset_by(self, amount):
206 # type: (int) -> MergedSSAVal
207 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
208 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
209
210 def normalized(self):
211 # type: () -> MergedSSAVal
212 return self.offset_by(-self.offset)
213
214 def with_offset_to_match(self, target, additional_offset=0):
215 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
216 if isinstance(target, MergedSSAVal):
217 ssa_val_offsets = target.ssa_val_offsets
218 else:
219 ssa_val_offsets = {target: 0}
220 for ssa_val, offset in self.ssa_val_offsets.items():
221 if ssa_val in ssa_val_offsets:
222 return self.offset_by(
223 ssa_val_offsets[ssa_val] + additional_offset - offset)
224 raise ValueError("can't change offset to match unrelated MergedSSAVal")
225
226 def merged(self, *others):
227 # type: (*MergedSSAVal) -> MergedSSAVal
228 retval = dict(self.ssa_val_offsets)
229 for other in others:
230 if other.fn_analysis != self.fn_analysis:
231 raise ValueError("fn_analysis mismatch")
232 for ssa_val, offset in other.ssa_val_offsets.items():
233 if ssa_val in retval and retval[ssa_val] != offset:
234 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
235 f"{retval[ssa_val]} != {offset}")
236 retval[ssa_val] = offset
237 return MergedSSAVal(fn_analysis=self.fn_analysis,
238 ssa_val_offsets=retval)
239
240 @cached_property
241 def live_interval(self):
242 # type: () -> ProgramRange
243 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
244 start = live_range.start
245 stop = live_range.stop
246 for ssa_val in self.ssa_vals:
247 live_range = self.fn_analysis.live_ranges[ssa_val]
248 start = min(start, live_range.start)
249 stop = max(stop, live_range.stop)
250 return ProgramRange(start=start, stop=stop)
251
252 def __repr__(self):
253 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
254 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
255 f"live_interval={self.live_interval})")
256
257 @cached_property
258 def copy_related_ssa_vals(self):
259 # type: () -> OFSet[SSAVal]
260 sets = OSet() # type: OSet[OFSet[SSAVal]]
261 # avoid merging the same sets multiple times
262 for ssa_val in self.ssa_vals:
263 sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
264 return OFSet(v for s in sets for v in s)
265
266 def is_copy_related(self, other):
267 # type: (MergedSSAVal) -> bool
268 for lhs_ssa_val in self.ssa_vals:
269 for rhs_ssa_val in other.ssa_vals:
270 for lhs in lhs_ssa_val.ssa_val_sub_regs:
271 for rhs in rhs_ssa_val.ssa_val_sub_regs:
272 lhs = self.fn_analysis.copies.get(lhs, lhs)
273 rhs = self.fn_analysis.copies.get(rhs, rhs)
274 if lhs == rhs:
275 return True
276 return False
277
278
279 @final
280 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
281 def __init__(self):
282 # type: (...) -> None
283 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
284 self.__ig_node_map = MergedSSAValToIGNodeMap(
285 _private_merged_ssa_val_map=self.__map)
286
287 def __getitem__(self, __key):
288 # type: (SSAVal) -> MergedSSAVal
289 return self.__map[__key]
290
291 def __iter__(self):
292 # type: () -> Iterator[SSAVal]
293 return iter(self.__map)
294
295 def __len__(self):
296 # type: () -> int
297 return len(self.__map)
298
299 @property
300 def ig_node_map(self):
301 # type: () -> MergedSSAValToIGNodeMap
302 return self.__ig_node_map
303
304 def __repr__(self):
305 # type: () -> str
306 s = ",\n".join(repr(v) for v in self.__ig_node_map)
307 return f"SSAValToMergedSSAValMap({{{s}}})"
308
309
310 @final
311 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
312 def __init__(
313 self, *,
314 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
315 ):
316 # type: (...) -> None
317 self.__merged_ssa_val_map = _private_merged_ssa_val_map
318 self.__map = {} # type: dict[MergedSSAVal, IGNode]
319
320 def __getitem__(self, __key):
321 # type: (MergedSSAVal) -> IGNode
322 return self.__map[__key]
323
324 def __iter__(self):
325 # type: () -> Iterator[MergedSSAVal]
326 return iter(self.__map)
327
328 def __len__(self):
329 # type: () -> int
330 return len(self.__map)
331
332 def add_node(self, merged_ssa_val):
333 # type: (MergedSSAVal) -> IGNode
334 node = self.__map.get(merged_ssa_val, None)
335 if node is not None:
336 return node
337 added = 0 # type: int | None
338 try:
339 for ssa_val in merged_ssa_val.ssa_vals:
340 if ssa_val in self.__merged_ssa_val_map:
341 raise ValueError(
342 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
343 f"{merged_ssa_val} and "
344 f"{self.__merged_ssa_val_map[ssa_val]}")
345 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
346 added += 1
347 retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None)
348 self.__map[merged_ssa_val] = retval
349 added = None
350 return retval
351 finally:
352 if added is not None:
353 # remove partially added stuff
354 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
355 if idx >= added:
356 break
357 del self.__merged_ssa_val_map[ssa_val]
358
359 def merge_into_one_node(self, final_merged_ssa_val):
360 # type: (MergedSSAVal) -> IGNode
361 source_nodes = OSet() # type: OSet[IGNode]
362 edges = {} # type: dict[IGNode, IGEdge]
363 loc = None # type: Loc | None
364 for ssa_val in final_merged_ssa_val.ssa_vals:
365 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
366 source_node = self.__map[merged_ssa_val]
367 source_nodes.add(source_node)
368 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
369 raise ValueError(
370 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
371 f"but not in merged IGNode's merged_ssa_val: "
372 f"source_node={source_node} "
373 f"final_merged_ssa_val={final_merged_ssa_val}")
374 if loc is None:
375 loc = source_node.loc
376 elif source_node.loc is not None and loc != source_node.loc:
377 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
378 f"values: {loc} != {source_node.loc}")
379 for n, edge in source_node.edges.items():
380 if n in edges:
381 edge = edge.merged(edges[n])
382 edges[n] = edge
383 if len(source_nodes) == 1:
384 return source_nodes.pop() # merging a single node is a no-op
385 # we're finished checking validity, now we can modify stuff
386 for n in source_nodes:
387 edges.pop(n, None)
388 retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
389 loc=loc)
390 for node in edges:
391 edge = reduce(IGEdge.merged,
392 (node.edges.pop(n) for n in source_nodes))
393 node.edges[retval] = edge
394 for node in source_nodes:
395 del self.__map[node.merged_ssa_val]
396 self.__map[final_merged_ssa_val] = retval
397 for ssa_val in final_merged_ssa_val.ssa_vals:
398 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
399 return retval
400
401 def __repr__(self, repr_state=None):
402 # type: (None | IGNodeReprState) -> str
403 if repr_state is None:
404 repr_state = IGNodeReprState()
405 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
406 return f"MergedSSAValToIGNodeMap({{{s}}})"
407
408
409 @plain_data(frozen=True, repr=False)
410 @final
411 class InterferenceGraph:
412 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
413
414 def __init__(self, fn_analysis, merged_ssa_vals):
415 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
416 self.fn_analysis = fn_analysis
417 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
418 self.nodes = self.merged_ssa_val_map.ig_node_map
419 for i in merged_ssa_vals:
420 self.nodes.add_node(i)
421
422 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
423 # type: (SSAVal, SSAVal, int) -> IGNode
424 merged1 = self.merged_ssa_val_map[ssa_val1]
425 merged2 = self.merged_ssa_val_map[ssa_val2]
426 merged = merged1.with_offset_to_match(ssa_val1)
427 merged = merged.merged(merged2.with_offset_to_match(
428 ssa_val2, additional_offset=additional_offset))
429 return self.nodes.merge_into_one_node(merged)
430
431 @staticmethod
432 def minimally_merged(fn_analysis):
433 # type: (FnAnalysis) -> InterferenceGraph
434 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
435 for op in fn_analysis.fn.ops:
436 for inp in op.input_uses:
437 if inp.unspread_start != inp:
438 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
439 additional_offset=inp.reg_offset_in_unspread)
440 for out in op.outputs:
441 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
442 if out.unspread_start != out:
443 retval.merge(out.unspread_start, out,
444 additional_offset=out.reg_offset_in_unspread)
445 if out.tied_input is not None:
446 retval.merge(out.tied_input.ssa_val, out)
447 return retval
448
449 def __repr__(self, repr_state=None):
450 # type: (None | IGNodeReprState) -> str
451 if repr_state is None:
452 repr_state = IGNodeReprState()
453 s = self.nodes.__repr__(repr_state)
454 return f"InterferenceGraph(nodes={s}, <...>)"
455
456 def dump_to_dot(self):
457 # type: () -> str
458
459 def quote(s):
460 # type: (object) -> str
461 s = str(s)
462 s = s.replace('\\', r'\\')
463 s = s.replace('"', r'\"')
464 s = s.replace('\n', r'\n')
465 return f'"{s}"'
466
467 edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge]
468 node_ids = {} # type: dict[IGNode, str]
469 for node in self.nodes.values():
470 node_ids[node] = quote(len(node_ids))
471 for neighbor, edge in node.edges.items():
472 edge_key = (node, neighbor)
473 # ensure we only insert each edge once by checking for
474 # both directions
475 if edge_key not in edges and edge_key[::-1] not in edges:
476 edges[edge_key] = edge
477 lines = ["graph {"]
478 for node, node_id in node_ids.items():
479 label_lines = [] # type: list[str]
480 for k, v in node.merged_ssa_val.ssa_val_offsets.items():
481 label_lines.append(f"{k}: {v}")
482 label = quote("\n".join(label_lines))
483 lines.append(f" {node_id} [label = {label}]")
484 for (node1, node2), edge in edges.items():
485 label = quote(repr(edge))
486 lines.append(f" {node_ids[node1]} -- {node_ids[node2]} "
487 f"[label = {label}]")
488 lines.append("}")
489 return "\n".join(lines)
490
491
492 @plain_data(repr=False)
493 class IGNodeReprState:
494 __slots__ = "node_ids", "did_full_repr"
495
496 def __init__(self):
497 super().__init__()
498 self.node_ids = {} # type: dict[IGNode, int]
499 self.did_full_repr = OSet() # type: OSet[IGNode]
500
501
502 @plain_data(frozen=True, unsafe_hash=True)
503 @final
504 class IGEdge:
505 """ interference graph edge """
506 __slots__ = "is_copy_related",
507
508 def __init__(self, is_copy_related):
509 # type: (bool) -> None
510 self.is_copy_related = is_copy_related
511
512 def merged(self, other):
513 # type: (IGEdge) -> IGEdge
514 is_copy_related = self.is_copy_related | other.is_copy_related
515 return IGEdge(is_copy_related=is_copy_related)
516
517
518 @final
519 class IGNode:
520 """ interference graph node """
521 __slots__ = "merged_ssa_val", "edges", "loc"
522
523 def __init__(self, merged_ssa_val, edges, loc):
524 # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
525 self.merged_ssa_val = merged_ssa_val
526 self.edges = edges
527 self.loc = loc
528
529 def add_edge(self, other, edge):
530 # type: (IGNode, IGEdge) -> None
531 self.edges[other] = edge
532 other.edges[self] = edge
533
534 def __eq__(self, other):
535 # type: (object) -> bool
536 if isinstance(other, IGNode):
537 return self.merged_ssa_val == other.merged_ssa_val
538 return NotImplemented
539
540 def __hash__(self):
541 # type: () -> int
542 return hash(self.merged_ssa_val)
543
544 def __repr__(self, repr_state=None, short=False):
545 # type: (None | IGNodeReprState, bool) -> str
546 rs = repr_state
547 del repr_state
548 if rs is None:
549 rs = IGNodeReprState()
550 node_id = rs.node_ids.get(self, None)
551 if node_id is None:
552 rs.node_ids[self] = node_id = len(rs.node_ids)
553 if short or self in rs.did_full_repr:
554 return f"<IGNode #{node_id}>"
555 rs.did_full_repr.add(self)
556 edges = ", ".join(
557 f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
558 return (f"IGNode(#{node_id}, "
559 f"merged_ssa_val={self.merged_ssa_val}, "
560 f"edges={{{edges}}}, "
561 f"loc={self.loc})")
562
563 @property
564 def loc_set(self):
565 # type: () -> LocSet
566 return self.merged_ssa_val.loc_set
567
568 def loc_conflicts_with_neighbors(self, loc):
569 # type: (Loc) -> bool
570 for neighbor in self.edges:
571 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
572 return True
573 return False
574
575
576 class AllocationFailedError(Exception):
577 def __init__(self, msg, node, interference_graph):
578 # type: (str, IGNode, InterferenceGraph) -> None
579 super().__init__(msg, node, interference_graph)
580 self.node = node
581 self.interference_graph = interference_graph
582
583 def __repr__(self, repr_state=None):
584 # type: (None | IGNodeReprState) -> str
585 if repr_state is None:
586 repr_state = IGNodeReprState()
587 return (f"{__class__.__name__}({self.args[0]!r}, "
588 f"node={self.node.__repr__(repr_state, True)}, "
589 f"interference_graph="
590 f"{self.interference_graph.__repr__(repr_state)})")
591
592 def __str__(self):
593 # type: () -> str
594 return self.__repr__()
595
596
597 def allocate_registers(
598 fn, # type: Fn
599 debug_out=None, # type: TextIO | None
600 dump_graph=None, # type: Callable[[str, str], None] | None
601 ):
602 # type: (...) -> dict[SSAVal, Loc]
603
604 # inserts enough copies that no manual spilling is necessary, all
605 # spilling is done by the register allocator naturally allocating SSAVals
606 # to stack slots
607 fn.pre_ra_insert_copies()
608
609 if debug_out is not None:
610 print(f"After pre_ra_insert_copies():\n{fn.ops}",
611 file=debug_out, flush=True)
612
613 fn_analysis = FnAnalysis(fn)
614 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
615
616 if debug_out is not None:
617 print(f"After InterferenceGraph.minimally_merged():\n"
618 f"{interference_graph}", file=debug_out, flush=True)
619
620 for pp, ssa_vals in fn_analysis.live_at.items():
621 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
622 for ssa_val in ssa_vals:
623 live_merged_ssa_vals.add(
624 interference_graph.merged_ssa_val_map[ssa_val])
625 for i, j in combinations(live_merged_ssa_vals, 2):
626 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
627 # can't use:
628 # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
629 # j.copy_related_ssa_vals)
630 # since it is too coarse
631 interference_graph.nodes[i].add_edge(
632 interference_graph.nodes[j],
633 edge=IGEdge(is_copy_related=i.is_copy_related(j)))
634 if debug_out is not None:
635 print(f"processed {pp} out of {fn_analysis.all_program_points}",
636 file=debug_out, flush=True)
637
638 if debug_out is not None:
639 print(f"After adding interference graph edges:\n"
640 f"{interference_graph}", file=debug_out, flush=True)
641 if dump_graph is not None:
642 dump_graph("initial", interference_graph.dump_to_dot())
643
644 nodes_remaining = OSet(interference_graph.nodes.values())
645
646 local_colorability_score_cache = {} # type: dict[IGNode, int]
647
648 def local_colorability_score(node):
649 # type: (IGNode) -> int
650 """ returns a positive integer if node is locally colorable, returns
651 zero or a negative integer if node isn't known to be locally
652 colorable, the more negative the value, the less colorable
653 """
654 if node not in nodes_remaining:
655 raise ValueError()
656 retval = local_colorability_score_cache.get(node, None)
657 if retval is not None:
658 return retval
659 retval = len(node.loc_set)
660 for neighbor in node.edges:
661 if neighbor in nodes_remaining:
662 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
663 local_colorability_score_cache[node] = retval
664 return retval
665
666 # TODO: implement copy-merging
667
668 node_stack = [] # type: list[IGNode]
669 while True:
670 best_node = None # type: None | IGNode
671 best_score = 0
672 for node in nodes_remaining:
673 score = local_colorability_score(node)
674 if best_node is None or score > best_score:
675 best_node = node
676 best_score = score
677 if best_score > 0:
678 # it's locally colorable, no need to find a better one
679 break
680
681 if best_node is None:
682 break
683 node_stack.append(best_node)
684 nodes_remaining.remove(best_node)
685 local_colorability_score_cache.pop(best_node, None)
686 for neighbor in best_node.edges:
687 local_colorability_score_cache.pop(neighbor, None)
688
689 if debug_out is not None:
690 print(f"After deciding node allocation order:\n"
691 f"{node_stack}", file=debug_out, flush=True)
692
693 retval = {} # type: dict[SSAVal, Loc]
694
695 while len(node_stack) > 0:
696 node = node_stack.pop()
697 if node.loc is not None:
698 if node.loc_conflicts_with_neighbors(node.loc):
699 raise AllocationFailedError(
700 "IGNode is pre-allocated to a conflicting Loc",
701 node=node, interference_graph=interference_graph)
702 else:
703 # pick the first non-conflicting register in node.reg_class, since
704 # register classes are ordered from most preferred to least
705 # preferred register.
706 for loc in node.loc_set:
707 if not node.loc_conflicts_with_neighbors(loc):
708 node.loc = loc
709 break
710 if node.loc is None:
711 raise AllocationFailedError(
712 "failed to allocate Loc for IGNode",
713 node=node, interference_graph=interference_graph)
714
715 if debug_out is not None:
716 print(f"After allocating Loc for node:\n{node}",
717 file=debug_out, flush=True)
718
719 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
720 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
721
722 if debug_out is not None:
723 print(f"final Locs for all SSAVals:\n{retval}",
724 file=debug_out, flush=True)
725
726 return retval