try_allocate_registers_without_spilling works!
[bigint-presentation-code.git] / src / bigint_presentation_code / register_allocator.py
1 """
2 Register Allocator for Toom-Cook algorithm generator for SVP64
3
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from itertools import combinations
9 from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
10
11 from nmutil.plain_data import plain_data
12
13 from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
14 RegLoc, RegType, SSAVal)
15 from bigint_presentation_code.ordered_set import OFSet, OSet
16
17 if TYPE_CHECKING:
18 from typing_extensions import final
19 else:
20 def final(v):
21 return v
22
23
24 _RegType = TypeVar("_RegType", bound=RegType)
25
26
27 @plain_data(unsafe_hash=True, order=True, frozen=True)
28 class LiveInterval:
29 __slots__ = "first_write", "last_use"
30
31 def __init__(self, first_write, last_use=None):
32 # type: (int, int | None) -> None
33 if last_use is None:
34 last_use = first_write
35 if last_use < first_write:
36 raise ValueError("uses must be after first_write")
37 if first_write < 0 or last_use < 0:
38 raise ValueError("indexes must be nonnegative")
39 self.first_write = first_write
40 self.last_use = last_use
41
42 def overlaps(self, other):
43 # type: (LiveInterval) -> bool
44 if self.first_write == other.first_write:
45 return True
46 return self.last_use > other.first_write \
47 and other.last_use > self.first_write
48
49 def __add__(self, use):
50 # type: (int) -> LiveInterval
51 last_use = max(self.last_use, use)
52 return LiveInterval(first_write=self.first_write, last_use=last_use)
53
54 @property
55 def live_after_op_range(self):
56 """the range of op indexes where self is live immediately after the
57 Op at each index
58 """
59 return range(self.first_write, self.last_use)
60
61
62 @final
63 class MergedRegSet(Mapping[SSAVal[_RegType], int]):
64 def __init__(self, reg_set):
65 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
66 self.__items = {} # type: dict[SSAVal[_RegType], int]
67 if isinstance(reg_set, SSAVal):
68 reg_set = [(reg_set, 0)]
69 for ssa_val, offset in reg_set:
70 if ssa_val in self.__items:
71 other = self.__items[ssa_val]
72 if offset != other:
73 raise ValueError(
74 f"can't merge register sets: conflicting offsets: "
75 f"for {ssa_val}: {offset} != {other}")
76 else:
77 self.__items[ssa_val] = offset
78 first_item = None
79 for i in self.__items.items():
80 first_item = i
81 break
82 if first_item is None:
83 raise ValueError("can't have empty MergedRegs")
84 first_ssa_val, start = first_item
85 ty = first_ssa_val.ty
86 if isinstance(ty, GPRRangeType):
87 stop = start + ty.length
88 for ssa_val, offset in self.__items.items():
89 if not isinstance(ssa_val.ty, GPRRangeType):
90 raise ValueError(f"can't merge incompatible types: "
91 f"{ssa_val.ty} and {ty}")
92 stop = max(stop, offset + ssa_val.ty.length)
93 start = min(start, offset)
94 ty = GPRRangeType(stop - start)
95 else:
96 stop = 1
97 for ssa_val, offset in self.__items.items():
98 if offset != 0:
99 raise ValueError(f"can't have non-zero offset "
100 f"for {ssa_val.ty}")
101 if ty != ssa_val.ty:
102 raise ValueError(f"can't merge incompatible types: "
103 f"{ssa_val.ty} and {ty}")
104 self.__start = start # type: int
105 self.__stop = stop # type: int
106 self.__ty = ty # type: RegType
107 self.__hash = hash(OFSet(self.items()))
108
109 @staticmethod
110 def from_equality_constraint(constraint_sequence):
111 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
112 if len(constraint_sequence) == 1:
113 # any type allowed with len = 1
114 return MergedRegSet(constraint_sequence[0])
115 offset = 0
116 retval = []
117 for val in constraint_sequence:
118 if not isinstance(val.ty, GPRRangeType):
119 raise ValueError("equality constraint sequences must only "
120 "have SSAVal type GPRRangeType")
121 retval.append((val, offset))
122 offset += val.ty.length
123 return MergedRegSet(retval)
124
125 @property
126 def ty(self):
127 return self.__ty
128
129 @property
130 def stop(self):
131 return self.__stop
132
133 @property
134 def start(self):
135 return self.__start
136
137 @property
138 def range(self):
139 return range(self.__start, self.__stop)
140
141 def offset_by(self, amount):
142 # type: (int) -> MergedRegSet[_RegType]
143 return MergedRegSet((k, v + amount) for k, v in self.items())
144
145 def normalized(self):
146 # type: () -> MergedRegSet[_RegType]
147 return self.offset_by(-self.start)
148
149 def with_offset_to_match(self, target):
150 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
151 for ssa_val, offset in self.items():
152 if ssa_val in target:
153 return self.offset_by(target[ssa_val] - offset)
154 raise ValueError("can't change offset to match unrelated MergedRegSet")
155
156 def __getitem__(self, item):
157 # type: (SSAVal[_RegType]) -> int
158 return self.__items[item]
159
160 def __iter__(self):
161 return iter(self.__items)
162
163 def __len__(self):
164 return len(self.__items)
165
166 def __hash__(self):
167 return self.__hash
168
169 def __repr__(self):
170 return f"MergedRegSet({list(self.__items.items())})"
171
172
173 @final
174 class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
175 def __init__(self, ops):
176 # type: (Iterable[Op]) -> None
177 merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
178 for op in ops:
179 for val in (*op.inputs().values(), *op.outputs().values()):
180 if val not in merged_sets:
181 merged_sets[val] = MergedRegSet(val)
182 for e in op.get_equality_constraints():
183 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
184 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
185 items = [] # type: list[tuple[SSAVal, int]]
186 for i in e.lhs:
187 s = merged_sets[i].with_offset_to_match(lhs_set)
188 items.extend(s.items())
189 for i in e.rhs:
190 s = merged_sets[i].with_offset_to_match(rhs_set)
191 items.extend(s.items())
192 full_set = MergedRegSet(items)
193 for val in full_set.keys():
194 merged_sets[val] = full_set
195
196 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
197
198 def __getitem__(self, key):
199 # type: (SSAVal) -> MergedRegSet
200 return self.__map[key]
201
202 def __iter__(self):
203 return iter(self.__map)
204
205 def __len__(self):
206 return len(self.__map)
207
208 def __repr__(self):
209 return f"MergedRegSets(data={self.__map})"
210
211
212 @final
213 class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
214 def __init__(self, ops):
215 # type: (list[Op]) -> None
216 self.__merged_reg_sets = MergedRegSets(ops)
217 live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
218 for op_idx, op in enumerate(ops):
219 for val in op.inputs().values():
220 live_intervals[self.__merged_reg_sets[val]] += op_idx
221 for val in op.outputs().values():
222 reg_set = self.__merged_reg_sets[val]
223 if reg_set not in live_intervals:
224 live_intervals[reg_set] = LiveInterval(op_idx)
225 else:
226 live_intervals[reg_set] += op_idx
227 self.__live_intervals = live_intervals
228 live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
229 live_after += (OSet() for _ in ops)
230 for reg_set, live_interval in self.__live_intervals.items():
231 for i in live_interval.live_after_op_range:
232 live_after[i].add(reg_set)
233 self.__live_after = [OFSet(i) for i in live_after]
234
235 @property
236 def merged_reg_sets(self):
237 return self.__merged_reg_sets
238
239 def __getitem__(self, key):
240 # type: (MergedRegSet[_RegType]) -> LiveInterval
241 return self.__live_intervals[key]
242
243 def __iter__(self):
244 return iter(self.__live_intervals)
245
246 def __len__(self):
247 return len(self.__live_intervals)
248
249 def reg_sets_live_after(self, op_index):
250 # type: (int) -> OFSet[MergedRegSet[_RegType]]
251 return self.__live_after[op_index]
252
253 def __repr__(self):
254 reg_sets_live_after = dict(enumerate(self.__live_after))
255 return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
256 f"merged_reg_sets={self.merged_reg_sets}, "
257 f"reg_sets_live_after={reg_sets_live_after})")
258
259
260 @final
261 class IGNode(Generic[_RegType]):
262 """ interference graph node """
263 __slots__ = "merged_reg_set", "edges", "reg"
264
265 def __init__(self, merged_reg_set, edges=(), reg=None):
266 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
267 self.merged_reg_set = merged_reg_set
268 self.edges = OSet(edges)
269 self.reg = reg
270
271 def add_edge(self, other):
272 # type: (IGNode) -> None
273 self.edges.add(other)
274 other.edges.add(self)
275
276 def __eq__(self, other):
277 # type: (object) -> bool
278 if isinstance(other, IGNode):
279 return self.merged_reg_set == other.merged_reg_set
280 return NotImplemented
281
282 def __hash__(self):
283 return hash(self.merged_reg_set)
284
285 def __repr__(self, nodes=None):
286 # type: (None | dict[IGNode, int]) -> str
287 if nodes is None:
288 nodes = {}
289 if self in nodes:
290 return f"<IGNode #{nodes[self]}>"
291 nodes[self] = len(nodes)
292 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
293 return (f"IGNode(#{nodes[self]}, "
294 f"merged_reg_set={self.merged_reg_set}, "
295 f"edges={edges}, "
296 f"reg={self.reg})")
297
298 @property
299 def reg_class(self):
300 # type: () -> RegClass
301 return self.merged_reg_set.ty.reg_class
302
303 def reg_conflicts_with_neighbors(self, reg):
304 # type: (RegLoc) -> bool
305 for neighbor in self.edges:
306 if neighbor.reg is not None and neighbor.reg.conflicts(reg):
307 return True
308 return False
309
310
311 @final
312 class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
313 def __init__(self, merged_reg_sets):
314 # type: (Iterable[MergedRegSet[_RegType]]) -> None
315 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
316
317 def __getitem__(self, key):
318 # type: (MergedRegSet[_RegType]) -> IGNode
319 return self.__nodes[key]
320
321 def __iter__(self):
322 return iter(self.__nodes)
323
324 def __len__(self):
325 return len(self.__nodes)
326
327 def __repr__(self):
328 nodes = {}
329 nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
330 nodes_text = ", ".join(nodes_text)
331 return f"InterferenceGraph(nodes={{{nodes_text}}})"
332
333
334 @plain_data()
335 class AllocationFailed:
336 __slots__ = "node", "live_intervals", "interference_graph"
337
338 def __init__(self, node, live_intervals, interference_graph):
339 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
340 self.node = node
341 self.live_intervals = live_intervals
342 self.interference_graph = interference_graph
343
344
345 def try_allocate_registers_without_spilling(ops):
346 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
347
348 live_intervals = LiveIntervals(ops)
349 merged_reg_sets = live_intervals.merged_reg_sets
350 interference_graph = InterferenceGraph(merged_reg_sets.values())
351 for op_idx, op in enumerate(ops):
352 reg_sets = live_intervals.reg_sets_live_after(op_idx)
353 for i, j in combinations(reg_sets, 2):
354 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
355 interference_graph[i].add_edge(interference_graph[j])
356 for i, j in op.get_extra_interferences():
357 i = merged_reg_sets[i]
358 j = merged_reg_sets[j]
359 if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
360 interference_graph[i].add_edge(interference_graph[j])
361
362 nodes_remaining = OSet(interference_graph.values())
363
364 def local_colorability_score(node):
365 # type: (IGNode) -> int
366 """ returns a positive integer if node is locally colorable, returns
367 zero or a negative integer if node isn't known to be locally
368 colorable, the more negative the value, the less colorable
369 """
370 if node not in nodes_remaining:
371 raise ValueError()
372 retval = len(node.reg_class)
373 for neighbor in node.edges:
374 if neighbor in nodes_remaining:
375 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
376 return retval
377
378 node_stack = [] # type: list[IGNode]
379 while True:
380 best_node = None # type: None | IGNode
381 best_score = 0
382 for node in nodes_remaining:
383 score = local_colorability_score(node)
384 if best_node is None or score > best_score:
385 best_node = node
386 best_score = score
387 if best_score > 0:
388 # it's locally colorable, no need to find a better one
389 break
390
391 if best_node is None:
392 break
393 node_stack.append(best_node)
394 nodes_remaining.remove(best_node)
395
396 retval = {} # type: dict[SSAVal, RegLoc]
397
398 while len(node_stack) > 0:
399 node = node_stack.pop()
400 if node.reg is not None:
401 if node.reg_conflicts_with_neighbors(node.reg):
402 return AllocationFailed(node=node,
403 live_intervals=live_intervals,
404 interference_graph=interference_graph)
405 else:
406 # pick the first non-conflicting register in node.reg_class, since
407 # register classes are ordered from most preferred to least
408 # preferred register.
409 for reg in node.reg_class:
410 if not node.reg_conflicts_with_neighbors(reg):
411 node.reg = reg
412 break
413 if node.reg is None:
414 return AllocationFailed(node=node,
415 live_intervals=live_intervals,
416 interference_graph=interference_graph)
417
418 for ssa_val, offset in node.merged_reg_set.items():
419 retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
420
421 return retval
422
423
424 def allocate_registers(ops):
425 # type: (list[Op]) -> None
426 raise NotImplementedError