2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from functools
import reduce
10 from typing
import Generic
, Iterable
, Mapping
11 from cached_property
import cached_property
14 from nmutil
.plain_data
import plain_data
16 from bigint_presentation_code
.compiler_ir2
import (
17 Op
, LocSet
, Ty
, SSAVal
, BaseTy
, Loc
, FnWithUses
)
18 from bigint_presentation_code
.type_util
import final
, Self
19 from bigint_presentation_code
.util
import OFSet
, OSet
, FMap
22 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
24 __slots__
= "first_write", "last_use"
26 def __init__(self
, first_write
, last_use
=None):
27 # type: (int, int | None) -> None
29 last_use
= first_write
30 if last_use
< first_write
:
31 raise ValueError("uses must be after first_write")
32 if first_write
< 0 or last_use
< 0:
33 raise ValueError("indexes must be nonnegative")
34 self
.first_write
= first_write
35 self
.last_use
= last_use
37 def overlaps(self
, other
):
38 # type: (LiveInterval) -> bool
39 if self
.first_write
== other
.first_write
:
41 return self
.last_use
> other
.first_write \
42 and other
.last_use
> self
.first_write
44 def __add__(self
, use
):
45 # type: (int) -> LiveInterval
46 last_use
= max(self
.last_use
, use
)
47 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
50 def live_after_op_range(self
):
51 """the range of op indexes where self is live immediately after the
54 return range(self
.first_write
, self
.last_use
)
57 class BadMergedSSAVal(ValueError):
61 @plain_data(frozen
=True, unsafe_hash
=True)
64 """a set of `SSAVal`s along with their offsets, all register allocated as
67 Definition of the term `offset` for this class:
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
91 __slots__
= "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
93 def __init__(self
, fn_with_uses
, ssa_val_offsets
):
94 # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
95 self
.fn_with_uses
= fn_with_uses
96 if isinstance(ssa_val_offsets
, SSAVal
):
97 ssa_val_offsets
= {ssa_val_offsets
: 0}
98 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
100 for ssa_val
in self
.ssa_val_offsets
.keys():
101 base_ty
= ssa_val
.base_ty
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self
.base_ty
= base_ty
# type: BaseTy
106 # self.ty checks for mismatched base_ty
107 reg_len
= self
.ty
.reg_len
108 loc_set
= None # type: None | LocSet
109 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
110 def_spread_idx
= ssa_val
.defining_descriptor
.spread_index
or 0
113 # type: () -> Iterable[Loc]
114 for loc
in ssa_val
.def_loc_set_before_spread
:
115 disallowed_by_use
= False
116 for use
in fn_with_uses
.uses
[ssa_val
]:
118 use
.defining_descriptor
.spread_index
or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start
= loc
.start
+ def_spread_idx
- use_spread_idx
126 use_loc
= Loc
.try_make(
127 loc
.kind
, start
=start
,
128 reg_len
=use
.ty_before_spread
.reg_len
)
129 if (use_loc
is None or
130 use_loc
not in use
.use_loc_set_before_spread
):
131 disallowed_by_use
= True
133 if disallowed_by_use
:
135 # FIXME: add spread consistency check
136 start
= loc
.start
- cur_offset
+ self
.offset
137 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
138 if loc
is not None and (loc_set
is None or loc
in loc_set
):
140 loc_set
= LocSet(locs())
141 assert loc_set
is not None, "already checked that self isn't empty"
142 if loc_set
.ty
is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
145 self
.loc_set
= loc_set
# type: LocSet
150 return min(self
.ssa_val_offsets_before_spread
.values())
156 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
157 cur_ty
= ssa_val
.ty_before_spread
158 if self
.base_ty
!= cur_ty
.base_ty
:
159 raise BadMergedSSAVal(
160 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
161 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
162 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
165 def ssa_val_offsets_before_spread(self
):
166 # type: () -> FMap[SSAVal, int]
167 retval
= {} # type: dict[SSAVal, int]
168 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
169 offset_before_spread
= offset
170 spread_index
= ssa_val
.defining_descriptor
.spread_index
171 if spread_index
is not None:
172 assert ssa_val
.ty
.reg_len
== 1, (
173 "this function assumes spreading always converts a vector "
174 "to a contiguous sequence of scalars, if that's changed "
175 "in the future, then this function needs to be adjusted")
176 offset_before_spread
-= spread_index
177 retval
[ssa_val
] = offset_before_spread
180 def offset_by(self
, amount
):
181 # type: (int) -> MergedSSAVal
182 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
183 return MergedSSAVal(fn_with_uses
=self
.fn_with_uses
, ssa_val_offsets
=v
)
185 def normalized(self
):
186 # type: () -> MergedSSAVal
187 return self
.offset_by(-self
.offset
)
189 def with_offset_to_match(self
, target
):
190 # type: (MergedSSAVal) -> MergedSSAVal
191 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
192 if ssa_val
in target
.ssa_val_offsets
:
193 return self
.offset_by(target
.ssa_val_offsets
[ssa_val
] - offset
)
194 raise ValueError("can't change offset to match unrelated MergedSSAVal")
198 class MergedSSAVals(OFSet
[MergedSSAVal
]):
199 def __init__(self
, merged_ssa_vals
=()):
200 # type: (Iterable[MergedSSAVal]) -> None
201 super().__init
__(merged_ssa_vals
)
202 merge_map
= {} # type: dict[SSAVal, MergedSSAVal]
203 for merged_ssa_val
in self
:
204 for ssa_val
in merged_ssa_val
.ssa_val_offsets
.keys():
205 if ssa_val
in merge_map
:
207 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
208 f
"{merged_ssa_val} and {merge_map[ssa_val]}")
209 merge_map
[ssa_val
] = merged_ssa_val
210 self
.__merge
_map
= FMap(merge_map
)
214 # type: () -> FMap[SSAVal, MergedSSAVal]
215 return self
.__merge
_map
217 # FIXME: work on code from here
220 def minimally_merged(fn_with_uses
):
221 # type: (FnWithUses) -> MergedSSAVals
222 merge_map
= {} # type: dict[SSAVal, MergedSSAVal]
223 for op
in fn_with_uses
.fn
.ops
:
225 for val
in (*op
.inputs().values(), *op
.outputs().values()):
226 if val
not in merged_sets
:
227 merged_sets
[val
] = MergedRegSet(val
)
228 for e
in op
.get_equality_constraints():
229 lhs_set
= MergedRegSet
.from_equality_constraint(e
.lhs
)
230 rhs_set
= MergedRegSet
.from_equality_constraint(e
.rhs
)
231 items
= [] # type: list[tuple[SSAVal, int]]
233 s
= merged_sets
[i
].with_offset_to_match(lhs_set
)
234 items
.extend(s
.items())
236 s
= merged_sets
[i
].with_offset_to_match(rhs_set
)
237 items
.extend(s
.items())
238 full_set
= MergedRegSet(items
)
239 for val
in full_set
.keys():
240 merged_sets
[val
] = full_set
242 self
.__map
= {k
: v
.normalized() for k
, v
in merged_sets
.items()}
246 class LiveIntervals(Mapping
[MergedRegSet
[_RegType
], LiveInterval
]):
247 def __init__(self
, ops
):
248 # type: (list[Op]) -> None
249 self
.__merged
_reg
_sets
= MergedRegSets(ops
)
250 live_intervals
= {} # type: dict[MergedRegSet[_RegType], LiveInterval]
251 for op_idx
, op
in enumerate(ops
):
252 for val
in op
.inputs().values():
253 live_intervals
[self
.__merged
_reg
_sets
[val
]] += op_idx
254 for val
in op
.outputs().values():
255 reg_set
= self
.__merged
_reg
_sets
[val
]
256 if reg_set
not in live_intervals
:
257 live_intervals
[reg_set
] = LiveInterval(op_idx
)
259 live_intervals
[reg_set
] += op_idx
260 self
.__live
_intervals
= live_intervals
261 live_after
= [] # type: list[OSet[MergedRegSet[_RegType]]]
262 live_after
+= (OSet() for _
in ops
)
263 for reg_set
, live_interval
in self
.__live
_intervals
.items():
264 for i
in live_interval
.live_after_op_range
:
265 live_after
[i
].add(reg_set
)
266 self
.__live
_after
= [OFSet(i
) for i
in live_after
]
269 def merged_reg_sets(self
):
270 return self
.__merged
_reg
_sets
272 def __getitem__(self
, key
):
273 # type: (MergedRegSet[_RegType]) -> LiveInterval
274 return self
.__live
_intervals
[key
]
277 return iter(self
.__live
_intervals
)
280 return len(self
.__live
_intervals
)
282 def reg_sets_live_after(self
, op_index
):
283 # type: (int) -> OFSet[MergedRegSet[_RegType]]
284 return self
.__live
_after
[op_index
]
287 reg_sets_live_after
= dict(enumerate(self
.__live
_after
))
288 return (f
"LiveIntervals(live_intervals={self.__live_intervals}, "
289 f
"merged_reg_sets={self.merged_reg_sets}, "
290 f
"reg_sets_live_after={reg_sets_live_after})")
294 class IGNode(Generic
[_RegType
]):
295 """ interference graph node """
296 __slots__
= "merged_reg_set", "edges", "reg"
298 def __init__(self
, merged_reg_set
, edges
=(), reg
=None):
299 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
300 self
.merged_reg_set
= merged_reg_set
301 self
.edges
= OSet(edges
)
304 def add_edge(self
, other
):
305 # type: (IGNode) -> None
306 self
.edges
.add(other
)
307 other
.edges
.add(self
)
309 def __eq__(self
, other
):
310 # type: (object) -> bool
311 if isinstance(other
, IGNode
):
312 return self
.merged_reg_set
== other
.merged_reg_set
313 return NotImplemented
316 return hash(self
.merged_reg_set
)
318 def __repr__(self
, nodes
=None):
319 # type: (None | dict[IGNode, int]) -> str
323 return f
"<IGNode #{nodes[self]}>"
324 nodes
[self
] = len(nodes
)
325 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
326 return (f
"IGNode(#{nodes[self]}, "
327 f
"merged_reg_set={self.merged_reg_set}, "
333 # type: () -> RegClass
334 return self
.merged_reg_set
.ty
.reg_class
336 def reg_conflicts_with_neighbors(self
, reg
):
337 # type: (RegLoc) -> bool
338 for neighbor
in self
.edges
:
339 if neighbor
.reg
is not None and neighbor
.reg
.conflicts(reg
):
345 class InterferenceGraph(Mapping
[MergedRegSet
[_RegType
], IGNode
[_RegType
]]):
346 def __init__(self
, merged_reg_sets
):
347 # type: (Iterable[MergedRegSet[_RegType]]) -> None
348 self
.__nodes
= {i
: IGNode(i
) for i
in merged_reg_sets
}
350 def __getitem__(self
, key
):
351 # type: (MergedRegSet[_RegType]) -> IGNode
352 return self
.__nodes
[key
]
355 return iter(self
.__nodes
)
358 return len(self
.__nodes
)
362 nodes_text
= [f
"...: {node.__repr__(nodes)}" for node
in self
.values()]
363 nodes_text
= ", ".join(nodes_text
)
364 return f
"InterferenceGraph(nodes={{{nodes_text}}})"
368 class AllocationFailed
:
369 __slots__
= "node", "live_intervals", "interference_graph"
371 def __init__(self
, node
, live_intervals
, interference_graph
):
372 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
374 self
.live_intervals
= live_intervals
375 self
.interference_graph
= interference_graph
378 class AllocationFailedError(Exception):
379 def __init__(self
, msg
, allocation_failed
):
380 # type: (str, AllocationFailed) -> None
381 super().__init
__(msg
, allocation_failed
)
382 self
.allocation_failed
= allocation_failed
385 def try_allocate_registers_without_spilling(ops
):
386 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
388 live_intervals
= LiveIntervals(ops
)
389 merged_reg_sets
= live_intervals
.merged_reg_sets
390 interference_graph
= InterferenceGraph(merged_reg_sets
.values())
391 for op_idx
, op
in enumerate(ops
):
392 reg_sets
= live_intervals
.reg_sets_live_after(op_idx
)
393 for i
, j
in combinations(reg_sets
, 2):
394 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
395 interference_graph
[i
].add_edge(interference_graph
[j
])
396 for i
, j
in op
.get_extra_interferences():
397 i
= merged_reg_sets
[i
]
398 j
= merged_reg_sets
[j
]
399 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
400 interference_graph
[i
].add_edge(interference_graph
[j
])
402 nodes_remaining
= OSet(interference_graph
.values())
404 def local_colorability_score(node
):
405 # type: (IGNode) -> int
406 """ returns a positive integer if node is locally colorable, returns
407 zero or a negative integer if node isn't known to be locally
408 colorable, the more negative the value, the less colorable
410 if node
not in nodes_remaining
:
412 retval
= len(node
.reg_class
)
413 for neighbor
in node
.edges
:
414 if neighbor
in nodes_remaining
:
415 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
418 node_stack
= [] # type: list[IGNode]
420 best_node
= None # type: None | IGNode
422 for node
in nodes_remaining
:
423 score
= local_colorability_score(node
)
424 if best_node
is None or score
> best_score
:
428 # it's locally colorable, no need to find a better one
431 if best_node
is None:
433 node_stack
.append(best_node
)
434 nodes_remaining
.remove(best_node
)
436 retval
= {} # type: dict[SSAVal, RegLoc]
438 while len(node_stack
) > 0:
439 node
= node_stack
.pop()
440 if node
.reg
is not None:
441 if node
.reg_conflicts_with_neighbors(node
.reg
):
442 return AllocationFailed(node
=node
,
443 live_intervals
=live_intervals
,
444 interference_graph
=interference_graph
)
446 # pick the first non-conflicting register in node.reg_class, since
447 # register classes are ordered from most preferred to least
448 # preferred register.
449 for reg
in node
.reg_class
:
450 if not node
.reg_conflicts_with_neighbors(reg
):
454 return AllocationFailed(node
=node
,
455 live_intervals
=live_intervals
,
456 interference_graph
=interference_graph
)
458 for ssa_val
, offset
in node
.merged_reg_set
.items():
459 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
464 def allocate_registers(ops
):
465 # type: (list[Op]) -> dict[SSAVal, RegLoc]
466 retval
= try_allocate_registers_without_spilling(ops
)
467 if isinstance(retval
, AllocationFailed
):
468 # TODO: implement spilling
469 raise AllocationFailedError(
470 "spilling required but not yet implemented", retval
)