2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Iterable
, Iterator
, Mapping
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir2
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
21 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
23 __slots__
= "first_write", "last_use"
25 def __init__(self
, first_write
, last_use
=None):
26 # type: (int, int | None) -> None
29 last_use
= first_write
30 if last_use
< first_write
:
31 raise ValueError("uses must be after first_write")
32 if first_write
< 0 or last_use
< 0:
33 raise ValueError("indexes must be nonnegative")
34 self
.first_write
= first_write
35 self
.last_use
= last_use
37 def overlaps(self
, other
):
38 # type: (LiveInterval) -> bool
39 if self
.first_write
== other
.first_write
:
41 return self
.last_use
> other
.first_write \
42 and other
.last_use
> self
.first_write
44 def __add__(self
, use
):
45 # type: (int) -> LiveInterval
46 last_use
= max(self
.last_use
, use
)
47 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
50 def live_after_op_range(self
):
51 """the range of op indexes where self is live immediately after the
54 return range(self
.first_write
, self
.last_use
)
57 class BadMergedSSAVal(ValueError):
61 @plain_data(frozen
=True, repr=False)
64 """a set of `SSAVal`s along with their offsets, all register allocated as
67 Definition of the term `offset` for this class:
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
91 __slots__
= "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
93 def __init__(self
, fn_analysis
, ssa_val_offsets
):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self
.fn_analysis
= fn_analysis
96 if isinstance(ssa_val_offsets
, SSAVal
):
97 ssa_val_offsets
= {ssa_val_offsets
: 0}
98 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
100 for ssa_val
in self
.ssa_vals
:
101 first_ssa_val
= ssa_val
103 if first_ssa_val
is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len
= self
.ty
.reg_len
108 loc_set
= None # type: None | LocSet
109 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
110 def_spread_idx
= ssa_val
.defining_descriptor
.spread_index
or 0
113 # type: () -> Iterable[Loc]
114 for loc
in ssa_val
.def_loc_set_before_spread
:
115 disallowed_by_use
= False
116 for use
in fn_analysis
.uses
[ssa_val
]:
118 use
.defining_descriptor
.spread_index
or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start
= loc
.start
+ def_spread_idx
- use_spread_idx
126 use_loc
= Loc
.try_make(
127 loc
.kind
, start
=start
,
128 reg_len
=use
.ty_before_spread
.reg_len
)
129 if (use_loc
is None or
130 use_loc
not in use
.use_loc_set_before_spread
):
131 disallowed_by_use
= True
133 if disallowed_by_use
:
135 # FIXME: add spread consistency check
136 start
= loc
.start
- cur_offset
+ self
.offset
137 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
138 if loc
is not None and (loc_set
is None or loc
in loc_set
):
140 loc_set
= LocSet(locs())
141 assert loc_set
is not None, "already checked that self isn't empty"
142 if loc_set
.ty
is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
145 self
.loc_set
= loc_set
# type: LocSet
150 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
159 return min(self
.ssa_val_offsets_before_spread
.values())
164 return self
.first_ssa_val
.base_ty
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self
.ssa_val_offsets
.keys())
175 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
176 cur_ty
= ssa_val
.ty_before_spread
177 if self
.base_ty
!= cur_ty
.base_ty
:
178 raise BadMergedSSAVal(
179 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
181 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
184 def ssa_val_offsets_before_spread(self
):
185 # type: () -> FMap[SSAVal, int]
186 retval
= {} # type: dict[SSAVal, int]
187 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
189 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
192 def offset_by(self
, amount
):
193 # type: (int) -> MergedSSAVal
194 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
195 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
197 def normalized(self
):
198 # type: () -> MergedSSAVal
199 return self
.offset_by(-self
.offset
)
201 def with_offset_to_match(self
, target
, additional_offset
=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target
, MergedSSAVal
):
204 ssa_val_offsets
= target
.ssa_val_offsets
206 ssa_val_offsets
= {target
: 0}
207 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
208 if ssa_val
in ssa_val_offsets
:
209 return self
.offset_by(
210 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
213 def merged(self
, *others
):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval
= dict(self
.ssa_val_offsets
)
217 if other
.fn_analysis
!= self
.fn_analysis
:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
220 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
221 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
222 f
"{retval[ssa_val]} != {offset}")
223 retval
[ssa_val
] = offset
224 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
225 ssa_val_offsets
=retval
)
228 def live_interval(self
):
229 # type: () -> ProgramRange
230 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
231 start
= live_range
.start
232 stop
= live_range
.stop
233 for ssa_val
in self
.ssa_vals
:
234 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
235 start
= min(start
, live_range
.start
)
236 stop
= max(stop
, live_range
.stop
)
237 return ProgramRange(start
=start
, stop
=stop
)
240 return (f
"MergedSSAVal({self.fn_analysis}, "
241 f
"ssa_val_offsets={self.ssa_val_offsets})")
245 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
247 # type: (...) -> None
248 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
249 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
250 _private_merged_ssa_val_map
=self
.__map
)
252 def __getitem__(self
, __key
):
253 # type: (SSAVal) -> MergedSSAVal
254 return self
.__map
[__key
]
257 # type: () -> Iterator[SSAVal]
258 return iter(self
.__map
)
262 return len(self
.__map
)
265 def ig_node_map(self
):
266 # type: () -> MergedSSAValToIGNodeMap
267 return self
.__ig
_node
_map
271 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
272 return f
"SSAValToMergedSSAValMap({{{s}}})"
276 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
279 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
281 # type: (...) -> None
282 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
283 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
285 def __getitem__(self
, __key
):
286 # type: (MergedSSAVal) -> IGNode
287 return self
.__map
[__key
]
290 # type: () -> Iterator[MergedSSAVal]
291 return iter(self
.__map
)
295 return len(self
.__map
)
297 def add_node(self
, merged_ssa_val
):
298 # type: (MergedSSAVal) -> IGNode
299 node
= self
.__map
.get(merged_ssa_val
, None)
302 added
= 0 # type: int | None
304 for ssa_val
in merged_ssa_val
.ssa_vals
:
305 if ssa_val
in self
.__merged
_ssa
_val
_map
:
307 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
308 f
"{merged_ssa_val} and "
309 f
"{self.__merged_ssa_val_map[ssa_val]}")
310 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
312 retval
= IGNode(merged_ssa_val
=merged_ssa_val
, edges
=(), loc
=None)
313 self
.__map
[merged_ssa_val
] = retval
317 if added
is not None:
318 # remove partially added stuff
319 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
322 del self
.__merged
_ssa
_val
_map
[ssa_val
]
324 def merge_into_one_node(self
, final_merged_ssa_val
):
325 # type: (MergedSSAVal) -> IGNode
326 source_nodes
= OSet() # type: OSet[IGNode]
327 edges
= OSet() # type: OSet[IGNode]
328 loc
= None # type: Loc | None
329 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
330 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
331 source_node
= self
.__map
[merged_ssa_val
]
332 source_nodes
.add(source_node
)
333 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
335 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
336 f
"but not in merged IGNode's merged_ssa_val: "
337 f
"source_node={source_node} "
338 f
"final_merged_ssa_val={final_merged_ssa_val}")
340 loc
= source_node
.loc
341 elif source_node
.loc
is not None and loc
!= source_node
.loc
:
342 raise ValueError(f
"can't merge IGNodes with mismatched `loc` "
343 f
"values: {loc} != {source_node.loc}")
344 edges |
= source_node
.edges
345 if len(source_nodes
) == 1:
346 return source_nodes
.pop() # merging a single node is a no-op
347 # we're finished checking validity, now we can modify stuff
348 edges
-= source_nodes
349 retval
= IGNode(merged_ssa_val
=final_merged_ssa_val
, edges
=edges
,
352 node
.edges
-= source_nodes
353 node
.edges
.add(retval
)
354 for node
in source_nodes
:
355 del self
.__map
[node
.merged_ssa_val
]
356 self
.__map
[final_merged_ssa_val
] = retval
357 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
358 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
361 def __repr__(self
, repr_state
=None):
362 # type: (None | IGNodeReprState) -> str
363 if repr_state
is None:
364 repr_state
= IGNodeReprState()
365 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
366 return f
"MergedSSAValToIGNodeMap({{{s}}})"
369 @plain_data(frozen
=True, repr=False)
371 class InterferenceGraph
:
372 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
374 def __init__(self
, fn_analysis
, merged_ssa_vals
):
375 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
376 self
.fn_analysis
= fn_analysis
377 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
378 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
379 for i
in merged_ssa_vals
:
380 self
.nodes
.add_node(i
)
382 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
383 # type: (SSAVal, SSAVal, int) -> IGNode
384 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
385 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
386 merged
= merged1
.with_offset_to_match(ssa_val1
)
387 merged
= merged
.merged(merged2
.with_offset_to_match(
388 ssa_val2
, additional_offset
=additional_offset
))
389 return self
.nodes
.merge_into_one_node(merged
)
392 def minimally_merged(fn_analysis
):
393 # type: (FnAnalysis) -> InterferenceGraph
394 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
395 for op
in fn_analysis
.fn
.ops
:
396 for inp
in op
.input_uses
:
397 if inp
.unspread_start
!= inp
:
398 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
399 additional_offset
=inp
.reg_offset_in_unspread
)
400 for out
in op
.outputs
:
401 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
402 if out
.unspread_start
!= out
:
403 retval
.merge(out
.unspread_start
, out
,
404 additional_offset
=out
.reg_offset_in_unspread
)
405 if out
.tied_input
is not None:
406 retval
.merge(out
.tied_input
.ssa_val
, out
)
409 def __repr__(self
, repr_state
=None):
410 # type: (None | IGNodeReprState) -> str
411 if repr_state
is None:
412 repr_state
= IGNodeReprState()
413 s
= self
.nodes
.__repr
__(repr_state
)
414 return f
"InterferenceGraph(nodes={s}, <...>)"
417 @plain_data(repr=False)
418 class IGNodeReprState
:
419 __slots__
= "node_ids", "did_full_repr"
423 self
.node_ids
= {} # type: dict[IGNode, int]
424 self
.did_full_repr
= OSet() # type: OSet[IGNode]
429 """ interference graph node """
430 __slots__
= "merged_ssa_val", "edges", "loc"
432 def __init__(self
, merged_ssa_val
, edges
, loc
):
433 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
434 self
.merged_ssa_val
= merged_ssa_val
435 self
.edges
= OSet(edges
)
438 def add_edge(self
, other
):
439 # type: (IGNode) -> None
440 self
.edges
.add(other
)
441 other
.edges
.add(self
)
443 def __eq__(self
, other
):
444 # type: (object) -> bool
445 if isinstance(other
, IGNode
):
446 return self
.merged_ssa_val
== other
.merged_ssa_val
447 return NotImplemented
451 return hash(self
.merged_ssa_val
)
453 def __repr__(self
, repr_state
=None, short
=False):
454 # type: (None | IGNodeReprState, bool) -> str
455 if repr_state
is None:
456 repr_state
= IGNodeReprState()
457 node_id
= repr_state
.node_ids
.get(self
, None)
459 repr_state
.node_ids
[self
] = node_id
= len(repr_state
.node_ids
)
460 if short
or self
in repr_state
.did_full_repr
:
461 return f
"<IGNode #{node_id}>"
462 repr_state
.did_full_repr
.add(self
)
463 edges
= ", ".join(i
.__repr
__(repr_state
, True) for i
in self
.edges
)
464 return (f
"IGNode(#{node_id}, "
465 f
"merged_ssa_val={self.merged_ssa_val}, "
466 f
"edges={{{edges}}}, "
472 return self
.merged_ssa_val
.loc_set
474 def loc_conflicts_with_neighbors(self
, loc
):
475 # type: (Loc) -> bool
476 for neighbor
in self
.edges
:
477 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
482 class AllocationFailedError(Exception):
483 def __init__(self
, msg
, node
, interference_graph
):
484 # type: (str, IGNode, InterferenceGraph) -> None
485 super().__init
__(msg
, node
, interference_graph
)
487 self
.interference_graph
= interference_graph
489 def __repr__(self
, repr_state
=None):
490 # type: (None | IGNodeReprState) -> str
491 if repr_state
is None:
492 repr_state
= IGNodeReprState()
493 return (f
"{__class__.__name__}({self.args[0]!r}, "
494 f
"node={self.node.__repr__(repr_state, True)}, "
495 f
"interference_graph="
496 f
"{self.interference_graph.__repr__(repr_state)})")
500 return self
.__repr
__()
503 def allocate_registers(fn
):
504 # type: (Fn) -> dict[SSAVal, Loc]
506 # inserts enough copies that no manual spilling is necessary, all
507 # spilling is done by the register allocator naturally allocating SSAVals
509 fn
.pre_ra_insert_copies()
511 fn_analysis
= FnAnalysis(fn
)
512 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
514 for ssa_vals
in fn_analysis
.live_at
.values():
515 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
516 for ssa_val
in ssa_vals
:
517 live_merged_ssa_vals
.add(
518 interference_graph
.merged_ssa_val_map
[ssa_val
])
519 for i
, j
in combinations(live_merged_ssa_vals
, 2):
520 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
521 interference_graph
.nodes
[i
].add_edge(
522 interference_graph
.nodes
[j
])
524 nodes_remaining
= OSet(interference_graph
.nodes
.values())
526 def local_colorability_score(node
):
527 # type: (IGNode) -> int
528 """ returns a positive integer if node is locally colorable, returns
529 zero or a negative integer if node isn't known to be locally
530 colorable, the more negative the value, the less colorable
532 if node
not in nodes_remaining
:
534 retval
= len(node
.loc_set
)
535 for neighbor
in node
.edges
:
536 if neighbor
in nodes_remaining
:
537 retval
-= node
.loc_set
.max_conflicts_with(neighbor
.loc_set
)
540 # TODO: implement copy-merging
542 node_stack
= [] # type: list[IGNode]
544 best_node
= None # type: None | IGNode
546 for node
in nodes_remaining
:
547 score
= local_colorability_score(node
)
548 if best_node
is None or score
> best_score
:
552 # it's locally colorable, no need to find a better one
555 if best_node
is None:
557 node_stack
.append(best_node
)
558 nodes_remaining
.remove(best_node
)
560 retval
= {} # type: dict[SSAVal, Loc]
562 while len(node_stack
) > 0:
563 node
= node_stack
.pop()
564 if node
.loc
is not None:
565 if node
.loc_conflicts_with_neighbors(node
.loc
):
566 raise AllocationFailedError(
567 "IGNode is pre-allocated to a conflicting Loc",
568 node
=node
, interference_graph
=interference_graph
)
570 # pick the first non-conflicting register in node.reg_class, since
571 # register classes are ordered from most preferred to least
572 # preferred register.
573 for loc
in node
.loc_set
:
574 if not node
.loc_conflicts_with_neighbors(loc
):
578 raise AllocationFailedError(
579 "failed to allocate Loc for IGNode",
580 node
=node
, interference_graph
=interference_graph
)
582 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
583 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)