working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Iterable, Iterator, Mapping, TextIO
10
11 from cached_property import cached_property
12 from nmutil.plain_data import plain_data
13
14 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
15 LocSet, ProgramRange, SSAVal,
16 Ty)
17 from bigint_presentation_code.type_util import final
18 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
19
20
21 class BadMergedSSAVal(ValueError):
22 pass
23
24
25 @plain_data(frozen=True, repr=False)
26 @final
27 class MergedSSAVal(metaclass=InternedMeta):
28 """a set of `SSAVal`s along with their offsets, all register allocated as
29 a single unit.
30
31 Definition of the term `offset` for this class:
32
33 Let `locs[x]` be the `Loc` that `x` is assigned to after register
34 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
35 for each `SSAVal` `ssa_val` in `msv` is defined as:
36
37 ```
38 msv.ssa_val_offsets[ssa_val] = (msv.offset
39 + locs[ssa_val].start - locs[msv].start)
40 ```
41
42 Example:
43 ```
44 v1.ty == <I64*4>
45 v2.ty == <I64*2>
46 v3.ty == <I64>
47 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
48 msv.ty == <I64*6>
49 ```
50 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
51 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
52 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
53 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
54 """
55 __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
56
57 def __init__(self, fn_analysis, ssa_val_offsets):
58 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
59 self.fn_analysis = fn_analysis
60 if isinstance(ssa_val_offsets, SSAVal):
61 ssa_val_offsets = {ssa_val_offsets: 0}
62 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
63 first_ssa_val = None
64 for ssa_val in self.ssa_vals:
65 first_ssa_val = ssa_val
66 break
67 if first_ssa_val is None:
68 raise BadMergedSSAVal("MergedSSAVal can't be empty")
69 self.first_ssa_val = first_ssa_val # type: SSAVal
70 # self.ty checks for mismatched base_ty
71 reg_len = self.ty.reg_len
72 loc_set = None # type: None | LocSet
73 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
74 def locs():
75 # type: () -> Iterable[Loc]
76 for loc in ssa_val.def_loc_set_before_spread:
77 disallowed_by_use = False
78 for use in fn_analysis.uses[ssa_val]:
79 # calculate the start for the use's Loc before spread
80 # e.g. if the def's Loc before spread starts at r6
81 # and the def's reg_offset_in_unspread is 5
82 # and the use's reg_offset_in_unspread is 3
83 # then the use's Loc before spread starts at r8
84 # because 8 == 6 + 5 - 3
85 start = (loc.start + ssa_val.reg_offset_in_unspread
86 - use.reg_offset_in_unspread)
87 use_loc = Loc.try_make(
88 loc.kind, start=start,
89 reg_len=use.ty_before_spread.reg_len)
90 if (use_loc is None or
91 use_loc not in use.use_loc_set_before_spread):
92 disallowed_by_use = True
93 break
94 if disallowed_by_use:
95 continue
96 # FIXME: add spread consistency check
97 start = loc.start - cur_offset + self.offset
98 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
99 if loc is not None and (loc_set is None or loc in loc_set):
100 yield loc
101 loc_set = LocSet(locs())
102 assert loc_set is not None, "already checked that self isn't empty"
103 if loc_set.ty is None:
104 raise BadMergedSSAVal("there are no valid Locs left")
105 assert loc_set.ty == self.ty, "logic error somewhere"
106 self.loc_set = loc_set # type: LocSet
107
108 @cached_property
109 def __hash(self):
110 # type: () -> int
111 return hash((self.fn_analysis, self.ssa_val_offsets))
112
113 def __hash__(self):
114 # type: () -> int
115 return self.__hash
116
117 @cached_property
118 def offset(self):
119 # type: () -> int
120 return min(self.ssa_val_offsets_before_spread.values())
121
122 @property
123 def base_ty(self):
124 # type: () -> BaseTy
125 return self.first_ssa_val.base_ty
126
127 @cached_property
128 def ssa_vals(self):
129 # type: () -> OFSet[SSAVal]
130 return OFSet(self.ssa_val_offsets.keys())
131
132 @cached_property
133 def ty(self):
134 # type: () -> Ty
135 reg_len = 0
136 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
137 cur_ty = ssa_val.ty_before_spread
138 if self.base_ty != cur_ty.base_ty:
139 raise BadMergedSSAVal(
140 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
141 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
142 return Ty(base_ty=self.base_ty, reg_len=reg_len)
143
144 @cached_property
145 def ssa_val_offsets_before_spread(self):
146 # type: () -> FMap[SSAVal, int]
147 retval = {} # type: dict[SSAVal, int]
148 for ssa_val, offset in self.ssa_val_offsets.items():
149 retval[ssa_val] = (
150 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
151 return FMap(retval)
152
153 def offset_by(self, amount):
154 # type: (int) -> MergedSSAVal
155 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
156 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
157
158 def normalized(self):
159 # type: () -> MergedSSAVal
160 return self.offset_by(-self.offset)
161
162 def with_offset_to_match(self, target, additional_offset=0):
163 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
164 if isinstance(target, MergedSSAVal):
165 ssa_val_offsets = target.ssa_val_offsets
166 else:
167 ssa_val_offsets = {target: 0}
168 for ssa_val, offset in self.ssa_val_offsets.items():
169 if ssa_val in ssa_val_offsets:
170 return self.offset_by(
171 ssa_val_offsets[ssa_val] + additional_offset - offset)
172 raise ValueError("can't change offset to match unrelated MergedSSAVal")
173
174 def merged(self, *others):
175 # type: (*MergedSSAVal) -> MergedSSAVal
176 retval = dict(self.ssa_val_offsets)
177 for other in others:
178 if other.fn_analysis != self.fn_analysis:
179 raise ValueError("fn_analysis mismatch")
180 for ssa_val, offset in other.ssa_val_offsets.items():
181 if ssa_val in retval and retval[ssa_val] != offset:
182 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
183 f"{retval[ssa_val]} != {offset}")
184 retval[ssa_val] = offset
185 return MergedSSAVal(fn_analysis=self.fn_analysis,
186 ssa_val_offsets=retval)
187
188 @cached_property
189 def live_interval(self):
190 # type: () -> ProgramRange
191 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
192 start = live_range.start
193 stop = live_range.stop
194 for ssa_val in self.ssa_vals:
195 live_range = self.fn_analysis.live_ranges[ssa_val]
196 start = min(start, live_range.start)
197 stop = max(stop, live_range.stop)
198 return ProgramRange(start=start, stop=stop)
199
200 def __repr__(self):
201 return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
202 f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
203 f"live_interval={self.live_interval})")
204
205
206 @final
207 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
208 def __init__(self):
209 # type: (...) -> None
210 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
211 self.__ig_node_map = MergedSSAValToIGNodeMap(
212 _private_merged_ssa_val_map=self.__map)
213
214 def __getitem__(self, __key):
215 # type: (SSAVal) -> MergedSSAVal
216 return self.__map[__key]
217
218 def __iter__(self):
219 # type: () -> Iterator[SSAVal]
220 return iter(self.__map)
221
222 def __len__(self):
223 # type: () -> int
224 return len(self.__map)
225
226 @property
227 def ig_node_map(self):
228 # type: () -> MergedSSAValToIGNodeMap
229 return self.__ig_node_map
230
231 def __repr__(self):
232 # type: () -> str
233 s = ",\n".join(repr(v) for v in self.__ig_node_map)
234 return f"SSAValToMergedSSAValMap({{{s}}})"
235
236
237 @final
238 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
239 def __init__(
240 self, *,
241 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
242 ):
243 # type: (...) -> None
244 self.__merged_ssa_val_map = _private_merged_ssa_val_map
245 self.__map = {} # type: dict[MergedSSAVal, IGNode]
246
247 def __getitem__(self, __key):
248 # type: (MergedSSAVal) -> IGNode
249 return self.__map[__key]
250
251 def __iter__(self):
252 # type: () -> Iterator[MergedSSAVal]
253 return iter(self.__map)
254
255 def __len__(self):
256 # type: () -> int
257 return len(self.__map)
258
259 def add_node(self, merged_ssa_val):
260 # type: (MergedSSAVal) -> IGNode
261 node = self.__map.get(merged_ssa_val, None)
262 if node is not None:
263 return node
264 added = 0 # type: int | None
265 try:
266 for ssa_val in merged_ssa_val.ssa_vals:
267 if ssa_val in self.__merged_ssa_val_map:
268 raise ValueError(
269 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
270 f"{merged_ssa_val} and "
271 f"{self.__merged_ssa_val_map[ssa_val]}")
272 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
273 added += 1
274 retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
275 self.__map[merged_ssa_val] = retval
276 added = None
277 return retval
278 finally:
279 if added is not None:
280 # remove partially added stuff
281 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
282 if idx >= added:
283 break
284 del self.__merged_ssa_val_map[ssa_val]
285
286 def merge_into_one_node(self, final_merged_ssa_val):
287 # type: (MergedSSAVal) -> IGNode
288 source_nodes = OSet() # type: OSet[IGNode]
289 edges = OSet() # type: OSet[IGNode]
290 loc = None # type: Loc | None
291 for ssa_val in final_merged_ssa_val.ssa_vals:
292 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
293 source_node = self.__map[merged_ssa_val]
294 source_nodes.add(source_node)
295 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
296 raise ValueError(
297 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
298 f"but not in merged IGNode's merged_ssa_val: "
299 f"source_node={source_node} "
300 f"final_merged_ssa_val={final_merged_ssa_val}")
301 if loc is None:
302 loc = source_node.loc
303 elif source_node.loc is not None and loc != source_node.loc:
304 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
305 f"values: {loc} != {source_node.loc}")
306 edges |= source_node.edges
307 if len(source_nodes) == 1:
308 return source_nodes.pop() # merging a single node is a no-op
309 # we're finished checking validity, now we can modify stuff
310 edges -= source_nodes
311 retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
312 loc=loc)
313 for node in edges:
314 node.edges -= source_nodes
315 node.edges.add(retval)
316 for node in source_nodes:
317 del self.__map[node.merged_ssa_val]
318 self.__map[final_merged_ssa_val] = retval
319 for ssa_val in final_merged_ssa_val.ssa_vals:
320 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
321 return retval
322
323 def __repr__(self, repr_state=None):
324 # type: (None | IGNodeReprState) -> str
325 if repr_state is None:
326 repr_state = IGNodeReprState()
327 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
328 return f"MergedSSAValToIGNodeMap({{{s}}})"
329
330
331 @plain_data(frozen=True, repr=False)
332 @final
333 class InterferenceGraph:
334 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
335
336 def __init__(self, fn_analysis, merged_ssa_vals):
337 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
338 self.fn_analysis = fn_analysis
339 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
340 self.nodes = self.merged_ssa_val_map.ig_node_map
341 for i in merged_ssa_vals:
342 self.nodes.add_node(i)
343
344 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
345 # type: (SSAVal, SSAVal, int) -> IGNode
346 merged1 = self.merged_ssa_val_map[ssa_val1]
347 merged2 = self.merged_ssa_val_map[ssa_val2]
348 merged = merged1.with_offset_to_match(ssa_val1)
349 merged = merged.merged(merged2.with_offset_to_match(
350 ssa_val2, additional_offset=additional_offset))
351 return self.nodes.merge_into_one_node(merged)
352
353 @staticmethod
354 def minimally_merged(fn_analysis):
355 # type: (FnAnalysis) -> InterferenceGraph
356 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
357 for op in fn_analysis.fn.ops:
358 for inp in op.input_uses:
359 if inp.unspread_start != inp:
360 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
361 additional_offset=inp.reg_offset_in_unspread)
362 for out in op.outputs:
363 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
364 if out.unspread_start != out:
365 retval.merge(out.unspread_start, out,
366 additional_offset=out.reg_offset_in_unspread)
367 if out.tied_input is not None:
368 retval.merge(out.tied_input.ssa_val, out)
369 return retval
370
371 def __repr__(self, repr_state=None):
372 # type: (None | IGNodeReprState) -> str
373 if repr_state is None:
374 repr_state = IGNodeReprState()
375 s = self.nodes.__repr__(repr_state)
376 return f"InterferenceGraph(nodes={s}, <...>)"
377
378
379 @plain_data(repr=False)
380 class IGNodeReprState:
381 __slots__ = "node_ids", "did_full_repr"
382
383 def __init__(self):
384 super().__init__()
385 self.node_ids = {} # type: dict[IGNode, int]
386 self.did_full_repr = OSet() # type: OSet[IGNode]
387
388
389 @final
390 class IGNode:
391 """ interference graph node """
392 __slots__ = "merged_ssa_val", "edges", "loc"
393
394 def __init__(self, merged_ssa_val, edges, loc):
395 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
396 self.merged_ssa_val = merged_ssa_val
397 self.edges = OSet(edges)
398 self.loc = loc
399
400 def add_edge(self, other):
401 # type: (IGNode) -> None
402 self.edges.add(other)
403 other.edges.add(self)
404
405 def __eq__(self, other):
406 # type: (object) -> bool
407 if isinstance(other, IGNode):
408 return self.merged_ssa_val == other.merged_ssa_val
409 return NotImplemented
410
411 def __hash__(self):
412 # type: () -> int
413 return hash(self.merged_ssa_val)
414
415 def __repr__(self, repr_state=None, short=False):
416 # type: (None | IGNodeReprState, bool) -> str
417 if repr_state is None:
418 repr_state = IGNodeReprState()
419 node_id = repr_state.node_ids.get(self, None)
420 if node_id is None:
421 repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
422 if short or self in repr_state.did_full_repr:
423 return f"<IGNode #{node_id}>"
424 repr_state.did_full_repr.add(self)
425 edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
426 return (f"IGNode(#{node_id}, "
427 f"merged_ssa_val={self.merged_ssa_val}, "
428 f"edges={{{edges}}}, "
429 f"loc={self.loc})")
430
431 @property
432 def loc_set(self):
433 # type: () -> LocSet
434 return self.merged_ssa_val.loc_set
435
436 def loc_conflicts_with_neighbors(self, loc):
437 # type: (Loc) -> bool
438 for neighbor in self.edges:
439 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
440 return True
441 return False
442
443
444 class AllocationFailedError(Exception):
445 def __init__(self, msg, node, interference_graph):
446 # type: (str, IGNode, InterferenceGraph) -> None
447 super().__init__(msg, node, interference_graph)
448 self.node = node
449 self.interference_graph = interference_graph
450
451 def __repr__(self, repr_state=None):
452 # type: (None | IGNodeReprState) -> str
453 if repr_state is None:
454 repr_state = IGNodeReprState()
455 return (f"{__class__.__name__}({self.args[0]!r}, "
456 f"node={self.node.__repr__(repr_state, True)}, "
457 f"interference_graph="
458 f"{self.interference_graph.__repr__(repr_state)})")
459
460 def __str__(self):
461 # type: () -> str
462 return self.__repr__()
463
464
465 def allocate_registers(fn, debug_out=None):
466 # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
467
468 # inserts enough copies that no manual spilling is necessary, all
469 # spilling is done by the register allocator naturally allocating SSAVals
470 # to stack slots
471 fn.pre_ra_insert_copies()
472
473 if debug_out is not None:
474 print(f"After pre_ra_insert_copies():\n{fn.ops}",
475 file=debug_out, flush=True)
476
477 fn_analysis = FnAnalysis(fn)
478 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
479
480 if debug_out is not None:
481 print(f"After InterferenceGraph.minimally_merged():\n"
482 f"{interference_graph}", file=debug_out, flush=True)
483
484 for pp, ssa_vals in fn_analysis.live_at.items():
485 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
486 for ssa_val in ssa_vals:
487 live_merged_ssa_vals.add(
488 interference_graph.merged_ssa_val_map[ssa_val])
489 for i, j in combinations(live_merged_ssa_vals, 2):
490 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
491 interference_graph.nodes[i].add_edge(
492 interference_graph.nodes[j])
493 if debug_out is not None:
494 print(f"processed {pp} out of {fn_analysis.all_program_points}",
495 file=debug_out, flush=True)
496
497 if debug_out is not None:
498 print(f"After adding interference graph edges:\n"
499 f"{interference_graph}", file=debug_out, flush=True)
500
501 nodes_remaining = OSet(interference_graph.nodes.values())
502
503 local_colorability_score_cache = {} # type: dict[IGNode, int]
504
505 def local_colorability_score(node):
506 # type: (IGNode) -> int
507 """ returns a positive integer if node is locally colorable, returns
508 zero or a negative integer if node isn't known to be locally
509 colorable, the more negative the value, the less colorable
510 """
511 if node not in nodes_remaining:
512 raise ValueError()
513 retval = local_colorability_score_cache.get(node, None)
514 if retval is not None:
515 return retval
516 retval = len(node.loc_set)
517 for neighbor in node.edges:
518 if neighbor in nodes_remaining:
519 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
520 local_colorability_score_cache[node] = retval
521 return retval
522
523 # TODO: implement copy-merging
524
525 node_stack = [] # type: list[IGNode]
526 while True:
527 best_node = None # type: None | IGNode
528 best_score = 0
529 for node in nodes_remaining:
530 score = local_colorability_score(node)
531 if best_node is None or score > best_score:
532 best_node = node
533 best_score = score
534 if best_score > 0:
535 # it's locally colorable, no need to find a better one
536 break
537
538 if best_node is None:
539 break
540 node_stack.append(best_node)
541 nodes_remaining.remove(best_node)
542 local_colorability_score_cache.pop(best_node, None)
543 for neighbor in best_node.edges:
544 local_colorability_score_cache.pop(neighbor, None)
545
546 if debug_out is not None:
547 print(f"After deciding node allocation order:\n"
548 f"{node_stack}", file=debug_out, flush=True)
549
550 retval = {} # type: dict[SSAVal, Loc]
551
552 while len(node_stack) > 0:
553 node = node_stack.pop()
554 if node.loc is not None:
555 if node.loc_conflicts_with_neighbors(node.loc):
556 raise AllocationFailedError(
557 "IGNode is pre-allocated to a conflicting Loc",
558 node=node, interference_graph=interference_graph)
559 else:
560 # pick the first non-conflicting register in node.reg_class, since
561 # register classes are ordered from most preferred to least
562 # preferred register.
563 for loc in node.loc_set:
564 if not node.loc_conflicts_with_neighbors(loc):
565 node.loc = loc
566 break
567 if node.loc is None:
568 raise AllocationFailedError(
569 "failed to allocate Loc for IGNode",
570 node=node, interference_graph=interference_graph)
571
572 if debug_out is not None:
573 print(f"After allocating Loc for node:\n{node}",
574 file=debug_out, flush=True)
575
576 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
577 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
578
579 if debug_out is not None:
580 print(f"final Locs for all SSAVals:\n{retval}",
581 file=debug_out, flush=True)
582
583 return retval