2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Generic
, Iterable
, Mapping
, TypeVar
11 from nmutil
.plain_data
import plain_data
13 from bigint_presentation_code
.compiler_ir
import (GPRRangeType
, Op
, RegClass
,
14 RegLoc
, RegType
, SSAVal
)
15 from bigint_presentation_code
.type_util
import final
16 from bigint_presentation_code
.util
import OFSet
, OSet
18 _RegType
= TypeVar("_RegType", bound
=RegType
)
21 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
23 __slots__
= "first_write", "last_use"
25 def __init__(self
, first_write
, last_use
=None):
26 # type: (int, int | None) -> None
28 last_use
= first_write
29 if last_use
< first_write
:
30 raise ValueError("uses must be after first_write")
31 if first_write
< 0 or last_use
< 0:
32 raise ValueError("indexes must be nonnegative")
33 self
.first_write
= first_write
34 self
.last_use
= last_use
36 def overlaps(self
, other
):
37 # type: (LiveInterval) -> bool
38 if self
.first_write
== other
.first_write
:
40 return self
.last_use
> other
.first_write \
41 and other
.last_use
> self
.first_write
43 def __add__(self
, use
):
44 # type: (int) -> LiveInterval
45 last_use
= max(self
.last_use
, use
)
46 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
49 def live_after_op_range(self
):
50 """the range of op indexes where self is live immediately after the
53 return range(self
.first_write
, self
.last_use
)
57 class MergedRegSet(Mapping
[SSAVal
[_RegType
], int]):
58 def __init__(self
, reg_set
):
59 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
60 self
.__items
= {} # type: dict[SSAVal[_RegType], int]
61 if isinstance(reg_set
, SSAVal
):
62 reg_set
= [(reg_set
, 0)]
63 for ssa_val
, offset
in reg_set
:
64 if ssa_val
in self
.__items
:
65 other
= self
.__items
[ssa_val
]
68 f
"can't merge register sets: conflicting offsets: "
69 f
"for {ssa_val}: {offset} != {other}")
71 self
.__items
[ssa_val
] = offset
73 for i
in self
.__items
.items():
76 if first_item
is None:
77 raise ValueError("can't have empty MergedRegs")
78 first_ssa_val
, start
= first_item
80 if isinstance(ty
, GPRRangeType
):
81 stop
= start
+ ty
.length
82 for ssa_val
, offset
in self
.__items
.items():
83 if not isinstance(ssa_val
.ty
, GPRRangeType
):
84 raise ValueError(f
"can't merge incompatible types: "
85 f
"{ssa_val.ty} and {ty}")
86 stop
= max(stop
, offset
+ ssa_val
.ty
.length
)
87 start
= min(start
, offset
)
88 ty
= GPRRangeType(stop
- start
)
91 for ssa_val
, offset
in self
.__items
.items():
93 raise ValueError(f
"can't have non-zero offset "
96 raise ValueError(f
"can't merge incompatible types: "
97 f
"{ssa_val.ty} and {ty}")
98 self
.__start
= start
# type: int
99 self
.__stop
= stop
# type: int
100 self
.__ty
= ty
# type: RegType
101 self
.__hash
= hash(OFSet(self
.items()))
104 def from_equality_constraint(constraint_sequence
):
105 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
106 if len(constraint_sequence
) == 1:
107 # any type allowed with len = 1
108 return MergedRegSet(constraint_sequence
[0])
111 for val
in constraint_sequence
:
112 if not isinstance(val
.ty
, GPRRangeType
):
113 raise ValueError("equality constraint sequences must only "
114 "have SSAVal type GPRRangeType")
115 retval
.append((val
, offset
))
116 offset
+= val
.ty
.length
117 return MergedRegSet(retval
)
133 return range(self
.__start
, self
.__stop
)
135 def offset_by(self
, amount
):
136 # type: (int) -> MergedRegSet[_RegType]
137 return MergedRegSet((k
, v
+ amount
) for k
, v
in self
.items())
139 def normalized(self
):
140 # type: () -> MergedRegSet[_RegType]
141 return self
.offset_by(-self
.start
)
143 def with_offset_to_match(self
, target
):
144 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
145 for ssa_val
, offset
in self
.items():
146 if ssa_val
in target
:
147 return self
.offset_by(target
[ssa_val
] - offset
)
148 raise ValueError("can't change offset to match unrelated MergedRegSet")
150 def __getitem__(self
, item
):
151 # type: (SSAVal[_RegType]) -> int
152 return self
.__items
[item
]
155 return iter(self
.__items
)
158 return len(self
.__items
)
164 return f
"MergedRegSet({list(self.__items.items())})"
168 class MergedRegSets(Mapping
[SSAVal
, MergedRegSet
[_RegType
]], Generic
[_RegType
]):
169 def __init__(self
, ops
):
170 # type: (Iterable[Op]) -> None
171 merged_sets
= {} # type: dict[SSAVal, MergedRegSet[_RegType]]
173 for val
in (*op
.inputs().values(), *op
.outputs().values()):
174 if val
not in merged_sets
:
175 merged_sets
[val
] = MergedRegSet(val
)
176 for e
in op
.get_equality_constraints():
177 lhs_set
= MergedRegSet
.from_equality_constraint(e
.lhs
)
178 rhs_set
= MergedRegSet
.from_equality_constraint(e
.rhs
)
179 items
= [] # type: list[tuple[SSAVal, int]]
181 s
= merged_sets
[i
].with_offset_to_match(lhs_set
)
182 items
.extend(s
.items())
184 s
= merged_sets
[i
].with_offset_to_match(rhs_set
)
185 items
.extend(s
.items())
186 full_set
= MergedRegSet(items
)
187 for val
in full_set
.keys():
188 merged_sets
[val
] = full_set
190 self
.__map
= {k
: v
.normalized() for k
, v
in merged_sets
.items()}
192 def __getitem__(self
, key
):
193 # type: (SSAVal) -> MergedRegSet
194 return self
.__map
[key
]
197 return iter(self
.__map
)
200 return len(self
.__map
)
203 return f
"MergedRegSets(data={self.__map})"
207 class LiveIntervals(Mapping
[MergedRegSet
[_RegType
], LiveInterval
]):
208 def __init__(self
, ops
):
209 # type: (list[Op]) -> None
210 self
.__merged
_reg
_sets
= MergedRegSets(ops
)
211 live_intervals
= {} # type: dict[MergedRegSet[_RegType], LiveInterval]
212 for op_idx
, op
in enumerate(ops
):
213 for val
in op
.inputs().values():
214 live_intervals
[self
.__merged
_reg
_sets
[val
]] += op_idx
215 for val
in op
.outputs().values():
216 reg_set
= self
.__merged
_reg
_sets
[val
]
217 if reg_set
not in live_intervals
:
218 live_intervals
[reg_set
] = LiveInterval(op_idx
)
220 live_intervals
[reg_set
] += op_idx
221 self
.__live
_intervals
= live_intervals
222 live_after
= [] # type: list[OSet[MergedRegSet[_RegType]]]
223 live_after
+= (OSet() for _
in ops
)
224 for reg_set
, live_interval
in self
.__live
_intervals
.items():
225 for i
in live_interval
.live_after_op_range
:
226 live_after
[i
].add(reg_set
)
227 self
.__live
_after
= [OFSet(i
) for i
in live_after
]
230 def merged_reg_sets(self
):
231 return self
.__merged
_reg
_sets
233 def __getitem__(self
, key
):
234 # type: (MergedRegSet[_RegType]) -> LiveInterval
235 return self
.__live
_intervals
[key
]
238 return iter(self
.__live
_intervals
)
241 return len(self
.__live
_intervals
)
243 def reg_sets_live_after(self
, op_index
):
244 # type: (int) -> OFSet[MergedRegSet[_RegType]]
245 return self
.__live
_after
[op_index
]
248 reg_sets_live_after
= dict(enumerate(self
.__live
_after
))
249 return (f
"LiveIntervals(live_intervals={self.__live_intervals}, "
250 f
"merged_reg_sets={self.merged_reg_sets}, "
251 f
"reg_sets_live_after={reg_sets_live_after})")
255 class IGNode(Generic
[_RegType
]):
256 """ interference graph node """
257 __slots__
= "merged_reg_set", "edges", "reg"
259 def __init__(self
, merged_reg_set
, edges
=(), reg
=None):
260 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
261 self
.merged_reg_set
= merged_reg_set
262 self
.edges
= OSet(edges
)
265 def add_edge(self
, other
):
266 # type: (IGNode) -> None
267 self
.edges
.add(other
)
268 other
.edges
.add(self
)
270 def __eq__(self
, other
):
271 # type: (object) -> bool
272 if isinstance(other
, IGNode
):
273 return self
.merged_reg_set
== other
.merged_reg_set
274 return NotImplemented
277 return hash(self
.merged_reg_set
)
279 def __repr__(self
, nodes
=None):
280 # type: (None | dict[IGNode, int]) -> str
284 return f
"<IGNode #{nodes[self]}>"
285 nodes
[self
] = len(nodes
)
286 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
287 return (f
"IGNode(#{nodes[self]}, "
288 f
"merged_reg_set={self.merged_reg_set}, "
294 # type: () -> RegClass
295 return self
.merged_reg_set
.ty
.reg_class
297 def reg_conflicts_with_neighbors(self
, reg
):
298 # type: (RegLoc) -> bool
299 for neighbor
in self
.edges
:
300 if neighbor
.reg
is not None and neighbor
.reg
.conflicts(reg
):
306 class InterferenceGraph(Mapping
[MergedRegSet
[_RegType
], IGNode
[_RegType
]]):
307 def __init__(self
, merged_reg_sets
):
308 # type: (Iterable[MergedRegSet[_RegType]]) -> None
309 self
.__nodes
= {i
: IGNode(i
) for i
in merged_reg_sets
}
311 def __getitem__(self
, key
):
312 # type: (MergedRegSet[_RegType]) -> IGNode
313 return self
.__nodes
[key
]
316 return iter(self
.__nodes
)
319 return len(self
.__nodes
)
323 nodes_text
= [f
"...: {node.__repr__(nodes)}" for node
in self
.values()]
324 nodes_text
= ", ".join(nodes_text
)
325 return f
"InterferenceGraph(nodes={{{nodes_text}}})"
329 class AllocationFailed
:
330 __slots__
= "node", "live_intervals", "interference_graph"
332 def __init__(self
, node
, live_intervals
, interference_graph
):
333 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
335 self
.live_intervals
= live_intervals
336 self
.interference_graph
= interference_graph
339 class AllocationFailedError(Exception):
340 def __init__(self
, msg
, allocation_failed
):
341 # type: (str, AllocationFailed) -> None
342 super().__init
__(msg
, allocation_failed
)
343 self
.allocation_failed
= allocation_failed
346 def try_allocate_registers_without_spilling(ops
):
347 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
349 live_intervals
= LiveIntervals(ops
)
350 merged_reg_sets
= live_intervals
.merged_reg_sets
351 interference_graph
= InterferenceGraph(merged_reg_sets
.values())
352 for op_idx
, op
in enumerate(ops
):
353 reg_sets
= live_intervals
.reg_sets_live_after(op_idx
)
354 for i
, j
in combinations(reg_sets
, 2):
355 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
356 interference_graph
[i
].add_edge(interference_graph
[j
])
357 for i
, j
in op
.get_extra_interferences():
358 i
= merged_reg_sets
[i
]
359 j
= merged_reg_sets
[j
]
360 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
361 interference_graph
[i
].add_edge(interference_graph
[j
])
363 nodes_remaining
= OSet(interference_graph
.values())
365 def local_colorability_score(node
):
366 # type: (IGNode) -> int
367 """ returns a positive integer if node is locally colorable, returns
368 zero or a negative integer if node isn't known to be locally
369 colorable, the more negative the value, the less colorable
371 if node
not in nodes_remaining
:
373 retval
= len(node
.reg_class
)
374 for neighbor
in node
.edges
:
375 if neighbor
in nodes_remaining
:
376 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
379 node_stack
= [] # type: list[IGNode]
381 best_node
= None # type: None | IGNode
383 for node
in nodes_remaining
:
384 score
= local_colorability_score(node
)
385 if best_node
is None or score
> best_score
:
389 # it's locally colorable, no need to find a better one
392 if best_node
is None:
394 node_stack
.append(best_node
)
395 nodes_remaining
.remove(best_node
)
397 retval
= {} # type: dict[SSAVal, RegLoc]
399 while len(node_stack
) > 0:
400 node
= node_stack
.pop()
401 if node
.reg
is not None:
402 if node
.reg_conflicts_with_neighbors(node
.reg
):
403 return AllocationFailed(node
=node
,
404 live_intervals
=live_intervals
,
405 interference_graph
=interference_graph
)
407 # pick the first non-conflicting register in node.reg_class, since
408 # register classes are ordered from most preferred to least
409 # preferred register.
410 for reg
in node
.reg_class
:
411 if not node
.reg_conflicts_with_neighbors(reg
):
415 return AllocationFailed(node
=node
,
416 live_intervals
=live_intervals
,
417 interference_graph
=interference_graph
)
419 for ssa_val
, offset
in node
.merged_reg_set
.items():
420 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
425 def allocate_registers(ops
):
426 # type: (list[Op]) -> dict[SSAVal, RegLoc]
427 retval
= try_allocate_registers_without_spilling(ops
)
428 if isinstance(retval
, AllocationFailed
):
429 # TODO: implement spilling
430 raise AllocationFailedError(
431 "spilling required but not yet implemented", retval
)