2 Register Allocator for Toom-Cook algorithm generator for SVP64
4 this uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
8 from itertools
import combinations
9 from typing
import Any
, Iterable
, Iterator
, Mapping
, MutableMapping
, MutableSet
, Dict
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import plain_data
14 from bigint_presentation_code
.compiler_ir2
import (BaseTy
, FnAnalysis
, Loc
,
15 LocSet
, Op
, ProgramRange
,
17 from bigint_presentation_code
.type_util
import final
18 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
21 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
23 __slots__
= "first_write", "last_use"
25 def __init__(self
, first_write
, last_use
=None):
26 # type: (int, int | None) -> None
29 last_use
= first_write
30 if last_use
< first_write
:
31 raise ValueError("uses must be after first_write")
32 if first_write
< 0 or last_use
< 0:
33 raise ValueError("indexes must be nonnegative")
34 self
.first_write
= first_write
35 self
.last_use
= last_use
37 def overlaps(self
, other
):
38 # type: (LiveInterval) -> bool
39 if self
.first_write
== other
.first_write
:
41 return self
.last_use
> other
.first_write \
42 and other
.last_use
> self
.first_write
44 def __add__(self
, use
):
45 # type: (int) -> LiveInterval
46 last_use
= max(self
.last_use
, use
)
47 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
50 def live_after_op_range(self
):
51 """the range of op indexes where self is live immediately after the
54 return range(self
.first_write
, self
.last_use
)
57 class BadMergedSSAVal(ValueError):
61 @plain_data(frozen
=True)
64 """a set of `SSAVal`s along with their offsets, all register allocated as
67 Definition of the term `offset` for this class:
69 Let `locs[x]` be the `Loc` that `x` is assigned to after register
70 allocation and let `msv` be a `MergedSSAVal` instance, then the offset
71 for each `SSAVal` `ssa_val` in `msv` is defined as:
74 msv.ssa_val_offsets[ssa_val] = (msv.offset
75 + locs[ssa_val].start - locs[msv].start)
83 msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
86 if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
87 * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
88 * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
89 * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
91 __slots__
= "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
93 def __init__(self
, fn_analysis
, ssa_val_offsets
):
94 # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
95 self
.fn_analysis
= fn_analysis
96 if isinstance(ssa_val_offsets
, SSAVal
):
97 ssa_val_offsets
= {ssa_val_offsets
: 0}
98 self
.ssa_val_offsets
= FMap(ssa_val_offsets
) # type: FMap[SSAVal, int]
100 for ssa_val
in self
.ssa_vals
:
101 first_ssa_val
= ssa_val
103 if first_ssa_val
is None:
104 raise BadMergedSSAVal("MergedSSAVal can't be empty")
105 self
.first_ssa_val
= first_ssa_val
# type: SSAVal
106 # self.ty checks for mismatched base_ty
107 reg_len
= self
.ty
.reg_len
108 loc_set
= None # type: None | LocSet
109 for ssa_val
, cur_offset
in self
.ssa_val_offsets_before_spread
.items():
110 def_spread_idx
= ssa_val
.defining_descriptor
.spread_index
or 0
113 # type: () -> Iterable[Loc]
114 for loc
in ssa_val
.def_loc_set_before_spread
:
115 disallowed_by_use
= False
116 for use
in fn_analysis
.uses
[ssa_val
]:
118 use
.defining_descriptor
.spread_index
or 0
119 # calculate the start for the use's Loc before spread
120 # e.g. if the def's Loc before spread starts at r6
121 # and the def's spread_index is 5
122 # and the use's spread_index is 3
123 # then the use's Loc before spread starts at r8
124 # because 8 == 6 + 5 - 3
125 start
= loc
.start
+ def_spread_idx
- use_spread_idx
126 use_loc
= Loc
.try_make(
127 loc
.kind
, start
=start
,
128 reg_len
=use
.ty_before_spread
.reg_len
)
129 if (use_loc
is None or
130 use_loc
not in use
.use_loc_set_before_spread
):
131 disallowed_by_use
= True
133 if disallowed_by_use
:
135 # FIXME: add spread consistency check
136 start
= loc
.start
- cur_offset
+ self
.offset
137 loc
= Loc
.try_make(loc
.kind
, start
=start
, reg_len
=reg_len
)
138 if loc
is not None and (loc_set
is None or loc
in loc_set
):
140 loc_set
= LocSet(locs())
141 assert loc_set
is not None, "already checked that self isn't empty"
142 if loc_set
.ty
is None:
143 raise BadMergedSSAVal("there are no valid Locs left")
144 assert loc_set
.ty
== self
.ty
, "logic error somewhere"
145 self
.loc_set
= loc_set
# type: LocSet
150 return hash((self
.fn_analysis
, self
.ssa_val_offsets
))
159 return min(self
.ssa_val_offsets_before_spread
.values())
164 return self
.first_ssa_val
.base_ty
168 # type: () -> OFSet[SSAVal]
169 return OFSet(self
.ssa_val_offsets
.keys())
175 for ssa_val
, offset
in self
.ssa_val_offsets_before_spread
.items():
176 cur_ty
= ssa_val
.ty_before_spread
177 if self
.base_ty
!= cur_ty
.base_ty
:
178 raise BadMergedSSAVal(
179 f
"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
180 reg_len
= max(reg_len
, cur_ty
.reg_len
+ offset
- self
.offset
)
181 return Ty(base_ty
=self
.base_ty
, reg_len
=reg_len
)
184 def ssa_val_offsets_before_spread(self
):
185 # type: () -> FMap[SSAVal, int]
186 retval
= {} # type: dict[SSAVal, int]
187 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
189 offset
- ssa_val
.defining_descriptor
.reg_offset_in_unspread
)
192 def offset_by(self
, amount
):
193 # type: (int) -> MergedSSAVal
194 v
= {k
: v
+ amount
for k
, v
in self
.ssa_val_offsets
.items()}
195 return MergedSSAVal(fn_analysis
=self
.fn_analysis
, ssa_val_offsets
=v
)
197 def normalized(self
):
198 # type: () -> MergedSSAVal
199 return self
.offset_by(-self
.offset
)
201 def with_offset_to_match(self
, target
, additional_offset
=0):
202 # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
203 if isinstance(target
, MergedSSAVal
):
204 ssa_val_offsets
= target
.ssa_val_offsets
206 ssa_val_offsets
= {target
: 0}
207 for ssa_val
, offset
in self
.ssa_val_offsets
.items():
208 if ssa_val
in ssa_val_offsets
:
209 return self
.offset_by(
210 ssa_val_offsets
[ssa_val
] + additional_offset
- offset
)
211 raise ValueError("can't change offset to match unrelated MergedSSAVal")
213 def merged(self
, *others
):
214 # type: (*MergedSSAVal) -> MergedSSAVal
215 retval
= dict(self
.ssa_val_offsets
)
217 if other
.fn_analysis
!= self
.fn_analysis
:
218 raise ValueError("fn_analysis mismatch")
219 for ssa_val
, offset
in other
.ssa_val_offsets
.items():
220 if ssa_val
in retval
and retval
[ssa_val
] != offset
:
221 raise BadMergedSSAVal(f
"offset mismatch for {ssa_val}: "
222 f
"{retval[ssa_val]} != {offset}")
223 retval
[ssa_val
] = offset
224 return MergedSSAVal(fn_analysis
=self
.fn_analysis
,
225 ssa_val_offsets
=retval
)
228 def live_interval(self
):
229 # type: () -> ProgramRange
230 live_range
= self
.fn_analysis
.live_ranges
[self
.first_ssa_val
]
231 start
= live_range
.start
232 stop
= live_range
.stop
233 for ssa_val
in self
.ssa_vals
:
234 live_range
= self
.fn_analysis
.live_ranges
[ssa_val
]
235 start
= min(start
, live_range
.start
)
236 stop
= max(stop
, live_range
.stop
)
237 return ProgramRange(start
=start
, stop
=stop
)
241 class SSAValToMergedSSAValMap(Mapping
[SSAVal
, MergedSSAVal
]):
243 # type: (...) -> None
244 self
.__map
= {} # type: dict[SSAVal, MergedSSAVal]
245 self
.__ig
_node
_map
= MergedSSAValToIGNodeMap(
246 _private_merged_ssa_val_map
=self
.__map
)
248 def __getitem__(self
, __key
):
249 # type: (SSAVal) -> MergedSSAVal
250 return self
.__map
[__key
]
253 # type: () -> Iterator[SSAVal]
254 return iter(self
.__map
)
258 return len(self
.__map
)
261 def ig_node_map(self
):
262 # type: () -> MergedSSAValToIGNodeMap
263 return self
.__ig
_node
_map
267 s
= ",\n".join(repr(v
) for v
in self
.__ig
_node
_map
)
268 return f
"SSAValToMergedSSAValMap({{{s}}})"
272 class MergedSSAValToIGNodeMap(Mapping
[MergedSSAVal
, IGNode
]):
275 _private_merged_ssa_val_map
, # type: dict[SSAVal, MergedSSAVal]
277 # type: (...) -> None
278 self
.__merged
_ssa
_val
_map
= _private_merged_ssa_val_map
279 self
.__map
= {} # type: dict[MergedSSAVal, IGNode]
281 def __getitem__(self
, __key
):
282 # type: (MergedSSAVal) -> IGNode
283 return self
.__map
[__key
]
286 # type: () -> Iterator[MergedSSAVal]
287 return iter(self
.__map
)
291 return len(self
.__map
)
293 def add_node(self
, merged_ssa_val
):
294 # type: (MergedSSAVal) -> IGNode
295 node
= self
.__map
.get(merged_ssa_val
, None)
298 added
= 0 # type: int | None
300 for ssa_val
in merged_ssa_val
.ssa_vals
:
301 if ssa_val
in self
.__merged
_ssa
_val
_map
:
303 f
"overlapping `MergedSSAVal`s: {ssa_val} is in both "
304 f
"{merged_ssa_val} and "
305 f
"{self.__merged_ssa_val_map[ssa_val]}")
306 self
.__merged
_ssa
_val
_map
[ssa_val
] = merged_ssa_val
308 retval
= IGNode(merged_ssa_val
)
309 self
.__map
[merged_ssa_val
] = retval
313 if added
is not None:
314 # remove partially added stuff
315 for idx
, ssa_val
in enumerate(merged_ssa_val
.ssa_vals
):
318 del self
.__merged
_ssa
_val
_map
[ssa_val
]
320 def merge_into_one_node(self
, final_merged_ssa_val
):
321 # type: (MergedSSAVal) -> IGNode
322 source_nodes
= {} # type: dict[MergedSSAVal, IGNode]
323 for ssa_val
in final_merged_ssa_val
.ssa_vals
:
324 merged_ssa_val
= self
.__merged
_ssa
_val
_map
[ssa_val
]
325 source_nodes
[merged_ssa_val
] = self
.__map
[merged_ssa_val
]
326 for i
in merged_ssa_val
.ssa_vals
- final_merged_ssa_val
.ssa_vals
:
328 f
"SSAVal {i} appears in source IGNode's merged_ssa_val "
329 f
"but not in merged IGNode's merged_ssa_val: "
330 f
"source_node={self.__map[merged_ssa_val]} "
331 f
"final_merged_ssa_val={final_merged_ssa_val}")
332 # FIXME: work on function from here
333 raise NotImplementedError
334 self
.__values
_set
.discard(value
)
335 for ssa_val
in value
.ssa_val_offsets
.keys():
336 del self
.__merge
_map
[ssa_val
]
340 s
= ",\n".join(repr(v
) for v
in self
.__map
.values())
341 return f
"MergedSSAValToIGNodeMap({{{s}}})"
344 @plain_data(frozen
=True)
346 class InterferenceGraph
:
347 __slots__
= "fn_analysis", "merged_ssa_val_map", "nodes"
349 def __init__(self
, fn_analysis
, merged_ssa_vals
):
350 # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
351 self
.fn_analysis
= fn_analysis
352 self
.merged_ssa_val_map
= SSAValToMergedSSAValMap()
353 self
.nodes
= self
.merged_ssa_val_map
.ig_node_map
354 for i
in merged_ssa_vals
:
355 self
.nodes
.add_node(i
)
357 def merge(self
, ssa_val1
, ssa_val2
, additional_offset
=0):
358 # type: (SSAVal, SSAVal, int) -> IGNode
359 merged1
= self
.merged_ssa_val_map
[ssa_val1
]
360 merged2
= self
.merged_ssa_val_map
[ssa_val2
]
361 merged
= merged1
.with_offset_to_match(ssa_val1
)
362 merged
= merged
.merged(merged2
.with_offset_to_match(
363 ssa_val2
, additional_offset
=additional_offset
))
364 return self
.nodes
.merge_into_one_node(merged
)
367 def minimally_merged(fn_analysis
):
368 # type: (FnAnalysis) -> InterferenceGraph
369 retval
= InterferenceGraph(fn_analysis
=fn_analysis
, merged_ssa_vals
=())
370 for op
in fn_analysis
.fn
.ops
:
371 for inp
in op
.input_uses
:
372 if inp
.unspread_start
!= inp
:
373 retval
.merge(inp
.unspread_start
.ssa_val
, inp
.ssa_val
,
374 additional_offset
=inp
.reg_offset_in_unspread
)
375 for out
in op
.outputs
:
376 retval
.nodes
.add_node(MergedSSAVal(fn_analysis
, out
))
377 if out
.unspread_start
!= out
:
378 retval
.merge(out
.unspread_start
, out
,
379 additional_offset
=out
.reg_offset_in_unspread
)
380 if out
.tied_input
is not None:
381 retval
.merge(out
.tied_input
.ssa_val
, out
)
387 """ interference graph node """
388 __slots__
= "merged_ssa_val", "edges", "loc"
390 def __init__(self
, merged_ssa_val
, edges
=(), loc
=None):
391 # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
392 self
.merged_ssa_val
= merged_ssa_val
393 self
.edges
= OSet(edges
)
396 def add_edge(self
, other
):
397 # type: (IGNode) -> None
398 self
.edges
.add(other
)
399 other
.edges
.add(self
)
401 def __eq__(self
, other
):
402 # type: (object) -> bool
403 if isinstance(other
, IGNode
):
404 return self
.merged_ssa_val
== other
.merged_ssa_val
405 return NotImplemented
408 return hash(self
.merged_ssa_val
)
410 def __repr__(self
, nodes
=None):
411 # type: (None | dict[IGNode, int]) -> str
415 return f
"<IGNode #{nodes[self]}>"
416 nodes
[self
] = len(nodes
)
417 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
418 return (f
"IGNode(#{nodes[self]}, "
419 f
"merged_ssa_val={self.merged_ssa_val}, "
426 return self
.merged_ssa_val
.loc_set
428 def loc_conflicts_with_neighbors(self
, loc
):
429 # type: (Loc) -> bool
430 for neighbor
in self
.edges
:
431 if neighbor
.loc
is not None and neighbor
.loc
.conflicts(loc
):
437 class AllocationFailed
:
438 __slots__
= "node", "merged_ssa_vals", "interference_graph"
440 def __init__(self
, node
, merged_ssa_vals
, interference_graph
):
441 # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
444 self
.merged_ssa_vals
= merged_ssa_vals
445 self
.interference_graph
= interference_graph
448 class AllocationFailedError(Exception):
449 def __init__(self
, msg
, allocation_failed
):
450 # type: (str, AllocationFailed) -> None
451 super().__init
__(msg
, allocation_failed
)
452 self
.allocation_failed
= allocation_failed
455 def try_allocate_registers_without_spilling(merged_ssa_vals
):
456 # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
458 interference_graph
= {
459 i
: IGNode(i
) for i
in merged_ssa_vals
.merged_ssa_vals
}
460 fn_analysis
= merged_ssa_vals
.fn_analysis
461 for ssa_vals
in fn_analysis
.live_at
.values():
462 live_merged_ssa_vals
= OSet() # type: OSet[MergedSSAVal]
463 for ssa_val
in ssa_vals
:
464 live_merged_ssa_vals
.add(merged_ssa_vals
.merge_map
[ssa_val
])
465 for i
, j
in combinations(live_merged_ssa_vals
, 2):
466 if i
.loc_set
.max_conflicts_with(j
.loc_set
) != 0:
467 interference_graph
[i
].add_edge(interference_graph
[j
])
469 nodes_remaining
= OSet(interference_graph
.values())
471 # FIXME: work on code from here
473 def local_colorability_score(node
):
474 # type: (IGNode) -> int
475 """ returns a positive integer if node is locally colorable, returns
476 zero or a negative integer if node isn't known to be locally
477 colorable, the more negative the value, the less colorable
479 if node
not in nodes_remaining
:
481 retval
= len(node
.loc_set
)
482 for neighbor
in node
.edges
:
483 if neighbor
in nodes_remaining
:
484 retval
-= node
.reg_class
.max_conflicts_with(neighbor
.reg_class
)
487 node_stack
= [] # type: list[IGNode]
489 best_node
= None # type: None | IGNode
491 for node
in nodes_remaining
:
492 score
= local_colorability_score(node
)
493 if best_node
is None or score
> best_score
:
497 # it's locally colorable, no need to find a better one
500 if best_node
is None:
502 node_stack
.append(best_node
)
503 nodes_remaining
.remove(best_node
)
505 retval
= {} # type: dict[SSAVal, RegLoc]
507 while len(node_stack
) > 0:
508 node
= node_stack
.pop()
509 if node
.reg
is not None:
510 if node
.reg_conflicts_with_neighbors(node
.reg
):
511 return AllocationFailed(node
=node
,
512 live_intervals
=live_intervals
,
513 interference_graph
=interference_graph
)
515 # pick the first non-conflicting register in node.reg_class, since
516 # register classes are ordered from most preferred to least
517 # preferred register.
518 for reg
in node
.reg_class
:
519 if not node
.reg_conflicts_with_neighbors(reg
):
523 return AllocationFailed(node
=node
,
524 live_intervals
=live_intervals
,
525 interference_graph
=interference_graph
)
527 for ssa_val
, offset
in node
.merged_reg_set
.items():
528 retval
[ssa_val
] = node
.reg
.get_subreg_at_offset(ssa_val
.ty
, offset
)
533 def allocate_registers(ops
):
534 # type: (list[Op]) -> dict[SSAVal, RegLoc]
535 retval
= try_allocate_registers_without_spilling(ops
)
536 if isinstance(retval
, AllocationFailed
):
537 # TODO: implement spilling
538 raise AllocationFailedError(
539 "spilling required but not yet implemented", retval
)