2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from dataclasses
import dataclass
9 from functools
import lru_cache
, reduce
10 from itertools
import combinations
, count
11 from typing
import Any
, Callable
, Container
, Iterable
, Iterator
, Mapping
, TextIO
, Tuple
13 from cached_property
import cached_property
14 from nmutil
.plain_data
import plain_data
, replace
16 from bigint_presentation_code
.compiler_ir
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
17 LocSet
, Op
, ProgramRange
,
18 SSAVal
, SSAValSubReg
, Ty
)
19 from bigint_presentation_code
.type_util
import final
20 from bigint_presentation_code
.util
import FMap
, Interned
, OFSet
, OSet
23 class BadMergedSSAVal(ValueError):
27 _CopyRelation
= Tuple
[SSAValSubReg
, SSAValSubReg
]
30 @dataclass(frozen
=True, repr=False, eq
=False)
32 class MergedSSAVal(Interned
):
33 """a set of `SSAVal`s along with their offsets, all register allocated as
36 Definition of the term `offset` for this class:
38 Let `locs[x]` be the `Loc` that `x` is assigned to after register
39 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
40 for each `SSAVal` `ssa_val` in `msv` is defined as:
43 msv.ssa_val_offsets[ssa_val] = (msv.offset
44 + locs[ssa_val].start - locs[msv].start)
52 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
55 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
56 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
57 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
58 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
60 fn_analysis
: FnAnalysis
61 ssa_val_offsets
: "FMap[SSAVal, int]"
66 def __init__(self
, fn_analysis
, ssa_val_offsets
, loc_set
=None):
67 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
68 object.__setattr
__(self
, "fn_analysis", fn_analysis
)
69 if isinstance(ssa_val_offsets
, SSAVal
):
70 ssa_val_offsets
= {ssa_val_offsets
: 0}
71 object.__setattr
__(self
, "ssa_val_offsets", FMap(ssa_val_offsets
))
73 for ssa_val
in self
.ssa_vals
:
74 first_ssa_val
= ssa_val
76 if first_ssa_val
is None:
77 raise BadMergedSSAVal("MergedSSAVal can't be empty")
78 object.__setattr
__(self
, "first_ssa_val", first_ssa_val
)
79 # self.ty checks for mismatched base_ty
80 reg_len
= self
.ty
.reg_len
81 if loc_set
is not None and loc_set
.ty
!= self
.ty
:
83 f
"invalid loc_set, type doesn't match: "
84 f
"{loc_set.ty} != {self.ty}")
85 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
87 # type: () -> Iterable[Loc]
88 for loc
in ssa_val
.def_loc_set_before_spread
:
89 disallowed_by_use
= False
90 for use
in fn_analysis
.uses
[ssa_val
]:
91 # calculate the start for the use's Loc before spread
92 # e.g. if the def's Loc before spread starts at r6
93 # and the def's reg_offset_in_unspread is 5
94 # and the use's reg_offset_in_unspread is 3
95 # then the use's Loc before spread starts at r8
96 # because 8 == 6 + 5 - 3
97 start
= (loc
.start
+ ssa_val
.reg_offset_in_unspread
98 - use
.reg_offset_in_unspread
)
99 use_loc
= Loc
.try_make(
100 loc
.kind
, start
=start
,
101 reg_len
=use
.ty_before_spread
.reg_len
)
102 if (use_loc
is None or
103 use_loc
not in use
.use_loc_set_before_spread
):
104 disallowed_by_use
= True
106 if disallowed_by_use
:
108 start
= loc
.start
- cur_offset
+ self
.offset
109 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
110 if loc
is not None and (loc_set
is None or loc
in loc_set
):
112 loc_set
= LocSet(locs())
113 assert loc_set
is not None, "already checked that self isn't empty"
118 if first_loc
is None:
119 raise BadMergedSSAVal("there are no valid Locs left")
120 object.__setattr
__(self
, "first_loc", first_loc
)
121 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
122 object.__setattr
__(self
, "loc_set", loc_set
)
123 self
.__mergable
_check
()
125 def __mergable_check(self
):
127 """ checks that nothing is forcing two independent SSAVals
128 to illegally overlap. This is required to avoid copy merging merging
129 things that can't be merged.
130 spread arguments are one of the things that can force two values to
133 ops
= OSet() # type: Iterable[Op]
134 for ssa_val
in self
.ssa_vals
:
136 for use
in self
.fn_analysis
.uses
[ssa_val
]:
138 ops
= sorted(ops
, key
=self
.fn_analysis
.op_indexes
.__getitem
__)
139 vals
= {} # type: dict[int, SSAValSubReg]
141 for inp
in op
.input_vals
:
143 ssa_val_offset
= self
.ssa_val_offsets
[inp
]
146 for orig_reg
in inp
.ssa_val_sub_regs
:
147 reg_offset
= ssa_val_offset
+ orig_reg
.reg_idx
148 replaced_reg
= vals
[reg_offset
]
149 if not self
.fn_analysis
.is_always_equal(
150 orig_reg
, replaced_reg
):
151 raise BadMergedSSAVal(
152 f
"attempting to merge values that aren't known to "
153 f
"be always equal: {orig_reg} != {replaced_reg}")
154 output_offsets
= dict.fromkeys(range(
155 self
.offset
, self
.offset
+ self
.ty
.reg_len
))
156 for out
in op
.outputs
:
158 ssa_val_offset
= self
.ssa_val_offsets
[out
]
161 for reg
in out
.ssa_val_sub_regs
:
162 reg_offset
= ssa_val_offset
+ reg
.reg_idx
164 del output_offsets
[reg_offset
]
166 raise BadMergedSSAVal("attempted to merge two outputs "
167 "of the same instruction")
168 vals
[reg_offset
] = reg
172 return hash((self
.fn_analysis
, self
.ssa_val_offsets
, self
.loc_set
))
174 def __eq__(self
, other
):
175 # type: (MergedSSAVal | Any) -> bool
176 if isinstance(other
, MergedSSAVal
):
177 return self
.fn_analysis
== other
.fn_analysis
and \
178 self
.ssa_val_offsets
== other
.ssa_val_offsets
and \
179 self
.loc_set
== other
.loc_set
180 return NotImplemented
184 # type: () -> Loc | None
185 return self
.loc_set
.only_loc
190 return min(self
.ssa_val_offsets_before_spread
.values())
195 return self
.first_ssa_val
.base_ty
199 # type: () -> OFSet[SSAVal]
200 return OFSet(self
.ssa_val_offsets
.keys())
206 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
207 cur_ty
= ssa_val
.ty_before_spread
208 if self
.base_ty
!= cur_ty
.base_ty
:
209 raise BadMergedSSAVal(
210 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
211 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
212 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
215 def ssa_val_offsets_before_spread(self
):
216 # type: () -> FMap[SSAVal, int]
217 retval
= {} # type: dict[SSAVal, int]
218 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
220 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
223 def offset_by(self
, amount
):
224 # type: (int) -> MergedSSAVal
225 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
226 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
228 def normalized(self
):
229 # type: () -> MergedSSAVal
230 return self
.offset_by(-self
.offset
)
232 def with_offset_to_match(self
, target
, additional_offset
=0):
233 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
234 if isinstance(target
, MergedSSAVal
):
235 ssa_val_offsets
= target
.ssa_val_offsets
237 ssa_val_offsets
= {target
: 0}
238 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
239 if ssa_val
in ssa_val_offsets
:
240 return self
.offset_by(
241 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
242 raise ValueError("can't change offset to match unrelated MergedSSAVal")
244 def with_loc(self
, loc
):
245 # type: (Loc) -> MergedSSAVal
246 if loc
not in self
.loc_set
:
248 f
"Loc is not allowed -- not a member of `self.loc_set`: "
249 f
"{loc} not in {self.loc_set}")
250 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
251 ssa_val_offsets
=self
.ssa_val_offsets
,
252 loc_set
=LocSet([loc
]))
254 def merged(self
, *others
):
255 # type: (*MergedSSAVal) -> MergedSSAVal
256 retval
= dict(self
.ssa_val_offsets
)
258 if other
.fn_analysis
!= self
.fn_analysis
:
259 raise ValueError("fn_analysis mismatch")
260 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
261 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
262 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
263 f
"{retval[ssa_val]} != {offset}")
264 retval
[ssa_val
] = offset
265 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
266 ssa_val_offsets
=retval
)
269 def live_interval(self
):
270 # type: () -> ProgramRange
271 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
272 start
= live_range
.start
273 stop
= live_range
.stop
274 for ssa_val
in self
.ssa_vals
:
275 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
276 start
= min(start
, live_range
.start
)
277 stop
= max(stop
, live_range
.stop
)
278 return ProgramRange(start
=start
, stop
=stop
)
281 return (f
"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
282 f
"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
283 f
"live_interval={self.live_interval})")
286 def copy_related_ssa_vals(self
):
287 # type: () -> OFSet[SSAVal]
288 sets
= OSet() # type: OSet[OFSet[SSAVal]]
289 # avoid merging the same sets multiple times
290 for ssa_val
in self
.ssa_vals
:
291 sets
.add(self
.fn_analysis
.copy_related_ssa_vals
[ssa_val
])
292 return OFSet(v
for s
in sets
for v
in s
)
294 def get_copy_relation(self
, other
):
295 # type: (MergedSSAVal) -> None | _CopyRelation
296 for lhs_ssa_val
in self
.ssa_vals
:
297 for rhs_ssa_val
in other
.ssa_vals
:
298 for lhs
in lhs_ssa_val
.ssa_val_sub_regs
:
299 for rhs
in rhs_ssa_val
.ssa_val_sub_regs
:
300 lhs_src
= self
.fn_analysis
.copies
.get(lhs
, lhs
)
301 rhs_src
= self
.fn_analysis
.copies
.get(rhs
, rhs
)
302 if lhs_src
== rhs_src
:
306 @lru_cache(maxsize
=None, typed
=True)
307 def copy_merged(self
, lhs_loc
, rhs
, rhs_loc
, copy_relation
):
308 # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
309 cr_lhs
, cr_rhs
= copy_relation
310 if cr_lhs
.ssa_val
not in self
.ssa_vals
:
311 cr_lhs
, cr_rhs
= cr_rhs
, cr_lhs
312 lhs_merged
= self
.with_offset_to_match(
313 cr_lhs
.ssa_val
, additional_offset
=-cr_lhs
.reg_idx
)
314 if lhs_loc
is not None:
315 lhs_merged
= lhs_merged
.with_loc(lhs_loc
)
316 rhs_merged
= rhs
.with_offset_to_match(
317 cr_rhs
.ssa_val
, additional_offset
=-cr_rhs
.reg_idx
)
318 if rhs_loc
is not None:
319 rhs_merged
= rhs_merged
.with_loc(rhs_loc
)
320 return lhs_merged
.merged(rhs_merged
).normalized()
324 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
326 # type: (...) -> None
327 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
328 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
329 _private_merged_ssa_val_map
=self
.__map
)
331 def __getitem__(self
, __key
):
332 # type: (SSAVal) -> MergedSSAVal
333 return self
.__map
[__key
]
336 # type: () -> Iterator[SSAVal]
337 return iter(self
.__map
)
341 return len(self
.__map
)
344 def ig_node_map(self
):
345 # type: () -> MergedSSAValToIGNodeMap
346 return self
.__ig
_node
_map
350 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
351 return f
"SSAValToMergedSSAValMap({{{s}}})"
355 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
358 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
360 # type: (...) -> None
361 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
362 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
363 self
.__next
_node
_id
= 0
365 def __getitem__(self
, __key
):
366 # type: (MergedSSAVal) -> IGNode
367 return self
.__map
[__key
]
370 # type: () -> Iterator[MergedSSAVal]
371 return iter(self
.__map
)
375 return len(self
.__map
)
377 def add_node(self
, merged_ssa_val
):
378 # type: (MergedSSAVal) -> IGNode
379 node
= self
.__map
.get(merged_ssa_val
, None)
382 added
= 0 # type: int | None
384 for ssa_val
in merged_ssa_val
.ssa_vals
:
385 if ssa_val
in self
.__merged
_ssa
_val
_map
:
387 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
388 f
"{merged_ssa_val} and "
389 f
"{self.__merged_ssa_val_map[ssa_val]}")
390 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
393 node_id
=self
.__next
_node
_id
, merged_ssa_val
=merged_ssa_val
,
394 edges
={}, loc
=merged_ssa_val
.only_loc
, ignored
=False)
395 self
.__map
[merged_ssa_val
] = retval
396 self
.__next
_node
_id
+= 1
400 if added
is not None:
401 # remove partially added stuff
402 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
405 del self
.__merged
_ssa
_val
_map
[ssa_val
]
407 def merge_into_one_node(self
, final_merged_ssa_val
):
408 # type: (MergedSSAVal) -> IGNode
409 source_nodes
= OSet() # type: OSet[IGNode]
410 edges
= {} # type: dict[IGNode, IGEdge]
411 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
412 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
413 source_node
= self
.__map
[merged_ssa_val
]
414 if source_node
.ignored
:
415 raise ValueError(f
"can't merge ignored nodes: {source_node}")
416 source_nodes
.add(source_node
)
417 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
419 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
420 f
"but not in merged IGNode's merged_ssa_val: "
421 f
"source_node={source_node} "
422 f
"final_merged_ssa_val={final_merged_ssa_val}")
423 if source_node
.loc
!= source_node
.merged_ssa_val
.only_loc
:
425 f
"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
426 f
"{source_node.loc} != "
427 f
"{source_node.merged_ssa_val.only_loc}")
428 for n
, edge
in source_node
.edges
.items():
430 edge
= edge
.merged(edges
[n
])
432 if len(source_nodes
) == 1:
433 return source_nodes
.pop() # merging a single node is a no-op
434 # we're finished checking validity, now we can modify stuff
435 for n
in source_nodes
:
437 loc
= final_merged_ssa_val
.only_loc
438 for n
, edge
in edges
.items():
439 if edge
.copy_relation
is None or not edge
.interferes
:
442 # if merging works, then the edge can't interfere
443 _
= final_merged_ssa_val
.copy_merged(
444 lhs_loc
=loc
, rhs
=n
.merged_ssa_val
, rhs_loc
=n
.loc
,
445 copy_relation
=edge
.copy_relation
)
446 except BadMergedSSAVal
:
448 edges
[n
] = replace(edge
, interferes
=False)
450 node_id
=self
.__next
_node
_id
, merged_ssa_val
=final_merged_ssa_val
,
451 edges
=edges
, loc
=loc
, ignored
=False)
452 self
.__next
_node
_id
+= 1
455 edge
= reduce(IGEdge
.merged
,
456 (node
.edges
.pop(n
, empty_e
) for n
in source_nodes
))
458 node
.edges
.pop(retval
, None)
460 node
.edges
[retval
] = edge
461 for node
in source_nodes
:
462 del self
.__map
[node
.merged_ssa_val
]
463 self
.__map
[final_merged_ssa_val
] = retval
464 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
465 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
468 def __repr__(self
, repr_state
=None):
469 # type: (None | IGNodeReprState) -> str
470 if repr_state
is None:
471 repr_state
= IGNodeReprState()
472 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
473 return f
"MergedSSAValToIGNodeMap({{{s}}})"
476 @plain_data(frozen
=True, repr=False)
478 class InterferenceGraph
:
479 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
481 def __init__(self
, fn_analysis
, merged_ssa_vals
):
482 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
483 self
.fn_analysis
= fn_analysis
484 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
485 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
486 for i
in merged_ssa_vals
:
487 self
.nodes
.add_node(i
)
489 def merge_preview(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
490 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
491 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
492 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
493 merged
= merged1
.with_offset_to_match(ssa_val1
)
494 return merged
.merged(merged2
.with_offset_to_match(
495 ssa_val2
, additional_offset
=additional_offset
)).normalized()
497 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
498 # type: (SSAVal, SSAVal, int) -> IGNode
499 return self
.nodes
.merge_into_one_node(self
.merge_preview(
500 ssa_val1
=ssa_val1
, ssa_val2
=ssa_val2
,
501 additional_offset
=additional_offset
))
503 def copy_merge(self
, node1
, node2
):
504 # type: (IGNode, IGNode) -> IGNode
505 return self
.nodes
.merge_into_one_node(node1
.copy_merge_preview(node2
))
507 def local_colorability_score(self
, node
, merged_in_copy
=None):
508 # type: (IGNode, None | IGNode) -> int
509 """ returns a positive integer if node is locally colorable, returns
510 zero or a negative integer if node isn't known to be locally
511 colorable, the more negative the value, the less colorable.
513 if `merged_in_copy` is not `None`, then the node used is what would be
514 the result of `self.copy_merge(node, merged_in_copy)`.
518 "can't get local_colorability_score of ignored node")
519 loc_set
= node
.loc_set
521 if merged_in_copy
is not None:
522 if merged_in_copy
.ignored
:
524 "can't get local_colorability_score of ignored node")
525 loc_set
= node
.copy_merge_preview(merged_in_copy
).loc_set
527 for neighbor
, edge
in merged_in_copy
.edges
.items():
528 edges
[neighbor
] = edge
.merged(edges
.get(neighbor
))
529 retval
= len(loc_set
)
530 for neighbor
, edge
in edges
.items():
531 if neighbor
.ignored
or not edge
.interferes
:
533 if neighbor
== merged_in_copy
or neighbor
== node
:
535 retval
-= loc_set
.max_conflicts_with(neighbor
.loc_set
)
539 def minimally_merged(fn_analysis
):
540 # type: (FnAnalysis) -> InterferenceGraph
541 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
542 for op
in fn_analysis
.fn
.ops
:
543 for inp
in op
.input_uses
:
544 if inp
.unspread_start
!= inp
:
545 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
546 additional_offset
=inp
.reg_offset_in_unspread
)
547 for out
in op
.outputs
:
548 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
549 if out
.unspread_start
!= out
:
550 retval
.merge(out
.unspread_start
, out
,
551 additional_offset
=out
.reg_offset_in_unspread
)
552 if out
.tied_input
is not None:
553 retval
.merge(out
.tied_input
.ssa_val
, out
)
556 def __repr__(self
, repr_state
=None):
557 # type: (None | IGNodeReprState) -> str
558 if repr_state
is None:
559 repr_state
= IGNodeReprState()
560 s
= self
.nodes
.__repr
__(repr_state
)
561 return f
"InterferenceGraph(nodes={s}, <...>)"
564 self
, highlighted_nodes
=(), # type: Container[IGNode]
565 node_scores
=None, # type: None | dict[IGNode, int]
566 edge_scores
=None, # type: None | dict[tuple[IGNode, IGNode], int]
571 # type: (object) -> str
573 s
= s
.replace('\\', r
'\\')
574 s
= s
.replace('"', r
'\"')
575 s
= s
.replace('\n', r
'\n')
578 if node_scores
is None:
580 if edge_scores
is None:
583 edges
= {} # type: dict[tuple[IGNode, IGNode], IGEdge]
584 node_ids
= {} # type: dict[IGNode, str]
585 for node
in self
.nodes
.values():
586 node_ids
[node
] = quote(node
.node_id
)
587 for neighbor
, edge
in node
.edges
.items():
588 edge_key
= (node
, neighbor
)
589 # ensure we only insert each edge once by checking for
591 if edge_key
not in edges
and edge_key
[::-1] not in edges
:
592 edges
[edge_key
] = edge
595 " graph [pack = true]",
597 for node
, node_id
in node_ids
.items():
598 label_lines
= [] # type: list[str]
599 score
= node_scores
.get(node
)
600 if score
is not None:
601 label_lines
.append(f
"score={score}")
602 for k
, v
in node
.merged_ssa_val
.ssa_val_offsets
.items():
603 label_lines
.append(f
"{k}: {v}")
604 label
= quote("\n".join(label_lines
))
605 style
= "dotted" if node
.ignored
else "solid"
607 if node
in highlighted_nodes
:
612 lines
.append(f
" {node_id} ["
617 def append_edge(node1
, node2
, label
, color
, style
):
618 # type: (IGNode, IGNode, str, str, str) -> None
622 lines
.append(f
" {node_ids[node1]} -- {node_ids[node2]} ["
627 for (node1
, node2
), edge
in edges
.items():
628 score
= edge_scores
.get((node1
, node2
))
630 score
= edge_scores
.get((node2
, node1
))
632 if score
is not None:
633 label_prefix
= f
"score={score}\n"
635 append_edge(node1
, node2
, label
=label_prefix
+ "interferes",
636 color
="darkred", style
="bold")
637 if edge
.copy_relation
is not None:
638 append_edge(node1
, node2
, label
=label_prefix
+ "copy related",
639 color
="blue", style
="dashed")
641 return "\n".join(lines
)
644 @plain_data(repr=False)
645 class IGNodeReprState
:
646 __slots__
= "did_full_repr",
650 self
.did_full_repr
= OSet() # type: OSet[IGNode]
653 @plain_data(frozen
=True, unsafe_hash
=True)
656 """ interference graph edge """
657 __slots__
= "interferes", "copy_relation"
659 def __init__(self
, interferes
=False, copy_relation
=None):
660 # type: (bool, None | _CopyRelation) -> None
661 self
.interferes
= interferes
662 self
.copy_relation
= copy_relation
664 def merged(self
, other
):
665 # type: (IGEdge | None) -> IGEdge
668 copy_relation
= self
.copy_relation
669 if copy_relation
is None:
670 copy_relation
= other
.copy_relation
671 interferes
= self
.interferes | other
.interferes
672 return IGEdge(interferes
=interferes
, copy_relation
=copy_relation
)
677 """ interference graph node """
678 __slots__
= "node_id", "merged_ssa_val", "edges", "loc", "ignored"
680 def __init__(self
, node_id
, merged_ssa_val
, edges
, loc
, ignored
):
681 # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None
682 self
.node_id
= node_id
683 self
.merged_ssa_val
= merged_ssa_val
686 self
.ignored
= ignored
688 def merge_edge(self
, other
, edge
):
689 # type: (IGNode, IGEdge) -> None
691 raise ValueError("can't have self-loops")
692 old_edge
= self
.edges
.get(other
, None)
693 assert old_edge
is other
.edges
.get(self
, None), "inconsistent edges"
694 edge
= edge
.merged(old_edge
)
696 self
.edges
.pop(other
, None)
697 other
.edges
.pop(self
, None)
699 self
.edges
[other
] = edge
700 other
.edges
[self
] = edge
702 def __eq__(self
, other
):
703 # type: (object) -> bool
704 if isinstance(other
, IGNode
):
705 return self
.node_id
== other
.node_id
706 return NotImplemented
710 return hash(self
.node_id
)
712 def __repr__(self
, repr_state
=None, short
=False):
713 # type: (None | IGNodeReprState, bool) -> str
717 rs
= IGNodeReprState()
718 if short
or self
in rs
.did_full_repr
:
719 return f
"<IGNode #{self.node_id}>"
720 rs
.did_full_repr
.add(self
)
722 f
"{k.__repr__(rs, True)}: {v}" for k
, v
in self
.edges
.items())
723 return (f
"IGNode(#{self.node_id}, "
724 f
"merged_ssa_val={self.merged_ssa_val}, "
725 f
"edges={{{edges}}}, "
727 f
"ignored={self.ignored})")
732 return self
.merged_ssa_val
.loc_set
734 def loc_conflicts_with_neighbors(self
, loc
):
735 # type: (Loc) -> bool
736 for neighbor
, edge
in self
.edges
.items():
737 if not edge
.interferes
:
739 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
743 def copy_merge_preview(self
, rhs_node
):
744 # type: (IGNode) -> MergedSSAVal
746 copy_relation
= self
.edges
[rhs_node
].copy_relation
748 raise ValueError("nodes aren't copy related")
749 if copy_relation
is None:
750 raise ValueError("nodes aren't copy related")
751 return self
.merged_ssa_val
.copy_merged(
753 rhs
=rhs_node
.merged_ssa_val
, rhs_loc
=rhs_node
.loc
,
754 copy_relation
=copy_relation
)
757 class AllocationFailedError(Exception):
758 def __init__(self
, msg
, node
, interference_graph
):
759 # type: (str, IGNode, InterferenceGraph) -> None
760 super().__init
__(msg
, node
, interference_graph
)
762 self
.interference_graph
= interference_graph
764 def __repr__(self
, repr_state
=None):
765 # type: (None | IGNodeReprState) -> str
766 if repr_state
is None:
767 repr_state
= IGNodeReprState()
768 return (f
"{__class__.__name__}({self.args[0]!r}, "
769 f
"node={self.node.__repr__(repr_state, True)}, "
770 f
"interference_graph="
771 f
"{self.interference_graph.__repr__(repr_state)})")
775 return self
.__repr
__()
778 def allocate_registers(
780 debug_out
=None, # type: TextIO | None
781 dump_graph
=None, # type: Callable[[str, str], None] | None
783 # type: (...) -> dict[SSAVal, Loc]
785 # inserts enough copies that no manual spilling is necessary, all
786 # spilling is done by the register allocator naturally allocating SSAVals
788 fn
.pre_ra_insert_copies()
790 if debug_out
is not None:
791 print(f
"After pre_ra_insert_copies():\n{fn.ops}",
792 file=debug_out
, flush
=True)
794 fn_analysis
= FnAnalysis(fn
)
795 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
797 if debug_out
is not None:
798 print(f
"After InterferenceGraph.minimally_merged():\n"
799 f
"{interference_graph}", file=debug_out
, flush
=True)
801 for i
, j
in combinations(interference_graph
.nodes
.values(), 2):
802 copy_relation
= i
.merged_ssa_val
.get_copy_relation(j
.merged_ssa_val
)
803 i
.merge_edge(j
, IGEdge(copy_relation
=copy_relation
))
805 for pp
, ssa_vals
in fn_analysis
.live_at
.items():
806 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
807 for ssa_val
in ssa_vals
:
808 live_merged_ssa_vals
.add(
809 interference_graph
.merged_ssa_val_map
[ssa_val
])
810 for i
, j
in combinations(live_merged_ssa_vals
, 2):
811 if i
.loc_set
.max_conflicts_with(j
.loc_set
) == 0:
813 node_i
= interference_graph
.nodes
[i
]
814 node_j
= interference_graph
.nodes
[j
]
815 if node_j
in node_i
.edges
:
816 if node_i
.edges
[node_j
].copy_relation
is not None:
818 _
= node_i
.copy_merge_preview(node_j
)
819 continue # doesn't interfere if copy merging succeeds
820 except BadMergedSSAVal
:
822 node_i
.merge_edge(node_j
, edge
=IGEdge(interferes
=True))
823 if debug_out
is not None:
824 print(f
"processed {pp} out of {fn_analysis.all_program_points}",
825 file=debug_out
, flush
=True)
827 if debug_out
is not None:
828 print(f
"After adding interference graph edges:\n"
829 f
"{interference_graph}", file=debug_out
, flush
=True)
830 if dump_graph
is not None:
831 dump_graph("initial", interference_graph
.dump_to_dot())
833 node_stack
= [] # type: list[IGNode]
835 debug_node_scores
= {} # type: dict[IGNode, int]
836 debug_edge_scores
= {} # type: dict[tuple[IGNode, IGNode], int]
838 def find_best_node(has_copy_relation
):
839 # type: (bool) -> None | IGNode
840 best_node
= None # type: None | IGNode
842 for node
in interference_graph
.nodes
.values():
845 node_has_copy_relation
= False
846 for neighbor
, edge
in node
.edges
.items():
849 if edge
.copy_relation
is not None:
850 node_has_copy_relation
= True
852 if node_has_copy_relation
!= has_copy_relation
:
854 score
= interference_graph
.local_colorability_score(node
)
855 debug_node_scores
[node
] = score
856 if best_node
is None or score
> best_score
:
860 # it's locally colorable, no need to find a better one
862 if debug_out
is not None:
863 print(f
"find_best_node(has_copy_relation={has_copy_relation}):\n"
864 f
"{best_node}", file=debug_out
, flush
=True)
866 # copy-merging algorithm based on Iterated Register Coalescing, section 5:
867 # https://dl.acm.org/doi/pdf/10.1145/229542.229546
868 # Build step is above.
870 debug_node_scores
.clear()
871 debug_edge_scores
.clear()
873 best_node
= find_best_node(has_copy_relation
=False)
874 if best_node
is not None:
875 if dump_graph
is not None:
877 f
"step_{step}_simplify", interference_graph
.dump_to_dot(
878 highlighted_nodes
=[best_node
],
879 node_scores
=debug_node_scores
,
880 edge_scores
=debug_edge_scores
))
881 node_stack
.append(best_node
)
882 best_node
.ignored
= True
884 # Coalesce (aka. do copy-merges):
885 did_any_copy_merges
= False
886 for node
in interference_graph
.nodes
.values():
889 for neighbor
, edge
in node
.edges
.items():
892 if edge
.copy_relation
is None:
895 score
= interference_graph
.local_colorability_score(
896 node
, merged_in_copy
=neighbor
)
897 except BadMergedSSAVal
:
899 if (neighbor
, node
) in debug_edge_scores
:
900 debug_edge_scores
[(neighbor
, node
)] = score
902 debug_edge_scores
[(node
, neighbor
)] = score
903 if score
> 0: # merged node is locally colorable
904 if dump_graph
is not None:
906 f
"step_{step}_copy_merge",
907 interference_graph
.dump_to_dot(
908 highlighted_nodes
=[node
, neighbor
],
909 node_scores
=debug_node_scores
,
910 edge_scores
=debug_edge_scores
))
911 if debug_out
is not None:
912 print(f
"\nCopy-merging:\n{node}\nwith:\n{neighbor}",
913 file=debug_out
, flush
=True)
914 merged_node
= interference_graph
.copy_merge(node
, neighbor
)
915 if dump_graph
is not None:
917 f
"step_{step}_copy_merge_result",
918 interference_graph
.dump_to_dot(
919 highlighted_nodes
=[merged_node
]))
920 if debug_out
is not None:
921 print(f
"merged_node:\n"
922 f
"{merged_node}", file=debug_out
, flush
=True)
923 did_any_copy_merges
= True
925 if did_any_copy_merges
:
927 if did_any_copy_merges
:
930 best_node
= find_best_node(has_copy_relation
=True)
931 if best_node
is not None:
932 if dump_graph
is not None:
933 dump_graph(f
"step_{step}_freeze",
934 interference_graph
.dump_to_dot(
935 highlighted_nodes
=[best_node
],
936 node_scores
=debug_node_scores
,
937 edge_scores
=debug_edge_scores
))
938 # no need to clear copy relations since best_node won't be
939 # considered since it's now ignored.
940 node_stack
.append(best_node
)
941 best_node
.ignored
= True
945 if dump_graph
is not None:
946 dump_graph("final", interference_graph
.dump_to_dot())
947 if debug_out
is not None:
948 print(f
"After deciding node allocation order:\n"
949 f
"{node_stack}", file=debug_out
, flush
=True)
951 retval
= {} # type: dict[SSAVal, Loc]
953 while len(node_stack
) > 0:
954 node
= node_stack
.pop()
955 if node
.loc
is not None:
956 if node
.loc_conflicts_with_neighbors(node
.loc
):
957 raise AllocationFailedError(
958 "IGNode is pre-allocated to a conflicting Loc",
959 node
=node
, interference_graph
=interference_graph
)
961 # Locs to try allocating, ordered from most preferred to least
964 # prefer eliminating copies
965 for neighbor
, edge
in node
.edges
.items():
966 if neighbor
.loc
is None or edge
.copy_relation
is None:
969 merged
= node
.copy_merge_preview(neighbor
)
970 except BadMergedSSAVal
:
972 # get merged_loc if merged.loc_set has a single Loc
973 merged_loc
= merged
.only_loc
974 if merged_loc
is None:
976 ssa_val
= node
.merged_ssa_val
.first_ssa_val
977 ssa_val_loc
= merged_loc
.get_subloc_at_offset(
978 subloc_ty
=ssa_val
.ty
,
979 offset
=merged
.ssa_val_offsets
[ssa_val
])
980 node_loc
= ssa_val_loc
.get_superloc_with_self_at_offset(
981 superloc_ty
=node
.merged_ssa_val
.ty
,
982 offset
=node
.merged_ssa_val
.ssa_val_offsets
[ssa_val
])
983 assert node_loc
in node
.merged_ssa_val
.loc_set
, "logic error"
985 # add node's allowed Locs as fallback
986 for loc
in node
.loc_set
:
987 # TODO: add in order of preference
989 # pick the first non-conflicting register in locs, since locs is
990 # ordered from most preferred to least preferred register.
992 if not node
.loc_conflicts_with_neighbors(loc
):
996 raise AllocationFailedError(
997 "failed to allocate Loc for IGNode",
998 node
=node
, interference_graph
=interference_graph
)
1000 if debug_out
is not None:
1001 print(f
"After allocating Loc for node:\n{node}",
1002 file=debug_out
, flush
=True)
1004 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
1005 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)
1007 if debug_out
is not None:
1008 print(f
"final Locs for all SSAVals:\n{retval}",
1009 file=debug_out
, flush
=True)