b04848b53a37be825fc16d2d07692edf5a01931b
2 from abc
import abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
5 from typing
import (AbstractSet
, Any
, Generic
, Iterable
, Iterator
, Sequence
,
7 from weakref
import WeakValueDictionary
as _WeakVDict
9 from cached_property
import cached_property
10 from nmutil
.plain_data
import plain_data
12 from bigint_presentation_code
.type_util
import Self
, assert_never
, final
13 from bigint_presentation_code
.util
import BitSet
, FBitSet
, FMap
, OFSet
19 self
.ops
= [] # type: list[Op]
20 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
21 self
.__next
_name
_suffix
= 2
23 def _add_op_with_unused_name(self
, op
, name
=""):
24 # type: (Op, str) -> str
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op
, "name"):
28 raise ValueError("Op already named")
31 if name
not in self
.__op
_names
:
32 self
.__op
_names
[name
] = op
34 name
= orig_name
+ str(self
.__next
_name
_suffix
)
35 self
.__next
_name
_suffix
+= 1
47 VL_MAXVL
= enum
.auto()
50 def only_scalar(self
):
52 if self
is BaseTy
.I64
:
54 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
60 def max_reg_len(self
):
62 if self
is BaseTy
.I64
:
64 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
70 return "BaseTy." + self
._name
_
73 @plain_data(frozen
=True, unsafe_hash
=True)
76 __slots__
= "base_ty", "reg_len"
79 def validate(base_ty
, reg_len
):
80 # type: (BaseTy, int) -> str | None
81 """ return a string with the error if the combination is invalid,
84 if base_ty
.only_scalar
and reg_len
!= 1:
85 return f
"can't create a vector of an only-scalar type: {base_ty}"
86 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
87 return "reg_len out of range"
90 def __init__(self
, base_ty
, reg_len
):
91 # type: (BaseTy, int) -> None
92 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
95 self
.base_ty
= base_ty
96 self
.reg_len
= reg_len
103 StackI64
= enum
.auto()
105 VL_MAXVL
= enum
.auto()
110 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
112 if self
is LocKind
.CA
:
114 if self
is LocKind
.VL_MAXVL
:
115 return BaseTy
.VL_MAXVL
122 if self
is LocKind
.StackI64
:
124 if self
is LocKind
.GPR
or self
is LocKind
.CA \
125 or self
is LocKind
.VL_MAXVL
:
126 return self
.base_ty
.max_reg_len
131 return "LocKind." + self
._name
_
136 class LocSubKind(Enum
):
137 BASE_GPR
= enum
.auto()
138 SV_EXTRA2_VGPR
= enum
.auto()
139 SV_EXTRA2_SGPR
= enum
.auto()
140 SV_EXTRA3_VGPR
= enum
.auto()
141 SV_EXTRA3_SGPR
= enum
.auto()
142 StackI64
= enum
.auto()
144 VL_MAXVL
= enum
.auto()
148 # type: () -> LocKind
149 # pyright fails typechecking when using `in` here:
150 # reported: https://github.com/microsoft/pyright/issues/4102
151 if self
is LocSubKind
.BASE_GPR
or self
is LocSubKind
.SV_EXTRA2_VGPR \
152 or self
is LocSubKind
.SV_EXTRA2_SGPR \
153 or self
is LocSubKind
.SV_EXTRA3_VGPR \
154 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
156 if self
is LocSubKind
.StackI64
:
157 return LocKind
.StackI64
158 if self
is LocSubKind
.CA
:
160 if self
is LocSubKind
.VL_MAXVL
:
161 return LocKind
.VL_MAXVL
166 return self
.kind
.base_ty
169 def allocatable_locs(self
, ty
):
170 # type: (Ty) -> LocSet
171 if ty
.base_ty
!= self
.base_ty
:
172 raise ValueError("type mismatch")
173 raise NotImplementedError # FIXME: finish
176 @plain_data(frozen
=True, unsafe_hash
=True)
179 __slots__
= "base_ty", "is_vec"
181 def __init__(self
, base_ty
, is_vec
):
182 # type: (BaseTy, bool) -> None
183 self
.base_ty
= base_ty
184 if base_ty
.only_scalar
and is_vec
:
185 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
188 def instantiate(self
, maxvl
):
190 # here's where subvl and elwid would be accounted for
192 return Ty(self
.base_ty
, maxvl
)
193 return Ty(self
.base_ty
, 1)
195 def can_instantiate_to(self
, ty
):
197 if self
.base_ty
!= ty
.base_ty
:
201 return ty
.reg_len
== 1
204 @plain_data(frozen
=True, unsafe_hash
=True)
207 __slots__
= "kind", "start", "reg_len"
210 def validate(kind
, start
, reg_len
):
211 # type: (LocKind, int, int) -> str | None
212 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
215 if reg_len
> kind
.loc_count
:
216 return "invalid reg_len"
217 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
218 return "start not in valid range"
222 def try_make(kind
, start
, reg_len
):
223 # type: (LocKind, int, int) -> Loc | None
224 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
227 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
229 def __init__(self
, kind
, start
, reg_len
):
230 # type: (LocKind, int, int) -> None
231 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
233 raise ValueError(msg
)
235 self
.reg_len
= reg_len
238 def conflicts(self
, other
):
239 # type: (Loc) -> bool
240 return (self
.kind
!= other
.kind
241 and self
.start
< other
.stop
and other
.start
< self
.stop
)
244 def make_ty(kind
, reg_len
):
245 # type: (LocKind, int) -> Ty
246 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
251 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
256 return self
.start
+ self
.reg_len
258 def try_concat(self
, *others
):
259 # type: (*Loc | None) -> Loc | None
260 reg_len
= self
.reg_len
263 if other
is None or other
.kind
!= self
.kind
:
265 if stop
!= other
.start
:
268 reg_len
+= other
.reg_len
269 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
272 @plain_data(frozen
=True, eq
=False, repr=False)
274 class LocSet(AbstractSet
[Loc
]):
275 __slots__
= "starts", "ty"
277 def __init__(self
, __locs
=()):
278 # type: (Iterable[Loc]) -> None
279 if isinstance(__locs
, LocSet
):
280 self
.starts
= __locs
.starts
# type: FMap[LocKind, FBitSet]
281 self
.ty
= __locs
.ty
# type: Ty | None
283 starts
= {i
: BitSet() for i
in LocKind
}
289 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
290 starts
[loc
.kind
].add(loc
.start
)
292 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
297 # type: () -> FMap[LocKind, FBitSet]
302 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
306 # type: () -> AbstractSet[LocKind]
307 return self
.starts
.keys()
311 # type: () -> int | None
314 return self
.ty
.reg_len
318 # type: () -> BaseTy | None
321 return self
.ty
.base_ty
323 def concat(self
, *others
):
324 # type: (*LocSet) -> LocSet
327 base_ty
= self
.ty
.base_ty
328 reg_len
= self
.ty
.reg_len
329 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
333 if other
.ty
.base_ty
!= base_ty
:
335 for kind
, other_starts
in other
.starts
.items():
336 if kind
not in starts
:
338 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
339 if starts
[kind
] == 0:
343 reg_len
+= other
.ty
.reg_len
346 # type: () -> Iterable[Loc]
347 for kind
, v
in starts
.items():
349 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
352 return LocSet(locs())
354 def __contains__(self
, loc
):
355 # type: (Loc | Any) -> bool
356 if not isinstance(loc
, Loc
) or loc
.ty
== self
.ty
:
358 if loc
.kind
not in self
.starts
:
360 return loc
.start
in self
.starts
[loc
.kind
]
363 # type: () -> Iterator[Loc]
366 for kind
, starts
in self
.starts
.items():
368 yield Loc(kind
=kind
, start
=start
, reg_len
=self
.ty
.reg_len
)
372 return sum((len(v
) for v
in self
.starts
.values()), 0)
379 return super()._hash
()
385 @plain_data(frozen
=True, unsafe_hash
=True)
387 class GenericOperandDesc
:
388 """generic Op operand descriptor"""
389 __slots__
= "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
392 self
, ty
, # type: GenericTy
393 sub_kinds
, # type: Iterable[LocSubKind]
395 fixed_loc
=None, # type: Loc | None
396 tied_input_index
=None, # type: int | None
397 spread
=False, # type: bool
399 # type: (...) -> None
401 self
.sub_kinds
= OFSet(sub_kinds
)
402 if len(self
.sub_kinds
) == 0:
403 raise ValueError("sub_kinds can't be empty")
404 self
.fixed_loc
= fixed_loc
405 if fixed_loc
is not None:
406 if tied_input_index
is not None:
407 raise ValueError("operand can't be both tied and fixed")
408 if not ty
.can_instantiate_to(fixed_loc
.ty
):
410 f
"fixed_loc has incompatible type for given generic "
411 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
412 if len(self
.sub_kinds
) != 1:
414 "multiple sub_kinds not allowed for fixed operand")
415 for sub_kind
in self
.sub_kinds
:
416 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
418 f
"fixed_loc not in given sub_kind: "
419 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
420 for sub_kind
in self
.sub_kinds
:
421 if sub_kind
.base_ty
!= ty
.base_ty
:
422 raise ValueError(f
"sub_kind is incompatible with type: "
423 f
"sub_kind={sub_kind} ty={ty}")
424 if tied_input_index
is not None and tied_input_index
< 0:
425 raise ValueError("invalid tied_input_index")
426 self
.tied_input_index
= tied_input_index
429 if self
.tied_input_index
is not None:
430 raise ValueError("operand can't be both spread and tied")
431 if self
.fixed_loc
is not None:
432 raise ValueError("operand can't be both spread and fixed")
434 raise ValueError("operand can't be both spread and vector")
436 def tied_to_input(self
, tied_input_index
):
437 # type: (int) -> Self
438 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
439 tied_input_index
=tied_input_index
)
441 def with_fixed_loc(self
, fixed_loc
):
442 # type: (Loc) -> Self
443 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
)
445 def instantiate(self
, maxvl
):
446 # type: (int) -> Iterable[OperandDesc]
451 ty
= self
.ty
.instantiate(maxvl
=maxvl
)
454 # type: () -> Iterable[Loc]
455 if self
.fixed_loc
is not None:
456 if ty
!= self
.fixed_loc
.ty
:
458 f
"instantiation failed: type mismatch with fixed_loc: "
459 f
"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
462 for sub_kind
in self
.sub_kinds
:
463 yield from sub_kind
.allocatable_locs(ty
)
464 loc_set_before_spread
= LocSet(locs())
465 for idx
in range(rep_count
):
468 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
469 tied_input_index
=self
.tied_input_index
,
473 @plain_data(frozen
=True, unsafe_hash
=True)
476 """Op operand descriptor"""
477 __slots__
= "loc_set_before_spread", "tied_input_index", "spread_index"
479 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
):
480 # type: (LocSet, int | None, int | None) -> None
481 if len(loc_set_before_spread
) == 0:
482 raise ValueError("loc_set_before_spread must not be empty")
483 self
.loc_set_before_spread
= loc_set_before_spread
484 self
.tied_input_index
= tied_input_index
485 if self
.tied_input_index
is not None and self
.spread_index
is not None:
486 raise ValueError("operand can't be both spread and tied")
487 self
.spread_index
= spread_index
490 def ty_before_spread(self
):
492 ty
= self
.loc_set_before_spread
.ty
493 assert ty
is not None, (
494 "__init__ checked that the LocSet isn't empty, "
495 "non-empty LocSets should always have ty set")
500 """ Ty after any spread is applied """
501 if self
.spread_index
is not None:
502 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
503 return self
.ty_before_spread
506 OD_BASE_SGPR
= GenericOperandDesc(
507 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
508 sub_kinds
=[LocSubKind
.BASE_GPR
])
509 OD_EXTRA3_SGPR
= GenericOperandDesc(
510 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
511 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
512 OD_EXTRA3_VGPR
= GenericOperandDesc(
513 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
514 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
515 OD_EXTRA2_SGPR
= GenericOperandDesc(
516 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
517 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
518 OD_EXTRA2_VGPR
= GenericOperandDesc(
519 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
520 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
521 OD_CA
= GenericOperandDesc(
522 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
523 sub_kinds
=[LocSubKind
.CA
])
524 OD_VL
= GenericOperandDesc(
525 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
526 sub_kinds
=[LocSubKind
.VL_MAXVL
])
529 @plain_data(frozen
=True, unsafe_hash
=True)
531 class GenericOpProperties
:
532 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
533 "is_copy", "is_load_immediate", "has_side_effects")
535 def __init__(self
, demo_asm
, # type: str
536 inputs
, # type: Iterable[GenericOperandDesc]
537 outputs
, # type: Iterable[GenericOperandDesc]
538 immediates
=(), # type: Iterable[range]
539 is_copy
=False, # type: bool
540 is_load_immediate
=False, # type: bool
541 has_side_effects
=False, # type: bool
543 # type: (...) -> None
544 self
.demo_asm
= demo_asm
545 self
.inputs
= tuple(inputs
)
546 for inp
in self
.inputs
:
547 if inp
.tied_input_index
is not None:
549 f
"tied_input_index is not allowed on inputs: {inp}")
550 self
.outputs
= tuple(outputs
)
551 fixed_locs
= [] # type: list[tuple[Loc, int]]
552 for idx
, out
in enumerate(self
.outputs
):
553 if out
.tied_input_index
is not None \
554 and out
.tied_input_index
>= len(self
.inputs
):
555 raise ValueError(f
"tied_input_index out of range: {out}")
556 if out
.fixed_loc
is not None:
557 for other_fixed_loc
, other_idx
in fixed_locs
:
558 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
561 f
"conflicting fixed_locs: outputs[{idx}] and "
562 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
563 f
"with {other_fixed_loc}")
564 fixed_locs
.append((out
.fixed_loc
, idx
))
565 self
.immediates
= tuple(immediates
)
566 self
.is_copy
= is_copy
567 self
.is_load_immediate
= is_load_immediate
568 self
.has_side_effects
= has_side_effects
571 @plain_data(frozen
=True, unsafe_hash
=True)
574 __slots__
= "kind", "inputs", "outputs", "maxvl"
576 def __init__(self
, kind
, maxvl
):
577 # type: (OpKind, int) -> None
579 inputs
= [] # type: list[OperandDesc]
580 for inp
in self
.generic
.inputs
:
581 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
582 self
.inputs
= tuple(inputs
)
583 outputs
= [] # type: list[OperandDesc]
584 for out
in self
.generic
.outputs
:
585 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
586 self
.outputs
= tuple(outputs
)
591 # type: () -> GenericOpProperties
592 return self
.kind
.properties
595 def immediates(self
):
596 # type: () -> tuple[range, ...]
597 return self
.generic
.immediates
602 return self
.generic
.demo_asm
607 return self
.generic
.is_copy
610 def is_load_immediate(self
):
612 return self
.generic
.is_load_immediate
615 def has_side_effects(self
):
617 return self
.generic
.has_side_effects
623 def __init__(self
, properties
):
624 # type: (GenericOpProperties) -> None
626 self
.__properties
= properties
629 def properties(self
):
630 # type: () -> GenericOpProperties
631 return self
.__properties
633 SvAddE
= GenericOpProperties(
634 demo_asm
="sv.adde *RT, *RA, *RB",
635 inputs
=(OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
),
636 outputs
=(OD_EXTRA3_VGPR
, OD_CA
),
638 SvSubFE
= GenericOpProperties(
639 demo_asm
="sv.subfe *RT, *RA, *RB",
640 inputs
=(OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
),
641 outputs
=(OD_EXTRA3_VGPR
, OD_CA
),
643 SvMAddEDU
= GenericOpProperties(
644 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
645 inputs
=(OD_EXTRA2_VGPR
, OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
,
646 OD_EXTRA2_SGPR
, OD_VL
),
647 outputs
=(OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(3)),
649 SetVLI
= GenericOpProperties(
650 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
653 immediates
=(range(1, 65),),
654 is_load_immediate
=True,
656 SvLI
= GenericOpProperties(
657 demo_asm
="sv.addi *RT, 0, imm",
659 outputs
=(OD_EXTRA3_VGPR
,),
660 immediates
=(range(-2 ** 15, 2 ** 15),),
661 is_load_immediate
=True,
663 LI
= GenericOpProperties(
664 demo_asm
="addi RT, 0, imm",
666 outputs
=(OD_BASE_SGPR
,),
667 immediates
=(range(-2 ** 15, 2 ** 15),),
668 is_load_immediate
=True,
670 VecCopyToReg
= GenericOpProperties(
671 demo_asm
="sv.mv dest, src",
672 inputs
=(GenericOperandDesc(
673 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
674 sub_kinds
=(LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
),
676 outputs
=(OD_EXTRA3_VGPR
,),
679 VecCopyFromReg
= GenericOpProperties(
680 demo_asm
="sv.mv dest, src",
681 inputs
=(OD_EXTRA3_VGPR
, OD_VL
),
682 outputs
=(GenericOperandDesc(
683 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
684 sub_kinds
=(LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
),
688 CopyToReg
= GenericOpProperties(
689 demo_asm
="mv dest, src",
690 inputs
=(GenericOperandDesc(
691 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
692 sub_kinds
=(LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
693 LocSubKind
.StackI64
),
695 outputs
=(GenericOperandDesc(
696 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
697 sub_kinds
=(LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
),
701 CopyFromReg
= GenericOpProperties(
702 demo_asm
="mv dest, src",
703 inputs
=(GenericOperandDesc(
704 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
705 sub_kinds
=(LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
),
707 outputs
=(GenericOperandDesc(
708 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
709 sub_kinds
=(LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
710 LocSubKind
.StackI64
),
714 Concat
= GenericOpProperties(
715 demo_asm
="sv.mv dest, src",
716 inputs
=(GenericOperandDesc(
717 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
718 sub_kinds
=(LocSubKind
.SV_EXTRA3_VGPR
,),
721 outputs
=(OD_EXTRA3_VGPR
,),
724 Spread
= GenericOpProperties(
725 demo_asm
="sv.mv dest, src",
726 inputs
=(OD_EXTRA3_VGPR
, OD_VL
),
727 outputs
=(GenericOperandDesc(
728 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
729 sub_kinds
=(LocSubKind
.SV_EXTRA3_VGPR
,),
736 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
739 __slots__
= "op", "output_idx"
741 def __init__(self
, op
, output_idx
):
742 # type: (Op, int) -> None
744 if output_idx
< 0 or output_idx
>= len(op
.properties
.outputs
):
745 raise ValueError("invalid output_idx")
746 self
.output_idx
= output_idx
750 return f
"<{self.op.name}#{self.output_idx}>"
753 def defining_descriptor(self
):
754 # type: () -> OperandDesc
755 return self
.op
.properties
.outputs
[self
.output_idx
]
758 def loc_set_before_spread(self
):
760 return self
.defining_descriptor
.loc_set_before_spread
765 return self
.defining_descriptor
.ty
768 def ty_before_spread(self
):
770 return self
.defining_descriptor
.ty_before_spread
774 _Desc
= TypeVar("_Desc")
777 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
779 def _verify_write_with_desc(self
, idx
, item
, desc
):
780 # type: (int, _T | Any, _Desc) -> None
781 raise NotImplementedError
784 def _verify_write(self
, idx
, item
):
785 # type: (int | Any, _T | Any) -> int
786 if not isinstance(idx
, int):
787 if isinstance(idx
, slice):
789 f
"can't write to slice of {self.__class__.__name__}")
790 raise TypeError(f
"can't write with index {idx!r}")
791 # normalize idx, raising IndexError if it is out of range
792 idx
= range(len(self
.descriptors
))[idx
]
793 desc
= self
.descriptors
[idx
]
794 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
798 def _get_descriptors(self
):
799 # type: () -> tuple[_Desc, ...]
800 raise NotImplementedError
804 def descriptors(self
):
805 # type: () -> tuple[_Desc, ...]
806 return self
._get
_descriptors
()
813 def __init__(self
, items
, op
):
814 # type: (Iterable[_T], Op) -> None
816 self
.__items
= [] # type: list[_T]
817 for idx
, item
in enumerate(items
):
818 if idx
>= len(self
.descriptors
):
819 raise ValueError("too many items")
820 self
._verify
_write
(idx
, item
)
821 self
.__items
.append(item
)
822 if len(self
.__items
) < len(self
.descriptors
):
823 raise ValueError("not enough items")
827 # type: () -> Iterator[_T]
828 yield from self
.__items
831 def __getitem__(self
, idx
):
836 def __getitem__(self
, idx
):
837 # type: (slice) -> list[_T]
841 def __getitem__(self
, idx
):
842 # type: (int | slice) -> _T | list[_T]
843 return self
.__items
[idx
]
846 def __setitem__(self
, idx
, item
):
847 # type: (int, _T) -> None
848 idx
= self
._verify
_write
(idx
, item
)
849 self
.__items
[idx
] = item
854 return len(self
.__items
)
858 class OpInputs(OpInputSeq
[SSAVal
, OperandDesc
]):
859 def _get_descriptors(self
):
860 # type: () -> tuple[OperandDesc, ...]
861 return self
.op
.properties
.inputs
863 def _verify_write_with_desc(self
, idx
, item
, desc
):
864 # type: (int, SSAVal | Any, OperandDesc) -> None
865 if not isinstance(item
, SSAVal
):
866 raise TypeError("expected value of type SSAVal")
867 if item
.ty
!= desc
.ty
:
868 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
869 f
"corresponding input's type {desc.ty!r}")
871 def __init__(self
, items
, op
):
872 # type: (Iterable[SSAVal], Op) -> None
873 if hasattr(op
, "inputs"):
874 raise ValueError("Op.inputs already set")
875 super().__init
__(items
, op
)
879 class OpImmediates(OpInputSeq
[int, range]):
880 def _get_descriptors(self
):
881 # type: () -> tuple[range, ...]
882 return self
.op
.properties
.immediates
884 def _verify_write_with_desc(self
, idx
, item
, desc
):
885 # type: (int, int | Any, range) -> None
886 if not isinstance(item
, int):
887 raise TypeError("expected value of type int")
889 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
891 def __init__(self
, items
, op
):
892 # type: (Iterable[int], Op) -> None
893 if hasattr(op
, "immediates"):
894 raise ValueError("Op.immediates already set")
895 super().__init
__(items
, op
)
898 @plain_data(frozen
=True, eq
=False)
901 __slots__
= "fn", "properties", "inputs", "immediates", "outputs", "name"
903 def __init__(self
, fn
, properties
, inputs
, immediates
, name
=""):
904 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
906 self
.properties
= properties
907 self
.inputs
= OpInputs(inputs
, op
=self
)
908 self
.immediates
= OpImmediates(immediates
, op
=self
)
909 outputs_len
= len(self
.properties
.outputs
)
910 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
911 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
915 return self
.properties
.kind
917 def __eq__(self
, other
):
918 # type: (Op | Any) -> bool
919 if isinstance(other
, Op
):
921 return NotImplemented
924 return object.__hash
__(self
)