2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from functools
import reduce
9 from itertools
import combinations
, count
10 from typing
import Callable
, Container
, Iterable
, Iterator
, Mapping
, TextIO
, Tuple
12 from cached_property
import cached_property
13 from nmutil
.plain_data
import plain_data
, replace
15 from bigint_presentation_code
.compiler_ir
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
16 LocSet
, Op
, ProgramRange
,
17 SSAVal
, SSAValSubReg
, Ty
)
18 from bigint_presentation_code
.type_util
import final
19 from bigint_presentation_code
.util
import FMap
, InternedMeta
, OFSet
, OSet
22 class BadMergedSSAVal(ValueError):
26 _CopyRelation
= Tuple
[SSAValSubReg
, SSAValSubReg
]
29 @plain_data(frozen
=True, repr=False)
31 class MergedSSAVal(metaclass
=InternedMeta
):
32 """a set of `SSAVal`s along with their offsets, all register allocated as
35 Definition of the term `offset` for this class:
37 Let `locs[x]` be the `Loc` that `x` is assigned to after register
38 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
39 for each `SSAVal` `ssa_val` in `msv` is defined as:
42 msv.ssa_val_offsets[ssa_val] = (msv.offset
43 + locs[ssa_val].start - locs[msv].start)
51 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
54 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
55 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
56 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
57 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
59 __slots__
= ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
62 def __init__(self
, fn_analysis
, ssa_val_offsets
, loc_set
=None):
63 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
64 self
.fn_analysis
= fn_analysis
65 if isinstance(ssa_val_offsets
, SSAVal
):
66 ssa_val_offsets
= {ssa_val_offsets
: 0}
67 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
69 for ssa_val
in self
.ssa_vals
:
70 first_ssa_val
= ssa_val
72 if first_ssa_val
is None:
73 raise BadMergedSSAVal("MergedSSAVal can't be empty")
74 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
75 # self.ty checks for mismatched base_ty
76 reg_len
= self
.ty
.reg_len
77 if loc_set
is not None and loc_set
.ty
!= self
.ty
:
79 f
"invalid loc_set, type doesn't match: "
80 f
"{loc_set.ty} != {self.ty}")
81 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
83 # type: () -> Iterable[Loc]
84 for loc
in ssa_val
.def_loc_set_before_spread
:
85 disallowed_by_use
= False
86 for use
in fn_analysis
.uses
[ssa_val
]:
87 # calculate the start for the use's Loc before spread
88 # e.g. if the def's Loc before spread starts at r6
89 # and the def's reg_offset_in_unspread is 5
90 # and the use's reg_offset_in_unspread is 3
91 # then the use's Loc before spread starts at r8
92 # because 8 == 6 + 5 - 3
93 start
= (loc
.start
+ ssa_val
.reg_offset_in_unspread
94 - use
.reg_offset_in_unspread
)
95 use_loc
= Loc
.try_make(
96 loc
.kind
, start
=start
,
97 reg_len
=use
.ty_before_spread
.reg_len
)
98 if (use_loc
is None or
99 use_loc
not in use
.use_loc_set_before_spread
):
100 disallowed_by_use
= True
102 if disallowed_by_use
:
104 start
= loc
.start
- cur_offset
+ self
.offset
105 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
106 if loc
is not None and (loc_set
is None or loc
in loc_set
):
108 loc_set
= LocSet(locs())
109 assert loc_set
is not None, "already checked that self isn't empty"
114 if first_loc
is None:
115 raise BadMergedSSAVal("there are no valid Locs left")
116 self
.first_loc
= first_loc
117 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
118 self
.loc_set
= loc_set
# type: LocSet
119 self
.__mergable
_check
()
121 def __mergable_check(self
):
123 """ checks that nothing is forcing two independent SSAVals
124 to illegally overlap. This is required to avoid copy merging merging
125 things that can't be merged.
126 spread arguments are one of the things that can force two values to
129 ops
= OSet() # type: Iterable[Op]
130 for ssa_val
in self
.ssa_vals
:
132 for use
in self
.fn_analysis
.uses
[ssa_val
]:
134 ops
= sorted(ops
, key
=self
.fn_analysis
.op_indexes
.__getitem
__)
135 vals
= {} # type: dict[int, SSAValSubReg]
137 for inp
in op
.input_vals
:
139 ssa_val_offset
= self
.ssa_val_offsets
[inp
]
142 for orig_reg
in inp
.ssa_val_sub_regs
:
143 reg_offset
= ssa_val_offset
+ orig_reg
.reg_idx
144 replaced_reg
= vals
[reg_offset
]
145 if not self
.fn_analysis
.is_always_equal(
146 orig_reg
, replaced_reg
):
147 raise BadMergedSSAVal(
148 f
"attempting to merge values that aren't known to "
149 f
"be always equal: {orig_reg} != {replaced_reg}")
150 output_offsets
= dict.fromkeys(range(
151 self
.offset
, self
.offset
+ self
.ty
.reg_len
))
152 for out
in op
.outputs
:
154 ssa_val_offset
= self
.ssa_val_offsets
[out
]
157 for reg
in out
.ssa_val_sub_regs
:
158 reg_offset
= ssa_val_offset
+ reg
.reg_idx
160 del output_offsets
[reg_offset
]
162 raise BadMergedSSAVal("attempted to merge two outputs "
163 "of the same instruction")
164 vals
[reg_offset
] = reg
169 return hash((self
.fn_analysis
, self
.ssa_val_offsets
, self
.loc_set
))
177 # type: () -> Loc | None
178 return self
.loc_set
.only_loc
183 return min(self
.ssa_val_offsets_before_spread
.values())
188 return self
.first_ssa_val
.base_ty
192 # type: () -> OFSet[SSAVal]
193 return OFSet(self
.ssa_val_offsets
.keys())
199 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
200 cur_ty
= ssa_val
.ty_before_spread
201 if self
.base_ty
!= cur_ty
.base_ty
:
202 raise BadMergedSSAVal(
203 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
204 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
205 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
208 def ssa_val_offsets_before_spread(self
):
209 # type: () -> FMap[SSAVal, int]
210 retval
= {} # type: dict[SSAVal, int]
211 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
213 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
216 def offset_by(self
, amount
):
217 # type: (int) -> MergedSSAVal
218 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
219 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
221 def normalized(self
):
222 # type: () -> MergedSSAVal
223 return self
.offset_by(-self
.offset
)
225 def with_offset_to_match(self
, target
, additional_offset
=0):
226 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
227 if isinstance(target
, MergedSSAVal
):
228 ssa_val_offsets
= target
.ssa_val_offsets
230 ssa_val_offsets
= {target
: 0}
231 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
232 if ssa_val
in ssa_val_offsets
:
233 return self
.offset_by(
234 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
235 raise ValueError("can't change offset to match unrelated MergedSSAVal")
237 def with_loc(self
, loc
):
238 # type: (Loc) -> MergedSSAVal
239 if loc
not in self
.loc_set
:
241 f
"Loc is not allowed -- not a member of `self.loc_set`: "
242 f
"{loc} not in {self.loc_set}")
243 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
244 ssa_val_offsets
=self
.ssa_val_offsets
,
245 loc_set
=LocSet([loc
]))
247 def merged(self
, *others
):
248 # type: (*MergedSSAVal) -> MergedSSAVal
249 retval
= dict(self
.ssa_val_offsets
)
251 if other
.fn_analysis
!= self
.fn_analysis
:
252 raise ValueError("fn_analysis mismatch")
253 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
254 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
255 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
256 f
"{retval[ssa_val]} != {offset}")
257 retval
[ssa_val
] = offset
258 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
259 ssa_val_offsets
=retval
)
262 def live_interval(self
):
263 # type: () -> ProgramRange
264 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
265 start
= live_range
.start
266 stop
= live_range
.stop
267 for ssa_val
in self
.ssa_vals
:
268 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
269 start
= min(start
, live_range
.start
)
270 stop
= max(stop
, live_range
.stop
)
271 return ProgramRange(start
=start
, stop
=stop
)
274 return (f
"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
275 f
"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
276 f
"live_interval={self.live_interval})")
279 def copy_related_ssa_vals(self
):
280 # type: () -> OFSet[SSAVal]
281 sets
= OSet() # type: OSet[OFSet[SSAVal]]
282 # avoid merging the same sets multiple times
283 for ssa_val
in self
.ssa_vals
:
284 sets
.add(self
.fn_analysis
.copy_related_ssa_vals
[ssa_val
])
285 return OFSet(v
for s
in sets
for v
in s
)
287 def get_copy_relation(self
, other
):
288 # type: (MergedSSAVal) -> None | _CopyRelation
289 for lhs_ssa_val
in self
.ssa_vals
:
290 for rhs_ssa_val
in other
.ssa_vals
:
291 for lhs
in lhs_ssa_val
.ssa_val_sub_regs
:
292 for rhs
in rhs_ssa_val
.ssa_val_sub_regs
:
293 lhs_src
= self
.fn_analysis
.copies
.get(lhs
, lhs
)
294 rhs_src
= self
.fn_analysis
.copies
.get(rhs
, rhs
)
295 if lhs_src
== rhs_src
:
299 def copy_merged(self
, lhs_loc
, rhs
, rhs_loc
, copy_relation
):
300 # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
301 cr_lhs
, cr_rhs
= copy_relation
302 if cr_lhs
.ssa_val
not in self
.ssa_vals
:
303 cr_lhs
, cr_rhs
= cr_rhs
, cr_lhs
304 lhs_merged
= self
.with_offset_to_match(
305 cr_lhs
.ssa_val
, additional_offset
=-cr_lhs
.reg_idx
)
306 if lhs_loc
is not None:
307 lhs_merged
= lhs_merged
.with_loc(lhs_loc
)
308 rhs_merged
= rhs
.with_offset_to_match(
309 cr_rhs
.ssa_val
, additional_offset
=-cr_rhs
.reg_idx
)
310 if rhs_loc
is not None:
311 rhs_merged
= rhs_merged
.with_loc(rhs_loc
)
312 return lhs_merged
.merged(rhs_merged
).normalized()
316 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
318 # type: (...) -> None
319 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
320 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
321 _private_merged_ssa_val_map
=self
.__map
)
323 def __getitem__(self
, __key
):
324 # type: (SSAVal) -> MergedSSAVal
325 return self
.__map
[__key
]
328 # type: () -> Iterator[SSAVal]
329 return iter(self
.__map
)
333 return len(self
.__map
)
336 def ig_node_map(self
):
337 # type: () -> MergedSSAValToIGNodeMap
338 return self
.__ig
_node
_map
342 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
343 return f
"SSAValToMergedSSAValMap({{{s}}})"
347 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
350 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
352 # type: (...) -> None
353 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
354 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
355 self
.__next
_node
_id
= 0
357 def __getitem__(self
, __key
):
358 # type: (MergedSSAVal) -> IGNode
359 return self
.__map
[__key
]
362 # type: () -> Iterator[MergedSSAVal]
363 return iter(self
.__map
)
367 return len(self
.__map
)
369 def add_node(self
, merged_ssa_val
):
370 # type: (MergedSSAVal) -> IGNode
371 node
= self
.__map
.get(merged_ssa_val
, None)
374 added
= 0 # type: int | None
376 for ssa_val
in merged_ssa_val
.ssa_vals
:
377 if ssa_val
in self
.__merged
_ssa
_val
_map
:
379 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
380 f
"{merged_ssa_val} and "
381 f
"{self.__merged_ssa_val_map[ssa_val]}")
382 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
385 node_id
=self
.__next
_node
_id
, merged_ssa_val
=merged_ssa_val
,
386 edges
={}, loc
=merged_ssa_val
.only_loc
, ignored
=False)
387 self
.__map
[merged_ssa_val
] = retval
388 self
.__next
_node
_id
+= 1
392 if added
is not None:
393 # remove partially added stuff
394 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
397 del self
.__merged
_ssa
_val
_map
[ssa_val
]
399 def merge_into_one_node(self
, final_merged_ssa_val
):
400 # type: (MergedSSAVal) -> IGNode
401 source_nodes
= OSet() # type: OSet[IGNode]
402 edges
= {} # type: dict[IGNode, IGEdge]
403 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
404 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
405 source_node
= self
.__map
[merged_ssa_val
]
406 if source_node
.ignored
:
407 raise ValueError(f
"can't merge ignored nodes: {source_node}")
408 source_nodes
.add(source_node
)
409 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
411 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
412 f
"but not in merged IGNode's merged_ssa_val: "
413 f
"source_node={source_node} "
414 f
"final_merged_ssa_val={final_merged_ssa_val}")
415 if source_node
.loc
!= source_node
.merged_ssa_val
.only_loc
:
417 f
"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
418 f
"{source_node.loc} != "
419 f
"{source_node.merged_ssa_val.only_loc}")
420 for n
, edge
in source_node
.edges
.items():
422 edge
= edge
.merged(edges
[n
])
424 if len(source_nodes
) == 1:
425 return source_nodes
.pop() # merging a single node is a no-op
426 # we're finished checking validity, now we can modify stuff
427 for n
in source_nodes
:
429 loc
= final_merged_ssa_val
.only_loc
430 for n
, edge
in edges
.items():
431 if edge
.copy_relation
is None or not edge
.interferes
:
434 # if merging works, then the edge can't interfere
435 _
= final_merged_ssa_val
.copy_merged(
436 lhs_loc
=loc
, rhs
=n
.merged_ssa_val
, rhs_loc
=n
.loc
,
437 copy_relation
=edge
.copy_relation
)
438 except BadMergedSSAVal
:
440 edges
[n
] = replace(edge
, interferes
=False)
442 node_id
=self
.__next
_node
_id
, merged_ssa_val
=final_merged_ssa_val
,
443 edges
=edges
, loc
=loc
, ignored
=False)
444 self
.__next
_node
_id
+= 1
447 edge
= reduce(IGEdge
.merged
,
448 (node
.edges
.pop(n
, empty_e
) for n
in source_nodes
))
450 node
.edges
.pop(retval
, None)
452 node
.edges
[retval
] = edge
453 for node
in source_nodes
:
454 del self
.__map
[node
.merged_ssa_val
]
455 self
.__map
[final_merged_ssa_val
] = retval
456 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
457 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
460 def __repr__(self
, repr_state
=None):
461 # type: (None | IGNodeReprState) -> str
462 if repr_state
is None:
463 repr_state
= IGNodeReprState()
464 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
465 return f
"MergedSSAValToIGNodeMap({{{s}}})"
468 @plain_data(frozen
=True, repr=False)
470 class InterferenceGraph
:
471 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
473 def __init__(self
, fn_analysis
, merged_ssa_vals
):
474 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
475 self
.fn_analysis
= fn_analysis
476 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
477 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
478 for i
in merged_ssa_vals
:
479 self
.nodes
.add_node(i
)
481 def merge_preview(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
482 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
483 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
484 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
485 merged
= merged1
.with_offset_to_match(ssa_val1
)
486 return merged
.merged(merged2
.with_offset_to_match(
487 ssa_val2
, additional_offset
=additional_offset
)).normalized()
489 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
490 # type: (SSAVal, SSAVal, int) -> IGNode
491 return self
.nodes
.merge_into_one_node(self
.merge_preview(
492 ssa_val1
=ssa_val1
, ssa_val2
=ssa_val2
,
493 additional_offset
=additional_offset
))
495 def copy_merge(self
, node1
, node2
):
496 # type: (IGNode, IGNode) -> IGNode
497 return self
.nodes
.merge_into_one_node(node1
.copy_merge_preview(node2
))
499 def local_colorability_score(self
, node
, merged_in_copy
=None):
500 # type: (IGNode, None | IGNode) -> int
501 """ returns a positive integer if node is locally colorable, returns
502 zero or a negative integer if node isn't known to be locally
503 colorable, the more negative the value, the less colorable.
505 if `merged_in_copy` is not `None`, then the node used is what would be
506 the result of `self.copy_merge(node, merged_in_copy)`.
510 "can't get local_colorability_score of ignored node")
511 loc_set
= node
.loc_set
513 if merged_in_copy
is not None:
514 if merged_in_copy
.ignored
:
516 "can't get local_colorability_score of ignored node")
517 loc_set
= node
.copy_merge_preview(merged_in_copy
).loc_set
519 for neighbor
, edge
in merged_in_copy
.edges
.items():
520 edges
[neighbor
] = edge
.merged(edges
.get(neighbor
))
521 retval
= len(loc_set
)
522 for neighbor
, edge
in edges
.items():
523 if neighbor
.ignored
or not edge
.interferes
:
525 if neighbor
== merged_in_copy
or neighbor
== node
:
527 retval
-= loc_set
.max_conflicts_with(neighbor
.loc_set
)
531 def minimally_merged(fn_analysis
):
532 # type: (FnAnalysis) -> InterferenceGraph
533 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
534 for op
in fn_analysis
.fn
.ops
:
535 for inp
in op
.input_uses
:
536 if inp
.unspread_start
!= inp
:
537 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
538 additional_offset
=inp
.reg_offset_in_unspread
)
539 for out
in op
.outputs
:
540 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
541 if out
.unspread_start
!= out
:
542 retval
.merge(out
.unspread_start
, out
,
543 additional_offset
=out
.reg_offset_in_unspread
)
544 if out
.tied_input
is not None:
545 retval
.merge(out
.tied_input
.ssa_val
, out
)
548 def __repr__(self
, repr_state
=None):
549 # type: (None | IGNodeReprState) -> str
550 if repr_state
is None:
551 repr_state
= IGNodeReprState()
552 s
= self
.nodes
.__repr
__(repr_state
)
553 return f
"InterferenceGraph(nodes={s}, <...>)"
556 self
, highlighted_nodes
=(), # type: Container[IGNode]
557 node_scores
=None, # type: None | dict[IGNode, int]
558 edge_scores
=None, # type: None | dict[tuple[IGNode, IGNode], int]
563 # type: (object) -> str
565 s
= s
.replace('\\', r
'\\')
566 s
= s
.replace('"', r
'\"')
567 s
= s
.replace('\n', r
'\n')
570 if node_scores
is None:
572 if edge_scores
is None:
575 edges
= {} # type: dict[tuple[IGNode, IGNode], IGEdge]
576 node_ids
= {} # type: dict[IGNode, str]
577 for node
in self
.nodes
.values():
578 node_ids
[node
] = quote(node
.node_id
)
579 for neighbor
, edge
in node
.edges
.items():
580 edge_key
= (node
, neighbor
)
581 # ensure we only insert each edge once by checking for
583 if edge_key
not in edges
and edge_key
[::-1] not in edges
:
584 edges
[edge_key
] = edge
587 " graph [pack = true]",
589 for node
, node_id
in node_ids
.items():
590 label_lines
= [] # type: list[str]
591 score
= node_scores
.get(node
)
592 if score
is not None:
593 label_lines
.append(f
"score={score}")
594 for k
, v
in node
.merged_ssa_val
.ssa_val_offsets
.items():
595 label_lines
.append(f
"{k}: {v}")
596 label
= quote("\n".join(label_lines
))
597 style
= "dotted" if node
.ignored
else "solid"
599 if node
in highlighted_nodes
:
604 lines
.append(f
" {node_id} ["
609 def append_edge(node1
, node2
, label
, color
, style
):
610 # type: (IGNode, IGNode, str, str, str) -> None
614 lines
.append(f
" {node_ids[node1]} -- {node_ids[node2]} ["
619 for (node1
, node2
), edge
in edges
.items():
620 score
= edge_scores
.get((node1
, node2
))
622 score
= edge_scores
.get((node2
, node1
))
624 if score
is not None:
625 label_prefix
= f
"score={score}\n"
627 append_edge(node1
, node2
, label
=label_prefix
+ "interferes",
628 color
="darkred", style
="bold")
629 if edge
.copy_relation
is not None:
630 append_edge(node1
, node2
, label
=label_prefix
+ "copy related",
631 color
="blue", style
="dashed")
633 return "\n".join(lines
)
636 @plain_data(repr=False)
637 class IGNodeReprState
:
638 __slots__
= "did_full_repr",
642 self
.did_full_repr
= OSet() # type: OSet[IGNode]
645 @plain_data(frozen
=True, unsafe_hash
=True)
648 """ interference graph edge """
649 __slots__
= "interferes", "copy_relation"
651 def __init__(self
, interferes
=False, copy_relation
=None):
652 # type: (bool, None | _CopyRelation) -> None
653 self
.interferes
= interferes
654 self
.copy_relation
= copy_relation
656 def merged(self
, other
):
657 # type: (IGEdge | None) -> IGEdge
660 copy_relation
= self
.copy_relation
661 if copy_relation
is None:
662 copy_relation
= other
.copy_relation
663 interferes
= self
.interferes | other
.interferes
664 return IGEdge(interferes
=interferes
, copy_relation
=copy_relation
)
669 """ interference graph node """
670 __slots__
= "node_id", "merged_ssa_val", "edges", "loc", "ignored"
672 def __init__(self
, node_id
, merged_ssa_val
, edges
, loc
, ignored
):
673 # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
674 self
.node_id
= node_id
675 self
.merged_ssa_val
= merged_ssa_val
678 self
.ignored
= ignored
680 def merge_edge(self
, other
, edge
):
681 # type: (IGNode, IGEdge) -> None
683 raise ValueError("can't have self-loops")
684 old_edge
= self
.edges
.get(other
, None)
685 assert old_edge
is other
.edges
.get(self
, None), "inconsistent edges"
686 edge
= edge
.merged(old_edge
)
688 self
.edges
.pop(other
, None)
689 other
.edges
.pop(self
, None)
691 self
.edges
[other
] = edge
692 other
.edges
[self
] = edge
694 def __eq__(self
, other
):
695 # type: (object) -> bool
696 if isinstance(other
, IGNode
):
697 return self
.node_id
== other
.node_id
698 return NotImplemented
702 return hash(self
.node_id
)
704 def __repr__(self
, repr_state
=None, short
=False):
705 # type: (None | IGNodeReprState, bool) -> str
709 rs
= IGNodeReprState()
710 if short
or self
in rs
.did_full_repr
:
711 return f
"<IGNode #{self.node_id}>"
712 rs
.did_full_repr
.add(self
)
714 f
"{k.__repr__(rs, True)}: {v}" for k
, v
in self
.edges
.items())
715 return (f
"IGNode(#{self.node_id}, "
716 f
"merged_ssa_val={self.merged_ssa_val}, "
717 f
"edges={{{edges}}}, "
719 f
"ignored={self.ignored})")
724 return self
.merged_ssa_val
.loc_set
726 def loc_conflicts_with_neighbors(self
, loc
):
727 # type: (Loc) -> bool
728 for neighbor
, edge
in self
.edges
.items():
729 if not edge
.interferes
:
731 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
735 def copy_merge_preview(self
, rhs_node
):
736 # type: (IGNode) -> MergedSSAVal
738 copy_relation
= self
.edges
[rhs_node
].copy_relation
740 raise ValueError("nodes aren't copy related")
741 if copy_relation
is None:
742 raise ValueError("nodes aren't copy related")
743 return self
.merged_ssa_val
.copy_merged(
745 rhs
=rhs_node
.merged_ssa_val
, rhs_loc
=rhs_node
.loc
,
746 copy_relation
=copy_relation
)
749 class AllocationFailedError(Exception):
750 def __init__(self
, msg
, node
, interference_graph
):
751 # type: (str, IGNode, InterferenceGraph) -> None
752 super().__init
__(msg
, node
, interference_graph
)
754 self
.interference_graph
= interference_graph
756 def __repr__(self
, repr_state
=None):
757 # type: (None | IGNodeReprState) -> str
758 if repr_state
is None:
759 repr_state
= IGNodeReprState()
760 return (f
"{__class__.__name__}({self.args[0]!r}, "
761 f
"node={self.node.__repr__(repr_state, True)}, "
762 f
"interference_graph="
763 f
"{self.interference_graph.__repr__(repr_state)})")
767 return self
.__repr
__()
770 def allocate_registers(
772 debug_out
=None, # type: TextIO | None
773 dump_graph
=None, # type: Callable[[str, str], None] | None
775 # type: (...) -> dict[SSAVal, Loc]
777 # inserts enough copies that no manual spilling is necessary, all
778 # spilling is done by the register allocator naturally allocating SSAVals
780 fn
.pre_ra_insert_copies()
782 if debug_out
is not None:
783 print(f
"After pre_ra_insert_copies():\n{fn.ops}",
784 file=debug_out
, flush
=True)
786 fn_analysis
= FnAnalysis(fn
)
787 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
789 if debug_out
is not None:
790 print(f
"After InterferenceGraph.minimally_merged():\n"
791 f
"{interference_graph}", file=debug_out
, flush
=True)
793 for i
, j
in combinations(interference_graph
.nodes
.values(), 2):
794 copy_relation
= i
.merged_ssa_val
.get_copy_relation(j
.merged_ssa_val
)
795 i
.merge_edge(j
, IGEdge(copy_relation
=copy_relation
))
797 for pp
, ssa_vals
in fn_analysis
.live_at
.items():
798 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
799 for ssa_val
in ssa_vals
:
800 live_merged_ssa_vals
.add(
801 interference_graph
.merged_ssa_val_map
[ssa_val
])
802 for i
, j
in combinations(live_merged_ssa_vals
, 2):
803 if i
.loc_set
.max_conflicts_with(j
.loc_set
) == 0:
805 node_i
= interference_graph
.nodes
[i
]
806 node_j
= interference_graph
.nodes
[j
]
807 if node_j
in node_i
.edges
:
808 if node_i
.edges
[node_j
].copy_relation
is not None:
810 _
= node_i
.copy_merge_preview(node_j
)
811 continue # doesn't interfere if copy merging succeeds
812 except BadMergedSSAVal
:
814 node_i
.merge_edge(node_j
, edge
=IGEdge(interferes
=True))
815 if debug_out
is not None:
816 print(f
"processed {pp} out of {fn_analysis.all_program_points}",
817 file=debug_out
, flush
=True)
819 if debug_out
is not None:
820 print(f
"After adding interference graph edges:\n"
821 f
"{interference_graph}", file=debug_out
, flush
=True)
822 if dump_graph
is not None:
823 dump_graph("initial", interference_graph
.dump_to_dot())
825 node_stack
= [] # type: list[IGNode]
827 debug_node_scores
= {} # type: dict[IGNode, int]
828 debug_edge_scores
= {} # type: dict[tuple[IGNode, IGNode], int]
830 def find_best_node(has_copy_relation
):
831 # type: (bool) -> None | IGNode
832 best_node
= None # type: None | IGNode
834 for node
in interference_graph
.nodes
.values():
837 node_has_copy_relation
= False
838 for neighbor
, edge
in node
.edges
.items():
841 if edge
.copy_relation
is not None:
842 node_has_copy_relation
= True
844 if node_has_copy_relation
!= has_copy_relation
:
846 score
= interference_graph
.local_colorability_score(node
)
847 debug_node_scores
[node
] = score
848 if best_node
is None or score
> best_score
:
852 # it's locally colorable, no need to find a better one
854 if debug_out
is not None:
855 print(f
"find_best_node(has_copy_relation={has_copy_relation}):\n"
856 f
"{best_node}", file=debug_out
, flush
=True)
858 # copy-merging algorithm based on Iterated Register Coalescing, section 5:
859 # https://dl.acm.org/doi/pdf/10.1145/229542.229546
860 # Build step is above.
862 debug_node_scores
.clear()
863 debug_edge_scores
.clear()
865 best_node
= find_best_node(has_copy_relation
=False)
866 if best_node
is not None:
867 if dump_graph
is not None:
869 f
"step_{step}_simplify", interference_graph
.dump_to_dot(
870 highlighted_nodes
=[best_node
],
871 node_scores
=debug_node_scores
,
872 edge_scores
=debug_edge_scores
))
873 node_stack
.append(best_node
)
874 best_node
.ignored
= True
876 # Coalesce (aka. do copy-merges):
877 did_any_copy_merges
= False
878 for node
in interference_graph
.nodes
.values():
881 for neighbor
, edge
in node
.edges
.items():
884 if edge
.copy_relation
is None:
887 score
= interference_graph
.local_colorability_score(
888 node
, merged_in_copy
=neighbor
)
889 except BadMergedSSAVal
:
891 if (neighbor
, node
) in debug_edge_scores
:
892 debug_edge_scores
[(neighbor
, node
)] = score
894 debug_edge_scores
[(node
, neighbor
)] = score
895 if score
> 0: # merged node is locally colorable
896 if dump_graph
is not None:
898 f
"step_{step}_copy_merge",
899 interference_graph
.dump_to_dot(
900 highlighted_nodes
=[node
, neighbor
],
901 node_scores
=debug_node_scores
,
902 edge_scores
=debug_edge_scores
))
903 if debug_out
is not None:
904 print(f
"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
905 file=debug_out
, flush
=True)
906 merged_node
= interference_graph
.copy_merge(node
, neighbor
)
907 if dump_graph
is not None:
909 f
"step_{step}_copy_merge_result",
910 interference_graph
.dump_to_dot(
911 highlighted_nodes
=[merged_node
]))
912 if debug_out
is not None:
913 print(f
"merged_node:\n"
914 f
"{merged_node}", file=debug_out
, flush
=True)
915 did_any_copy_merges
= True
917 if did_any_copy_merges
:
919 if did_any_copy_merges
:
922 best_node
= find_best_node(has_copy_relation
=True)
923 if best_node
is not None:
924 if dump_graph
is not None:
925 dump_graph(f
"step_{step}_freeze",
926 interference_graph
.dump_to_dot(
927 highlighted_nodes
=[best_node
],
928 node_scores
=debug_node_scores
,
929 edge_scores
=debug_edge_scores
))
930 # no need to clear copy relations since best_node won't be
931 # considered since it's now ignored.
932 node_stack
.append(best_node
)
933 best_node
.ignored
= True
937 if dump_graph
is not None:
938 dump_graph("final", interference_graph
.dump_to_dot())
939 if debug_out
is not None:
940 print(f
"After deciding node allocation order:\n"
941 f
"{node_stack}", file=debug_out
, flush
=True)
943 retval
= {} # type: dict[SSAVal, Loc]
945 while len(node_stack
) > 0:
946 node
= node_stack
.pop()
947 if node
.loc
is not None:
948 if node
.loc_conflicts_with_neighbors(node
.loc
):
949 raise AllocationFailedError(
950 "IGNode is pre-allocated to a conflicting Loc",
951 node
=node
, interference_graph
=interference_graph
)
953 # Locs to try allocating, ordered from most preferred to least
956 # prefer eliminating copies
957 for neighbor
, edge
in node
.edges
.items():
958 if neighbor
.loc
is None or edge
.copy_relation
is None:
961 merged
= node
.copy_merge_preview(neighbor
)
962 except BadMergedSSAVal
:
964 # get merged_loc if merged.loc_set has a single Loc
965 merged_loc
= merged
.only_loc
966 if merged_loc
is None:
968 ssa_val
= node
.merged_ssa_val
.first_ssa_val
969 ssa_val_loc
= merged_loc
.get_subloc_at_offset(
970 subloc_ty
=ssa_val
.ty
,
971 offset
=merged
.ssa_val_offsets
[ssa_val
])
972 node_loc
= ssa_val_loc
.get_superloc_with_self_at_offset(
973 superloc_ty
=node
.merged_ssa_val
.ty
,
974 offset
=node
.merged_ssa_val
.ssa_val_offsets
[ssa_val
])
975 assert node_loc
in node
.merged_ssa_val
.loc_set
, "logic error"
977 # add node's allowed Locs as fallback
978 for loc
in node
.loc_set
:
979 # TODO: add in order of preference
981 # pick the first non-conflicting register in locs, since locs is
982 # ordered from most preferred to least preferred register.
984 if not node
.loc_conflicts_with_neighbors(loc
):
988 raise AllocationFailedError(
989 "failed to allocate Loc for IGNode",
990 node
=node
, interference_graph
=interference_graph
)
992 if debug_out
is not None:
993 print(f
"After allocating Loc for node:\n{node}",
994 file=debug_out
, flush
=True)
996 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
997 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)
999 if debug_out
is not None:
1000 print(f
"final Locs for all SSAVals:\n{retval}",
1001 file=debug_out
, flush
=True)