75b342243fa95cd7223647b6204f0d33d7e36cf0
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
10
11 from nmutil.plain_data import plain_data
12
13 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
14 RegLoc, RegType, SSAVal)
15
16 if TYPE_CHECKING:
17 from typing_extensions import Self, final
18 else:
19 def final(v):
20 return v
21
22
23 _RegType = TypeVar("_RegType", bound=RegType)
24
25
26 @plain_data(unsafe_hash=True, order=True, frozen=True)
27 class LiveInterval:
28 __slots__ = "first_write", "last_use"
29
30 def __init__(self, first_write, last_use=None):
31 # type: (int, int | None) -> None
32 if last_use is None:
33 last_use = first_write
34 if last_use < first_write:
35 raise ValueError("uses must be after first_write")
36 if first_write < 0 or last_use < 0:
37 raise ValueError("indexes must be nonnegative")
38 self.first_write = first_write
39 self.last_use = last_use
40
41 def overlaps(self, other):
42 # type: (LiveInterval) -> bool
43 if self.first_write == other.first_write:
44 return True
45 return self.last_use > other.first_write \
46 and other.last_use > self.first_write
47
48 def __add__(self, use):
49 # type: (int) -> LiveInterval
50 last_use = max(self.last_use, use)
51 return LiveInterval(first_write=self.first_write, last_use=last_use)
52
53 @property
54 def live_after_op_range(self):
55 """the range of op indexes where self is live immediately after the
56 Op at each index
57 """
58 return range(self.first_write, self.last_use)
59
60
61 @final
62 class MergedRegSet(Mapping[SSAVal[_RegType], int]):
63 def __init__(self, reg_set):
64 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
65 self.__items = {} # type: dict[SSAVal[_RegType], int]
66 if isinstance(reg_set, SSAVal):
67 reg_set = [(reg_set, 0)]
68 for ssa_val, offset in reg_set:
69 if ssa_val in self.__items:
70 other = self.__items[ssa_val]
71 if offset != other:
72 raise ValueError(
73 f"can't merge register sets: conflicting offsets: "
74 f"for {ssa_val}: {offset} != {other}")
75 else:
76 self.__items[ssa_val] = offset
77 first_item = None
78 for i in self.__items.items():
79 first_item = i
80 break
81 if first_item is None:
82 raise ValueError("can't have empty MergedRegs")
83 first_ssa_val, start = first_item
84 ty = first_ssa_val.ty
85 if isinstance(ty, GPRRangeType):
86 stop = start + ty.length
87 for ssa_val, offset in self.__items.items():
88 if not isinstance(ssa_val.ty, GPRRangeType):
89 raise ValueError(f"can't merge incompatible types: "
90 f"{ssa_val.ty} and {ty}")
91 stop = max(stop, offset + ssa_val.ty.length)
92 start = min(start, offset)
93 ty = GPRRangeType(stop - start)
94 else:
95 stop = 1
96 for ssa_val, offset in self.__items.items():
97 if offset != 0:
98 raise ValueError(f"can't have non-zero offset "
99 f"for {ssa_val.ty}")
100 if ty != ssa_val.ty:
101 raise ValueError(f"can't merge incompatible types: "
102 f"{ssa_val.ty} and {ty}")
103 self.__start = start # type: int
104 self.__stop = stop # type: int
105 self.__ty = ty # type: RegType
106 self.__hash = hash(frozenset(self.items()))
107
108 @staticmethod
109 def from_equality_constraint(constraint_sequence):
110 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
111 if len(constraint_sequence) == 1:
112 # any type allowed with len = 1
113 return MergedRegSet(constraint_sequence[0])
114 offset = 0
115 retval = []
116 for val in constraint_sequence:
117 if not isinstance(val.ty, GPRRangeType):
118 raise ValueError("equality constraint sequences must only "
119 "have SSAVal type GPRRangeType")
120 retval.append((val, offset))
121 offset += val.ty.length
122 return MergedRegSet(retval)
123
124 @property
125 def ty(self):
126 return self.__ty
127
128 @property
129 def stop(self):
130 return self.__stop
131
132 @property
133 def start(self):
134 return self.__start
135
136 @property
137 def range(self):
138 return range(self.__start, self.__stop)
139
140 def offset_by(self, amount):
141 # type: (int) -> MergedRegSet[_RegType]
142 return MergedRegSet((k, v + amount) for k, v in self.items())
143
144 def normalized(self):
145 # type: () -> MergedRegSet[_RegType]
146 return self.offset_by(-self.start)
147
148 def with_offset_to_match(self, target):
149 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
150 for ssa_val, offset in self.items():
151 if ssa_val in target:
152 return self.offset_by(target[ssa_val] - offset)
153 raise ValueError("can't change offset to match unrelated MergedRegSet")
154
155 def __getitem__(self, item):
156 # type: (SSAVal[_RegType]) -> int
157 return self.__items[item]
158
159 def __iter__(self):
160 return iter(self.__items)
161
162 def __len__(self):
163 return len(self.__items)
164
165 def __hash__(self):
166 return self.__hash
167
168 def __repr__(self):
169 return f"MergedRegSet({list(self.__items.items())})"
170
171
172 @final
173 class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
174 def __init__(self, ops):
175 # type: (Iterable[Op]) -> None
176 merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
177 for op in ops:
178 for val in (*op.inputs().values(), *op.outputs().values()):
179 if val not in merged_sets:
180 merged_sets[val] = MergedRegSet(val)
181 for e in op.get_equality_constraints():
182 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
183 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
184 lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
185 rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
186 full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
187 for val in full_set.keys():
188 merged_sets[val] = full_set
189
190 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
191
192 def __getitem__(self, key):
193 # type: (SSAVal) -> MergedRegSet
194 return self.__map[key]
195
196 def __iter__(self):
197 return iter(self.__map)
198
199 def __len__(self):
200 return len(self.__map)
201
202 def __repr__(self):
203 return f"MergedRegSets(data={self.__map})"
204
205
206 @final
207 class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
208 def __init__(self, ops):
209 # type: (list[Op]) -> None
210 self.__merged_reg_sets = MergedRegSets(ops)
211 live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
212 for op_idx, op in enumerate(ops):
213 for val in op.inputs().values():
214 live_intervals[self.__merged_reg_sets[val]] += op_idx
215 for val in op.outputs().values():
216 reg_set = self.__merged_reg_sets[val]
217 if reg_set not in live_intervals:
218 live_intervals[reg_set] = LiveInterval(op_idx)
219 else:
220 live_intervals[reg_set] += op_idx
221 self.__live_intervals = live_intervals
222 live_after = [] # type: list[set[MergedRegSet[_RegType]]]
223 live_after += (set() for _ in ops)
224 for reg_set, live_interval in self.__live_intervals.items():
225 for i in live_interval.live_after_op_range:
226 live_after[i].add(reg_set)
227 self.__live_after = [frozenset(i) for i in live_after]
228
229 @property
230 def merged_reg_sets(self):
231 return self.__merged_reg_sets
232
233 def __getitem__(self, key):
234 # type: (MergedRegSet[_RegType]) -> LiveInterval
235 return self.__live_intervals[key]
236
237 def __iter__(self):
238 return iter(self.__live_intervals)
239
240 def reg_sets_live_after(self, op_index):
241 # type: (int) -> frozenset[MergedRegSet[_RegType]]
242 return self.__live_after[op_index]
243
244 def __repr__(self):
245 reg_sets_live_after = dict(enumerate(self.__live_after))
246 return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
247 f"merged_reg_sets={self.merged_reg_sets}, "
248 f"reg_sets_live_after={reg_sets_live_after})")
249
250
251 @final
252 class IGNode(Generic[_RegType]):
253 """ interference graph node """
254 __slots__ = "merged_reg_set", "edges", "reg"
255
256 def __init__(self, merged_reg_set, edges=(), reg=None):
257 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
258 self.merged_reg_set = merged_reg_set
259 self.edges = set(edges)
260 self.reg = reg
261
262 def add_edge(self, other):
263 # type: (IGNode) -> None
264 self.edges.add(other)
265 other.edges.add(self)
266
267 def __eq__(self, other):
268 # type: (object) -> bool
269 if isinstance(other, IGNode):
270 return self.merged_reg_set == other.merged_reg_set
271 return NotImplemented
272
273 def __hash__(self):
274 return hash(self.merged_reg_set)
275
276 def __repr__(self, nodes=None):
277 # type: (None | dict[IGNode, int]) -> str
278 if nodes is None:
279 nodes = {}
280 if self in nodes:
281 return f"<IGNode #{nodes[self]}>"
282 nodes[self] = len(nodes)
283 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
284 return (f"IGNode(#{nodes[self]}, "
285 f"merged_reg_set={self.merged_reg_set}, "
286 f"edges={edges}, "
287 f"reg={self.reg})")
288
289 @property
290 def reg_class(self):
291 # type: () -> RegClass
292 return self.merged_reg_set.ty.reg_class
293
294 def reg_conflicts_with_neighbors(self, reg):
295 # type: (RegLoc) -> bool
296 for neighbor in self.edges:
297 if neighbor.reg is not None and neighbor.reg.conflicts(reg):
298 return True
299 return False
300
301
302 @final
303 class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
304 def __init__(self, merged_reg_sets):
305 # type: (Iterable[MergedRegSet[_RegType]]) -> None
306 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
307
308 def __getitem__(self, key):
309 # type: (MergedRegSet[_RegType]) -> IGNode
310 return self.__nodes[key]
311
312 def __iter__(self):
313 return iter(self.__nodes)
314
315 def __repr__(self):
316 nodes = {}
317 nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
318 nodes_text = ", ".join(nodes_text)
319 return f"InterferenceGraph(nodes={{{nodes_text}}})"
320
321
322 @plain_data()
323 class AllocationFailed:
324 __slots__ = "node", "live_intervals", "interference_graph"
325
326 def __init__(self, node, live_intervals, interference_graph):
327 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
328 self.node = node
329 self.live_intervals = live_intervals
330 self.interference_graph = interference_graph
331
332
333 def try_allocate_registers_without_spilling(ops):
334 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
335
336 live_intervals = LiveIntervals(ops)
337 merged_reg_sets = live_intervals.merged_reg_sets
338 interference_graph = InterferenceGraph(merged_reg_sets.values())
339 for op_idx, op in enumerate(ops):
340 reg_sets = live_intervals.reg_sets_live_after(op_idx)
341 for i, j in combinations(reg_sets, 2):
342 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
343 interference_graph[i].add_edge(interference_graph[j])
344 for i, j in op.get_extra_interferences():
345 i = merged_reg_sets[i]
346 j = merged_reg_sets[j]
347 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
348 interference_graph[i].add_edge(interference_graph[j])
349
350 nodes_remaining = set(interference_graph.values())
351
352 def local_colorability_score(node):
353 # type: (IGNode) -> int
354 """ returns a positive integer if node is locally colorable, returns
355 zero or a negative integer if node isn't known to be locally
356 colorable, the more negative the value, the less colorable
357 """
358 if node not in nodes_remaining:
359 raise ValueError()
360 retval = len(node.reg_class)
361 for neighbor in node.edges:
362 if neighbor in nodes_remaining:
363 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
364 return retval
365
366 node_stack = [] # type: list[IGNode]
367 while True:
368 best_node = None # type: None | IGNode
369 best_score = 0
370 for node in nodes_remaining:
371 score = local_colorability_score(node)
372 if best_node is None or score > best_score:
373 best_node = node
374 best_score = score
375 if best_score > 0:
376 # it's locally colorable, no need to find a better one
377 break
378
379 if best_node is None:
380 break
381 node_stack.append(best_node)
382 nodes_remaining.remove(best_node)
383
384 retval = {} # type: dict[SSAVal, RegLoc]
385
386 while len(node_stack) > 0:
387 node = node_stack.pop()
388 if node.reg is not None:
389 if node.reg_conflicts_with_neighbors(node.reg):
390 return AllocationFailed(node=node,
391 live_intervals=live_intervals,
392 interference_graph=interference_graph)
393 else:
394 # pick the first non-conflicting register in node.reg_class, since
395 # register classes are ordered from most preferred to least
396 # preferred register.
397 for reg in node.reg_class:
398 if not node.reg_conflicts_with_neighbors(reg):
399 node.reg = reg
400 break
401 if node.reg is None:
402 return AllocationFailed(node=node,
403 live_intervals=live_intervals,
404 interference_graph=interference_graph)
405
406 for ssa_val, offset in node.merged_reg_set.items():
407 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
408
409 return retval
410
411
412 def allocate_registers(ops):
413 # type: (list[Op]) -> None
414 raise NotImplementedError