2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Iterable
, Iterator
, Mapping
, TextIO
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
15 LocSet
, ProgramRange
, SSAVal
,
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, InternedMeta
, OFSet
, OSet
21 class BadMergedSSAVal(ValueError):
25 @plain_data(frozen
=True, repr=False)
27 class MergedSSAVal(metaclass
=InternedMeta
):
28 """a set of `SSAVal`s along with their offsets, all register allocated as
31 Definition of the term `offset` for this class:
33 Let `locs[x]` be the `Loc` that `x` is assigned to after register
34 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
35 for each `SSAVal` `ssa_val` in `msv` is defined as:
38 msv.ssa_val_offsets[ssa_val] = (msv.offset
39 + locs[ssa_val].start - locs[msv].start)
47 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
50 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
51 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
52 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
53 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
55 __slots__
= "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
57 def __init__(self
, fn_analysis
, ssa_val_offsets
):
58 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
59 self
.fn_analysis
= fn_analysis
60 if isinstance(ssa_val_offsets
, SSAVal
):
61 ssa_val_offsets
= {ssa_val_offsets
: 0}
62 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
64 for ssa_val
in self
.ssa_vals
:
65 first_ssa_val
= ssa_val
67 if first_ssa_val
is None:
68 raise BadMergedSSAVal("MergedSSAVal can't be empty")
69 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
70 # self.ty checks for mismatched base_ty
71 reg_len
= self
.ty
.reg_len
72 loc_set
= None # type: None | LocSet
73 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
75 # type: () -> Iterable[Loc]
76 for loc
in ssa_val
.def_loc_set_before_spread
:
77 disallowed_by_use
= False
78 for use
in fn_analysis
.uses
[ssa_val
]:
79 # calculate the start for the use's Loc before spread
80 # e.g. if the def's Loc before spread starts at r6
81 # and the def's reg_offset_in_unspread is 5
82 # and the use's reg_offset_in_unspread is 3
83 # then the use's Loc before spread starts at r8
84 # because 8 == 6 + 5 - 3
85 start
= (loc
.start
+ ssa_val
.reg_offset_in_unspread
86 - use
.reg_offset_in_unspread
)
87 use_loc
= Loc
.try_make(
88 loc
.kind
, start
=start
,
89 reg_len
=use
.ty_before_spread
.reg_len
)
90 if (use_loc
is None or
91 use_loc
not in use
.use_loc_set_before_spread
):
92 disallowed_by_use
= True
96 # FIXME: add spread consistency check
97 start
= loc
.start
- cur_offset
+ self
.offset
98 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
99 if loc
is not None and (loc_set
is None or loc
in loc_set
):
101 loc_set
= LocSet(locs())
102 assert loc_set
is not None, "already checked that self isn't empty"
103 if loc_set
.ty
is None:
104 raise BadMergedSSAVal("there are no valid Locs left")
105 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
106 self
.loc_set
= loc_set
# type: LocSet
111 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
120 return min(self
.ssa_val_offsets_before_spread
.values())
125 return self
.first_ssa_val
.base_ty
129 # type: () -> OFSet[SSAVal]
130 return OFSet(self
.ssa_val_offsets
.keys())
136 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
137 cur_ty
= ssa_val
.ty_before_spread
138 if self
.base_ty
!= cur_ty
.base_ty
:
139 raise BadMergedSSAVal(
140 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
141 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
142 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
145 def ssa_val_offsets_before_spread(self
):
146 # type: () -> FMap[SSAVal, int]
147 retval
= {} # type: dict[SSAVal, int]
148 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
150 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
153 def offset_by(self
, amount
):
154 # type: (int) -> MergedSSAVal
155 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
156 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
158 def normalized(self
):
159 # type: () -> MergedSSAVal
160 return self
.offset_by(-self
.offset
)
162 def with_offset_to_match(self
, target
, additional_offset
=0):
163 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
164 if isinstance(target
, MergedSSAVal
):
165 ssa_val_offsets
= target
.ssa_val_offsets
167 ssa_val_offsets
= {target
: 0}
168 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
169 if ssa_val
in ssa_val_offsets
:
170 return self
.offset_by(
171 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
172 raise ValueError("can't change offset to match unrelated MergedSSAVal")
174 def merged(self
, *others
):
175 # type: (*MergedSSAVal) -> MergedSSAVal
176 retval
= dict(self
.ssa_val_offsets
)
178 if other
.fn_analysis
!= self
.fn_analysis
:
179 raise ValueError("fn_analysis mismatch")
180 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
181 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
182 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
183 f
"{retval[ssa_val]} != {offset}")
184 retval
[ssa_val
] = offset
185 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
186 ssa_val_offsets
=retval
)
189 def live_interval(self
):
190 # type: () -> ProgramRange
191 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
192 start
= live_range
.start
193 stop
= live_range
.stop
194 for ssa_val
in self
.ssa_vals
:
195 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
196 start
= min(start
, live_range
.start
)
197 stop
= max(stop
, live_range
.stop
)
198 return ProgramRange(start
=start
, stop
=stop
)
201 return (f
"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
202 f
"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
203 f
"live_interval={self.live_interval})")
207 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
209 # type: (...) -> None
210 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
211 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
212 _private_merged_ssa_val_map
=self
.__map
)
214 def __getitem__(self
, __key
):
215 # type: (SSAVal) -> MergedSSAVal
216 return self
.__map
[__key
]
219 # type: () -> Iterator[SSAVal]
220 return iter(self
.__map
)
224 return len(self
.__map
)
227 def ig_node_map(self
):
228 # type: () -> MergedSSAValToIGNodeMap
229 return self
.__ig
_node
_map
233 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
234 return f
"SSAValToMergedSSAValMap({{{s}}})"
238 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
241 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
243 # type: (...) -> None
244 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
245 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
247 def __getitem__(self
, __key
):
248 # type: (MergedSSAVal) -> IGNode
249 return self
.__map
[__key
]
252 # type: () -> Iterator[MergedSSAVal]
253 return iter(self
.__map
)
257 return len(self
.__map
)
259 def add_node(self
, merged_ssa_val
):
260 # type: (MergedSSAVal) -> IGNode
261 node
= self
.__map
.get(merged_ssa_val
, None)
264 added
= 0 # type: int | None
266 for ssa_val
in merged_ssa_val
.ssa_vals
:
267 if ssa_val
in self
.__merged
_ssa
_val
_map
:
269 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
270 f
"{merged_ssa_val} and "
271 f
"{self.__merged_ssa_val_map[ssa_val]}")
272 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
274 retval
= IGNode(merged_ssa_val
=merged_ssa_val
, edges
=(), loc
=None)
275 self
.__map
[merged_ssa_val
] = retval
279 if added
is not None:
280 # remove partially added stuff
281 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
284 del self
.__merged
_ssa
_val
_map
[ssa_val
]
286 def merge_into_one_node(self
, final_merged_ssa_val
):
287 # type: (MergedSSAVal) -> IGNode
288 source_nodes
= OSet() # type: OSet[IGNode]
289 edges
= OSet() # type: OSet[IGNode]
290 loc
= None # type: Loc | None
291 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
292 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
293 source_node
= self
.__map
[merged_ssa_val
]
294 source_nodes
.add(source_node
)
295 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
297 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
298 f
"but not in merged IGNode's merged_ssa_val: "
299 f
"source_node={source_node} "
300 f
"final_merged_ssa_val={final_merged_ssa_val}")
302 loc
= source_node
.loc
303 elif source_node
.loc
is not None and loc
!= source_node
.loc
:
304 raise ValueError(f
"can't merge IGNodes with mismatched `loc` "
305 f
"values: {loc} != {source_node.loc}")
306 edges |
= source_node
.edges
307 if len(source_nodes
) == 1:
308 return source_nodes
.pop() # merging a single node is a no-op
309 # we're finished checking validity, now we can modify stuff
310 edges
-= source_nodes
311 retval
= IGNode(merged_ssa_val
=final_merged_ssa_val
, edges
=edges
,
314 node
.edges
-= source_nodes
315 node
.edges
.add(retval
)
316 for node
in source_nodes
:
317 del self
.__map
[node
.merged_ssa_val
]
318 self
.__map
[final_merged_ssa_val
] = retval
319 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
320 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
323 def __repr__(self
, repr_state
=None):
324 # type: (None | IGNodeReprState) -> str
325 if repr_state
is None:
326 repr_state
= IGNodeReprState()
327 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
328 return f
"MergedSSAValToIGNodeMap({{{s}}})"
331 @plain_data(frozen
=True, repr=False)
333 class InterferenceGraph
:
334 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
336 def __init__(self
, fn_analysis
, merged_ssa_vals
):
337 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
338 self
.fn_analysis
= fn_analysis
339 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
340 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
341 for i
in merged_ssa_vals
:
342 self
.nodes
.add_node(i
)
344 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
345 # type: (SSAVal, SSAVal, int) -> IGNode
346 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
347 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
348 merged
= merged1
.with_offset_to_match(ssa_val1
)
349 merged
= merged
.merged(merged2
.with_offset_to_match(
350 ssa_val2
, additional_offset
=additional_offset
))
351 return self
.nodes
.merge_into_one_node(merged
)
354 def minimally_merged(fn_analysis
):
355 # type: (FnAnalysis) -> InterferenceGraph
356 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
357 for op
in fn_analysis
.fn
.ops
:
358 for inp
in op
.input_uses
:
359 if inp
.unspread_start
!= inp
:
360 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
361 additional_offset
=inp
.reg_offset_in_unspread
)
362 for out
in op
.outputs
:
363 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
364 if out
.unspread_start
!= out
:
365 retval
.merge(out
.unspread_start
, out
,
366 additional_offset
=out
.reg_offset_in_unspread
)
367 if out
.tied_input
is not None:
368 retval
.merge(out
.tied_input
.ssa_val
, out
)
371 def __repr__(self
, repr_state
=None):
372 # type: (None | IGNodeReprState) -> str
373 if repr_state
is None:
374 repr_state
= IGNodeReprState()
375 s
= self
.nodes
.__repr
__(repr_state
)
376 return f
"InterferenceGraph(nodes={s}, <...>)"
379 @plain_data(repr=False)
380 class IGNodeReprState
:
381 __slots__
= "node_ids", "did_full_repr"
385 self
.node_ids
= {} # type: dict[IGNode, int]
386 self
.did_full_repr
= OSet() # type: OSet[IGNode]
391 """ interference graph node """
392 __slots__
= "merged_ssa_val", "edges", "loc"
394 def __init__(self
, merged_ssa_val
, edges
, loc
):
395 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
396 self
.merged_ssa_val
= merged_ssa_val
397 self
.edges
= OSet(edges
)
400 def add_edge(self
, other
):
401 # type: (IGNode) -> None
402 self
.edges
.add(other
)
403 other
.edges
.add(self
)
405 def __eq__(self
, other
):
406 # type: (object) -> bool
407 if isinstance(other
, IGNode
):
408 return self
.merged_ssa_val
== other
.merged_ssa_val
409 return NotImplemented
413 return hash(self
.merged_ssa_val
)
415 def __repr__(self
, repr_state
=None, short
=False):
416 # type: (None | IGNodeReprState, bool) -> str
417 if repr_state
is None:
418 repr_state
= IGNodeReprState()
419 node_id
= repr_state
.node_ids
.get(self
, None)
421 repr_state
.node_ids
[self
] = node_id
= len(repr_state
.node_ids
)
422 if short
or self
in repr_state
.did_full_repr
:
423 return f
"<IGNode #{node_id}>"
424 repr_state
.did_full_repr
.add(self
)
425 edges
= ", ".join(i
.__repr
__(repr_state
, True) for i
in self
.edges
)
426 return (f
"IGNode(#{node_id}, "
427 f
"merged_ssa_val={self.merged_ssa_val}, "
428 f
"edges={{{edges}}}, "
434 return self
.merged_ssa_val
.loc_set
436 def loc_conflicts_with_neighbors(self
, loc
):
437 # type: (Loc) -> bool
438 for neighbor
in self
.edges
:
439 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
444 class AllocationFailedError(Exception):
445 def __init__(self
, msg
, node
, interference_graph
):
446 # type: (str, IGNode, InterferenceGraph) -> None
447 super().__init
__(msg
, node
, interference_graph
)
449 self
.interference_graph
= interference_graph
451 def __repr__(self
, repr_state
=None):
452 # type: (None | IGNodeReprState) -> str
453 if repr_state
is None:
454 repr_state
= IGNodeReprState()
455 return (f
"{__class__.__name__}({self.args[0]!r}, "
456 f
"node={self.node.__repr__(repr_state, True)}, "
457 f
"interference_graph="
458 f
"{self.interference_graph.__repr__(repr_state)})")
462 return self
.__repr
__()
465 def allocate_registers(fn
, debug_out
=None):
466 # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
468 # inserts enough copies that no manual spilling is necessary, all
469 # spilling is done by the register allocator naturally allocating SSAVals
471 fn
.pre_ra_insert_copies()
473 if debug_out
is not None:
474 print(f
"After pre_ra_insert_copies():\n{fn.ops}",
475 file=debug_out
, flush
=True)
477 fn_analysis
= FnAnalysis(fn
)
478 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
480 if debug_out
is not None:
481 print(f
"After InterferenceGraph.minimally_merged():\n"
482 f
"{interference_graph}", file=debug_out
, flush
=True)
484 for pp
, ssa_vals
in fn_analysis
.live_at
.items():
485 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
486 for ssa_val
in ssa_vals
:
487 live_merged_ssa_vals
.add(
488 interference_graph
.merged_ssa_val_map
[ssa_val
])
489 for i
, j
in combinations(live_merged_ssa_vals
, 2):
490 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
491 interference_graph
.nodes
[i
].add_edge(
492 interference_graph
.nodes
[j
])
493 if debug_out
is not None:
494 print(f
"processed {pp} out of {fn_analysis.all_program_points}",
495 file=debug_out
, flush
=True)
497 if debug_out
is not None:
498 print(f
"After adding interference graph edges:\n"
499 f
"{interference_graph}", file=debug_out
, flush
=True)
501 nodes_remaining
= OSet(interference_graph
.nodes
.values())
503 local_colorability_score_cache
= {} # type: dict[IGNode, int]
505 def local_colorability_score(node
):
506 # type: (IGNode) -> int
507 """ returns a positive integer if node is locally colorable, returns
508 zero or a negative integer if node isn't known to be locally
509 colorable, the more negative the value, the less colorable
511 if node
not in nodes_remaining
:
513 retval
= local_colorability_score_cache
.get(node
, None)
514 if retval
is not None:
516 retval
= len(node
.loc_set
)
517 for neighbor
in node
.edges
:
518 if neighbor
in nodes_remaining
:
519 retval
-= node
.loc_set
.max_conflicts_with(neighbor
.loc_set
)
520 local_colorability_score_cache
[node
] = retval
523 # TODO: implement copy-merging
525 node_stack
= [] # type: list[IGNode]
527 best_node
= None # type: None | IGNode
529 for node
in nodes_remaining
:
530 score
= local_colorability_score(node
)
531 if best_node
is None or score
> best_score
:
535 # it's locally colorable, no need to find a better one
538 if best_node
is None:
540 node_stack
.append(best_node
)
541 nodes_remaining
.remove(best_node
)
542 local_colorability_score_cache
.pop(best_node
, None)
543 for neighbor
in best_node
.edges
:
544 local_colorability_score_cache
.pop(neighbor
, None)
546 if debug_out
is not None:
547 print(f
"After deciding node allocation order:\n"
548 f
"{node_stack}", file=debug_out
, flush
=True)
550 retval
= {} # type: dict[SSAVal, Loc]
552 while len(node_stack
) > 0:
553 node
= node_stack
.pop()
554 if node
.loc
is not None:
555 if node
.loc_conflicts_with_neighbors(node
.loc
):
556 raise AllocationFailedError(
557 "IGNode is pre-allocated to a conflicting Loc",
558 node
=node
, interference_graph
=interference_graph
)
560 # pick the first non-conflicting register in node.reg_class, since
561 # register classes are ordered from most preferred to least
562 # preferred register.
563 for loc
in node
.loc_set
:
564 if not node
.loc_conflicts_with_neighbors(loc
):
568 raise AllocationFailedError(
569 "failed to allocate Loc for IGNode",
570 node
=node
, interference_graph
=interference_graph
)
572 if debug_out
is not None:
573 print(f
"After allocating Loc for node:\n{node}",
574 file=debug_out
, flush
=True)
576 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
577 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)
579 if debug_out
is not None:
580 print(f
"final Locs for all SSAVals:\n{retval}",
581 file=debug_out
, flush
=True)