b8269e4bc354f6c31b5433b10d19b71a39a38c1e
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Generic, Iterable, Mapping, TypeVar
10
11 from nmutil.plain_data import plain_data
12
13 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
14 RegLoc, RegType, SSAVal)
15 from bigint_presentation_code.util import OFSet, OSet, final
16
17 _RegType = TypeVar("_RegType", bound=RegType)
18
19
20 @plain_data(unsafe_hash=True, order=True, frozen=True)
21 class LiveInterval:
22 __slots__ = "first_write", "last_use"
23
24 def __init__(self, first_write, last_use=None):
25 # type: (int, int | None) -> None
26 if last_use is None:
27 last_use = first_write
28 if last_use < first_write:
29 raise ValueError("uses must be after first_write")
30 if first_write < 0 or last_use < 0:
31 raise ValueError("indexes must be nonnegative")
32 self.first_write = first_write
33 self.last_use = last_use
34
35 def overlaps(self, other):
36 # type: (LiveInterval) -> bool
37 if self.first_write == other.first_write:
38 return True
39 return self.last_use > other.first_write \
40 and other.last_use > self.first_write
41
42 def __add__(self, use):
43 # type: (int) -> LiveInterval
44 last_use = max(self.last_use, use)
45 return LiveInterval(first_write=self.first_write, last_use=last_use)
46
47 @property
48 def live_after_op_range(self):
49 """the range of op indexes where self is live immediately after the
50 Op at each index
51 """
52 return range(self.first_write, self.last_use)
53
54
55 @final
56 class MergedRegSet(Mapping[SSAVal[_RegType], int]):
57 def __init__(self, reg_set):
58 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
59 self.__items = {} # type: dict[SSAVal[_RegType], int]
60 if isinstance(reg_set, SSAVal):
61 reg_set = [(reg_set, 0)]
62 for ssa_val, offset in reg_set:
63 if ssa_val in self.__items:
64 other = self.__items[ssa_val]
65 if offset != other:
66 raise ValueError(
67 f"can't merge register sets: conflicting offsets: "
68 f"for {ssa_val}: {offset} != {other}")
69 else:
70 self.__items[ssa_val] = offset
71 first_item = None
72 for i in self.__items.items():
73 first_item = i
74 break
75 if first_item is None:
76 raise ValueError("can't have empty MergedRegs")
77 first_ssa_val, start = first_item
78 ty = first_ssa_val.ty
79 if isinstance(ty, GPRRangeType):
80 stop = start + ty.length
81 for ssa_val, offset in self.__items.items():
82 if not isinstance(ssa_val.ty, GPRRangeType):
83 raise ValueError(f"can't merge incompatible types: "
84 f"{ssa_val.ty} and {ty}")
85 stop = max(stop, offset + ssa_val.ty.length)
86 start = min(start, offset)
87 ty = GPRRangeType(stop - start)
88 else:
89 stop = 1
90 for ssa_val, offset in self.__items.items():
91 if offset != 0:
92 raise ValueError(f"can't have non-zero offset "
93 f"for {ssa_val.ty}")
94 if ty != ssa_val.ty:
95 raise ValueError(f"can't merge incompatible types: "
96 f"{ssa_val.ty} and {ty}")
97 self.__start = start # type: int
98 self.__stop = stop # type: int
99 self.__ty = ty # type: RegType
100 self.__hash = hash(OFSet(self.items()))
101
102 @staticmethod
103 def from_equality_constraint(constraint_sequence):
104 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
105 if len(constraint_sequence) == 1:
106 # any type allowed with len = 1
107 return MergedRegSet(constraint_sequence[0])
108 offset = 0
109 retval = []
110 for val in constraint_sequence:
111 if not isinstance(val.ty, GPRRangeType):
112 raise ValueError("equality constraint sequences must only "
113 "have SSAVal type GPRRangeType")
114 retval.append((val, offset))
115 offset += val.ty.length
116 return MergedRegSet(retval)
117
118 @property
119 def ty(self):
120 return self.__ty
121
122 @property
123 def stop(self):
124 return self.__stop
125
126 @property
127 def start(self):
128 return self.__start
129
130 @property
131 def range(self):
132 return range(self.__start, self.__stop)
133
134 def offset_by(self, amount):
135 # type: (int) -> MergedRegSet[_RegType]
136 return MergedRegSet((k, v + amount) for k, v in self.items())
137
138 def normalized(self):
139 # type: () -> MergedRegSet[_RegType]
140 return self.offset_by(-self.start)
141
142 def with_offset_to_match(self, target):
143 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
144 for ssa_val, offset in self.items():
145 if ssa_val in target:
146 return self.offset_by(target[ssa_val] - offset)
147 raise ValueError("can't change offset to match unrelated MergedRegSet")
148
149 def __getitem__(self, item):
150 # type: (SSAVal[_RegType]) -> int
151 return self.__items[item]
152
153 def __iter__(self):
154 return iter(self.__items)
155
156 def __len__(self):
157 return len(self.__items)
158
159 def __hash__(self):
160 return self.__hash
161
162 def __repr__(self):
163 return f"MergedRegSet({list(self.__items.items())})"
164
165
166 @final
167 class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
168 def __init__(self, ops):
169 # type: (Iterable[Op]) -> None
170 merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
171 for op in ops:
172 for val in (*op.inputs().values(), *op.outputs().values()):
173 if val not in merged_sets:
174 merged_sets[val] = MergedRegSet(val)
175 for e in op.get_equality_constraints():
176 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
177 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
178 items = [] # type: list[tuple[SSAVal, int]]
179 for i in e.lhs:
180 s = merged_sets[i].with_offset_to_match(lhs_set)
181 items.extend(s.items())
182 for i in e.rhs:
183 s = merged_sets[i].with_offset_to_match(rhs_set)
184 items.extend(s.items())
185 full_set = MergedRegSet(items)
186 for val in full_set.keys():
187 merged_sets[val] = full_set
188
189 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
190
191 def __getitem__(self, key):
192 # type: (SSAVal) -> MergedRegSet
193 return self.__map[key]
194
195 def __iter__(self):
196 return iter(self.__map)
197
198 def __len__(self):
199 return len(self.__map)
200
201 def __repr__(self):
202 return f"MergedRegSets(data={self.__map})"
203
204
205 @final
206 class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
207 def __init__(self, ops):
208 # type: (list[Op]) -> None
209 self.__merged_reg_sets = MergedRegSets(ops)
210 live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
211 for op_idx, op in enumerate(ops):
212 for val in op.inputs().values():
213 live_intervals[self.__merged_reg_sets[val]] += op_idx
214 for val in op.outputs().values():
215 reg_set = self.__merged_reg_sets[val]
216 if reg_set not in live_intervals:
217 live_intervals[reg_set] = LiveInterval(op_idx)
218 else:
219 live_intervals[reg_set] += op_idx
220 self.__live_intervals = live_intervals
221 live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
222 live_after += (OSet() for _ in ops)
223 for reg_set, live_interval in self.__live_intervals.items():
224 for i in live_interval.live_after_op_range:
225 live_after[i].add(reg_set)
226 self.__live_after = [OFSet(i) for i in live_after]
227
228 @property
229 def merged_reg_sets(self):
230 return self.__merged_reg_sets
231
232 def __getitem__(self, key):
233 # type: (MergedRegSet[_RegType]) -> LiveInterval
234 return self.__live_intervals[key]
235
236 def __iter__(self):
237 return iter(self.__live_intervals)
238
239 def __len__(self):
240 return len(self.__live_intervals)
241
242 def reg_sets_live_after(self, op_index):
243 # type: (int) -> OFSet[MergedRegSet[_RegType]]
244 return self.__live_after[op_index]
245
246 def __repr__(self):
247 reg_sets_live_after = dict(enumerate(self.__live_after))
248 return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
249 f"merged_reg_sets={self.merged_reg_sets}, "
250 f"reg_sets_live_after={reg_sets_live_after})")
251
252
253 @final
254 class IGNode(Generic[_RegType]):
255 """ interference graph node """
256 __slots__ = "merged_reg_set", "edges", "reg"
257
258 def __init__(self, merged_reg_set, edges=(), reg=None):
259 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
260 self.merged_reg_set = merged_reg_set
261 self.edges = OSet(edges)
262 self.reg = reg
263
264 def add_edge(self, other):
265 # type: (IGNode) -> None
266 self.edges.add(other)
267 other.edges.add(self)
268
269 def __eq__(self, other):
270 # type: (object) -> bool
271 if isinstance(other, IGNode):
272 return self.merged_reg_set == other.merged_reg_set
273 return NotImplemented
274
275 def __hash__(self):
276 return hash(self.merged_reg_set)
277
278 def __repr__(self, nodes=None):
279 # type: (None | dict[IGNode, int]) -> str
280 if nodes is None:
281 nodes = {}
282 if self in nodes:
283 return f"<IGNode #{nodes[self]}>"
284 nodes[self] = len(nodes)
285 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
286 return (f"IGNode(#{nodes[self]}, "
287 f"merged_reg_set={self.merged_reg_set}, "
288 f"edges={edges}, "
289 f"reg={self.reg})")
290
291 @property
292 def reg_class(self):
293 # type: () -> RegClass
294 return self.merged_reg_set.ty.reg_class
295
296 def reg_conflicts_with_neighbors(self, reg):
297 # type: (RegLoc) -> bool
298 for neighbor in self.edges:
299 if neighbor.reg is not None and neighbor.reg.conflicts(reg):
300 return True
301 return False
302
303
304 @final
305 class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
306 def __init__(self, merged_reg_sets):
307 # type: (Iterable[MergedRegSet[_RegType]]) -> None
308 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
309
310 def __getitem__(self, key):
311 # type: (MergedRegSet[_RegType]) -> IGNode
312 return self.__nodes[key]
313
314 def __iter__(self):
315 return iter(self.__nodes)
316
317 def __len__(self):
318 return len(self.__nodes)
319
320 def __repr__(self):
321 nodes = {}
322 nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
323 nodes_text = ", ".join(nodes_text)
324 return f"InterferenceGraph(nodes={{{nodes_text}}})"
325
326
327 @plain_data()
328 class AllocationFailed:
329 __slots__ = "node", "live_intervals", "interference_graph"
330
331 def __init__(self, node, live_intervals, interference_graph):
332 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
333 self.node = node
334 self.live_intervals = live_intervals
335 self.interference_graph = interference_graph
336
337
338 class AllocationFailedError(Exception):
339 def __init__(self, msg, allocation_failed):
340 # type: (str, AllocationFailed) -> None
341 super().__init__(msg, allocation_failed)
342 self.allocation_failed = allocation_failed
343
344
345 def try_allocate_registers_without_spilling(ops):
346 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
347
348 live_intervals = LiveIntervals(ops)
349 merged_reg_sets = live_intervals.merged_reg_sets
350 interference_graph = InterferenceGraph(merged_reg_sets.values())
351 for op_idx, op in enumerate(ops):
352 reg_sets = live_intervals.reg_sets_live_after(op_idx)
353 for i, j in combinations(reg_sets, 2):
354 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
355 interference_graph[i].add_edge(interference_graph[j])
356 for i, j in op.get_extra_interferences():
357 i = merged_reg_sets[i]
358 j = merged_reg_sets[j]
359 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
360 interference_graph[i].add_edge(interference_graph[j])
361
362 nodes_remaining = OSet(interference_graph.values())
363
364 def local_colorability_score(node):
365 # type: (IGNode) -> int
366 """ returns a positive integer if node is locally colorable, returns
367 zero or a negative integer if node isn't known to be locally
368 colorable, the more negative the value, the less colorable
369 """
370 if node not in nodes_remaining:
371 raise ValueError()
372 retval = len(node.reg_class)
373 for neighbor in node.edges:
374 if neighbor in nodes_remaining:
375 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
376 return retval
377
378 node_stack = [] # type: list[IGNode]
379 while True:
380 best_node = None # type: None | IGNode
381 best_score = 0
382 for node in nodes_remaining:
383 score = local_colorability_score(node)
384 if best_node is None or score > best_score:
385 best_node = node
386 best_score = score
387 if best_score > 0:
388 # it's locally colorable, no need to find a better one
389 break
390
391 if best_node is None:
392 break
393 node_stack.append(best_node)
394 nodes_remaining.remove(best_node)
395
396 retval = {} # type: dict[SSAVal, RegLoc]
397
398 while len(node_stack) > 0:
399 node = node_stack.pop()
400 if node.reg is not None:
401 if node.reg_conflicts_with_neighbors(node.reg):
402 return AllocationFailed(node=node,
403 live_intervals=live_intervals,
404 interference_graph=interference_graph)
405 else:
406 # pick the first non-conflicting register in node.reg_class, since
407 # register classes are ordered from most preferred to least
408 # preferred register.
409 for reg in node.reg_class:
410 if not node.reg_conflicts_with_neighbors(reg):
411 node.reg = reg
412 break
413 if node.reg is None:
414 return AllocationFailed(node=node,
415 live_intervals=live_intervals,
416 interference_graph=interference_graph)
417
418 for ssa_val, offset in node.merged_reg_set.items():
419 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
420
421 return retval
422
423
424 def allocate_registers(ops):
425 # type: (list[Op]) -> dict[SSAVal, RegLoc]
426 retval = try_allocate_registers_without_spilling(ops)
427 if isinstance(retval, AllocationFailed):
428 # TODO: implement spilling
429 raise AllocationFailedError(
430 "spilling required but not yet implemented", retval)
431 return retval