be1522ecc6c38fa3787e7b4502854af1799bb17b
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from functools import reduce
9 from itertools import combinations
10 from typing import Iterable, Iterator, Mapping, TextIO
11
12 from cached_property import cached_property
13 from nmutil.plain_data import plain_data
14
15 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
16 LocSet, Op, ProgramRange,
17 SSAVal, SSAValSubReg, Ty)
18 from bigint_presentation_code.type_util import final
19 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
20
21
22 class BadMergedSSAVal(ValueError):
23 pass
24
25
26 @plain_data(frozen=True, repr=False)
27 @final
28 class MergedSSAVal(metaclass=InternedMeta):
29 """a set of `SSAVal`s along with their offsets, all register allocated as
30 a single unit.
31
32 Definition of the term `offset` for this class:
33
34 Let `locs[x]` be the `Loc` that `x` is assigned to after register
35 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
36 for each `SSAVal` `ssa_val` in `msv` is defined as:
37
38 ```
39 msv.ssa_val_offsets[ssa_val] = (msv.offset
40 + locs[ssa_val].start - locs[msv].start)
41 ```
42
43 Example:
44 ```
45 v1.ty == <I64*4>
46 v2.ty == <I64*2>
47 v3.ty == <I64>
48 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
49 msv.ty == <I64*6>
50 ```
51 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
52 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
53 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
54 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
55 """
56 __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
57 "first_loc")
58
59 def __init__(self, fn_analysis, ssa_val_offsets):
60 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
61 self.fn_analysis = fn_analysis
62 if isinstance(ssa_val_offsets, SSAVal):
63 ssa_val_offsets = {ssa_val_offsets: 0}
64 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
65 first_ssa_val = None
66 for ssa_val in self.ssa_vals:
67 first_ssa_val = ssa_val
68 break
69 if first_ssa_val is None:
70 raise BadMergedSSAVal("MergedSSAVal can't be empty")
71 self.first_ssa_val = first_ssa_val # type: SSAVal
72 # self.ty checks for mismatched base_ty
73 reg_len = self.ty.reg_len
74 loc_set = None # type: None | LocSet
75 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
76 def locs():
77 # type: () -> Iterable[Loc]
78 for loc in ssa_val.def_loc_set_before_spread:
79 disallowed_by_use = False
80 for use in fn_analysis.uses[ssa_val]:
81 # calculate the start for the use's Loc before spread
82 # e.g. if the def's Loc before spread starts at r6
83 # and the def's reg_offset_in_unspread is 5
84 # and the use's reg_offset_in_unspread is 3
85 # then the use's Loc before spread starts at r8
86 # because 8 == 6 + 5 - 3
87 start = (loc.start + ssa_val.reg_offset_in_unspread
88 - use.reg_offset_in_unspread)
89 use_loc = Loc.try_make(
90 loc.kind, start=start,
91 reg_len=use.ty_before_spread.reg_len)
92 if (use_loc is None or
93 use_loc not in use.use_loc_set_before_spread):
94 disallowed_by_use = True
95 break
96 if disallowed_by_use:
97 continue
98 start = loc.start - cur_offset + self.offset
99 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
100 if loc is not None and (loc_set is None or loc in loc_set):
101 yield loc
102 loc_set = LocSet(locs())
103 assert loc_set is not None, "already checked that self isn't empty"
104 first_loc = None
105 for loc in loc_set:
106 first_loc = loc
107 break
108 if first_loc is None:
109 raise BadMergedSSAVal("there are no valid Locs left")
110 self.first_loc = first_loc
111 assert loc_set.ty == self.ty, "logic error somewhere"
112 self.loc_set = loc_set # type: LocSet
113 self.__mergable_check()
114
115 def __mergable_check(self):
116 # type: () -> None
117 """ checks that nothing is forcing two independent SSAVals
118 to illegally overlap. This is required to avoid copy merging merging
119 things that can't be merged.
120 spread arguments are one of the things that can force two values to
121 illegally overlap.
122 """
123 ops = OSet() # type: Iterable[Op]
124 for ssa_val in self.ssa_vals:
125 ops.add(ssa_val.op)
126 for use in self.fn_analysis.uses[ssa_val]:
127 ops.add(use.op)
128 ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
129 vals = {} # type: dict[int, SSAValSubReg]
130 for op in ops:
131 for inp in op.input_vals:
132 try:
133 ssa_val_offset = self.ssa_val_offsets[inp]
134 except KeyError:
135 continue
136 for orig_reg in inp.ssa_val_sub_regs:
137 reg_offset = ssa_val_offset + orig_reg.reg_idx
138 replaced_reg = vals[reg_offset]
139 if not self.fn_analysis.is_always_equal(
140 orig_reg, replaced_reg):
141 raise BadMergedSSAVal(
142 f"attempting to merge values that aren't known to "
143 f"be always equal: {orig_reg} != {replaced_reg}")
144 output_offsets = dict.fromkeys(range(
145 self.offset, self.offset + self.ty.reg_len))
146 for out in op.outputs:
147 try:
148 ssa_val_offset = self.ssa_val_offsets[out]
149 except KeyError:
150 continue
151 for reg in out.ssa_val_sub_regs:
152 reg_offset = ssa_val_offset + reg.reg_idx
153 try:
154 del output_offsets[reg_offset]
155 except KeyError:
156 raise BadMergedSSAVal("attempted to merge two outputs "
157 "of the same instruction")
158 vals[reg_offset] = reg
159
160 @cached_property
161 def __hash(self):
162 # type: () -> int
163 return hash((self.fn_analysis, self.ssa_val_offsets))
164
165 def __hash__(self):
166 # type: () -> int
167 return self.__hash
168
169 @cached_property
170 def offset(self):
171 # type: () -> int
172 return min(self.ssa_val_offsets_before_spread.values())
173
174 @property
175 def base_ty(self):
176 # type: () -> BaseTy
177 return self.first_ssa_val.base_ty
178
179 @cached_property
180 def ssa_vals(self):
181 # type: () -> OFSet[SSAVal]
182 return OFSet(self.ssa_val_offsets.keys())
183
184 @cached_property
185 def ty(self):
186 # type: () -> Ty
187 reg_len = 0
188 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
189 cur_ty = ssa_val.ty_before_spread
190 if self.base_ty != cur_ty.base_ty:
191 raise BadMergedSSAVal(
192 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
193 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
194 return Ty(base_ty=self.base_ty, reg_len=reg_len)
195
196 @cached_property
197 def ssa_val_offsets_before_spread(self):
198 # type: () -> FMap[SSAVal, int]
199 retval = {} # type: dict[SSAVal, int]
200 for ssa_val, offset in self.ssa_val_offsets.items():
201 retval[ssa_val] = (
202 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
203 return FMap(retval)
204
205 def offset_by(self, amount):
206 # type: (int) -> MergedSSAVal
207 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
208 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
209
210 def normalized(self):
211 # type: () -> MergedSSAVal
212 return self.offset_by(-self.offset)
213
214 def with_offset_to_match(self, target, additional_offset=0):
215 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
216 if isinstance(target, MergedSSAVal):
217 ssa_val_offsets = target.ssa_val_offsets
218 else:
219 ssa_val_offsets = {target: 0}
220 for ssa_val, offset in self.ssa_val_offsets.items():
221 if ssa_val in ssa_val_offsets:
222 return self.offset_by(
223 ssa_val_offsets[ssa_val] + additional_offset - offset)
224 raise ValueError("can't change offset to match unrelated MergedSSAVal")
225
226 def merged(self, *others):
227 # type: (*MergedSSAVal) -> MergedSSAVal
228 retval = dict(self.ssa_val_offsets)
229 for other in others:
230 if other.fn_analysis != self.fn_analysis:
231 raise ValueError("fn_analysis mismatch")
232 for ssa_val, offset in other.ssa_val_offsets.items():
233 if ssa_val in retval and retval[ssa_val] != offset:
234 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
235 f"{retval[ssa_val]} != {offset}")
236 retval[ssa_val] = offset
237 return MergedSSAVal(fn_analysis=self.fn_analysis,
238 ssa_val_offsets=retval)
239
240 @cached_property
241 def live_interval(self):
242 # type: () -> ProgramRange
243 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
244 start = live_range.start
245 stop = live_range.stop
246 for ssa_val in self.ssa_vals:
247 live_range = self.fn_analysis.live_ranges[ssa_val]
248 start = min(start, live_range.start)
249 stop = max(stop, live_range.stop)
250 return ProgramRange(start=start, stop=stop)
251
252 def __repr__(self):
253 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
254 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
255 f"live_interval={self.live_interval})")
256
257 @cached_property
258 def copy_related_ssa_vals(self):
259 # type: () -> OFSet[SSAVal]
260 sets = OSet() # type: OSet[OFSet[SSAVal]]
261 # avoid merging the same sets multiple times
262 for ssa_val in self.ssa_vals:
263 sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
264 return OFSet(v for s in sets for v in s)
265
266
267 @final
268 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
269 def __init__(self):
270 # type: (...) -> None
271 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
272 self.__ig_node_map = MergedSSAValToIGNodeMap(
273 _private_merged_ssa_val_map=self.__map)
274
275 def __getitem__(self, __key):
276 # type: (SSAVal) -> MergedSSAVal
277 return self.__map[__key]
278
279 def __iter__(self):
280 # type: () -> Iterator[SSAVal]
281 return iter(self.__map)
282
283 def __len__(self):
284 # type: () -> int
285 return len(self.__map)
286
287 @property
288 def ig_node_map(self):
289 # type: () -> MergedSSAValToIGNodeMap
290 return self.__ig_node_map
291
292 def __repr__(self):
293 # type: () -> str
294 s = ",\n".join(repr(v) for v in self.__ig_node_map)
295 return f"SSAValToMergedSSAValMap({{{s}}})"
296
297
298 @final
299 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
300 def __init__(
301 self, *,
302 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
303 ):
304 # type: (...) -> None
305 self.__merged_ssa_val_map = _private_merged_ssa_val_map
306 self.__map = {} # type: dict[MergedSSAVal, IGNode]
307
308 def __getitem__(self, __key):
309 # type: (MergedSSAVal) -> IGNode
310 return self.__map[__key]
311
312 def __iter__(self):
313 # type: () -> Iterator[MergedSSAVal]
314 return iter(self.__map)
315
316 def __len__(self):
317 # type: () -> int
318 return len(self.__map)
319
320 def add_node(self, merged_ssa_val):
321 # type: (MergedSSAVal) -> IGNode
322 node = self.__map.get(merged_ssa_val, None)
323 if node is not None:
324 return node
325 added = 0 # type: int | None
326 try:
327 for ssa_val in merged_ssa_val.ssa_vals:
328 if ssa_val in self.__merged_ssa_val_map:
329 raise ValueError(
330 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
331 f"{merged_ssa_val} and "
332 f"{self.__merged_ssa_val_map[ssa_val]}")
333 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
334 added += 1
335 retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None)
336 self.__map[merged_ssa_val] = retval
337 added = None
338 return retval
339 finally:
340 if added is not None:
341 # remove partially added stuff
342 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
343 if idx >= added:
344 break
345 del self.__merged_ssa_val_map[ssa_val]
346
347 def merge_into_one_node(self, final_merged_ssa_val):
348 # type: (MergedSSAVal) -> IGNode
349 source_nodes = OSet() # type: OSet[IGNode]
350 edges = {} # type: dict[IGNode, IGEdge]
351 loc = None # type: Loc | None
352 for ssa_val in final_merged_ssa_val.ssa_vals:
353 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
354 source_node = self.__map[merged_ssa_val]
355 source_nodes.add(source_node)
356 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
357 raise ValueError(
358 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
359 f"but not in merged IGNode's merged_ssa_val: "
360 f"source_node={source_node} "
361 f"final_merged_ssa_val={final_merged_ssa_val}")
362 if loc is None:
363 loc = source_node.loc
364 elif source_node.loc is not None and loc != source_node.loc:
365 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
366 f"values: {loc} != {source_node.loc}")
367 for n, edge in source_node.edges.items():
368 if n in edges:
369 edge = edge.merged(edges[n])
370 edges[n] = edge
371 if len(source_nodes) == 1:
372 return source_nodes.pop() # merging a single node is a no-op
373 # we're finished checking validity, now we can modify stuff
374 for n in source_nodes:
375 edges.pop(n, None)
376 retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
377 loc=loc)
378 for node in edges:
379 edge = reduce(IGEdge.merged,
380 (node.edges.pop(n) for n in source_nodes))
381 node.edges[retval] = edge
382 for node in source_nodes:
383 del self.__map[node.merged_ssa_val]
384 self.__map[final_merged_ssa_val] = retval
385 for ssa_val in final_merged_ssa_val.ssa_vals:
386 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
387 return retval
388
389 def __repr__(self, repr_state=None):
390 # type: (None | IGNodeReprState) -> str
391 if repr_state is None:
392 repr_state = IGNodeReprState()
393 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
394 return f"MergedSSAValToIGNodeMap({{{s}}})"
395
396
397 @plain_data(frozen=True, repr=False)
398 @final
399 class InterferenceGraph:
400 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
401
402 def __init__(self, fn_analysis, merged_ssa_vals):
403 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
404 self.fn_analysis = fn_analysis
405 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
406 self.nodes = self.merged_ssa_val_map.ig_node_map
407 for i in merged_ssa_vals:
408 self.nodes.add_node(i)
409
410 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
411 # type: (SSAVal, SSAVal, int) -> IGNode
412 merged1 = self.merged_ssa_val_map[ssa_val1]
413 merged2 = self.merged_ssa_val_map[ssa_val2]
414 merged = merged1.with_offset_to_match(ssa_val1)
415 merged = merged.merged(merged2.with_offset_to_match(
416 ssa_val2, additional_offset=additional_offset))
417 return self.nodes.merge_into_one_node(merged)
418
419 @staticmethod
420 def minimally_merged(fn_analysis):
421 # type: (FnAnalysis) -> InterferenceGraph
422 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
423 for op in fn_analysis.fn.ops:
424 for inp in op.input_uses:
425 if inp.unspread_start != inp:
426 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
427 additional_offset=inp.reg_offset_in_unspread)
428 for out in op.outputs:
429 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
430 if out.unspread_start != out:
431 retval.merge(out.unspread_start, out,
432 additional_offset=out.reg_offset_in_unspread)
433 if out.tied_input is not None:
434 retval.merge(out.tied_input.ssa_val, out)
435 return retval
436
437 def __repr__(self, repr_state=None):
438 # type: (None | IGNodeReprState) -> str
439 if repr_state is None:
440 repr_state = IGNodeReprState()
441 s = self.nodes.__repr__(repr_state)
442 return f"InterferenceGraph(nodes={s}, <...>)"
443
444
445 @plain_data(repr=False)
446 class IGNodeReprState:
447 __slots__ = "node_ids", "did_full_repr"
448
449 def __init__(self):
450 super().__init__()
451 self.node_ids = {} # type: dict[IGNode, int]
452 self.did_full_repr = OSet() # type: OSet[IGNode]
453
454
455 @plain_data(frozen=True, unsafe_hash=True)
456 @final
457 class IGEdge:
458 """ interference graph edge """
459 __slots__ = "is_copy_related",
460
461 def __init__(self, is_copy_related):
462 # type: (bool) -> None
463 self.is_copy_related = is_copy_related
464
465 def merged(self, other):
466 # type: (IGEdge) -> IGEdge
467 is_copy_related = self.is_copy_related | other.is_copy_related
468 return IGEdge(is_copy_related=is_copy_related)
469
470
471 @final
472 class IGNode:
473 """ interference graph node """
474 __slots__ = "merged_ssa_val", "edges", "loc"
475
476 def __init__(self, merged_ssa_val, edges, loc):
477 # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
478 self.merged_ssa_val = merged_ssa_val
479 self.edges = edges
480 self.loc = loc
481
482 def add_edge(self, other, edge):
483 # type: (IGNode, IGEdge) -> None
484 self.edges[other] = edge
485 other.edges[self] = edge
486
487 def __eq__(self, other):
488 # type: (object) -> bool
489 if isinstance(other, IGNode):
490 return self.merged_ssa_val == other.merged_ssa_val
491 return NotImplemented
492
493 def __hash__(self):
494 # type: () -> int
495 return hash(self.merged_ssa_val)
496
497 def __repr__(self, repr_state=None, short=False):
498 # type: (None | IGNodeReprState, bool) -> str
499 rs = repr_state
500 del repr_state
501 if rs is None:
502 rs = IGNodeReprState()
503 node_id = rs.node_ids.get(self, None)
504 if node_id is None:
505 rs.node_ids[self] = node_id = len(rs.node_ids)
506 if short or self in rs.did_full_repr:
507 return f"<IGNode #{node_id}>"
508 rs.did_full_repr.add(self)
509 edges = ", ".join(
510 f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
511 return (f"IGNode(#{node_id}, "
512 f"merged_ssa_val={self.merged_ssa_val}, "
513 f"edges={{{edges}}}, "
514 f"loc={self.loc})")
515
516 @property
517 def loc_set(self):
518 # type: () -> LocSet
519 return self.merged_ssa_val.loc_set
520
521 def loc_conflicts_with_neighbors(self, loc):
522 # type: (Loc) -> bool
523 for neighbor in self.edges:
524 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
525 return True
526 return False
527
528
529 class AllocationFailedError(Exception):
530 def __init__(self, msg, node, interference_graph):
531 # type: (str, IGNode, InterferenceGraph) -> None
532 super().__init__(msg, node, interference_graph)
533 self.node = node
534 self.interference_graph = interference_graph
535
536 def __repr__(self, repr_state=None):
537 # type: (None | IGNodeReprState) -> str
538 if repr_state is None:
539 repr_state = IGNodeReprState()
540 return (f"{__class__.__name__}({self.args[0]!r}, "
541 f"node={self.node.__repr__(repr_state, True)}, "
542 f"interference_graph="
543 f"{self.interference_graph.__repr__(repr_state)})")
544
545 def __str__(self):
546 # type: () -> str
547 return self.__repr__()
548
549
550 def allocate_registers(fn, debug_out=None):
551 # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
552
553 # inserts enough copies that no manual spilling is necessary, all
554 # spilling is done by the register allocator naturally allocating SSAVals
555 # to stack slots
556 fn.pre_ra_insert_copies()
557
558 if debug_out is not None:
559 print(f"After pre_ra_insert_copies():\n{fn.ops}",
560 file=debug_out, flush=True)
561
562 fn_analysis = FnAnalysis(fn)
563 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
564
565 if debug_out is not None:
566 print(f"After InterferenceGraph.minimally_merged():\n"
567 f"{interference_graph}", file=debug_out, flush=True)
568
569 for pp, ssa_vals in fn_analysis.live_at.items():
570 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
571 for ssa_val in ssa_vals:
572 live_merged_ssa_vals.add(
573 interference_graph.merged_ssa_val_map[ssa_val])
574 for i, j in combinations(live_merged_ssa_vals, 2):
575 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
576 # can't use:
577 # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
578 # j.copy_related_ssa_vals)
579 # since it is too coarse
580
581 # TODO: fill in is_copy_related afterwards
582 # using fn_analysis.copies
583 interference_graph.nodes[i].add_edge(
584 interference_graph.nodes[j],
585 edge=IGEdge(is_copy_related=False))
586 if debug_out is not None:
587 print(f"processed {pp} out of {fn_analysis.all_program_points}",
588 file=debug_out, flush=True)
589
590 if debug_out is not None:
591 print(f"After adding interference graph edges:\n"
592 f"{interference_graph}", file=debug_out, flush=True)
593
594 nodes_remaining = OSet(interference_graph.nodes.values())
595
596 local_colorability_score_cache = {} # type: dict[IGNode, int]
597
598 def local_colorability_score(node):
599 # type: (IGNode) -> int
600 """ returns a positive integer if node is locally colorable, returns
601 zero or a negative integer if node isn't known to be locally
602 colorable, the more negative the value, the less colorable
603 """
604 if node not in nodes_remaining:
605 raise ValueError()
606 retval = local_colorability_score_cache.get(node, None)
607 if retval is not None:
608 return retval
609 retval = len(node.loc_set)
610 for neighbor in node.edges:
611 if neighbor in nodes_remaining:
612 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
613 local_colorability_score_cache[node] = retval
614 return retval
615
616 # TODO: implement copy-merging
617
618 node_stack = [] # type: list[IGNode]
619 while True:
620 best_node = None # type: None | IGNode
621 best_score = 0
622 for node in nodes_remaining:
623 score = local_colorability_score(node)
624 if best_node is None or score > best_score:
625 best_node = node
626 best_score = score
627 if best_score > 0:
628 # it's locally colorable, no need to find a better one
629 break
630
631 if best_node is None:
632 break
633 node_stack.append(best_node)
634 nodes_remaining.remove(best_node)
635 local_colorability_score_cache.pop(best_node, None)
636 for neighbor in best_node.edges:
637 local_colorability_score_cache.pop(neighbor, None)
638
639 if debug_out is not None:
640 print(f"After deciding node allocation order:\n"
641 f"{node_stack}", file=debug_out, flush=True)
642
643 retval = {} # type: dict[SSAVal, Loc]
644
645 while len(node_stack) > 0:
646 node = node_stack.pop()
647 if node.loc is not None:
648 if node.loc_conflicts_with_neighbors(node.loc):
649 raise AllocationFailedError(
650 "IGNode is pre-allocated to a conflicting Loc",
651 node=node, interference_graph=interference_graph)
652 else:
653 # pick the first non-conflicting register in node.reg_class, since
654 # register classes are ordered from most preferred to least
655 # preferred register.
656 for loc in node.loc_set:
657 if not node.loc_conflicts_with_neighbors(loc):
658 node.loc = loc
659 break
660 if node.loc is None:
661 raise AllocationFailedError(
662 "failed to allocate Loc for IGNode",
663 node=node, interference_graph=interference_graph)
664
665 if debug_out is not None:
666 print(f"After allocating Loc for node:\n{node}",
667 file=debug_out, flush=True)
668
669 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
670 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
671
672 if debug_out is not None:
673 print(f"final Locs for all SSAVals:\n{retval}",
674 file=debug_out, flush=True)
675
676 return retval