959e1c39cb978105460037d4c139639025debfbe
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from dataclasses import dataclass
9 from functools import lru_cache, reduce
10 from itertools import combinations, count
11 from typing import Any, Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
12
13 from cached_property import cached_property
14 from nmutil.plain_data import plain_data, replace
15
16 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
17 LocSet, Op, ProgramRange,
18 SSAVal, SSAValSubReg, Ty)
19 from bigint_presentation_code.type_util import final
20 from bigint_presentation_code.util import FMap, Interned, OFSet, OSet
21
22
23 class BadMergedSSAVal(ValueError):
24 pass
25
26
27 _CopyRelation = Tuple[SSAValSubReg, SSAValSubReg]
28
29
30 @dataclass(frozen=True, repr=False, eq=False)
31 @final
32 class MergedSSAVal(Interned):
33 """a set of `SSAVal`s along with their offsets, all register allocated as
34 a single unit.
35
36 Definition of the term `offset` for this class:
37
38 Let `locs[x]` be the `Loc` that `x` is assigned to after register
39 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
40 for each `SSAVal` `ssa_val` in `msv` is defined as:
41
42 ```
43 msv.ssa_val_offsets[ssa_val] = (msv.offset
44 + locs[ssa_val].start - locs[msv].start)
45 ```
46
47 Example:
48 ```
49 v1.ty == <I64*4>
50 v2.ty == <I64*2>
51 v3.ty == <I64>
52 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
53 msv.ty == <I64*6>
54 ```
55 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
56 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
57 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
58 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
59 """
60 fn_analysis: FnAnalysis
61 ssa_val_offsets: "FMap[SSAVal, int]"
62 first_ssa_val: SSAVal
63 loc_set: LocSet
64 first_loc: Loc
65
66 def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
67 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
68 object.__setattr__(self, "fn_analysis", fn_analysis)
69 if isinstance(ssa_val_offsets, SSAVal):
70 ssa_val_offsets = {ssa_val_offsets: 0}
71 object.__setattr__(self, "ssa_val_offsets", FMap(ssa_val_offsets))
72 first_ssa_val = None
73 for ssa_val in self.ssa_vals:
74 first_ssa_val = ssa_val
75 break
76 if first_ssa_val is None:
77 raise BadMergedSSAVal("MergedSSAVal can't be empty")
78 object.__setattr__(self, "first_ssa_val", first_ssa_val)
79 # self.ty checks for mismatched base_ty
80 reg_len = self.ty.reg_len
81 if loc_set is not None and loc_set.ty != self.ty:
82 raise ValueError(
83 f"invalid loc_set, type doesn't match: "
84 f"{loc_set.ty} != {self.ty}")
85 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
86 def locs():
87 # type: () -> Iterable[Loc]
88 for loc in ssa_val.def_loc_set_before_spread:
89 disallowed_by_use = False
90 for use in fn_analysis.uses[ssa_val]:
91 # calculate the start for the use's Loc before spread
92 # e.g. if the def's Loc before spread starts at r6
93 # and the def's reg_offset_in_unspread is 5
94 # and the use's reg_offset_in_unspread is 3
95 # then the use's Loc before spread starts at r8
96 # because 8 == 6 + 5 - 3
97 start = (loc.start + ssa_val.reg_offset_in_unspread
98 - use.reg_offset_in_unspread)
99 use_loc = Loc.try_make(
100 loc.kind, start=start,
101 reg_len=use.ty_before_spread.reg_len)
102 if (use_loc is None or
103 use_loc not in use.use_loc_set_before_spread):
104 disallowed_by_use = True
105 break
106 if disallowed_by_use:
107 continue
108 start = loc.start - cur_offset + self.offset
109 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
110 if loc is not None and (loc_set is None or loc in loc_set):
111 yield loc
112 loc_set = LocSet(locs())
113 assert loc_set is not None, "already checked that self isn't empty"
114 first_loc = None
115 for loc in loc_set:
116 first_loc = loc
117 break
118 if first_loc is None:
119 raise BadMergedSSAVal("there are no valid Locs left")
120 object.__setattr__(self, "first_loc", first_loc)
121 assert loc_set.ty == self.ty, "logic error somewhere"
122 object.__setattr__(self, "loc_set", loc_set)
123 self.__mergable_check()
124
125 def __mergable_check(self):
126 # type: () -> None
127 """ checks that nothing is forcing two independent SSAVals
128 to illegally overlap. This is required to avoid copy merging merging
129 things that can't be merged.
130 spread arguments are one of the things that can force two values to
131 illegally overlap.
132 """
133 ops = OSet() # type: Iterable[Op]
134 for ssa_val in self.ssa_vals:
135 ops.add(ssa_val.op)
136 for use in self.fn_analysis.uses[ssa_val]:
137 ops.add(use.op)
138 ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
139 vals = {} # type: dict[int, SSAValSubReg]
140 for op in ops:
141 for inp in op.input_vals:
142 try:
143 ssa_val_offset = self.ssa_val_offsets[inp]
144 except KeyError:
145 continue
146 for orig_reg in inp.ssa_val_sub_regs:
147 reg_offset = ssa_val_offset + orig_reg.reg_idx
148 replaced_reg = vals[reg_offset]
149 if not self.fn_analysis.is_always_equal(
150 orig_reg, replaced_reg):
151 raise BadMergedSSAVal(
152 f"attempting to merge values that aren't known to "
153 f"be always equal: {orig_reg} != {replaced_reg}")
154 output_offsets = dict.fromkeys(range(
155 self.offset, self.offset + self.ty.reg_len))
156 for out in op.outputs:
157 try:
158 ssa_val_offset = self.ssa_val_offsets[out]
159 except KeyError:
160 continue
161 for reg in out.ssa_val_sub_regs:
162 reg_offset = ssa_val_offset + reg.reg_idx
163 try:
164 del output_offsets[reg_offset]
165 except KeyError:
166 raise BadMergedSSAVal("attempted to merge two outputs "
167 "of the same instruction")
168 vals[reg_offset] = reg
169
170 def __hash__(self):
171 # type: () -> int
172 return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
173
174 def __eq__(self, other):
175 # type: (MergedSSAVal | Any) -> bool
176 if isinstance(other, MergedSSAVal):
177 return self.fn_analysis == other.fn_analysis and \
178 self.ssa_val_offsets == other.ssa_val_offsets and \
179 self.loc_set == other.loc_set
180 return NotImplemented
181
182 @property
183 def only_loc(self):
184 # type: () -> Loc | None
185 return self.loc_set.only_loc
186
187 @cached_property
188 def offset(self):
189 # type: () -> int
190 return min(self.ssa_val_offsets_before_spread.values())
191
192 @property
193 def base_ty(self):
194 # type: () -> BaseTy
195 return self.first_ssa_val.base_ty
196
197 @cached_property
198 def ssa_vals(self):
199 # type: () -> OFSet[SSAVal]
200 return OFSet(self.ssa_val_offsets.keys())
201
202 @cached_property
203 def ty(self):
204 # type: () -> Ty
205 reg_len = 0
206 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
207 cur_ty = ssa_val.ty_before_spread
208 if self.base_ty != cur_ty.base_ty:
209 raise BadMergedSSAVal(
210 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
211 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
212 return Ty(base_ty=self.base_ty, reg_len=reg_len)
213
214 @cached_property
215 def ssa_val_offsets_before_spread(self):
216 # type: () -> FMap[SSAVal, int]
217 retval = {} # type: dict[SSAVal, int]
218 for ssa_val, offset in self.ssa_val_offsets.items():
219 retval[ssa_val] = (
220 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
221 return FMap(retval)
222
223 def offset_by(self, amount):
224 # type: (int) -> MergedSSAVal
225 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
226 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
227
228 def normalized(self):
229 # type: () -> MergedSSAVal
230 return self.offset_by(-self.offset)
231
232 def with_offset_to_match(self, target, additional_offset=0):
233 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
234 if isinstance(target, MergedSSAVal):
235 ssa_val_offsets = target.ssa_val_offsets
236 else:
237 ssa_val_offsets = {target: 0}
238 for ssa_val, offset in self.ssa_val_offsets.items():
239 if ssa_val in ssa_val_offsets:
240 return self.offset_by(
241 ssa_val_offsets[ssa_val] + additional_offset - offset)
242 raise ValueError("can't change offset to match unrelated MergedSSAVal")
243
244 def with_loc(self, loc):
245 # type: (Loc) -> MergedSSAVal
246 if loc not in self.loc_set:
247 raise ValueError(
248 f"Loc is not allowed -- not a member of `self.loc_set`: "
249 f"{loc} not in {self.loc_set}")
250 return MergedSSAVal(fn_analysis=self.fn_analysis,
251 ssa_val_offsets=self.ssa_val_offsets,
252 loc_set=LocSet([loc]))
253
254 def merged(self, *others):
255 # type: (*MergedSSAVal) -> MergedSSAVal
256 retval = dict(self.ssa_val_offsets)
257 for other in others:
258 if other.fn_analysis != self.fn_analysis:
259 raise ValueError("fn_analysis mismatch")
260 for ssa_val, offset in other.ssa_val_offsets.items():
261 if ssa_val in retval and retval[ssa_val] != offset:
262 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
263 f"{retval[ssa_val]} != {offset}")
264 retval[ssa_val] = offset
265 return MergedSSAVal(fn_analysis=self.fn_analysis,
266 ssa_val_offsets=retval)
267
268 @cached_property
269 def live_interval(self):
270 # type: () -> ProgramRange
271 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
272 start = live_range.start
273 stop = live_range.stop
274 for ssa_val in self.ssa_vals:
275 live_range = self.fn_analysis.live_ranges[ssa_val]
276 start = min(start, live_range.start)
277 stop = max(stop, live_range.stop)
278 return ProgramRange(start=start, stop=stop)
279
280 def __repr__(self):
281 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
282 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
283 f"live_interval={self.live_interval})")
284
285 @cached_property
286 def copy_related_ssa_vals(self):
287 # type: () -> OFSet[SSAVal]
288 sets = OSet() # type: OSet[OFSet[SSAVal]]
289 # avoid merging the same sets multiple times
290 for ssa_val in self.ssa_vals:
291 sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
292 return OFSet(v for s in sets for v in s)
293
294 def get_copy_relation(self, other):
295 # type: (MergedSSAVal) -> None | _CopyRelation
296 for lhs_ssa_val in self.ssa_vals:
297 for rhs_ssa_val in other.ssa_vals:
298 for lhs in lhs_ssa_val.ssa_val_sub_regs:
299 for rhs in rhs_ssa_val.ssa_val_sub_regs:
300 lhs_src = self.fn_analysis.copies.get(lhs, lhs)
301 rhs_src = self.fn_analysis.copies.get(rhs, rhs)
302 if lhs_src == rhs_src:
303 return lhs, rhs
304 return None
305
306 @lru_cache(maxsize=None, typed=True)
307 def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
308 # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
309 cr_lhs, cr_rhs = copy_relation
310 if cr_lhs.ssa_val not in self.ssa_vals:
311 cr_lhs, cr_rhs = cr_rhs, cr_lhs
312 lhs_merged = self.with_offset_to_match(
313 cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
314 if lhs_loc is not None:
315 lhs_merged = lhs_merged.with_loc(lhs_loc)
316 rhs_merged = rhs.with_offset_to_match(
317 cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
318 if rhs_loc is not None:
319 rhs_merged = rhs_merged.with_loc(rhs_loc)
320 return lhs_merged.merged(rhs_merged).normalized()
321
322
323 @final
324 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
325 def __init__(self):
326 # type: (...) -> None
327 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
328 self.__ig_node_map = MergedSSAValToIGNodeMap(
329 _private_merged_ssa_val_map=self.__map)
330
331 def __getitem__(self, __key):
332 # type: (SSAVal) -> MergedSSAVal
333 return self.__map[__key]
334
335 def __iter__(self):
336 # type: () -> Iterator[SSAVal]
337 return iter(self.__map)
338
339 def __len__(self):
340 # type: () -> int
341 return len(self.__map)
342
343 @property
344 def ig_node_map(self):
345 # type: () -> MergedSSAValToIGNodeMap
346 return self.__ig_node_map
347
348 def __repr__(self):
349 # type: () -> str
350 s = ",\n".join(repr(v) for v in self.__ig_node_map)
351 return f"SSAValToMergedSSAValMap({{{s}}})"
352
353
354 @final
355 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
356 def __init__(
357 self, *,
358 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
359 ):
360 # type: (...) -> None
361 self.__merged_ssa_val_map = _private_merged_ssa_val_map
362 self.__map = {} # type: dict[MergedSSAVal, IGNode]
363 self.__next_node_id = 0
364
365 def __getitem__(self, __key):
366 # type: (MergedSSAVal) -> IGNode
367 return self.__map[__key]
368
369 def __iter__(self):
370 # type: () -> Iterator[MergedSSAVal]
371 return iter(self.__map)
372
373 def __len__(self):
374 # type: () -> int
375 return len(self.__map)
376
377 def add_node(self, merged_ssa_val):
378 # type: (MergedSSAVal) -> IGNode
379 node = self.__map.get(merged_ssa_val, None)
380 if node is not None:
381 return node
382 added = 0 # type: int | None
383 try:
384 for ssa_val in merged_ssa_val.ssa_vals:
385 if ssa_val in self.__merged_ssa_val_map:
386 raise ValueError(
387 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
388 f"{merged_ssa_val} and "
389 f"{self.__merged_ssa_val_map[ssa_val]}")
390 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
391 added += 1
392 retval = IGNode(
393 node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
394 edges={}, loc=merged_ssa_val.only_loc, ignored=False)
395 self.__map[merged_ssa_val] = retval
396 self.__next_node_id += 1
397 added = None
398 return retval
399 finally:
400 if added is not None:
401 # remove partially added stuff
402 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
403 if idx >= added:
404 break
405 del self.__merged_ssa_val_map[ssa_val]
406
407 def merge_into_one_node(self, final_merged_ssa_val):
408 # type: (MergedSSAVal) -> IGNode
409 source_nodes = OSet() # type: OSet[IGNode]
410 edges = {} # type: dict[IGNode, IGEdge]
411 for ssa_val in final_merged_ssa_val.ssa_vals:
412 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
413 source_node = self.__map[merged_ssa_val]
414 if source_node.ignored:
415 raise ValueError(f"can't merge ignored nodes: {source_node}")
416 source_nodes.add(source_node)
417 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
418 raise ValueError(
419 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
420 f"but not in merged IGNode's merged_ssa_val: "
421 f"source_node={source_node} "
422 f"final_merged_ssa_val={final_merged_ssa_val}")
423 if source_node.loc != source_node.merged_ssa_val.only_loc:
424 raise ValueError(
425 f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
426 f"{source_node.loc} != "
427 f"{source_node.merged_ssa_val.only_loc}")
428 for n, edge in source_node.edges.items():
429 if n in edges:
430 edge = edge.merged(edges[n])
431 edges[n] = edge
432 if len(source_nodes) == 1:
433 return source_nodes.pop() # merging a single node is a no-op
434 # we're finished checking validity, now we can modify stuff
435 for n in source_nodes:
436 edges.pop(n, None)
437 loc = final_merged_ssa_val.only_loc
438 for n, edge in edges.items():
439 if edge.copy_relation is None or not edge.interferes:
440 continue
441 retval = IGNode(
442 node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
443 edges=edges, loc=loc, ignored=False)
444 self.__next_node_id += 1
445 empty_e = IGEdge()
446 for node in edges:
447 edge = reduce(IGEdge.merged,
448 (node.edges.pop(n, empty_e) for n in source_nodes))
449 if edge == empty_e:
450 node.edges.pop(retval, None)
451 else:
452 node.edges[retval] = edge
453 for node in source_nodes:
454 del self.__map[node.merged_ssa_val]
455 self.__map[final_merged_ssa_val] = retval
456 for ssa_val in final_merged_ssa_val.ssa_vals:
457 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
458 return retval
459
460 def __repr__(self, repr_state=None):
461 # type: (None | IGNodeReprState) -> str
462 if repr_state is None:
463 repr_state = IGNodeReprState()
464 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
465 return f"MergedSSAValToIGNodeMap({{{s}}})"
466
467
468 @plain_data(frozen=True, repr=False)
469 @final
470 class InterferenceGraph:
471 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
472
473 def __init__(self, fn_analysis, merged_ssa_vals):
474 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
475 self.fn_analysis = fn_analysis
476 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
477 self.nodes = self.merged_ssa_val_map.ig_node_map
478 for i in merged_ssa_vals:
479 self.nodes.add_node(i)
480
481 def merge_preview(self, ssa_val1, ssa_val2, additional_offset=0):
482 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
483 merged1 = self.merged_ssa_val_map[ssa_val1]
484 merged2 = self.merged_ssa_val_map[ssa_val2]
485 merged = merged1.with_offset_to_match(ssa_val1)
486 return merged.merged(merged2.with_offset_to_match(
487 ssa_val2, additional_offset=additional_offset)).normalized()
488
489 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
490 # type: (SSAVal, SSAVal, int) -> IGNode
491 return self.nodes.merge_into_one_node(self.merge_preview(
492 ssa_val1=ssa_val1, ssa_val2=ssa_val2,
493 additional_offset=additional_offset))
494
495 def copy_merge(self, node1, node2):
496 # type: (IGNode, IGNode) -> IGNode
497 return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
498
499 def local_colorability_score(self, node, merged_in_copy=None):
500 # type: (IGNode, None | IGNode) -> int
501 """ returns a positive integer if node is locally colorable, returns
502 zero or a negative integer if node isn't known to be locally
503 colorable, the more negative the value, the less colorable.
504
505 if `merged_in_copy` is not `None`, then the node used is what would be
506 the result of `self.copy_merge(node, merged_in_copy)`.
507 """
508 if node.ignored:
509 raise ValueError(
510 "can't get local_colorability_score of ignored node")
511 loc_set = node.loc_set
512 edges = node.edges
513 if merged_in_copy is not None:
514 if merged_in_copy.ignored:
515 raise ValueError(
516 "can't get local_colorability_score of ignored node")
517 loc_set = node.copy_merge_preview(merged_in_copy).loc_set
518 edges = edges.copy()
519 for neighbor, edge in merged_in_copy.edges.items():
520 edges[neighbor] = edge.merged(edges.get(neighbor))
521 retval = len(loc_set)
522 for neighbor, edge in edges.items():
523 if neighbor.ignored or not edge.interferes:
524 continue
525 if neighbor == merged_in_copy or neighbor == node:
526 continue
527 retval -= loc_set.max_conflicts_with(neighbor.loc_set)
528 return retval
529
530 @staticmethod
531 def minimally_merged(fn_analysis):
532 # type: (FnAnalysis) -> InterferenceGraph
533 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
534 for op in fn_analysis.fn.ops:
535 for inp in op.input_uses:
536 if inp.unspread_start != inp:
537 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
538 additional_offset=inp.reg_offset_in_unspread)
539 for out in op.outputs:
540 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
541 if out.unspread_start != out:
542 retval.merge(out.unspread_start, out,
543 additional_offset=out.reg_offset_in_unspread)
544 if out.tied_input is not None:
545 retval.merge(out.tied_input.ssa_val, out)
546 return retval
547
548 def __repr__(self, repr_state=None):
549 # type: (None | IGNodeReprState) -> str
550 if repr_state is None:
551 repr_state = IGNodeReprState()
552 s = self.nodes.__repr__(repr_state)
553 return f"InterferenceGraph(nodes={s}, <...>)"
554
555 def dump_to_dot(
556 self, highlighted_nodes=(), # type: Container[IGNode]
557 node_scores=None, # type: None | dict[IGNode, int]
558 edge_scores=None, # type: None | dict[tuple[IGNode, IGNode], int]
559 ):
560 # type: (...) -> str
561
562 def quote(s):
563 # type: (object) -> str
564 s = str(s)
565 s = s.replace('\\', r'\\')
566 s = s.replace('"', r'\"')
567 s = s.replace('\n', r'\n')
568 return f'"{s}"'
569
570 if node_scores is None:
571 node_scores = {}
572 if edge_scores is None:
573 edge_scores = {}
574
575 edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge]
576 node_ids = {} # type: dict[IGNode, str]
577 for node in self.nodes.values():
578 node_ids[node] = quote(node.node_id)
579 for neighbor, edge in node.edges.items():
580 edge_key = (node, neighbor)
581 # ensure we only insert each edge once by checking for
582 # both directions
583 if edge_key not in edges and edge_key[::-1] not in edges:
584 edges[edge_key] = edge
585 lines = [
586 "graph {",
587 " graph [pack = true]",
588 ]
589 for node, node_id in node_ids.items():
590 label_lines = [] # type: list[str]
591 score = node_scores.get(node)
592 if score is not None:
593 label_lines.append(f"score={score}")
594 for k, v in node.merged_ssa_val.ssa_val_offsets.items():
595 label_lines.append(f"{k}: {v}")
596 label = quote("\n".join(label_lines))
597 style = "dotted" if node.ignored else "solid"
598 color = "black"
599 if node in highlighted_nodes:
600 style = "bold"
601 color = "green"
602 style = quote(style)
603 color = quote(color)
604 lines.append(f" {node_id} ["
605 f"label = {label}, "
606 f"style = {style}, "
607 f"color = {color}]")
608
609 def append_edge(node1, node2, label, color, style):
610 # type: (IGNode, IGNode, str, str, str) -> None
611 label = quote(label)
612 color = quote(color)
613 style = quote(style)
614 lines.append(f" {node_ids[node1]} -- {node_ids[node2]} ["
615 f"label = {label}, "
616 f"color = {color}, "
617 f"style = {style}, "
618 f"decorate = true]")
619 for (node1, node2), edge in edges.items():
620 score = edge_scores.get((node1, node2))
621 if score is None:
622 score = edge_scores.get((node2, node1))
623 label_prefix = ""
624 if score is not None:
625 label_prefix = f"score={score}\n"
626 if edge.interferes:
627 append_edge(node1, node2, label=label_prefix + "interferes",
628 color="darkred", style="bold")
629 if edge.copy_relation is not None:
630 append_edge(node1, node2, label=label_prefix + "copy related",
631 color="blue", style="dashed")
632 lines.append("}")
633 return "\n".join(lines)
634
635
636 @plain_data(repr=False)
637 class IGNodeReprState:
638 __slots__ = "did_full_repr",
639
640 def __init__(self):
641 super().__init__()
642 self.did_full_repr = OSet() # type: OSet[IGNode]
643
644
645 @plain_data(frozen=True, unsafe_hash=True)
646 @final
647 class IGEdge:
648 """ interference graph edge """
649 __slots__ = "interferes", "copy_relation"
650
651 def __init__(self, interferes=False, copy_relation=None):
652 # type: (bool, None | _CopyRelation) -> None
653 self.interferes = interferes
654 self.copy_relation = copy_relation
655
656 def merged(self, other):
657 # type: (IGEdge | None) -> IGEdge
658 if other is None:
659 return self
660 copy_relation = self.copy_relation
661 if copy_relation is None:
662 copy_relation = other.copy_relation
663 interferes = self.interferes | other.interferes
664 return IGEdge(interferes=interferes, copy_relation=copy_relation)
665
666
667 @final
668 class IGNode:
669 """ interference graph node """
670 __slots__ = "node_id", "merged_ssa_val", "edges", "loc", "ignored"
671
672 def __init__(self, node_id, merged_ssa_val, edges, loc, ignored):
673 # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
674 self.node_id = node_id
675 self.merged_ssa_val = merged_ssa_val
676 self.edges = edges
677 self.loc = loc
678 self.ignored = ignored
679
680 def merge_edge(self, other, edge):
681 # type: (IGNode, IGEdge) -> None
682 if self == other:
683 raise ValueError("can't have self-loops")
684 old_edge = self.edges.get(other, None)
685 assert old_edge is other.edges.get(self, None), "inconsistent edges"
686 edge = edge.merged(old_edge)
687 if edge == IGEdge():
688 self.edges.pop(other, None)
689 other.edges.pop(self, None)
690 else:
691 self.edges[other] = edge
692 other.edges[self] = edge
693
694 def __eq__(self, other):
695 # type: (object) -> bool
696 if isinstance(other, IGNode):
697 return self.node_id == other.node_id
698 return NotImplemented
699
700 def __hash__(self):
701 # type: () -> int
702 return hash(self.node_id)
703
704 def __repr__(self, repr_state=None, short=False):
705 # type: (None | IGNodeReprState, bool) -> str
706 rs = repr_state
707 del repr_state
708 if rs is None:
709 rs = IGNodeReprState()
710 if short or self in rs.did_full_repr:
711 return f"<IGNode #{self.node_id}>"
712 rs.did_full_repr.add(self)
713 edges = ", ".join(
714 f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
715 return (f"IGNode(#{self.node_id}, "
716 f"merged_ssa_val={self.merged_ssa_val}, "
717 f"edges={{{edges}}}, "
718 f"loc={self.loc}, "
719 f"ignored={self.ignored})")
720
721 @property
722 def loc_set(self):
723 # type: () -> LocSet
724 return self.merged_ssa_val.loc_set
725
726 def loc_conflicts_with_neighbors(self, loc):
727 # type: (Loc) -> bool
728 for neighbor, edge in self.edges.items():
729 if not edge.interferes:
730 continue
731 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
732 return True
733 return False
734
735 def copy_merge_preview(self, rhs_node):
736 # type: (IGNode) -> MergedSSAVal
737 try:
738 copy_relation = self.edges[rhs_node].copy_relation
739 except KeyError:
740 raise ValueError("nodes aren't copy related")
741 if copy_relation is None:
742 raise ValueError("nodes aren't copy related")
743 return self.merged_ssa_val.copy_merged(
744 lhs_loc=self.loc,
745 rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
746 copy_relation=copy_relation)
747
748
749 class AllocationFailedError(Exception):
750 def __init__(self, msg, node, interference_graph):
751 # type: (str, IGNode, InterferenceGraph) -> None
752 super().__init__(msg, node, interference_graph)
753 self.node = node
754 self.interference_graph = interference_graph
755
756 def __repr__(self, repr_state=None):
757 # type: (None | IGNodeReprState) -> str
758 if repr_state is None:
759 repr_state = IGNodeReprState()
760 return (f"{__class__.__name__}({self.args[0]!r}, "
761 f"node={self.node.__repr__(repr_state, True)}, "
762 f"interference_graph="
763 f"{self.interference_graph.__repr__(repr_state)})")
764
765 def __str__(self):
766 # type: () -> str
767 return self.__repr__()
768
769
770 def allocate_registers(
771 fn, # type: Fn
772 debug_out=None, # type: TextIO | None
773 dump_graph=None, # type: Callable[[str, str], None] | None
774 ):
775 # type: (...) -> dict[SSAVal, Loc]
776
777 # inserts enough copies that no manual spilling is necessary, all
778 # spilling is done by the register allocator naturally allocating SSAVals
779 # to stack slots
780 fn.pre_ra_insert_copies()
781
782 if debug_out is not None:
783 print(f"After pre_ra_insert_copies():\n{fn.ops}",
784 file=debug_out, flush=True)
785
786 fn_analysis = FnAnalysis(fn)
787 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
788
789 if debug_out is not None:
790 print(f"After InterferenceGraph.minimally_merged():\n"
791 f"{interference_graph}", file=debug_out, flush=True)
792
793 for i, j in combinations(interference_graph.nodes.values(), 2):
794 copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
795 i.merge_edge(j, IGEdge(copy_relation=copy_relation))
796
797 for pp, ssa_vals in fn_analysis.live_at.items():
798 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
799 for ssa_val in ssa_vals:
800 live_merged_ssa_vals.add(
801 interference_graph.merged_ssa_val_map[ssa_val])
802 for i, j in combinations(live_merged_ssa_vals, 2):
803 if i.loc_set.max_conflicts_with(j.loc_set) == 0:
804 continue
805 node_i = interference_graph.nodes[i]
806 node_j = interference_graph.nodes[j]
807 if node_j in node_i.edges:
808 if node_i.edges[node_j].copy_relation is not None:
809 try:
810 _ = node_i.copy_merge_preview(node_j)
811 continue # doesn't interfere if copy merging succeeds
812 except BadMergedSSAVal:
813 pass
814 node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
815 if debug_out is not None:
816 print(f"processed {pp} out of {fn_analysis.all_program_points}",
817 file=debug_out, flush=True)
818
819 if debug_out is not None:
820 print(f"After adding interference graph edges:\n"
821 f"{interference_graph}", file=debug_out, flush=True)
822 if dump_graph is not None:
823 dump_graph("initial", interference_graph.dump_to_dot())
824
825 node_stack = [] # type: list[IGNode]
826
827 debug_node_scores = {} # type: dict[IGNode, int]
828 debug_edge_scores = {} # type: dict[tuple[IGNode, IGNode], int]
829
830 def find_best_node(has_copy_relation):
831 # type: (bool) -> None | IGNode
832 best_node = None # type: None | IGNode
833 best_score = 0
834 for node in interference_graph.nodes.values():
835 if node.ignored:
836 continue
837 node_has_copy_relation = False
838 for neighbor, edge in node.edges.items():
839 if neighbor.ignored:
840 continue
841 if edge.copy_relation is not None:
842 node_has_copy_relation = True
843 break
844 if node_has_copy_relation != has_copy_relation:
845 continue
846 score = interference_graph.local_colorability_score(node)
847 debug_node_scores[node] = score
848 if best_node is None or score > best_score:
849 best_node = node
850 best_score = score
851 if best_score > 0:
852 # it's locally colorable, no need to find a better one
853 break
854 if debug_out is not None:
855 print(f"find_best_node(has_copy_relation={has_copy_relation}):\n"
856 f"{best_node}", file=debug_out, flush=True)
857 return best_node
858 # copy-merging algorithm based on Iterated Register Coalescing, section 5:
859 # https://dl.acm.org/doi/pdf/10.1145/229542.229546
860 # Build step is above.
861 for step in count():
862 debug_node_scores.clear()
863 debug_edge_scores.clear()
864 # Simplify:
865 best_node = find_best_node(has_copy_relation=False)
866 if best_node is not None:
867 if dump_graph is not None:
868 dump_graph(
869 f"step_{step}_simplify", interference_graph.dump_to_dot(
870 highlighted_nodes=[best_node],
871 node_scores=debug_node_scores,
872 edge_scores=debug_edge_scores))
873 node_stack.append(best_node)
874 best_node.ignored = True
875 continue
876 # Coalesce (aka. do copy-merges):
877 did_any_copy_merges = False
878 for node in interference_graph.nodes.values():
879 if node.ignored:
880 continue
881 for neighbor, edge in node.edges.items():
882 if neighbor.ignored:
883 continue
884 if edge.copy_relation is None:
885 continue
886 try:
887 score = interference_graph.local_colorability_score(
888 node, merged_in_copy=neighbor)
889 except BadMergedSSAVal:
890 continue
891 if (neighbor, node) in debug_edge_scores:
892 debug_edge_scores[(neighbor, node)] = score
893 else:
894 debug_edge_scores[(node, neighbor)] = score
895 if score > 0: # merged node is locally colorable
896 if dump_graph is not None:
897 dump_graph(
898 f"step_{step}_copy_merge",
899 interference_graph.dump_to_dot(
900 highlighted_nodes=[node, neighbor],
901 node_scores=debug_node_scores,
902 edge_scores=debug_edge_scores))
903 if debug_out is not None:
904 print(f"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
905 file=debug_out, flush=True)
906 merged_node = interference_graph.copy_merge(node, neighbor)
907 if dump_graph is not None:
908 dump_graph(
909 f"step_{step}_copy_merge_result",
910 interference_graph.dump_to_dot(
911 highlighted_nodes=[merged_node]))
912 if debug_out is not None:
913 print(f"merged_node:\n"
914 f"{merged_node}", file=debug_out, flush=True)
915 did_any_copy_merges = True
916 break
917 if did_any_copy_merges:
918 break
919 if did_any_copy_merges:
920 continue
921 # Freeze:
922 best_node = find_best_node(has_copy_relation=True)
923 if best_node is not None:
924 if dump_graph is not None:
925 dump_graph(f"step_{step}_freeze",
926 interference_graph.dump_to_dot(
927 highlighted_nodes=[best_node],
928 node_scores=debug_node_scores,
929 edge_scores=debug_edge_scores))
930 # no need to clear copy relations since best_node won't be
931 # considered since it's now ignored.
932 node_stack.append(best_node)
933 best_node.ignored = True
934 continue
935 break
936
937 if dump_graph is not None:
938 dump_graph("final", interference_graph.dump_to_dot())
939 if debug_out is not None:
940 print(f"After deciding node allocation order:\n"
941 f"{node_stack}", file=debug_out, flush=True)
942
943 retval = {} # type: dict[SSAVal, Loc]
944
945 while len(node_stack) > 0:
946 node = node_stack.pop()
947 if node.loc is not None:
948 if node.loc_conflicts_with_neighbors(node.loc):
949 raise AllocationFailedError(
950 "IGNode is pre-allocated to a conflicting Loc",
951 node=node, interference_graph=interference_graph)
952 else:
953 # Locs to try allocating, ordered from most preferred to least
954 # preferred
955 locs = OSet()
956 # prefer eliminating copies
957 for neighbor, edge in node.edges.items():
958 if neighbor.loc is None or edge.copy_relation is None:
959 continue
960 try:
961 merged = node.copy_merge_preview(neighbor)
962 except BadMergedSSAVal:
963 continue
964 # get merged_loc if merged.loc_set has a single Loc
965 merged_loc = merged.only_loc
966 if merged_loc is None:
967 continue
968 ssa_val = node.merged_ssa_val.first_ssa_val
969 ssa_val_loc = merged_loc.get_subloc_at_offset(
970 subloc_ty=ssa_val.ty,
971 offset=merged.ssa_val_offsets[ssa_val])
972 node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
973 superloc_ty=node.merged_ssa_val.ty,
974 offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
975 assert node_loc in node.merged_ssa_val.loc_set, "logic error"
976 locs.add(node_loc)
977 # add node's allowed Locs as fallback
978 for loc in node.loc_set:
979 # TODO: add in order of preference
980 locs.add(loc)
981 # pick the first non-conflicting register in locs, since locs is
982 # ordered from most preferred to least preferred register.
983 for loc in locs:
984 if not node.loc_conflicts_with_neighbors(loc):
985 node.loc = loc
986 break
987 if node.loc is None:
988 raise AllocationFailedError(
989 "failed to allocate Loc for IGNode",
990 node=node, interference_graph=interference_graph)
991
992 if debug_out is not None:
993 print(f"After allocating Loc for node:\n{node}",
994 file=debug_out, flush=True)
995
996 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
997 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
998
999 if debug_out is not None:
1000 print(f"final Locs for all SSAVals:\n{retval}",
1001 file=debug_out, flush=True)
1002
1003 return retval