b8269e4bc354f6c31b5433b10d19b71a39a38c1e
2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Generic
, Iterable
, Mapping
, TypeVar
11 from nmutil
.plain_data
import plain_data
13 from bigint_presentation_code
.compiler_ir
import (GPRRangeType
, Op
, RegClass
,
14 RegLoc
, RegType
, SSAVal
)
15 from bigint_presentation_code
.util
import OFSet
, OSet
, final
17 _RegType
= TypeVar("_RegType", bound
=RegType
)
20 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
22 __slots__
= "first_write", "last_use"
24 def __init__(self
, first_write
, last_use
=None):
25 # type: (int, int | None) -> None
27 last_use
= first_write
28 if last_use
< first_write
:
29 raise ValueError("uses must be after first_write")
30 if first_write
< 0 or last_use
< 0:
31 raise ValueError("indexes must be nonnegative")
32 self
.first_write
= first_write
33 self
.last_use
= last_use
35 def overlaps(self
, other
):
36 # type: (LiveInterval) -> bool
37 if self
.first_write
== other
.first_write
:
39 return self
.last_use
> other
.first_write \
40 and other
.last_use
> self
.first_write
42 def __add__(self
, use
):
43 # type: (int) -> LiveInterval
44 last_use
= max(self
.last_use
, use
)
45 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
48 def live_after_op_range(self
):
49 """the range of op indexes where self is live immediately after the
52 return range(self
.first_write
, self
.last_use
)
56 class MergedRegSet(Mapping
[SSAVal
[_RegType
], int]):
57 def __init__(self
, reg_set
):
58 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
59 self
.__items
= {} # type: dict[SSAVal[_RegType], int]
60 if isinstance(reg_set
, SSAVal
):
61 reg_set
= [(reg_set
, 0)]
62 for ssa_val
, offset
in reg_set
:
63 if ssa_val
in self
.__items
:
64 other
= self
.__items
[ssa_val
]
67 f
"can't merge register sets: conflicting offsets: "
68 f
"for {ssa_val}: {offset} != {other}")
70 self
.__items
[ssa_val
] = offset
72 for i
in self
.__items
.items():
75 if first_item
is None:
76 raise ValueError("can't have empty MergedRegs")
77 first_ssa_val
, start
= first_item
79 if isinstance(ty
, GPRRangeType
):
80 stop
= start
+ ty
.length
81 for ssa_val
, offset
in self
.__items
.items():
82 if not isinstance(ssa_val
.ty
, GPRRangeType
):
83 raise ValueError(f
"can't merge incompatible types: "
84 f
"{ssa_val.ty} and {ty}")
85 stop
= max(stop
, offset
+ ssa_val
.ty
.length
)
86 start
= min(start
, offset
)
87 ty
= GPRRangeType(stop
- start
)
90 for ssa_val
, offset
in self
.__items
.items():
92 raise ValueError(f
"can't have non-zero offset "
95 raise ValueError(f
"can't merge incompatible types: "
96 f
"{ssa_val.ty} and {ty}")
97 self
.__start
= start
# type: int
98 self
.__stop
= stop
# type: int
99 self
.__ty
= ty
# type: RegType
100 self
.__hash
= hash(OFSet(self
.items()))
103 def from_equality_constraint(constraint_sequence
):
104 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
105 if len(constraint_sequence
) == 1:
106 # any type allowed with len = 1
107 return MergedRegSet(constraint_sequence
[0])
110 for val
in constraint_sequence
:
111 if not isinstance(val
.ty
, GPRRangeType
):
112 raise ValueError("equality constraint sequences must only "
113 "have SSAVal type GPRRangeType")
114 retval
.append((val
, offset
))
115 offset
+= val
.ty
.length
116 return MergedRegSet(retval
)
132 return range(self
.__start
, self
.__stop
)
134 def offset_by(self
, amount
):
135 # type: (int) -> MergedRegSet[_RegType]
136 return MergedRegSet((k
, v
+ amount
) for k
, v
in self
.items())
138 def normalized(self
):
139 # type: () -> MergedRegSet[_RegType]
140 return self
.offset_by(-self
.start
)
142 def with_offset_to_match(self
, target
):
143 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
144 for ssa_val
, offset
in self
.items():
145 if ssa_val
in target
:
146 return self
.offset_by(target
[ssa_val
] - offset
)
147 raise ValueError("can't change offset to match unrelated MergedRegSet")
149 def __getitem__(self
, item
):
150 # type: (SSAVal[_RegType]) -> int
151 return self
.__items
[item
]
154 return iter(self
.__items
)
157 return len(self
.__items
)
163 return f
"MergedRegSet({list(self.__items.items())})"
167 class MergedRegSets(Mapping
[SSAVal
, MergedRegSet
[_RegType
]], Generic
[_RegType
]):
168 def __init__(self
, ops
):
169 # type: (Iterable[Op]) -> None
170 merged_sets
= {} # type: dict[SSAVal, MergedRegSet[_RegType]]
172 for val
in (*op
.inputs().values(), *op
.outputs().values()):
173 if val
not in merged_sets
:
174 merged_sets
[val
] = MergedRegSet(val
)
175 for e
in op
.get_equality_constraints():
176 lhs_set
= MergedRegSet
.from_equality_constraint(e
.lhs
)
177 rhs_set
= MergedRegSet
.from_equality_constraint(e
.rhs
)
178 items
= [] # type: list[tuple[SSAVal, int]]
180 s
= merged_sets
[i
].with_offset_to_match(lhs_set
)
181 items
.extend(s
.items())
183 s
= merged_sets
[i
].with_offset_to_match(rhs_set
)
184 items
.extend(s
.items())
185 full_set
= MergedRegSet(items
)
186 for val
in full_set
.keys():
187 merged_sets
[val
] = full_set
189 self
.__map
= {k
: v
.normalized() for k
, v
in merged_sets
.items()}
191 def __getitem__(self
, key
):
192 # type: (SSAVal) -> MergedRegSet
193 return self
.__map
[key
]
196 return iter(self
.__map
)
199 return len(self
.__map
)
202 return f
"MergedRegSets(data={self.__map})"
206 class LiveIntervals(Mapping
[MergedRegSet
[_RegType
], LiveInterval
]):
207 def __init__(self
, ops
):
208 # type: (list[Op]) -> None
209 self
.__merged
_reg
_sets
= MergedRegSets(ops
)
210 live_intervals
= {} # type: dict[MergedRegSet[_RegType], LiveInterval]
211 for op_idx
, op
in enumerate(ops
):
212 for val
in op
.inputs().values():
213 live_intervals
[self
.__merged
_reg
_sets
[val
]] += op_idx
214 for val
in op
.outputs().values():
215 reg_set
= self
.__merged
_reg
_sets
[val
]
216 if reg_set
not in live_intervals
:
217 live_intervals
[reg_set
] = LiveInterval(op_idx
)
219 live_intervals
[reg_set
] += op_idx
220 self
.__live
_intervals
= live_intervals
221 live_after
= [] # type: list[OSet[MergedRegSet[_RegType]]]
222 live_after
+= (OSet() for _
in ops
)
223 for reg_set
, live_interval
in self
.__live
_intervals
.items():
224 for i
in live_interval
.live_after_op_range
:
225 live_after
[i
].add(reg_set
)
226 self
.__live
_after
= [OFSet(i
) for i
in live_after
]
229 def merged_reg_sets(self
):
230 return self
.__merged
_reg
_sets
232 def __getitem__(self
, key
):
233 # type: (MergedRegSet[_RegType]) -> LiveInterval
234 return self
.__live
_intervals
[key
]
237 return iter(self
.__live
_intervals
)
240 return len(self
.__live
_intervals
)
242 def reg_sets_live_after(self
, op_index
):
243 # type: (int) -> OFSet[MergedRegSet[_RegType]]
244 return self
.__live
_after
[op_index
]
247 reg_sets_live_after
= dict(enumerate(self
.__live
_after
))
248 return (f
"LiveIntervals(live_intervals={self.__live_intervals}, "
249 f
"merged_reg_sets={self.merged_reg_sets}, "
250 f
"reg_sets_live_after={reg_sets_live_after})")
254 class IGNode(Generic
[_RegType
]):
255 """ interference graph node """
256 __slots__
= "merged_reg_set", "edges", "reg"
258 def __init__(self
, merged_reg_set
, edges
=(), reg
=None):
259 # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
260 self
.merged_reg_set
= merged_reg_set
261 self
.edges
= OSet(edges
)
264 def add_edge(self
, other
):
265 # type: (IGNode) -> None
266 self
.edges
.add(other
)
267 other
.edges
.add(self
)
269 def __eq__(self
, other
):
270 # type: (object) -> bool
271 if isinstance(other
, IGNode
):
272 return self
.merged_reg_set
== other
.merged_reg_set
273 return NotImplemented
276 return hash(self
.merged_reg_set
)
278 def __repr__(self
, nodes
=None):
279 # type: (None | dict[IGNode, int]) -> str
283 return f
"<IGNode #{nodes[self]}>"
284 nodes
[self
] = len(nodes
)
285 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
286 return (f
"IGNode(#{nodes[self]}, "
287 f
"merged_reg_set={self.merged_reg_set}, "
293 # type: () -> RegClass
294 return self
.merged_reg_set
.ty
.reg_class
296 def reg_conflicts_with_neighbors(self
, reg
):
297 # type: (RegLoc) -> bool
298 for neighbor
in self
.edges
:
299 if neighbor
.reg
is not None and neighbor
.reg
.conflicts(reg
):
305 class InterferenceGraph(Mapping
[MergedRegSet
[_RegType
], IGNode
[_RegType
]]):
306 def __init__(self
, merged_reg_sets
):
307 # type: (Iterable[MergedRegSet[_RegType]]) -> None
308 self
.__nodes
= {i
: IGNode(i
) for i
in merged_reg_sets
}
310 def __getitem__(self
, key
):
311 # type: (MergedRegSet[_RegType]) -> IGNode
312 return self
.__nodes
[key
]
315 return iter(self
.__nodes
)
318 return len(self
.__nodes
)
322 nodes_text
= [f
"...: {node.__repr__(nodes)}" for node
in self
.values()]
323 nodes_text
= ", ".join(nodes_text
)
324 return f
"InterferenceGraph(nodes={{{nodes_text}}})"
328 class AllocationFailed
:
329 __slots__
= "node", "live_intervals", "interference_graph"
331 def __init__(self
, node
, live_intervals
, interference_graph
):
332 # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
334 self
.live_intervals
= live_intervals
335 self
.interference_graph
= interference_graph
338 class AllocationFailedError(Exception):
339 def __init__(self
, msg
, allocation_failed
):
340 # type: (str, AllocationFailed) -> None
341 super().__init
__(msg
, allocation_failed
)
342 self
.allocation_failed
= allocation_failed
345 def try_allocate_registers_without_spilling(ops
):
346 # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
348 live_intervals
= LiveIntervals(ops
)
349 merged_reg_sets
= live_intervals
.merged_reg_sets
350 interference_graph
= InterferenceGraph(merged_reg_sets
.values())
351 for op_idx
, op
in enumerate(ops
):
352 reg_sets
= live_intervals
.reg_sets_live_after(op_idx
)
353 for i
, j
in combinations(reg_sets
, 2):
354 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
355 interference_graph
[i
].add_edge(interference_graph
[j
])
356 for i
, j
in op
.get_extra_interferences():
357 i
= merged_reg_sets
[i
]
358 j
= merged_reg_sets
[j
]
359 if i
.ty
.reg_class
.max_conflicts_with(j
.ty
.reg_class
) != 0:
360 interference_graph
[i
].add_edge(interference_graph
[j
])
362 nodes_remaining
= OSet(interference_graph
.values())
364 def local_colorability_score(node
):
365 # type: (IGNode) -> int
366 """ returns a positive integer if node is locally colorable, returns
367 zero or a negative integer if node isn't known to be locally
368 colorable, the more negative the value, the less colorable
370 if node
not in nodes_remaining
:
372 retval
= len(node
.reg_class
)
373 for neighbor
in node
.edges
:
374 if neighbor
in nodes_remaining
:
375 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
378 node_stack
= [] # type: list[IGNode]
380 best_node
= None # type: None | IGNode
382 for node
in nodes_remaining
:
383 score
= local_colorability_score(node
)
384 if best_node
is None or score
> best_score
:
388 # it's locally colorable, no need to find a better one
391 if best_node
is None:
393 node_stack
.append(best_node
)
394 nodes_remaining
.remove(best_node
)
396 retval
= {} # type: dict[SSAVal, RegLoc]
398 while len(node_stack
) > 0:
399 node
= node_stack
.pop()
400 if node
.reg
is not None:
401 if node
.reg_conflicts_with_neighbors(node
.reg
):
402 return AllocationFailed(node
=node
,
403 live_intervals
=live_intervals
,
404 interference_graph
=interference_graph
)
406 # pick the first non-conflicting register in node.reg_class, since
407 # register classes are ordered from most preferred to least
408 # preferred register.
409 for reg
in node
.reg_class
:
410 if not node
.reg_conflicts_with_neighbors(reg
):
414 return AllocationFailed(node
=node
,
415 live_intervals
=live_intervals
,
416 interference_graph
=interference_graph
)
418 for ssa_val
, offset
in node
.merged_reg_set
.items():
419 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
424 def allocate_registers(ops
):
425 # type: (list[Op]) -> dict[SSAVal, RegLoc]
426 retval
= try_allocate_registers_without_spilling(ops
)
427 if isinstance(retval
, AllocationFailed
):
428 # TODO: implement spilling
429 raise AllocationFailedError(
430 "spilling required but not yet implemented", retval
)