2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Any
, Iterable
, Iterator
, Mapping
, MutableSet
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir2
import (BaseTy
, FnAnalysis
, Loc
,
15 LocSet
, Op
, ProgramRange
,
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
21 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
23 __slots__
= "first_write", "last_use"
25 def __init__(self
, first_write
, last_use
=None):
26 # type: (int, int | None) -> None
29 last_use
= first_write
30 if last_use
< first_write
:
31 raise ValueError("uses must be after first_write")
32 if first_write
< 0 or last_use
< 0:
33 raise ValueError("indexes must be nonnegative")
34 self
.first_write
= first_write
35 self
.last_use
= last_use
37 def overlaps(self
, other
):
38 # type: (LiveInterval) -> bool
39 if self
.first_write
== other
.first_write
:
41 return self
.last_use
> other
.first_write \
42 and other
.last_use
> self
.first_write
44 def __add__(self
, use
):
45 # type: (int) -> LiveInterval
46 last_use
= max(self
.last_use
, use
)
47 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
50 def live_after_op_range(self
):
51 """the range of op indexes where self is live immediately after the
54 return range(self
.first_write
, self
.last_use
)
57 class BadMergedSSAVal(ValueError):
61 @plain_data(frozen
=True)
64 """a set of `SSAVal`s along with their offsets, all register allocated as
67 Definition of the term `offset` for this class:
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
91 __slots__
= "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
93 def __init__(self
, fn_analysis
, ssa_val_offsets
):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self
.fn_analysis
= fn_analysis
96 if isinstance(ssa_val_offsets
, SSAVal
):
97 ssa_val_offsets
= {ssa_val_offsets
: 0}
98 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
100 for ssa_val
in self
.ssa_vals
:
101 first_ssa_val
= ssa_val
103 if first_ssa_val
is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len
= self
.ty
.reg_len
108 loc_set
= None # type: None | LocSet
109 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
110 def_spread_idx
= ssa_val
.defining_descriptor
.spread_index
or 0
113 # type: () -> Iterable[Loc]
114 for loc
in ssa_val
.def_loc_set_before_spread
:
115 disallowed_by_use
= False
116 for use
in fn_analysis
.uses
[ssa_val
]:
118 use
.defining_descriptor
.spread_index
or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start
= loc
.start
+ def_spread_idx
- use_spread_idx
126 use_loc
= Loc
.try_make(
127 loc
.kind
, start
=start
,
128 reg_len
=use
.ty_before_spread
.reg_len
)
129 if (use_loc
is None or
130 use_loc
not in use
.use_loc_set_before_spread
):
131 disallowed_by_use
= True
133 if disallowed_by_use
:
135 # FIXME: add spread consistency check
136 start
= loc
.start
- cur_offset
+ self
.offset
137 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
138 if loc
is not None and (loc_set
is None or loc
in loc_set
):
140 loc_set
= LocSet(locs())
141 assert loc_set
is not None, "already checked that self isn't empty"
142 if loc_set
.ty
is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
145 self
.loc_set
= loc_set
# type: LocSet
150 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
159 return min(self
.ssa_val_offsets_before_spread
.values())
164 return self
.first_ssa_val
.base_ty
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self
.ssa_val_offsets
.keys())
175 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
176 cur_ty
= ssa_val
.ty_before_spread
177 if self
.base_ty
!= cur_ty
.base_ty
:
178 raise BadMergedSSAVal(
179 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
181 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
184 def ssa_val_offsets_before_spread(self
):
185 # type: () -> FMap[SSAVal, int]
186 retval
= {} # type: dict[SSAVal, int]
187 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
189 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
192 def offset_by(self
, amount
):
193 # type: (int) -> MergedSSAVal
194 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
195 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
197 def normalized(self
):
198 # type: () -> MergedSSAVal
199 return self
.offset_by(-self
.offset
)
201 def with_offset_to_match(self
, target
, additional_offset
=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target
, MergedSSAVal
):
204 ssa_val_offsets
= target
.ssa_val_offsets
206 ssa_val_offsets
= {target
: 0}
207 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
208 if ssa_val
in ssa_val_offsets
:
209 return self
.offset_by(
210 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
213 def merged(self
, *others
):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval
= dict(self
.ssa_val_offsets
)
217 if other
.fn_analysis
!= self
.fn_analysis
:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
220 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
221 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
222 f
"{retval[ssa_val]} != {offset}")
223 retval
[ssa_val
] = offset
224 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
225 ssa_val_offsets
=retval
)
228 def live_interval(self
):
229 # type: () -> ProgramRange
230 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
231 start
= live_range
.start
232 stop
= live_range
.stop
233 for ssa_val
in self
.ssa_vals
:
234 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
235 start
= min(start
, live_range
.start
)
236 stop
= max(stop
, live_range
.stop
)
237 return ProgramRange(start
=start
, stop
=stop
)
241 class MergedSSAValsMap(Mapping
[SSAVal
, MergedSSAVal
]):
243 # type: (...) -> None
244 self
.__merge
_map
= {} # type: dict[SSAVal, MergedSSAVal]
245 self
.__values
_set
= MergedSSAValsSet(
246 _private_merge_map
=self
.__merge
_map
,
247 _private_values_set
=OSet())
249 def __getitem__(self
, __key
):
250 # type: (SSAVal) -> MergedSSAVal
251 return self
.__merge
_map
[__key
]
254 # type: () -> Iterator[SSAVal]
255 return iter(self
.__merge
_map
)
259 return len(self
.__merge
_map
)
262 def values_set(self
):
263 # type: () -> MergedSSAValsSet
264 return self
.__values
_set
268 s
= ",\n".join(repr(v
) for v
in self
.__values
_set
)
269 return f
"MergedSSAValsMap({{{s}}})"
273 class MergedSSAValsSet(MutableSet
[MergedSSAVal
]):
274 def __init__(self
, *,
275 _private_merge_map
, # type: dict[SSAVal, MergedSSAVal]
276 _private_values_set
, # type: OSet[MergedSSAVal]
278 # type: (...) -> None
279 self
.__merge
_map
= _private_merge_map
280 self
.__values
_set
= _private_values_set
283 def _from_iterable(cls
, it
):
284 # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
287 def __contains__(self
, value
):
288 # type: (MergedSSAVal | Any) -> bool
289 return value
in self
.__values
_set
292 # type: () -> Iterator[MergedSSAVal]
293 return iter(self
.__values
_set
)
297 return len(self
.__values
_set
)
299 def add(self
, value
):
300 # type: (MergedSSAVal) -> None
303 added
= 0 # type: int | None
305 for ssa_val
in value
.ssa_vals
:
306 if ssa_val
in self
.__merge
_map
:
308 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
309 f
"{value} and {self.__merge_map[ssa_val]}")
310 self
.__merge
_map
[ssa_val
] = value
312 self
.__values
_set
.add(value
)
315 if added
is not None:
316 # remove partially added stuff
317 for idx
, ssa_val
in enumerate(value
.ssa_vals
):
320 del self
.__merge
_map
[ssa_val
]
322 def discard(self
, value
):
323 # type: (MergedSSAVal) -> None
324 if value
not in self
:
326 self
.__values
_set
.discard(value
)
327 for ssa_val
in value
.ssa_val_offsets
.keys():
328 del self
.__merge
_map
[ssa_val
]
332 s
= ",\n".join(repr(v
) for v
in self
.__values
_set
)
333 return f
"MergedSSAValsSet({{{s}}})"
336 @plain_data(frozen
=True)
339 __slots__
= "fn_analysis", "merge_map", "merged_ssa_vals"
341 def __init__(self
, fn_analysis
, merged_ssa_vals
):
342 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
343 self
.fn_analysis
= fn_analysis
344 self
.merge_map
= MergedSSAValsMap()
345 self
.merged_ssa_vals
= self
.merge_map
.values_set
346 for i
in merged_ssa_vals
:
347 self
.merged_ssa_vals
.add(i
)
349 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
350 # type: (SSAVal, SSAVal, int) -> MergedSSAVal
351 merged1
= self
.merge_map
[ssa_val1
]
352 merged2
= self
.merge_map
[ssa_val2
]
353 merged
= merged1
.with_offset_to_match(ssa_val1
)
354 merged
= merged
.merged(merged2
.with_offset_to_match(
355 ssa_val2
, additional_offset
=additional_offset
))
356 self
.merged_ssa_vals
.remove(merged1
)
357 self
.merged_ssa_vals
.remove(merged2
)
358 self
.merged_ssa_vals
.add(merged
)
362 def minimally_merged(fn_analysis
):
363 # type: (FnAnalysis) -> MergedSSAVals
364 retval
= MergedSSAVals(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
365 for op
in fn_analysis
.fn
.ops
:
366 for inp
in op
.input_uses
:
367 if inp
.unspread_start
!= inp
:
368 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
369 additional_offset
=inp
.reg_offset_in_unspread
)
370 for out
in op
.outputs
:
371 if out
.unspread_start
!= out
:
372 retval
.merge(out
.unspread_start
, out
,
373 additional_offset
=out
.reg_offset_in_unspread
)
374 if out
.tied_input
is not None:
375 retval
.merge(out
.tied_input
.ssa_val
, out
)
381 """ interference graph node """
382 __slots__
= "merged_ssa_val", "edges", "loc"
384 def __init__(self
, merged_ssa_val
, edges
=(), loc
=None):
385 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
386 self
.merged_ssa_val
= merged_ssa_val
387 self
.edges
= OSet(edges
)
390 def add_edge(self
, other
):
391 # type: (IGNode) -> None
392 self
.edges
.add(other
)
393 other
.edges
.add(self
)
395 def __eq__(self
, other
):
396 # type: (object) -> bool
397 if isinstance(other
, IGNode
):
398 return self
.merged_ssa_val
== other
.merged_ssa_val
399 return NotImplemented
402 return hash(self
.merged_ssa_val
)
404 def __repr__(self
, nodes
=None):
405 # type: (None | dict[IGNode, int]) -> str
409 return f
"<IGNode #{nodes[self]}>"
410 nodes
[self
] = len(nodes
)
411 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
412 return (f
"IGNode(#{nodes[self]}, "
413 f
"merged_ssa_val={self.merged_ssa_val}, "
420 return self
.merged_ssa_val
.loc_set
422 def loc_conflicts_with_neighbors(self
, loc
):
423 # type: (Loc) -> bool
424 for neighbor
in self
.edges
:
425 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
431 class AllocationFailed
:
432 __slots__
= "node", "merged_ssa_vals", "interference_graph"
434 def __init__(self
, node
, merged_ssa_vals
, interference_graph
):
435 # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
438 self
.merged_ssa_vals
= merged_ssa_vals
439 self
.interference_graph
= interference_graph
442 class AllocationFailedError(Exception):
443 def __init__(self
, msg
, allocation_failed
):
444 # type: (str, AllocationFailed) -> None
445 super().__init
__(msg
, allocation_failed
)
446 self
.allocation_failed
= allocation_failed
449 def try_allocate_registers_without_spilling(merged_ssa_vals
):
450 # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
452 interference_graph
= {
453 i
: IGNode(i
) for i
in merged_ssa_vals
.merged_ssa_vals
}
454 fn_analysis
= merged_ssa_vals
.fn_analysis
455 for ssa_vals
in fn_analysis
.live_at
.values():
456 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
457 for ssa_val
in ssa_vals
:
458 live_merged_ssa_vals
.add(merged_ssa_vals
.merge_map
[ssa_val
])
459 for i
, j
in combinations(live_merged_ssa_vals
, 2):
460 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
461 interference_graph
[i
].add_edge(interference_graph
[j
])
463 nodes_remaining
= OSet(interference_graph
.values())
465 # FIXME: work on code from here
467 def local_colorability_score(node
):
468 # type: (IGNode) -> int
469 """ returns a positive integer if node is locally colorable, returns
470 zero or a negative integer if node isn't known to be locally
471 colorable, the more negative the value, the less colorable
473 if node
not in nodes_remaining
:
475 retval
= len(node
.loc_set
)
476 for neighbor
in node
.edges
:
477 if neighbor
in nodes_remaining
:
478 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
481 node_stack
= [] # type: list[IGNode]
483 best_node
= None # type: None | IGNode
485 for node
in nodes_remaining
:
486 score
= local_colorability_score(node
)
487 if best_node
is None or score
> best_score
:
491 # it's locally colorable, no need to find a better one
494 if best_node
is None:
496 node_stack
.append(best_node
)
497 nodes_remaining
.remove(best_node
)
499 retval
= {} # type: dict[SSAVal, RegLoc]
501 while len(node_stack
) > 0:
502 node
= node_stack
.pop()
503 if node
.reg
is not None:
504 if node
.reg_conflicts_with_neighbors(node
.reg
):
505 return AllocationFailed(node
=node
,
506 live_intervals
=live_intervals
,
507 interference_graph
=interference_graph
)
509 # pick the first non-conflicting register in node.reg_class, since
510 # register classes are ordered from most preferred to least
511 # preferred register.
512 for reg
in node
.reg_class
:
513 if not node
.reg_conflicts_with_neighbors(reg
):
517 return AllocationFailed(node
=node
,
518 live_intervals
=live_intervals
,
519 interference_graph
=interference_graph
)
521 for ssa_val
, offset
in node
.merged_reg_set
.items():
522 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
527 def allocate_registers(ops
):
528 # type: (list[Op]) -> dict[SSAVal, RegLoc]
529 retval
= try_allocate_registers_without_spilling(ops
)
530 if isinstance(retval
, AllocationFailed
):
531 # TODO: implement spilling
532 raise AllocationFailedError(
533 "spilling required but not yet implemented", retval
)