cc794e9b926efc906ae8bd6906c1de9bdb7034ce
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import Generic, Iterable, Mapping, TypeVar
10
11 from nmutil.plain_data import plain_data
12
13 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
14 RegLoc, RegType, SSAVal)
15 from bigint_presentation_code.type_util import final
16 from bigint_presentation_code.util import OFSet, OSet
17
18 _RegType = TypeVar("_RegType", bound=RegType)
19
20
21 @plain_data(unsafe_hash=True, order=True, frozen=True)
22 class LiveInterval:
23 __slots__ = "first_write", "last_use"
24
25 def __init__(self, first_write, last_use=None):
26 # type: (int, int | None) -> None
27 if last_use is None:
28 last_use = first_write
29 if last_use < first_write:
30 raise ValueError("uses must be after first_write")
31 if first_write < 0 or last_use < 0:
32 raise ValueError("indexes must be nonnegative")
33 self.first_write = first_write
34 self.last_use = last_use
35
36 def overlaps(self, other):
37 # type: (LiveInterval) -> bool
38 if self.first_write == other.first_write:
39 return True
40 return self.last_use > other.first_write \
41 and other.last_use > self.first_write
42
43 def __add__(self, use):
44 # type: (int) -> LiveInterval
45 last_use = max(self.last_use, use)
46 return LiveInterval(first_write=self.first_write, last_use=last_use)
47
48 @property
49 def live_after_op_range(self):
50 """the range of op indexes where self is live immediately after the
51 Op at each index
52 """
53 return range(self.first_write, self.last_use)
54
55
56 @final
57 class MergedRegSet(Mapping[SSAVal[_RegType], int]):
58 def __init__(self, reg_set):
59 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
60 self.__items = {} # type: dict[SSAVal[_RegType], int]
61 if isinstance(reg_set, SSAVal):
62 reg_set = [(reg_set, 0)]
63 for ssa_val, offset in reg_set:
64 if ssa_val in self.__items:
65 other = self.__items[ssa_val]
66 if offset != other:
67 raise ValueError(
68 f"can't merge register sets: conflicting offsets: "
69 f"for {ssa_val}: {offset} != {other}")
70 else:
71 self.__items[ssa_val] = offset
72 first_item = None
73 for i in self.__items.items():
74 first_item = i
75 break
76 if first_item is None:
77 raise ValueError("can't have empty MergedRegs")
78 first_ssa_val, start = first_item
79 ty = first_ssa_val.ty
80 if isinstance(ty, GPRRangeType):
81 stop = start + ty.length
82 for ssa_val, offset in self.__items.items():
83 if not isinstance(ssa_val.ty, GPRRangeType):
84 raise ValueError(f"can't merge incompatible types: "
85 f"{ssa_val.ty} and {ty}")
86 stop = max(stop, offset + ssa_val.ty.length)
87 start = min(start, offset)
88 ty = GPRRangeType(stop - start)
89 else:
90 stop = 1
91 for ssa_val, offset in self.__items.items():
92 if offset != 0:
93 raise ValueError(f"can't have non-zero offset "
94 f"for {ssa_val.ty}")
95 if ty != ssa_val.ty:
96 raise ValueError(f"can't merge incompatible types: "
97 f"{ssa_val.ty} and {ty}")
98 self.__start = start # type: int
99 self.__stop = stop # type: int
100 self.__ty = ty # type: RegType
101 self.__hash = hash(OFSet(self.items()))
102
103 @staticmethod
104 def from_equality_constraint(constraint_sequence):
105 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
106 if len(constraint_sequence) == 1:
107 # any type allowed with len = 1
108 return MergedRegSet(constraint_sequence[0])
109 offset = 0
110 retval = []
111 for val in constraint_sequence:
112 if not isinstance(val.ty, GPRRangeType):
113 raise ValueError("equality constraint sequences must only "
114 "have SSAVal type GPRRangeType")
115 retval.append((val, offset))
116 offset += val.ty.length
117 return MergedRegSet(retval)
118
119 @property
120 def ty(self):
121 return self.__ty
122
123 @property
124 def stop(self):
125 return self.__stop
126
127 @property
128 def start(self):
129 return self.__start
130
131 @property
132 def range(self):
133 return range(self.__start, self.__stop)
134
135 def offset_by(self, amount):
136 # type: (int) -> MergedRegSet[_RegType]
137 return MergedRegSet((k, v + amount) for k, v in self.items())
138
139 def normalized(self):
140 # type: () -> MergedRegSet[_RegType]
141 return self.offset_by(-self.start)
142
143 def with_offset_to_match(self, target):
144 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
145 for ssa_val, offset in self.items():
146 if ssa_val in target:
147 return self.offset_by(target[ssa_val] - offset)
148 raise ValueError("can't change offset to match unrelated MergedRegSet")
149
150 def __getitem__(self, item):
151 # type: (SSAVal[_RegType]) -> int
152 return self.__items[item]
153
154 def __iter__(self):
155 return iter(self.__items)
156
157 def __len__(self):
158 return len(self.__items)
159
160 def __hash__(self):
161 return self.__hash
162
163 def __repr__(self):
164 return f"MergedRegSet({list(self.__items.items())})"
165
166
167 @final
168 class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
169 def __init__(self, ops):
170 # type: (Iterable[Op]) -> None
171 merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
172 for op in ops:
173 for val in (*op.inputs().values(), *op.outputs().values()):
174 if val not in merged_sets:
175 merged_sets[val] = MergedRegSet(val)
176 for e in op.get_equality_constraints():
177 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
178 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
179 items = [] # type: list[tuple[SSAVal, int]]
180 for i in e.lhs:
181 s = merged_sets[i].with_offset_to_match(lhs_set)
182 items.extend(s.items())
183 for i in e.rhs:
184 s = merged_sets[i].with_offset_to_match(rhs_set)
185 items.extend(s.items())
186 full_set = MergedRegSet(items)
187 for val in full_set.keys():
188 merged_sets[val] = full_set
189
190 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
191
192 def __getitem__(self, key):
193 # type: (SSAVal) -> MergedRegSet
194 return self.__map[key]
195
196 def __iter__(self):
197 return iter(self.__map)
198
199 def __len__(self):
200 return len(self.__map)
201
202 def __repr__(self):
203 return f"MergedRegSets(data={self.__map})"
204
205
206 @final
207 class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
208 def __init__(self, ops):
209 # type: (list[Op]) -> None
210 self.__merged_reg_sets = MergedRegSets(ops)
211 live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
212 for op_idx, op in enumerate(ops):
213 for val in op.inputs().values():
214 live_intervals[self.__merged_reg_sets[val]] += op_idx
215 for val in op.outputs().values():
216 reg_set = self.__merged_reg_sets[val]
217 if reg_set not in live_intervals:
218 live_intervals[reg_set] = LiveInterval(op_idx)
219 else:
220 live_intervals[reg_set] += op_idx
221 self.__live_intervals = live_intervals
222 live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
223 live_after += (OSet() for _ in ops)
224 for reg_set, live_interval in self.__live_intervals.items():
225 for i in live_interval.live_after_op_range:
226 live_after[i].add(reg_set)
227 self.__live_after = [OFSet(i) for i in live_after]
228
229 @property
230 def merged_reg_sets(self):
231 return self.__merged_reg_sets
232
233 def __getitem__(self, key):
234 # type: (MergedRegSet[_RegType]) -> LiveInterval
235 return self.__live_intervals[key]
236
237 def __iter__(self):
238 return iter(self.__live_intervals)
239
240 def __len__(self):
241 return len(self.__live_intervals)
242
243 def reg_sets_live_after(self, op_index):
244 # type: (int) -> OFSet[MergedRegSet[_RegType]]
245 return self.__live_after[op_index]
246
247 def __repr__(self):
248 reg_sets_live_after = dict(enumerate(self.__live_after))
249 return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
250 f"merged_reg_sets={self.merged_reg_sets}, "
251 f"reg_sets_live_after={reg_sets_live_after})")
252
253
254 @final
255 class IGNode(Generic[_RegType]):
256 """ interference graph node """
257 __slots__ = "merged_reg_set", "edges", "reg"
258
259 def __init__(self, merged_reg_set, edges=(), reg=None):
260 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
261 self.merged_reg_set = merged_reg_set
262 self.edges = OSet(edges)
263 self.reg = reg
264
265 def add_edge(self, other):
266 # type: (IGNode) -> None
267 self.edges.add(other)
268 other.edges.add(self)
269
270 def __eq__(self, other):
271 # type: (object) -> bool
272 if isinstance(other, IGNode):
273 return self.merged_reg_set == other.merged_reg_set
274 return NotImplemented
275
276 def __hash__(self):
277 return hash(self.merged_reg_set)
278
279 def __repr__(self, nodes=None):
280 # type: (None | dict[IGNode, int]) -> str
281 if nodes is None:
282 nodes = {}
283 if self in nodes:
284 return f"<IGNode #{nodes[self]}>"
285 nodes[self] = len(nodes)
286 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
287 return (f"IGNode(#{nodes[self]}, "
288 f"merged_reg_set={self.merged_reg_set}, "
289 f"edges={edges}, "
290 f"reg={self.reg})")
291
292 @property
293 def reg_class(self):
294 # type: () -> RegClass
295 return self.merged_reg_set.ty.reg_class
296
297 def reg_conflicts_with_neighbors(self, reg):
298 # type: (RegLoc) -> bool
299 for neighbor in self.edges:
300 if neighbor.reg is not None and neighbor.reg.conflicts(reg):
301 return True
302 return False
303
304
305 @final
306 class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
307 def __init__(self, merged_reg_sets):
308 # type: (Iterable[MergedRegSet[_RegType]]) -> None
309 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
310
311 def __getitem__(self, key):
312 # type: (MergedRegSet[_RegType]) -> IGNode
313 return self.__nodes[key]
314
315 def __iter__(self):
316 return iter(self.__nodes)
317
318 def __len__(self):
319 return len(self.__nodes)
320
321 def __repr__(self):
322 nodes = {}
323 nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
324 nodes_text = ", ".join(nodes_text)
325 return f"InterferenceGraph(nodes={{{nodes_text}}})"
326
327
328 @plain_data()
329 class AllocationFailed:
330 __slots__ = "node", "live_intervals", "interference_graph"
331
332 def __init__(self, node, live_intervals, interference_graph):
333 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
334 self.node = node
335 self.live_intervals = live_intervals
336 self.interference_graph = interference_graph
337
338
339 class AllocationFailedError(Exception):
340 def __init__(self, msg, allocation_failed):
341 # type: (str, AllocationFailed) -> None
342 super().__init__(msg, allocation_failed)
343 self.allocation_failed = allocation_failed
344
345
346 def try_allocate_registers_without_spilling(ops):
347 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
348
349 live_intervals = LiveIntervals(ops)
350 merged_reg_sets = live_intervals.merged_reg_sets
351 interference_graph = InterferenceGraph(merged_reg_sets.values())
352 for op_idx, op in enumerate(ops):
353 reg_sets = live_intervals.reg_sets_live_after(op_idx)
354 for i, j in combinations(reg_sets, 2):
355 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
356 interference_graph[i].add_edge(interference_graph[j])
357 for i, j in op.get_extra_interferences():
358 i = merged_reg_sets[i]
359 j = merged_reg_sets[j]
360 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
361 interference_graph[i].add_edge(interference_graph[j])
362
363 nodes_remaining = OSet(interference_graph.values())
364
365 def local_colorability_score(node):
366 # type: (IGNode) -> int
367 """ returns a positive integer if node is locally colorable, returns
368 zero or a negative integer if node isn't known to be locally
369 colorable, the more negative the value, the less colorable
370 """
371 if node not in nodes_remaining:
372 raise ValueError()
373 retval = len(node.reg_class)
374 for neighbor in node.edges:
375 if neighbor in nodes_remaining:
376 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
377 return retval
378
379 node_stack = [] # type: list[IGNode]
380 while True:
381 best_node = None # type: None | IGNode
382 best_score = 0
383 for node in nodes_remaining:
384 score = local_colorability_score(node)
385 if best_node is None or score > best_score:
386 best_node = node
387 best_score = score
388 if best_score > 0:
389 # it's locally colorable, no need to find a better one
390 break
391
392 if best_node is None:
393 break
394 node_stack.append(best_node)
395 nodes_remaining.remove(best_node)
396
397 retval = {} # type: dict[SSAVal, RegLoc]
398
399 while len(node_stack) > 0:
400 node = node_stack.pop()
401 if node.reg is not None:
402 if node.reg_conflicts_with_neighbors(node.reg):
403 return AllocationFailed(node=node,
404 live_intervals=live_intervals,
405 interference_graph=interference_graph)
406 else:
407 # pick the first non-conflicting register in node.reg_class, since
408 # register classes are ordered from most preferred to least
409 # preferred register.
410 for reg in node.reg_class:
411 if not node.reg_conflicts_with_neighbors(reg):
412 node.reg = reg
413 break
414 if node.reg is None:
415 return AllocationFailed(node=node,
416 live_intervals=live_intervals,
417 interference_graph=interference_graph)
418
419 for ssa_val, offset in node.merged_reg_set.items():
420 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
421
422 return retval
423
424
425 def allocate_registers(ops):
426 # type: (list[Op]) -> dict[SSAVal, RegLoc]
427 retval = try_allocate_registers_without_spilling(ops)
428 if isinstance(retval, AllocationFailed):
429 # TODO: implement spilling
430 raise AllocationFailedError(
431 "spilling required but not yet implemented", retval)
432 return retval