1 from abc
import ABCMeta
, abstractmethod
3 from collections
import defaultdict
4 from enum
import Enum
, unique
5 from typing
import Iterable
, Mapping
, TYPE_CHECKING
7 from nmutil
.plain_data
import plain_data
10 from typing_extensions
import final
16 @plain_data(frozen
=True, unsafe_hash
=True)
21 @plain_data(frozen
=True, unsafe_hash
=True)
22 class GPROrStackLoc(PhysLoc
):
27 class GPR(GPROrStackLoc
, Enum
):
28 def __init__(self
, reg_num
):
30 self
.reg_num
= reg_num
32 R0
= 0; R1
= 1; R2
= 2; R3
= 3; R4
= 4; R5
= 5
33 R6
= 6; R7
= 7; R8
= 8; R9
= 9; R10
= 10; R11
= 11
34 R12
= 12; R13
= 13; R14
= 14; R15
= 15; R16
= 16; R17
= 17
35 R18
= 18; R19
= 19; R20
= 20; R21
= 21; R22
= 22; R23
= 23
36 R24
= 24; R25
= 25; R26
= 26; R27
= 27; R28
= 28; R29
= 29
37 R30
= 30; R31
= 31; R32
= 32; R33
= 33; R34
= 34; R35
= 35
38 R36
= 36; R37
= 37; R38
= 38; R39
= 39; R40
= 40; R41
= 41
39 R42
= 42; R43
= 43; R44
= 44; R45
= 45; R46
= 46; R47
= 47
40 R48
= 48; R49
= 49; R50
= 50; R51
= 51; R52
= 52; R53
= 53
41 R54
= 54; R55
= 55; R56
= 56; R57
= 57; R58
= 58; R59
= 59
42 R60
= 60; R61
= 61; R62
= 62; R63
= 63; R64
= 64; R65
= 65
43 R66
= 66; R67
= 67; R68
= 68; R69
= 69; R70
= 70; R71
= 71
44 R72
= 72; R73
= 73; R74
= 74; R75
= 75; R76
= 76; R77
= 77
45 R78
= 78; R79
= 79; R80
= 80; R81
= 81; R82
= 82; R83
= 83
46 R84
= 84; R85
= 85; R86
= 86; R87
= 87; R88
= 88; R89
= 89
47 R90
= 90; R91
= 91; R92
= 92; R93
= 93; R94
= 94; R95
= 95
48 R96
= 96; R97
= 97; R98
= 98; R99
= 99; R100
= 100; R101
= 101
49 R102
= 102; R103
= 103; R104
= 104; R105
= 105; R106
= 106; R107
= 107
50 R108
= 108; R109
= 109; R110
= 110; R111
= 111; R112
= 112; R113
= 113
51 R114
= 114; R115
= 115; R116
= 116; R117
= 117; R118
= 118; R119
= 119
52 R120
= 120; R121
= 121; R122
= 122; R123
= 123; R124
= 124; R125
= 125
53 R126
= 126; R127
= 127
59 SPECIAL_GPRS
= GPR
.R0
, GPR
.SP
, GPR
.TOC
, GPR
.R13
64 class XERBit(Enum
, PhysLoc
):
70 class GlobalMem(Enum
, PhysLoc
):
71 """singleton representing all non-StackSlot memory"""
72 GlobalMem
= "GlobalMem"
77 class StackSlot(GPROrStackLoc
):
78 """a stack slot. Use OpCopy to load from/store into this stack slot."""
81 def __init__(self
, offset
=None):
82 # type: (int | None) -> None
87 class SSAVal(metaclass
=ABCMeta
):
90 def __init__(self
, id=None):
91 # type: (int | None) -> None
93 id = builtins
.id(self
)
96 def __eq__(self
, rhs
):
97 if isinstance(rhs
, SSAVal
):
98 return self
.id == rhs
.id
104 def _get_phys_loc(self
, phys_loc_in
, value_assignments
=None):
105 # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
106 if phys_loc_in
is not None:
108 if value_assignments
is not None:
109 return value_assignments
.get(self
)
113 def get_phys_loc(self
, value_assignments
=None):
114 # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
118 @plain_data(eq
=False)
120 class SSAGPRVal(SSAVal
):
121 __slots__
= "phys_loc",
123 def __init__(self
, phys_loc
=None):
124 # type: (GPROrStackLoc | None) -> None
125 self
.phys_loc
= phys_loc
131 def get_phys_loc(self
, value_assignments
=None):
132 # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
133 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
134 if isinstance(loc
, GPROrStackLoc
):
138 def get_reg_num(self
, value_assignments
=None):
139 # type: (dict[SSAVal, PhysLoc] | None) -> int | None
140 reg
= self
.get_reg(value_assignments
)
145 def get_reg(self
, value_assignments
=None):
146 # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
147 loc
= self
.get_phys_loc(value_assignments
)
148 if isinstance(loc
, GPR
):
152 def get_stack_slot(self
, value_assignments
=None):
153 # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
154 loc
= self
.get_phys_loc(value_assignments
)
155 if isinstance(loc
, StackSlot
):
159 def possible_reg_assignments(self
, value_assignments
,
160 conflicting_regs
=set()):
161 # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
162 if self
.get_phys_loc(value_assignments
) is not None:
163 raise ValueError("can't assign a already-assigned SSA value")
165 if reg
not in conflicting_regs
:
169 @plain_data(eq
=False)
171 class SSAXERBitVal(SSAVal
):
172 __slots__
= "phys_loc",
174 def __init__(self
, phys_loc
=None):
175 # type: (XERBit | None) -> None
176 self
.phys_loc
= phys_loc
178 def get_phys_loc(self
, value_assignments
=None):
179 # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
180 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
181 if isinstance(loc
, XERBit
):
186 @plain_data(eq
=False)
188 class SSAMemory(SSAVal
):
189 __slots__
= "phys_loc",
191 def __init__(self
, phys_loc
=GlobalMem
.GlobalMem
):
192 # type: (GlobalMem) -> None
193 self
.phys_loc
= phys_loc
195 def get_phys_loc(self
, value_assignments
=None):
196 # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None
197 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
198 if isinstance(loc
, GlobalMem
):
203 @plain_data(unsafe_hash
=True, frozen
=True)
208 def __init__(self
, regs
):
209 # type: (Iterable[SSAGPRVal]) -> None
210 self
.regs
= tuple(regs
)
213 return len(self
.regs
)
215 def is_unassigned(self
, value_assignments
=None):
216 # type: (dict[SSAVal, PhysLoc] | None) -> bool
217 for val
in self
.regs
:
218 if val
.get_phys_loc(value_assignments
) is not None:
222 def try_get_range(self
, value_assignments
=None, allow_unassigned
=False,
223 raise_if_invalid
=False):
224 # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
225 if len(self
.regs
) == 0:
228 retval
= None # type: range | None
229 for i
, val
in enumerate(self
.regs
):
230 if val
.get_phys_loc(value_assignments
) is None:
231 if not allow_unassigned
:
233 raise ValueError("not a valid register range: "
234 "unassigned SSA value encountered")
237 reg
= val
.get_reg_num(value_assignments
)
240 raise ValueError("not a valid register range: "
241 "non-register encountered")
243 expected_range
= range(reg
- i
, reg
- i
+ len(self
.regs
))
245 retval
= expected_range
246 elif retval
!= expected_range
:
248 raise ValueError("not a valid register range: "
249 "register out of sequence")
253 def possible_reg_assignments(
256 value_assignments
, # type: dict[SSAVal, PhysLoc] | None
257 conflicting_regs
=set(), # type: set[GPR]
258 ): # type: (...) -> Iterable[GPR]
259 index
= self
.regs
.index(val
)
261 while alignment
< len(self
.regs
):
263 r
= self
.try_get_range(value_assignments
)
264 if r
is not None and r
.start
% alignment
!= 0:
265 raise ValueError("must be a ascending aligned range of GPRs")
267 for i
in range(0, len(GPR
), alignment
):
268 r
= range(i
, i
+ len(self
.regs
))
269 if any(GPR(reg
) in conflicting_regs
for reg
in r
):
277 @plain_data(unsafe_hash
=True, frozen
=True)
278 class EqualityConstraint
:
279 __slots__
= "lhs", "rhs"
281 def __init__(self
, lhs
, rhs
):
282 # type: (SSAVal, SSAVal) -> None
287 @plain_data(unsafe_hash
=True, frozen
=True)
288 class Op(metaclass
=ABCMeta
):
291 def input_ssa_vals(self
):
292 # type: () -> Iterable[SSAVal]
293 for arg
in self
.inputs().values():
294 if isinstance(arg
, VecArg
):
299 def output_ssa_vals(self
):
300 # type: () -> Iterable[SSAVal]
301 for arg
in self
.outputs().values():
302 if isinstance(arg
, VecArg
):
309 # type: () -> dict[str, VecArg | SSAVal]
314 # type: () -> dict[str, VecArg | SSAVal]
318 def possible_reg_assignments(self
, val
, value_assignments
):
319 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
322 def get_equality_constraints(self
):
323 # type: () -> Iterable[EqualityConstraint]
331 @plain_data(unsafe_hash
=True, frozen
=True)
334 __slots__
= "dest", "src"
337 # type: () -> dict[str, VecArg | SSAVal]
338 return {"src": self
.src
}
341 # type: () -> dict[str, VecArg | SSAVal]
342 return {"dest": self
.dest
}
344 def __init__(self
, dest
, src
):
345 # type: (VecArg | SSAVal, VecArg | SSAVal) -> None
346 if isinstance(dest
, VecArg
) and isinstance(src
, VecArg
):
347 if len(src
.regs
) != len(dest
.regs
):
348 raise TypeError(f
"source length must match dest "
349 f
"length: {src} doesn't match {dest}")
350 elif type(dest
) != type(src
):
351 raise TypeError(f
"source argument type must match dest "
352 f
"argument type: {src} doesn't match {dest}")
356 def possible_reg_assignments(self
, val
, value_assignments
):
357 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
358 if val
not in self
.input_ssa_vals() \
359 and val
not in self
.output_ssa_vals():
360 raise ValueError(f
"{val} must be an operand of {self}")
361 if val
.get_phys_loc(value_assignments
) is not None:
362 raise ValueError(f
"{val} already assigned a physical location")
363 conflicting_regs
= set() # type: set[GPR]
364 if val
in self
.output_ssa_vals() and isinstance(self
.dest
, VecArg
):
365 # OpCopy is the only op that can write to physical locations in
366 # any order, it handles figuring out the right instruction sequence
367 dest_locs
= {} # type: dict[GPROrStackLoc, SSAVal]
368 for val
in self
.dest
.regs
:
369 loc
= val
.get_phys_loc(value_assignments
)
374 f
"duplicate destination location not allowed: "
375 f
"{val} is assigned to {loc} which is also "
376 f
"written by {dest_locs[loc]}")
378 if isinstance(loc
, GPR
):
379 conflicting_regs
.add(loc
)
380 if not isinstance(val
, SSAGPRVal
):
381 raise ValueError("invalid operand type")
382 return val
.possible_reg_assignments(value_assignments
,
386 def range_overlaps(range1
, range2
):
387 # type: (range, range) -> bool
388 if len(range1
) == 0 or len(range2
) == 0:
390 range1_last
= range1
[-1]
391 range2_last
= range2
[-1]
392 return (range1
.start
in range2
or range1_last
in range2
or
393 range2
.start
in range1
or range2_last
in range1
)
396 @plain_data(unsafe_hash
=True, frozen
=True)
399 __slots__
= "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
402 # type: () -> dict[str, VecArg | SSAVal]
403 return {"RA": self
.RA
, "RB": self
.RB
, "CY_in": self
.CY_in
}
406 # type: () -> dict[str, VecArg | SSAVal]
407 return {"RT": self
.RT
, "CY_out": self
.CY_out
}
409 def __init__(self
, RT
, RA
, RB
, CY_in
, CY_out
, is_sub
):
410 # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None
411 if len(RA
.regs
) != len(RT
.regs
):
412 raise TypeError(f
"source length must match dest "
413 f
"length: {RA} doesn't match {RT}")
414 if len(RB
.regs
) != len(RT
.regs
):
415 raise TypeError(f
"source length must match dest "
416 f
"length: {RB} doesn't match {RT}")
424 def possible_reg_assignments(self
, val
, value_assignments
):
425 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
426 if val
not in self
.input_ssa_vals() \
427 and val
not in self
.output_ssa_vals():
428 raise ValueError(f
"{val} must be an operand of {self}")
429 if val
.get_phys_loc(value_assignments
) is not None:
430 raise ValueError(f
"{val} already assigned a physical location")
431 if self
.CY_in
== val
or self
.CY_out
== val
:
433 elif val
in self
.RT
.regs
:
434 # since possible_reg_assignments only returns aligned
435 # vectors, all possible assignments either are the same as an
436 # input or don't overlap with an input and we avoid the incorrect
437 # results caused by partial overlaps overwriting input elements
438 # before they're read
439 yield from self
.RT
.possible_reg_assignments(val
, value_assignments
)
440 elif val
in self
.RA
.regs
:
441 yield from self
.RA
.possible_reg_assignments(val
, value_assignments
)
443 yield from self
.RB
.possible_reg_assignments(val
, value_assignments
)
445 def get_equality_constraints(self
):
446 # type: () -> Iterable[EqualityConstraint]
447 yield EqualityConstraint(self
.CY_in
, self
.CY_out
)
451 # type: (None | GPR | range) -> set[GPR]
454 if isinstance(v
, range):
455 return set(map(GPR
, v
))
459 @plain_data(unsafe_hash
=True, frozen
=True)
461 class OpBigIntMulDiv(Op
):
462 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div"
465 # type: () -> dict[str, VecArg | SSAVal]
466 return {"RA": self
.RA
, "RB": self
.RB
, "RC": self
.RC
}
469 # type: () -> dict[str, VecArg | SSAVal]
470 return {"RT": self
.RT
, "RS": self
.RS
}
472 def __init__(self
, RT
, RA
, RB
, RC
, RS
, is_div
):
473 # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None
474 if len(RA
.regs
) != len(RT
.regs
):
475 raise TypeError(f
"source length must match dest "
476 f
"length: {RA} doesn't match {RT}")
484 def possible_reg_assignments(self
, val
, value_assignments
):
485 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
486 if val
not in self
.input_ssa_vals() \
487 and val
not in self
.output_ssa_vals():
488 raise ValueError(f
"{val} must be an operand of {self}")
489 if val
.get_phys_loc(value_assignments
) is not None:
490 raise ValueError(f
"{val} already assigned a physical location")
492 RT_range
= self
.RT
.try_get_range(value_assignments
,
493 allow_unassigned
=True,
494 raise_if_invalid
=True)
495 RA_range
= self
.RA
.try_get_range(value_assignments
,
496 allow_unassigned
=True,
497 raise_if_invalid
=True)
498 RC_RS_reg
= self
.RC
.get_reg(value_assignments
)
499 if RC_RS_reg
is None:
500 RC_RS_reg
= self
.RS
.get_reg(value_assignments
)
502 if self
.RC
== val
or self
.RS
== val
:
503 if RC_RS_reg
is not None:
506 conflicting_regs
= to_reg_set(RT_range
) |
to_reg_set(RA_range
)
507 yield from self
.RC
.possible_reg_assignments(value_assignments
,
509 elif val
in self
.RT
.regs
:
510 # since possible_reg_assignments only returns aligned
511 # vectors, all possible assignments either are the same as
512 # RA or don't overlap with RA and we avoid the incorrect
513 # results caused by partial overlaps overwriting input elements
514 # before they're read
515 yield from self
.RT
.possible_reg_assignments(
516 val
, value_assignments
,
517 conflicting_regs
=to_reg_set(RA_range
) |
to_reg_set(RC_RS_reg
))
519 yield from self
.RA
.possible_reg_assignments(
520 val
, value_assignments
,
521 conflicting_regs
=to_reg_set(RT_range
) |
to_reg_set(RC_RS_reg
))
523 def get_equality_constraints(self
):
524 # type: () -> Iterable[EqualityConstraint]
525 yield EqualityConstraint(self
.RC
, self
.RS
)
530 class ShiftKind(Enum
):
536 @plain_data(unsafe_hash
=True, frozen
=True)
538 class OpBigIntShift(Op
):
539 __slots__
= "RT", "inp", "sh", "kind"
542 # type: () -> dict[str, VecArg | SSAVal]
543 return {"inp": self
.inp
, "sh": self
.sh
}
546 # type: () -> dict[str, VecArg | SSAVal]
547 return {"RT": self
.RT
}
549 def __init__(self
, RT
, inp
, sh
, kind
):
550 # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None
551 if len(inp
.regs
) != len(RT
.regs
):
552 raise TypeError(f
"source length must match dest "
553 f
"length: {inp} doesn't match {RT}")
559 def possible_reg_assignments(self
, val
, value_assignments
):
560 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
561 if val
not in self
.input_ssa_vals() \
562 and val
not in self
.output_ssa_vals():
563 raise ValueError(f
"{val} must be an operand of {self}")
564 if val
.get_phys_loc(value_assignments
) is not None:
565 raise ValueError(f
"{val} already assigned a physical location")
567 RT_range
= self
.RT
.try_get_range(value_assignments
,
568 allow_unassigned
=True,
569 raise_if_invalid
=True)
570 inp_range
= self
.inp
.try_get_range(value_assignments
,
571 allow_unassigned
=True,
572 raise_if_invalid
=True)
573 sh_reg
= self
.sh
.get_reg(value_assignments
)
576 conflicting_regs
= to_reg_set(RT_range
)
577 yield from self
.sh
.possible_reg_assignments(value_assignments
,
579 elif val
in self
.RT
.regs
:
580 # since possible_reg_assignments only returns aligned
581 # vectors, all possible assignments either are the same as
582 # RA or don't overlap with RA and we avoid the incorrect
583 # results caused by partial overlaps overwriting input elements
584 # before they're read
585 yield from self
.RT
.possible_reg_assignments(
586 val
, value_assignments
,
587 conflicting_regs
=to_reg_set(inp_range
) |
to_reg_set(sh_reg
))
589 yield from self
.inp
.possible_reg_assignments(
590 val
, value_assignments
,
591 conflicting_regs
=to_reg_set(RT_range
))
594 @plain_data(unsafe_hash
=True, frozen
=True)
597 __slots__
= "out", "value"
600 # type: () -> dict[str, VecArg | SSAVal]
604 # type: () -> dict[str, VecArg | SSAVal]
605 return {"out": self
.out
}
607 def __init__(self
, out
, value
):
608 # type: (VecArg | SSAGPRVal, int) -> None
612 def possible_reg_assignments(self
, val
, value_assignments
):
613 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
614 if val
not in self
.input_ssa_vals() \
615 and val
not in self
.output_ssa_vals():
616 raise ValueError(f
"{val} must be an operand of {self}")
617 if val
.get_phys_loc(value_assignments
) is not None:
618 raise ValueError(f
"{val} already assigned a physical location")
620 if isinstance(self
.out
, VecArg
):
621 yield from self
.out
.possible_reg_assignments(val
,
624 yield from self
.out
.possible_reg_assignments(value_assignments
)
627 @plain_data(unsafe_hash
=True, frozen
=True)
633 # type: () -> dict[str, VecArg | SSAVal]
637 # type: () -> dict[str, VecArg | SSAVal]
638 return {"out": self
.out
}
640 def __init__(self
, out
):
641 # type: (SSAXERBitVal) -> None
644 def possible_reg_assignments(self
, val
, value_assignments
):
645 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
646 if val
not in self
.input_ssa_vals() \
647 and val
not in self
.output_ssa_vals():
648 raise ValueError(f
"{val} must be an operand of {self}")
649 if val
.get_phys_loc(value_assignments
) is not None:
650 raise ValueError(f
"{val} already assigned a physical location")
655 @plain_data(unsafe_hash
=True, frozen
=True)
658 __slots__
= "RT", "RA", "offset", "mem"
661 # type: () -> dict[str, VecArg | SSAVal]
662 return {"RA": self
.RA
, "mem": self
.mem
}
665 # type: () -> dict[str, VecArg | SSAVal]
666 return {"RT": self
.RT
}
668 def __init__(self
, RT
, RA
, offset
, mem
):
669 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
675 def possible_reg_assignments(self
, val
, value_assignments
):
676 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
677 if val
not in self
.input_ssa_vals() \
678 and val
not in self
.output_ssa_vals():
679 raise ValueError(f
"{val} must be an operand of {self}")
680 if val
.get_phys_loc(value_assignments
) is not None:
681 raise ValueError(f
"{val} already assigned a physical location")
683 RA_reg
= self
.RA
.get_reg(value_assignments
)
686 yield GlobalMem
.GlobalMem
688 if isinstance(self
.RT
, VecArg
):
689 conflicting_regs
= to_reg_set(self
.RT
.try_get_range(
690 value_assignments
, allow_unassigned
=True,
691 raise_if_invalid
=True))
693 conflicting_regs
= set()
694 yield from self
.RA
.possible_reg_assignments(value_assignments
,
696 elif isinstance(self
.RT
, VecArg
):
697 yield from self
.RT
.possible_reg_assignments(
698 val
, value_assignments
,
699 conflicting_regs
=to_reg_set(RA_reg
))
701 yield from self
.RT
.possible_reg_assignments(value_assignments
)
704 @plain_data(unsafe_hash
=True, frozen
=True)
707 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out"
710 # type: () -> dict[str, VecArg | SSAVal]
711 return {"RS": self
.RS
, "RA": self
.RA
, "mem_in": self
.mem_in
}
714 # type: () -> dict[str, VecArg | SSAVal]
715 return {"mem_out": self
.mem_out
}
717 def __init__(self
, RS
, RA
, offset
, mem_in
, mem_out
):
718 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None
723 self
.mem_out
= mem_out
725 def possible_reg_assignments(self
, val
, value_assignments
):
726 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
727 if val
not in self
.input_ssa_vals() \
728 and val
not in self
.output_ssa_vals():
729 raise ValueError(f
"{val} must be an operand of {self}")
730 if val
.get_phys_loc(value_assignments
) is not None:
731 raise ValueError(f
"{val} already assigned a physical location")
733 if self
.mem_in
== val
or self
.mem_out
== val
:
734 yield GlobalMem
.GlobalMem
736 yield from self
.RA
.possible_reg_assignments(value_assignments
)
737 elif isinstance(self
.RS
, VecArg
):
738 yield from self
.RS
.possible_reg_assignments(val
, value_assignments
)
740 yield from self
.RS
.possible_reg_assignments(value_assignments
)
742 def get_equality_constraints(self
):
743 # type: () -> Iterable[EqualityConstraint]
744 yield EqualityConstraint(self
.mem_in
, self
.mem_out
)
747 @plain_data(unsafe_hash
=True, frozen
=True)
753 # type: () -> dict[str, VecArg | SSAVal]
757 # type: () -> dict[str, VecArg | SSAVal]
758 return {"out": self
.out
}
760 def __init__(self
, out
):
761 # type: (VecArg | SSAGPRVal) -> None
764 def possible_reg_assignments(self
, val
, value_assignments
):
765 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
766 if val
not in self
.input_ssa_vals() \
767 and val
not in self
.output_ssa_vals():
768 raise ValueError(f
"{val} must be an operand of {self}")
769 if val
.get_phys_loc(value_assignments
) is not None:
770 raise ValueError(f
"{val} already assigned a physical location")
772 if isinstance(self
.out
, VecArg
):
773 yield from self
.out
.possible_reg_assignments(val
,
776 yield from self
.out
.possible_reg_assignments(value_assignments
)
779 @plain_data(unsafe_hash
=True, frozen
=True)
781 class OpInputMem(Op
):
785 # type: () -> dict[str, VecArg | SSAVal]
789 # type: () -> dict[str, VecArg | SSAVal]
790 return {"out": self
.out
}
792 def __init__(self
, out
):
793 # type: (SSAMemory) -> None
796 def possible_reg_assignments(self
, val
, value_assignments
):
797 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
798 if val
not in self
.input_ssa_vals() \
799 and val
not in self
.output_ssa_vals():
800 raise ValueError(f
"{val} must be an operand of {self}")
801 if val
.get_phys_loc(value_assignments
) is not None:
802 raise ValueError(f
"{val} already assigned a physical location")
804 yield GlobalMem
.GlobalMem
807 def op_set_to_list(ops
):
808 # type: (Iterable[Op]) -> list[Op]
809 worklists
= [set()] # type: list[set[Op]]
810 input_vals_to_ops_map
= defaultdict(set) # type: dict[SSAVal, set[Op]]
811 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
814 for val
in op
.input_ssa_vals():
816 input_vals_to_ops_map
[val
].add(op
)
817 while len(worklists
) <= input_count
:
818 worklists
.append(set())
819 ops_to_pending_input_count_map
[op
] = input_count
820 worklists
[input_count
].add(op
)
821 retval
= [] # type: list[Op]
822 ready_vals
= set() # type: set[SSAVal]
823 while len(worklists
[0]) != 0:
824 writing_op
= worklists
[0].pop()
825 retval
.append(writing_op
)
826 for val
in writing_op
.output_ssa_vals():
827 if val
in ready_vals
:
828 raise ValueError(f
"multiple instructions must not write "
829 f
"to the same SSA value: {val}")
831 for reading_op
in input_vals_to_ops_map
[val
]:
832 pending
= ops_to_pending_input_count_map
[reading_op
]
833 worklists
[pending
].remove(reading_op
)
835 worklists
[pending
].add(reading_op
)
836 ops_to_pending_input_count_map
[reading_op
] = pending
837 for worklist
in worklists
:
839 raise ValueError(f
"instruction is part of a dependency loop or "
840 f
"its inputs are never written: {op}")
844 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
846 __slots__
= "assignment", "last_use"
848 def __init__(self
, assignment
, last_use
=None):
849 # type: (int, int | None) -> None
851 last_use
= assignment
852 if last_use
< assignment
:
853 raise ValueError("uses must be after assignment")
854 if assignment
< 0 or last_use
< 0:
855 raise ValueError("indexes must be nonnegative")
856 self
.assignment
= assignment
857 self
.last_use
= last_use
859 def overlaps(self
, other
):
860 # type: (LiveInterval) -> bool
861 if self
.assignment
== other
.assignment
:
863 return self
.last_use
> other
.assignment \
864 and other
.last_use
> self
.assignment
866 def __add__(self
, use
):
867 # type: (int) -> LiveInterval
868 last_use
= max(self
.last_use
, use
)
869 return LiveInterval(assignment
=self
.assignment
, last_use
=last_use
)
872 class LiveIntervals(Mapping
[SSAVal
, LiveInterval
]):
873 def __init__(self
, ops
):
874 # type: (list[Op]) -> None
875 live_intervals
= {} # type: dict[SSAVal, LiveInterval]
876 for op_idx
, op
in enumerate(ops
):
877 for val
in op
.input_ssa_vals():
878 live_intervals
[val
] += op_idx
879 for val
in op
.output_ssa_vals():
880 if val
in live_intervals
:
881 raise ValueError(f
"multiple instructions must not write "
882 f
"to the same SSA value: {val}")
883 live_intervals
[val
] = LiveInterval(op_idx
)
884 self
.__live
_intervals
= live_intervals
886 def __getitem__(self
, key
):
887 # type: (SSAVal) -> LiveInterval
888 return self
.__live
_intervals
[key
]
891 return iter(self
.__live
_intervals
)
895 class AllocationFailed
:
896 __slots__
= "op_idx", "arg", "live_intervals", "free_regs"
898 def __init__(self
, op_idx
, arg
, live_intervals
, free_regs
):
899 # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None
902 self
.live_intervals
= live_intervals
903 self
.free_regs
= free_regs
906 def try_allocate_registers_without_spilling(ops
):
907 # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
908 live_intervals
= LiveIntervals(ops
)
909 free_regs
= set() # type: set[GPR | XERBit]
910 free_regs
.update(GPR
)
911 free_regs
.difference_update(SPECIAL_GPRS
)
912 free_regs
.update(XERBit
)
913 raise NotImplementedError
916 def allocate_registers(ops
):
917 # type: (list[Op]) -> None
918 raise NotImplementedError