00e62219c100fa8a26e542f9a72ca8fb1587a850
2 from abc
import abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
5 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
6 Sequence
, TypeVar
, overload
)
7 from weakref
import WeakValueDictionary
as _WeakVDict
9 from cached_property
import cached_property
10 from nmutil
.plain_data
import fields
, plain_data
12 from bigint_presentation_code
.type_util
import Self
, assert_never
, final
13 from bigint_presentation_code
.util
import BitSet
, FBitSet
, FMap
, OFSet
19 self
.ops
= [] # type: list[Op]
20 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
21 self
.__next
_name
_suffix
= 2
23 def _add_op_with_unused_name(self
, op
, name
=""):
24 # type: (Op, str) -> str
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op
, "name"):
28 raise ValueError("Op already named")
31 if name
!= "" and name
not in self
.__op
_names
:
32 self
.__op
_names
[name
] = op
34 name
= orig_name
+ str(self
.__next
_name
_suffix
)
35 self
.__next
_name
_suffix
+= 1
41 def append_op(self
, op
):
44 raise ValueError("can't add Op to wrong Fn")
47 def append_new_op(self
, kind
, inputs
=(), immediates
=(), name
="", maxvl
=1):
48 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
49 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
50 inputs
=inputs
, immediates
=immediates
, name
=name
)
51 self
.append_op(retval
)
54 def pre_ra_sim(self
, state
):
55 # type: (PreRASimState) -> None
65 VL_MAXVL
= enum
.auto()
68 def only_scalar(self
):
70 if self
is BaseTy
.I64
:
72 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
78 def max_reg_len(self
):
80 if self
is BaseTy
.I64
:
82 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
88 return "BaseTy." + self
._name
_
91 @plain_data(frozen
=True, unsafe_hash
=True)
94 __slots__
= "base_ty", "reg_len"
97 def validate(base_ty
, reg_len
):
98 # type: (BaseTy, int) -> str | None
99 """ return a string with the error if the combination is invalid,
100 otherwise return None
102 if base_ty
.only_scalar
and reg_len
!= 1:
103 return f
"can't create a vector of an only-scalar type: {base_ty}"
104 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
105 return "reg_len out of range"
108 def __init__(self
, base_ty
, reg_len
):
109 # type: (BaseTy, int) -> None
110 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
112 raise ValueError(msg
)
113 self
.base_ty
= base_ty
114 self
.reg_len
= reg_len
121 StackI64
= enum
.auto()
123 VL_MAXVL
= enum
.auto()
128 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
130 if self
is LocKind
.CA
:
132 if self
is LocKind
.VL_MAXVL
:
133 return BaseTy
.VL_MAXVL
140 if self
is LocKind
.StackI64
:
142 if self
is LocKind
.GPR
or self
is LocKind
.CA \
143 or self
is LocKind
.VL_MAXVL
:
144 return self
.base_ty
.max_reg_len
149 return "LocKind." + self
._name
_
154 class LocSubKind(Enum
):
155 BASE_GPR
= enum
.auto()
156 SV_EXTRA2_VGPR
= enum
.auto()
157 SV_EXTRA2_SGPR
= enum
.auto()
158 SV_EXTRA3_VGPR
= enum
.auto()
159 SV_EXTRA3_SGPR
= enum
.auto()
160 StackI64
= enum
.auto()
162 VL_MAXVL
= enum
.auto()
166 # type: () -> LocKind
167 # pyright fails typechecking when using `in` here:
168 # reported: https://github.com/microsoft/pyright/issues/4102
169 if self
is LocSubKind
.BASE_GPR
or self
is LocSubKind
.SV_EXTRA2_VGPR \
170 or self
is LocSubKind
.SV_EXTRA2_SGPR \
171 or self
is LocSubKind
.SV_EXTRA3_VGPR \
172 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
174 if self
is LocSubKind
.StackI64
:
175 return LocKind
.StackI64
176 if self
is LocSubKind
.CA
:
178 if self
is LocSubKind
.VL_MAXVL
:
179 return LocKind
.VL_MAXVL
184 return self
.kind
.base_ty
187 def allocatable_locs(self
, ty
):
188 # type: (Ty) -> LocSet
189 if ty
.base_ty
!= self
.base_ty
:
190 raise ValueError("type mismatch")
191 if self
is LocSubKind
.BASE_GPR
:
193 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
194 starts
= range(0, 128, 2)
195 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
197 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
198 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
200 elif self
is LocSubKind
.StackI64
:
201 starts
= range(LocKind
.StackI64
.loc_count
)
202 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
203 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
206 retval
= [] # type: list[Loc]
208 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
212 for special_loc
in SPECIAL_GPRS
:
213 if loc
.conflicts(special_loc
):
218 return LocSet(retval
)
221 return "LocSubKind." + self
._name
_
224 @plain_data(frozen
=True, unsafe_hash
=True)
227 __slots__
= "base_ty", "is_vec"
229 def __init__(self
, base_ty
, is_vec
):
230 # type: (BaseTy, bool) -> None
231 self
.base_ty
= base_ty
232 if base_ty
.only_scalar
and is_vec
:
233 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
236 def instantiate(self
, maxvl
):
238 # here's where subvl and elwid would be accounted for
240 return Ty(self
.base_ty
, maxvl
)
241 return Ty(self
.base_ty
, 1)
243 def can_instantiate_to(self
, ty
):
245 if self
.base_ty
!= ty
.base_ty
:
249 return ty
.reg_len
== 1
252 @plain_data(frozen
=True, unsafe_hash
=True)
255 __slots__
= "kind", "start", "reg_len"
258 def validate(kind
, start
, reg_len
):
259 # type: (LocKind, int, int) -> str | None
260 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
263 if reg_len
> kind
.loc_count
:
264 return "invalid reg_len"
265 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
266 return "start not in valid range"
270 def try_make(kind
, start
, reg_len
):
271 # type: (LocKind, int, int) -> Loc | None
272 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
275 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
277 def __init__(self
, kind
, start
, reg_len
):
278 # type: (LocKind, int, int) -> None
279 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
281 raise ValueError(msg
)
283 self
.reg_len
= reg_len
286 def conflicts(self
, other
):
287 # type: (Loc) -> bool
288 return (self
.kind
== other
.kind
289 and self
.start
< other
.stop
and other
.start
< self
.stop
)
292 def make_ty(kind
, reg_len
):
293 # type: (LocKind, int) -> Ty
294 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
299 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
304 return self
.start
+ self
.reg_len
306 def try_concat(self
, *others
):
307 # type: (*Loc | None) -> Loc | None
308 reg_len
= self
.reg_len
311 if other
is None or other
.kind
!= self
.kind
:
313 if stop
!= other
.start
:
316 reg_len
+= other
.reg_len
317 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
321 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
322 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
323 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
324 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
328 @plain_data(frozen
=True, eq
=False)
330 class LocSet(AbstractSet
[Loc
]):
331 __slots__
= "starts", "ty"
333 def __init__(self
, __locs
=()):
334 # type: (Iterable[Loc]) -> None
335 if isinstance(__locs
, LocSet
):
336 self
.starts
= __locs
.starts
# type: FMap[LocKind, FBitSet]
337 self
.ty
= __locs
.ty
# type: Ty | None
339 starts
= {i
: BitSet() for i
in LocKind
}
345 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
346 starts
[loc
.kind
].add(loc
.start
)
348 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
353 # type: () -> FMap[LocKind, FBitSet]
358 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
362 # type: () -> AbstractSet[LocKind]
363 return self
.starts
.keys()
367 # type: () -> int | None
370 return self
.ty
.reg_len
374 # type: () -> BaseTy | None
377 return self
.ty
.base_ty
379 def concat(self
, *others
):
380 # type: (*LocSet) -> LocSet
383 base_ty
= self
.ty
.base_ty
384 reg_len
= self
.ty
.reg_len
385 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
389 if other
.ty
.base_ty
!= base_ty
:
391 for kind
, other_starts
in other
.starts
.items():
392 if kind
not in starts
:
394 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
395 if starts
[kind
] == 0:
399 reg_len
+= other
.ty
.reg_len
402 # type: () -> Iterable[Loc]
403 for kind
, v
in starts
.items():
405 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
408 return LocSet(locs())
410 def __contains__(self
, loc
):
411 # type: (Loc | Any) -> bool
412 if not isinstance(loc
, Loc
) or loc
.ty
!= self
.ty
:
414 if loc
.kind
not in self
.starts
:
416 return loc
.start
in self
.starts
[loc
.kind
]
419 # type: () -> Iterator[Loc]
422 for kind
, starts
in self
.starts
.items():
424 yield Loc(kind
=kind
, start
=start
, reg_len
=self
.ty
.reg_len
)
428 return sum((len(v
) for v
in self
.starts
.values()), 0)
435 return super()._hash
()
441 @plain_data(frozen
=True, unsafe_hash
=True)
443 class GenericOperandDesc
:
444 """generic Op operand descriptor"""
445 __slots__
= "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
448 self
, ty
, # type: GenericTy
449 sub_kinds
, # type: Iterable[LocSubKind]
451 fixed_loc
=None, # type: Loc | None
452 tied_input_index
=None, # type: int | None
453 spread
=False, # type: bool
455 # type: (...) -> None
457 self
.sub_kinds
= OFSet(sub_kinds
)
458 if len(self
.sub_kinds
) == 0:
459 raise ValueError("sub_kinds can't be empty")
460 self
.fixed_loc
= fixed_loc
461 if fixed_loc
is not None:
462 if tied_input_index
is not None:
463 raise ValueError("operand can't be both tied and fixed")
464 if not ty
.can_instantiate_to(fixed_loc
.ty
):
466 f
"fixed_loc has incompatible type for given generic "
467 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
468 if len(self
.sub_kinds
) != 1:
470 "multiple sub_kinds not allowed for fixed operand")
471 for sub_kind
in self
.sub_kinds
:
472 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
474 f
"fixed_loc not in given sub_kind: "
475 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
476 for sub_kind
in self
.sub_kinds
:
477 if sub_kind
.base_ty
!= ty
.base_ty
:
478 raise ValueError(f
"sub_kind is incompatible with type: "
479 f
"sub_kind={sub_kind} ty={ty}")
480 if tied_input_index
is not None and tied_input_index
< 0:
481 raise ValueError("invalid tied_input_index")
482 self
.tied_input_index
= tied_input_index
485 if self
.tied_input_index
is not None:
486 raise ValueError("operand can't be both spread and tied")
487 if self
.fixed_loc
is not None:
488 raise ValueError("operand can't be both spread and fixed")
490 raise ValueError("operand can't be both spread and vector")
492 def tied_to_input(self
, tied_input_index
):
493 # type: (int) -> Self
494 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
495 tied_input_index
=tied_input_index
)
497 def with_fixed_loc(self
, fixed_loc
):
498 # type: (Loc) -> Self
499 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
)
501 def instantiate(self
, maxvl
):
502 # type: (int) -> Iterable[OperandDesc]
507 ty
= self
.ty
.instantiate(maxvl
=maxvl
)
510 # type: () -> Iterable[Loc]
511 if self
.fixed_loc
is not None:
512 if ty
!= self
.fixed_loc
.ty
:
514 f
"instantiation failed: type mismatch with fixed_loc: "
515 f
"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
518 for sub_kind
in self
.sub_kinds
:
519 yield from sub_kind
.allocatable_locs(ty
)
520 loc_set_before_spread
= LocSet(locs())
521 for idx
in range(rep_count
):
524 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
525 tied_input_index
=self
.tied_input_index
,
529 @plain_data(frozen
=True, unsafe_hash
=True)
532 """Op operand descriptor"""
533 __slots__
= "loc_set_before_spread", "tied_input_index", "spread_index"
535 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
):
536 # type: (LocSet, int | None, int | None) -> None
537 if len(loc_set_before_spread
) == 0:
538 raise ValueError("loc_set_before_spread must not be empty")
539 self
.loc_set_before_spread
= loc_set_before_spread
540 self
.tied_input_index
= tied_input_index
541 if self
.tied_input_index
is not None and self
.spread_index
is not None:
542 raise ValueError("operand can't be both spread and tied")
543 self
.spread_index
= spread_index
546 def ty_before_spread(self
):
548 ty
= self
.loc_set_before_spread
.ty
549 assert ty
is not None, (
550 "__init__ checked that the LocSet isn't empty, "
551 "non-empty LocSets should always have ty set")
556 """ Ty after any spread is applied """
557 if self
.spread_index
is not None:
558 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
559 return self
.ty_before_spread
562 OD_BASE_SGPR
= GenericOperandDesc(
563 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
564 sub_kinds
=[LocSubKind
.BASE_GPR
])
565 OD_EXTRA3_SGPR
= GenericOperandDesc(
566 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
567 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
568 OD_EXTRA3_VGPR
= GenericOperandDesc(
569 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
570 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
571 OD_EXTRA2_SGPR
= GenericOperandDesc(
572 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
573 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
574 OD_EXTRA2_VGPR
= GenericOperandDesc(
575 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
576 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
577 OD_CA
= GenericOperandDesc(
578 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
579 sub_kinds
=[LocSubKind
.CA
])
580 OD_VL
= GenericOperandDesc(
581 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
582 sub_kinds
=[LocSubKind
.VL_MAXVL
])
585 @plain_data(frozen
=True, unsafe_hash
=True)
587 class GenericOpProperties
:
588 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
589 "is_copy", "is_load_immediate", "has_side_effects")
592 self
, demo_asm
, # type: str
593 inputs
, # type: Iterable[GenericOperandDesc]
594 outputs
, # type: Iterable[GenericOperandDesc]
595 immediates
=(), # type: Iterable[range]
596 is_copy
=False, # type: bool
597 is_load_immediate
=False, # type: bool
598 has_side_effects
=False, # type: bool
600 # type: (...) -> None
601 self
.demo_asm
= demo_asm
602 self
.inputs
= tuple(inputs
)
603 for inp
in self
.inputs
:
604 if inp
.tied_input_index
is not None:
606 f
"tied_input_index is not allowed on inputs: {inp}")
607 self
.outputs
= tuple(outputs
)
608 fixed_locs
= [] # type: list[tuple[Loc, int]]
609 for idx
, out
in enumerate(self
.outputs
):
610 if out
.tied_input_index
is not None \
611 and out
.tied_input_index
>= len(self
.inputs
):
612 raise ValueError(f
"tied_input_index out of range: {out}")
613 if out
.fixed_loc
is not None:
614 for other_fixed_loc
, other_idx
in fixed_locs
:
615 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
618 f
"conflicting fixed_locs: outputs[{idx}] and "
619 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
620 f
"with {other_fixed_loc}")
621 fixed_locs
.append((out
.fixed_loc
, idx
))
622 self
.immediates
= tuple(immediates
)
623 self
.is_copy
= is_copy
624 self
.is_load_immediate
= is_load_immediate
625 self
.has_side_effects
= has_side_effects
628 @plain_data(frozen
=True, unsafe_hash
=True)
631 __slots__
= "kind", "inputs", "outputs", "maxvl"
633 def __init__(self
, kind
, maxvl
):
634 # type: (OpKind, int) -> None
636 inputs
= [] # type: list[OperandDesc]
637 for inp
in self
.generic
.inputs
:
638 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
639 self
.inputs
= tuple(inputs
)
640 outputs
= [] # type: list[OperandDesc]
641 for out
in self
.generic
.outputs
:
642 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
643 self
.outputs
= tuple(outputs
)
648 # type: () -> GenericOpProperties
649 return self
.kind
.properties
652 def immediates(self
):
653 # type: () -> tuple[range, ...]
654 return self
.generic
.immediates
659 return self
.generic
.demo_asm
664 return self
.generic
.is_copy
667 def is_load_immediate(self
):
669 return self
.generic
.is_load_immediate
672 def has_side_effects(self
):
674 return self
.generic
.has_side_effects
677 IMM_S16
= range(-1 << 15, 1 << 15)
679 _PRE_RA_SIM_FN
= Callable
[["Op", "PreRASimState"], None]
680 _PRE_RA_SIM_FN2
= Callable
[[], _PRE_RA_SIM_FN
]
681 _PRE_RA_SIMS
= {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
687 def __init__(self
, properties
):
688 # type: (GenericOpProperties) -> None
690 self
.__properties
= properties
693 def properties(self
):
694 # type: () -> GenericOpProperties
695 return self
.__properties
697 def instantiate(self
, maxvl
):
698 # type: (int) -> OpProperties
699 return OpProperties(self
, maxvl
=maxvl
)
702 return "OpKind." + self
._name
_
705 def pre_ra_sim(self
):
706 # type: () -> _PRE_RA_SIM_FN
707 return _PRE_RA_SIMS
[self
.properties
]()
710 def __clearca_pre_ra_sim(op
, state
):
711 # type: (Op, PreRASimState) -> None
712 state
.ssa_vals
[op
.outputs
[0]] = False,
713 ClearCA
= GenericOpProperties(
714 demo_asm
="addic 0, 0, 0",
718 _PRE_RA_SIMS
[ClearCA
] = lambda: OpKind
.__clearca
_pre
_ra
_sim
721 def __setca_pre_ra_sim(op
, state
):
722 # type: (Op, PreRASimState) -> None
723 state
.ssa_vals
[op
.outputs
[0]] = True,
724 SetCA
= GenericOpProperties(
725 demo_asm
="subfc 0, 0, 0",
729 _PRE_RA_SIMS
[SetCA
] = lambda: OpKind
.__setca
_pre
_ra
_sim
732 def __svadde_pre_ra_sim(op
, state
):
733 # type: (Op, PreRASimState) -> None
734 RA
= state
.ssa_vals
[op
.inputs
[0]]
735 RB
= state
.ssa_vals
[op
.inputs
[1]]
736 carry
, = state
.ssa_vals
[op
.inputs
[2]]
737 VL
, = state
.ssa_vals
[op
.inputs
[3]]
738 RT
= [] # type: list[int]
740 v
= RA
[i
] + RB
[i
] + carry
741 RT
.append(v
& GPR_VALUE_MASK
)
742 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
743 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
744 state
.ssa_vals
[op
.outputs
[1]] = carry
,
745 SvAddE
= GenericOpProperties(
746 demo_asm
="sv.adde *RT, *RA, *RB",
747 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
748 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
750 _PRE_RA_SIMS
[SvAddE
] = lambda: OpKind
.__svadde
_pre
_ra
_sim
753 def __svsubfe_pre_ra_sim(op
, state
):
754 # type: (Op, PreRASimState) -> None
755 RA
= state
.ssa_vals
[op
.inputs
[0]]
756 RB
= state
.ssa_vals
[op
.inputs
[1]]
757 carry
, = state
.ssa_vals
[op
.inputs
[2]]
758 VL
, = state
.ssa_vals
[op
.inputs
[3]]
759 RT
= [] # type: list[int]
761 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
762 RT
.append(v
& GPR_VALUE_MASK
)
763 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
764 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
765 state
.ssa_vals
[op
.outputs
[1]] = carry
,
766 SvSubFE
= GenericOpProperties(
767 demo_asm
="sv.subfe *RT, *RA, *RB",
768 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
769 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
771 _PRE_RA_SIMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_pre
_ra
_sim
774 def __svmaddedu_pre_ra_sim(op
, state
):
775 # type: (Op, PreRASimState) -> None
776 RA
= state
.ssa_vals
[op
.inputs
[0]]
777 RB
, = state
.ssa_vals
[op
.inputs
[1]]
778 carry
, = state
.ssa_vals
[op
.inputs
[2]]
779 VL
, = state
.ssa_vals
[op
.inputs
[3]]
780 RT
= [] # type: list[int]
782 v
= RA
[i
] * RB
+ carry
783 RT
.append(v
& GPR_VALUE_MASK
)
784 carry
= v
>> GPR_SIZE_IN_BITS
785 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
786 state
.ssa_vals
[op
.outputs
[1]] = carry
,
787 SvMAddEDU
= GenericOpProperties(
788 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
789 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
790 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
792 _PRE_RA_SIMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_pre
_ra
_sim
795 def __setvli_pre_ra_sim(op
, state
):
796 # type: (Op, PreRASimState) -> None
797 state
.ssa_vals
[op
.outputs
[0]] = op
.immediates
[0],
798 SetVLI
= GenericOpProperties(
799 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
802 immediates
=[range(1, 65)],
803 is_load_immediate
=True,
805 _PRE_RA_SIMS
[SetVLI
] = lambda: OpKind
.__setvli
_pre
_ra
_sim
808 def __svli_pre_ra_sim(op
, state
):
809 # type: (Op, PreRASimState) -> None
810 VL
, = state
.ssa_vals
[op
.inputs
[0]]
811 imm
= op
.immediates
[0] & GPR_VALUE_MASK
812 state
.ssa_vals
[op
.outputs
[0]] = (imm
,) * VL
813 SvLI
= GenericOpProperties(
814 demo_asm
="sv.addi *RT, 0, imm",
816 outputs
=[OD_EXTRA3_VGPR
],
817 immediates
=[IMM_S16
],
818 is_load_immediate
=True,
820 _PRE_RA_SIMS
[SvLI
] = lambda: OpKind
.__svli
_pre
_ra
_sim
823 def __li_pre_ra_sim(op
, state
):
824 # type: (Op, PreRASimState) -> None
825 imm
= op
.immediates
[0] & GPR_VALUE_MASK
826 state
.ssa_vals
[op
.outputs
[0]] = imm
,
827 LI
= GenericOpProperties(
828 demo_asm
="addi RT, 0, imm",
830 outputs
=[OD_BASE_SGPR
],
831 immediates
=[IMM_S16
],
832 is_load_immediate
=True,
834 _PRE_RA_SIMS
[LI
] = lambda: OpKind
.__li
_pre
_ra
_sim
837 def __veccopytoreg_pre_ra_sim(op
, state
):
838 # type: (Op, PreRASimState) -> None
839 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
840 VecCopyToReg
= GenericOpProperties(
841 demo_asm
="sv.mv dest, src",
842 inputs
=[GenericOperandDesc(
843 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
844 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
846 outputs
=[OD_EXTRA3_VGPR
],
849 _PRE_RA_SIMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_pre
_ra
_sim
852 def __veccopyfromreg_pre_ra_sim(op
, state
):
853 # type: (Op, PreRASimState) -> None
854 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
855 VecCopyFromReg
= GenericOpProperties(
856 demo_asm
="sv.mv dest, src",
857 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
858 outputs
=[GenericOperandDesc(
859 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
860 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
864 _PRE_RA_SIMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_pre
_ra
_sim
867 def __copytoreg_pre_ra_sim(op
, state
):
868 # type: (Op, PreRASimState) -> None
869 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
870 CopyToReg
= GenericOpProperties(
871 demo_asm
="mv dest, src",
872 inputs
=[GenericOperandDesc(
873 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
874 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
875 LocSubKind
.StackI64
],
877 outputs
=[GenericOperandDesc(
878 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
879 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
883 _PRE_RA_SIMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_pre
_ra
_sim
886 def __copyfromreg_pre_ra_sim(op
, state
):
887 # type: (Op, PreRASimState) -> None
888 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.inputs
[0]]
889 CopyFromReg
= GenericOpProperties(
890 demo_asm
="mv dest, src",
891 inputs
=[GenericOperandDesc(
892 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
893 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
895 outputs
=[GenericOperandDesc(
896 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
897 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
898 LocSubKind
.StackI64
],
902 _PRE_RA_SIMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_pre
_ra
_sim
905 def __concat_pre_ra_sim(op
, state
):
906 # type: (Op, PreRASimState) -> None
907 state
.ssa_vals
[op
.outputs
[0]] = tuple(
908 state
.ssa_vals
[i
][0] for i
in op
.inputs
[:-1])
909 Concat
= GenericOpProperties(
910 demo_asm
="sv.mv dest, src",
911 inputs
=[GenericOperandDesc(
912 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
913 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
916 outputs
=[OD_EXTRA3_VGPR
],
919 _PRE_RA_SIMS
[Concat
] = lambda: OpKind
.__concat
_pre
_ra
_sim
922 def __spread_pre_ra_sim(op
, state
):
923 # type: (Op, PreRASimState) -> None
924 for idx
, inp
in enumerate(state
.ssa_vals
[op
.inputs
[0]]):
925 state
.ssa_vals
[op
.outputs
[idx
]] = inp
,
926 Spread
= GenericOpProperties(
927 demo_asm
="sv.mv dest, src",
928 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
929 outputs
=[GenericOperandDesc(
930 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
931 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
936 _PRE_RA_SIMS
[Spread
] = lambda: OpKind
.__spread
_pre
_ra
_sim
939 def __svld_pre_ra_sim(op
, state
):
940 # type: (Op, PreRASimState) -> None
941 RA
, = state
.ssa_vals
[op
.inputs
[0]]
942 VL
, = state
.ssa_vals
[op
.inputs
[1]]
943 addr
= RA
+ op
.immediates
[0]
944 RT
= [] # type: list[int]
946 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
947 RT
.append(v
& GPR_VALUE_MASK
)
948 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
949 SvLd
= GenericOpProperties(
950 demo_asm
="sv.ld *RT, imm(RA)",
951 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
952 outputs
=[OD_EXTRA3_VGPR
],
953 immediates
=[IMM_S16
],
955 _PRE_RA_SIMS
[SvLd
] = lambda: OpKind
.__svld
_pre
_ra
_sim
958 def __ld_pre_ra_sim(op
, state
):
959 # type: (Op, PreRASimState) -> None
960 RA
, = state
.ssa_vals
[op
.inputs
[0]]
961 addr
= RA
+ op
.immediates
[0]
963 state
.ssa_vals
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
964 Ld
= GenericOpProperties(
965 demo_asm
="ld RT, imm(RA)",
966 inputs
=[OD_BASE_SGPR
],
967 outputs
=[OD_BASE_SGPR
],
968 immediates
=[IMM_S16
],
970 _PRE_RA_SIMS
[Ld
] = lambda: OpKind
.__ld
_pre
_ra
_sim
973 def __svstd_pre_ra_sim(op
, state
):
974 # type: (Op, PreRASimState) -> None
975 RS
= state
.ssa_vals
[op
.inputs
[0]]
976 RA
, = state
.ssa_vals
[op
.inputs
[1]]
977 VL
, = state
.ssa_vals
[op
.inputs
[2]]
978 addr
= RA
+ op
.immediates
[0]
980 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
981 SvStd
= GenericOpProperties(
982 demo_asm
="sv.std *RS, imm(RA)",
983 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
985 immediates
=[IMM_S16
],
986 has_side_effects
=True,
988 _PRE_RA_SIMS
[SvStd
] = lambda: OpKind
.__svstd
_pre
_ra
_sim
991 def __std_pre_ra_sim(op
, state
):
992 # type: (Op, PreRASimState) -> None
993 RS
, = state
.ssa_vals
[op
.inputs
[0]]
994 RA
, = state
.ssa_vals
[op
.inputs
[1]]
995 addr
= RA
+ op
.immediates
[0]
996 state
.store(addr
, value
=RS
)
997 Std
= GenericOpProperties(
998 demo_asm
="std RT, imm(RA)",
999 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1001 immediates
=[IMM_S16
],
1002 has_side_effects
=True,
1004 _PRE_RA_SIMS
[Std
] = lambda: OpKind
.__std
_pre
_ra
_sim
1007 def __funcargr3_pre_ra_sim(op
, state
):
1008 # type: (Op, PreRASimState) -> None
1009 pass # return value set before simulation
1010 FuncArgR3
= GenericOpProperties(
1013 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1014 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1016 _PRE_RA_SIMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_pre
_ra
_sim
1019 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1022 __slots__
= "op", "output_idx"
1024 def __init__(self
, op
, output_idx
):
1025 # type: (Op, int) -> None
1027 if output_idx
< 0 or output_idx
>= len(op
.properties
.outputs
):
1028 raise ValueError("invalid output_idx")
1029 self
.output_idx
= output_idx
1033 return f
"<{self.op.name}#{self.output_idx}>"
1036 def defining_descriptor(self
):
1037 # type: () -> OperandDesc
1038 return self
.op
.properties
.outputs
[self
.output_idx
]
1041 def loc_set_before_spread(self
):
1042 # type: () -> LocSet
1043 return self
.defining_descriptor
.loc_set_before_spread
1048 return self
.defining_descriptor
.ty
1051 def ty_before_spread(self
):
1053 return self
.defining_descriptor
.ty_before_spread
1057 _Desc
= TypeVar("_Desc")
1060 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1062 def _verify_write_with_desc(self
, idx
, item
, desc
):
1063 # type: (int, _T | Any, _Desc) -> None
1064 raise NotImplementedError
1067 def _verify_write(self
, idx
, item
):
1068 # type: (int | Any, _T | Any) -> int
1069 if not isinstance(idx
, int):
1070 if isinstance(idx
, slice):
1072 f
"can't write to slice of {self.__class__.__name__}")
1073 raise TypeError(f
"can't write with index {idx!r}")
1074 # normalize idx, raising IndexError if it is out of range
1075 idx
= range(len(self
.descriptors
))[idx
]
1076 desc
= self
.descriptors
[idx
]
1077 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1081 def _get_descriptors(self
):
1082 # type: () -> tuple[_Desc, ...]
1083 raise NotImplementedError
1087 def descriptors(self
):
1088 # type: () -> tuple[_Desc, ...]
1089 return self
._get
_descriptors
()
1096 def __init__(self
, items
, op
):
1097 # type: (Iterable[_T], Op) -> None
1099 self
.__items
= [] # type: list[_T]
1100 for idx
, item
in enumerate(items
):
1101 if idx
>= len(self
.descriptors
):
1102 raise ValueError("too many items")
1103 self
._verify
_write
(idx
, item
)
1104 self
.__items
.append(item
)
1105 if len(self
.__items
) < len(self
.descriptors
):
1106 raise ValueError("not enough items")
1110 # type: () -> Iterator[_T]
1111 yield from self
.__items
1114 def __getitem__(self
, idx
):
1119 def __getitem__(self
, idx
):
1120 # type: (slice) -> list[_T]
1124 def __getitem__(self
, idx
):
1125 # type: (int | slice) -> _T | list[_T]
1126 return self
.__items
[idx
]
1129 def __setitem__(self
, idx
, item
):
1130 # type: (int, _T) -> None
1131 idx
= self
._verify
_write
(idx
, item
)
1132 self
.__items
[idx
] = item
1137 return len(self
.__items
)
1140 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1144 class OpInputs(OpInputSeq
[SSAVal
, OperandDesc
]):
1145 def _get_descriptors(self
):
1146 # type: () -> tuple[OperandDesc, ...]
1147 return self
.op
.properties
.inputs
1149 def _verify_write_with_desc(self
, idx
, item
, desc
):
1150 # type: (int, SSAVal | Any, OperandDesc) -> None
1151 if not isinstance(item
, SSAVal
):
1152 raise TypeError("expected value of type SSAVal")
1153 if item
.ty
!= desc
.ty
:
1154 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1155 f
"corresponding input's type {desc.ty!r}")
1157 def __init__(self
, items
, op
):
1158 # type: (Iterable[SSAVal], Op) -> None
1159 if hasattr(op
, "inputs"):
1160 raise ValueError("Op.inputs already set")
1161 super().__init
__(items
, op
)
1165 class OpImmediates(OpInputSeq
[int, range]):
1166 def _get_descriptors(self
):
1167 # type: () -> tuple[range, ...]
1168 return self
.op
.properties
.immediates
1170 def _verify_write_with_desc(self
, idx
, item
, desc
):
1171 # type: (int, int | Any, range) -> None
1172 if not isinstance(item
, int):
1173 raise TypeError("expected value of type int")
1174 if item
not in desc
:
1175 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1177 def __init__(self
, items
, op
):
1178 # type: (Iterable[int], Op) -> None
1179 if hasattr(op
, "immediates"):
1180 raise ValueError("Op.immediates already set")
1181 super().__init
__(items
, op
)
1184 @plain_data(frozen
=True, eq
=False)
1187 __slots__
= "fn", "properties", "inputs", "immediates", "outputs", "name"
1189 def __init__(self
, fn
, properties
, inputs
, immediates
, name
=""):
1190 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1192 self
.properties
= properties
1193 self
.inputs
= OpInputs(inputs
, op
=self
)
1194 self
.immediates
= OpImmediates(immediates
, op
=self
)
1195 outputs_len
= len(self
.properties
.outputs
)
1196 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1197 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1201 return self
.properties
.kind
1203 def __eq__(self
, other
):
1204 # type: (Op | Any) -> bool
1205 if isinstance(other
, Op
):
1206 return self
is other
1207 return NotImplemented
1210 return object.__hash
__(self
)
1212 def pre_ra_sim(self
, state
):
1213 # type: (PreRASimState) -> None
1214 for inp
in self
.inputs
:
1215 if inp
not in state
.ssa_vals
:
1216 raise ValueError(f
"SSAVal {inp} not yet assigned when "
1218 if len(state
.ssa_vals
[inp
]) != inp
.ty
.reg_len
:
1220 f
"value of SSAVal {inp} has wrong number of elements: "
1221 f
"expected {inp.ty.reg_len} found "
1222 f
"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1223 for out
in self
.outputs
:
1224 if out
in state
.ssa_vals
:
1225 if self
.kind
is OpKind
.FuncArgR3
:
1227 raise ValueError(f
"SSAVal {out} already assigned before "
1229 self
.kind
.pre_ra_sim(self
, state
)
1230 for out
in self
.outputs
:
1231 if out
not in state
.ssa_vals
:
1232 raise ValueError(f
"running {self} failed to assign to {out}")
1233 if len(state
.ssa_vals
[out
]) != out
.ty
.reg_len
:
1235 f
"value of SSAVal {out} has wrong number of elements: "
1236 f
"expected {out.ty.reg_len} found "
1237 f
"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1240 GPR_SIZE_IN_BYTES
= 8
1242 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
1243 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
1246 @plain_data(frozen
=True, repr=False)
1248 class PreRASimState
:
1249 __slots__
= "ssa_vals", "memory"
1251 def __init__(self
, ssa_vals
, memory
):
1252 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1253 self
.ssa_vals
= ssa_vals
1254 self
.memory
= memory
1256 def load_byte(self
, addr
):
1257 # type: (int) -> int
1258 addr
&= GPR_VALUE_MASK
1259 return self
.memory
.get(addr
, 0) & 0xFF
1261 def store_byte(self
, addr
, value
):
1262 # type: (int, int) -> None
1263 addr
&= GPR_VALUE_MASK
1265 self
.memory
[addr
] = value
1267 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
1268 # type: (int, int, bool) -> int
1269 if addr
% size_in_bytes
!= 0:
1270 raise ValueError(f
"address not aligned: {hex(addr)} "
1271 f
"required alignment: {size_in_bytes}")
1273 for i
in range(size_in_bytes
):
1274 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
1275 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
1276 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
1279 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
1280 # type: (int, int, int) -> None
1281 if addr
% size_in_bytes
!= 0:
1282 raise ValueError(f
"address not aligned: {hex(addr)} "
1283 f
"required alignment: {size_in_bytes}")
1284 for i
in range(size_in_bytes
):
1285 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
1287 def _memory__repr(self
):
1289 if len(self
.memory
) == 0:
1291 keys
= sorted(self
.memory
.keys(), reverse
=True)
1292 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
1293 items
= [] # type: list[str]
1294 while len(keys
) != 0:
1296 if (len(keys
) >= CHUNK_SIZE
1297 and addr
% CHUNK_SIZE
== 0
1298 and keys
[-CHUNK_SIZE
:]
1299 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
1300 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
1301 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1302 keys
[-CHUNK_SIZE
:] = ()
1304 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1306 return f
"{{{items[0]}}}"
1307 items_str
= ",\n".join(items
)
1308 return f
"{{\n{items_str}}}"
1310 def _ssa_vals__repr(self
):
1312 if len(self
.ssa_vals
) == 0:
1314 items
= [] # type: list[str]
1316 for k
, v
in self
.ssa_vals
.items():
1317 element_strs
= [] # type: list[str]
1318 for i
, el
in enumerate(v
):
1319 if i
% CHUNK_SIZE
!= 0:
1320 element_strs
.append(" " + hex(el
))
1322 element_strs
.append("\n " + hex(el
))
1323 if len(element_strs
) <= CHUNK_SIZE
:
1324 element_strs
[0] = element_strs
[0].lstrip()
1325 if len(element_strs
) == 1:
1326 element_strs
.append("")
1327 v_str
= ",".join(element_strs
)
1328 items
.append(f
"{k!r}: ({v_str})")
1329 if len(items
) == 1 and "\n" not in items
[0]:
1330 return f
"{{{items[0]}}}"
1331 items_str
= ",\n".join(items
)
1332 return f
"{{\n{items_str},\n}}"
1336 field_vals
= [] # type: list[str]
1337 for name
in fields(self
):
1339 value
= getattr(self
, name
)
1340 except AttributeError:
1341 field_vals
.append(f
"{name}=<not set>")
1343 repr_fn
= getattr(self
, f
"_{name}__repr", None)
1344 if callable(repr_fn
):
1345 field_vals
.append(f
"{name}={repr_fn()}")
1347 field_vals
.append(f
"{name}={value!r}")
1348 field_vals_str
= ", ".join(field_vals
)
1349 return f
"PreRASimState({field_vals_str})"