e6ffe4e606076577b57bb64a208b8b70a42c7954
2 from abc
import abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
5 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
6 Sequence
, TypeVar
, overload
)
7 from weakref
import WeakValueDictionary
as _WeakVDict
9 from cached_property
import cached_property
10 from nmutil
.plain_data
import fields
, plain_data
12 from bigint_presentation_code
.type_util
import Self
, assert_never
, final
13 from bigint_presentation_code
.util
import BitSet
, FBitSet
, FMap
, OFSet
19 self
.ops
= [] # type: list[Op]
20 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
21 self
.__next
_name
_suffix
= 2
23 def _add_op_with_unused_name(self
, op
, name
=""):
24 # type: (Op, str) -> str
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op
, "name"):
28 raise ValueError("Op already named")
31 if name
!= "" and name
not in self
.__op
_names
:
32 self
.__op
_names
[name
] = op
34 name
= orig_name
+ str(self
.__next
_name
_suffix
)
35 self
.__next
_name
_suffix
+= 1
41 def append_op(self
, op
):
44 raise ValueError("can't add Op to wrong Fn")
47 def append_new_op(self
, kind
, inputs
=(), immediates
=(), name
="", maxvl
=1):
48 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
49 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
50 inputs
=inputs
, immediates
=immediates
, name
=name
)
51 self
.append_op(retval
)
54 def pre_ra_sim(self
, state
):
55 # type: (PreRASimState) -> None
65 VL_MAXVL
= enum
.auto()
68 def only_scalar(self
):
70 if self
is BaseTy
.I64
:
72 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
78 def max_reg_len(self
):
80 if self
is BaseTy
.I64
:
82 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
88 return "BaseTy." + self
._name
_
91 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
94 __slots__
= "base_ty", "reg_len"
97 def validate(base_ty
, reg_len
):
98 # type: (BaseTy, int) -> str | None
99 """ return a string with the error if the combination is invalid,
100 otherwise return None
102 if base_ty
.only_scalar
and reg_len
!= 1:
103 return f
"can't create a vector of an only-scalar type: {base_ty}"
104 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
105 return "reg_len out of range"
108 def __init__(self
, base_ty
, reg_len
):
109 # type: (BaseTy, int) -> None
110 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
112 raise ValueError(msg
)
113 self
.base_ty
= base_ty
114 self
.reg_len
= reg_len
118 if self
.reg_len
!= 1:
119 reg_len
= f
"*{self.reg_len}"
122 return f
"<{self.base_ty._name_}{reg_len}>"
129 StackI64
= enum
.auto()
131 VL_MAXVL
= enum
.auto()
136 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
138 if self
is LocKind
.CA
:
140 if self
is LocKind
.VL_MAXVL
:
141 return BaseTy
.VL_MAXVL
148 if self
is LocKind
.StackI64
:
150 if self
is LocKind
.GPR
or self
is LocKind
.CA \
151 or self
is LocKind
.VL_MAXVL
:
152 return self
.base_ty
.max_reg_len
157 return "LocKind." + self
._name
_
162 class LocSubKind(Enum
):
163 BASE_GPR
= enum
.auto()
164 SV_EXTRA2_VGPR
= enum
.auto()
165 SV_EXTRA2_SGPR
= enum
.auto()
166 SV_EXTRA3_VGPR
= enum
.auto()
167 SV_EXTRA3_SGPR
= enum
.auto()
168 StackI64
= enum
.auto()
170 VL_MAXVL
= enum
.auto()
174 # type: () -> LocKind
175 # pyright fails typechecking when using `in` here:
176 # reported: https://github.com/microsoft/pyright/issues/4102
177 if self
is LocSubKind
.BASE_GPR
or self
is LocSubKind
.SV_EXTRA2_VGPR \
178 or self
is LocSubKind
.SV_EXTRA2_SGPR \
179 or self
is LocSubKind
.SV_EXTRA3_VGPR \
180 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
182 if self
is LocSubKind
.StackI64
:
183 return LocKind
.StackI64
184 if self
is LocSubKind
.CA
:
186 if self
is LocSubKind
.VL_MAXVL
:
187 return LocKind
.VL_MAXVL
192 return self
.kind
.base_ty
195 def allocatable_locs(self
, ty
):
196 # type: (Ty) -> LocSet
197 if ty
.base_ty
!= self
.base_ty
:
198 raise ValueError("type mismatch")
199 if self
is LocSubKind
.BASE_GPR
:
201 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
202 starts
= range(0, 128, 2)
203 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
205 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
206 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
208 elif self
is LocSubKind
.StackI64
:
209 starts
= range(LocKind
.StackI64
.loc_count
)
210 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
211 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
214 retval
= [] # type: list[Loc]
216 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
220 for special_loc
in SPECIAL_GPRS
:
221 if loc
.conflicts(special_loc
):
226 return LocSet(retval
)
229 return "LocSubKind." + self
._name
_
232 @plain_data(frozen
=True, unsafe_hash
=True)
235 __slots__
= "base_ty", "is_vec"
237 def __init__(self
, base_ty
, is_vec
):
238 # type: (BaseTy, bool) -> None
239 self
.base_ty
= base_ty
240 if base_ty
.only_scalar
and is_vec
:
241 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
244 def instantiate(self
, maxvl
):
246 # here's where subvl and elwid would be accounted for
248 return Ty(self
.base_ty
, maxvl
)
249 return Ty(self
.base_ty
, 1)
251 def can_instantiate_to(self
, ty
):
253 if self
.base_ty
!= ty
.base_ty
:
257 return ty
.reg_len
== 1
260 @plain_data(frozen
=True, unsafe_hash
=True)
263 __slots__
= "kind", "start", "reg_len"
266 def validate(kind
, start
, reg_len
):
267 # type: (LocKind, int, int) -> str | None
268 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
271 if reg_len
> kind
.loc_count
:
272 return "invalid reg_len"
273 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
274 return "start not in valid range"
278 def try_make(kind
, start
, reg_len
):
279 # type: (LocKind, int, int) -> Loc | None
280 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
283 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
285 def __init__(self
, kind
, start
, reg_len
):
286 # type: (LocKind, int, int) -> None
287 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
289 raise ValueError(msg
)
291 self
.reg_len
= reg_len
294 def conflicts(self
, other
):
295 # type: (Loc) -> bool
296 return (self
.kind
== other
.kind
297 and self
.start
< other
.stop
and other
.start
< self
.stop
)
300 def make_ty(kind
, reg_len
):
301 # type: (LocKind, int) -> Ty
302 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
307 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
312 return self
.start
+ self
.reg_len
314 def try_concat(self
, *others
):
315 # type: (*Loc | None) -> Loc | None
316 reg_len
= self
.reg_len
319 if other
is None or other
.kind
!= self
.kind
:
321 if stop
!= other
.start
:
324 reg_len
+= other
.reg_len
325 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
329 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
330 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
331 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
332 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
336 @plain_data(frozen
=True, eq
=False)
338 class LocSet(AbstractSet
[Loc
]):
339 __slots__
= "starts", "ty"
341 def __init__(self
, __locs
=()):
342 # type: (Iterable[Loc]) -> None
343 if isinstance(__locs
, LocSet
):
344 self
.starts
= __locs
.starts
# type: FMap[LocKind, FBitSet]
345 self
.ty
= __locs
.ty
# type: Ty | None
347 starts
= {i
: BitSet() for i
in LocKind
}
353 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
354 starts
[loc
.kind
].add(loc
.start
)
356 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
361 # type: () -> FMap[LocKind, FBitSet]
366 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
370 # type: () -> AbstractSet[LocKind]
371 return self
.starts
.keys()
375 # type: () -> int | None
378 return self
.ty
.reg_len
382 # type: () -> BaseTy | None
385 return self
.ty
.base_ty
387 def concat(self
, *others
):
388 # type: (*LocSet) -> LocSet
391 base_ty
= self
.ty
.base_ty
392 reg_len
= self
.ty
.reg_len
393 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
397 if other
.ty
.base_ty
!= base_ty
:
399 for kind
, other_starts
in other
.starts
.items():
400 if kind
not in starts
:
402 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
403 if starts
[kind
] == 0:
407 reg_len
+= other
.ty
.reg_len
410 # type: () -> Iterable[Loc]
411 for kind
, v
in starts
.items():
413 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
416 return LocSet(locs())
418 def __contains__(self
, loc
):
419 # type: (Loc | Any) -> bool
420 if not isinstance(loc
, Loc
) or loc
.ty
!= self
.ty
:
422 if loc
.kind
not in self
.starts
:
424 return loc
.start
in self
.starts
[loc
.kind
]
427 # type: () -> Iterator[Loc]
430 for kind
, starts
in self
.starts
.items():
432 yield Loc(kind
=kind
, start
=start
, reg_len
=self
.ty
.reg_len
)
436 return sum((len(v
) for v
in self
.starts
.values()), 0)
443 return super()._hash
()
449 @plain_data(frozen
=True, unsafe_hash
=True)
451 class GenericOperandDesc
:
452 """generic Op operand descriptor"""
453 __slots__
= "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
456 self
, ty
, # type: GenericTy
457 sub_kinds
, # type: Iterable[LocSubKind]
459 fixed_loc
=None, # type: Loc | None
460 tied_input_index
=None, # type: int | None
461 spread
=False, # type: bool
463 # type: (...) -> None
465 self
.sub_kinds
= OFSet(sub_kinds
)
466 if len(self
.sub_kinds
) == 0:
467 raise ValueError("sub_kinds can't be empty")
468 self
.fixed_loc
= fixed_loc
469 if fixed_loc
is not None:
470 if tied_input_index
is not None:
471 raise ValueError("operand can't be both tied and fixed")
472 if not ty
.can_instantiate_to(fixed_loc
.ty
):
474 f
"fixed_loc has incompatible type for given generic "
475 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
476 if len(self
.sub_kinds
) != 1:
478 "multiple sub_kinds not allowed for fixed operand")
479 for sub_kind
in self
.sub_kinds
:
480 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
482 f
"fixed_loc not in given sub_kind: "
483 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
484 for sub_kind
in self
.sub_kinds
:
485 if sub_kind
.base_ty
!= ty
.base_ty
:
486 raise ValueError(f
"sub_kind is incompatible with type: "
487 f
"sub_kind={sub_kind} ty={ty}")
488 if tied_input_index
is not None and tied_input_index
< 0:
489 raise ValueError("invalid tied_input_index")
490 self
.tied_input_index
= tied_input_index
493 if self
.tied_input_index
is not None:
494 raise ValueError("operand can't be both spread and tied")
495 if self
.fixed_loc
is not None:
496 raise ValueError("operand can't be both spread and fixed")
498 raise ValueError("operand can't be both spread and vector")
500 def tied_to_input(self
, tied_input_index
):
501 # type: (int) -> Self
502 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
503 tied_input_index
=tied_input_index
)
505 def with_fixed_loc(self
, fixed_loc
):
506 # type: (Loc) -> Self
507 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
)
509 def instantiate(self
, maxvl
):
510 # type: (int) -> Iterable[OperandDesc]
515 ty
= self
.ty
.instantiate(maxvl
=maxvl
)
518 # type: () -> Iterable[Loc]
519 if self
.fixed_loc
is not None:
520 if ty
!= self
.fixed_loc
.ty
:
522 f
"instantiation failed: type mismatch with fixed_loc: "
523 f
"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
526 for sub_kind
in self
.sub_kinds
:
527 yield from sub_kind
.allocatable_locs(ty
)
528 loc_set_before_spread
= LocSet(locs())
529 for idx
in range(rep_count
):
532 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
533 tied_input_index
=self
.tied_input_index
,
537 @plain_data(frozen
=True, unsafe_hash
=True)
540 """Op operand descriptor"""
541 __slots__
= "loc_set_before_spread", "tied_input_index", "spread_index"
543 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
):
544 # type: (LocSet, int | None, int | None) -> None
545 if len(loc_set_before_spread
) == 0:
546 raise ValueError("loc_set_before_spread must not be empty")
547 self
.loc_set_before_spread
= loc_set_before_spread
548 self
.tied_input_index
= tied_input_index
549 if self
.tied_input_index
is not None and self
.spread_index
is not None:
550 raise ValueError("operand can't be both spread and tied")
551 self
.spread_index
= spread_index
554 def ty_before_spread(self
):
556 ty
= self
.loc_set_before_spread
.ty
557 assert ty
is not None, (
558 "__init__ checked that the LocSet isn't empty, "
559 "non-empty LocSets should always have ty set")
564 """ Ty after any spread is applied """
565 if self
.spread_index
is not None:
566 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
567 return self
.ty_before_spread
570 OD_BASE_SGPR
= GenericOperandDesc(
571 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
572 sub_kinds
=[LocSubKind
.BASE_GPR
])
573 OD_EXTRA3_SGPR
= GenericOperandDesc(
574 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
575 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
576 OD_EXTRA3_VGPR
= GenericOperandDesc(
577 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
578 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
579 OD_EXTRA2_SGPR
= GenericOperandDesc(
580 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
581 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
582 OD_EXTRA2_VGPR
= GenericOperandDesc(
583 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
584 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
585 OD_CA
= GenericOperandDesc(
586 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
587 sub_kinds
=[LocSubKind
.CA
])
588 OD_VL
= GenericOperandDesc(
589 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
590 sub_kinds
=[LocSubKind
.VL_MAXVL
])
593 @plain_data(frozen
=True, unsafe_hash
=True)
595 class GenericOpProperties
:
596 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
597 "is_copy", "is_load_immediate", "has_side_effects")
600 self
, demo_asm
, # type: str
601 inputs
, # type: Iterable[GenericOperandDesc]
602 outputs
, # type: Iterable[GenericOperandDesc]
603 immediates
=(), # type: Iterable[range]
604 is_copy
=False, # type: bool
605 is_load_immediate
=False, # type: bool
606 has_side_effects
=False, # type: bool
608 # type: (...) -> None
609 self
.demo_asm
= demo_asm
610 self
.inputs
= tuple(inputs
)
611 for inp
in self
.inputs
:
612 if inp
.tied_input_index
is not None:
614 f
"tied_input_index is not allowed on inputs: {inp}")
615 self
.outputs
= tuple(outputs
)
616 fixed_locs
= [] # type: list[tuple[Loc, int]]
617 for idx
, out
in enumerate(self
.outputs
):
618 if out
.tied_input_index
is not None \
619 and out
.tied_input_index
>= len(self
.inputs
):
620 raise ValueError(f
"tied_input_index out of range: {out}")
621 if out
.fixed_loc
is not None:
622 for other_fixed_loc
, other_idx
in fixed_locs
:
623 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
626 f
"conflicting fixed_locs: outputs[{idx}] and "
627 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
628 f
"with {other_fixed_loc}")
629 fixed_locs
.append((out
.fixed_loc
, idx
))
630 self
.immediates
= tuple(immediates
)
631 self
.is_copy
= is_copy
632 self
.is_load_immediate
= is_load_immediate
633 self
.has_side_effects
= has_side_effects
636 @plain_data(frozen
=True, unsafe_hash
=True)
639 __slots__
= "kind", "inputs", "outputs", "maxvl"
641 def __init__(self
, kind
, maxvl
):
642 # type: (OpKind, int) -> None
644 inputs
= [] # type: list[OperandDesc]
645 for inp
in self
.generic
.inputs
:
646 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
647 self
.inputs
= tuple(inputs
)
648 outputs
= [] # type: list[OperandDesc]
649 for out
in self
.generic
.outputs
:
650 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
651 self
.outputs
= tuple(outputs
)
656 # type: () -> GenericOpProperties
657 return self
.kind
.properties
660 def immediates(self
):
661 # type: () -> tuple[range, ...]
662 return self
.generic
.immediates
667 return self
.generic
.demo_asm
672 return self
.generic
.is_copy
675 def is_load_immediate(self
):
677 return self
.generic
.is_load_immediate
680 def has_side_effects(self
):
682 return self
.generic
.has_side_effects
685 IMM_S16
= range(-1 << 15, 1 << 15)
687 _PRE_RA_SIM_FN
= Callable
[["Op", "PreRASimState"], None]
688 _PRE_RA_SIM_FN2
= Callable
[[], _PRE_RA_SIM_FN
]
689 _PRE_RA_SIMS
= {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
695 def __init__(self
, properties
):
696 # type: (GenericOpProperties) -> None
698 self
.__properties
= properties
701 def properties(self
):
702 # type: () -> GenericOpProperties
703 return self
.__properties
705 def instantiate(self
, maxvl
):
706 # type: (int) -> OpProperties
707 return OpProperties(self
, maxvl
=maxvl
)
710 return "OpKind." + self
._name
_
713 def pre_ra_sim(self
):
714 # type: () -> _PRE_RA_SIM_FN
715 return _PRE_RA_SIMS
[self
.properties
]()
718 def __clearca_pre_ra_sim(op
, state
):
719 # type: (Op, PreRASimState) -> None
720 state
.ssa_vals
[op
.outputs
[0]] = False,
721 ClearCA
= GenericOpProperties(
722 demo_asm
="addic 0, 0, 0",
726 _PRE_RA_SIMS
[ClearCA
] = lambda: OpKind
.__clearca
_pre
_ra
_sim
729 def __setca_pre_ra_sim(op
, state
):
730 # type: (Op, PreRASimState) -> None
731 state
.ssa_vals
[op
.outputs
[0]] = True,
732 SetCA
= GenericOpProperties(
733 demo_asm
="subfc 0, 0, 0",
737 _PRE_RA_SIMS
[SetCA
] = lambda: OpKind
.__setca
_pre
_ra
_sim
740 def __svadde_pre_ra_sim(op
, state
):
741 # type: (Op, PreRASimState) -> None
742 RA
= state
.ssa_vals
[op
.inputs
[0]]
743 RB
= state
.ssa_vals
[op
.inputs
[1]]
744 carry
, = state
.ssa_vals
[op
.inputs
[2]]
745 VL
, = state
.ssa_vals
[op
.inputs
[3]]
746 RT
= [] # type: list[int]
748 v
= RA
[i
] + RB
[i
] + carry
749 RT
.append(v
& GPR_VALUE_MASK
)
750 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
751 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
752 state
.ssa_vals
[op
.outputs
[1]] = carry
,
753 SvAddE
= GenericOpProperties(
754 demo_asm
="sv.adde *RT, *RA, *RB",
755 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
756 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
758 _PRE_RA_SIMS
[SvAddE
] = lambda: OpKind
.__svadde
_pre
_ra
_sim
761 def __svsubfe_pre_ra_sim(op
, state
):
762 # type: (Op, PreRASimState) -> None
763 RA
= state
.ssa_vals
[op
.inputs
[0]]
764 RB
= state
.ssa_vals
[op
.inputs
[1]]
765 carry
, = state
.ssa_vals
[op
.inputs
[2]]
766 VL
, = state
.ssa_vals
[op
.inputs
[3]]
767 RT
= [] # type: list[int]
769 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
770 RT
.append(v
& GPR_VALUE_MASK
)
771 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
772 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
773 state
.ssa_vals
[op
.outputs
[1]] = carry
,
774 SvSubFE
= GenericOpProperties(
775 demo_asm
="sv.subfe *RT, *RA, *RB",
776 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
777 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
779 _PRE_RA_SIMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_pre
_ra
_sim
782 def __svmaddedu_pre_ra_sim(op
, state
):
783 # type: (Op, PreRASimState) -> None
784 RA
= state
.ssa_vals
[op
.inputs
[0]]
785 RB
, = state
.ssa_vals
[op
.inputs
[1]]
786 carry
, = state
.ssa_vals
[op
.inputs
[2]]
787 VL
, = state
.ssa_vals
[op
.inputs
[3]]
788 RT
= [] # type: list[int]
790 v
= RA
[i
] * RB
+ carry
791 RT
.append(v
& GPR_VALUE_MASK
)
792 carry
= v
>> GPR_SIZE_IN_BITS
793 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
794 state
.ssa_vals
[op
.outputs
[1]] = carry
,
795 SvMAddEDU
= GenericOpProperties(
796 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
797 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
798 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
800 _PRE_RA_SIMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_pre
_ra
_sim
803 def __setvli_pre_ra_sim(op
, state
):
804 # type: (Op, PreRASimState) -> None
805 state
.ssa_vals
[op
.outputs
[0]] = op
.immediates
[0],
806 SetVLI
= GenericOpProperties(
807 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
810 immediates
=[range(1, 65)],
811 is_load_immediate
=True,
813 _PRE_RA_SIMS
[SetVLI
] = lambda: OpKind
.__setvli
_pre
_ra
_sim
816 def __svli_pre_ra_sim(op
, state
):
817 # type: (Op, PreRASimState) -> None
818 VL
, = state
.ssa_vals
[op
.inputs
[0]]
819 imm
= op
.immediates
[0] & GPR_VALUE_MASK
820 state
.ssa_vals
[op
.outputs
[0]] = (imm
,) * VL
821 SvLI
= GenericOpProperties(
822 demo_asm
="sv.addi *RT, 0, imm",
824 outputs
=[OD_EXTRA3_VGPR
],
825 immediates
=[IMM_S16
],
826 is_load_immediate
=True,
828 _PRE_RA_SIMS
[SvLI
] = lambda: OpKind
.__svli
_pre
_ra
_sim
831 def __li_pre_ra_sim(op
, state
):
832 # type: (Op, PreRASimState) -> None
833 imm
= op
.immediates
[0] & GPR_VALUE_MASK
834 state
.ssa_vals
[op
.outputs
[0]] = imm
,
835 LI
= GenericOpProperties(
836 demo_asm
="addi RT, 0, imm",
838 outputs
=[OD_BASE_SGPR
],
839 immediates
=[IMM_S16
],
840 is_load_immediate
=True,
842 _PRE_RA_SIMS
[LI
] = lambda: OpKind
.__li
_pre
_ra
_sim
845 def __veccopytoreg_pre_ra_sim(op
, state
):
846 # type: (Op, PreRASimState) -> None
847 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
848 VecCopyToReg
= GenericOpProperties(
849 demo_asm
="sv.mv dest, src",
850 inputs
=[GenericOperandDesc(
851 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
852 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
854 outputs
=[OD_EXTRA3_VGPR
],
857 _PRE_RA_SIMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_pre
_ra
_sim
860 def __veccopyfromreg_pre_ra_sim(op
, state
):
861 # type: (Op, PreRASimState) -> None
862 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
863 VecCopyFromReg
= GenericOpProperties(
864 demo_asm
="sv.mv dest, src",
865 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
866 outputs
=[GenericOperandDesc(
867 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
868 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
872 _PRE_RA_SIMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_pre
_ra
_sim
875 def __copytoreg_pre_ra_sim(op
, state
):
876 # type: (Op, PreRASimState) -> None
877 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
878 CopyToReg
= GenericOpProperties(
879 demo_asm
="mv dest, src",
880 inputs
=[GenericOperandDesc(
881 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
882 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
883 LocSubKind
.StackI64
],
885 outputs
=[GenericOperandDesc(
886 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
887 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
891 _PRE_RA_SIMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_pre
_ra
_sim
894 def __copyfromreg_pre_ra_sim(op
, state
):
895 # type: (Op, PreRASimState) -> None
896 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
897 CopyFromReg
= GenericOpProperties(
898 demo_asm
="mv dest, src",
899 inputs
=[GenericOperandDesc(
900 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
901 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
903 outputs
=[GenericOperandDesc(
904 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
905 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
906 LocSubKind
.StackI64
],
910 _PRE_RA_SIMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_pre
_ra
_sim
913 def __concat_pre_ra_sim(op
, state
):
914 # type: (Op, PreRASimState) -> None
915 state
.ssa_vals
[op
.outputs
[0]] = tuple(
916 state
.ssa_vals
[i
][0] for i
in op
.inputs
[:-1])
917 Concat
= GenericOpProperties(
918 demo_asm
="sv.mv dest, src",
919 inputs
=[GenericOperandDesc(
920 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
921 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
924 outputs
=[OD_EXTRA3_VGPR
],
927 _PRE_RA_SIMS
[Concat
] = lambda: OpKind
.__concat
_pre
_ra
_sim
930 def __spread_pre_ra_sim(op
, state
):
931 # type: (Op, PreRASimState) -> None
932 for idx
, inp
in enumerate(state
.ssa_vals
[op
.inputs
[0]]):
933 state
.ssa_vals
[op
.outputs
[idx
]] = inp
,
934 Spread
= GenericOpProperties(
935 demo_asm
="sv.mv dest, src",
936 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
937 outputs
=[GenericOperandDesc(
938 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
939 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
944 _PRE_RA_SIMS
[Spread
] = lambda: OpKind
.__spread
_pre
_ra
_sim
947 def __svld_pre_ra_sim(op
, state
):
948 # type: (Op, PreRASimState) -> None
949 RA
, = state
.ssa_vals
[op
.inputs
[0]]
950 VL
, = state
.ssa_vals
[op
.inputs
[1]]
951 addr
= RA
+ op
.immediates
[0]
952 RT
= [] # type: list[int]
954 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
955 RT
.append(v
& GPR_VALUE_MASK
)
956 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
957 SvLd
= GenericOpProperties(
958 demo_asm
="sv.ld *RT, imm(RA)",
959 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
960 outputs
=[OD_EXTRA3_VGPR
],
961 immediates
=[IMM_S16
],
963 _PRE_RA_SIMS
[SvLd
] = lambda: OpKind
.__svld
_pre
_ra
_sim
966 def __ld_pre_ra_sim(op
, state
):
967 # type: (Op, PreRASimState) -> None
968 RA
, = state
.ssa_vals
[op
.inputs
[0]]
969 addr
= RA
+ op
.immediates
[0]
971 state
.ssa_vals
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
972 Ld
= GenericOpProperties(
973 demo_asm
="ld RT, imm(RA)",
974 inputs
=[OD_BASE_SGPR
],
975 outputs
=[OD_BASE_SGPR
],
976 immediates
=[IMM_S16
],
978 _PRE_RA_SIMS
[Ld
] = lambda: OpKind
.__ld
_pre
_ra
_sim
981 def __svstd_pre_ra_sim(op
, state
):
982 # type: (Op, PreRASimState) -> None
983 RS
= state
.ssa_vals
[op
.inputs
[0]]
984 RA
, = state
.ssa_vals
[op
.inputs
[1]]
985 VL
, = state
.ssa_vals
[op
.inputs
[2]]
986 addr
= RA
+ op
.immediates
[0]
988 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
989 SvStd
= GenericOpProperties(
990 demo_asm
="sv.std *RS, imm(RA)",
991 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
993 immediates
=[IMM_S16
],
994 has_side_effects
=True,
996 _PRE_RA_SIMS
[SvStd
] = lambda: OpKind
.__svstd
_pre
_ra
_sim
999 def __std_pre_ra_sim(op
, state
):
1000 # type: (Op, PreRASimState) -> None
1001 RS
, = state
.ssa_vals
[op
.inputs
[0]]
1002 RA
, = state
.ssa_vals
[op
.inputs
[1]]
1003 addr
= RA
+ op
.immediates
[0]
1004 state
.store(addr
, value
=RS
)
1005 Std
= GenericOpProperties(
1006 demo_asm
="std RT, imm(RA)",
1007 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1009 immediates
=[IMM_S16
],
1010 has_side_effects
=True,
1012 _PRE_RA_SIMS
[Std
] = lambda: OpKind
.__std
_pre
_ra
_sim
1015 def __funcargr3_pre_ra_sim(op
, state
):
1016 # type: (Op, PreRASimState) -> None
1017 pass # return value set before simulation
1018 FuncArgR3
= GenericOpProperties(
1021 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1022 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1024 _PRE_RA_SIMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_pre
_ra
_sim
1027 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1030 __slots__
= "op", "output_idx"
1032 def __init__(self
, op
, output_idx
):
1033 # type: (Op, int) -> None
1035 if output_idx
< 0 or output_idx
>= len(op
.properties
.outputs
):
1036 raise ValueError("invalid output_idx")
1037 self
.output_idx
= output_idx
1041 return f
"<{self.op.name}#{self.output_idx}: {self.ty}>"
1044 def defining_descriptor(self
):
1045 # type: () -> OperandDesc
1046 return self
.op
.properties
.outputs
[self
.output_idx
]
1049 def loc_set_before_spread(self
):
1050 # type: () -> LocSet
1051 return self
.defining_descriptor
.loc_set_before_spread
1056 return self
.defining_descriptor
.ty
1059 def ty_before_spread(self
):
1061 return self
.defining_descriptor
.ty_before_spread
1065 _Desc
= TypeVar("_Desc")
1068 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1070 def _verify_write_with_desc(self
, idx
, item
, desc
):
1071 # type: (int, _T | Any, _Desc) -> None
1072 raise NotImplementedError
1075 def _verify_write(self
, idx
, item
):
1076 # type: (int | Any, _T | Any) -> int
1077 if not isinstance(idx
, int):
1078 if isinstance(idx
, slice):
1080 f
"can't write to slice of {self.__class__.__name__}")
1081 raise TypeError(f
"can't write with index {idx!r}")
1082 # normalize idx, raising IndexError if it is out of range
1083 idx
= range(len(self
.descriptors
))[idx
]
1084 desc
= self
.descriptors
[idx
]
1085 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1089 def _get_descriptors(self
):
1090 # type: () -> tuple[_Desc, ...]
1091 raise NotImplementedError
1095 def descriptors(self
):
1096 # type: () -> tuple[_Desc, ...]
1097 return self
._get
_descriptors
()
1104 def __init__(self
, items
, op
):
1105 # type: (Iterable[_T], Op) -> None
1107 self
.__items
= [] # type: list[_T]
1108 for idx
, item
in enumerate(items
):
1109 if idx
>= len(self
.descriptors
):
1110 raise ValueError("too many items")
1111 self
._verify
_write
(idx
, item
)
1112 self
.__items
.append(item
)
1113 if len(self
.__items
) < len(self
.descriptors
):
1114 raise ValueError("not enough items")
1118 # type: () -> Iterator[_T]
1119 yield from self
.__items
1122 def __getitem__(self
, idx
):
1127 def __getitem__(self
, idx
):
1128 # type: (slice) -> list[_T]
1132 def __getitem__(self
, idx
):
1133 # type: (int | slice) -> _T | list[_T]
1134 return self
.__items
[idx
]
1137 def __setitem__(self
, idx
, item
):
1138 # type: (int, _T) -> None
1139 idx
= self
._verify
_write
(idx
, item
)
1140 self
.__items
[idx
] = item
1145 return len(self
.__items
)
1148 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1152 class OpInputs(OpInputSeq
[SSAVal
, OperandDesc
]):
1153 def _get_descriptors(self
):
1154 # type: () -> tuple[OperandDesc, ...]
1155 return self
.op
.properties
.inputs
1157 def _verify_write_with_desc(self
, idx
, item
, desc
):
1158 # type: (int, SSAVal | Any, OperandDesc) -> None
1159 if not isinstance(item
, SSAVal
):
1160 raise TypeError("expected value of type SSAVal")
1161 if item
.ty
!= desc
.ty
:
1162 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1163 f
"corresponding input's type {desc.ty!r}")
1165 def __init__(self
, items
, op
):
1166 # type: (Iterable[SSAVal], Op) -> None
1167 if hasattr(op
, "inputs"):
1168 raise ValueError("Op.inputs already set")
1169 super().__init
__(items
, op
)
1173 class OpImmediates(OpInputSeq
[int, range]):
1174 def _get_descriptors(self
):
1175 # type: () -> tuple[range, ...]
1176 return self
.op
.properties
.immediates
1178 def _verify_write_with_desc(self
, idx
, item
, desc
):
1179 # type: (int, int | Any, range) -> None
1180 if not isinstance(item
, int):
1181 raise TypeError("expected value of type int")
1182 if item
not in desc
:
1183 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1185 def __init__(self
, items
, op
):
1186 # type: (Iterable[int], Op) -> None
1187 if hasattr(op
, "immediates"):
1188 raise ValueError("Op.immediates already set")
1189 super().__init
__(items
, op
)
1192 @plain_data(frozen
=True, eq
=False, repr=False)
1195 __slots__
= "fn", "properties", "inputs", "immediates", "outputs", "name"
1197 def __init__(self
, fn
, properties
, inputs
, immediates
, name
=""):
1198 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1200 self
.properties
= properties
1201 self
.inputs
= OpInputs(inputs
, op
=self
)
1202 self
.immediates
= OpImmediates(immediates
, op
=self
)
1203 outputs_len
= len(self
.properties
.outputs
)
1204 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1205 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1209 return self
.properties
.kind
1211 def __eq__(self
, other
):
1212 # type: (Op | Any) -> bool
1213 if isinstance(other
, Op
):
1214 return self
is other
1215 return NotImplemented
1218 return object.__hash
__(self
)
1222 field_vals
= [] # type: list[str]
1223 for name
in fields(self
):
1224 if name
== "properties":
1229 value
= getattr(self
, name
)
1230 except AttributeError:
1231 field_vals
.append(f
"{name}=<not set>")
1233 if isinstance(value
, OpInputSeq
):
1234 value
= list(value
) # type: ignore
1235 field_vals
.append(f
"{name}={value!r}")
1236 field_vals_str
= ", ".join(field_vals
)
1237 return f
"Op({field_vals_str})"
1239 def pre_ra_sim(self
, state
):
1240 # type: (PreRASimState) -> None
1241 for inp
in self
.inputs
:
1242 if inp
not in state
.ssa_vals
:
1243 raise ValueError(f
"SSAVal {inp} not yet assigned when "
1245 if len(state
.ssa_vals
[inp
]) != inp
.ty
.reg_len
:
1247 f
"value of SSAVal {inp} has wrong number of elements: "
1248 f
"expected {inp.ty.reg_len} found "
1249 f
"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1250 for out
in self
.outputs
:
1251 if out
in state
.ssa_vals
:
1252 if self
.kind
is OpKind
.FuncArgR3
:
1254 raise ValueError(f
"SSAVal {out} already assigned before "
1256 self
.kind
.pre_ra_sim(self
, state
)
1257 for out
in self
.outputs
:
1258 if out
not in state
.ssa_vals
:
1259 raise ValueError(f
"running {self} failed to assign to {out}")
1260 if len(state
.ssa_vals
[out
]) != out
.ty
.reg_len
:
1262 f
"value of SSAVal {out} has wrong number of elements: "
1263 f
"expected {out.ty.reg_len} found "
1264 f
"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1267 GPR_SIZE_IN_BYTES
= 8
1269 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
1270 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
1273 @plain_data(frozen
=True, repr=False)
1275 class PreRASimState
:
1276 __slots__
= "ssa_vals", "memory"
1278 def __init__(self
, ssa_vals
, memory
):
1279 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1280 self
.ssa_vals
= ssa_vals
1281 self
.memory
= memory
1283 def load_byte(self
, addr
):
1284 # type: (int) -> int
1285 addr
&= GPR_VALUE_MASK
1286 return self
.memory
.get(addr
, 0) & 0xFF
1288 def store_byte(self
, addr
, value
):
1289 # type: (int, int) -> None
1290 addr
&= GPR_VALUE_MASK
1292 self
.memory
[addr
] = value
1294 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
1295 # type: (int, int, bool) -> int
1296 if addr
% size_in_bytes
!= 0:
1297 raise ValueError(f
"address not aligned: {hex(addr)} "
1298 f
"required alignment: {size_in_bytes}")
1300 for i
in range(size_in_bytes
):
1301 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
1302 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
1303 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
1306 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
1307 # type: (int, int, int) -> None
1308 if addr
% size_in_bytes
!= 0:
1309 raise ValueError(f
"address not aligned: {hex(addr)} "
1310 f
"required alignment: {size_in_bytes}")
1311 for i
in range(size_in_bytes
):
1312 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
1314 def _memory__repr(self
):
1316 if len(self
.memory
) == 0:
1318 keys
= sorted(self
.memory
.keys(), reverse
=True)
1319 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
1320 items
= [] # type: list[str]
1321 while len(keys
) != 0:
1323 if (len(keys
) >= CHUNK_SIZE
1324 and addr
% CHUNK_SIZE
== 0
1325 and keys
[-CHUNK_SIZE
:]
1326 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
1327 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
1328 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1329 keys
[-CHUNK_SIZE
:] = ()
1331 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1333 return f
"{{{items[0]}}}"
1334 items_str
= ",\n".join(items
)
1335 return f
"{{\n{items_str}}}"
1337 def _ssa_vals__repr(self
):
1339 if len(self
.ssa_vals
) == 0:
1341 items
= [] # type: list[str]
1343 for k
, v
in self
.ssa_vals
.items():
1344 element_strs
= [] # type: list[str]
1345 for i
, el
in enumerate(v
):
1346 if i
% CHUNK_SIZE
!= 0:
1347 element_strs
.append(" " + hex(el
))
1349 element_strs
.append("\n " + hex(el
))
1350 if len(element_strs
) <= CHUNK_SIZE
:
1351 element_strs
[0] = element_strs
[0].lstrip()
1352 if len(element_strs
) == 1:
1353 element_strs
.append("")
1354 v_str
= ",".join(element_strs
)
1355 items
.append(f
"{k!r}: ({v_str})")
1356 if len(items
) == 1 and "\n" not in items
[0]:
1357 return f
"{{{items[0]}}}"
1358 items_str
= ",\n".join(items
)
1359 return f
"{{\n{items_str},\n}}"
1363 field_vals
= [] # type: list[str]
1364 for name
in fields(self
):
1366 value
= getattr(self
, name
)
1367 except AttributeError:
1368 field_vals
.append(f
"{name}=<not set>")
1370 repr_fn
= getattr(self
, f
"_{name}__repr", None)
1371 if callable(repr_fn
):
1372 field_vals
.append(f
"{name}={repr_fn()}")
1374 field_vals
.append(f
"{name}={value!r}")
1375 field_vals_str
= ", ".join(field_vals
)
1376 return f
"PreRASimState({field_vals_str})"