working on refactoring register allocator to use new ir
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator2.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from functools import reduce
10 from typing import Generic, Iterable, Mapping
11 from cached_property import cached_property
12 import operator
13
14 from nmutil.plain_data import plain_data
15
16 from bigint_presentation_code.compiler_ir2 import (
17 Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses)
18 from bigint_presentation_code.type_util import final, Self
19 from bigint_presentation_code.util import OFSet, OSet, FMap
20
21
22 @plain_data(unsafe_hash=True, order=True, frozen=True)
23 class LiveInterval:
24 __slots__ = "first_write", "last_use"
25
26 def __init__(self, first_write, last_use=None):
27 # type: (int, int | None) -> None
28 if last_use is None:
29 last_use = first_write
30 if last_use < first_write:
31 raise ValueError("uses must be after first_write")
32 if first_write < 0 or last_use < 0:
33 raise ValueError("indexes must be nonnegative")
34 self.first_write = first_write
35 self.last_use = last_use
36
37 def overlaps(self, other):
38 # type: (LiveInterval) -> bool
39 if self.first_write == other.first_write:
40 return True
41 return self.last_use > other.first_write \
42 and other.last_use > self.first_write
43
44 def __add__(self, use):
45 # type: (int) -> LiveInterval
46 last_use = max(self.last_use, use)
47 return LiveInterval(first_write=self.first_write, last_use=last_use)
48
49 @property
50 def live_after_op_range(self):
51 """the range of op indexes where self is live immediately after the
52 Op at each index
53 """
54 return range(self.first_write, self.last_use)
55
56
57 class BadMergedSSAVal(ValueError):
58 pass
59
60
61 @plain_data(frozen=True, unsafe_hash=True)
62 @final
63 class MergedSSAVal:
64 """a set of `SSAVal`s along with their offsets, all register allocated as
65 a single unit.
66
67 Definition of the term `offset` for this class:
68
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
72
73 ```
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
76 ```
77
78 Example:
79 ```
80 v1.ty == <I64*4>
81 v2.ty == <I64*2>
82 v3.ty == <I64>
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
84 msv.ty == <I64*6>
85 ```
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
90 """
91 __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
92
93 def __init__(self, fn_with_uses, ssa_val_offsets):
94 # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
95 self.fn_with_uses = fn_with_uses
96 if isinstance(ssa_val_offsets, SSAVal):
97 ssa_val_offsets = {ssa_val_offsets: 0}
98 self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
99 base_ty = None
100 for ssa_val in self.ssa_val_offsets.keys():
101 base_ty = ssa_val.base_ty
102 break
103 if base_ty is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self.base_ty = base_ty # type: BaseTy
106 # self.ty checks for mismatched base_ty
107 reg_len = self.ty.reg_len
108 loc_set = None # type: None | LocSet
109 for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
110 def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
111
112 def locs():
113 # type: () -> Iterable[Loc]
114 for loc in ssa_val.def_loc_set_before_spread:
115 disallowed_by_use = False
116 for use in fn_with_uses.uses[ssa_val]:
117 use_spread_idx = \
118 use.defining_descriptor.spread_index or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start = loc.start + def_spread_idx - use_spread_idx
126 use_loc = Loc.try_make(
127 loc.kind, start=start,
128 reg_len=use.ty_before_spread.reg_len)
129 if (use_loc is None or
130 use_loc not in use.use_loc_set_before_spread):
131 disallowed_by_use = True
132 break
133 if disallowed_by_use:
134 continue
135 # FIXME: add spread consistency check
136 start = loc.start - cur_offset + self.offset
137 loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
138 if loc is not None and (loc_set is None or loc in loc_set):
139 yield loc
140 loc_set = LocSet(locs())
141 assert loc_set is not None, "already checked that self isn't empty"
142 if loc_set.ty is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set.ty == self.ty, "logic error somewhere"
145 self.loc_set = loc_set # type: LocSet
146
147 @cached_property
148 def offset(self):
149 # type: () -> int
150 return min(self.ssa_val_offsets_before_spread.values())
151
152 @cached_property
153 def ty(self):
154 # type: () -> Ty
155 reg_len = 0
156 for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
157 cur_ty = ssa_val.ty_before_spread
158 if self.base_ty != cur_ty.base_ty:
159 raise BadMergedSSAVal(
160 f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
161 reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
162 return Ty(base_ty=self.base_ty, reg_len=reg_len)
163
164 @cached_property
165 def ssa_val_offsets_before_spread(self):
166 # type: () -> FMap[SSAVal, int]
167 retval = {} # type: dict[SSAVal, int]
168 for ssa_val, offset in self.ssa_val_offsets.items():
169 offset_before_spread = offset
170 spread_index = ssa_val.defining_descriptor.spread_index
171 if spread_index is not None:
172 assert ssa_val.ty.reg_len == 1, (
173 "this function assumes spreading always converts a vector "
174 "to a contiguous sequence of scalars, if that's changed "
175 "in the future, then this function needs to be adjusted")
176 offset_before_spread -= spread_index
177 retval[ssa_val] = offset_before_spread
178 return FMap(retval)
179
180 def offset_by(self, amount):
181 # type: (int) -> MergedSSAVal
182 v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
183 return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v)
184
185 def normalized(self):
186 # type: () -> MergedSSAVal
187 return self.offset_by(-self.offset)
188
189 def with_offset_to_match(self, target):
190 # type: (MergedSSAVal) -> MergedSSAVal
191 for ssa_val, offset in self.ssa_val_offsets.items():
192 if ssa_val in target.ssa_val_offsets:
193 return self.offset_by(target.ssa_val_offsets[ssa_val] - offset)
194 raise ValueError("can't change offset to match unrelated MergedSSAVal")
195
196
197 @final
198 class MergedSSAVals(OFSet[MergedSSAVal]):
199 def __init__(self, merged_ssa_vals=()):
200 # type: (Iterable[MergedSSAVal]) -> None
201 super().__init__(merged_ssa_vals)
202 merge_map = {} # type: dict[SSAVal, MergedSSAVal]
203 for merged_ssa_val in self:
204 for ssa_val in merged_ssa_val.ssa_val_offsets.keys():
205 if ssa_val in merge_map:
206 raise ValueError(
207 f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
208 f"{merged_ssa_val} and {merge_map[ssa_val]}")
209 merge_map[ssa_val] = merged_ssa_val
210 self.__merge_map = FMap(merge_map)
211
212 @cached_property
213 def merge_map(self):
214 # type: () -> FMap[SSAVal, MergedSSAVal]
215 return self.__merge_map
216
217 # FIXME: work on code from here
218
219 @staticmethod
220 def minimally_merged(fn_with_uses):
221 # type: (FnWithUses) -> MergedSSAVals
222 merge_map = {} # type: dict[SSAVal, MergedSSAVal]
223 for op in fn_with_uses.fn.ops:
224 for fn
225 for val in (*op.inputs().values(), *op.outputs().values()):
226 if val not in merged_sets:
227 merged_sets[val] = MergedRegSet(val)
228 for e in op.get_equality_constraints():
229 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
230 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
231 items = [] # type: list[tuple[SSAVal, int]]
232 for i in e.lhs:
233 s = merged_sets[i].with_offset_to_match(lhs_set)
234 items.extend(s.items())
235 for i in e.rhs:
236 s = merged_sets[i].with_offset_to_match(rhs_set)
237 items.extend(s.items())
238 full_set = MergedRegSet(items)
239 for val in full_set.keys():
240 merged_sets[val] = full_set
241
242 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
243
244
245 @final
246 class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
247 def __init__(self, ops):
248 # type: (list[Op]) -> None
249 self.__merged_reg_sets = MergedRegSets(ops)
250 live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
251 for op_idx, op in enumerate(ops):
252 for val in op.inputs().values():
253 live_intervals[self.__merged_reg_sets[val]] += op_idx
254 for val in op.outputs().values():
255 reg_set = self.__merged_reg_sets[val]
256 if reg_set not in live_intervals:
257 live_intervals[reg_set] = LiveInterval(op_idx)
258 else:
259 live_intervals[reg_set] += op_idx
260 self.__live_intervals = live_intervals
261 live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
262 live_after += (OSet() for _ in ops)
263 for reg_set, live_interval in self.__live_intervals.items():
264 for i in live_interval.live_after_op_range:
265 live_after[i].add(reg_set)
266 self.__live_after = [OFSet(i) for i in live_after]
267
268 @property
269 def merged_reg_sets(self):
270 return self.__merged_reg_sets
271
272 def __getitem__(self, key):
273 # type: (MergedRegSet[_RegType]) -> LiveInterval
274 return self.__live_intervals[key]
275
276 def __iter__(self):
277 return iter(self.__live_intervals)
278
279 def __len__(self):
280 return len(self.__live_intervals)
281
282 def reg_sets_live_after(self, op_index):
283 # type: (int) -> OFSet[MergedRegSet[_RegType]]
284 return self.__live_after[op_index]
285
286 def __repr__(self):
287 reg_sets_live_after = dict(enumerate(self.__live_after))
288 return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
289 f"merged_reg_sets={self.merged_reg_sets}, "
290 f"reg_sets_live_after={reg_sets_live_after})")
291
292
293 @final
294 class IGNode(Generic[_RegType]):
295 """ interference graph node """
296 __slots__ = "merged_reg_set", "edges", "reg"
297
298 def __init__(self, merged_reg_set, edges=(), reg=None):
299 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
300 self.merged_reg_set = merged_reg_set
301 self.edges = OSet(edges)
302 self.reg = reg
303
304 def add_edge(self, other):
305 # type: (IGNode) -> None
306 self.edges.add(other)
307 other.edges.add(self)
308
309 def __eq__(self, other):
310 # type: (object) -> bool
311 if isinstance(other, IGNode):
312 return self.merged_reg_set == other.merged_reg_set
313 return NotImplemented
314
315 def __hash__(self):
316 return hash(self.merged_reg_set)
317
318 def __repr__(self, nodes=None):
319 # type: (None | dict[IGNode, int]) -> str
320 if nodes is None:
321 nodes = {}
322 if self in nodes:
323 return f"<IGNode #{nodes[self]}>"
324 nodes[self] = len(nodes)
325 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
326 return (f"IGNode(#{nodes[self]}, "
327 f"merged_reg_set={self.merged_reg_set}, "
328 f"edges={edges}, "
329 f"reg={self.reg})")
330
331 @property
332 def reg_class(self):
333 # type: () -> RegClass
334 return self.merged_reg_set.ty.reg_class
335
336 def reg_conflicts_with_neighbors(self, reg):
337 # type: (RegLoc) -> bool
338 for neighbor in self.edges:
339 if neighbor.reg is not None and neighbor.reg.conflicts(reg):
340 return True
341 return False
342
343
344 @final
345 class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
346 def __init__(self, merged_reg_sets):
347 # type: (Iterable[MergedRegSet[_RegType]]) -> None
348 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
349
350 def __getitem__(self, key):
351 # type: (MergedRegSet[_RegType]) -> IGNode
352 return self.__nodes[key]
353
354 def __iter__(self):
355 return iter(self.__nodes)
356
357 def __len__(self):
358 return len(self.__nodes)
359
360 def __repr__(self):
361 nodes = {}
362 nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
363 nodes_text = ", ".join(nodes_text)
364 return f"InterferenceGraph(nodes={{{nodes_text}}})"
365
366
367 @plain_data()
368 class AllocationFailed:
369 __slots__ = "node", "live_intervals", "interference_graph"
370
371 def __init__(self, node, live_intervals, interference_graph):
372 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
373 self.node = node
374 self.live_intervals = live_intervals
375 self.interference_graph = interference_graph
376
377
378 class AllocationFailedError(Exception):
379 def __init__(self, msg, allocation_failed):
380 # type: (str, AllocationFailed) -> None
381 super().__init__(msg, allocation_failed)
382 self.allocation_failed = allocation_failed
383
384
385 def try_allocate_registers_without_spilling(ops):
386 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
387
388 live_intervals = LiveIntervals(ops)
389 merged_reg_sets = live_intervals.merged_reg_sets
390 interference_graph = InterferenceGraph(merged_reg_sets.values())
391 for op_idx, op in enumerate(ops):
392 reg_sets = live_intervals.reg_sets_live_after(op_idx)
393 for i, j in combinations(reg_sets, 2):
394 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
395 interference_graph[i].add_edge(interference_graph[j])
396 for i, j in op.get_extra_interferences():
397 i = merged_reg_sets[i]
398 j = merged_reg_sets[j]
399 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
400 interference_graph[i].add_edge(interference_graph[j])
401
402 nodes_remaining = OSet(interference_graph.values())
403
404 def local_colorability_score(node):
405 # type: (IGNode) -> int
406 """ returns a positive integer if node is locally colorable, returns
407 zero or a negative integer if node isn't known to be locally
408 colorable, the more negative the value, the less colorable
409 """
410 if node not in nodes_remaining:
411 raise ValueError()
412 retval = len(node.reg_class)
413 for neighbor in node.edges:
414 if neighbor in nodes_remaining:
415 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
416 return retval
417
418 node_stack = [] # type: list[IGNode]
419 while True:
420 best_node = None # type: None | IGNode
421 best_score = 0
422 for node in nodes_remaining:
423 score = local_colorability_score(node)
424 if best_node is None or score > best_score:
425 best_node = node
426 best_score = score
427 if best_score > 0:
428 # it's locally colorable, no need to find a better one
429 break
430
431 if best_node is None:
432 break
433 node_stack.append(best_node)
434 nodes_remaining.remove(best_node)
435
436 retval = {} # type: dict[SSAVal, RegLoc]
437
438 while len(node_stack) > 0:
439 node = node_stack.pop()
440 if node.reg is not None:
441 if node.reg_conflicts_with_neighbors(node.reg):
442 return AllocationFailed(node=node,
443 live_intervals=live_intervals,
444 interference_graph=interference_graph)
445 else:
446 # pick the first non-conflicting register in node.reg_class, since
447 # register classes are ordered from most preferred to least
448 # preferred register.
449 for reg in node.reg_class:
450 if not node.reg_conflicts_with_neighbors(reg):
451 node.reg = reg
452 break
453 if node.reg is None:
454 return AllocationFailed(node=node,
455 live_intervals=live_intervals,
456 interference_graph=interference_graph)
457
458 for ssa_val, offset in node.merged_reg_set.items():
459 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
460
461 return retval
462
463
464 def allocate_registers(ops):
465 # type: (list[Op]) -> dict[SSAVal, RegLoc]
466 retval = try_allocate_registers_without_spilling(ops)
467 if isinstance(retval, AllocationFailed):
468 # TODO: implement spilling
469 raise AllocationFailedError(
470 "spilling required but not yet implemented", retval)
471 return retval