copy-merging works afaict! -- some tests still broken: out-of-date
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from functools import reduce
9 from itertools import combinations, count
10 from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
11
12 from cached_property import cached_property
13 from nmutil.plain_data import plain_data, replace
14
15 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
16 LocSet, Op, ProgramRange,
17 SSAVal, SSAValSubReg, Ty)
18 from bigint_presentation_code.type_util import final
19 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
20
21
22 class BadMergedSSAVal(ValueError):
23 pass
24
25
26 _CopyRelation = Tuple[SSAValSubReg, SSAValSubReg]
27
28
29 @plain_data(frozen=True, repr=False)
30 @final
31 class MergedSSAVal(metaclass=InternedMeta):
32 """a set of `SSAVal`s along with their offsets, all register allocated as
33 a single unit.
34
35 Definition of the term `offset` for this class:
36
37 Let `locs[x]` be the `Loc` that `x` is assigned to after register
38 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
39 for each `SSAVal` `ssa_val` in `msv` is defined as:
40
41 ```
42 msv.ssa_val_offsets[ssa_val] = (msv.offset
43 + locs[ssa_val].start - locs[msv].start)
44 ```
45
46 Example:
47 ```
48 v1.ty == <I64*4>
49 v2.ty == <I64*2>
50 v3.ty == <I64>
51 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
52 msv.ty == <I64*6>
53 ```
54 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
55 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
56 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
57 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
58 """
59 __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
60 "first_loc")
61
62 def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
63 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
64 self.fn_analysis = fn_analysis
65 if isinstance(ssa_val_offsets, SSAVal):
66 ssa_val_offsets = {ssa_val_offsets: 0}
67 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
68 first_ssa_val = None
69 for ssa_val in self.ssa_vals:
70 first_ssa_val = ssa_val
71 break
72 if first_ssa_val is None:
73 raise BadMergedSSAVal("MergedSSAVal can't be empty")
74 self.first_ssa_val = first_ssa_val # type: SSAVal
75 # self.ty checks for mismatched base_ty
76 reg_len = self.ty.reg_len
77 if loc_set is not None and loc_set.ty != self.ty:
78 raise ValueError(
79 f"invalid loc_set, type doesn't match: "
80 f"{loc_set.ty} != {self.ty}")
81 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
82 def locs():
83 # type: () -> Iterable[Loc]
84 for loc in ssa_val.def_loc_set_before_spread:
85 disallowed_by_use = False
86 for use in fn_analysis.uses[ssa_val]:
87 # calculate the start for the use's Loc before spread
88 # e.g. if the def's Loc before spread starts at r6
89 # and the def's reg_offset_in_unspread is 5
90 # and the use's reg_offset_in_unspread is 3
91 # then the use's Loc before spread starts at r8
92 # because 8 == 6 + 5 - 3
93 start = (loc.start + ssa_val.reg_offset_in_unspread
94 - use.reg_offset_in_unspread)
95 use_loc = Loc.try_make(
96 loc.kind, start=start,
97 reg_len=use.ty_before_spread.reg_len)
98 if (use_loc is None or
99 use_loc not in use.use_loc_set_before_spread):
100 disallowed_by_use = True
101 break
102 if disallowed_by_use:
103 continue
104 start = loc.start - cur_offset + self.offset
105 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
106 if loc is not None and (loc_set is None or loc in loc_set):
107 yield loc
108 loc_set = LocSet(locs())
109 assert loc_set is not None, "already checked that self isn't empty"
110 first_loc = None
111 for loc in loc_set:
112 first_loc = loc
113 break
114 if first_loc is None:
115 raise BadMergedSSAVal("there are no valid Locs left")
116 self.first_loc = first_loc
117 assert loc_set.ty == self.ty, "logic error somewhere"
118 self.loc_set = loc_set # type: LocSet
119 self.__mergable_check()
120
121 def __mergable_check(self):
122 # type: () -> None
123 """ checks that nothing is forcing two independent SSAVals
124 to illegally overlap. This is required to avoid copy merging merging
125 things that can't be merged.
126 spread arguments are one of the things that can force two values to
127 illegally overlap.
128 """
129 ops = OSet() # type: Iterable[Op]
130 for ssa_val in self.ssa_vals:
131 ops.add(ssa_val.op)
132 for use in self.fn_analysis.uses[ssa_val]:
133 ops.add(use.op)
134 ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
135 vals = {} # type: dict[int, SSAValSubReg]
136 for op in ops:
137 for inp in op.input_vals:
138 try:
139 ssa_val_offset = self.ssa_val_offsets[inp]
140 except KeyError:
141 continue
142 for orig_reg in inp.ssa_val_sub_regs:
143 reg_offset = ssa_val_offset + orig_reg.reg_idx
144 replaced_reg = vals[reg_offset]
145 if not self.fn_analysis.is_always_equal(
146 orig_reg, replaced_reg):
147 raise BadMergedSSAVal(
148 f"attempting to merge values that aren't known to "
149 f"be always equal: {orig_reg} != {replaced_reg}")
150 output_offsets = dict.fromkeys(range(
151 self.offset, self.offset + self.ty.reg_len))
152 for out in op.outputs:
153 try:
154 ssa_val_offset = self.ssa_val_offsets[out]
155 except KeyError:
156 continue
157 for reg in out.ssa_val_sub_regs:
158 reg_offset = ssa_val_offset + reg.reg_idx
159 try:
160 del output_offsets[reg_offset]
161 except KeyError:
162 raise BadMergedSSAVal("attempted to merge two outputs "
163 "of the same instruction")
164 vals[reg_offset] = reg
165
166 @cached_property
167 def __hash(self):
168 # type: () -> int
169 return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
170
171 def __hash__(self):
172 # type: () -> int
173 return self.__hash
174
175 @property
176 def only_loc(self):
177 # type: () -> Loc | None
178 return self.loc_set.only_loc
179
180 @cached_property
181 def offset(self):
182 # type: () -> int
183 return min(self.ssa_val_offsets_before_spread.values())
184
185 @property
186 def base_ty(self):
187 # type: () -> BaseTy
188 return self.first_ssa_val.base_ty
189
190 @cached_property
191 def ssa_vals(self):
192 # type: () -> OFSet[SSAVal]
193 return OFSet(self.ssa_val_offsets.keys())
194
195 @cached_property
196 def ty(self):
197 # type: () -> Ty
198 reg_len = 0
199 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
200 cur_ty = ssa_val.ty_before_spread
201 if self.base_ty != cur_ty.base_ty:
202 raise BadMergedSSAVal(
203 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
204 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
205 return Ty(base_ty=self.base_ty, reg_len=reg_len)
206
207 @cached_property
208 def ssa_val_offsets_before_spread(self):
209 # type: () -> FMap[SSAVal, int]
210 retval = {} # type: dict[SSAVal, int]
211 for ssa_val, offset in self.ssa_val_offsets.items():
212 retval[ssa_val] = (
213 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
214 return FMap(retval)
215
216 def offset_by(self, amount):
217 # type: (int) -> MergedSSAVal
218 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
219 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
220
221 def normalized(self):
222 # type: () -> MergedSSAVal
223 return self.offset_by(-self.offset)
224
225 def with_offset_to_match(self, target, additional_offset=0):
226 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
227 if isinstance(target, MergedSSAVal):
228 ssa_val_offsets = target.ssa_val_offsets
229 else:
230 ssa_val_offsets = {target: 0}
231 for ssa_val, offset in self.ssa_val_offsets.items():
232 if ssa_val in ssa_val_offsets:
233 return self.offset_by(
234 ssa_val_offsets[ssa_val] + additional_offset - offset)
235 raise ValueError("can't change offset to match unrelated MergedSSAVal")
236
237 def with_loc(self, loc):
238 # type: (Loc) -> MergedSSAVal
239 if loc not in self.loc_set:
240 raise ValueError(
241 f"Loc is not allowed -- not a member of `self.loc_set`: "
242 f"{loc} not in {self.loc_set}")
243 return MergedSSAVal(fn_analysis=self.fn_analysis,
244 ssa_val_offsets=self.ssa_val_offsets,
245 loc_set=LocSet([loc]))
246
247 def merged(self, *others):
248 # type: (*MergedSSAVal) -> MergedSSAVal
249 retval = dict(self.ssa_val_offsets)
250 for other in others:
251 if other.fn_analysis != self.fn_analysis:
252 raise ValueError("fn_analysis mismatch")
253 for ssa_val, offset in other.ssa_val_offsets.items():
254 if ssa_val in retval and retval[ssa_val] != offset:
255 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
256 f"{retval[ssa_val]} != {offset}")
257 retval[ssa_val] = offset
258 return MergedSSAVal(fn_analysis=self.fn_analysis,
259 ssa_val_offsets=retval)
260
261 @cached_property
262 def live_interval(self):
263 # type: () -> ProgramRange
264 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
265 start = live_range.start
266 stop = live_range.stop
267 for ssa_val in self.ssa_vals:
268 live_range = self.fn_analysis.live_ranges[ssa_val]
269 start = min(start, live_range.start)
270 stop = max(stop, live_range.stop)
271 return ProgramRange(start=start, stop=stop)
272
273 def __repr__(self):
274 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
275 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
276 f"live_interval={self.live_interval})")
277
278 @cached_property
279 def copy_related_ssa_vals(self):
280 # type: () -> OFSet[SSAVal]
281 sets = OSet() # type: OSet[OFSet[SSAVal]]
282 # avoid merging the same sets multiple times
283 for ssa_val in self.ssa_vals:
284 sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
285 return OFSet(v for s in sets for v in s)
286
287 def get_copy_relation(self, other):
288 # type: (MergedSSAVal) -> None | _CopyRelation
289 for lhs_ssa_val in self.ssa_vals:
290 for rhs_ssa_val in other.ssa_vals:
291 for lhs in lhs_ssa_val.ssa_val_sub_regs:
292 for rhs in rhs_ssa_val.ssa_val_sub_regs:
293 lhs_src = self.fn_analysis.copies.get(lhs, lhs)
294 rhs_src = self.fn_analysis.copies.get(rhs, rhs)
295 if lhs_src == rhs_src:
296 return lhs, rhs
297 return None
298
299 def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
300 # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
301 cr_lhs, cr_rhs = copy_relation
302 if cr_lhs.ssa_val not in self.ssa_vals:
303 cr_lhs, cr_rhs = cr_rhs, cr_lhs
304 lhs_merged = self.with_offset_to_match(
305 cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
306 if lhs_loc is not None:
307 lhs_merged = lhs_merged.with_loc(lhs_loc)
308 rhs_merged = rhs.with_offset_to_match(
309 cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
310 if rhs_loc is not None:
311 rhs_merged = rhs_merged.with_loc(rhs_loc)
312 return lhs_merged.merged(rhs_merged).normalized()
313
314
315 @final
316 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
317 def __init__(self):
318 # type: (...) -> None
319 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
320 self.__ig_node_map = MergedSSAValToIGNodeMap(
321 _private_merged_ssa_val_map=self.__map)
322
323 def __getitem__(self, __key):
324 # type: (SSAVal) -> MergedSSAVal
325 return self.__map[__key]
326
327 def __iter__(self):
328 # type: () -> Iterator[SSAVal]
329 return iter(self.__map)
330
331 def __len__(self):
332 # type: () -> int
333 return len(self.__map)
334
335 @property
336 def ig_node_map(self):
337 # type: () -> MergedSSAValToIGNodeMap
338 return self.__ig_node_map
339
340 def __repr__(self):
341 # type: () -> str
342 s = ",\n".join(repr(v) for v in self.__ig_node_map)
343 return f"SSAValToMergedSSAValMap({{{s}}})"
344
345
346 @final
347 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
348 def __init__(
349 self, *,
350 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
351 ):
352 # type: (...) -> None
353 self.__merged_ssa_val_map = _private_merged_ssa_val_map
354 self.__map = {} # type: dict[MergedSSAVal, IGNode]
355 self.__next_node_id = 0
356
357 def __getitem__(self, __key):
358 # type: (MergedSSAVal) -> IGNode
359 return self.__map[__key]
360
361 def __iter__(self):
362 # type: () -> Iterator[MergedSSAVal]
363 return iter(self.__map)
364
365 def __len__(self):
366 # type: () -> int
367 return len(self.__map)
368
369 def add_node(self, merged_ssa_val):
370 # type: (MergedSSAVal) -> IGNode
371 node = self.__map.get(merged_ssa_val, None)
372 if node is not None:
373 return node
374 added = 0 # type: int | None
375 try:
376 for ssa_val in merged_ssa_val.ssa_vals:
377 if ssa_val in self.__merged_ssa_val_map:
378 raise ValueError(
379 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
380 f"{merged_ssa_val} and "
381 f"{self.__merged_ssa_val_map[ssa_val]}")
382 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
383 added += 1
384 retval = IGNode(
385 node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
386 edges={}, loc=merged_ssa_val.only_loc, ignored=False)
387 self.__map[merged_ssa_val] = retval
388 self.__next_node_id += 1
389 added = None
390 return retval
391 finally:
392 if added is not None:
393 # remove partially added stuff
394 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
395 if idx >= added:
396 break
397 del self.__merged_ssa_val_map[ssa_val]
398
399 def merge_into_one_node(self, final_merged_ssa_val):
400 # type: (MergedSSAVal) -> IGNode
401 source_nodes = OSet() # type: OSet[IGNode]
402 edges = {} # type: dict[IGNode, IGEdge]
403 for ssa_val in final_merged_ssa_val.ssa_vals:
404 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
405 source_node = self.__map[merged_ssa_val]
406 if source_node.ignored:
407 raise ValueError(f"can't merge ignored nodes: {source_node}")
408 source_nodes.add(source_node)
409 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
410 raise ValueError(
411 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
412 f"but not in merged IGNode's merged_ssa_val: "
413 f"source_node={source_node} "
414 f"final_merged_ssa_val={final_merged_ssa_val}")
415 if source_node.loc != source_node.merged_ssa_val.only_loc:
416 raise ValueError(
417 f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
418 f"{source_node.loc} != "
419 f"{source_node.merged_ssa_val.only_loc}")
420 for n, edge in source_node.edges.items():
421 if n in edges:
422 edge = edge.merged(edges[n])
423 edges[n] = edge
424 if len(source_nodes) == 1:
425 return source_nodes.pop() # merging a single node is a no-op
426 # we're finished checking validity, now we can modify stuff
427 for n in source_nodes:
428 edges.pop(n, None)
429 loc = final_merged_ssa_val.only_loc
430 for n, edge in edges.items():
431 if edge.copy_relation is None or not edge.interferes:
432 continue
433 try:
434 # if merging works, then the edge can't interfere
435 _ = final_merged_ssa_val.copy_merged(
436 lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc,
437 copy_relation=edge.copy_relation)
438 except BadMergedSSAVal:
439 continue
440 edges[n] = replace(edge, interferes=False)
441 retval = IGNode(
442 node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
443 edges=edges, loc=loc, ignored=False)
444 self.__next_node_id += 1
445 empty_e = IGEdge()
446 for node in edges:
447 edge = reduce(IGEdge.merged,
448 (node.edges.pop(n, empty_e) for n in source_nodes))
449 if edge == empty_e:
450 node.edges.pop(retval, None)
451 else:
452 node.edges[retval] = edge
453 for node in source_nodes:
454 del self.__map[node.merged_ssa_val]
455 self.__map[final_merged_ssa_val] = retval
456 for ssa_val in final_merged_ssa_val.ssa_vals:
457 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
458 return retval
459
460 def __repr__(self, repr_state=None):
461 # type: (None | IGNodeReprState) -> str
462 if repr_state is None:
463 repr_state = IGNodeReprState()
464 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
465 return f"MergedSSAValToIGNodeMap({{{s}}})"
466
467
468 @plain_data(frozen=True, repr=False)
469 @final
470 class InterferenceGraph:
471 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
472
473 def __init__(self, fn_analysis, merged_ssa_vals):
474 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
475 self.fn_analysis = fn_analysis
476 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
477 self.nodes = self.merged_ssa_val_map.ig_node_map
478 for i in merged_ssa_vals:
479 self.nodes.add_node(i)
480
481 def merge_preview(self, ssa_val1, ssa_val2, additional_offset=0):
482 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
483 merged1 = self.merged_ssa_val_map[ssa_val1]
484 merged2 = self.merged_ssa_val_map[ssa_val2]
485 merged = merged1.with_offset_to_match(ssa_val1)
486 return merged.merged(merged2.with_offset_to_match(
487 ssa_val2, additional_offset=additional_offset)).normalized()
488
489 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
490 # type: (SSAVal, SSAVal, int) -> IGNode
491 return self.nodes.merge_into_one_node(self.merge_preview(
492 ssa_val1=ssa_val1, ssa_val2=ssa_val2,
493 additional_offset=additional_offset))
494
495 def copy_merge(self, node1, node2):
496 # type: (IGNode, IGNode) -> IGNode
497 return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
498
499 def local_colorability_score(self, node, merged_in_copy=None):
500 # type: (IGNode, None | IGNode) -> int
501 """ returns a positive integer if node is locally colorable, returns
502 zero or a negative integer if node isn't known to be locally
503 colorable, the more negative the value, the less colorable.
504
505 if `merged_in_copy` is not `None`, then the node used is what would be
506 the result of `self.copy_merge(node, merged_in_copy)`.
507 """
508 if node.ignored:
509 raise ValueError(
510 "can't get local_colorability_score of ignored node")
511 loc_set = node.loc_set
512 edges = node.edges
513 if merged_in_copy is not None:
514 if merged_in_copy.ignored:
515 raise ValueError(
516 "can't get local_colorability_score of ignored node")
517 loc_set = node.copy_merge_preview(merged_in_copy).loc_set
518 edges = edges.copy()
519 for neighbor, edge in merged_in_copy.edges.items():
520 edges[neighbor] = edge.merged(edges.get(neighbor))
521 retval = len(loc_set)
522 for neighbor, edge in edges.items():
523 if neighbor.ignored or not edge.interferes:
524 continue
525 if neighbor == merged_in_copy or neighbor == node:
526 continue
527 retval -= loc_set.max_conflicts_with(neighbor.loc_set)
528 return retval
529
530 @staticmethod
531 def minimally_merged(fn_analysis):
532 # type: (FnAnalysis) -> InterferenceGraph
533 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
534 for op in fn_analysis.fn.ops:
535 for inp in op.input_uses:
536 if inp.unspread_start != inp:
537 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
538 additional_offset=inp.reg_offset_in_unspread)
539 for out in op.outputs:
540 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
541 if out.unspread_start != out:
542 retval.merge(out.unspread_start, out,
543 additional_offset=out.reg_offset_in_unspread)
544 if out.tied_input is not None:
545 retval.merge(out.tied_input.ssa_val, out)
546 return retval
547
548 def __repr__(self, repr_state=None):
549 # type: (None | IGNodeReprState) -> str
550 if repr_state is None:
551 repr_state = IGNodeReprState()
552 s = self.nodes.__repr__(repr_state)
553 return f"InterferenceGraph(nodes={s}, <...>)"
554
555 def dump_to_dot(
556 self, highlighted_nodes=(), # type: Container[IGNode]
557 node_scores=None, # type: None | dict[IGNode, int]
558 edge_scores=None, # type: None | dict[tuple[IGNode, IGNode], int]
559 ):
560 # type: (...) -> str
561
562 def quote(s):
563 # type: (object) -> str
564 s = str(s)
565 s = s.replace('\\', r'\\')
566 s = s.replace('"', r'\"')
567 s = s.replace('\n', r'\n')
568 return f'"{s}"'
569
570 if node_scores is None:
571 node_scores = {}
572 if edge_scores is None:
573 edge_scores = {}
574
575 edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge]
576 node_ids = {} # type: dict[IGNode, str]
577 for node in self.nodes.values():
578 node_ids[node] = quote(node.node_id)
579 for neighbor, edge in node.edges.items():
580 edge_key = (node, neighbor)
581 # ensure we only insert each edge once by checking for
582 # both directions
583 if edge_key not in edges and edge_key[::-1] not in edges:
584 edges[edge_key] = edge
585 lines = [
586 "graph {",
587 " graph [pack = true]",
588 ]
589 for node, node_id in node_ids.items():
590 label_lines = [] # type: list[str]
591 score = node_scores.get(node)
592 if score is not None:
593 label_lines.append(f"score={score}")
594 for k, v in node.merged_ssa_val.ssa_val_offsets.items():
595 label_lines.append(f"{k}: {v}")
596 label = quote("\n".join(label_lines))
597 style = "dotted" if node.ignored else "solid"
598 color = "black"
599 if node in highlighted_nodes:
600 style = "bold"
601 color = "green"
602 style = quote(style)
603 color = quote(color)
604 lines.append(f" {node_id} ["
605 f"label = {label}, "
606 f"style = {style}, "
607 f"color = {color}]")
608
609 def append_edge(node1, node2, label, color, style):
610 # type: (IGNode, IGNode, str, str, str) -> None
611 label = quote(label)
612 color = quote(color)
613 style = quote(style)
614 lines.append(f" {node_ids[node1]} -- {node_ids[node2]} ["
615 f"label = {label}, "
616 f"color = {color}, "
617 f"style = {style}, "
618 f"decorate = true]")
619 for (node1, node2), edge in edges.items():
620 score = edge_scores.get((node1, node2))
621 if score is None:
622 score = edge_scores.get((node2, node1))
623 label_prefix = ""
624 if score is not None:
625 label_prefix = f"score={score}\n"
626 if edge.interferes:
627 append_edge(node1, node2, label=label_prefix + "interferes",
628 color="darkred", style="bold")
629 if edge.copy_relation is not None:
630 append_edge(node1, node2, label=label_prefix + "copy related",
631 color="blue", style="dashed")
632 lines.append("}")
633 return "\n".join(lines)
634
635
636 @plain_data(repr=False)
637 class IGNodeReprState:
638 __slots__ = "did_full_repr",
639
640 def __init__(self):
641 super().__init__()
642 self.did_full_repr = OSet() # type: OSet[IGNode]
643
644
645 @plain_data(frozen=True, unsafe_hash=True)
646 @final
647 class IGEdge:
648 """ interference graph edge """
649 __slots__ = "interferes", "copy_relation"
650
651 def __init__(self, interferes=False, copy_relation=None):
652 # type: (bool, None | _CopyRelation) -> None
653 self.interferes = interferes
654 self.copy_relation = copy_relation
655
656 def merged(self, other):
657 # type: (IGEdge | None) -> IGEdge
658 if other is None:
659 return self
660 copy_relation = self.copy_relation
661 if copy_relation is None:
662 copy_relation = other.copy_relation
663 interferes = self.interferes | other.interferes
664 return IGEdge(interferes=interferes, copy_relation=copy_relation)
665
666
667 @final
668 class IGNode:
669 """ interference graph node """
670 __slots__ = "node_id", "merged_ssa_val", "edges", "loc", "ignored"
671
672 def __init__(self, node_id, merged_ssa_val, edges, loc, ignored):
673 # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
674 self.node_id = node_id
675 self.merged_ssa_val = merged_ssa_val
676 self.edges = edges
677 self.loc = loc
678 self.ignored = ignored
679
680 def merge_edge(self, other, edge):
681 # type: (IGNode, IGEdge) -> None
682 if self == other:
683 raise ValueError("can't have self-loops")
684 old_edge = self.edges.get(other, None)
685 assert old_edge is other.edges.get(self, None), "inconsistent edges"
686 edge = edge.merged(old_edge)
687 if edge == IGEdge():
688 self.edges.pop(other, None)
689 other.edges.pop(self, None)
690 else:
691 self.edges[other] = edge
692 other.edges[self] = edge
693
694 def __eq__(self, other):
695 # type: (object) -> bool
696 if isinstance(other, IGNode):
697 return self.node_id == other.node_id
698 return NotImplemented
699
700 def __hash__(self):
701 # type: () -> int
702 return hash(self.node_id)
703
704 def __repr__(self, repr_state=None, short=False):
705 # type: (None | IGNodeReprState, bool) -> str
706 rs = repr_state
707 del repr_state
708 if rs is None:
709 rs = IGNodeReprState()
710 if short or self in rs.did_full_repr:
711 return f"<IGNode #{self.node_id}>"
712 rs.did_full_repr.add(self)
713 edges = ", ".join(
714 f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
715 return (f"IGNode(#{self.node_id}, "
716 f"merged_ssa_val={self.merged_ssa_val}, "
717 f"edges={{{edges}}}, "
718 f"loc={self.loc}, "
719 f"ignored={self.ignored})")
720
721 @property
722 def loc_set(self):
723 # type: () -> LocSet
724 return self.merged_ssa_val.loc_set
725
726 def loc_conflicts_with_neighbors(self, loc):
727 # type: (Loc) -> bool
728 for neighbor, edge in self.edges.items():
729 if not edge.interferes:
730 continue
731 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
732 return True
733 return False
734
735 def copy_merge_preview(self, rhs_node):
736 # type: (IGNode) -> MergedSSAVal
737 try:
738 copy_relation = self.edges[rhs_node].copy_relation
739 except KeyError:
740 raise ValueError("nodes aren't copy related")
741 if copy_relation is None:
742 raise ValueError("nodes aren't copy related")
743 return self.merged_ssa_val.copy_merged(
744 lhs_loc=self.loc,
745 rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
746 copy_relation=copy_relation)
747
748
749 class AllocationFailedError(Exception):
750 def __init__(self, msg, node, interference_graph):
751 # type: (str, IGNode, InterferenceGraph) -> None
752 super().__init__(msg, node, interference_graph)
753 self.node = node
754 self.interference_graph = interference_graph
755
756 def __repr__(self, repr_state=None):
757 # type: (None | IGNodeReprState) -> str
758 if repr_state is None:
759 repr_state = IGNodeReprState()
760 return (f"{__class__.__name__}({self.args[0]!r}, "
761 f"node={self.node.__repr__(repr_state, True)}, "
762 f"interference_graph="
763 f"{self.interference_graph.__repr__(repr_state)})")
764
765 def __str__(self):
766 # type: () -> str
767 return self.__repr__()
768
769
770 def allocate_registers(
771 fn, # type: Fn
772 debug_out=None, # type: TextIO | None
773 dump_graph=None, # type: Callable[[str, str], None] | None
774 ):
775 # type: (...) -> dict[SSAVal, Loc]
776
777 # inserts enough copies that no manual spilling is necessary, all
778 # spilling is done by the register allocator naturally allocating SSAVals
779 # to stack slots
780 fn.pre_ra_insert_copies()
781
782 if debug_out is not None:
783 print(f"After pre_ra_insert_copies():\n{fn.ops}",
784 file=debug_out, flush=True)
785
786 fn_analysis = FnAnalysis(fn)
787 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
788
789 if debug_out is not None:
790 print(f"After InterferenceGraph.minimally_merged():\n"
791 f"{interference_graph}", file=debug_out, flush=True)
792
793 for i, j in combinations(interference_graph.nodes.values(), 2):
794 copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
795 i.merge_edge(j, IGEdge(copy_relation=copy_relation))
796
797 for pp, ssa_vals in fn_analysis.live_at.items():
798 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
799 for ssa_val in ssa_vals:
800 live_merged_ssa_vals.add(
801 interference_graph.merged_ssa_val_map[ssa_val])
802 for i, j in combinations(live_merged_ssa_vals, 2):
803 if i.loc_set.max_conflicts_with(j.loc_set) == 0:
804 continue
805 node_i = interference_graph.nodes[i]
806 node_j = interference_graph.nodes[j]
807 if node_j in node_i.edges:
808 if node_i.edges[node_j].copy_relation is not None:
809 try:
810 _ = node_i.copy_merge_preview(node_j)
811 continue # doesn't interfere if copy merging succeeds
812 except BadMergedSSAVal:
813 pass
814 node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
815 if debug_out is not None:
816 print(f"processed {pp} out of {fn_analysis.all_program_points}",
817 file=debug_out, flush=True)
818
819 if debug_out is not None:
820 print(f"After adding interference graph edges:\n"
821 f"{interference_graph}", file=debug_out, flush=True)
822 if dump_graph is not None:
823 dump_graph("initial", interference_graph.dump_to_dot())
824
825 node_stack = [] # type: list[IGNode]
826
827 debug_node_scores = {} # type: dict[IGNode, int]
828 debug_edge_scores = {} # type: dict[tuple[IGNode, IGNode], int]
829
830 def find_best_node(has_copy_relation):
831 # type: (bool) -> None | IGNode
832 best_node = None # type: None | IGNode
833 best_score = 0
834 for node in interference_graph.nodes.values():
835 if node.ignored:
836 continue
837 node_has_copy_relation = False
838 for neighbor, edge in node.edges.items():
839 if neighbor.ignored:
840 continue
841 if edge.copy_relation is not None:
842 node_has_copy_relation = True
843 break
844 if node_has_copy_relation != has_copy_relation:
845 continue
846 score = interference_graph.local_colorability_score(node)
847 debug_node_scores[node] = score
848 if best_node is None or score > best_score:
849 best_node = node
850 best_score = score
851 if best_score > 0:
852 # it's locally colorable, no need to find a better one
853 break
854 if debug_out is not None:
855 print(f"find_best_node(has_copy_relation={has_copy_relation}):\n"
856 f"{best_node}", file=debug_out, flush=True)
857 return best_node
858 # copy-merging algorithm based on Iterated Register Coalescing, section 5:
859 # https://dl.acm.org/doi/pdf/10.1145/229542.229546
860 # Build step is above.
861 for step in count():
862 debug_node_scores.clear()
863 debug_edge_scores.clear()
864 # Simplify:
865 best_node = find_best_node(has_copy_relation=False)
866 if best_node is not None:
867 if dump_graph is not None:
868 dump_graph(
869 f"step_{step}_simplify", interference_graph.dump_to_dot(
870 highlighted_nodes=[best_node],
871 node_scores=debug_node_scores,
872 edge_scores=debug_edge_scores))
873 node_stack.append(best_node)
874 best_node.ignored = True
875 continue
876 # Coalesce (aka. do copy-merges):
877 did_any_copy_merges = False
878 for node in interference_graph.nodes.values():
879 if node.ignored:
880 continue
881 for neighbor, edge in node.edges.items():
882 if neighbor.ignored:
883 continue
884 if edge.copy_relation is None:
885 continue
886 try:
887 score = interference_graph.local_colorability_score(
888 node, merged_in_copy=neighbor)
889 except BadMergedSSAVal:
890 continue
891 if (neighbor, node) in debug_edge_scores:
892 debug_edge_scores[(neighbor, node)] = score
893 else:
894 debug_edge_scores[(node, neighbor)] = score
895 if score > 0: # merged node is locally colorable
896 if dump_graph is not None:
897 dump_graph(
898 f"step_{step}_copy_merge",
899 interference_graph.dump_to_dot(
900 highlighted_nodes=[node, neighbor],
901 node_scores=debug_node_scores,
902 edge_scores=debug_edge_scores))
903 if debug_out is not None:
904 print(f"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
905 file=debug_out, flush=True)
906 merged_node = interference_graph.copy_merge(node, neighbor)
907 if dump_graph is not None:
908 dump_graph(
909 f"step_{step}_copy_merge_result",
910 interference_graph.dump_to_dot(
911 highlighted_nodes=[merged_node]))
912 if debug_out is not None:
913 print(f"merged_node:\n"
914 f"{merged_node}", file=debug_out, flush=True)
915 did_any_copy_merges = True
916 break
917 if did_any_copy_merges:
918 break
919 if did_any_copy_merges:
920 continue
921 # Freeze:
922 best_node = find_best_node(has_copy_relation=True)
923 if best_node is not None:
924 if dump_graph is not None:
925 dump_graph(f"step_{step}_freeze",
926 interference_graph.dump_to_dot(
927 highlighted_nodes=[best_node],
928 node_scores=debug_node_scores,
929 edge_scores=debug_edge_scores))
930 # no need to clear copy relations since best_node won't be
931 # considered since it's now ignored.
932 node_stack.append(best_node)
933 best_node.ignored = True
934 continue
935 break
936
937 if dump_graph is not None:
938 dump_graph("final", interference_graph.dump_to_dot())
939 if debug_out is not None:
940 print(f"After deciding node allocation order:\n"
941 f"{node_stack}", file=debug_out, flush=True)
942
943 retval = {} # type: dict[SSAVal, Loc]
944
945 while len(node_stack) > 0:
946 node = node_stack.pop()
947 if node.loc is not None:
948 if node.loc_conflicts_with_neighbors(node.loc):
949 raise AllocationFailedError(
950 "IGNode is pre-allocated to a conflicting Loc",
951 node=node, interference_graph=interference_graph)
952 else:
953 # Locs to try allocating, ordered from most preferred to least
954 # preferred
955 locs = OSet()
956 # prefer eliminating copies
957 for neighbor, edge in node.edges.items():
958 if neighbor.loc is None or edge.copy_relation is None:
959 continue
960 try:
961 merged = node.copy_merge_preview(neighbor)
962 except BadMergedSSAVal:
963 continue
964 # get merged_loc if merged.loc_set has a single Loc
965 merged_loc = merged.only_loc
966 if merged_loc is None:
967 continue
968 ssa_val = node.merged_ssa_val.first_ssa_val
969 ssa_val_loc = merged_loc.get_subloc_at_offset(
970 subloc_ty=ssa_val.ty,
971 offset=merged.ssa_val_offsets[ssa_val])
972 node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
973 superloc_ty=node.merged_ssa_val.ty,
974 offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
975 assert node_loc in node.merged_ssa_val.loc_set, "logic error"
976 locs.add(node_loc)
977 # add node's allowed Locs as fallback
978 for loc in node.loc_set:
979 # TODO: add in order of preference
980 locs.add(loc)
981 # pick the first non-conflicting register in locs, since locs is
982 # ordered from most preferred to least preferred register.
983 for loc in locs:
984 if not node.loc_conflicts_with_neighbors(loc):
985 node.loc = loc
986 break
987 if node.loc is None:
988 raise AllocationFailedError(
989 "failed to allocate Loc for IGNode",
990 node=node, interference_graph=interference_graph)
991
992 if debug_out is not None:
993 print(f"After allocating Loc for node:\n{node}",
994 file=debug_out, flush=True)
995
996 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
997 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
998
999 if debug_out is not None:
1000 print(f"final Locs for all SSAVals:\n{retval}",
1001 file=debug_out, flush=True)
1002
1003 return retval