2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import TYPE_CHECKING
, Generic
, Iterable
, Mapping
, TypeVar
11 from nmutil
.plain_data
import plain_data
13 from bigint_presentation_code
.compiler_ir
import (GPRRangeType
, Op
, RegClass
,
14 RegLoc
, RegType
, SSAVal
)
15 from bigint_presentation_code
.ordered_set
import OFSet
, OSet
18 from typing_extensions
import final
24 _RegType
= TypeVar("_RegType", bound
=RegType
)
27 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
29 __slots__
= "first_write", "last_use"
31 def __init__(self
, first_write
, last_use
=None):
32 # type: (int, int | None) -> None
34 last_use
= first_write
35 if last_use
< first_write
:
36 raise ValueError("uses must be after first_write")
37 if first_write
< 0 or last_use
< 0:
38 raise ValueError("indexes must be nonnegative")
39 self
.first_write
= first_write
40 self
.last_use
= last_use
42 def overlaps(self
, other
):
43 # type: (LiveInterval) -> bool
44 if self
.first_write
== other
.first_write
:
46 return self
.last_use
> other
.first_write \
47 and other
.last_use
> self
.first_write
49 def __add__(self
, use
):
50 # type: (int) -> LiveInterval
51 last_use
= max(self
.last_use
, use
)
52 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
55 def live_after_op_range(self
):
56 """the range of op indexes where self is live immediately after the
59 return range(self
.first_write
, self
.last_use
)
63 class MergedRegSet(Mapping
[SSAVal
[_RegType
], int]):
64 def __init__(self
, reg_set
):
65 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
66 self
.__items
= {} # type: dict[SSAVal[_RegType], int]
67 if isinstance(reg_set
, SSAVal
):
68 reg_set
= [(reg_set
, 0)]
69 for ssa_val
, offset
in reg_set
:
70 if ssa_val
in self
.__items
:
71 other
= self
.__items
[ssa_val
]
74 f
"can't merge register sets: conflicting offsets: "
75 f
"for {ssa_val}: {offset} != {other}")
77 self
.__items
[ssa_val
] = offset
79 for i
in self
.__items
.items():
82 if first_item
is None:
83 raise ValueError("can't have empty MergedRegs")
84 first_ssa_val
, start
= first_item
86 if isinstance(ty
, GPRRangeType
):
87 stop
= start
+ ty
.length
88 for ssa_val
, offset
in self
.__items
.items():
89 if not isinstance(ssa_val
.ty
, GPRRangeType
):
90 raise ValueError(f
"can't merge incompatible types: "
91 f
"{ssa_val.ty} and {ty}")
92 stop
= max(stop
, offset
+ ssa_val
.ty
.length
)
93 start
= min(start
, offset
)
94 ty
= GPRRangeType(stop
- start
)
97 for ssa_val
, offset
in self
.__items
.items():
99 raise ValueError(f
"can't have non-zero offset "
102 raise ValueError(f
"can't merge incompatible types: "
103 f
"{ssa_val.ty} and {ty}")
104 self
.__start
= start
# type: int
105 self
.__stop
= stop
# type: int
106 self
.__ty
= ty
# type: RegType
107 self
.__hash
= hash(OFSet(self
.items()))
110 def from_equality_constraint(constraint_sequence
):
111 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
112 if len(constraint_sequence
) == 1:
113 # any type allowed with len = 1
114 return MergedRegSet(constraint_sequence
[0])
117 for val
in constraint_sequence
:
118 if not isinstance(val
.ty
, GPRRangeType
):
119 raise ValueError("equality constraint sequences must only "
120 "have SSAVal type GPRRangeType")
121 retval
.append((val
, offset
))
122 offset
+= val
.ty
.length
123 return MergedRegSet(retval
)
139 return range(self
.__start
, self
.__stop
)
141 def offset_by(self
, amount
):
142 # type: (int) -> MergedRegSet[_RegType]
143 return MergedRegSet((k
, v
+ amount
) for k
, v
in self
.items())
145 def normalized(self
):
146 # type: () -> MergedRegSet[_RegType]
147 return self
.offset_by(-self
.start
)
149 def with_offset_to_match(self
, target
):
150 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
151 for ssa_val
, offset
in self
.items():
152 if ssa_val
in target
:
153 return self
.offset_by(target
[ssa_val
] - offset
)
154 raise ValueError("can't change offset to match unrelated MergedRegSet")
156 def __getitem__(self
, item
):
157 # type: (SSAVal[_RegType]) -> int
158 return self
.__items
[item
]
161 return iter(self
.__items
)
164 return len(self
.__items
)
170 return f
"MergedRegSet({list(self.__items.items())})"
174 class MergedRegSets(Mapping
[SSAVal
, MergedRegSet
[_RegType
]], Generic
[_RegType
]):
175 def __init__(self
, ops
):
176 # type: (Iterable[Op]) -> None
177 merged_sets
= {} # type: dict[SSAVal, MergedRegSet[_RegType]]
179 for val
in (*op
.inputs().values(), *op
.outputs().values()):
180 if val
not in merged_sets
:
181 merged_sets
[val
] = MergedRegSet(val
)
182 for e
in op
.get_equality_constraints():
183 lhs_set
= MergedRegSet
.from_equality_constraint(e
.lhs
)
184 rhs_set
= MergedRegSet
.from_equality_constraint(e
.rhs
)
185 items
= [] # type: list[tuple[SSAVal, int]]
187 s
= merged_sets
[i
].with_offset_to_match(lhs_set
)
188 items
.extend(s
.items())
190 s
= merged_sets
[i
].with_offset_to_match(rhs_set
)
191 items
.extend(s
.items())
192 full_set
= MergedRegSet(items
)
193 for val
in full_set
.keys():
194 merged_sets
[val
] = full_set
196 self
.__map
= {k
: v
.normalized() for k
, v
in merged_sets
.items()}
198 def __getitem__(self
, key
):
199 # type: (SSAVal) -> MergedRegSet
200 return self
.__map
[key
]
203 return iter(self
.__map
)
206 return len(self
.__map
)
209 return f
"MergedRegSets(data={self.__map})"
213 class LiveIntervals(Mapping
[MergedRegSet
[_RegType
], LiveInterval
]):
214 def __init__(self
, ops
):
215 # type: (list[Op]) -> None
216 self
.__merged
_reg
_sets
= MergedRegSets(ops
)
217 live_intervals
= {} # type: dict[MergedRegSet[_RegType], LiveInterval]
218 for op_idx
, op
in enumerate(ops
):
219 for val
in op
.inputs().values():
220 live_intervals
[self
.__merged
_reg
_sets
[val
]] += op_idx
221 for val
in op
.outputs().values():
222 reg_set
= self
.__merged
_reg
_sets
[val
]
223 if reg_set
not in live_intervals
:
224 live_intervals
[reg_set
] = LiveInterval(op_idx
)
226 live_intervals
[reg_set
] += op_idx
227 self
.__live
_intervals
= live_intervals
228 live_after
= [] # type: list[OSet[MergedRegSet[_RegType]]]
229 live_after
+= (OSet() for _
in ops
)
230 for reg_set
, live_interval
in self
.__live
_intervals
.items():
231 for i
in live_interval
.live_after_op_range
:
232 live_after
[i
].add(reg_set
)
233 self
.__live
_after
= [OFSet(i
) for i
in live_after
]
236 def merged_reg_sets(self
):
237 return self
.__merged
_reg
_sets
239 def __getitem__(self
, key
):
240 # type: (MergedRegSet[_RegType]) -> LiveInterval
241 return self
.__live
_intervals
[key
]
244 return iter(self
.__live
_intervals
)
247 return len(self
.__live
_intervals
)
249 def reg_sets_live_after(self
, op_index
):
250 # type: (int) -> OFSet[MergedRegSet[_RegType]]
251 return self
.__live
_after
[op_index
]
254 reg_sets_live_after
= dict(enumerate(self
.__live
_after
))
255 return (f
"LiveIntervals(live_intervals={self.__live_intervals}, "
256 f
"merged_reg_sets={self.merged_reg_sets}, "
257 f
"reg_sets_live_after={reg_sets_live_after})")
261 class IGNode(Generic
[_RegType
]):
262 """ interference graph node """
263 __slots__
= "merged_reg_set", "edges", "reg"
265 def __init__(self
, merged_reg_set
, edges
=(), reg
=None):
266 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
267 self
.merged_reg_set
= merged_reg_set
268 self
.edges
= OSet(edges
)
271 def add_edge(self
, other
):
272 # type: (IGNode) -> None
273 self
.edges
.add(other
)
274 other
.edges
.add(self
)
276 def __eq__(self
, other
):
277 # type: (object) -> bool
278 if isinstance(other
, IGNode
):
279 return self
.merged_reg_set
== other
.merged_reg_set
280 return NotImplemented
283 return hash(self
.merged_reg_set
)
285 def __repr__(self
, nodes
=None):
286 # type: (None | dict[IGNode, int]) -> str
290 return f
"<IGNode #{nodes[self]}>"
291 nodes
[self
] = len(nodes
)
292 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
293 return (f
"IGNode(#{nodes[self]}, "
294 f
"merged_reg_set={self.merged_reg_set}, "
300 # type: () -> RegClass
301 return self
.merged_reg_set
.ty
.reg_class
303 def reg_conflicts_with_neighbors(self
, reg
):
304 # type: (RegLoc) -> bool
305 for neighbor
in self
.edges
:
306 if neighbor
.reg
is not None and neighbor
.reg
.conflicts(reg
):
312 class InterferenceGraph(Mapping
[MergedRegSet
[_RegType
], IGNode
[_RegType
]]):
313 def __init__(self
, merged_reg_sets
):
314 # type: (Iterable[MergedRegSet[_RegType]]) -> None
315 self
.__nodes
= {i
: IGNode(i
) for i
in merged_reg_sets
}
317 def __getitem__(self
, key
):
318 # type: (MergedRegSet[_RegType]) -> IGNode
319 return self
.__nodes
[key
]
322 return iter(self
.__nodes
)
325 return len(self
.__nodes
)
329 nodes_text
= [f
"...: {node.__repr__(nodes)}" for node
in self
.values()]
330 nodes_text
= ", ".join(nodes_text
)
331 return f
"InterferenceGraph(nodes={{{nodes_text}}})"
335 class AllocationFailed
:
336 __slots__
= "node", "live_intervals", "interference_graph"
338 def __init__(self
, node
, live_intervals
, interference_graph
):
339 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
341 self
.live_intervals
= live_intervals
342 self
.interference_graph
= interference_graph
345 def try_allocate_registers_without_spilling(ops
):
346 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
348 live_intervals
= LiveIntervals(ops
)
349 merged_reg_sets
= live_intervals
.merged_reg_sets
350 interference_graph
= InterferenceGraph(merged_reg_sets
.values())
351 for op_idx
, op
in enumerate(ops
):
352 reg_sets
= live_intervals
.reg_sets_live_after(op_idx
)
353 for i
, j
in combinations(reg_sets
, 2):
354 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
355 interference_graph
[i
].add_edge(interference_graph
[j
])
356 for i
, j
in op
.get_extra_interferences():
357 i
= merged_reg_sets
[i
]
358 j
= merged_reg_sets
[j
]
359 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
360 interference_graph
[i
].add_edge(interference_graph
[j
])
362 nodes_remaining
= OSet(interference_graph
.values())
364 def local_colorability_score(node
):
365 # type: (IGNode) -> int
366 """ returns a positive integer if node is locally colorable, returns
367 zero or a negative integer if node isn't known to be locally
368 colorable, the more negative the value, the less colorable
370 if node
not in nodes_remaining
:
372 retval
= len(node
.reg_class
)
373 for neighbor
in node
.edges
:
374 if neighbor
in nodes_remaining
:
375 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
378 node_stack
= [] # type: list[IGNode]
380 best_node
= None # type: None | IGNode
382 for node
in nodes_remaining
:
383 score
= local_colorability_score(node
)
384 if best_node
is None or score
> best_score
:
388 # it's locally colorable, no need to find a better one
391 if best_node
is None:
393 node_stack
.append(best_node
)
394 nodes_remaining
.remove(best_node
)
396 retval
= {} # type: dict[SSAVal, RegLoc]
398 while len(node_stack
) > 0:
399 node
= node_stack
.pop()
400 if node
.reg
is not None:
401 if node
.reg_conflicts_with_neighbors(node
.reg
):
402 return AllocationFailed(node
=node
,
403 live_intervals
=live_intervals
,
404 interference_graph
=interference_graph
)
406 # pick the first non-conflicting register in node.reg_class, since
407 # register classes are ordered from most preferred to least
408 # preferred register.
409 for reg
in node
.reg_class
:
410 if not node
.reg_conflicts_with_neighbors(reg
):
414 return AllocationFailed(node
=node
,
415 live_intervals
=live_intervals
,
416 interference_graph
=interference_graph
)
418 for ssa_val
, offset
in node
.merged_reg_set
.items():
419 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
424 def allocate_registers(ops
):
425 # type: (list[Op]) -> None
426 raise NotImplementedError