register_allocator2.py works!
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator2.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Iterable, Iterator, Mapping
10
11 from cached_property import cached_property
12 from nmutil.plain_data import plain_data
13
14 from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc,
15 LocSet, ProgramRange,
16 SSAVal, Ty)
17 from bigint_presentation_code.type_util import final
18 from bigint_presentation_code.util import FMap, OFSet, OSet
19
20
21 @plain_data(unsafe_hash=True, order=True, frozen=True)
22 class LiveInterval:
23 __slots__ = "first_write", "last_use"
24
25 def __init__(self, first_write, last_use=None):
26 # type: (int, int | None) -> None
27 super().__init__()
28 if last_use is None:
29 last_use = first_write
30 if last_use < first_write:
31 raise ValueError("uses must be after first_write")
32 if first_write < 0 or last_use < 0:
33 raise ValueError("indexes must be nonnegative")
34 self.first_write = first_write
35 self.last_use = last_use
36
37 def overlaps(self, other):
38 # type: (LiveInterval) -> bool
39 if self.first_write == other.first_write:
40 return True
41 return self.last_use > other.first_write \
42 and other.last_use > self.first_write
43
44 def __add__(self, use):
45 # type: (int) -> LiveInterval
46 last_use = max(self.last_use, use)
47 return LiveInterval(first_write=self.first_write, last_use=last_use)
48
49 @property
50 def live_after_op_range(self):
51 """the range of op indexes where self is live immediately after the
52 Op at each index
53 """
54 return range(self.first_write, self.last_use)
55
56
57 class BadMergedSSAVal(ValueError):
58 pass
59
60
61 @plain_data(frozen=True, repr=False)
62 @final
63 class MergedSSAVal:
64 """a set of `SSAVal`s along with their offsets, all register allocated as
65 a single unit.
66
67 Definition of the term `offset` for this class:
68
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
72
73 ```
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
76 ```
77
78 Example:
79 ```
80 v1.ty == <I64*4>
81 v2.ty == <I64*2>
82 v3.ty == <I64>
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
84 msv.ty == <I64*6>
85 ```
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
90 """
91 __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
92
93 def __init__(self, fn_analysis, ssa_val_offsets):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self.fn_analysis = fn_analysis
96 if isinstance(ssa_val_offsets, SSAVal):
97 ssa_val_offsets = {ssa_val_offsets: 0}
98 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
99 first_ssa_val = None
100 for ssa_val in self.ssa_vals:
101 first_ssa_val = ssa_val
102 break
103 if first_ssa_val is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self.first_ssa_val = first_ssa_val # type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len = self.ty.reg_len
108 loc_set = None # type: None | LocSet
109 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
110 def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
111
112 def locs():
113 # type: () -> Iterable[Loc]
114 for loc in ssa_val.def_loc_set_before_spread:
115 disallowed_by_use = False
116 for use in fn_analysis.uses[ssa_val]:
117 use_spread_idx = \
118 use.defining_descriptor.spread_index or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start = loc.start + def_spread_idx - use_spread_idx
126 use_loc = Loc.try_make(
127 loc.kind, start=start,
128 reg_len=use.ty_before_spread.reg_len)
129 if (use_loc is None or
130 use_loc not in use.use_loc_set_before_spread):
131 disallowed_by_use = True
132 break
133 if disallowed_by_use:
134 continue
135 # FIXME: add spread consistency check
136 start = loc.start - cur_offset + self.offset
137 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
138 if loc is not None and (loc_set is None or loc in loc_set):
139 yield loc
140 loc_set = LocSet(locs())
141 assert loc_set is not None, "already checked that self isn't empty"
142 if loc_set.ty is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set.ty == self.ty, "logic error somewhere"
145 self.loc_set = loc_set # type: LocSet
146
147 @cached_property
148 def __hash(self):
149 # type: () -> int
150 return hash((self.fn_analysis, self.ssa_val_offsets))
151
152 def __hash__(self):
153 # type: () -> int
154 return self.__hash
155
156 @cached_property
157 def offset(self):
158 # type: () -> int
159 return min(self.ssa_val_offsets_before_spread.values())
160
161 @property
162 def base_ty(self):
163 # type: () -> BaseTy
164 return self.first_ssa_val.base_ty
165
166 @cached_property
167 def ssa_vals(self):
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self.ssa_val_offsets.keys())
170
171 @cached_property
172 def ty(self):
173 # type: () -> Ty
174 reg_len = 0
175 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
176 cur_ty = ssa_val.ty_before_spread
177 if self.base_ty != cur_ty.base_ty:
178 raise BadMergedSSAVal(
179 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
181 return Ty(base_ty=self.base_ty, reg_len=reg_len)
182
183 @cached_property
184 def ssa_val_offsets_before_spread(self):
185 # type: () -> FMap[SSAVal, int]
186 retval = {} # type: dict[SSAVal, int]
187 for ssa_val, offset in self.ssa_val_offsets.items():
188 retval[ssa_val] = (
189 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
190 return FMap(retval)
191
192 def offset_by(self, amount):
193 # type: (int) -> MergedSSAVal
194 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
195 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
196
197 def normalized(self):
198 # type: () -> MergedSSAVal
199 return self.offset_by(-self.offset)
200
201 def with_offset_to_match(self, target, additional_offset=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target, MergedSSAVal):
204 ssa_val_offsets = target.ssa_val_offsets
205 else:
206 ssa_val_offsets = {target: 0}
207 for ssa_val, offset in self.ssa_val_offsets.items():
208 if ssa_val in ssa_val_offsets:
209 return self.offset_by(
210 ssa_val_offsets[ssa_val] + additional_offset - offset)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
212
213 def merged(self, *others):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval = dict(self.ssa_val_offsets)
216 for other in others:
217 if other.fn_analysis != self.fn_analysis:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val, offset in other.ssa_val_offsets.items():
220 if ssa_val in retval and retval[ssa_val] != offset:
221 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
222 f"{retval[ssa_val]} != {offset}")
223 retval[ssa_val] = offset
224 return MergedSSAVal(fn_analysis=self.fn_analysis,
225 ssa_val_offsets=retval)
226
227 @cached_property
228 def live_interval(self):
229 # type: () -> ProgramRange
230 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
231 start = live_range.start
232 stop = live_range.stop
233 for ssa_val in self.ssa_vals:
234 live_range = self.fn_analysis.live_ranges[ssa_val]
235 start = min(start, live_range.start)
236 stop = max(stop, live_range.stop)
237 return ProgramRange(start=start, stop=stop)
238
239 def __repr__(self):
240 return (f"MergedSSAVal({self.fn_analysis}, "
241 f"ssa_val_offsets={self.ssa_val_offsets})")
242
243
244 @final
245 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
246 def __init__(self):
247 # type: (...) -> None
248 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
249 self.__ig_node_map = MergedSSAValToIGNodeMap(
250 _private_merged_ssa_val_map=self.__map)
251
252 def __getitem__(self, __key):
253 # type: (SSAVal) -> MergedSSAVal
254 return self.__map[__key]
255
256 def __iter__(self):
257 # type: () -> Iterator[SSAVal]
258 return iter(self.__map)
259
260 def __len__(self):
261 # type: () -> int
262 return len(self.__map)
263
264 @property
265 def ig_node_map(self):
266 # type: () -> MergedSSAValToIGNodeMap
267 return self.__ig_node_map
268
269 def __repr__(self):
270 # type: () -> str
271 s = ",\n".join(repr(v) for v in self.__ig_node_map)
272 return f"SSAValToMergedSSAValMap({{{s}}})"
273
274
275 @final
276 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
277 def __init__(
278 self, *,
279 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
280 ):
281 # type: (...) -> None
282 self.__merged_ssa_val_map = _private_merged_ssa_val_map
283 self.__map = {} # type: dict[MergedSSAVal, IGNode]
284
285 def __getitem__(self, __key):
286 # type: (MergedSSAVal) -> IGNode
287 return self.__map[__key]
288
289 def __iter__(self):
290 # type: () -> Iterator[MergedSSAVal]
291 return iter(self.__map)
292
293 def __len__(self):
294 # type: () -> int
295 return len(self.__map)
296
297 def add_node(self, merged_ssa_val):
298 # type: (MergedSSAVal) -> IGNode
299 node = self.__map.get(merged_ssa_val, None)
300 if node is not None:
301 return node
302 added = 0 # type: int | None
303 try:
304 for ssa_val in merged_ssa_val.ssa_vals:
305 if ssa_val in self.__merged_ssa_val_map:
306 raise ValueError(
307 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
308 f"{merged_ssa_val} and "
309 f"{self.__merged_ssa_val_map[ssa_val]}")
310 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
311 added += 1
312 retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
313 self.__map[merged_ssa_val] = retval
314 added = None
315 return retval
316 finally:
317 if added is not None:
318 # remove partially added stuff
319 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
320 if idx >= added:
321 break
322 del self.__merged_ssa_val_map[ssa_val]
323
324 def merge_into_one_node(self, final_merged_ssa_val):
325 # type: (MergedSSAVal) -> IGNode
326 source_nodes = OSet() # type: OSet[IGNode]
327 edges = OSet() # type: OSet[IGNode]
328 loc = None # type: Loc | None
329 for ssa_val in final_merged_ssa_val.ssa_vals:
330 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
331 source_node = self.__map[merged_ssa_val]
332 source_nodes.add(source_node)
333 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
334 raise ValueError(
335 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
336 f"but not in merged IGNode's merged_ssa_val: "
337 f"source_node={source_node} "
338 f"final_merged_ssa_val={final_merged_ssa_val}")
339 if loc is None:
340 loc = source_node.loc
341 elif source_node.loc is not None and loc != source_node.loc:
342 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
343 f"values: {loc} != {source_node.loc}")
344 edges |= source_node.edges
345 if len(source_nodes) == 1:
346 return source_nodes.pop() # merging a single node is a no-op
347 # we're finished checking validity, now we can modify stuff
348 edges -= source_nodes
349 retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
350 loc=loc)
351 for node in edges:
352 node.edges -= source_nodes
353 node.edges.add(retval)
354 for node in source_nodes:
355 del self.__map[node.merged_ssa_val]
356 self.__map[final_merged_ssa_val] = retval
357 for ssa_val in final_merged_ssa_val.ssa_vals:
358 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
359 return retval
360
361 def __repr__(self, repr_state=None):
362 # type: (None | IGNodeReprState) -> str
363 if repr_state is None:
364 repr_state = IGNodeReprState()
365 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
366 return f"MergedSSAValToIGNodeMap({{{s}}})"
367
368
369 @plain_data(frozen=True, repr=False)
370 @final
371 class InterferenceGraph:
372 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
373
374 def __init__(self, fn_analysis, merged_ssa_vals):
375 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
376 self.fn_analysis = fn_analysis
377 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
378 self.nodes = self.merged_ssa_val_map.ig_node_map
379 for i in merged_ssa_vals:
380 self.nodes.add_node(i)
381
382 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
383 # type: (SSAVal, SSAVal, int) -> IGNode
384 merged1 = self.merged_ssa_val_map[ssa_val1]
385 merged2 = self.merged_ssa_val_map[ssa_val2]
386 merged = merged1.with_offset_to_match(ssa_val1)
387 merged = merged.merged(merged2.with_offset_to_match(
388 ssa_val2, additional_offset=additional_offset))
389 return self.nodes.merge_into_one_node(merged)
390
391 @staticmethod
392 def minimally_merged(fn_analysis):
393 # type: (FnAnalysis) -> InterferenceGraph
394 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
395 for op in fn_analysis.fn.ops:
396 for inp in op.input_uses:
397 if inp.unspread_start != inp:
398 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
399 additional_offset=inp.reg_offset_in_unspread)
400 for out in op.outputs:
401 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
402 if out.unspread_start != out:
403 retval.merge(out.unspread_start, out,
404 additional_offset=out.reg_offset_in_unspread)
405 if out.tied_input is not None:
406 retval.merge(out.tied_input.ssa_val, out)
407 return retval
408
409 def __repr__(self, repr_state=None):
410 # type: (None | IGNodeReprState) -> str
411 if repr_state is None:
412 repr_state = IGNodeReprState()
413 s = self.nodes.__repr__(repr_state)
414 return f"InterferenceGraph(nodes={s}, <...>)"
415
416
417 @plain_data(repr=False)
418 class IGNodeReprState:
419 __slots__ = "node_ids", "did_full_repr"
420
421 def __init__(self):
422 super().__init__()
423 self.node_ids = {} # type: dict[IGNode, int]
424 self.did_full_repr = OSet() # type: OSet[IGNode]
425
426
427 @final
428 class IGNode:
429 """ interference graph node """
430 __slots__ = "merged_ssa_val", "edges", "loc"
431
432 def __init__(self, merged_ssa_val, edges, loc):
433 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
434 self.merged_ssa_val = merged_ssa_val
435 self.edges = OSet(edges)
436 self.loc = loc
437
438 def add_edge(self, other):
439 # type: (IGNode) -> None
440 self.edges.add(other)
441 other.edges.add(self)
442
443 def __eq__(self, other):
444 # type: (object) -> bool
445 if isinstance(other, IGNode):
446 return self.merged_ssa_val == other.merged_ssa_val
447 return NotImplemented
448
449 def __hash__(self):
450 # type: () -> int
451 return hash(self.merged_ssa_val)
452
453 def __repr__(self, repr_state=None, short=False):
454 # type: (None | IGNodeReprState, bool) -> str
455 if repr_state is None:
456 repr_state = IGNodeReprState()
457 node_id = repr_state.node_ids.get(self, None)
458 if node_id is None:
459 repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
460 if short or self in repr_state.did_full_repr:
461 return f"<IGNode #{node_id}>"
462 repr_state.did_full_repr.add(self)
463 edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
464 return (f"IGNode(#{node_id}, "
465 f"merged_ssa_val={self.merged_ssa_val}, "
466 f"edges={{{edges}}}, "
467 f"loc={self.loc})")
468
469 @property
470 def loc_set(self):
471 # type: () -> LocSet
472 return self.merged_ssa_val.loc_set
473
474 def loc_conflicts_with_neighbors(self, loc):
475 # type: (Loc) -> bool
476 for neighbor in self.edges:
477 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
478 return True
479 return False
480
481
482 class AllocationFailedError(Exception):
483 def __init__(self, msg, node, interference_graph):
484 # type: (str, IGNode, InterferenceGraph) -> None
485 super().__init__(msg, node, interference_graph)
486 self.node = node
487 self.interference_graph = interference_graph
488
489 def __repr__(self, repr_state=None):
490 # type: (None | IGNodeReprState) -> str
491 if repr_state is None:
492 repr_state = IGNodeReprState()
493 return (f"{__class__.__name__}({self.args[0]!r}, "
494 f"node={self.node.__repr__(repr_state, True)}, "
495 f"interference_graph="
496 f"{self.interference_graph.__repr__(repr_state)})")
497
498 def __str__(self):
499 # type: () -> str
500 return self.__repr__()
501
502
503 def allocate_registers(fn):
504 # type: (Fn) -> dict[SSAVal, Loc]
505
506 # inserts enough copies that no manual spilling is necessary, all
507 # spilling is done by the register allocator naturally allocating SSAVals
508 # to stack slots
509 fn.pre_ra_insert_copies()
510
511 fn_analysis = FnAnalysis(fn)
512 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
513
514 for ssa_vals in fn_analysis.live_at.values():
515 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
516 for ssa_val in ssa_vals:
517 live_merged_ssa_vals.add(
518 interference_graph.merged_ssa_val_map[ssa_val])
519 for i, j in combinations(live_merged_ssa_vals, 2):
520 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
521 interference_graph.nodes[i].add_edge(
522 interference_graph.nodes[j])
523
524 nodes_remaining = OSet(interference_graph.nodes.values())
525
526 def local_colorability_score(node):
527 # type: (IGNode) -> int
528 """ returns a positive integer if node is locally colorable, returns
529 zero or a negative integer if node isn't known to be locally
530 colorable, the more negative the value, the less colorable
531 """
532 if node not in nodes_remaining:
533 raise ValueError()
534 retval = len(node.loc_set)
535 for neighbor in node.edges:
536 if neighbor in nodes_remaining:
537 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
538 return retval
539
540 # TODO: implement copy-merging
541
542 node_stack = [] # type: list[IGNode]
543 while True:
544 best_node = None # type: None | IGNode
545 best_score = 0
546 for node in nodes_remaining:
547 score = local_colorability_score(node)
548 if best_node is None or score > best_score:
549 best_node = node
550 best_score = score
551 if best_score > 0:
552 # it's locally colorable, no need to find a better one
553 break
554
555 if best_node is None:
556 break
557 node_stack.append(best_node)
558 nodes_remaining.remove(best_node)
559
560 retval = {} # type: dict[SSAVal, Loc]
561
562 while len(node_stack) > 0:
563 node = node_stack.pop()
564 if node.loc is not None:
565 if node.loc_conflicts_with_neighbors(node.loc):
566 raise AllocationFailedError(
567 "IGNode is pre-allocated to a conflicting Loc",
568 node=node, interference_graph=interference_graph)
569 else:
570 # pick the first non-conflicting register in node.reg_class, since
571 # register classes are ordered from most preferred to least
572 # preferred register.
573 for loc in node.loc_set:
574 if not node.loc_conflicts_with_neighbors(loc):
575 node.loc = loc
576 break
577 if node.loc is None:
578 raise AllocationFailedError(
579 "failed to allocate Loc for IGNode",
580 node=node, interference_graph=interference_graph)
581
582 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
583 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
584
585 return retval