1 from abc
import ABCMeta
, abstractmethod
2 from collections
import defaultdict
3 from enum
import Enum
, unique
4 from typing
import AbstractSet
, Iterable
, Mapping
, TYPE_CHECKING
6 from nmutil
.plain_data
import plain_data
9 from typing_extensions
import final
, Self
15 @plain_data(frozen
=True, unsafe_hash
=True)
20 @plain_data(frozen
=True, unsafe_hash
=True)
21 class GPROrStackLoc(PhysLoc
):
26 class GPR(GPROrStackLoc
, Enum
):
27 def __init__(self
, reg_num
):
29 self
.reg_num
= reg_num
31 R0
= 0; R1
= 1; R2
= 2; R3
= 3; R4
= 4; R5
= 5
32 R6
= 6; R7
= 7; R8
= 8; R9
= 9; R10
= 10; R11
= 11
33 R12
= 12; R13
= 13; R14
= 14; R15
= 15; R16
= 16; R17
= 17
34 R18
= 18; R19
= 19; R20
= 20; R21
= 21; R22
= 22; R23
= 23
35 R24
= 24; R25
= 25; R26
= 26; R27
= 27; R28
= 28; R29
= 29
36 R30
= 30; R31
= 31; R32
= 32; R33
= 33; R34
= 34; R35
= 35
37 R36
= 36; R37
= 37; R38
= 38; R39
= 39; R40
= 40; R41
= 41
38 R42
= 42; R43
= 43; R44
= 44; R45
= 45; R46
= 46; R47
= 47
39 R48
= 48; R49
= 49; R50
= 50; R51
= 51; R52
= 52; R53
= 53
40 R54
= 54; R55
= 55; R56
= 56; R57
= 57; R58
= 58; R59
= 59
41 R60
= 60; R61
= 61; R62
= 62; R63
= 63; R64
= 64; R65
= 65
42 R66
= 66; R67
= 67; R68
= 68; R69
= 69; R70
= 70; R71
= 71
43 R72
= 72; R73
= 73; R74
= 74; R75
= 75; R76
= 76; R77
= 77
44 R78
= 78; R79
= 79; R80
= 80; R81
= 81; R82
= 82; R83
= 83
45 R84
= 84; R85
= 85; R86
= 86; R87
= 87; R88
= 88; R89
= 89
46 R90
= 90; R91
= 91; R92
= 92; R93
= 93; R94
= 94; R95
= 95
47 R96
= 96; R97
= 97; R98
= 98; R99
= 99; R100
= 100; R101
= 101
48 R102
= 102; R103
= 103; R104
= 104; R105
= 105; R106
= 106; R107
= 107
49 R108
= 108; R109
= 109; R110
= 110; R111
= 111; R112
= 112; R113
= 113
50 R114
= 114; R115
= 115; R116
= 116; R117
= 117; R118
= 118; R119
= 119
51 R120
= 120; R121
= 121; R122
= 122; R123
= 123; R124
= 124; R125
= 125
52 R126
= 126; R127
= 127
58 SPECIAL_GPRS
= GPR
.R0
, GPR
.SP
, GPR
.TOC
, GPR
.R13
63 class XERBit(Enum
, PhysLoc
):
69 class GlobalMem(Enum
, PhysLoc
):
70 """singleton representing all non-StackSlot memory"""
71 GlobalMem
= "GlobalMem"
74 ALLOCATABLE_REGS
= frozenset((set(GPR
) - set(SPECIAL_GPRS
))
75 |
set(XERBit
) |
set(GlobalMem
))
80 class StackSlot(GPROrStackLoc
):
81 """a stack slot. Use OpCopy to load from/store into this stack slot."""
84 def __init__(self
, offset
=None):
85 # type: (int | None) -> None
89 class SSAVal(metaclass
=ABCMeta
):
90 __slots__
= "op", "arg_name", "element_index"
92 def __init__(self
, op
, arg_name
, element_index
):
93 # type: (Op, str, int) -> None
95 """the Op that writes this SSAVal"""
97 self
.arg_name
= arg_name
98 self
.element_index
= element_index
101 def __eq__(self
, rhs
):
102 if isinstance(rhs
, SSAVal
):
103 return (self
.op
is rhs
.op
104 and self
.arg_name
== rhs
.arg_name
105 and self
.element_index
== rhs
.element_index
)
110 return hash((id(self
.op
), self
.arg_name
, self
.element_index
))
112 def _get_phys_loc(self
, phys_loc_in
, value_assignments
=None):
113 # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
114 if phys_loc_in
is not None:
116 if value_assignments
is not None:
117 return value_assignments
.get(self
)
121 def get_phys_loc(self
, value_assignments
=None):
122 # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
127 name
= self
.__class
__.__name
__
128 op
= object.__repr
__(self
.op
)
129 phys_loc
= self
.get_phys_loc()
130 return (f
"{name}(op={op}, arg_name={self.arg_name}, "
131 f
"element_index={self.element_index}, phys_loc={phys_loc})")
134 def like(self
, op
, arg_name
):
135 # type: (Op, str) -> Self
136 """create a new SSAVal based off of self's type.
137 has same signature as VecArg.like.
139 return self
.__class
__(op
=op
, arg_name
=arg_name
,
144 class SSAGPRVal(SSAVal
):
145 __slots__
= "phys_loc",
147 def __init__(self
, op
, arg_name
, element_index
, phys_loc
=None):
148 # type: (Op, str, int, GPROrStackLoc | None) -> None
149 super().__init
__(op
, arg_name
, element_index
)
150 self
.phys_loc
= phys_loc
155 def get_phys_loc(self
, value_assignments
=None):
156 # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
157 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
158 if isinstance(loc
, GPROrStackLoc
):
162 def get_reg_num(self
, value_assignments
=None):
163 # type: (dict[SSAVal, PhysLoc] | None) -> int | None
164 reg
= self
.get_reg(value_assignments
)
169 def get_reg(self
, value_assignments
=None):
170 # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
171 loc
= self
.get_phys_loc(value_assignments
)
172 if isinstance(loc
, GPR
):
176 def get_stack_slot(self
, value_assignments
=None):
177 # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
178 loc
= self
.get_phys_loc(value_assignments
)
179 if isinstance(loc
, StackSlot
):
183 def possible_reg_assignments(self
, value_assignments
,
184 conflicting_regs
=set()):
185 # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
186 if self
.get_phys_loc(value_assignments
) is not None:
187 raise ValueError("can't assign a already-assigned SSA value")
189 if reg
not in conflicting_regs
:
194 class SSAXERBitVal(SSAVal
):
195 __slots__
= "phys_loc",
197 def __init__(self
, op
, arg_name
, element_index
, phys_loc
=None):
198 # type: (Op, str, int, XERBit | None) -> None
199 super().__init
__(op
, arg_name
, element_index
)
200 self
.phys_loc
= phys_loc
202 def get_phys_loc(self
, value_assignments
=None):
203 # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
204 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
205 if isinstance(loc
, XERBit
):
211 class SSAMemory(SSAVal
):
212 __slots__
= "phys_loc",
214 def __init__(self
, op
, arg_name
, element_index
,
215 phys_loc
=GlobalMem
.GlobalMem
):
216 # type: (Op, str, int, GlobalMem) -> None
217 super().__init
__(op
, arg_name
, element_index
)
218 self
.phys_loc
= phys_loc
220 def get_phys_loc(self
, value_assignments
=None):
221 # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem
222 loc
= self
._get
_phys
_loc
(self
.phys_loc
, value_assignments
)
223 if isinstance(loc
, GlobalMem
):
228 @plain_data(unsafe_hash
=True, frozen
=True)
233 def __init__(self
, regs
):
234 # type: (Iterable[SSAGPRVal]) -> None
235 self
.regs
= tuple(regs
)
238 return len(self
.regs
)
240 def is_unassigned(self
, value_assignments
=None):
241 # type: (dict[SSAVal, PhysLoc] | None) -> bool
242 for val
in self
.regs
:
243 if val
.get_phys_loc(value_assignments
) is not None:
247 def try_get_range(self
, value_assignments
=None, allow_unassigned
=False,
248 raise_if_invalid
=False):
249 # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
250 if len(self
.regs
) == 0:
253 retval
= None # type: range | None
254 for i
, val
in enumerate(self
.regs
):
255 if val
.get_phys_loc(value_assignments
) is None:
256 if not allow_unassigned
:
258 raise ValueError("not a valid register range: "
259 "unassigned SSA value encountered")
262 reg
= val
.get_reg_num(value_assignments
)
265 raise ValueError("not a valid register range: "
266 "non-register encountered")
268 expected_range
= range(reg
- i
, reg
- i
+ len(self
.regs
))
270 retval
= expected_range
271 elif retval
!= expected_range
:
273 raise ValueError("not a valid register range: "
274 "register out of sequence")
278 def possible_reg_assignments(
281 value_assignments
, # type: dict[SSAVal, PhysLoc] | None
282 conflicting_regs
=set(), # type: set[GPR]
283 ): # type: (...) -> Iterable[GPR]
284 index
= self
.regs
.index(val
)
286 while alignment
< len(self
.regs
):
288 r
= self
.try_get_range(value_assignments
)
289 if r
is not None and r
.start
% alignment
!= 0:
290 raise ValueError("must be a ascending aligned range of GPRs")
292 for i
in range(0, len(GPR
), alignment
):
293 r
= range(i
, i
+ len(self
.regs
))
294 if any(GPR(reg
) in conflicting_regs
for reg
in r
):
300 def like(self
, op
, arg_name
):
301 # type: (Op, str) -> VecArg
302 """create a new VecArg based off of self's type.
303 has same signature as SSAVal.like.
306 SSAGPRVal(op
, arg_name
, i
) for i
in range(len(self
.regs
)))
309 def vec_or_scalar_arg(element_count
, op
, arg_name
):
310 # type: (int | None, Op, str) -> VecArg | SSAGPRVal
311 if element_count
is None:
312 return SSAGPRVal(op
, arg_name
, 0)
314 return VecArg(SSAGPRVal(op
, arg_name
, i
) for i
in range(element_count
))
318 @plain_data(unsafe_hash
=True, frozen
=True)
319 class EqualityConstraint
:
320 __slots__
= "lhs", "rhs"
322 def __init__(self
, lhs
, rhs
):
323 # type: (SSAVal, SSAVal) -> None
328 @plain_data(unsafe_hash
=True, frozen
=True)
329 class Op(metaclass
=ABCMeta
):
332 def input_ssa_vals(self
):
333 # type: () -> Iterable[SSAVal]
334 for arg
in self
.inputs().values():
335 if isinstance(arg
, VecArg
):
340 def output_ssa_vals(self
):
341 # type: () -> Iterable[SSAVal]
342 for arg
in self
.outputs().values():
343 if isinstance(arg
, VecArg
):
350 # type: () -> dict[str, VecArg | SSAVal]
355 # type: () -> dict[str, VecArg | SSAVal]
359 def possible_reg_assignments(self
, val
, value_assignments
):
360 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
363 def get_equality_constraints(self
):
364 # type: () -> Iterable[EqualityConstraint]
372 @plain_data(unsafe_hash
=True, frozen
=True)
375 __slots__
= "dest", "src"
378 # type: () -> dict[str, VecArg | SSAVal]
379 return {"src": self
.src
}
382 # type: () -> dict[str, VecArg | SSAVal]
383 return {"dest": self
.dest
}
385 def __init__(self
, src
):
386 # type: (SSAGPRVal) -> None
387 self
.dest
= src
.like(op
=self
, arg_name
="dest")
390 def possible_reg_assignments(self
, val
, value_assignments
):
391 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
392 if val
not in self
.input_ssa_vals() \
393 and val
not in self
.output_ssa_vals():
394 raise ValueError(f
"{val} must be an operand of {self}")
395 if val
.get_phys_loc(value_assignments
) is not None:
396 raise ValueError(f
"{val} already assigned a physical location")
397 if not isinstance(val
, SSAGPRVal
):
398 raise ValueError("invalid operand type")
399 return val
.possible_reg_assignments(value_assignments
)
402 def range_overlaps(range1
, range2
):
403 # type: (range, range) -> bool
404 if len(range1
) == 0 or len(range2
) == 0:
406 range1_last
= range1
[-1]
407 range2_last
= range2
[-1]
408 return (range1
.start
in range2
or range1_last
in range2
or
409 range2
.start
in range1
or range2_last
in range1
)
412 @plain_data(unsafe_hash
=True, frozen
=True)
415 __slots__
= "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
418 # type: () -> dict[str, VecArg | SSAVal]
419 return {"RA": self
.RA
, "RB": self
.RB
, "CY_in": self
.CY_in
}
422 # type: () -> dict[str, VecArg | SSAVal]
423 return {"RT": self
.RT
, "CY_out": self
.CY_out
}
425 def __init__(self
, RA
, RB
, CY_in
, is_sub
):
426 # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None
427 if len(RA
.regs
) != len(RB
.regs
):
428 raise TypeError(f
"source lengths must match: "
429 f
"{RA} doesn't match {RB}")
430 self
.RT
= RA
.like(op
=self
, arg_name
="RT")
434 self
.CY_out
= CY_in
.like(op
=self
, arg_name
="CY_out")
437 def possible_reg_assignments(self
, val
, value_assignments
):
438 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
439 if val
not in self
.input_ssa_vals() \
440 and val
not in self
.output_ssa_vals():
441 raise ValueError(f
"{val} must be an operand of {self}")
442 if val
.get_phys_loc(value_assignments
) is not None:
443 raise ValueError(f
"{val} already assigned a physical location")
444 if self
.CY_in
== val
or self
.CY_out
== val
:
446 elif val
in self
.RT
.regs
:
447 # since possible_reg_assignments only returns aligned
448 # vectors, all possible assignments either are the same as an
449 # input or don't overlap with an input and we avoid the incorrect
450 # results caused by partial overlaps overwriting input elements
451 # before they're read
452 yield from self
.RT
.possible_reg_assignments(val
, value_assignments
)
453 elif val
in self
.RA
.regs
:
454 yield from self
.RA
.possible_reg_assignments(val
, value_assignments
)
456 yield from self
.RB
.possible_reg_assignments(val
, value_assignments
)
458 def get_equality_constraints(self
):
459 # type: () -> Iterable[EqualityConstraint]
460 yield EqualityConstraint(self
.CY_in
, self
.CY_out
)
464 # type: (None | GPR | range) -> set[GPR]
467 if isinstance(v
, range):
468 return set(map(GPR
, v
))
472 @plain_data(unsafe_hash
=True, frozen
=True)
474 class OpBigIntMulDiv(Op
):
475 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div"
478 # type: () -> dict[str, VecArg | SSAVal]
479 return {"RA": self
.RA
, "RB": self
.RB
, "RC": self
.RC
}
482 # type: () -> dict[str, VecArg | SSAVal]
483 return {"RT": self
.RT
, "RS": self
.RS
}
485 def __init__(self
, RA
, RB
, RC
, is_div
):
486 # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None
487 self
.RT
= RA
.like(op
=self
, arg_name
="RT")
491 self
.RS
= RC
.like(op
=self
, arg_name
="RS")
494 def possible_reg_assignments(self
, val
, value_assignments
):
495 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
496 if val
not in self
.input_ssa_vals() \
497 and val
not in self
.output_ssa_vals():
498 raise ValueError(f
"{val} must be an operand of {self}")
499 if val
.get_phys_loc(value_assignments
) is not None:
500 raise ValueError(f
"{val} already assigned a physical location")
502 RT_range
= self
.RT
.try_get_range(value_assignments
,
503 allow_unassigned
=True,
504 raise_if_invalid
=True)
505 RA_range
= self
.RA
.try_get_range(value_assignments
,
506 allow_unassigned
=True,
507 raise_if_invalid
=True)
508 RC_RS_reg
= self
.RC
.get_reg(value_assignments
)
509 if RC_RS_reg
is None:
510 RC_RS_reg
= self
.RS
.get_reg(value_assignments
)
512 if self
.RC
== val
or self
.RS
== val
:
513 if RC_RS_reg
is not None:
516 conflicting_regs
= to_reg_set(RT_range
) |
to_reg_set(RA_range
)
517 yield from self
.RC
.possible_reg_assignments(value_assignments
,
519 elif val
in self
.RT
.regs
:
520 # since possible_reg_assignments only returns aligned
521 # vectors, all possible assignments either are the same as
522 # RA or don't overlap with RA and we avoid the incorrect
523 # results caused by partial overlaps overwriting input elements
524 # before they're read
525 yield from self
.RT
.possible_reg_assignments(
526 val
, value_assignments
,
527 conflicting_regs
=to_reg_set(RA_range
) |
to_reg_set(RC_RS_reg
))
529 yield from self
.RA
.possible_reg_assignments(
530 val
, value_assignments
,
531 conflicting_regs
=to_reg_set(RT_range
) |
to_reg_set(RC_RS_reg
))
533 def get_equality_constraints(self
):
534 # type: () -> Iterable[EqualityConstraint]
535 yield EqualityConstraint(self
.RC
, self
.RS
)
540 class ShiftKind(Enum
):
546 @plain_data(unsafe_hash
=True, frozen
=True)
548 class OpBigIntShift(Op
):
549 __slots__
= "RT", "inp", "sh", "kind"
552 # type: () -> dict[str, VecArg | SSAVal]
553 return {"inp": self
.inp
, "sh": self
.sh
}
556 # type: () -> dict[str, VecArg | SSAVal]
557 return {"RT": self
.RT
}
559 def __init__(self
, inp
, sh
, kind
):
560 # type: (VecArg, SSAGPRVal, ShiftKind) -> None
561 self
.RT
= inp
.like(op
=self
, arg_name
="RT")
566 def possible_reg_assignments(self
, val
, value_assignments
):
567 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
568 if val
not in self
.input_ssa_vals() \
569 and val
not in self
.output_ssa_vals():
570 raise ValueError(f
"{val} must be an operand of {self}")
571 if val
.get_phys_loc(value_assignments
) is not None:
572 raise ValueError(f
"{val} already assigned a physical location")
574 RT_range
= self
.RT
.try_get_range(value_assignments
,
575 allow_unassigned
=True,
576 raise_if_invalid
=True)
577 inp_range
= self
.inp
.try_get_range(value_assignments
,
578 allow_unassigned
=True,
579 raise_if_invalid
=True)
580 sh_reg
= self
.sh
.get_reg(value_assignments
)
583 conflicting_regs
= to_reg_set(RT_range
)
584 yield from self
.sh
.possible_reg_assignments(value_assignments
,
586 elif val
in self
.RT
.regs
:
587 # since possible_reg_assignments only returns aligned
588 # vectors, all possible assignments either are the same as
589 # RA or don't overlap with RA and we avoid the incorrect
590 # results caused by partial overlaps overwriting input elements
591 # before they're read
592 yield from self
.RT
.possible_reg_assignments(
593 val
, value_assignments
,
594 conflicting_regs
=to_reg_set(inp_range
) |
to_reg_set(sh_reg
))
596 yield from self
.inp
.possible_reg_assignments(
597 val
, value_assignments
,
598 conflicting_regs
=to_reg_set(RT_range
))
601 @plain_data(unsafe_hash
=True, frozen
=True)
604 __slots__
= "out", "value"
607 # type: () -> dict[str, VecArg | SSAVal]
611 # type: () -> dict[str, VecArg | SSAVal]
612 return {"out": self
.out
}
614 def __init__(self
, value
, element_count
=None):
615 # type: (int, int | None) -> None
616 self
.out
= vec_or_scalar_arg(element_count
, op
=self
, arg_name
="out")
619 def possible_reg_assignments(self
, val
, value_assignments
):
620 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
621 if val
not in self
.input_ssa_vals() \
622 and val
not in self
.output_ssa_vals():
623 raise ValueError(f
"{val} must be an operand of {self}")
624 if val
.get_phys_loc(value_assignments
) is not None:
625 raise ValueError(f
"{val} already assigned a physical location")
627 if isinstance(self
.out
, VecArg
):
628 yield from self
.out
.possible_reg_assignments(val
,
631 yield from self
.out
.possible_reg_assignments(value_assignments
)
634 @plain_data(unsafe_hash
=True, frozen
=True)
640 # type: () -> dict[str, VecArg | SSAVal]
644 # type: () -> dict[str, VecArg | SSAVal]
645 return {"out": self
.out
}
649 self
.out
= SSAXERBitVal(op
=self
, arg_name
="out", element_index
=0,
652 def possible_reg_assignments(self
, val
, value_assignments
):
653 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
654 if val
not in self
.input_ssa_vals() \
655 and val
not in self
.output_ssa_vals():
656 raise ValueError(f
"{val} must be an operand of {self}")
657 if val
.get_phys_loc(value_assignments
) is not None:
658 raise ValueError(f
"{val} already assigned a physical location")
663 @plain_data(unsafe_hash
=True, frozen
=True)
666 __slots__
= "RT", "RA", "offset", "mem"
669 # type: () -> dict[str, VecArg | SSAVal]
670 return {"RA": self
.RA
, "mem": self
.mem
}
673 # type: () -> dict[str, VecArg | SSAVal]
674 return {"RT": self
.RT
}
676 def __init__(self
, RA
, offset
, mem
, element_count
=None):
677 # type: (SSAGPRVal, int, SSAMemory, int | None) -> None
678 self
.RT
= vec_or_scalar_arg(element_count
, op
=self
, arg_name
="RT")
683 def possible_reg_assignments(self
, val
, value_assignments
):
684 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
685 if val
not in self
.input_ssa_vals() \
686 and val
not in self
.output_ssa_vals():
687 raise ValueError(f
"{val} must be an operand of {self}")
688 if val
.get_phys_loc(value_assignments
) is not None:
689 raise ValueError(f
"{val} already assigned a physical location")
691 RA_reg
= self
.RA
.get_reg(value_assignments
)
694 yield GlobalMem
.GlobalMem
696 if isinstance(self
.RT
, VecArg
):
697 conflicting_regs
= to_reg_set(self
.RT
.try_get_range(
698 value_assignments
, allow_unassigned
=True,
699 raise_if_invalid
=True))
701 conflicting_regs
= set()
702 yield from self
.RA
.possible_reg_assignments(value_assignments
,
704 elif isinstance(self
.RT
, VecArg
):
705 yield from self
.RT
.possible_reg_assignments(
706 val
, value_assignments
,
707 conflicting_regs
=to_reg_set(RA_reg
))
709 yield from self
.RT
.possible_reg_assignments(value_assignments
)
712 @plain_data(unsafe_hash
=True, frozen
=True)
715 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out"
718 # type: () -> dict[str, VecArg | SSAVal]
719 return {"RS": self
.RS
, "RA": self
.RA
, "mem_in": self
.mem_in
}
722 # type: () -> dict[str, VecArg | SSAVal]
723 return {"mem_out": self
.mem_out
}
725 def __init__(self
, RS
, RA
, offset
, mem_in
):
726 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
731 self
.mem_out
= mem_in
.like(op
=self
, arg_name
="mem_out")
733 def possible_reg_assignments(self
, val
, value_assignments
):
734 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
735 if val
not in self
.input_ssa_vals() \
736 and val
not in self
.output_ssa_vals():
737 raise ValueError(f
"{val} must be an operand of {self}")
738 if val
.get_phys_loc(value_assignments
) is not None:
739 raise ValueError(f
"{val} already assigned a physical location")
741 if self
.mem_in
== val
or self
.mem_out
== val
:
742 yield GlobalMem
.GlobalMem
744 yield from self
.RA
.possible_reg_assignments(value_assignments
)
745 elif isinstance(self
.RS
, VecArg
):
746 yield from self
.RS
.possible_reg_assignments(val
, value_assignments
)
748 yield from self
.RS
.possible_reg_assignments(value_assignments
)
750 def get_equality_constraints(self
):
751 # type: () -> Iterable[EqualityConstraint]
752 yield EqualityConstraint(self
.mem_in
, self
.mem_out
)
755 @plain_data(unsafe_hash
=True, frozen
=True)
761 # type: () -> dict[str, VecArg | SSAVal]
765 # type: () -> dict[str, VecArg | SSAVal]
766 return {"out": self
.out
}
768 def __init__(self
, phys_loc
):
769 # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None
770 if isinstance(phys_loc
, GPROrStackLoc
):
771 self
.out
= SSAGPRVal(self
, "out", 0, phys_loc
)
774 SSAGPRVal(self
, "out", i
, v
) for i
, v
in enumerate(phys_loc
))
776 def possible_reg_assignments(self
, val
, value_assignments
):
777 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
778 if val
not in self
.input_ssa_vals() \
779 and val
not in self
.output_ssa_vals():
780 raise ValueError(f
"{val} must be an operand of {self}")
781 if val
.get_phys_loc(value_assignments
) is not None:
782 raise ValueError(f
"{val} already assigned a physical location")
784 if isinstance(self
.out
, VecArg
):
785 yield from self
.out
.possible_reg_assignments(val
,
788 yield from self
.out
.possible_reg_assignments(value_assignments
)
791 @plain_data(unsafe_hash
=True, frozen
=True)
793 class OpInputMem(Op
):
797 # type: () -> dict[str, VecArg | SSAVal]
801 # type: () -> dict[str, VecArg | SSAVal]
802 return {"out": self
.out
}
806 self
.out
= SSAMemory(op
=self
, arg_name
="out", element_index
=0)
808 def possible_reg_assignments(self
, val
, value_assignments
):
809 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
810 if val
not in self
.input_ssa_vals() \
811 and val
not in self
.output_ssa_vals():
812 raise ValueError(f
"{val} must be an operand of {self}")
813 if val
.get_phys_loc(value_assignments
) is not None:
814 raise ValueError(f
"{val} already assigned a physical location")
816 yield GlobalMem
.GlobalMem
819 def op_set_to_list(ops
):
820 # type: (Iterable[Op]) -> list[Op]
821 worklists
= [set()] # type: list[set[Op]]
822 input_vals_to_ops_map
= defaultdict(set) # type: dict[SSAVal, set[Op]]
823 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
826 for val
in op
.input_ssa_vals():
828 input_vals_to_ops_map
[val
].add(op
)
829 while len(worklists
) <= input_count
:
830 worklists
.append(set())
831 ops_to_pending_input_count_map
[op
] = input_count
832 worklists
[input_count
].add(op
)
833 retval
= [] # type: list[Op]
834 ready_vals
= set() # type: set[SSAVal]
835 while len(worklists
[0]) != 0:
836 writing_op
= worklists
[0].pop()
837 retval
.append(writing_op
)
838 for val
in writing_op
.output_ssa_vals():
839 if val
in ready_vals
:
840 raise ValueError(f
"multiple instructions must not write "
841 f
"to the same SSA value: {val}")
843 for reading_op
in input_vals_to_ops_map
[val
]:
844 pending
= ops_to_pending_input_count_map
[reading_op
]
845 worklists
[pending
].remove(reading_op
)
847 worklists
[pending
].add(reading_op
)
848 ops_to_pending_input_count_map
[reading_op
] = pending
849 for worklist
in worklists
:
851 raise ValueError(f
"instruction is part of a dependency loop or "
852 f
"its inputs are never written: {op}")
856 @plain_data(unsafe_hash
=True, order
=True, frozen
=True)
858 __slots__
= "first_write", "last_use"
860 def __init__(self
, first_write
, last_use
=None):
861 # type: (int, int | None) -> None
863 last_use
= first_write
864 if last_use
< first_write
:
865 raise ValueError("uses must be after first_write")
866 if first_write
< 0 or last_use
< 0:
867 raise ValueError("indexes must be nonnegative")
868 self
.first_write
= first_write
869 self
.last_use
= last_use
871 def overlaps(self
, other
):
872 # type: (LiveInterval) -> bool
873 if self
.first_write
== other
.first_write
:
875 return self
.last_use
> other
.first_write \
876 and other
.last_use
> self
.first_write
878 def __add__(self
, use
):
879 # type: (int) -> LiveInterval
880 last_use
= max(self
.last_use
, use
)
881 return LiveInterval(first_write
=self
.first_write
, last_use
=last_use
)
885 class EqualitySet(AbstractSet
[SSAVal
]):
886 def __init__(self
, items
):
887 # type: (Iterable[SSAVal]) -> None
888 self
.__items
= frozenset(items
)
890 def __contains__(self
, x
):
891 # type: (object) -> bool
892 return x
in self
.__items
895 return iter(self
.__items
)
898 return len(self
.__items
)
901 return super()._hash
()
905 class EqualitySets(Mapping
[SSAVal
, EqualitySet
]):
906 def __init__(self
, ops
):
907 # type: (Iterable[Op]) -> None
908 indexes
= {} # type: dict[SSAVal, int]
909 sets
= [] # type: list[set[SSAVal]]
911 for val
in (*op
.input_ssa_vals(), *op
.output_ssa_vals()):
912 if val
not in indexes
:
913 indexes
[val
] = len(sets
)
915 for e
in op
.get_equality_constraints():
916 lhs_index
= indexes
[e
.lhs
]
917 rhs_index
= indexes
[e
.rhs
]
918 sets
[lhs_index
] |
= sets
[rhs_index
]
919 for val
in sets
[rhs_index
]:
920 indexes
[val
] = lhs_index
922 equality_sets
= [EqualitySet(i
) for i
in sets
]
923 self
.__map
= {k
: equality_sets
[v
] for k
, v
in indexes
.items()}
925 def __getitem__(self
, key
):
926 # type: (SSAVal) -> EqualitySet
927 return self
.__map
[key
]
930 return iter(self
.__map
)
934 class LiveIntervals(Mapping
[EqualitySet
, LiveInterval
]):
935 def __init__(self
, ops
):
936 # type: (list[Op]) -> None
937 self
.__equality
_sets
= eqsets
= EqualitySets(ops
)
938 live_intervals
= {} # type: dict[EqualitySet, LiveInterval]
939 for op_idx
, op
in enumerate(ops
):
940 for val
in op
.input_ssa_vals():
941 live_intervals
[eqsets
[val
]] += op_idx
942 for val
in op
.output_ssa_vals():
943 if eqsets
[val
] not in live_intervals
:
944 live_intervals
[eqsets
[val
]] = LiveInterval(op_idx
)
946 live_intervals
[eqsets
[val
]] += op_idx
947 self
.__live
_intervals
= live_intervals
950 def equality_sets(self
):
951 return self
.__equality
_sets
953 def __getitem__(self
, key
):
954 # type: (EqualitySet) -> LiveInterval
955 return self
.__live
_intervals
[key
]
958 return iter(self
.__live
_intervals
)
963 """ interference graph node """
964 __slots__
= "equality_set", "edges"
966 def __init__(self
, equality_set
, edges
=()):
967 # type: (EqualitySet, Iterable[IGNode]) -> None
968 self
.equality_set
= equality_set
969 self
.edges
= set(edges
)
971 def add_edge(self
, other
):
972 # type: (IGNode) -> None
973 self
.edges
.add(other
)
974 other
.edges
.add(self
)
976 def __eq__(self
, other
):
977 # type: (object) -> bool
978 if isinstance(other
, IGNode
):
979 return self
.equality_set
== other
.equality_set
980 return NotImplemented
983 return self
.equality_set
.__hash
__()
985 def __repr__(self
, nodes
=None):
986 # type: (None | dict[IGNode, int]) -> str
990 return f
"<IGNode #{nodes[self]}>"
991 nodes
[self
] = len(nodes
)
992 edges
= "{" + ", ".join(i
.__repr
__(nodes
) for i
in self
.edges
) + "}"
993 return (f
"IGNode(#{nodes[self]}, "
994 f
"equality_set={self.equality_set}, "
999 class InterferenceGraph(Mapping
[EqualitySet
, IGNode
]):
1000 def __init__(self
, equality_sets
):
1001 # type: (Iterable[EqualitySet]) -> None
1002 self
.__nodes
= {i
: IGNode(i
) for i
in equality_sets
}
1004 def __getitem__(self
, key
):
1005 # type: (EqualitySet) -> IGNode
1006 return self
.__nodes
[key
]
1009 return iter(self
.__nodes
)
1013 class AllocationFailed
:
1014 __slots__
= "op_idx", "arg", "live_intervals"
1016 def __init__(self
, op_idx
, arg
, live_intervals
):
1017 # type: (int, SSAVal | VecArg, LiveIntervals) -> None
1018 self
.op_idx
= op_idx
1020 self
.live_intervals
= live_intervals
1023 def try_allocate_registers_without_spilling(ops
):
1024 # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
1025 live_intervals
= LiveIntervals(ops
)
1027 def is_constrained(node
):
1028 # type: (EqualitySet) -> bool
1029 raise NotImplementedError
1031 raise NotImplementedError
1034 def allocate_registers(ops
):
1035 # type: (list[Op]) -> None
1036 raise NotImplementedError