75b342243fa95cd7223647b6204f0d33d7e36cf0
2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import TYPE_CHECKING
, Generic
, Iterable
, Mapping
, TypeVar
11 from nmutil
.plain_data
import plain_data
13 from bigint_presentation_code
.compiler_ir
import (GPRRangeType
, Op
, RegClass
,
14 RegLoc
, RegType
, SSAVal
)
17 from typing_extensions
import Self
, final
23 _RegType
= TypeVar("_RegType", bound
=RegType
)
26 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
28 __slots__
= "first_write", "last_use"
30 def __init__(self
, first_write
, last_use
=None):
31 # type: (int, int | None) -> None
33 last_use
= first_write
34 if last_use
< first_write
:
35 raise ValueError("uses must be after first_write")
36 if first_write
< 0 or last_use
< 0:
37 raise ValueError("indexes must be nonnegative")
38 self
.first_write
= first_write
39 self
.last_use
= last_use
41 def overlaps(self
, other
):
42 # type: (LiveInterval) -> bool
43 if self
.first_write
== other
.first_write
:
45 return self
.last_use
> other
.first_write \
46 and other
.last_use
> self
.first_write
48 def __add__(self
, use
):
49 # type: (int) -> LiveInterval
50 last_use
= max(self
.last_use
, use
)
51 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
54 def live_after_op_range(self
):
55 """the range of op indexes where self is live immediately after the
58 return range(self
.first_write
, self
.last_use
)
62 class MergedRegSet(Mapping
[SSAVal
[_RegType
], int]):
63 def __init__(self
, reg_set
):
64 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
65 self
.__items
= {} # type: dict[SSAVal[_RegType], int]
66 if isinstance(reg_set
, SSAVal
):
67 reg_set
= [(reg_set
, 0)]
68 for ssa_val
, offset
in reg_set
:
69 if ssa_val
in self
.__items
:
70 other
= self
.__items
[ssa_val
]
73 f
"can't merge register sets: conflicting offsets: "
74 f
"for {ssa_val}: {offset} != {other}")
76 self
.__items
[ssa_val
] = offset
78 for i
in self
.__items
.items():
81 if first_item
is None:
82 raise ValueError("can't have empty MergedRegs")
83 first_ssa_val
, start
= first_item
85 if isinstance(ty
, GPRRangeType
):
86 stop
= start
+ ty
.length
87 for ssa_val
, offset
in self
.__items
.items():
88 if not isinstance(ssa_val
.ty
, GPRRangeType
):
89 raise ValueError(f
"can't merge incompatible types: "
90 f
"{ssa_val.ty} and {ty}")
91 stop
= max(stop
, offset
+ ssa_val
.ty
.length
)
92 start
= min(start
, offset
)
93 ty
= GPRRangeType(stop
- start
)
96 for ssa_val
, offset
in self
.__items
.items():
98 raise ValueError(f
"can't have non-zero offset "
101 raise ValueError(f
"can't merge incompatible types: "
102 f
"{ssa_val.ty} and {ty}")
103 self
.__start
= start
# type: int
104 self
.__stop
= stop
# type: int
105 self
.__ty
= ty
# type: RegType
106 self
.__hash
= hash(frozenset(self
.items()))
109 def from_equality_constraint(constraint_sequence
):
110 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
111 if len(constraint_sequence
) == 1:
112 # any type allowed with len = 1
113 return MergedRegSet(constraint_sequence
[0])
116 for val
in constraint_sequence
:
117 if not isinstance(val
.ty
, GPRRangeType
):
118 raise ValueError("equality constraint sequences must only "
119 "have SSAVal type GPRRangeType")
120 retval
.append((val
, offset
))
121 offset
+= val
.ty
.length
122 return MergedRegSet(retval
)
138 return range(self
.__start
, self
.__stop
)
140 def offset_by(self
, amount
):
141 # type: (int) -> MergedRegSet[_RegType]
142 return MergedRegSet((k
, v
+ amount
) for k
, v
in self
.items())
144 def normalized(self
):
145 # type: () -> MergedRegSet[_RegType]
146 return self
.offset_by(-self
.start
)
148 def with_offset_to_match(self
, target
):
149 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
150 for ssa_val
, offset
in self
.items():
151 if ssa_val
in target
:
152 return self
.offset_by(target
[ssa_val
] - offset
)
153 raise ValueError("can't change offset to match unrelated MergedRegSet")
155 def __getitem__(self
, item
):
156 # type: (SSAVal[_RegType]) -> int
157 return self
.__items
[item
]
160 return iter(self
.__items
)
163 return len(self
.__items
)
169 return f
"MergedRegSet({list(self.__items.items())})"
173 class MergedRegSets(Mapping
[SSAVal
, MergedRegSet
[_RegType
]], Generic
[_RegType
]):
174 def __init__(self
, ops
):
175 # type: (Iterable[Op]) -> None
176 merged_sets
= {} # type: dict[SSAVal, MergedRegSet[_RegType]]
178 for val
in (*op
.inputs().values(), *op
.outputs().values()):
179 if val
not in merged_sets
:
180 merged_sets
[val
] = MergedRegSet(val
)
181 for e
in op
.get_equality_constraints():
182 lhs_set
= MergedRegSet
.from_equality_constraint(e
.lhs
)
183 rhs_set
= MergedRegSet
.from_equality_constraint(e
.rhs
)
184 lhs_set
= merged_sets
[e
.lhs
[0]].with_offset_to_match(lhs_set
)
185 rhs_set
= merged_sets
[e
.rhs
[0]].with_offset_to_match(rhs_set
)
186 full_set
= MergedRegSet([*lhs_set
.items(), *rhs_set
.items()])
187 for val
in full_set
.keys():
188 merged_sets
[val
] = full_set
190 self
.__map
= {k
: v
.normalized() for k
, v
in merged_sets
.items()}
192 def __getitem__(self
, key
):
193 # type: (SSAVal) -> MergedRegSet
194 return self
.__map
[key
]
197 return iter(self
.__map
)
200 return len(self
.__map
)
203 return f
"MergedRegSets(data={self.__map})"
207 class LiveIntervals(Mapping
[MergedRegSet
[_RegType
], LiveInterval
]):
208 def __init__(self
, ops
):
209 # type: (list[Op]) -> None
210 self
.__merged
_reg
_sets
= MergedRegSets(ops
)
211 live_intervals
= {} # type: dict[MergedRegSet[_RegType], LiveInterval]
212 for op_idx
, op
in enumerate(ops
):
213 for val
in op
.inputs().values():
214 live_intervals
[self
.__merged
_reg
_sets
[val
]] += op_idx
215 for val
in op
.outputs().values():
216 reg_set
= self
.__merged
_reg
_sets
[val
]
217 if reg_set
not in live_intervals
:
218 live_intervals
[reg_set
] = LiveInterval(op_idx
)
220 live_intervals
[reg_set
] += op_idx
221 self
.__live
_intervals
= live_intervals
222 live_after
= [] # type: list[set[MergedRegSet[_RegType]]]
223 live_after
+= (set() for _
in ops
)
224 for reg_set
, live_interval
in self
.__live
_intervals
.items():
225 for i
in live_interval
.live_after_op_range
:
226 live_after
[i
].add(reg_set
)
227 self
.__live
_after
= [frozenset(i
) for i
in live_after
]
230 def merged_reg_sets(self
):
231 return self
.__merged
_reg
_sets
233 def __getitem__(self
, key
):
234 # type: (MergedRegSet[_RegType]) -> LiveInterval
235 return self
.__live
_intervals
[key
]
238 return iter(self
.__live
_intervals
)
240 def reg_sets_live_after(self
, op_index
):
241 # type: (int) -> frozenset[MergedRegSet[_RegType]]
242 return self
.__live
_after
[op_index
]
245 reg_sets_live_after
= dict(enumerate(self
.__live
_after
))
246 return (f
"LiveIntervals(live_intervals={self.__live_intervals}, "
247 f
"merged_reg_sets={self.merged_reg_sets}, "
248 f
"reg_sets_live_after={reg_sets_live_after})")
252 class IGNode(Generic
[_RegType
]):
253 """ interference graph node """
254 __slots__
= "merged_reg_set", "edges", "reg"
256 def __init__(self
, merged_reg_set
, edges
=(), reg
=None):
257 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
258 self
.merged_reg_set
= merged_reg_set
259 self
.edges
= set(edges
)
262 def add_edge(self
, other
):
263 # type: (IGNode) -> None
264 self
.edges
.add(other
)
265 other
.edges
.add(self
)
267 def __eq__(self
, other
):
268 # type: (object) -> bool
269 if isinstance(other
, IGNode
):
270 return self
.merged_reg_set
== other
.merged_reg_set
271 return NotImplemented
274 return hash(self
.merged_reg_set
)
276 def __repr__(self
, nodes
=None):
277 # type: (None | dict[IGNode, int]) -> str
281 return f
"<IGNode #{nodes[self]}>"
282 nodes
[self
] = len(nodes
)
283 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
284 return (f
"IGNode(#{nodes[self]}, "
285 f
"merged_reg_set={self.merged_reg_set}, "
291 # type: () -> RegClass
292 return self
.merged_reg_set
.ty
.reg_class
294 def reg_conflicts_with_neighbors(self
, reg
):
295 # type: (RegLoc) -> bool
296 for neighbor
in self
.edges
:
297 if neighbor
.reg
is not None and neighbor
.reg
.conflicts(reg
):
303 class InterferenceGraph(Mapping
[MergedRegSet
[_RegType
], IGNode
[_RegType
]]):
304 def __init__(self
, merged_reg_sets
):
305 # type: (Iterable[MergedRegSet[_RegType]]) -> None
306 self
.__nodes
= {i
: IGNode(i
) for i
in merged_reg_sets
}
308 def __getitem__(self
, key
):
309 # type: (MergedRegSet[_RegType]) -> IGNode
310 return self
.__nodes
[key
]
313 return iter(self
.__nodes
)
317 nodes_text
= [f
"...: {node.__repr__(nodes)}" for node
in self
.values()]
318 nodes_text
= ", ".join(nodes_text
)
319 return f
"InterferenceGraph(nodes={{{nodes_text}}})"
323 class AllocationFailed
:
324 __slots__
= "node", "live_intervals", "interference_graph"
326 def __init__(self
, node
, live_intervals
, interference_graph
):
327 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
329 self
.live_intervals
= live_intervals
330 self
.interference_graph
= interference_graph
333 def try_allocate_registers_without_spilling(ops
):
334 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
336 live_intervals
= LiveIntervals(ops
)
337 merged_reg_sets
= live_intervals
.merged_reg_sets
338 interference_graph
= InterferenceGraph(merged_reg_sets
.values())
339 for op_idx
, op
in enumerate(ops
):
340 reg_sets
= live_intervals
.reg_sets_live_after(op_idx
)
341 for i
, j
in combinations(reg_sets
, 2):
342 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
343 interference_graph
[i
].add_edge(interference_graph
[j
])
344 for i
, j
in op
.get_extra_interferences():
345 i
= merged_reg_sets
[i
]
346 j
= merged_reg_sets
[j
]
347 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
348 interference_graph
[i
].add_edge(interference_graph
[j
])
350 nodes_remaining
= set(interference_graph
.values())
352 def local_colorability_score(node
):
353 # type: (IGNode) -> int
354 """ returns a positive integer if node is locally colorable, returns
355 zero or a negative integer if node isn't known to be locally
356 colorable, the more negative the value, the less colorable
358 if node
not in nodes_remaining
:
360 retval
= len(node
.reg_class
)
361 for neighbor
in node
.edges
:
362 if neighbor
in nodes_remaining
:
363 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
366 node_stack
= [] # type: list[IGNode]
368 best_node
= None # type: None | IGNode
370 for node
in nodes_remaining
:
371 score
= local_colorability_score(node
)
372 if best_node
is None or score
> best_score
:
376 # it's locally colorable, no need to find a better one
379 if best_node
is None:
381 node_stack
.append(best_node
)
382 nodes_remaining
.remove(best_node
)
384 retval
= {} # type: dict[SSAVal, RegLoc]
386 while len(node_stack
) > 0:
387 node
= node_stack
.pop()
388 if node
.reg
is not None:
389 if node
.reg_conflicts_with_neighbors(node
.reg
):
390 return AllocationFailed(node
=node
,
391 live_intervals
=live_intervals
,
392 interference_graph
=interference_graph
)
394 # pick the first non-conflicting register in node.reg_class, since
395 # register classes are ordered from most preferred to least
396 # preferred register.
397 for reg
in node
.reg_class
:
398 if not node
.reg_conflicts_with_neighbors(reg
):
402 return AllocationFailed(node
=node
,
403 live_intervals
=live_intervals
,
404 interference_graph
=interference_graph
)
406 for ssa_val
, offset
in node
.merged_reg_set
.items():
407 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
412 def allocate_registers(ops
):
413 # type: (list[Op]) -> None
414 raise NotImplementedError