2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from functools
import reduce
9 from itertools
import combinations
10 from typing
import Callable
, Iterable
, Iterator
, Mapping
, TextIO
12 from cached_property
import cached_property
13 from nmutil
.plain_data
import plain_data
15 from bigint_presentation_code
.compiler_ir
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
16 LocSet
, Op
, ProgramRange
,
17 SSAVal
, SSAValSubReg
, Ty
)
18 from bigint_presentation_code
.type_util
import final
19 from bigint_presentation_code
.util
import FMap
, InternedMeta
, OFSet
, OSet
22 class BadMergedSSAVal(ValueError):
26 @plain_data(frozen
=True, repr=False)
28 class MergedSSAVal(metaclass
=InternedMeta
):
29 """a set of `SSAVal`s along with their offsets, all register allocated as
32 Definition of the term `offset` for this class:
34 Let `locs[x]` be the `Loc` that `x` is assigned to after register
35 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
36 for each `SSAVal` `ssa_val` in `msv` is defined as:
39 msv.ssa_val_offsets[ssa_val] = (msv.offset
40 + locs[ssa_val].start - locs[msv].start)
48 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
51 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
52 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
53 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
54 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
56 __slots__
= ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
59 def __init__(self
, fn_analysis
, ssa_val_offsets
):
60 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
61 self
.fn_analysis
= fn_analysis
62 if isinstance(ssa_val_offsets
, SSAVal
):
63 ssa_val_offsets
= {ssa_val_offsets
: 0}
64 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
66 for ssa_val
in self
.ssa_vals
:
67 first_ssa_val
= ssa_val
69 if first_ssa_val
is None:
70 raise BadMergedSSAVal("MergedSSAVal can't be empty")
71 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
72 # self.ty checks for mismatched base_ty
73 reg_len
= self
.ty
.reg_len
74 loc_set
= None # type: None | LocSet
75 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
77 # type: () -> Iterable[Loc]
78 for loc
in ssa_val
.def_loc_set_before_spread
:
79 disallowed_by_use
= False
80 for use
in fn_analysis
.uses
[ssa_val
]:
81 # calculate the start for the use's Loc before spread
82 # e.g. if the def's Loc before spread starts at r6
83 # and the def's reg_offset_in_unspread is 5
84 # and the use's reg_offset_in_unspread is 3
85 # then the use's Loc before spread starts at r8
86 # because 8 == 6 + 5 - 3
87 start
= (loc
.start
+ ssa_val
.reg_offset_in_unspread
88 - use
.reg_offset_in_unspread
)
89 use_loc
= Loc
.try_make(
90 loc
.kind
, start
=start
,
91 reg_len
=use
.ty_before_spread
.reg_len
)
92 if (use_loc
is None or
93 use_loc
not in use
.use_loc_set_before_spread
):
94 disallowed_by_use
= True
98 start
= loc
.start
- cur_offset
+ self
.offset
99 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
100 if loc
is not None and (loc_set
is None or loc
in loc_set
):
102 loc_set
= LocSet(locs())
103 assert loc_set
is not None, "already checked that self isn't empty"
108 if first_loc
is None:
109 raise BadMergedSSAVal("there are no valid Locs left")
110 self
.first_loc
= first_loc
111 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
112 self
.loc_set
= loc_set
# type: LocSet
113 self
.__mergable
_check
()
115 def __mergable_check(self
):
117 """ checks that nothing is forcing two independent SSAVals
118 to illegally overlap. This is required to avoid copy merging merging
119 things that can't be merged.
120 spread arguments are one of the things that can force two values to
123 ops
= OSet() # type: Iterable[Op]
124 for ssa_val
in self
.ssa_vals
:
126 for use
in self
.fn_analysis
.uses
[ssa_val
]:
128 ops
= sorted(ops
, key
=self
.fn_analysis
.op_indexes
.__getitem
__)
129 vals
= {} # type: dict[int, SSAValSubReg]
131 for inp
in op
.input_vals
:
133 ssa_val_offset
= self
.ssa_val_offsets
[inp
]
136 for orig_reg
in inp
.ssa_val_sub_regs
:
137 reg_offset
= ssa_val_offset
+ orig_reg
.reg_idx
138 replaced_reg
= vals
[reg_offset
]
139 if not self
.fn_analysis
.is_always_equal(
140 orig_reg
, replaced_reg
):
141 raise BadMergedSSAVal(
142 f
"attempting to merge values that aren't known to "
143 f
"be always equal: {orig_reg} != {replaced_reg}")
144 output_offsets
= dict.fromkeys(range(
145 self
.offset
, self
.offset
+ self
.ty
.reg_len
))
146 for out
in op
.outputs
:
148 ssa_val_offset
= self
.ssa_val_offsets
[out
]
151 for reg
in out
.ssa_val_sub_regs
:
152 reg_offset
= ssa_val_offset
+ reg
.reg_idx
154 del output_offsets
[reg_offset
]
156 raise BadMergedSSAVal("attempted to merge two outputs "
157 "of the same instruction")
158 vals
[reg_offset
] = reg
163 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
172 return min(self
.ssa_val_offsets_before_spread
.values())
177 return self
.first_ssa_val
.base_ty
181 # type: () -> OFSet[SSAVal]
182 return OFSet(self
.ssa_val_offsets
.keys())
188 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
189 cur_ty
= ssa_val
.ty_before_spread
190 if self
.base_ty
!= cur_ty
.base_ty
:
191 raise BadMergedSSAVal(
192 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
193 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
194 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
197 def ssa_val_offsets_before_spread(self
):
198 # type: () -> FMap[SSAVal, int]
199 retval
= {} # type: dict[SSAVal, int]
200 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
202 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
205 def offset_by(self
, amount
):
206 # type: (int) -> MergedSSAVal
207 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
208 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
210 def normalized(self
):
211 # type: () -> MergedSSAVal
212 return self
.offset_by(-self
.offset
)
214 def with_offset_to_match(self
, target
, additional_offset
=0):
215 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
216 if isinstance(target
, MergedSSAVal
):
217 ssa_val_offsets
= target
.ssa_val_offsets
219 ssa_val_offsets
= {target
: 0}
220 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
221 if ssa_val
in ssa_val_offsets
:
222 return self
.offset_by(
223 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
224 raise ValueError("can't change offset to match unrelated MergedSSAVal")
226 def merged(self
, *others
):
227 # type: (*MergedSSAVal) -> MergedSSAVal
228 retval
= dict(self
.ssa_val_offsets
)
230 if other
.fn_analysis
!= self
.fn_analysis
:
231 raise ValueError("fn_analysis mismatch")
232 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
233 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
234 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
235 f
"{retval[ssa_val]} != {offset}")
236 retval
[ssa_val
] = offset
237 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
238 ssa_val_offsets
=retval
)
241 def live_interval(self
):
242 # type: () -> ProgramRange
243 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
244 start
= live_range
.start
245 stop
= live_range
.stop
246 for ssa_val
in self
.ssa_vals
:
247 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
248 start
= min(start
, live_range
.start
)
249 stop
= max(stop
, live_range
.stop
)
250 return ProgramRange(start
=start
, stop
=stop
)
253 return (f
"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
254 f
"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
255 f
"live_interval={self.live_interval})")
258 def copy_related_ssa_vals(self
):
259 # type: () -> OFSet[SSAVal]
260 sets
= OSet() # type: OSet[OFSet[SSAVal]]
261 # avoid merging the same sets multiple times
262 for ssa_val
in self
.ssa_vals
:
263 sets
.add(self
.fn_analysis
.copy_related_ssa_vals
[ssa_val
])
264 return OFSet(v
for s
in sets
for v
in s
)
266 def is_copy_related(self
, other
):
267 # type: (MergedSSAVal) -> bool
268 for lhs_ssa_val
in self
.ssa_vals
:
269 for rhs_ssa_val
in other
.ssa_vals
:
270 for lhs
in lhs_ssa_val
.ssa_val_sub_regs
:
271 for rhs
in rhs_ssa_val
.ssa_val_sub_regs
:
272 lhs
= self
.fn_analysis
.copies
.get(lhs
, lhs
)
273 rhs
= self
.fn_analysis
.copies
.get(rhs
, rhs
)
280 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
282 # type: (...) -> None
283 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
284 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
285 _private_merged_ssa_val_map
=self
.__map
)
287 def __getitem__(self
, __key
):
288 # type: (SSAVal) -> MergedSSAVal
289 return self
.__map
[__key
]
292 # type: () -> Iterator[SSAVal]
293 return iter(self
.__map
)
297 return len(self
.__map
)
300 def ig_node_map(self
):
301 # type: () -> MergedSSAValToIGNodeMap
302 return self
.__ig
_node
_map
306 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
307 return f
"SSAValToMergedSSAValMap({{{s}}})"
311 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
314 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
316 # type: (...) -> None
317 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
318 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
320 def __getitem__(self
, __key
):
321 # type: (MergedSSAVal) -> IGNode
322 return self
.__map
[__key
]
325 # type: () -> Iterator[MergedSSAVal]
326 return iter(self
.__map
)
330 return len(self
.__map
)
332 def add_node(self
, merged_ssa_val
):
333 # type: (MergedSSAVal) -> IGNode
334 node
= self
.__map
.get(merged_ssa_val
, None)
337 added
= 0 # type: int | None
339 for ssa_val
in merged_ssa_val
.ssa_vals
:
340 if ssa_val
in self
.__merged
_ssa
_val
_map
:
342 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
343 f
"{merged_ssa_val} and "
344 f
"{self.__merged_ssa_val_map[ssa_val]}")
345 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
347 retval
= IGNode(merged_ssa_val
=merged_ssa_val
, edges
={}, loc
=None)
348 self
.__map
[merged_ssa_val
] = retval
352 if added
is not None:
353 # remove partially added stuff
354 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
357 del self
.__merged
_ssa
_val
_map
[ssa_val
]
359 def merge_into_one_node(self
, final_merged_ssa_val
):
360 # type: (MergedSSAVal) -> IGNode
361 source_nodes
= OSet() # type: OSet[IGNode]
362 edges
= {} # type: dict[IGNode, IGEdge]
363 loc
= None # type: Loc | None
364 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
365 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
366 source_node
= self
.__map
[merged_ssa_val
]
367 source_nodes
.add(source_node
)
368 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
370 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
371 f
"but not in merged IGNode's merged_ssa_val: "
372 f
"source_node={source_node} "
373 f
"final_merged_ssa_val={final_merged_ssa_val}")
375 loc
= source_node
.loc
376 elif source_node
.loc
is not None and loc
!= source_node
.loc
:
377 raise ValueError(f
"can't merge IGNodes with mismatched `loc` "
378 f
"values: {loc} != {source_node.loc}")
379 for n
, edge
in source_node
.edges
.items():
381 edge
= edge
.merged(edges
[n
])
383 if len(source_nodes
) == 1:
384 return source_nodes
.pop() # merging a single node is a no-op
385 # we're finished checking validity, now we can modify stuff
386 for n
in source_nodes
:
388 retval
= IGNode(merged_ssa_val
=final_merged_ssa_val
, edges
=edges
,
391 edge
= reduce(IGEdge
.merged
,
392 (node
.edges
.pop(n
) for n
in source_nodes
))
393 node
.edges
[retval
] = edge
394 for node
in source_nodes
:
395 del self
.__map
[node
.merged_ssa_val
]
396 self
.__map
[final_merged_ssa_val
] = retval
397 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
398 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
401 def __repr__(self
, repr_state
=None):
402 # type: (None | IGNodeReprState) -> str
403 if repr_state
is None:
404 repr_state
= IGNodeReprState()
405 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
406 return f
"MergedSSAValToIGNodeMap({{{s}}})"
409 @plain_data(frozen
=True, repr=False)
411 class InterferenceGraph
:
412 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
414 def __init__(self
, fn_analysis
, merged_ssa_vals
):
415 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
416 self
.fn_analysis
= fn_analysis
417 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
418 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
419 for i
in merged_ssa_vals
:
420 self
.nodes
.add_node(i
)
422 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
423 # type: (SSAVal, SSAVal, int) -> IGNode
424 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
425 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
426 merged
= merged1
.with_offset_to_match(ssa_val1
)
427 merged
= merged
.merged(merged2
.with_offset_to_match(
428 ssa_val2
, additional_offset
=additional_offset
))
429 return self
.nodes
.merge_into_one_node(merged
)
432 def minimally_merged(fn_analysis
):
433 # type: (FnAnalysis) -> InterferenceGraph
434 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
435 for op
in fn_analysis
.fn
.ops
:
436 for inp
in op
.input_uses
:
437 if inp
.unspread_start
!= inp
:
438 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
439 additional_offset
=inp
.reg_offset_in_unspread
)
440 for out
in op
.outputs
:
441 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
442 if out
.unspread_start
!= out
:
443 retval
.merge(out
.unspread_start
, out
,
444 additional_offset
=out
.reg_offset_in_unspread
)
445 if out
.tied_input
is not None:
446 retval
.merge(out
.tied_input
.ssa_val
, out
)
449 def __repr__(self
, repr_state
=None):
450 # type: (None | IGNodeReprState) -> str
451 if repr_state
is None:
452 repr_state
= IGNodeReprState()
453 s
= self
.nodes
.__repr
__(repr_state
)
454 return f
"InterferenceGraph(nodes={s}, <...>)"
456 def dump_to_dot(self
):
460 # type: (object) -> str
462 s
= s
.replace('\\', r
'\\')
463 s
= s
.replace('"', r
'\"')
464 s
= s
.replace('\n', r
'\n')
467 edges
= {} # type: dict[tuple[IGNode, IGNode], IGEdge]
468 node_ids
= {} # type: dict[IGNode, str]
469 for node
in self
.nodes
.values():
470 node_ids
[node
] = quote(len(node_ids
))
471 for neighbor
, edge
in node
.edges
.items():
472 edge_key
= (node
, neighbor
)
473 # ensure we only insert each edge once by checking for
475 if edge_key
not in edges
and edge_key
[::-1] not in edges
:
476 edges
[edge_key
] = edge
478 for node
, node_id
in node_ids
.items():
479 label_lines
= [] # type: list[str]
480 for k
, v
in node
.merged_ssa_val
.ssa_val_offsets
.items():
481 label_lines
.append(f
"{k}: {v}")
482 label
= quote("\n".join(label_lines
))
483 lines
.append(f
" {node_id} [label = {label}]")
484 for (node1
, node2
), edge
in edges
.items():
485 label
= quote(repr(edge
))
486 lines
.append(f
" {node_ids[node1]} -- {node_ids[node2]} "
487 f
"[label = {label}]")
489 return "\n".join(lines
)
492 @plain_data(repr=False)
493 class IGNodeReprState
:
494 __slots__
= "node_ids", "did_full_repr"
498 self
.node_ids
= {} # type: dict[IGNode, int]
499 self
.did_full_repr
= OSet() # type: OSet[IGNode]
502 @plain_data(frozen
=True, unsafe_hash
=True)
505 """ interference graph edge """
506 __slots__
= "is_copy_related",
508 def __init__(self
, is_copy_related
):
509 # type: (bool) -> None
510 self
.is_copy_related
= is_copy_related
512 def merged(self
, other
):
513 # type: (IGEdge) -> IGEdge
514 is_copy_related
= self
.is_copy_related | other
.is_copy_related
515 return IGEdge(is_copy_related
=is_copy_related
)
520 """ interference graph node """
521 __slots__
= "merged_ssa_val", "edges", "loc"
523 def __init__(self
, merged_ssa_val
, edges
, loc
):
524 # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
525 self
.merged_ssa_val
= merged_ssa_val
529 def add_edge(self
, other
, edge
):
530 # type: (IGNode, IGEdge) -> None
531 self
.edges
[other
] = edge
532 other
.edges
[self
] = edge
534 def __eq__(self
, other
):
535 # type: (object) -> bool
536 if isinstance(other
, IGNode
):
537 return self
.merged_ssa_val
== other
.merged_ssa_val
538 return NotImplemented
542 return hash(self
.merged_ssa_val
)
544 def __repr__(self
, repr_state
=None, short
=False):
545 # type: (None | IGNodeReprState, bool) -> str
549 rs
= IGNodeReprState()
550 node_id
= rs
.node_ids
.get(self
, None)
552 rs
.node_ids
[self
] = node_id
= len(rs
.node_ids
)
553 if short
or self
in rs
.did_full_repr
:
554 return f
"<IGNode #{node_id}>"
555 rs
.did_full_repr
.add(self
)
557 f
"{k.__repr__(rs, True)}: {v}" for k
, v
in self
.edges
.items())
558 return (f
"IGNode(#{node_id}, "
559 f
"merged_ssa_val={self.merged_ssa_val}, "
560 f
"edges={{{edges}}}, "
566 return self
.merged_ssa_val
.loc_set
568 def loc_conflicts_with_neighbors(self
, loc
):
569 # type: (Loc) -> bool
570 for neighbor
in self
.edges
:
571 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
576 class AllocationFailedError(Exception):
577 def __init__(self
, msg
, node
, interference_graph
):
578 # type: (str, IGNode, InterferenceGraph) -> None
579 super().__init
__(msg
, node
, interference_graph
)
581 self
.interference_graph
= interference_graph
583 def __repr__(self
, repr_state
=None):
584 # type: (None | IGNodeReprState) -> str
585 if repr_state
is None:
586 repr_state
= IGNodeReprState()
587 return (f
"{__class__.__name__}({self.args[0]!r}, "
588 f
"node={self.node.__repr__(repr_state, True)}, "
589 f
"interference_graph="
590 f
"{self.interference_graph.__repr__(repr_state)})")
594 return self
.__repr
__()
597 def allocate_registers(
599 debug_out
=None, # type: TextIO | None
600 dump_graph
=None, # type: Callable[[str, str], None] | None
602 # type: (...) -> dict[SSAVal, Loc]
604 # inserts enough copies that no manual spilling is necessary, all
605 # spilling is done by the register allocator naturally allocating SSAVals
607 fn
.pre_ra_insert_copies()
609 if debug_out
is not None:
610 print(f
"After pre_ra_insert_copies():\n{fn.ops}",
611 file=debug_out
, flush
=True)
613 fn_analysis
= FnAnalysis(fn
)
614 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
616 if debug_out
is not None:
617 print(f
"After InterferenceGraph.minimally_merged():\n"
618 f
"{interference_graph}", file=debug_out
, flush
=True)
620 for pp
, ssa_vals
in fn_analysis
.live_at
.items():
621 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
622 for ssa_val
in ssa_vals
:
623 live_merged_ssa_vals
.add(
624 interference_graph
.merged_ssa_val_map
[ssa_val
])
625 for i
, j
in combinations(live_merged_ssa_vals
, 2):
626 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
628 # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
629 # j.copy_related_ssa_vals)
630 # since it is too coarse
631 interference_graph
.nodes
[i
].add_edge(
632 interference_graph
.nodes
[j
],
633 edge
=IGEdge(is_copy_related
=i
.is_copy_related(j
)))
634 if debug_out
is not None:
635 print(f
"processed {pp} out of {fn_analysis.all_program_points}",
636 file=debug_out
, flush
=True)
638 if debug_out
is not None:
639 print(f
"After adding interference graph edges:\n"
640 f
"{interference_graph}", file=debug_out
, flush
=True)
641 if dump_graph
is not None:
642 dump_graph("initial", interference_graph
.dump_to_dot())
644 nodes_remaining
= OSet(interference_graph
.nodes
.values())
646 local_colorability_score_cache
= {} # type: dict[IGNode, int]
648 def local_colorability_score(node
):
649 # type: (IGNode) -> int
650 """ returns a positive integer if node is locally colorable, returns
651 zero or a negative integer if node isn't known to be locally
652 colorable, the more negative the value, the less colorable
654 if node
not in nodes_remaining
:
656 retval
= local_colorability_score_cache
.get(node
, None)
657 if retval
is not None:
659 retval
= len(node
.loc_set
)
660 for neighbor
in node
.edges
:
661 if neighbor
in nodes_remaining
:
662 retval
-= node
.loc_set
.max_conflicts_with(neighbor
.loc_set
)
663 local_colorability_score_cache
[node
] = retval
666 # TODO: implement copy-merging
668 node_stack
= [] # type: list[IGNode]
670 best_node
= None # type: None | IGNode
672 for node
in nodes_remaining
:
673 score
= local_colorability_score(node
)
674 if best_node
is None or score
> best_score
:
678 # it's locally colorable, no need to find a better one
681 if best_node
is None:
683 node_stack
.append(best_node
)
684 nodes_remaining
.remove(best_node
)
685 local_colorability_score_cache
.pop(best_node
, None)
686 for neighbor
in best_node
.edges
:
687 local_colorability_score_cache
.pop(neighbor
, None)
689 if debug_out
is not None:
690 print(f
"After deciding node allocation order:\n"
691 f
"{node_stack}", file=debug_out
, flush
=True)
693 retval
= {} # type: dict[SSAVal, Loc]
695 while len(node_stack
) > 0:
696 node
= node_stack
.pop()
697 if node
.loc
is not None:
698 if node
.loc_conflicts_with_neighbors(node
.loc
):
699 raise AllocationFailedError(
700 "IGNode is pre-allocated to a conflicting Loc",
701 node
=node
, interference_graph
=interference_graph
)
703 # pick the first non-conflicting register in node.reg_class, since
704 # register classes are ordered from most preferred to least
705 # preferred register.
706 for loc
in node
.loc_set
:
707 if not node
.loc_conflicts_with_neighbors(loc
):
711 raise AllocationFailedError(
712 "failed to allocate Loc for IGNode",
713 node
=node
, interference_graph
=interference_graph
)
715 if debug_out
is not None:
716 print(f
"After allocating Loc for node:\n{node}",
717 file=debug_out
, flush
=True)
719 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
720 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)
722 if debug_out
is not None:
723 print(f
"final Locs for all SSAVals:\n{retval}",
724 file=debug_out
, flush
=True)