working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator2.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Any, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Dict
10
11 from cached_property import cached_property
12 from nmutil.plain_data import plain_data
13
14 from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc,
15 LocSet, Op, ProgramRange,
16 SSAVal, Ty)
17 from bigint_presentation_code.type_util import final
18 from bigint_presentation_code.util import FMap, OFSet, OSet
19
20
21 @plain_data(unsafe_hash=True, order=True, frozen=True)
22 class LiveInterval:
23 __slots__ = "first_write", "last_use"
24
25 def __init__(self, first_write, last_use=None):
26 # type: (int, int | None) -> None
27 super().__init__()
28 if last_use is None:
29 last_use = first_write
30 if last_use < first_write:
31 raise ValueError("uses must be after first_write")
32 if first_write < 0 or last_use < 0:
33 raise ValueError("indexes must be nonnegative")
34 self.first_write = first_write
35 self.last_use = last_use
36
37 def overlaps(self, other):
38 # type: (LiveInterval) -> bool
39 if self.first_write == other.first_write:
40 return True
41 return self.last_use > other.first_write \
42 and other.last_use > self.first_write
43
44 def __add__(self, use):
45 # type: (int) -> LiveInterval
46 last_use = max(self.last_use, use)
47 return LiveInterval(first_write=self.first_write, last_use=last_use)
48
49 @property
50 def live_after_op_range(self):
51 """the range of op indexes where self is live immediately after the
52 Op at each index
53 """
54 return range(self.first_write, self.last_use)
55
56
57 class BadMergedSSAVal(ValueError):
58 pass
59
60
61 @plain_data(frozen=True)
62 @final
63 class MergedSSAVal:
64 """a set of `SSAVal`s along with their offsets, all register allocated as
65 a single unit.
66
67 Definition of the term `offset` for this class:
68
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
72
73 ```
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
76 ```
77
78 Example:
79 ```
80 v1.ty == <I64*4>
81 v2.ty == <I64*2>
82 v3.ty == <I64>
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
84 msv.ty == <I64*6>
85 ```
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
90 """
91 __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
92
93 def __init__(self, fn_analysis, ssa_val_offsets):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self.fn_analysis = fn_analysis
96 if isinstance(ssa_val_offsets, SSAVal):
97 ssa_val_offsets = {ssa_val_offsets: 0}
98 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
99 first_ssa_val = None
100 for ssa_val in self.ssa_vals:
101 first_ssa_val = ssa_val
102 break
103 if first_ssa_val is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self.first_ssa_val = first_ssa_val # type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len = self.ty.reg_len
108 loc_set = None # type: None | LocSet
109 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
110 def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
111
112 def locs():
113 # type: () -> Iterable[Loc]
114 for loc in ssa_val.def_loc_set_before_spread:
115 disallowed_by_use = False
116 for use in fn_analysis.uses[ssa_val]:
117 use_spread_idx = \
118 use.defining_descriptor.spread_index or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start = loc.start + def_spread_idx - use_spread_idx
126 use_loc = Loc.try_make(
127 loc.kind, start=start,
128 reg_len=use.ty_before_spread.reg_len)
129 if (use_loc is None or
130 use_loc not in use.use_loc_set_before_spread):
131 disallowed_by_use = True
132 break
133 if disallowed_by_use:
134 continue
135 # FIXME: add spread consistency check
136 start = loc.start - cur_offset + self.offset
137 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
138 if loc is not None and (loc_set is None or loc in loc_set):
139 yield loc
140 loc_set = LocSet(locs())
141 assert loc_set is not None, "already checked that self isn't empty"
142 if loc_set.ty is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set.ty == self.ty, "logic error somewhere"
145 self.loc_set = loc_set # type: LocSet
146
147 @cached_property
148 def __hash(self):
149 # type: () -> int
150 return hash((self.fn_analysis, self.ssa_val_offsets))
151
152 def __hash__(self):
153 # type: () -> int
154 return self.__hash
155
156 @cached_property
157 def offset(self):
158 # type: () -> int
159 return min(self.ssa_val_offsets_before_spread.values())
160
161 @property
162 def base_ty(self):
163 # type: () -> BaseTy
164 return self.first_ssa_val.base_ty
165
166 @cached_property
167 def ssa_vals(self):
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self.ssa_val_offsets.keys())
170
171 @cached_property
172 def ty(self):
173 # type: () -> Ty
174 reg_len = 0
175 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
176 cur_ty = ssa_val.ty_before_spread
177 if self.base_ty != cur_ty.base_ty:
178 raise BadMergedSSAVal(
179 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
181 return Ty(base_ty=self.base_ty, reg_len=reg_len)
182
183 @cached_property
184 def ssa_val_offsets_before_spread(self):
185 # type: () -> FMap[SSAVal, int]
186 retval = {} # type: dict[SSAVal, int]
187 for ssa_val, offset in self.ssa_val_offsets.items():
188 retval[ssa_val] = (
189 offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
190 return FMap(retval)
191
192 def offset_by(self, amount):
193 # type: (int) -> MergedSSAVal
194 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
195 return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
196
197 def normalized(self):
198 # type: () -> MergedSSAVal
199 return self.offset_by(-self.offset)
200
201 def with_offset_to_match(self, target, additional_offset=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target, MergedSSAVal):
204 ssa_val_offsets = target.ssa_val_offsets
205 else:
206 ssa_val_offsets = {target: 0}
207 for ssa_val, offset in self.ssa_val_offsets.items():
208 if ssa_val in ssa_val_offsets:
209 return self.offset_by(
210 ssa_val_offsets[ssa_val] + additional_offset - offset)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
212
213 def merged(self, *others):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval = dict(self.ssa_val_offsets)
216 for other in others:
217 if other.fn_analysis != self.fn_analysis:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val, offset in other.ssa_val_offsets.items():
220 if ssa_val in retval and retval[ssa_val] != offset:
221 raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
222 f"{retval[ssa_val]} != {offset}")
223 retval[ssa_val] = offset
224 return MergedSSAVal(fn_analysis=self.fn_analysis,
225 ssa_val_offsets=retval)
226
227 @cached_property
228 def live_interval(self):
229 # type: () -> ProgramRange
230 live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
231 start = live_range.start
232 stop = live_range.stop
233 for ssa_val in self.ssa_vals:
234 live_range = self.fn_analysis.live_ranges[ssa_val]
235 start = min(start, live_range.start)
236 stop = max(stop, live_range.stop)
237 return ProgramRange(start=start, stop=stop)
238
239
240 @final
241 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
242 def __init__(self):
243 # type: (...) -> None
244 self.__map = {} # type: dict[SSAVal, MergedSSAVal]
245 self.__ig_node_map = MergedSSAValToIGNodeMap(
246 _private_merged_ssa_val_map=self.__map)
247
248 def __getitem__(self, __key):
249 # type: (SSAVal) -> MergedSSAVal
250 return self.__map[__key]
251
252 def __iter__(self):
253 # type: () -> Iterator[SSAVal]
254 return iter(self.__map)
255
256 def __len__(self):
257 # type: () -> int
258 return len(self.__map)
259
260 @property
261 def ig_node_map(self):
262 # type: () -> MergedSSAValToIGNodeMap
263 return self.__ig_node_map
264
265 def __repr__(self):
266 # type: () -> str
267 s = ",\n".join(repr(v) for v in self.__ig_node_map)
268 return f"SSAValToMergedSSAValMap({{{s}}})"
269
270
271 @final
272 class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
273 def __init__(
274 self, *,
275 _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
276 ):
277 # type: (...) -> None
278 self.__merged_ssa_val_map = _private_merged_ssa_val_map
279 self.__map = {} # type: dict[MergedSSAVal, IGNode]
280
281 def __getitem__(self, __key):
282 # type: (MergedSSAVal) -> IGNode
283 return self.__map[__key]
284
285 def __iter__(self):
286 # type: () -> Iterator[MergedSSAVal]
287 return iter(self.__map)
288
289 def __len__(self):
290 # type: () -> int
291 return len(self.__map)
292
293 def add_node(self, merged_ssa_val):
294 # type: (MergedSSAVal) -> IGNode
295 node = self.__map.get(merged_ssa_val, None)
296 if node is not None:
297 return node
298 added = 0 # type: int | None
299 try:
300 for ssa_val in merged_ssa_val.ssa_vals:
301 if ssa_val in self.__merged_ssa_val_map:
302 raise ValueError(
303 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
304 f"{merged_ssa_val} and "
305 f"{self.__merged_ssa_val_map[ssa_val]}")
306 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
307 added += 1
308 retval = IGNode(merged_ssa_val)
309 self.__map[merged_ssa_val] = retval
310 added = None
311 return retval
312 finally:
313 if added is not None:
314 # remove partially added stuff
315 for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
316 if idx >= added:
317 break
318 del self.__merged_ssa_val_map[ssa_val]
319
320 def merge_into_one_node(self, final_merged_ssa_val):
321 # type: (MergedSSAVal) -> IGNode
322 source_nodes = {} # type: dict[MergedSSAVal, IGNode]
323 for ssa_val in final_merged_ssa_val.ssa_vals:
324 merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
325 source_nodes[merged_ssa_val] = self.__map[merged_ssa_val]
326 for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
327 raise ValueError(
328 f"SSAVal {i} appears in source IGNode's merged_ssa_val "
329 f"but not in merged IGNode's merged_ssa_val: "
330 f"source_node={self.__map[merged_ssa_val]} "
331 f"final_merged_ssa_val={final_merged_ssa_val}")
332 # FIXME: work on function from here
333 raise NotImplementedError
334 self.__values_set.discard(value)
335 for ssa_val in value.ssa_val_offsets.keys():
336 del self.__merge_map[ssa_val]
337
338 def __repr__(self):
339 # type: () -> str
340 s = ",\n".join(repr(v) for v in self.__map.values())
341 return f"MergedSSAValToIGNodeMap({{{s}}})"
342
343
344 @plain_data(frozen=True)
345 @final
346 class InterferenceGraph:
347 __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
348
349 def __init__(self, fn_analysis, merged_ssa_vals):
350 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
351 self.fn_analysis = fn_analysis
352 self.merged_ssa_val_map = SSAValToMergedSSAValMap()
353 self.nodes = self.merged_ssa_val_map.ig_node_map
354 for i in merged_ssa_vals:
355 self.nodes.add_node(i)
356
357 def merge(self, ssa_val1, ssa_val2, additional_offset=0):
358 # type: (SSAVal, SSAVal, int) -> IGNode
359 merged1 = self.merged_ssa_val_map[ssa_val1]
360 merged2 = self.merged_ssa_val_map[ssa_val2]
361 merged = merged1.with_offset_to_match(ssa_val1)
362 merged = merged.merged(merged2.with_offset_to_match(
363 ssa_val2, additional_offset=additional_offset))
364 return self.nodes.merge_into_one_node(merged)
365
366 @staticmethod
367 def minimally_merged(fn_analysis):
368 # type: (FnAnalysis) -> InterferenceGraph
369 retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
370 for op in fn_analysis.fn.ops:
371 for inp in op.input_uses:
372 if inp.unspread_start != inp:
373 retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
374 additional_offset=inp.reg_offset_in_unspread)
375 for out in op.outputs:
376 retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
377 if out.unspread_start != out:
378 retval.merge(out.unspread_start, out,
379 additional_offset=out.reg_offset_in_unspread)
380 if out.tied_input is not None:
381 retval.merge(out.tied_input.ssa_val, out)
382 return retval
383
384
385 @final
386 class IGNode:
387 """ interference graph node """
388 __slots__ = "merged_ssa_val", "edges", "loc"
389
390 def __init__(self, merged_ssa_val, edges=(), loc=None):
391 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
392 self.merged_ssa_val = merged_ssa_val
393 self.edges = OSet(edges)
394 self.loc = loc
395
396 def add_edge(self, other):
397 # type: (IGNode) -> None
398 self.edges.add(other)
399 other.edges.add(self)
400
401 def __eq__(self, other):
402 # type: (object) -> bool
403 if isinstance(other, IGNode):
404 return self.merged_ssa_val == other.merged_ssa_val
405 return NotImplemented
406
407 def __hash__(self):
408 return hash(self.merged_ssa_val)
409
410 def __repr__(self, nodes=None):
411 # type: (None | dict[IGNode, int]) -> str
412 if nodes is None:
413 nodes = {}
414 if self in nodes:
415 return f"<IGNode #{nodes[self]}>"
416 nodes[self] = len(nodes)
417 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
418 return (f"IGNode(#{nodes[self]}, "
419 f"merged_ssa_val={self.merged_ssa_val}, "
420 f"edges={edges}, "
421 f"loc={self.loc})")
422
423 @property
424 def loc_set(self):
425 # type: () -> LocSet
426 return self.merged_ssa_val.loc_set
427
428 def loc_conflicts_with_neighbors(self, loc):
429 # type: (Loc) -> bool
430 for neighbor in self.edges:
431 if neighbor.loc is not None and neighbor.loc.conflicts(loc):
432 return True
433 return False
434
435
436 @plain_data()
437 class AllocationFailed:
438 __slots__ = "node", "merged_ssa_vals", "interference_graph"
439
440 def __init__(self, node, merged_ssa_vals, interference_graph):
441 # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
442 super().__init__()
443 self.node = node
444 self.merged_ssa_vals = merged_ssa_vals
445 self.interference_graph = interference_graph
446
447
448 class AllocationFailedError(Exception):
449 def __init__(self, msg, allocation_failed):
450 # type: (str, AllocationFailed) -> None
451 super().__init__(msg, allocation_failed)
452 self.allocation_failed = allocation_failed
453
454
455 def try_allocate_registers_without_spilling(merged_ssa_vals):
456 # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
457
458 interference_graph = {
459 i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals}
460 fn_analysis = merged_ssa_vals.fn_analysis
461 for ssa_vals in fn_analysis.live_at.values():
462 live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
463 for ssa_val in ssa_vals:
464 live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val])
465 for i, j in combinations(live_merged_ssa_vals, 2):
466 if i.loc_set.max_conflicts_with(j.loc_set) != 0:
467 interference_graph[i].add_edge(interference_graph[j])
468
469 nodes_remaining = OSet(interference_graph.values())
470
471 # FIXME: work on code from here
472
473 def local_colorability_score(node):
474 # type: (IGNode) -> int
475 """ returns a positive integer if node is locally colorable, returns
476 zero or a negative integer if node isn't known to be locally
477 colorable, the more negative the value, the less colorable
478 """
479 if node not in nodes_remaining:
480 raise ValueError()
481 retval = len(node.loc_set)
482 for neighbor in node.edges:
483 if neighbor in nodes_remaining:
484 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
485 return retval
486
487 node_stack = [] # type: list[IGNode]
488 while True:
489 best_node = None # type: None | IGNode
490 best_score = 0
491 for node in nodes_remaining:
492 score = local_colorability_score(node)
493 if best_node is None or score > best_score:
494 best_node = node
495 best_score = score
496 if best_score > 0:
497 # it's locally colorable, no need to find a better one
498 break
499
500 if best_node is None:
501 break
502 node_stack.append(best_node)
503 nodes_remaining.remove(best_node)
504
505 retval = {} # type: dict[SSAVal, RegLoc]
506
507 while len(node_stack) > 0:
508 node = node_stack.pop()
509 if node.reg is not None:
510 if node.reg_conflicts_with_neighbors(node.reg):
511 return AllocationFailed(node=node,
512 live_intervals=live_intervals,
513 interference_graph=interference_graph)
514 else:
515 # pick the first non-conflicting register in node.reg_class, since
516 # register classes are ordered from most preferred to least
517 # preferred register.
518 for reg in node.reg_class:
519 if not node.reg_conflicts_with_neighbors(reg):
520 node.reg = reg
521 break
522 if node.reg is None:
523 return AllocationFailed(node=node,
524 live_intervals=live_intervals,
525 interference_graph=interference_graph)
526
527 for ssa_val, offset in node.merged_reg_set.items():
528 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
529
530 return retval
531
532
533 def allocate_registers(ops):
534 # type: (list[Op]) -> dict[SSAVal, RegLoc]
535 retval = try_allocate_registers_without_spilling(ops)
536 if isinstance(retval, AllocationFailed):
537 # TODO: implement spilling
538 raise AllocationFailedError(
539 "spilling required but not yet implemented", retval)
540 return retval