working on code some more
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from dataclasses import dataclass
9 from functools import lru_cache, reduce
10 from itertools import combinations, count
11 from typing import Any, Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
12
13 from cached_property import cached_property
14 from nmutil.plain_data import plain_data, replace
15
16 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
17 LocSet, Op, ProgramRange,
18 SSAVal, SSAValSubReg, Ty)
19 from bigint_presentation_code.type_util import final
20 from bigint_presentation_code.util import FMap, Interned, OFSet, OSet
21
22
23 class BadMergedSSAVal(ValueError):
24 pass
25
26
27 _CopyRelation = Tuple[SSAValSubReg, SSAValSubReg]
28
29
30 @dataclass(frozen=True, repr=False, eq=False)
31 @final
32 class MergedSSAVal(Interned):
33 """a set of `SSAVal`s along with their offsets, all register allocated as
34 a single unit.
35
36 Definition of the term `offset` for this class:
37
38 Let `locs[x]` be the `Loc` that `x` is assigned to after register
39 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
40 for each `SSAVal` `ssa_val` in `msv` is defined as:
41
42 ```
43 msv.ssa_val_offsets[ssa_val] = (msv.offset
44 + locs[ssa_val].start - locs[msv].start)
45 ```
46
47 Example:
48 ```
49 v1.ty == <I64*4>
50 v2.ty == <I64*2>
51 v3.ty == <I64>
52 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
53 msv.ty == <I64*6>
54 ```
55 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
56 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
57 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
58 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
59 """
60 fn_analysis: FnAnalysis
61 ssa_val_offsets: "FMap[SSAVal, int]"
62 first_ssa_val: SSAVal
63 loc_set: LocSet
64 first_loc: Loc
65
66 def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
67 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
68 object.__setattr__(self, "fn_analysis", fn_analysis)
69 if isinstance(ssa_val_offsets, SSAVal):
70 ssa_val_offsets = {ssa_val_offsets: 0}
71 object.__setattr__(self, "ssa_val_offsets", FMap(ssa_val_offsets))
72 first_ssa_val = None
73 for ssa_val in self.ssa_vals:
74 first_ssa_val = ssa_val
75 break
76 if first_ssa_val is None:
77 raise BadMergedSSAVal("MergedSSAVal can't be empty")
78 object.__setattr__(self, "first_ssa_val", first_ssa_val)
79 # self.ty checks for mismatched base_ty
80 reg_len = self.ty.reg_len
81 if loc_set is not None and loc_set.ty != self.ty:
82 raise ValueError(
83 f"invalid loc_set, type doesn't match: "
84 f"{loc_set.ty} != {self.ty}")
85 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
86 def locs():
87 # type: () -> Iterable[Loc]
88 for loc in ssa_val.def_loc_set_before_spread:
89 disallowed_by_use = False
90 for use in fn_analysis.uses[ssa_val]:
91 # calculate the start for the use's Loc before spread
92 # e.g. if the def's Loc before spread starts at r6
93 # and the def's reg_offset_in_unspread is 5
94 # and the use's reg_offset_in_unspread is 3
95 # then the use's Loc before spread starts at r8
96 # because 8 == 6 + 5 - 3
97 start = (loc.start + ssa_val.reg_offset_in_unspread
98 - use.reg_offset_in_unspread)
99 use_loc = Loc.try_make(
100 loc.kind, start=start,
101 reg_len=use.ty_before_spread.reg_len)
102 if (use_loc is None or
103 use_loc not in use.use_loc_set_before_spread):
104 disallowed_by_use = True
105 break
106 if disallowed_by_use:
107 continue
108 start = loc.start - cur_offset + self.offset
109 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
110 if loc is not None and (loc_set is None or loc in loc_set):
111 yield loc
112 loc_set = LocSet(locs())
113 assert loc_set is not None, "already checked that self isn't empty"
114 first_loc = None
115 for loc in loc_set:
116 first_loc = loc
117 break
118 if first_loc is None:
119 raise BadMergedSSAVal("there are no valid Locs left")
120 object.__setattr__(self, "first_loc", first_loc)
121 assert loc_set.ty == self.ty, "logic error somewhere"
122 object.__setattr__(self, "loc_set", loc_set)
123 self.__mergable_check()
124
125 def __mergable_check(self):
126 # type: () -> None
127 """ checks that nothing is forcing two independent SSAVals
128 to illegally overlap. This is required to avoid copy merging merging
129 things that can't be merged.
130 spread arguments are one of the things that can force two values to
131 illegally overlap.
132 """
133 ops = OSet() # type: Iterable[Op]
134 for ssa_val in self.ssa_vals:
135 ops.add(ssa_val.op)
136 for use in self.fn_analysis.uses[ssa_val]:
137 ops.add(use.op)
138 ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
139 vals = {} # type: dict[int, SSAValSubReg]
140 for op in ops:
141 for inp in op.input_vals:
142 try:
143 ssa_val_offset = self.ssa_val_offsets[inp]
144 except KeyError:
145 continue
146 for orig_reg in inp.ssa_val_sub_regs:
147 reg_offset = ssa_val_offset + orig_reg.reg_idx
148 replaced_reg = vals[reg_offset]
149 if not self.fn_analysis.is_always_equal(
150 orig_reg, replaced_reg):
151 raise BadMergedSSAVal(
152 f"attempting to merge values that aren't known to "
153 f"be always equal: {orig_reg} != {replaced_reg}")
154 output_offsets = dict.fromkeys(range(
155 self.offset, self.offset + self.ty.reg_len))
156 for out in op.outputs:
157 try:
158 ssa_val_offset = self.ssa_val_offsets[out]
159 except KeyError:
160 continue
161 for reg in out.ssa_val_sub_regs:
162 reg_offset = ssa_val_offset + reg.reg_idx
163 try:
164 del output_offsets[reg_offset]
165 except KeyError:
166 raise BadMergedSSAVal("attempted to merge two outputs "
167 "of the same instruction")
168 vals[reg_offset] = reg
169
170 def __hash__(self):
171 # type: () -> int
172 return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
173
174 def __eq__(self, other):
175 # type: (MergedSSAVal | Any) -> bool
176 if isinstance(other, MergedSSAVal):
177 return self.fn_analysis == other.fn_analysis and \
178 self.ssa_val_offsets == other.ssa_val_offsets and \
179 self.loc_set == other.loc_set
180 return NotImplemented
181
182 @property
183 def only_loc(self):
184 # type: () -> Loc | None
185 return self.loc_set.only_loc
186
187 @cached_property
188 def offset(self):
189 # type: () -> int
190 return min(self.ssa_val_offsets_before_spread.values())
191
192 @property
193 def base_ty(self):
194 # type: () -> BaseTy
195 return self.first_ssa_val.base_ty
196
197 @cached_property
198 def ssa_vals(self):
199 # type: () -> OFSet[SSAVal]
200 return OFSet(self.ssa_val_offsets.keys())
201
202 @cached_property
203 def ty(self):
204 # type: () -> Ty
205 reg_len = 0
206 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
207 cur_ty = ssa_val.ty_before_spread
208 if self.base_ty != cur_ty.base_ty:
209 raise BadMergedSSAVal(
210 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
211 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
212 return Ty(base_ty=self.base_ty, reg_len=reg_len)
213
214 @cached_property
215 def ssa_val_offsets_before_spread(self):
216 # type: () -> FMap[SSAVal, int]
217 retval = {} # type: dict[SSAVal, int]
218 for ssa_val, offset in self.ssa_val_offsets.items():
219 retval[ssa_val] = (
220 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
221 return FMap(retval)
222
223 def offset_by(self, amount):
224 # type: (int) -> MergedSSAVal
225 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
226 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
227
228 def normalized(self):
229 # type: () -> MergedSSAVal
230 return self.offset_by(-self.offset)
231
232 def with_offset_to_match(self, target, additional_offset=0):
233 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
234 if isinstance(target, MergedSSAVal):
235 ssa_val_offsets = target.ssa_val_offsets
236 else:
237 ssa_val_offsets = {target: 0}
238 for ssa_val, offset in self.ssa_val_offsets.items():
239 if ssa_val in ssa_val_offsets:
240 return self.offset_by(
241 ssa_val_offsets[ssa_val] + additional_offset - offset)
242 raise ValueError("can't change offset to match unrelated MergedSSAVal")
243
244 def with_loc(self, loc):
245 # type: (Loc) -> MergedSSAVal
246 if loc not in self.loc_set:
247 raise ValueError(
248 f"Loc is not allowed -- not a member of `self.loc_set`: "
249 f"{loc} not in {self.loc_set}")
250 return MergedSSAVal(fn_analysis=self.fn_analysis,
251 ssa_val_offsets=self.ssa_val_offsets,
252 loc_set=LocSet([loc]))
253
254 def merged(self, *others):
255 # type: (*MergedSSAVal) -> MergedSSAVal
256 retval = dict(self.ssa_val_offsets)
257 for other in others:
258 if other.fn_analysis != self.fn_analysis:
259 raise ValueError("fn_analysis mismatch")
260 for ssa_val, offset in other.ssa_val_offsets.items():
261 if ssa_val in retval and retval[ssa_val] != offset:
262 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
263 f"{retval[ssa_val]} != {offset}")
264 retval[ssa_val] = offset
265 return MergedSSAVal(fn_analysis=self.fn_analysis,
266 ssa_val_offsets=retval)
267
268 @cached_property
269 def live_interval(self):
270 # type: () -> ProgramRange
271 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
272 start = live_range.start
273 stop = live_range.stop
274 for ssa_val in self.ssa_vals:
275 live_range = self.fn_analysis.live_ranges[ssa_val]
276 start = min(start, live_range.start)
277 stop = max(stop, live_range.stop)
278 return ProgramRange(start=start, stop=stop)
279
280 def __repr__(self):
281 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
282 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
283 f"live_interval={self.live_interval})")
284
285 @cached_property
286 def copy_related_ssa_vals(self):
287 # type: () -> OFSet[SSAVal]
288 sets = OSet() # type: OSet[OFSet[SSAVal]]
289 # avoid merging the same sets multiple times
290 for ssa_val in self.ssa_vals:
291 sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
292 return OFSet(v for s in sets for v in s)
293
294 def get_copy_relation(self, other):
295 # type: (MergedSSAVal) -> None | _CopyRelation
296 for lhs_ssa_val in self.ssa_vals:
297 for rhs_ssa_val in other.ssa_vals:
298 for lhs in lhs_ssa_val.ssa_val_sub_regs:
299 for rhs in rhs_ssa_val.ssa_val_sub_regs:
300 lhs_src = self.fn_analysis.copies.get(lhs, lhs)
301 rhs_src = self.fn_analysis.copies.get(rhs, rhs)
302 if lhs_src == rhs_src:
303 return lhs, rhs
304 return None
305
306 @lru_cache(maxsize=None, typed=True)
307 def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
308 # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
309 cr_lhs, cr_rhs = copy_relation
310 if cr_lhs.ssa_val not in self.ssa_vals:
311 cr_lhs, cr_rhs = cr_rhs, cr_lhs
312 lhs_merged = self.with_offset_to_match(
313 cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
314 if lhs_loc is not None:
315 lhs_merged = lhs_merged.with_loc(lhs_loc)
316 rhs_merged = rhs.with_offset_to_match(
317 cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
318 if rhs_loc is not None:
319 rhs_merged = rhs_merged.with_loc(rhs_loc)
320 return lhs_merged.merged(rhs_merged).normalized()
321
322
323 @final
324 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
325 def __init__(self):
326 # type: (...) -> None
327 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
328 self.__ig_node_map = MergedSSAValToIGNodeMap(
329 _private_merged_ssa_val_map=self.__map)
330
331 def __getitem__(self, __key):
332 # type: (SSAVal) -> MergedSSAVal
333 return self.__map[__key]
334
335 def __iter__(self):
336 # type: () -> Iterator[SSAVal]
337 return iter(self.__map)
338
339 def __len__(self):
340 # type: () -> int
341 return len(self.__map)
342
343 @property
344 def ig_node_map(self):
345 # type: () -> MergedSSAValToIGNodeMap
346 return self.__ig_node_map
347
348 def __repr__(self):
349 # type: () -> str
350 s = ",\n".join(repr(v) for v in self.__ig_node_map)
351 return f"SSAValToMergedSSAValMap({{{s}}})"
352
353
354 @final
355 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
356 def __init__(
357 self, *,
358 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
359 ):
360 # type: (...) -> None
361 self.__merged_ssa_val_map = _private_merged_ssa_val_map
362 self.__map = {} # type: dict[MergedSSAVal, IGNode]
363 self.__next_node_id = 0
364
365 def __getitem__(self, __key):
366 # type: (MergedSSAVal) -> IGNode
367 return self.__map[__key]
368
369 def __iter__(self):
370 # type: () -> Iterator[MergedSSAVal]
371 return iter(self.__map)
372
373 def __len__(self):
374 # type: () -> int
375 return len(self.__map)
376
377 def add_node(self, merged_ssa_val):
378 # type: (MergedSSAVal) -> IGNode
379 node = self.__map.get(merged_ssa_val, None)
380 if node is not None:
381 return node
382 added = 0 # type: int | None
383 try:
384 for ssa_val in merged_ssa_val.ssa_vals:
385 if ssa_val in self.__merged_ssa_val_map:
386 raise ValueError(
387 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
388 f"{merged_ssa_val} and "
389 f"{self.__merged_ssa_val_map[ssa_val]}")
390 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
391 added += 1
392 retval = IGNode(
393 node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
394 edges={}, loc=merged_ssa_val.only_loc, ignored=False)
395 self.__map[merged_ssa_val] = retval
396 self.__next_node_id += 1
397 added = None
398 return retval
399 finally:
400 if added is not None:
401 # remove partially added stuff
402 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
403 if idx >= added:
404 break
405 del self.__merged_ssa_val_map[ssa_val]
406
407 def merge_into_one_node(self, final_merged_ssa_val):
408 # type: (MergedSSAVal) -> IGNode
409 source_nodes = OSet() # type: OSet[IGNode]
410 edges = {} # type: dict[IGNode, IGEdge]
411 for ssa_val in final_merged_ssa_val.ssa_vals:
412 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
413 source_node = self.__map[merged_ssa_val]
414 if source_node.ignored:
415 raise ValueError(f"can't merge ignored nodes: {source_node}")
416 source_nodes.add(source_node)
417 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
418 raise ValueError(
419 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
420 f"but not in merged IGNode's merged_ssa_val: "
421 f"source_node={source_node} "
422 f"final_merged_ssa_val={final_merged_ssa_val}")
423 if source_node.loc != source_node.merged_ssa_val.only_loc:
424 raise ValueError(
425 f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
426 f"{source_node.loc} != "
427 f"{source_node.merged_ssa_val.only_loc}")
428 for n, edge in source_node.edges.items():
429 if n in edges:
430 edge = edge.merged(edges[n])
431 edges[n] = edge
432 if len(source_nodes) == 1:
433 return source_nodes.pop() # merging a single node is a no-op
434 # we're finished checking validity, now we can modify stuff
435 for n in source_nodes:
436 edges.pop(n, None)
437 loc = final_merged_ssa_val.only_loc
438 for n, edge in edges.items():
439 if edge.copy_relation is None or not edge.interferes:
440 continue
441 try:
442 # if merging works, then the edge can't interfere
443 _ = final_merged_ssa_val.copy_merged(
444 lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc,
445 copy_relation=edge.copy_relation)
446 except BadMergedSSAVal:
447 continue
448 edges[n] = replace(edge, interferes=False)
449 retval = IGNode(
450 node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
451 edges=edges, loc=loc, ignored=False)
452 self.__next_node_id += 1
453 empty_e = IGEdge()
454 for node in edges:
455 edge = reduce(IGEdge.merged,
456 (node.edges.pop(n, empty_e) for n in source_nodes))
457 if edge == empty_e:
458 node.edges.pop(retval, None)
459 else:
460 node.edges[retval] = edge
461 for node in source_nodes:
462 del self.__map[node.merged_ssa_val]
463 self.__map[final_merged_ssa_val] = retval
464 for ssa_val in final_merged_ssa_val.ssa_vals:
465 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
466 return retval
467
468 def __repr__(self, repr_state=None):
469 # type: (None | IGNodeReprState) -> str
470 if repr_state is None:
471 repr_state = IGNodeReprState()
472 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
473 return f"MergedSSAValToIGNodeMap({{{s}}})"
474
475
476 @plain_data(frozen=True, repr=False)
477 @final
478 class InterferenceGraph:
479 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
480
481 def __init__(self, fn_analysis, merged_ssa_vals):
482 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
483 self.fn_analysis = fn_analysis
484 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
485 self.nodes = self.merged_ssa_val_map.ig_node_map
486 for i in merged_ssa_vals:
487 self.nodes.add_node(i)
488
489 def merge_preview(self, ssa_val1, ssa_val2, additional_offset=0):
490 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
491 merged1 = self.merged_ssa_val_map[ssa_val1]
492 merged2 = self.merged_ssa_val_map[ssa_val2]
493 merged = merged1.with_offset_to_match(ssa_val1)
494 return merged.merged(merged2.with_offset_to_match(
495 ssa_val2, additional_offset=additional_offset)).normalized()
496
497 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
498 # type: (SSAVal, SSAVal, int) -> IGNode
499 return self.nodes.merge_into_one_node(self.merge_preview(
500 ssa_val1=ssa_val1, ssa_val2=ssa_val2,
501 additional_offset=additional_offset))
502
503 def copy_merge(self, node1, node2):
504 # type: (IGNode, IGNode) -> IGNode
505 return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
506
507 def local_colorability_score(self, node, merged_in_copy=None):
508 # type: (IGNode, None | IGNode) -> int
509 """ returns a positive integer if node is locally colorable, returns
510 zero or a negative integer if node isn't known to be locally
511 colorable, the more negative the value, the less colorable.
512
513 if `merged_in_copy` is not `None`, then the node used is what would be
514 the result of `self.copy_merge(node, merged_in_copy)`.
515 """
516 if node.ignored:
517 raise ValueError(
518 "can't get local_colorability_score of ignored node")
519 loc_set = node.loc_set
520 edges = node.edges
521 if merged_in_copy is not None:
522 if merged_in_copy.ignored:
523 raise ValueError(
524 "can't get local_colorability_score of ignored node")
525 loc_set = node.copy_merge_preview(merged_in_copy).loc_set
526 edges = edges.copy()
527 for neighbor, edge in merged_in_copy.edges.items():
528 edges[neighbor] = edge.merged(edges.get(neighbor))
529 retval = len(loc_set)
530 for neighbor, edge in edges.items():
531 if neighbor.ignored or not edge.interferes:
532 continue
533 if neighbor == merged_in_copy or neighbor == node:
534 continue
535 retval -= loc_set.max_conflicts_with(neighbor.loc_set)
536 return retval
537
538 @staticmethod
539 def minimally_merged(fn_analysis):
540 # type: (FnAnalysis) -> InterferenceGraph
541 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
542 for op in fn_analysis.fn.ops:
543 for inp in op.input_uses:
544 if inp.unspread_start != inp:
545 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
546 additional_offset=inp.reg_offset_in_unspread)
547 for out in op.outputs:
548 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
549 if out.unspread_start != out:
550 retval.merge(out.unspread_start, out,
551 additional_offset=out.reg_offset_in_unspread)
552 if out.tied_input is not None:
553 retval.merge(out.tied_input.ssa_val, out)
554 return retval
555
556 def __repr__(self, repr_state=None):
557 # type: (None | IGNodeReprState) -> str
558 if repr_state is None:
559 repr_state = IGNodeReprState()
560 s = self.nodes.__repr__(repr_state)
561 return f"InterferenceGraph(nodes={s}, <...>)"
562
563 def dump_to_dot(
564 self, highlighted_nodes=(), # type: Container[IGNode]
565 node_scores=None, # type: None | dict[IGNode, int]
566 edge_scores=None, # type: None | dict[tuple[IGNode, IGNode], int]
567 ):
568 # type: (...) -> str
569
570 def quote(s):
571 # type: (object) -> str
572 s = str(s)
573 s = s.replace('\\', r'\\')
574 s = s.replace('"', r'\"')
575 s = s.replace('\n', r'\n')
576 return f'"{s}"'
577
578 if node_scores is None:
579 node_scores = {}
580 if edge_scores is None:
581 edge_scores = {}
582
583 edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge]
584 node_ids = {} # type: dict[IGNode, str]
585 for node in self.nodes.values():
586 node_ids[node] = quote(node.node_id)
587 for neighbor, edge in node.edges.items():
588 edge_key = (node, neighbor)
589 # ensure we only insert each edge once by checking for
590 # both directions
591 if edge_key not in edges and edge_key[::-1] not in edges:
592 edges[edge_key] = edge
593 lines = [
594 "graph {",
595 " graph [pack = true]",
596 ]
597 for node, node_id in node_ids.items():
598 label_lines = [] # type: list[str]
599 score = node_scores.get(node)
600 if score is not None:
601 label_lines.append(f"score={score}")
602 for k, v in node.merged_ssa_val.ssa_val_offsets.items():
603 label_lines.append(f"{k}: {v}")
604 label = quote("\n".join(label_lines))
605 style = "dotted" if node.ignored else "solid"
606 color = "black"
607 if node in highlighted_nodes:
608 style = "bold"
609 color = "green"
610 style = quote(style)
611 color = quote(color)
612 lines.append(f" {node_id} ["
613 f"label = {label}, "
614 f"style = {style}, "
615 f"color = {color}]")
616
617 def append_edge(node1, node2, label, color, style):
618 # type: (IGNode, IGNode, str, str, str) -> None
619 label = quote(label)
620 color = quote(color)
621 style = quote(style)
622 lines.append(f" {node_ids[node1]} -- {node_ids[node2]} ["
623 f"label = {label}, "
624 f"color = {color}, "
625 f"style = {style}, "
626 f"decorate = true]")
627 for (node1, node2), edge in edges.items():
628 score = edge_scores.get((node1, node2))
629 if score is None:
630 score = edge_scores.get((node2, node1))
631 label_prefix = ""
632 if score is not None:
633 label_prefix = f"score={score}\n"
634 if edge.interferes:
635 append_edge(node1, node2, label=label_prefix + "interferes",
636 color="darkred", style="bold")
637 if edge.copy_relation is not None:
638 append_edge(node1, node2, label=label_prefix + "copy related",
639 color="blue", style="dashed")
640 lines.append("}")
641 return "\n".join(lines)
642
643
644 @plain_data(repr=False)
645 class IGNodeReprState:
646 __slots__ = "did_full_repr",
647
648 def __init__(self):
649 super().__init__()
650 self.did_full_repr = OSet() # type: OSet[IGNode]
651
652
653 @plain_data(frozen=True, unsafe_hash=True)
654 @final
655 class IGEdge:
656 """ interference graph edge """
657 __slots__ = "interferes", "copy_relation"
658
659 def __init__(self, interferes=False, copy_relation=None):
660 # type: (bool, None | _CopyRelation) -> None
661 self.interferes = interferes
662 self.copy_relation = copy_relation
663
664 def merged(self, other):
665 # type: (IGEdge | None) -> IGEdge
666 if other is None:
667 return self
668 copy_relation = self.copy_relation
669 if copy_relation is None:
670 copy_relation = other.copy_relation
671 interferes = self.interferes | other.interferes
672 return IGEdge(interferes=interferes, copy_relation=copy_relation)
673
674
675 @final
676 class IGNode:
677 """ interference graph node """
678 __slots__ = "node_id", "merged_ssa_val", "edges", "loc", "ignored"
679
680 def __init__(self, node_id, merged_ssa_val, edges, loc, ignored):
681 # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
682 self.node_id = node_id
683 self.merged_ssa_val = merged_ssa_val
684 self.edges = edges
685 self.loc = loc
686 self.ignored = ignored
687
688 def merge_edge(self, other, edge):
689 # type: (IGNode, IGEdge) -> None
690 if self == other:
691 raise ValueError("can't have self-loops")
692 old_edge = self.edges.get(other, None)
693 assert old_edge is other.edges.get(self, None), "inconsistent edges"
694 edge = edge.merged(old_edge)
695 if edge == IGEdge():
696 self.edges.pop(other, None)
697 other.edges.pop(self, None)
698 else:
699 self.edges[other] = edge
700 other.edges[self] = edge
701
702 def __eq__(self, other):
703 # type: (object) -> bool
704 if isinstance(other, IGNode):
705 return self.node_id == other.node_id
706 return NotImplemented
707
708 def __hash__(self):
709 # type: () -> int
710 return hash(self.node_id)
711
712 def __repr__(self, repr_state=None, short=False):
713 # type: (None | IGNodeReprState, bool) -> str
714 rs = repr_state
715 del repr_state
716 if rs is None:
717 rs = IGNodeReprState()
718 if short or self in rs.did_full_repr:
719 return f"<IGNode #{self.node_id}>"
720 rs.did_full_repr.add(self)
721 edges = ", ".join(
722 f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
723 return (f"IGNode(#{self.node_id}, "
724 f"merged_ssa_val={self.merged_ssa_val}, "
725 f"edges={{{edges}}}, "
726 f"loc={self.loc}, "
727 f"ignored={self.ignored})")
728
729 @property
730 def loc_set(self):
731 # type: () -> LocSet
732 return self.merged_ssa_val.loc_set
733
734 def loc_conflicts_with_neighbors(self, loc):
735 # type: (Loc) -> bool
736 for neighbor, edge in self.edges.items():
737 if not edge.interferes:
738 continue
739 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
740 return True
741 return False
742
743 def copy_merge_preview(self, rhs_node):
744 # type: (IGNode) -> MergedSSAVal
745 try:
746 copy_relation = self.edges[rhs_node].copy_relation
747 except KeyError:
748 raise ValueError("nodes aren't copy related")
749 if copy_relation is None:
750 raise ValueError("nodes aren't copy related")
751 return self.merged_ssa_val.copy_merged(
752 lhs_loc=self.loc,
753 rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
754 copy_relation=copy_relation)
755
756
757 class AllocationFailedError(Exception):
758 def __init__(self, msg, node, interference_graph):
759 # type: (str, IGNode, InterferenceGraph) -> None
760 super().__init__(msg, node, interference_graph)
761 self.node = node
762 self.interference_graph = interference_graph
763
764 def __repr__(self, repr_state=None):
765 # type: (None | IGNodeReprState) -> str
766 if repr_state is None:
767 repr_state = IGNodeReprState()
768 return (f"{__class__.__name__}({self.args[0]!r}, "
769 f"node={self.node.__repr__(repr_state, True)}, "
770 f"interference_graph="
771 f"{self.interference_graph.__repr__(repr_state)})")
772
773 def __str__(self):
774 # type: () -> str
775 return self.__repr__()
776
777
778 def allocate_registers(
779 fn, # type: Fn
780 debug_out=None, # type: TextIO | None
781 dump_graph=None, # type: Callable[[str, str], None] | None
782 ):
783 # type: (...) -> dict[SSAVal, Loc]
784
785 # inserts enough copies that no manual spilling is necessary, all
786 # spilling is done by the register allocator naturally allocating SSAVals
787 # to stack slots
788 fn.pre_ra_insert_copies()
789
790 if debug_out is not None:
791 print(f"After pre_ra_insert_copies():\n{fn.ops}",
792 file=debug_out, flush=True)
793
794 fn_analysis = FnAnalysis(fn)
795 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
796
797 if debug_out is not None:
798 print(f"After InterferenceGraph.minimally_merged():\n"
799 f"{interference_graph}", file=debug_out, flush=True)
800
801 for i, j in combinations(interference_graph.nodes.values(), 2):
802 copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
803 i.merge_edge(j, IGEdge(copy_relation=copy_relation))
804
805 for pp, ssa_vals in fn_analysis.live_at.items():
806 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
807 for ssa_val in ssa_vals:
808 live_merged_ssa_vals.add(
809 interference_graph.merged_ssa_val_map[ssa_val])
810 for i, j in combinations(live_merged_ssa_vals, 2):
811 if i.loc_set.max_conflicts_with(j.loc_set) == 0:
812 continue
813 node_i = interference_graph.nodes[i]
814 node_j = interference_graph.nodes[j]
815 if node_j in node_i.edges:
816 if node_i.edges[node_j].copy_relation is not None:
817 try:
818 _ = node_i.copy_merge_preview(node_j)
819 continue # doesn't interfere if copy merging succeeds
820 except BadMergedSSAVal:
821 pass
822 node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
823 if debug_out is not None:
824 print(f"processed {pp} out of {fn_analysis.all_program_points}",
825 file=debug_out, flush=True)
826
827 if debug_out is not None:
828 print(f"After adding interference graph edges:\n"
829 f"{interference_graph}", file=debug_out, flush=True)
830 if dump_graph is not None:
831 dump_graph("initial", interference_graph.dump_to_dot())
832
833 node_stack = [] # type: list[IGNode]
834
835 debug_node_scores = {} # type: dict[IGNode, int]
836 debug_edge_scores = {} # type: dict[tuple[IGNode, IGNode], int]
837
838 def find_best_node(has_copy_relation):
839 # type: (bool) -> None | IGNode
840 best_node = None # type: None | IGNode
841 best_score = 0
842 for node in interference_graph.nodes.values():
843 if node.ignored:
844 continue
845 node_has_copy_relation = False
846 for neighbor, edge in node.edges.items():
847 if neighbor.ignored:
848 continue
849 if edge.copy_relation is not None:
850 node_has_copy_relation = True
851 break
852 if node_has_copy_relation != has_copy_relation:
853 continue
854 score = interference_graph.local_colorability_score(node)
855 debug_node_scores[node] = score
856 if best_node is None or score > best_score:
857 best_node = node
858 best_score = score
859 if best_score > 0:
860 # it's locally colorable, no need to find a better one
861 break
862 if debug_out is not None:
863 print(f"find_best_node(has_copy_relation={has_copy_relation}):\n"
864 f"{best_node}", file=debug_out, flush=True)
865 return best_node
866 # copy-merging algorithm based on Iterated Register Coalescing, section 5:
867 # https://dl.acm.org/doi/pdf/10.1145/229542.229546
868 # Build step is above.
869 for step in count():
870 debug_node_scores.clear()
871 debug_edge_scores.clear()
872 # Simplify:
873 best_node = find_best_node(has_copy_relation=False)
874 if best_node is not None:
875 if dump_graph is not None:
876 dump_graph(
877 f"step_{step}_simplify", interference_graph.dump_to_dot(
878 highlighted_nodes=[best_node],
879 node_scores=debug_node_scores,
880 edge_scores=debug_edge_scores))
881 node_stack.append(best_node)
882 best_node.ignored = True
883 continue
884 # Coalesce (aka. do copy-merges):
885 did_any_copy_merges = False
886 for node in interference_graph.nodes.values():
887 if node.ignored:
888 continue
889 for neighbor, edge in node.edges.items():
890 if neighbor.ignored:
891 continue
892 if edge.copy_relation is None:
893 continue
894 try:
895 score = interference_graph.local_colorability_score(
896 node, merged_in_copy=neighbor)
897 except BadMergedSSAVal:
898 continue
899 if (neighbor, node) in debug_edge_scores:
900 debug_edge_scores[(neighbor, node)] = score
901 else:
902 debug_edge_scores[(node, neighbor)] = score
903 if score > 0: # merged node is locally colorable
904 if dump_graph is not None:
905 dump_graph(
906 f"step_{step}_copy_merge",
907 interference_graph.dump_to_dot(
908 highlighted_nodes=[node, neighbor],
909 node_scores=debug_node_scores,
910 edge_scores=debug_edge_scores))
911 if debug_out is not None:
912 print(f"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
913 file=debug_out, flush=True)
914 merged_node = interference_graph.copy_merge(node, neighbor)
915 if dump_graph is not None:
916 dump_graph(
917 f"step_{step}_copy_merge_result",
918 interference_graph.dump_to_dot(
919 highlighted_nodes=[merged_node]))
920 if debug_out is not None:
921 print(f"merged_node:\n"
922 f"{merged_node}", file=debug_out, flush=True)
923 did_any_copy_merges = True
924 break
925 if did_any_copy_merges:
926 break
927 if did_any_copy_merges:
928 continue
929 # Freeze:
930 best_node = find_best_node(has_copy_relation=True)
931 if best_node is not None:
932 if dump_graph is not None:
933 dump_graph(f"step_{step}_freeze",
934 interference_graph.dump_to_dot(
935 highlighted_nodes=[best_node],
936 node_scores=debug_node_scores,
937 edge_scores=debug_edge_scores))
938 # no need to clear copy relations since best_node won't be
939 # considered since it's now ignored.
940 node_stack.append(best_node)
941 best_node.ignored = True
942 continue
943 break
944
945 if dump_graph is not None:
946 dump_graph("final", interference_graph.dump_to_dot())
947 if debug_out is not None:
948 print(f"After deciding node allocation order:\n"
949 f"{node_stack}", file=debug_out, flush=True)
950
951 retval = {} # type: dict[SSAVal, Loc]
952
953 while len(node_stack) > 0:
954 node = node_stack.pop()
955 if node.loc is not None:
956 if node.loc_conflicts_with_neighbors(node.loc):
957 raise AllocationFailedError(
958 "IGNode is pre-allocated to a conflicting Loc",
959 node=node, interference_graph=interference_graph)
960 else:
961 # Locs to try allocating, ordered from most preferred to least
962 # preferred
963 locs = OSet()
964 # prefer eliminating copies
965 for neighbor, edge in node.edges.items():
966 if neighbor.loc is None or edge.copy_relation is None:
967 continue
968 try:
969 merged = node.copy_merge_preview(neighbor)
970 except BadMergedSSAVal:
971 continue
972 # get merged_loc if merged.loc_set has a single Loc
973 merged_loc = merged.only_loc
974 if merged_loc is None:
975 continue
976 ssa_val = node.merged_ssa_val.first_ssa_val
977 ssa_val_loc = merged_loc.get_subloc_at_offset(
978 subloc_ty=ssa_val.ty,
979 offset=merged.ssa_val_offsets[ssa_val])
980 node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
981 superloc_ty=node.merged_ssa_val.ty,
982 offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
983 assert node_loc in node.merged_ssa_val.loc_set, "logic error"
984 locs.add(node_loc)
985 # add node's allowed Locs as fallback
986 for loc in node.loc_set:
987 # TODO: add in order of preference
988 locs.add(loc)
989 # pick the first non-conflicting register in locs, since locs is
990 # ordered from most preferred to least preferred register.
991 for loc in locs:
992 if not node.loc_conflicts_with_neighbors(loc):
993 node.loc = loc
994 break
995 if node.loc is None:
996 raise AllocationFailedError(
997 "failed to allocate Loc for IGNode",
998 node=node, interference_graph=interference_graph)
999
1000 if debug_out is not None:
1001 print(f"After allocating Loc for node:\n{node}",
1002 file=debug_out, flush=True)
1003
1004 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
1005 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
1006
1007 if debug_out is not None:
1008 print(f"final Locs for all SSAVals:\n{retval}",
1009 file=debug_out, flush=True)
1010
1011 return retval