working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator2.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Iterable, Iterator, Mapping
10
11 from cached_property import cached_property
12 from nmutil.plain_data import plain_data
13
14 from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc,
15 LocSet, ProgramRange,
16 SSAVal, Ty)
17 from bigint_presentation_code.type_util import final
18 from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
19
20
21 class BadMergedSSAVal(ValueError):
22 pass
23
24
25 @plain_data(frozen=True, repr=False)
26 @final
27 class MergedSSAVal(metaclass=InternedMeta):
28 """a set of `SSAVal`s along with their offsets, all register allocated as
29 a single unit.
30
31 Definition of the term `offset` for this class:
32
33 Let `locs[x]` be the `Loc` that `x` is assigned to after register
34 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
35 for each `SSAVal` `ssa_val` in `msv` is defined as:
36
37 ```
38 msv.ssa_val_offsets[ssa_val] = (msv.offset
39 + locs[ssa_val].start - locs[msv].start)
40 ```
41
42 Example:
43 ```
44 v1.ty == <I64*4>
45 v2.ty == <I64*2>
46 v3.ty == <I64>
47 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
48 msv.ty == <I64*6>
49 ```
50 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
51 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
52 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
53 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
54 """
55 __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
56
57 def __init__(self, fn_analysis, ssa_val_offsets):
58 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
59 self.fn_analysis = fn_analysis
60 if isinstance(ssa_val_offsets, SSAVal):
61 ssa_val_offsets = {ssa_val_offsets: 0}
62 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
63 first_ssa_val = None
64 for ssa_val in self.ssa_vals:
65 first_ssa_val = ssa_val
66 break
67 if first_ssa_val is None:
68 raise BadMergedSSAVal("MergedSSAVal can't be empty")
69 self.first_ssa_val = first_ssa_val # type: SSAVal
70 # self.ty checks for mismatched base_ty
71 reg_len = self.ty.reg_len
72 loc_set = None # type: None | LocSet
73 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
74 def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
75
76 def locs():
77 # type: () -> Iterable[Loc]
78 for loc in ssa_val.def_loc_set_before_spread:
79 disallowed_by_use = False
80 for use in fn_analysis.uses[ssa_val]:
81 use_spread_idx = \
82 use.defining_descriptor.spread_index or 0
83 # calculate the start for the use's Loc before spread
84 # e.g. if the def's Loc before spread starts at r6
85 # and the def's spread_index is 5
86 # and the use's spread_index is 3
87 # then the use's Loc before spread starts at r8
88 # because 8 == 6 + 5 - 3
89 start = loc.start + def_spread_idx - use_spread_idx
90 use_loc = Loc.try_make(
91 loc.kind, start=start,
92 reg_len=use.ty_before_spread.reg_len)
93 if (use_loc is None or
94 use_loc not in use.use_loc_set_before_spread):
95 disallowed_by_use = True
96 break
97 if disallowed_by_use:
98 continue
99 # FIXME: add spread consistency check
100 start = loc.start - cur_offset + self.offset
101 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
102 if loc is not None and (loc_set is None or loc in loc_set):
103 yield loc
104 loc_set = LocSet(locs())
105 assert loc_set is not None, "already checked that self isn't empty"
106 if loc_set.ty is None:
107 raise BadMergedSSAVal("there are no valid Locs left")
108 assert loc_set.ty == self.ty, "logic error somewhere"
109 self.loc_set = loc_set # type: LocSet
110
111 @cached_property
112 def __hash(self):
113 # type: () -> int
114 return hash((self.fn_analysis, self.ssa_val_offsets))
115
116 def __hash__(self):
117 # type: () -> int
118 return self.__hash
119
120 @cached_property
121 def offset(self):
122 # type: () -> int
123 return min(self.ssa_val_offsets_before_spread.values())
124
125 @property
126 def base_ty(self):
127 # type: () -> BaseTy
128 return self.first_ssa_val.base_ty
129
130 @cached_property
131 def ssa_vals(self):
132 # type: () -> OFSet[SSAVal]
133 return OFSet(self.ssa_val_offsets.keys())
134
135 @cached_property
136 def ty(self):
137 # type: () -> Ty
138 reg_len = 0
139 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
140 cur_ty = ssa_val.ty_before_spread
141 if self.base_ty != cur_ty.base_ty:
142 raise BadMergedSSAVal(
143 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
144 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
145 return Ty(base_ty=self.base_ty, reg_len=reg_len)
146
147 @cached_property
148 def ssa_val_offsets_before_spread(self):
149 # type: () -> FMap[SSAVal, int]
150 retval = {} # type: dict[SSAVal, int]
151 for ssa_val, offset in self.ssa_val_offsets.items():
152 retval[ssa_val] = (
153 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
154 return FMap(retval)
155
156 def offset_by(self, amount):
157 # type: (int) -> MergedSSAVal
158 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
159 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
160
161 def normalized(self):
162 # type: () -> MergedSSAVal
163 return self.offset_by(-self.offset)
164
165 def with_offset_to_match(self, target, additional_offset=0):
166 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
167 if isinstance(target, MergedSSAVal):
168 ssa_val_offsets = target.ssa_val_offsets
169 else:
170 ssa_val_offsets = {target: 0}
171 for ssa_val, offset in self.ssa_val_offsets.items():
172 if ssa_val in ssa_val_offsets:
173 return self.offset_by(
174 ssa_val_offsets[ssa_val] + additional_offset - offset)
175 raise ValueError("can't change offset to match unrelated MergedSSAVal")
176
177 def merged(self, *others):
178 # type: (*MergedSSAVal) -> MergedSSAVal
179 retval = dict(self.ssa_val_offsets)
180 for other in others:
181 if other.fn_analysis != self.fn_analysis:
182 raise ValueError("fn_analysis mismatch")
183 for ssa_val, offset in other.ssa_val_offsets.items():
184 if ssa_val in retval and retval[ssa_val] != offset:
185 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
186 f"{retval[ssa_val]} != {offset}")
187 retval[ssa_val] = offset
188 return MergedSSAVal(fn_analysis=self.fn_analysis,
189 ssa_val_offsets=retval)
190
191 @cached_property
192 def live_interval(self):
193 # type: () -> ProgramRange
194 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
195 start = live_range.start
196 stop = live_range.stop
197 for ssa_val in self.ssa_vals:
198 live_range = self.fn_analysis.live_ranges[ssa_val]
199 start = min(start, live_range.start)
200 stop = max(stop, live_range.stop)
201 return ProgramRange(start=start, stop=stop)
202
203 def __repr__(self):
204 return (f"MergedSSAVal({self.fn_analysis}, "
205 f"ssa_val_offsets={self.ssa_val_offsets})")
206
207
208 @final
209 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
210 def __init__(self):
211 # type: (...) -> None
212 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
213 self.__ig_node_map = MergedSSAValToIGNodeMap(
214 _private_merged_ssa_val_map=self.__map)
215
216 def __getitem__(self, __key):
217 # type: (SSAVal) -> MergedSSAVal
218 return self.__map[__key]
219
220 def __iter__(self):
221 # type: () -> Iterator[SSAVal]
222 return iter(self.__map)
223
224 def __len__(self):
225 # type: () -> int
226 return len(self.__map)
227
228 @property
229 def ig_node_map(self):
230 # type: () -> MergedSSAValToIGNodeMap
231 return self.__ig_node_map
232
233 def __repr__(self):
234 # type: () -> str
235 s = ",\n".join(repr(v) for v in self.__ig_node_map)
236 return f"SSAValToMergedSSAValMap({{{s}}})"
237
238
239 @final
240 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
241 def __init__(
242 self, *,
243 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
244 ):
245 # type: (...) -> None
246 self.__merged_ssa_val_map = _private_merged_ssa_val_map
247 self.__map = {} # type: dict[MergedSSAVal, IGNode]
248
249 def __getitem__(self, __key):
250 # type: (MergedSSAVal) -> IGNode
251 return self.__map[__key]
252
253 def __iter__(self):
254 # type: () -> Iterator[MergedSSAVal]
255 return iter(self.__map)
256
257 def __len__(self):
258 # type: () -> int
259 return len(self.__map)
260
261 def add_node(self, merged_ssa_val):
262 # type: (MergedSSAVal) -> IGNode
263 node = self.__map.get(merged_ssa_val, None)
264 if node is not None:
265 return node
266 added = 0 # type: int | None
267 try:
268 for ssa_val in merged_ssa_val.ssa_vals:
269 if ssa_val in self.__merged_ssa_val_map:
270 raise ValueError(
271 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
272 f"{merged_ssa_val} and "
273 f"{self.__merged_ssa_val_map[ssa_val]}")
274 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
275 added += 1
276 retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
277 self.__map[merged_ssa_val] = retval
278 added = None
279 return retval
280 finally:
281 if added is not None:
282 # remove partially added stuff
283 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
284 if idx >= added:
285 break
286 del self.__merged_ssa_val_map[ssa_val]
287
288 def merge_into_one_node(self, final_merged_ssa_val):
289 # type: (MergedSSAVal) -> IGNode
290 source_nodes = OSet() # type: OSet[IGNode]
291 edges = OSet() # type: OSet[IGNode]
292 loc = None # type: Loc | None
293 for ssa_val in final_merged_ssa_val.ssa_vals:
294 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
295 source_node = self.__map[merged_ssa_val]
296 source_nodes.add(source_node)
297 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
298 raise ValueError(
299 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
300 f"but not in merged IGNode's merged_ssa_val: "
301 f"source_node={source_node} "
302 f"final_merged_ssa_val={final_merged_ssa_val}")
303 if loc is None:
304 loc = source_node.loc
305 elif source_node.loc is not None and loc != source_node.loc:
306 raise ValueError(f"can't merge IGNodes with mismatched `loc` "
307 f"values: {loc} != {source_node.loc}")
308 edges |= source_node.edges
309 if len(source_nodes) == 1:
310 return source_nodes.pop() # merging a single node is a no-op
311 # we're finished checking validity, now we can modify stuff
312 edges -= source_nodes
313 retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
314 loc=loc)
315 for node in edges:
316 node.edges -= source_nodes
317 node.edges.add(retval)
318 for node in source_nodes:
319 del self.__map[node.merged_ssa_val]
320 self.__map[final_merged_ssa_val] = retval
321 for ssa_val in final_merged_ssa_val.ssa_vals:
322 self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
323 return retval
324
325 def __repr__(self, repr_state=None):
326 # type: (None | IGNodeReprState) -> str
327 if repr_state is None:
328 repr_state = IGNodeReprState()
329 s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
330 return f"MergedSSAValToIGNodeMap({{{s}}})"
331
332
333 @plain_data(frozen=True, repr=False)
334 @final
335 class InterferenceGraph:
336 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
337
338 def __init__(self, fn_analysis, merged_ssa_vals):
339 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
340 self.fn_analysis = fn_analysis
341 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
342 self.nodes = self.merged_ssa_val_map.ig_node_map
343 for i in merged_ssa_vals:
344 self.nodes.add_node(i)
345
346 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
347 # type: (SSAVal, SSAVal, int) -> IGNode
348 merged1 = self.merged_ssa_val_map[ssa_val1]
349 merged2 = self.merged_ssa_val_map[ssa_val2]
350 merged = merged1.with_offset_to_match(ssa_val1)
351 merged = merged.merged(merged2.with_offset_to_match(
352 ssa_val2, additional_offset=additional_offset))
353 return self.nodes.merge_into_one_node(merged)
354
355 @staticmethod
356 def minimally_merged(fn_analysis):
357 # type: (FnAnalysis) -> InterferenceGraph
358 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
359 for op in fn_analysis.fn.ops:
360 for inp in op.input_uses:
361 if inp.unspread_start != inp:
362 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
363 additional_offset=inp.reg_offset_in_unspread)
364 for out in op.outputs:
365 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
366 if out.unspread_start != out:
367 retval.merge(out.unspread_start, out,
368 additional_offset=out.reg_offset_in_unspread)
369 if out.tied_input is not None:
370 retval.merge(out.tied_input.ssa_val, out)
371 return retval
372
373 def __repr__(self, repr_state=None):
374 # type: (None | IGNodeReprState) -> str
375 if repr_state is None:
376 repr_state = IGNodeReprState()
377 s = self.nodes.__repr__(repr_state)
378 return f"InterferenceGraph(nodes={s}, <...>)"
379
380
381 @plain_data(repr=False)
382 class IGNodeReprState:
383 __slots__ = "node_ids", "did_full_repr"
384
385 def __init__(self):
386 super().__init__()
387 self.node_ids = {} # type: dict[IGNode, int]
388 self.did_full_repr = OSet() # type: OSet[IGNode]
389
390
391 @final
392 class IGNode:
393 """ interference graph node """
394 __slots__ = "merged_ssa_val", "edges", "loc"
395
396 def __init__(self, merged_ssa_val, edges, loc):
397 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
398 self.merged_ssa_val = merged_ssa_val
399 self.edges = OSet(edges)
400 self.loc = loc
401
402 def add_edge(self, other):
403 # type: (IGNode) -> None
404 self.edges.add(other)
405 other.edges.add(self)
406
407 def __eq__(self, other):
408 # type: (object) -> bool
409 if isinstance(other, IGNode):
410 return self.merged_ssa_val == other.merged_ssa_val
411 return NotImplemented
412
413 def __hash__(self):
414 # type: () -> int
415 return hash(self.merged_ssa_val)
416
417 def __repr__(self, repr_state=None, short=False):
418 # type: (None | IGNodeReprState, bool) -> str
419 if repr_state is None:
420 repr_state = IGNodeReprState()
421 node_id = repr_state.node_ids.get(self, None)
422 if node_id is None:
423 repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
424 if short or self in repr_state.did_full_repr:
425 return f"<IGNode #{node_id}>"
426 repr_state.did_full_repr.add(self)
427 edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
428 return (f"IGNode(#{node_id}, "
429 f"merged_ssa_val={self.merged_ssa_val}, "
430 f"edges={{{edges}}}, "
431 f"loc={self.loc})")
432
433 @property
434 def loc_set(self):
435 # type: () -> LocSet
436 return self.merged_ssa_val.loc_set
437
438 def loc_conflicts_with_neighbors(self, loc):
439 # type: (Loc) -> bool
440 for neighbor in self.edges:
441 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
442 return True
443 return False
444
445
446 class AllocationFailedError(Exception):
447 def __init__(self, msg, node, interference_graph):
448 # type: (str, IGNode, InterferenceGraph) -> None
449 super().__init__(msg, node, interference_graph)
450 self.node = node
451 self.interference_graph = interference_graph
452
453 def __repr__(self, repr_state=None):
454 # type: (None | IGNodeReprState) -> str
455 if repr_state is None:
456 repr_state = IGNodeReprState()
457 return (f"{__class__.__name__}({self.args[0]!r}, "
458 f"node={self.node.__repr__(repr_state, True)}, "
459 f"interference_graph="
460 f"{self.interference_graph.__repr__(repr_state)})")
461
462 def __str__(self):
463 # type: () -> str
464 return self.__repr__()
465
466
467 def allocate_registers(fn):
468 # type: (Fn) -> dict[SSAVal, Loc]
469
470 # inserts enough copies that no manual spilling is necessary, all
471 # spilling is done by the register allocator naturally allocating SSAVals
472 # to stack slots
473 fn.pre_ra_insert_copies()
474
475 fn_analysis = FnAnalysis(fn)
476 interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
477
478 for ssa_vals in fn_analysis.live_at.values():
479 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
480 for ssa_val in ssa_vals:
481 live_merged_ssa_vals.add(
482 interference_graph.merged_ssa_val_map[ssa_val])
483 for i, j in combinations(live_merged_ssa_vals, 2):
484 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
485 interference_graph.nodes[i].add_edge(
486 interference_graph.nodes[j])
487
488 nodes_remaining = OSet(interference_graph.nodes.values())
489
490 def local_colorability_score(node):
491 # type: (IGNode) -> int
492 """ returns a positive integer if node is locally colorable, returns
493 zero or a negative integer if node isn't known to be locally
494 colorable, the more negative the value, the less colorable
495 """
496 if node not in nodes_remaining:
497 raise ValueError()
498 retval = len(node.loc_set)
499 for neighbor in node.edges:
500 if neighbor in nodes_remaining:
501 retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
502 return retval
503
504 # TODO: implement copy-merging
505
506 node_stack = [] # type: list[IGNode]
507 while True:
508 best_node = None # type: None | IGNode
509 best_score = 0
510 for node in nodes_remaining:
511 score = local_colorability_score(node)
512 if best_node is None or score > best_score:
513 best_node = node
514 best_score = score
515 if best_score > 0:
516 # it's locally colorable, no need to find a better one
517 break
518
519 if best_node is None:
520 break
521 node_stack.append(best_node)
522 nodes_remaining.remove(best_node)
523
524 retval = {} # type: dict[SSAVal, Loc]
525
526 while len(node_stack) > 0:
527 node = node_stack.pop()
528 if node.loc is not None:
529 if node.loc_conflicts_with_neighbors(node.loc):
530 raise AllocationFailedError(
531 "IGNode is pre-allocated to a conflicting Loc",
532 node=node, interference_graph=interference_graph)
533 else:
534 # pick the first non-conflicting register in node.reg_class, since
535 # register classes are ordered from most preferred to least
536 # preferred register.
537 for loc in node.loc_set:
538 if not node.loc_conflicts_with_neighbors(loc):
539 node.loc = loc
540 break
541 if node.loc is None:
542 raise AllocationFailedError(
543 "failed to allocate Loc for IGNode",
544 node=node, interference_graph=interference_graph)
545
546 for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
547 retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
548
549 return retval