c7d9a88a79b873ed705441be57200867da86924d
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator2.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Any, Iterable, Iterator, Mapping, MutableSet
10
11 from cached_property import cached_property
12 from nmutil.plain_data import plain_data
13
14 from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc,
15 LocSet, Op, ProgramRange,
16 SSAVal, Ty)
17 from bigint_presentation_code.type_util import final
18 from bigint_presentation_code.util import FMap, OFSet, OSet
19
20
21 @plain_data(unsafe_hash=True, order=True, frozen=True)
22 class LiveInterval:
23 __slots__ = "first_write", "last_use"
24
25 def __init__(self, first_write, last_use=None):
26 # type: (int, int | None) -> None
27 super().__init__()
28 if last_use is None:
29 last_use = first_write
30 if last_use < first_write:
31 raise ValueError("uses must be after first_write")
32 if first_write < 0 or last_use < 0:
33 raise ValueError("indexes must be nonnegative")
34 self.first_write = first_write
35 self.last_use = last_use
36
37 def overlaps(self, other):
38 # type: (LiveInterval) -> bool
39 if self.first_write == other.first_write:
40 return True
41 return self.last_use > other.first_write \
42 and other.last_use > self.first_write
43
44 def __add__(self, use):
45 # type: (int) -> LiveInterval
46 last_use = max(self.last_use, use)
47 return LiveInterval(first_write=self.first_write, last_use=last_use)
48
49 @property
50 def live_after_op_range(self):
51 """the range of op indexes where self is live immediately after the
52 Op at each index
53 """
54 return range(self.first_write, self.last_use)
55
56
57 class BadMergedSSAVal(ValueError):
58 pass
59
60
61 @plain_data(frozen=True)
62 @final
63 class MergedSSAVal:
64 """a set of `SSAVal`s along with their offsets, all register allocated as
65 a single unit.
66
67 Definition of the term `offset` for this class:
68
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
72
73 ```
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
76 ```
77
78 Example:
79 ```
80 v1.ty == <I64*4>
81 v2.ty == <I64*2>
82 v3.ty == <I64>
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
84 msv.ty == <I64*6>
85 ```
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
90 """
91 __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
92
93 def __init__(self, fn_analysis, ssa_val_offsets):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self.fn_analysis = fn_analysis
96 if isinstance(ssa_val_offsets, SSAVal):
97 ssa_val_offsets = {ssa_val_offsets: 0}
98 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
99 first_ssa_val = None
100 for ssa_val in self.ssa_vals:
101 first_ssa_val = ssa_val
102 break
103 if first_ssa_val is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self.first_ssa_val = first_ssa_val # type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len = self.ty.reg_len
108 loc_set = None # type: None | LocSet
109 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
110 def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
111
112 def locs():
113 # type: () -> Iterable[Loc]
114 for loc in ssa_val.def_loc_set_before_spread:
115 disallowed_by_use = False
116 for use in fn_analysis.uses[ssa_val]:
117 use_spread_idx = \
118 use.defining_descriptor.spread_index or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start = loc.start + def_spread_idx - use_spread_idx
126 use_loc = Loc.try_make(
127 loc.kind, start=start,
128 reg_len=use.ty_before_spread.reg_len)
129 if (use_loc is None or
130 use_loc not in use.use_loc_set_before_spread):
131 disallowed_by_use = True
132 break
133 if disallowed_by_use:
134 continue
135 # FIXME: add spread consistency check
136 start = loc.start - cur_offset + self.offset
137 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
138 if loc is not None and (loc_set is None or loc in loc_set):
139 yield loc
140 loc_set = LocSet(locs())
141 assert loc_set is not None, "already checked that self isn't empty"
142 if loc_set.ty is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set.ty == self.ty, "logic error somewhere"
145 self.loc_set = loc_set # type: LocSet
146
147 @cached_property
148 def __hash(self):
149 # type: () -> int
150 return hash((self.fn_analysis, self.ssa_val_offsets))
151
152 def __hash__(self):
153 # type: () -> int
154 return self.__hash
155
156 @cached_property
157 def offset(self):
158 # type: () -> int
159 return min(self.ssa_val_offsets_before_spread.values())
160
161 @property
162 def base_ty(self):
163 # type: () -> BaseTy
164 return self.first_ssa_val.base_ty
165
166 @cached_property
167 def ssa_vals(self):
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self.ssa_val_offsets.keys())
170
171 @cached_property
172 def ty(self):
173 # type: () -> Ty
174 reg_len = 0
175 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
176 cur_ty = ssa_val.ty_before_spread
177 if self.base_ty != cur_ty.base_ty:
178 raise BadMergedSSAVal(
179 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
181 return Ty(base_ty=self.base_ty, reg_len=reg_len)
182
183 @cached_property
184 def ssa_val_offsets_before_spread(self):
185 # type: () -> FMap[SSAVal, int]
186 retval = {} # type: dict[SSAVal, int]
187 for ssa_val, offset in self.ssa_val_offsets.items():
188 retval[ssa_val] = (
189 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
190 return FMap(retval)
191
192 def offset_by(self, amount):
193 # type: (int) -> MergedSSAVal
194 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
195 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
196
197 def normalized(self):
198 # type: () -> MergedSSAVal
199 return self.offset_by(-self.offset)
200
201 def with_offset_to_match(self, target, additional_offset=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target, MergedSSAVal):
204 ssa_val_offsets = target.ssa_val_offsets
205 else:
206 ssa_val_offsets = {target: 0}
207 for ssa_val, offset in self.ssa_val_offsets.items():
208 if ssa_val in ssa_val_offsets:
209 return self.offset_by(
210 ssa_val_offsets[ssa_val] + additional_offset - offset)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
212
213 def merged(self, *others):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval = dict(self.ssa_val_offsets)
216 for other in others:
217 if other.fn_analysis != self.fn_analysis:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val, offset in other.ssa_val_offsets.items():
220 if ssa_val in retval and retval[ssa_val] != offset:
221 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
222 f"{retval[ssa_val]} != {offset}")
223 retval[ssa_val] = offset
224 return MergedSSAVal(fn_analysis=self.fn_analysis,
225 ssa_val_offsets=retval)
226
227 @cached_property
228 def live_interval(self):
229 # type: () -> ProgramRange
230 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
231 start = live_range.start
232 stop = live_range.stop
233 for ssa_val in self.ssa_vals:
234 live_range = self.fn_analysis.live_ranges[ssa_val]
235 start = min(start, live_range.start)
236 stop = max(stop, live_range.stop)
237 return ProgramRange(start=start, stop=stop)
238
239
240 @final
241 class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
242 def __init__(self):
243 # type: (...) -> None
244 self.__merge_map = {} # type: dict[SSAVal, MergedSSAVal]
245 self.__values_set = MergedSSAValsSet(
246 _private_merge_map=self.__merge_map,
247 _private_values_set=OSet())
248
249 def __getitem__(self, __key):
250 # type: (SSAVal) -> MergedSSAVal
251 return self.__merge_map[__key]
252
253 def __iter__(self):
254 # type: () -> Iterator[SSAVal]
255 return iter(self.__merge_map)
256
257 def __len__(self):
258 # type: () -> int
259 return len(self.__merge_map)
260
261 @property
262 def values_set(self):
263 # type: () -> MergedSSAValsSet
264 return self.__values_set
265
266 def __repr__(self):
267 # type: () -> str
268 s = ",\n".join(repr(v) for v in self.__values_set)
269 return f"MergedSSAValsMap({{{s}}})"
270
271
272 @final
273 class MergedSSAValsSet(MutableSet[MergedSSAVal]):
274 def __init__(self, *,
275 _private_merge_map, # type: dict[SSAVal, MergedSSAVal]
276 _private_values_set, # type: OSet[MergedSSAVal]
277 ):
278 # type: (...) -> None
279 self.__merge_map = _private_merge_map
280 self.__values_set = _private_values_set
281
282 @classmethod
283 def _from_iterable(cls, it):
284 # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
285 return OSet(it)
286
287 def __contains__(self, value):
288 # type: (MergedSSAVal | Any) -> bool
289 return value in self.__values_set
290
291 def __iter__(self):
292 # type: () -> Iterator[MergedSSAVal]
293 return iter(self.__values_set)
294
295 def __len__(self):
296 # type: () -> int
297 return len(self.__values_set)
298
299 def add(self, value):
300 # type: (MergedSSAVal) -> None
301 if value in self:
302 return
303 added = 0 # type: int | None
304 try:
305 for ssa_val in value.ssa_vals:
306 if ssa_val in self.__merge_map:
307 raise ValueError(
308 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
309 f"{value} and {self.__merge_map[ssa_val]}")
310 self.__merge_map[ssa_val] = value
311 added += 1
312 self.__values_set.add(value)
313 added = None
314 finally:
315 if added is not None:
316 # remove partially added stuff
317 for idx, ssa_val in enumerate(value.ssa_vals):
318 if idx >= added:
319 break
320 del self.__merge_map[ssa_val]
321
322 def discard(self, value):
323 # type: (MergedSSAVal) -> None
324 if value not in self:
325 return
326 self.__values_set.discard(value)
327 for ssa_val in value.ssa_val_offsets.keys():
328 del self.__merge_map[ssa_val]
329
330 def __repr__(self):
331 # type: () -> str
332 s = ",\n".join(repr(v) for v in self.__values_set)
333 return f"MergedSSAValsSet({{{s}}})"
334
335
336 @plain_data(frozen=True)
337 @final
338 class MergedSSAVals:
339 __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals"
340
341 def __init__(self, fn_analysis, merged_ssa_vals):
342 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
343 self.fn_analysis = fn_analysis
344 self.merge_map = MergedSSAValsMap()
345 self.merged_ssa_vals = self.merge_map.values_set
346 for i in merged_ssa_vals:
347 self.merged_ssa_vals.add(i)
348
349 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
350 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
351 merged1 = self.merge_map[ssa_val1]
352 merged2 = self.merge_map[ssa_val2]
353 merged = merged1.with_offset_to_match(ssa_val1)
354 merged = merged.merged(merged2.with_offset_to_match(
355 ssa_val2, additional_offset=additional_offset))
356 self.merged_ssa_vals.remove(merged1)
357 self.merged_ssa_vals.remove(merged2)
358 self.merged_ssa_vals.add(merged)
359 return merged
360
361 @staticmethod
362 def minimally_merged(fn_analysis):
363 # type: (FnAnalysis) -> MergedSSAVals
364 retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=())
365 for op in fn_analysis.fn.ops:
366 for inp in op.input_uses:
367 if inp.unspread_start != inp:
368 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
369 additional_offset=inp.reg_offset_in_unspread)
370 for out in op.outputs:
371 if out.unspread_start != out:
372 retval.merge(out.unspread_start, out,
373 additional_offset=out.reg_offset_in_unspread)
374 if out.tied_input is not None:
375 retval.merge(out.tied_input.ssa_val, out)
376 return retval
377
378
379 @final
380 class IGNode:
381 """ interference graph node """
382 __slots__ = "merged_ssa_val", "edges", "loc"
383
384 def __init__(self, merged_ssa_val, edges=(), loc=None):
385 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
386 self.merged_ssa_val = merged_ssa_val
387 self.edges = OSet(edges)
388 self.loc = loc
389
390 def add_edge(self, other):
391 # type: (IGNode) -> None
392 self.edges.add(other)
393 other.edges.add(self)
394
395 def __eq__(self, other):
396 # type: (object) -> bool
397 if isinstance(other, IGNode):
398 return self.merged_ssa_val == other.merged_ssa_val
399 return NotImplemented
400
401 def __hash__(self):
402 return hash(self.merged_ssa_val)
403
404 def __repr__(self, nodes=None):
405 # type: (None | dict[IGNode, int]) -> str
406 if nodes is None:
407 nodes = {}
408 if self in nodes:
409 return f"<IGNode #{nodes[self]}>"
410 nodes[self] = len(nodes)
411 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
412 return (f"IGNode(#{nodes[self]}, "
413 f"merged_ssa_val={self.merged_ssa_val}, "
414 f"edges={edges}, "
415 f"loc={self.loc})")
416
417 @property
418 def loc_set(self):
419 # type: () -> LocSet
420 return self.merged_ssa_val.loc_set
421
422 def loc_conflicts_with_neighbors(self, loc):
423 # type: (Loc) -> bool
424 for neighbor in self.edges:
425 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
426 return True
427 return False
428
429
430 @plain_data()
431 class AllocationFailed:
432 __slots__ = "node", "merged_ssa_vals", "interference_graph"
433
434 def __init__(self, node, merged_ssa_vals, interference_graph):
435 # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
436 super().__init__()
437 self.node = node
438 self.merged_ssa_vals = merged_ssa_vals
439 self.interference_graph = interference_graph
440
441
442 class AllocationFailedError(Exception):
443 def __init__(self, msg, allocation_failed):
444 # type: (str, AllocationFailed) -> None
445 super().__init__(msg, allocation_failed)
446 self.allocation_failed = allocation_failed
447
448
449 def try_allocate_registers_without_spilling(merged_ssa_vals):
450 # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
451
452 interference_graph = {
453 i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals}
454 fn_analysis = merged_ssa_vals.fn_analysis
455 for ssa_vals in fn_analysis.live_at.values():
456 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
457 for ssa_val in ssa_vals:
458 live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val])
459 for i, j in combinations(live_merged_ssa_vals, 2):
460 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
461 interference_graph[i].add_edge(interference_graph[j])
462
463 nodes_remaining = OSet(interference_graph.values())
464
465 # FIXME: work on code from here
466
467 def local_colorability_score(node):
468 # type: (IGNode) -> int
469 """ returns a positive integer if node is locally colorable, returns
470 zero or a negative integer if node isn't known to be locally
471 colorable, the more negative the value, the less colorable
472 """
473 if node not in nodes_remaining:
474 raise ValueError()
475 retval = len(node.loc_set)
476 for neighbor in node.edges:
477 if neighbor in nodes_remaining:
478 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
479 return retval
480
481 node_stack = [] # type: list[IGNode]
482 while True:
483 best_node = None # type: None | IGNode
484 best_score = 0
485 for node in nodes_remaining:
486 score = local_colorability_score(node)
487 if best_node is None or score > best_score:
488 best_node = node
489 best_score = score
490 if best_score > 0:
491 # it's locally colorable, no need to find a better one
492 break
493
494 if best_node is None:
495 break
496 node_stack.append(best_node)
497 nodes_remaining.remove(best_node)
498
499 retval = {} # type: dict[SSAVal, RegLoc]
500
501 while len(node_stack) > 0:
502 node = node_stack.pop()
503 if node.reg is not None:
504 if node.reg_conflicts_with_neighbors(node.reg):
505 return AllocationFailed(node=node,
506 live_intervals=live_intervals,
507 interference_graph=interference_graph)
508 else:
509 # pick the first non-conflicting register in node.reg_class, since
510 # register classes are ordered from most preferred to least
511 # preferred register.
512 for reg in node.reg_class:
513 if not node.reg_conflicts_with_neighbors(reg):
514 node.reg = reg
515 break
516 if node.reg is None:
517 return AllocationFailed(node=node,
518 live_intervals=live_intervals,
519 interference_graph=interference_graph)
520
521 for ssa_val, offset in node.merged_reg_set.items():
522 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
523
524 return retval
525
526
527 def allocate_registers(ops):
528 # type: (list[Op]) -> dict[SSAVal, RegLoc]
529 retval = try_allocate_registers_without_spilling(ops)
530 if isinstance(retval, AllocationFailed):
531 # TODO: implement spilling
532 raise AllocationFailedError(
533 "spilling required but not yet implemented", retval)
534 return retval