2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Iterable
, Iterator
, Mapping
, TextIO
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir
import (BaseTy
, Fn
, FnAnalysis
, Loc
,
15 LocSet
, Op
, ProgramRange
,
16 SSAVal
, SSAValSubReg
, Ty
)
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, InternedMeta
, OFSet
, OSet
21 class BadMergedSSAVal(ValueError):
25 @plain_data(frozen
=True, repr=False)
27 class MergedSSAVal(metaclass
=InternedMeta
):
28 """a set of `SSAVal`s along with their offsets, all register allocated as
31 Definition of the term `offset` for this class:
33 Let `locs[x]` be the `Loc` that `x` is assigned to after register
34 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
35 for each `SSAVal` `ssa_val` in `msv` is defined as:
38 msv.ssa_val_offsets[ssa_val] = (msv.offset
39 + locs[ssa_val].start - locs[msv].start)
47 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
50 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
51 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
52 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
53 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
55 __slots__
= ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
58 def __init__(self
, fn_analysis
, ssa_val_offsets
):
59 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
60 self
.fn_analysis
= fn_analysis
61 if isinstance(ssa_val_offsets
, SSAVal
):
62 ssa_val_offsets
= {ssa_val_offsets
: 0}
63 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
65 for ssa_val
in self
.ssa_vals
:
66 first_ssa_val
= ssa_val
68 if first_ssa_val
is None:
69 raise BadMergedSSAVal("MergedSSAVal can't be empty")
70 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
71 # self.ty checks for mismatched base_ty
72 reg_len
= self
.ty
.reg_len
73 loc_set
= None # type: None | LocSet
74 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
76 # type: () -> Iterable[Loc]
77 for loc
in ssa_val
.def_loc_set_before_spread
:
78 disallowed_by_use
= False
79 for use
in fn_analysis
.uses
[ssa_val
]:
80 # calculate the start for the use's Loc before spread
81 # e.g. if the def's Loc before spread starts at r6
82 # and the def's reg_offset_in_unspread is 5
83 # and the use's reg_offset_in_unspread is 3
84 # then the use's Loc before spread starts at r8
85 # because 8 == 6 + 5 - 3
86 start
= (loc
.start
+ ssa_val
.reg_offset_in_unspread
87 - use
.reg_offset_in_unspread
)
88 use_loc
= Loc
.try_make(
89 loc
.kind
, start
=start
,
90 reg_len
=use
.ty_before_spread
.reg_len
)
91 if (use_loc
is None or
92 use_loc
not in use
.use_loc_set_before_spread
):
93 disallowed_by_use
= True
97 start
= loc
.start
- cur_offset
+ self
.offset
98 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
99 if loc
is not None and (loc_set
is None or loc
in loc_set
):
101 loc_set
= LocSet(locs())
102 assert loc_set
is not None, "already checked that self isn't empty"
107 if first_loc
is None:
108 raise BadMergedSSAVal("there are no valid Locs left")
109 self
.first_loc
= first_loc
110 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
111 self
.loc_set
= loc_set
# type: LocSet
112 self
.__mergable
_check
()
114 def __mergable_check(self
):
116 """ checks that nothing is forcing two independent SSAVals
117 to illegally overlap. This is required to avoid copy merging merging
118 things that can't be merged.
119 spread arguments are one of the things that can force two values to
122 ops
= OSet() # type: Iterable[Op]
123 for ssa_val
in self
.ssa_vals
:
125 for use
in self
.fn_analysis
.uses
[ssa_val
]:
127 ops
= sorted(ops
, key
=self
.fn_analysis
.op_indexes
.__getitem
__)
128 vals
= {} # type: dict[int, SSAValSubReg]
130 for inp
in op
.input_vals
:
132 ssa_val_offset
= self
.ssa_val_offsets
[inp
]
135 for orig_reg
in inp
.ssa_val_sub_regs
:
136 reg_offset
= ssa_val_offset
+ orig_reg
.reg_idx
137 replaced_reg
= vals
[reg_offset
]
138 if not self
.fn_analysis
.is_always_equal(
139 orig_reg
, replaced_reg
):
140 raise BadMergedSSAVal(
141 f
"attempting to merge values that aren't known to "
142 f
"be always equal: {orig_reg} != {replaced_reg}")
143 output_offsets
= dict.fromkeys(range(
144 self
.offset
, self
.offset
+ self
.ty
.reg_len
))
145 for out
in op
.outputs
:
147 ssa_val_offset
= self
.ssa_val_offsets
[out
]
150 for reg
in out
.ssa_val_sub_regs
:
151 reg_offset
= ssa_val_offset
+ reg
.reg_idx
153 del output_offsets
[reg_offset
]
155 raise BadMergedSSAVal("attempted to merge two outputs "
156 "of the same instruction")
157 vals
[reg_offset
] = reg
162 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
171 return min(self
.ssa_val_offsets_before_spread
.values())
176 return self
.first_ssa_val
.base_ty
180 # type: () -> OFSet[SSAVal]
181 return OFSet(self
.ssa_val_offsets
.keys())
187 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
188 cur_ty
= ssa_val
.ty_before_spread
189 if self
.base_ty
!= cur_ty
.base_ty
:
190 raise BadMergedSSAVal(
191 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
192 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
193 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
196 def ssa_val_offsets_before_spread(self
):
197 # type: () -> FMap[SSAVal, int]
198 retval
= {} # type: dict[SSAVal, int]
199 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
201 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
204 def offset_by(self
, amount
):
205 # type: (int) -> MergedSSAVal
206 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
207 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
209 def normalized(self
):
210 # type: () -> MergedSSAVal
211 return self
.offset_by(-self
.offset
)
213 def with_offset_to_match(self
, target
, additional_offset
=0):
214 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
215 if isinstance(target
, MergedSSAVal
):
216 ssa_val_offsets
= target
.ssa_val_offsets
218 ssa_val_offsets
= {target
: 0}
219 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
220 if ssa_val
in ssa_val_offsets
:
221 return self
.offset_by(
222 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
223 raise ValueError("can't change offset to match unrelated MergedSSAVal")
225 def merged(self
, *others
):
226 # type: (*MergedSSAVal) -> MergedSSAVal
227 retval
= dict(self
.ssa_val_offsets
)
229 if other
.fn_analysis
!= self
.fn_analysis
:
230 raise ValueError("fn_analysis mismatch")
231 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
232 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
233 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
234 f
"{retval[ssa_val]} != {offset}")
235 retval
[ssa_val
] = offset
236 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
237 ssa_val_offsets
=retval
)
240 def live_interval(self
):
241 # type: () -> ProgramRange
242 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
243 start
= live_range
.start
244 stop
= live_range
.stop
245 for ssa_val
in self
.ssa_vals
:
246 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
247 start
= min(start
, live_range
.start
)
248 stop
= max(stop
, live_range
.stop
)
249 return ProgramRange(start
=start
, stop
=stop
)
252 return (f
"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
253 f
"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
254 f
"live_interval={self.live_interval})")
258 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
260 # type: (...) -> None
261 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
262 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
263 _private_merged_ssa_val_map
=self
.__map
)
265 def __getitem__(self
, __key
):
266 # type: (SSAVal) -> MergedSSAVal
267 return self
.__map
[__key
]
270 # type: () -> Iterator[SSAVal]
271 return iter(self
.__map
)
275 return len(self
.__map
)
278 def ig_node_map(self
):
279 # type: () -> MergedSSAValToIGNodeMap
280 return self
.__ig
_node
_map
284 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
285 return f
"SSAValToMergedSSAValMap({{{s}}})"
289 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, "IGNode"]):
292 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
294 # type: (...) -> None
295 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
296 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
298 def __getitem__(self
, __key
):
299 # type: (MergedSSAVal) -> IGNode
300 return self
.__map
[__key
]
303 # type: () -> Iterator[MergedSSAVal]
304 return iter(self
.__map
)
308 return len(self
.__map
)
310 def add_node(self
, merged_ssa_val
):
311 # type: (MergedSSAVal) -> IGNode
312 node
= self
.__map
.get(merged_ssa_val
, None)
315 added
= 0 # type: int | None
317 for ssa_val
in merged_ssa_val
.ssa_vals
:
318 if ssa_val
in self
.__merged
_ssa
_val
_map
:
320 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
321 f
"{merged_ssa_val} and "
322 f
"{self.__merged_ssa_val_map[ssa_val]}")
323 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
325 retval
= IGNode(merged_ssa_val
=merged_ssa_val
, edges
=(), loc
=None)
326 self
.__map
[merged_ssa_val
] = retval
330 if added
is not None:
331 # remove partially added stuff
332 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
335 del self
.__merged
_ssa
_val
_map
[ssa_val
]
337 def merge_into_one_node(self
, final_merged_ssa_val
):
338 # type: (MergedSSAVal) -> IGNode
339 source_nodes
= OSet() # type: OSet[IGNode]
340 edges
= OSet() # type: OSet[IGNode]
341 loc
= None # type: Loc | None
342 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
343 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
344 source_node
= self
.__map
[merged_ssa_val
]
345 source_nodes
.add(source_node
)
346 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
348 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
349 f
"but not in merged IGNode's merged_ssa_val: "
350 f
"source_node={source_node} "
351 f
"final_merged_ssa_val={final_merged_ssa_val}")
353 loc
= source_node
.loc
354 elif source_node
.loc
is not None and loc
!= source_node
.loc
:
355 raise ValueError(f
"can't merge IGNodes with mismatched `loc` "
356 f
"values: {loc} != {source_node.loc}")
357 edges |
= source_node
.edges
358 if len(source_nodes
) == 1:
359 return source_nodes
.pop() # merging a single node is a no-op
360 # we're finished checking validity, now we can modify stuff
361 edges
-= source_nodes
362 retval
= IGNode(merged_ssa_val
=final_merged_ssa_val
, edges
=edges
,
365 node
.edges
-= source_nodes
366 node
.edges
.add(retval
)
367 for node
in source_nodes
:
368 del self
.__map
[node
.merged_ssa_val
]
369 self
.__map
[final_merged_ssa_val
] = retval
370 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
371 self
.__merged
_ssa
_val
_map
[ssa_val
] = final_merged_ssa_val
374 def __repr__(self
, repr_state
=None):
375 # type: (None | IGNodeReprState) -> str
376 if repr_state
is None:
377 repr_state
= IGNodeReprState()
378 s
= ",\n".join(v
.__repr
__(repr_state
) for v
in self
.__map
.values())
379 return f
"MergedSSAValToIGNodeMap({{{s}}})"
382 @plain_data(frozen
=True, repr=False)
384 class InterferenceGraph
:
385 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
387 def __init__(self
, fn_analysis
, merged_ssa_vals
):
388 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
389 self
.fn_analysis
= fn_analysis
390 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
391 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
392 for i
in merged_ssa_vals
:
393 self
.nodes
.add_node(i
)
395 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
396 # type: (SSAVal, SSAVal, int) -> IGNode
397 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
398 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
399 merged
= merged1
.with_offset_to_match(ssa_val1
)
400 merged
= merged
.merged(merged2
.with_offset_to_match(
401 ssa_val2
, additional_offset
=additional_offset
))
402 return self
.nodes
.merge_into_one_node(merged
)
405 def minimally_merged(fn_analysis
):
406 # type: (FnAnalysis) -> InterferenceGraph
407 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
408 for op
in fn_analysis
.fn
.ops
:
409 for inp
in op
.input_uses
:
410 if inp
.unspread_start
!= inp
:
411 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
412 additional_offset
=inp
.reg_offset_in_unspread
)
413 for out
in op
.outputs
:
414 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
415 if out
.unspread_start
!= out
:
416 retval
.merge(out
.unspread_start
, out
,
417 additional_offset
=out
.reg_offset_in_unspread
)
418 if out
.tied_input
is not None:
419 retval
.merge(out
.tied_input
.ssa_val
, out
)
422 def __repr__(self
, repr_state
=None):
423 # type: (None | IGNodeReprState) -> str
424 if repr_state
is None:
425 repr_state
= IGNodeReprState()
426 s
= self
.nodes
.__repr
__(repr_state
)
427 return f
"InterferenceGraph(nodes={s}, <...>)"
430 @plain_data(repr=False)
431 class IGNodeReprState
:
432 __slots__
= "node_ids", "did_full_repr"
436 self
.node_ids
= {} # type: dict[IGNode, int]
437 self
.did_full_repr
= OSet() # type: OSet[IGNode]
442 """ interference graph node """
443 __slots__
= "merged_ssa_val", "edges", "loc"
445 def __init__(self
, merged_ssa_val
, edges
, loc
):
446 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
447 self
.merged_ssa_val
= merged_ssa_val
448 self
.edges
= OSet(edges
)
451 def add_edge(self
, other
):
452 # type: (IGNode) -> None
453 self
.edges
.add(other
)
454 other
.edges
.add(self
)
456 def __eq__(self
, other
):
457 # type: (object) -> bool
458 if isinstance(other
, IGNode
):
459 return self
.merged_ssa_val
== other
.merged_ssa_val
460 return NotImplemented
464 return hash(self
.merged_ssa_val
)
466 def __repr__(self
, repr_state
=None, short
=False):
467 # type: (None | IGNodeReprState, bool) -> str
468 if repr_state
is None:
469 repr_state
= IGNodeReprState()
470 node_id
= repr_state
.node_ids
.get(self
, None)
472 repr_state
.node_ids
[self
] = node_id
= len(repr_state
.node_ids
)
473 if short
or self
in repr_state
.did_full_repr
:
474 return f
"<IGNode #{node_id}>"
475 repr_state
.did_full_repr
.add(self
)
476 edges
= ", ".join(i
.__repr
__(repr_state
, True) for i
in self
.edges
)
477 return (f
"IGNode(#{node_id}, "
478 f
"merged_ssa_val={self.merged_ssa_val}, "
479 f
"edges={{{edges}}}, "
485 return self
.merged_ssa_val
.loc_set
487 def loc_conflicts_with_neighbors(self
, loc
):
488 # type: (Loc) -> bool
489 for neighbor
in self
.edges
:
490 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
495 class AllocationFailedError(Exception):
496 def __init__(self
, msg
, node
, interference_graph
):
497 # type: (str, IGNode, InterferenceGraph) -> None
498 super().__init
__(msg
, node
, interference_graph
)
500 self
.interference_graph
= interference_graph
502 def __repr__(self
, repr_state
=None):
503 # type: (None | IGNodeReprState) -> str
504 if repr_state
is None:
505 repr_state
= IGNodeReprState()
506 return (f
"{__class__.__name__}({self.args[0]!r}, "
507 f
"node={self.node.__repr__(repr_state, True)}, "
508 f
"interference_graph="
509 f
"{self.interference_graph.__repr__(repr_state)})")
513 return self
.__repr
__()
516 def allocate_registers(fn
, debug_out
=None):
517 # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
519 # inserts enough copies that no manual spilling is necessary, all
520 # spilling is done by the register allocator naturally allocating SSAVals
522 fn
.pre_ra_insert_copies()
524 if debug_out
is not None:
525 print(f
"After pre_ra_insert_copies():\n{fn.ops}",
526 file=debug_out
, flush
=True)
528 fn_analysis
= FnAnalysis(fn
)
529 interference_graph
= InterferenceGraph
.minimally_merged(fn_analysis
)
531 if debug_out
is not None:
532 print(f
"After InterferenceGraph.minimally_merged():\n"
533 f
"{interference_graph}", file=debug_out
, flush
=True)
535 for pp
, ssa_vals
in fn_analysis
.live_at
.items():
536 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
537 for ssa_val
in ssa_vals
:
538 live_merged_ssa_vals
.add(
539 interference_graph
.merged_ssa_val_map
[ssa_val
])
540 for i
, j
in combinations(live_merged_ssa_vals
, 2):
541 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
542 interference_graph
.nodes
[i
].add_edge(
543 interference_graph
.nodes
[j
])
544 if debug_out
is not None:
545 print(f
"processed {pp} out of {fn_analysis.all_program_points}",
546 file=debug_out
, flush
=True)
548 if debug_out
is not None:
549 print(f
"After adding interference graph edges:\n"
550 f
"{interference_graph}", file=debug_out
, flush
=True)
552 nodes_remaining
= OSet(interference_graph
.nodes
.values())
554 local_colorability_score_cache
= {} # type: dict[IGNode, int]
556 def local_colorability_score(node
):
557 # type: (IGNode) -> int
558 """ returns a positive integer if node is locally colorable, returns
559 zero or a negative integer if node isn't known to be locally
560 colorable, the more negative the value, the less colorable
562 if node
not in nodes_remaining
:
564 retval
= local_colorability_score_cache
.get(node
, None)
565 if retval
is not None:
567 retval
= len(node
.loc_set
)
568 for neighbor
in node
.edges
:
569 if neighbor
in nodes_remaining
:
570 retval
-= node
.loc_set
.max_conflicts_with(neighbor
.loc_set
)
571 local_colorability_score_cache
[node
] = retval
574 # TODO: implement copy-merging
576 node_stack
= [] # type: list[IGNode]
578 best_node
= None # type: None | IGNode
580 for node
in nodes_remaining
:
581 score
= local_colorability_score(node
)
582 if best_node
is None or score
> best_score
:
586 # it's locally colorable, no need to find a better one
589 if best_node
is None:
591 node_stack
.append(best_node
)
592 nodes_remaining
.remove(best_node
)
593 local_colorability_score_cache
.pop(best_node
, None)
594 for neighbor
in best_node
.edges
:
595 local_colorability_score_cache
.pop(neighbor
, None)
597 if debug_out
is not None:
598 print(f
"After deciding node allocation order:\n"
599 f
"{node_stack}", file=debug_out
, flush
=True)
601 retval
= {} # type: dict[SSAVal, Loc]
603 while len(node_stack
) > 0:
604 node
= node_stack
.pop()
605 if node
.loc
is not None:
606 if node
.loc_conflicts_with_neighbors(node
.loc
):
607 raise AllocationFailedError(
608 "IGNode is pre-allocated to a conflicting Loc",
609 node
=node
, interference_graph
=interference_graph
)
611 # pick the first non-conflicting register in node.reg_class, since
612 # register classes are ordered from most preferred to least
613 # preferred register.
614 for loc
in node
.loc_set
:
615 if not node
.loc_conflicts_with_neighbors(loc
):
619 raise AllocationFailedError(
620 "failed to allocate Loc for IGNode",
621 node
=node
, interference_graph
=interference_graph
)
623 if debug_out
is not None:
624 print(f
"After allocating Loc for node:\n{node}",
625 file=debug_out
, flush
=True)
627 for ssa_val
, offset
in node
.merged_ssa_val
.ssa_val_offsets
.items():
628 retval
[ssa_val
] = node
.loc
.get_subloc_at_offset(ssa_val
.ty
, offset
)
630 if debug_out
is not None:
631 print(f
"final Locs for all SSAVals:\n{retval}",
632 file=debug_out
, flush
=True)