b04848b53a37be825fc16d2d07692edf5a01931b
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 import enum
2 from abc import abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache
5 from typing import (AbstractSet, Any, Generic, Iterable, Iterator, Sequence,
6 TypeVar, overload)
7 from weakref import WeakValueDictionary as _WeakVDict
8
9 from cached_property import cached_property
10 from nmutil.plain_data import plain_data
11
12 from bigint_presentation_code.type_util import Self, assert_never, final
13 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
14
15
16 @final
17 class Fn:
18 def __init__(self):
19 self.ops = [] # type: list[Op]
20 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
21 self.__next_name_suffix = 2
22
23 def _add_op_with_unused_name(self, op, name=""):
24 # type: (Op, str) -> str
25 if op.fn is not self:
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op, "name"):
28 raise ValueError("Op already named")
29 orig_name = name
30 while True:
31 if name not in self.__op_names:
32 self.__op_names[name] = op
33 return name
34 name = orig_name + str(self.__next_name_suffix)
35 self.__next_name_suffix += 1
36
37 def __repr__(self):
38 # type: () -> str
39 return "<Fn>"
40
41
42 @unique
43 @final
44 class BaseTy(Enum):
45 I64 = enum.auto()
46 CA = enum.auto()
47 VL_MAXVL = enum.auto()
48
49 @cached_property
50 def only_scalar(self):
51 # type: () -> bool
52 if self is BaseTy.I64:
53 return False
54 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
55 return True
56 else:
57 assert_never(self)
58
59 @cached_property
60 def max_reg_len(self):
61 # type: () -> int
62 if self is BaseTy.I64:
63 return 128
64 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
65 return 1
66 else:
67 assert_never(self)
68
69 def __repr__(self):
70 return "BaseTy." + self._name_
71
72
73 @plain_data(frozen=True, unsafe_hash=True)
74 @final
75 class Ty:
76 __slots__ = "base_ty", "reg_len"
77
78 @staticmethod
79 def validate(base_ty, reg_len):
80 # type: (BaseTy, int) -> str | None
81 """ return a string with the error if the combination is invalid,
82 otherwise return None
83 """
84 if base_ty.only_scalar and reg_len != 1:
85 return f"can't create a vector of an only-scalar type: {base_ty}"
86 if reg_len < 1 or reg_len > base_ty.max_reg_len:
87 return "reg_len out of range"
88 return None
89
90 def __init__(self, base_ty, reg_len):
91 # type: (BaseTy, int) -> None
92 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
93 if msg is not None:
94 raise ValueError(msg)
95 self.base_ty = base_ty
96 self.reg_len = reg_len
97
98
99 @unique
100 @final
101 class LocKind(Enum):
102 GPR = enum.auto()
103 StackI64 = enum.auto()
104 CA = enum.auto()
105 VL_MAXVL = enum.auto()
106
107 @cached_property
108 def base_ty(self):
109 # type: () -> BaseTy
110 if self is LocKind.GPR or self is LocKind.StackI64:
111 return BaseTy.I64
112 if self is LocKind.CA:
113 return BaseTy.CA
114 if self is LocKind.VL_MAXVL:
115 return BaseTy.VL_MAXVL
116 else:
117 assert_never(self)
118
119 @cached_property
120 def loc_count(self):
121 # type: () -> int
122 if self is LocKind.StackI64:
123 return 1024
124 if self is LocKind.GPR or self is LocKind.CA \
125 or self is LocKind.VL_MAXVL:
126 return self.base_ty.max_reg_len
127 else:
128 assert_never(self)
129
130 def __repr__(self):
131 return "LocKind." + self._name_
132
133
134 @final
135 @unique
136 class LocSubKind(Enum):
137 BASE_GPR = enum.auto()
138 SV_EXTRA2_VGPR = enum.auto()
139 SV_EXTRA2_SGPR = enum.auto()
140 SV_EXTRA3_VGPR = enum.auto()
141 SV_EXTRA3_SGPR = enum.auto()
142 StackI64 = enum.auto()
143 CA = enum.auto()
144 VL_MAXVL = enum.auto()
145
146 @cached_property
147 def kind(self):
148 # type: () -> LocKind
149 # pyright fails typechecking when using `in` here:
150 # reported: https://github.com/microsoft/pyright/issues/4102
151 if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
152 or self is LocSubKind.SV_EXTRA2_SGPR \
153 or self is LocSubKind.SV_EXTRA3_VGPR \
154 or self is LocSubKind.SV_EXTRA3_SGPR:
155 return LocKind.GPR
156 if self is LocSubKind.StackI64:
157 return LocKind.StackI64
158 if self is LocSubKind.CA:
159 return LocKind.CA
160 if self is LocSubKind.VL_MAXVL:
161 return LocKind.VL_MAXVL
162 assert_never(self)
163
164 @property
165 def base_ty(self):
166 return self.kind.base_ty
167
168 @lru_cache()
169 def allocatable_locs(self, ty):
170 # type: (Ty) -> LocSet
171 if ty.base_ty != self.base_ty:
172 raise ValueError("type mismatch")
173 raise NotImplementedError # FIXME: finish
174
175
176 @plain_data(frozen=True, unsafe_hash=True)
177 @final
178 class GenericTy:
179 __slots__ = "base_ty", "is_vec"
180
181 def __init__(self, base_ty, is_vec):
182 # type: (BaseTy, bool) -> None
183 self.base_ty = base_ty
184 if base_ty.only_scalar and is_vec:
185 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
186 self.is_vec = is_vec
187
188 def instantiate(self, maxvl):
189 # type: (int) -> Ty
190 # here's where subvl and elwid would be accounted for
191 if self.is_vec:
192 return Ty(self.base_ty, maxvl)
193 return Ty(self.base_ty, 1)
194
195 def can_instantiate_to(self, ty):
196 # type: (Ty) -> bool
197 if self.base_ty != ty.base_ty:
198 return False
199 if self.is_vec:
200 return True
201 return ty.reg_len == 1
202
203
204 @plain_data(frozen=True, unsafe_hash=True)
205 @final
206 class Loc:
207 __slots__ = "kind", "start", "reg_len"
208
209 @staticmethod
210 def validate(kind, start, reg_len):
211 # type: (LocKind, int, int) -> str | None
212 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
213 if msg is not None:
214 return msg
215 if reg_len > kind.loc_count:
216 return "invalid reg_len"
217 if start < 0 or start + reg_len > kind.loc_count:
218 return "start not in valid range"
219 return None
220
221 @staticmethod
222 def try_make(kind, start, reg_len):
223 # type: (LocKind, int, int) -> Loc | None
224 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
225 if msg is None:
226 return None
227 return Loc(kind=kind, start=start, reg_len=reg_len)
228
229 def __init__(self, kind, start, reg_len):
230 # type: (LocKind, int, int) -> None
231 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
232 if msg is not None:
233 raise ValueError(msg)
234 self.kind = kind
235 self.reg_len = reg_len
236 self.start = start
237
238 def conflicts(self, other):
239 # type: (Loc) -> bool
240 return (self.kind != other.kind
241 and self.start < other.stop and other.start < self.stop)
242
243 @staticmethod
244 def make_ty(kind, reg_len):
245 # type: (LocKind, int) -> Ty
246 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
247
248 @cached_property
249 def ty(self):
250 # type: () -> Ty
251 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
252
253 @property
254 def stop(self):
255 # type: () -> int
256 return self.start + self.reg_len
257
258 def try_concat(self, *others):
259 # type: (*Loc | None) -> Loc | None
260 reg_len = self.reg_len
261 stop = self.stop
262 for other in others:
263 if other is None or other.kind != self.kind:
264 return None
265 if stop != other.start:
266 return None
267 stop = other.stop
268 reg_len += other.reg_len
269 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
270
271
272 @plain_data(frozen=True, eq=False, repr=False)
273 @final
274 class LocSet(AbstractSet[Loc]):
275 __slots__ = "starts", "ty"
276
277 def __init__(self, __locs=()):
278 # type: (Iterable[Loc]) -> None
279 if isinstance(__locs, LocSet):
280 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
281 self.ty = __locs.ty # type: Ty | None
282 return
283 starts = {i: BitSet() for i in LocKind}
284 ty = None
285 for loc in __locs:
286 if ty is None:
287 ty = loc.ty
288 if ty != loc.ty:
289 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
290 starts[loc.kind].add(loc.start)
291 self.starts = FMap(
292 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
293 self.ty = ty
294
295 @cached_property
296 def stops(self):
297 # type: () -> FMap[LocKind, FBitSet]
298 if self.ty is None:
299 return FMap()
300 sh = self.ty.reg_len
301 return FMap(
302 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
303
304 @property
305 def kinds(self):
306 # type: () -> AbstractSet[LocKind]
307 return self.starts.keys()
308
309 @property
310 def reg_len(self):
311 # type: () -> int | None
312 if self.ty is None:
313 return None
314 return self.ty.reg_len
315
316 @property
317 def base_ty(self):
318 # type: () -> BaseTy | None
319 if self.ty is None:
320 return None
321 return self.ty.base_ty
322
323 def concat(self, *others):
324 # type: (*LocSet) -> LocSet
325 if self.ty is None:
326 return LocSet()
327 base_ty = self.ty.base_ty
328 reg_len = self.ty.reg_len
329 starts = {k: BitSet(v) for k, v in self.starts.items()}
330 for other in others:
331 if other.ty is None:
332 return LocSet()
333 if other.ty.base_ty != base_ty:
334 return LocSet()
335 for kind, other_starts in other.starts.items():
336 if kind not in starts:
337 continue
338 starts[kind].bits &= other_starts.bits >> reg_len
339 if starts[kind] == 0:
340 del starts[kind]
341 if len(starts) == 0:
342 return LocSet()
343 reg_len += other.ty.reg_len
344
345 def locs():
346 # type: () -> Iterable[Loc]
347 for kind, v in starts.items():
348 for start in v:
349 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
350 if loc is not None:
351 yield loc
352 return LocSet(locs())
353
354 def __contains__(self, loc):
355 # type: (Loc | Any) -> bool
356 if not isinstance(loc, Loc) or loc.ty == self.ty:
357 return False
358 if loc.kind not in self.starts:
359 return False
360 return loc.start in self.starts[loc.kind]
361
362 def __iter__(self):
363 # type: () -> Iterator[Loc]
364 if self.ty is None:
365 return
366 for kind, starts in self.starts.items():
367 for start in starts:
368 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
369
370 @cached_property
371 def __len(self):
372 return sum((len(v) for v in self.starts.values()), 0)
373
374 def __len__(self):
375 return self.__len
376
377 @cached_property
378 def __hash(self):
379 return super()._hash()
380
381 def __hash__(self):
382 return self.__hash
383
384
385 @plain_data(frozen=True, unsafe_hash=True)
386 @final
387 class GenericOperandDesc:
388 """generic Op operand descriptor"""
389 __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
390
391 def __init__(
392 self, ty, # type: GenericTy
393 sub_kinds, # type: Iterable[LocSubKind]
394 *,
395 fixed_loc=None, # type: Loc | None
396 tied_input_index=None, # type: int | None
397 spread=False, # type: bool
398 ):
399 # type: (...) -> None
400 self.ty = ty
401 self.sub_kinds = OFSet(sub_kinds)
402 if len(self.sub_kinds) == 0:
403 raise ValueError("sub_kinds can't be empty")
404 self.fixed_loc = fixed_loc
405 if fixed_loc is not None:
406 if tied_input_index is not None:
407 raise ValueError("operand can't be both tied and fixed")
408 if not ty.can_instantiate_to(fixed_loc.ty):
409 raise ValueError(
410 f"fixed_loc has incompatible type for given generic "
411 f"type: fixed_loc={fixed_loc} generic ty={ty}")
412 if len(self.sub_kinds) != 1:
413 raise ValueError(
414 "multiple sub_kinds not allowed for fixed operand")
415 for sub_kind in self.sub_kinds:
416 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
417 raise ValueError(
418 f"fixed_loc not in given sub_kind: "
419 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
420 for sub_kind in self.sub_kinds:
421 if sub_kind.base_ty != ty.base_ty:
422 raise ValueError(f"sub_kind is incompatible with type: "
423 f"sub_kind={sub_kind} ty={ty}")
424 if tied_input_index is not None and tied_input_index < 0:
425 raise ValueError("invalid tied_input_index")
426 self.tied_input_index = tied_input_index
427 self.spread = spread
428 if spread:
429 if self.tied_input_index is not None:
430 raise ValueError("operand can't be both spread and tied")
431 if self.fixed_loc is not None:
432 raise ValueError("operand can't be both spread and fixed")
433 if self.ty.is_vec:
434 raise ValueError("operand can't be both spread and vector")
435
436 def tied_to_input(self, tied_input_index):
437 # type: (int) -> Self
438 return GenericOperandDesc(self.ty, self.sub_kinds,
439 tied_input_index=tied_input_index)
440
441 def with_fixed_loc(self, fixed_loc):
442 # type: (Loc) -> Self
443 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
444
445 def instantiate(self, maxvl):
446 # type: (int) -> Iterable[OperandDesc]
447 rep_count = 1
448 if self.spread:
449 rep_count = maxvl
450 maxvl = 1
451 ty = self.ty.instantiate(maxvl=maxvl)
452
453 def locs():
454 # type: () -> Iterable[Loc]
455 if self.fixed_loc is not None:
456 if ty != self.fixed_loc.ty:
457 raise ValueError(
458 f"instantiation failed: type mismatch with fixed_loc: "
459 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
460 yield self.fixed_loc
461 return
462 for sub_kind in self.sub_kinds:
463 yield from sub_kind.allocatable_locs(ty)
464 loc_set_before_spread = LocSet(locs())
465 for idx in range(rep_count):
466 if not self.spread:
467 idx = None
468 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
469 tied_input_index=self.tied_input_index,
470 spread_index=idx)
471
472
473 @plain_data(frozen=True, unsafe_hash=True)
474 @final
475 class OperandDesc:
476 """Op operand descriptor"""
477 __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
478
479 def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
480 # type: (LocSet, int | None, int | None) -> None
481 if len(loc_set_before_spread) == 0:
482 raise ValueError("loc_set_before_spread must not be empty")
483 self.loc_set_before_spread = loc_set_before_spread
484 self.tied_input_index = tied_input_index
485 if self.tied_input_index is not None and self.spread_index is not None:
486 raise ValueError("operand can't be both spread and tied")
487 self.spread_index = spread_index
488
489 @cached_property
490 def ty_before_spread(self):
491 # type: () -> Ty
492 ty = self.loc_set_before_spread.ty
493 assert ty is not None, (
494 "__init__ checked that the LocSet isn't empty, "
495 "non-empty LocSets should always have ty set")
496 return ty
497
498 @cached_property
499 def ty(self):
500 """ Ty after any spread is applied """
501 if self.spread_index is not None:
502 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
503 return self.ty_before_spread
504
505
506 OD_BASE_SGPR = GenericOperandDesc(
507 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
508 sub_kinds=[LocSubKind.BASE_GPR])
509 OD_EXTRA3_SGPR = GenericOperandDesc(
510 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
511 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
512 OD_EXTRA3_VGPR = GenericOperandDesc(
513 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
514 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
515 OD_EXTRA2_SGPR = GenericOperandDesc(
516 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
517 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
518 OD_EXTRA2_VGPR = GenericOperandDesc(
519 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
520 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
521 OD_CA = GenericOperandDesc(
522 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
523 sub_kinds=[LocSubKind.CA])
524 OD_VL = GenericOperandDesc(
525 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
526 sub_kinds=[LocSubKind.VL_MAXVL])
527
528
529 @plain_data(frozen=True, unsafe_hash=True)
530 @final
531 class GenericOpProperties:
532 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
533 "is_copy", "is_load_immediate", "has_side_effects")
534
535 def __init__(self, demo_asm, # type: str
536 inputs, # type: Iterable[GenericOperandDesc]
537 outputs, # type: Iterable[GenericOperandDesc]
538 immediates=(), # type: Iterable[range]
539 is_copy=False, # type: bool
540 is_load_immediate=False, # type: bool
541 has_side_effects=False, # type: bool
542 ):
543 # type: (...) -> None
544 self.demo_asm = demo_asm
545 self.inputs = tuple(inputs)
546 for inp in self.inputs:
547 if inp.tied_input_index is not None:
548 raise ValueError(
549 f"tied_input_index is not allowed on inputs: {inp}")
550 self.outputs = tuple(outputs)
551 fixed_locs = [] # type: list[tuple[Loc, int]]
552 for idx, out in enumerate(self.outputs):
553 if out.tied_input_index is not None \
554 and out.tied_input_index >= len(self.inputs):
555 raise ValueError(f"tied_input_index out of range: {out}")
556 if out.fixed_loc is not None:
557 for other_fixed_loc, other_idx in fixed_locs:
558 if not other_fixed_loc.conflicts(out.fixed_loc):
559 continue
560 raise ValueError(
561 f"conflicting fixed_locs: outputs[{idx}] and "
562 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
563 f"with {other_fixed_loc}")
564 fixed_locs.append((out.fixed_loc, idx))
565 self.immediates = tuple(immediates)
566 self.is_copy = is_copy
567 self.is_load_immediate = is_load_immediate
568 self.has_side_effects = has_side_effects
569
570
571 @plain_data(frozen=True, unsafe_hash=True)
572 @final
573 class OpProperties:
574 __slots__ = "kind", "inputs", "outputs", "maxvl"
575
576 def __init__(self, kind, maxvl):
577 # type: (OpKind, int) -> None
578 self.kind = kind
579 inputs = [] # type: list[OperandDesc]
580 for inp in self.generic.inputs:
581 inputs.extend(inp.instantiate(maxvl=maxvl))
582 self.inputs = tuple(inputs)
583 outputs = [] # type: list[OperandDesc]
584 for out in self.generic.outputs:
585 outputs.extend(out.instantiate(maxvl=maxvl))
586 self.outputs = tuple(outputs)
587 self.maxvl = maxvl
588
589 @property
590 def generic(self):
591 # type: () -> GenericOpProperties
592 return self.kind.properties
593
594 @property
595 def immediates(self):
596 # type: () -> tuple[range, ...]
597 return self.generic.immediates
598
599 @property
600 def demo_asm(self):
601 # type: () -> str
602 return self.generic.demo_asm
603
604 @property
605 def is_copy(self):
606 # type: () -> bool
607 return self.generic.is_copy
608
609 @property
610 def is_load_immediate(self):
611 # type: () -> bool
612 return self.generic.is_load_immediate
613
614 @property
615 def has_side_effects(self):
616 # type: () -> bool
617 return self.generic.has_side_effects
618
619
620 @unique
621 @final
622 class OpKind(Enum):
623 def __init__(self, properties):
624 # type: (GenericOpProperties) -> None
625 super().__init__()
626 self.__properties = properties
627
628 @property
629 def properties(self):
630 # type: () -> GenericOpProperties
631 return self.__properties
632
633 SvAddE = GenericOpProperties(
634 demo_asm="sv.adde *RT, *RA, *RB",
635 inputs=(OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL),
636 outputs=(OD_EXTRA3_VGPR, OD_CA),
637 )
638 SvSubFE = GenericOpProperties(
639 demo_asm="sv.subfe *RT, *RA, *RB",
640 inputs=(OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL),
641 outputs=(OD_EXTRA3_VGPR, OD_CA),
642 )
643 SvMAddEDU = GenericOpProperties(
644 demo_asm="sv.maddedu *RT, *RA, RB, RC",
645 inputs=(OD_EXTRA2_VGPR, OD_EXTRA2_VGPR, OD_EXTRA2_SGPR,
646 OD_EXTRA2_SGPR, OD_VL),
647 outputs=(OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(3)),
648 )
649 SetVLI = GenericOpProperties(
650 demo_asm="setvl 0, 0, imm, 0, 1, 1",
651 inputs=(),
652 outputs=(OD_VL,),
653 immediates=(range(1, 65),),
654 is_load_immediate=True,
655 )
656 SvLI = GenericOpProperties(
657 demo_asm="sv.addi *RT, 0, imm",
658 inputs=(OD_VL,),
659 outputs=(OD_EXTRA3_VGPR,),
660 immediates=(range(-2 ** 15, 2 ** 15),),
661 is_load_immediate=True,
662 )
663 LI = GenericOpProperties(
664 demo_asm="addi RT, 0, imm",
665 inputs=(),
666 outputs=(OD_BASE_SGPR,),
667 immediates=(range(-2 ** 15, 2 ** 15),),
668 is_load_immediate=True,
669 )
670 VecCopyToReg = GenericOpProperties(
671 demo_asm="sv.mv dest, src",
672 inputs=(GenericOperandDesc(
673 ty=GenericTy(BaseTy.I64, is_vec=True),
674 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64),
675 ), OD_VL),
676 outputs=(OD_EXTRA3_VGPR,),
677 is_copy=True,
678 )
679 VecCopyFromReg = GenericOpProperties(
680 demo_asm="sv.mv dest, src",
681 inputs=(OD_EXTRA3_VGPR, OD_VL),
682 outputs=(GenericOperandDesc(
683 ty=GenericTy(BaseTy.I64, is_vec=True),
684 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64),
685 ),),
686 is_copy=True,
687 )
688 CopyToReg = GenericOpProperties(
689 demo_asm="mv dest, src",
690 inputs=(GenericOperandDesc(
691 ty=GenericTy(BaseTy.I64, is_vec=False),
692 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
693 LocSubKind.StackI64),
694 ),),
695 outputs=(GenericOperandDesc(
696 ty=GenericTy(BaseTy.I64, is_vec=False),
697 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR),
698 ),),
699 is_copy=True,
700 )
701 CopyFromReg = GenericOpProperties(
702 demo_asm="mv dest, src",
703 inputs=(GenericOperandDesc(
704 ty=GenericTy(BaseTy.I64, is_vec=False),
705 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR),
706 ),),
707 outputs=(GenericOperandDesc(
708 ty=GenericTy(BaseTy.I64, is_vec=False),
709 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
710 LocSubKind.StackI64),
711 ),),
712 is_copy=True,
713 )
714 Concat = GenericOpProperties(
715 demo_asm="sv.mv dest, src",
716 inputs=(GenericOperandDesc(
717 ty=GenericTy(BaseTy.I64, is_vec=False),
718 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR,),
719 spread=True,
720 ), OD_VL),
721 outputs=(OD_EXTRA3_VGPR,),
722 is_copy=True,
723 )
724 Spread = GenericOpProperties(
725 demo_asm="sv.mv dest, src",
726 inputs=(OD_EXTRA3_VGPR, OD_VL),
727 outputs=(GenericOperandDesc(
728 ty=GenericTy(BaseTy.I64, is_vec=False),
729 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR,),
730 spread=True,
731 ),),
732 is_copy=True,
733 )
734
735
736 @plain_data(frozen=True, unsafe_hash=True, repr=False)
737 @final
738 class SSAVal:
739 __slots__ = "op", "output_idx"
740
741 def __init__(self, op, output_idx):
742 # type: (Op, int) -> None
743 self.op = op
744 if output_idx < 0 or output_idx >= len(op.properties.outputs):
745 raise ValueError("invalid output_idx")
746 self.output_idx = output_idx
747
748 def __repr__(self):
749 # type: () -> str
750 return f"<{self.op.name}#{self.output_idx}>"
751
752 @cached_property
753 def defining_descriptor(self):
754 # type: () -> OperandDesc
755 return self.op.properties.outputs[self.output_idx]
756
757 @cached_property
758 def loc_set_before_spread(self):
759 # type: () -> LocSet
760 return self.defining_descriptor.loc_set_before_spread
761
762 @cached_property
763 def ty(self):
764 # type: () -> Ty
765 return self.defining_descriptor.ty
766
767 @cached_property
768 def ty_before_spread(self):
769 # type: () -> Ty
770 return self.defining_descriptor.ty_before_spread
771
772
773 _T = TypeVar("_T")
774 _Desc = TypeVar("_Desc")
775
776
777 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
778 @abstractmethod
779 def _verify_write_with_desc(self, idx, item, desc):
780 # type: (int, _T | Any, _Desc) -> None
781 raise NotImplementedError
782
783 @final
784 def _verify_write(self, idx, item):
785 # type: (int | Any, _T | Any) -> int
786 if not isinstance(idx, int):
787 if isinstance(idx, slice):
788 raise TypeError(
789 f"can't write to slice of {self.__class__.__name__}")
790 raise TypeError(f"can't write with index {idx!r}")
791 # normalize idx, raising IndexError if it is out of range
792 idx = range(len(self.descriptors))[idx]
793 desc = self.descriptors[idx]
794 self._verify_write_with_desc(idx, item, desc)
795 return idx
796
797 @abstractmethod
798 def _get_descriptors(self):
799 # type: () -> tuple[_Desc, ...]
800 raise NotImplementedError
801
802 @cached_property
803 @final
804 def descriptors(self):
805 # type: () -> tuple[_Desc, ...]
806 return self._get_descriptors()
807
808 @property
809 @final
810 def op(self):
811 return self.__op
812
813 def __init__(self, items, op):
814 # type: (Iterable[_T], Op) -> None
815 self.__op = op
816 self.__items = [] # type: list[_T]
817 for idx, item in enumerate(items):
818 if idx >= len(self.descriptors):
819 raise ValueError("too many items")
820 self._verify_write(idx, item)
821 self.__items.append(item)
822 if len(self.__items) < len(self.descriptors):
823 raise ValueError("not enough items")
824
825 @final
826 def __iter__(self):
827 # type: () -> Iterator[_T]
828 yield from self.__items
829
830 @overload
831 def __getitem__(self, idx):
832 # type: (int) -> _T
833 ...
834
835 @overload
836 def __getitem__(self, idx):
837 # type: (slice) -> list[_T]
838 ...
839
840 @final
841 def __getitem__(self, idx):
842 # type: (int | slice) -> _T | list[_T]
843 return self.__items[idx]
844
845 @final
846 def __setitem__(self, idx, item):
847 # type: (int, _T) -> None
848 idx = self._verify_write(idx, item)
849 self.__items[idx] = item
850
851 @final
852 def __len__(self):
853 # type: () -> int
854 return len(self.__items)
855
856
857 @final
858 class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
859 def _get_descriptors(self):
860 # type: () -> tuple[OperandDesc, ...]
861 return self.op.properties.inputs
862
863 def _verify_write_with_desc(self, idx, item, desc):
864 # type: (int, SSAVal | Any, OperandDesc) -> None
865 if not isinstance(item, SSAVal):
866 raise TypeError("expected value of type SSAVal")
867 if item.ty != desc.ty:
868 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
869 f"corresponding input's type {desc.ty!r}")
870
871 def __init__(self, items, op):
872 # type: (Iterable[SSAVal], Op) -> None
873 if hasattr(op, "inputs"):
874 raise ValueError("Op.inputs already set")
875 super().__init__(items, op)
876
877
878 @final
879 class OpImmediates(OpInputSeq[int, range]):
880 def _get_descriptors(self):
881 # type: () -> tuple[range, ...]
882 return self.op.properties.immediates
883
884 def _verify_write_with_desc(self, idx, item, desc):
885 # type: (int, int | Any, range) -> None
886 if not isinstance(item, int):
887 raise TypeError("expected value of type int")
888 if item not in desc:
889 raise ValueError(f"immediate value {item!r} not in {desc!r}")
890
891 def __init__(self, items, op):
892 # type: (Iterable[int], Op) -> None
893 if hasattr(op, "immediates"):
894 raise ValueError("Op.immediates already set")
895 super().__init__(items, op)
896
897
898 @plain_data(frozen=True, eq=False)
899 @final
900 class Op:
901 __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
902
903 def __init__(self, fn, properties, inputs, immediates, name=""):
904 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
905 self.fn = fn
906 self.properties = properties
907 self.inputs = OpInputs(inputs, op=self)
908 self.immediates = OpImmediates(immediates, op=self)
909 outputs_len = len(self.properties.outputs)
910 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
911 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
912
913 @property
914 def kind(self):
915 return self.properties.kind
916
917 def __eq__(self, other):
918 # type: (Op | Any) -> bool
919 if isinstance(other, Op):
920 return self is other
921 return NotImplemented
922
923 def __hash__(self):
924 return object.__hash__(self)