2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Iterable
, Iterator
, Mapping
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir2
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, InternedMeta
, OFSet
, OSet
21 class BadMergedSSAVal(ValueError):
25 @plain_data(frozen
=True, repr=False)
27 class MergedSSAVal(metaclass
=InternedMeta
):
28 """a set of `SSAVal`s along with their offsets, all register allocated as
31 Definition of the term `offset` for this class:
33 Let `locs[x]` be the `Loc` that `x` is assigned to after register
34 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
35 for each `SSAVal` `ssa_val` in `msv` is defined as:
38 msv.ssa_val_offsets[ssa_val] = (msv.offset
39 + locs[ssa_val].start - locs[msv].start)
47 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
50 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
51 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
52 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
53 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
55 __slots__
= "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
57 def __init__(self
, fn_analysis
, ssa_val_offsets
):
58 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
59 self
.fn_analysis
= fn_analysis
60 if isinstance(ssa_val_offsets
, SSAVal
):
61 ssa_val_offsets
= {ssa_val_offsets
: 0}
62 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
64 for ssa_val
in self
.ssa_vals
:
65 first_ssa_val
= ssa_val
67 if first_ssa_val
is None:
68 raise BadMergedSSAVal("MergedSSAVal can't be empty")
69 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
70 # self.ty checks for mismatched base_ty
71 reg_len
= self
.ty
.reg_len
72 loc_set
= None # type: None | LocSet
73 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
74 def_spread_idx
= ssa_val
.defining_descriptor
.spread_index
or 0
77 # type: () -> Iterable[Loc]
78 for loc
in ssa_val
.def_loc_set_before_spread
:
79 disallowed_by_use
= False
80 for use
in fn_analysis
.uses
[ssa_val
]:
82 use
.defining_descriptor
.spread_index
or 0
83 # calculate the start for the use's Loc before spread
84 # e.g. if the def's Loc before spread starts at r6
85 # and the def's spread_index is 5
86 # and the use's spread_index is 3
87 # then the use's Loc before spread starts at r8
88 # because 8 == 6 + 5 - 3
89 start
= loc
.start
+ def_spread_idx
- use_spread_idx
90 use_loc
= Loc
.try_make(
91 loc
.kind
, start
=start
,
92 reg_len
=use
.ty_before_spread
.reg_len
)
93 if (use_loc
is None or
94 use_loc
not in use
.use_loc_set_before_spread
):
95 disallowed_by_use
= True
99 # FIXME: add spread consistency check
100 start
= loc
.start
- cur_offset
+ self
.offset
101 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
102 if loc
is not None and (loc_set
is None or loc
in loc_set
):
104 loc_set
= LocSet(locs())
105 assert loc_set
is not None, "already checked that self isn't empty"
106 if loc_set
.ty
is None:
107 raise BadMergedSSAVal("there are no valid Locs left")
108 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
109 self
.loc_set
= loc_set
# type: LocSet
114 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
123 return min(self
.ssa_val_offsets_before_spread
.values())
128 return self
.first_ssa_val
.base_ty
132 # type: () -> OFSet[SSAVal]
133 return OFSet(self
.ssa_val_offsets
.keys())
139 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
140 cur_ty
= ssa_val
.ty_before_spread
141 if self
.base_ty
!= cur_ty
.base_ty
:
142 raise BadMergedSSAVal(
143 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
144 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
145 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
148 def ssa_val_offsets_before_spread(self
):
149 # type: () -> FMap[SSAVal, int]
150 retval
= {} # type: dict[SSAVal, int]
151 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
153 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
156 def offset_by(self
, amount
):
157 # type: (int) -> MergedSSAVal
158 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
159 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
161 def normalized(self
):
162 # type: () -> MergedSSAVal
163 return self
.offset_by(-self
.offset
)
165 def with_offset_to_match(self
, target
, additional_offset
=0):
166 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
167 if isinstance(target
, MergedSSAVal
):
168 ssa_val_offsets
= target
.ssa_val_offsets
170 ssa_val_offsets
= {target
: 0}
171 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
172 if ssa_val
in ssa_val_offsets
:
173 return self
.offset_by(
174 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
175 raise ValueError("can't change offset to match unrelated MergedSSAVal")
177 def merged(self
, *others
):
178 # type: (*MergedSSAVal) -> MergedSSAVal
179 retval
= dict(self
.ssa_val_offsets
)
181 if other
.fn_analysis
!= self
.fn_analysis
:
182 raise ValueError("fn_analysis mismatch")
183 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
184 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
185 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
186 f
"{retval[ssa_val]} != {offset}")
187 retval
[ssa_val
] = offset
188 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
189 ssa_val_offsets
=retval
)
192 def live_interval(self
):
193 # type: () -> ProgramRange
194 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
195 start
= live_range
.start
196 stop
= live_range
.stop
197 for ssa_val
in self
.ssa_vals
:
198 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
199 start
= min(start
, live_range
.start
)
200 stop
= max(stop
, live_range
.stop
)
201 return ProgramRange(start
=start
, stop
=stop
)
204 return (f
"MergedSSAVal({self.fn_analysis}, "
205 f
"ssa_val_offsets={self.ssa_val_offsets})")
209 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
211 # type: (...) -> None
212 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
213 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
214 _private_merged_ssa_val_map
=self
.__map
)
216 def __getitem__(self
, __key
):
217 # type: (SSAVal) -> MergedSSAVal
218 return self
.__map
[__key
]
221 # type: () -> Iterator[SSAVal]
222 return iter(self
.__map
)
226 return len(self
.__map
)
229 def ig_node_map(self
):
230 # type: () -> MergedSSAValToIGNodeMap
231 return self
.__ig
_node
_map
235 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
236 return f
"SSAValToMergedSSAValMap({{{s}}})"
240 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
243 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
245 # type: (...) -> None
246 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
247 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
249 def __getitem__(self
, __key
):
250 # type: (MergedSSAVal) -> IGNode
251 return self
.__map
[__key
]
254 # type: () -> Iterator[MergedSSAVal]
255 return iter(self
.__map
)
259 return len(self
.__map
)
261 def add_node(self
, merged_ssa_val
):
262 # type: (MergedSSAVal) -> IGNode
263 node
= self
.__map
.get(merged_ssa_val
, None)
266 added
= 0 # type: int | None
268 for ssa_val
in merged_ssa_val
.ssa_vals
:
269 if ssa_val
in self
.__merged
_ssa
_val
_map
:
271 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
272 f
"{merged_ssa_val} and "
273 f
"{self.__merged_ssa_val_map[ssa_val]}")
274 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
276 retval
= IGNode(merged_ssa_val
=merged_ssa_val
, edges
=(), loc
=None)
277 self
.__map
[merged_ssa_val
] = retval
281 if added
is not None:
282 # remove partially added stuff
283 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
286 del self
.__merged
_ssa
_val
_map
[ssa_val
]
288 def merge_into_one_node(self
, final_merged_ssa_val
):
289 # type: (MergedSSAVal) -> IGNode
290 source_nodes
= OSet() # type: OSet[IGNode]
291 edges
= OSet() # type: OSet[IGNode]
292 loc
= None # type: Loc | None
293 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
294 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
295 source_node
= self
.__map
[merged_ssa_val
]
296 source_nodes
.add(source_node
)
297 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
299 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
300 f
"but not in merged IGNode's merged_ssa_val: "
301 f
"source_node={source_node} "
302 f
"final_merged_ssa_val={final_merged_ssa_val}")
304 loc
= source_node
.loc
305 elif source_node
.loc
is not None and loc
!= source_node
.loc
:
306 raise ValueError(f
"can't merge IGNodes with mismatched `loc` "
307 f
"values: {loc} != {source_node.loc}")
308 edges |
= source_node
.edges
309 if len(source_nodes
) == 1:
310 return source_nodes
.pop() # merging a single node is a no-op
311 # we're finished checking validity, now we can modify stuff
312 edges
-= source_nodes
313 retval
= IGNode(merged_ssa_val
=final_merged_ssa_val
, edges
=edges
,
316 node
.edges
-= source_nodes
317 node
.edges
.add(retval
)
318 for node
in source_nodes
:
319 del self
.__map
[node
.merged_ssa_val
]
320 self
.__map
[final_merged_ssa_val
] = retval
321 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
322 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
325 def __repr__(self
, repr_state
=None):
326 # type: (None | IGNodeReprState) -> str
327 if repr_state
is None:
328 repr_state
= IGNodeReprState()
329 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
330 return f
"MergedSSAValToIGNodeMap({{{s}}})"
333 @plain_data(frozen
=True, repr=False)
335 class InterferenceGraph
:
336 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
338 def __init__(self
, fn_analysis
, merged_ssa_vals
):
339 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
340 self
.fn_analysis
= fn_analysis
341 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
342 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
343 for i
in merged_ssa_vals
:
344 self
.nodes
.add_node(i
)
346 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
347 # type: (SSAVal, SSAVal, int) -> IGNode
348 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
349 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
350 merged
= merged1
.with_offset_to_match(ssa_val1
)
351 merged
= merged
.merged(merged2
.with_offset_to_match(
352 ssa_val2
, additional_offset
=additional_offset
))
353 return self
.nodes
.merge_into_one_node(merged
)
356 def minimally_merged(fn_analysis
):
357 # type: (FnAnalysis) -> InterferenceGraph
358 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
359 for op
in fn_analysis
.fn
.ops
:
360 for inp
in op
.input_uses
:
361 if inp
.unspread_start
!= inp
:
362 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
363 additional_offset
=inp
.reg_offset_in_unspread
)
364 for out
in op
.outputs
:
365 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
366 if out
.unspread_start
!= out
:
367 retval
.merge(out
.unspread_start
, out
,
368 additional_offset
=out
.reg_offset_in_unspread
)
369 if out
.tied_input
is not None:
370 retval
.merge(out
.tied_input
.ssa_val
, out
)
373 def __repr__(self
, repr_state
=None):
374 # type: (None | IGNodeReprState) -> str
375 if repr_state
is None:
376 repr_state
= IGNodeReprState()
377 s
= self
.nodes
.__repr
__(repr_state
)
378 return f
"InterferenceGraph(nodes={s}, <...>)"
381 @plain_data(repr=False)
382 class IGNodeReprState
:
383 __slots__
= "node_ids", "did_full_repr"
387 self
.node_ids
= {} # type: dict[IGNode, int]
388 self
.did_full_repr
= OSet() # type: OSet[IGNode]
393 """ interference graph node """
394 __slots__
= "merged_ssa_val", "edges", "loc"
396 def __init__(self
, merged_ssa_val
, edges
, loc
):
397 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
398 self
.merged_ssa_val
= merged_ssa_val
399 self
.edges
= OSet(edges
)
402 def add_edge(self
, other
):
403 # type: (IGNode) -> None
404 self
.edges
.add(other
)
405 other
.edges
.add(self
)
407 def __eq__(self
, other
):
408 # type: (object) -> bool
409 if isinstance(other
, IGNode
):
410 return self
.merged_ssa_val
== other
.merged_ssa_val
411 return NotImplemented
415 return hash(self
.merged_ssa_val
)
417 def __repr__(self
, repr_state
=None, short
=False):
418 # type: (None | IGNodeReprState, bool) -> str
419 if repr_state
is None:
420 repr_state
= IGNodeReprState()
421 node_id
= repr_state
.node_ids
.get(self
, None)
423 repr_state
.node_ids
[self
] = node_id
= len(repr_state
.node_ids
)
424 if short
or self
in repr_state
.did_full_repr
:
425 return f
"<IGNode #{node_id}>"
426 repr_state
.did_full_repr
.add(self
)
427 edges
= ", ".join(i
.__repr
__(repr_state
, True) for i
in self
.edges
)
428 return (f
"IGNode(#{node_id}, "
429 f
"merged_ssa_val={self.merged_ssa_val}, "
430 f
"edges={{{edges}}}, "
436 return self
.merged_ssa_val
.loc_set
438 def loc_conflicts_with_neighbors(self
, loc
):
439 # type: (Loc) -> bool
440 for neighbor
in self
.edges
:
441 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
446 class AllocationFailedError(Exception):
447 def __init__(self
, msg
, node
, interference_graph
):
448 # type: (str, IGNode, InterferenceGraph) -> None
449 super().__init
__(msg
, node
, interference_graph
)
451 self
.interference_graph
= interference_graph
453 def __repr__(self
, repr_state
=None):
454 # type: (None | IGNodeReprState) -> str
455 if repr_state
is None:
456 repr_state
= IGNodeReprState()
457 return (f
"{__class__.__name__}({self.args[0]!r}, "
458 f
"node={self.node.__repr__(repr_state, True)}, "
459 f
"interference_graph="
460 f
"{self.interference_graph.__repr__(repr_state)})")
464 return self
.__repr
__()
467 def allocate_registers(fn
):
468 # type: (Fn) -> dict[SSAVal, Loc]
470 # inserts enough copies that no manual spilling is necessary, all
471 # spilling is done by the register allocator naturally allocating SSAVals
473 fn
.pre_ra_insert_copies()
475 fn_analysis
= FnAnalysis(fn
)
476 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
478 for ssa_vals
in fn_analysis
.live_at
.values():
479 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
480 for ssa_val
in ssa_vals
:
481 live_merged_ssa_vals
.add(
482 interference_graph
.merged_ssa_val_map
[ssa_val
])
483 for i
, j
in combinations(live_merged_ssa_vals
, 2):
484 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
485 interference_graph
.nodes
[i
].add_edge(
486 interference_graph
.nodes
[j
])
488 nodes_remaining
= OSet(interference_graph
.nodes
.values())
490 def local_colorability_score(node
):
491 # type: (IGNode) -> int
492 """ returns a positive integer if node is locally colorable, returns
493 zero or a negative integer if node isn't known to be locally
494 colorable, the more negative the value, the less colorable
496 if node
not in nodes_remaining
:
498 retval
= len(node
.loc_set
)
499 for neighbor
in node
.edges
:
500 if neighbor
in nodes_remaining
:
501 retval
-= node
.loc_set
.max_conflicts_with(neighbor
.loc_set
)
504 # TODO: implement copy-merging
506 node_stack
= [] # type: list[IGNode]
508 best_node
= None # type: None | IGNode
510 for node
in nodes_remaining
:
511 score
= local_colorability_score(node
)
512 if best_node
is None or score
> best_score
:
516 # it's locally colorable, no need to find a better one
519 if best_node
is None:
521 node_stack
.append(best_node
)
522 nodes_remaining
.remove(best_node
)
524 retval
= {} # type: dict[SSAVal, Loc]
526 while len(node_stack
) > 0:
527 node
= node_stack
.pop()
528 if node
.loc
is not None:
529 if node
.loc_conflicts_with_neighbors(node
.loc
):
530 raise AllocationFailedError(
531 "IGNode is pre-allocated to a conflicting Loc",
532 node
=node
, interference_graph
=interference_graph
)
534 # pick the first non-conflicting register in node.reg_class, since
535 # register classes are ordered from most preferred to least
536 # preferred register.
537 for loc
in node
.loc_set
:
538 if not node
.loc_conflicts_with_neighbors(loc
):
542 raise AllocationFailedError(
543 "failed to allocate Loc for IGNode",
544 node
=node
, interference_graph
=interference_graph
)
546 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
547 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)