working on new ir
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 from collections import defaultdict
2 import enum
3 from enum import Enum, unique
4 from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
5 from weakref import WeakValueDictionary as _WeakVDict
6
7 from cached_property import cached_property
8 from nmutil.plain_data import plain_data
9
10 from bigint_presentation_code.type_util import Self, assert_never, final
11 from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
12 OSet, FMap)
13 from functools import lru_cache
14
15
16 @final
17 class Fn:
18 def __init__(self):
19 self.ops = [] # type: list[Op]
20 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
21 self.__next_name_suffix = 2
22
23 def _add_op_with_unused_name(self, op, name=""):
24 # type: (Op, str) -> str
25 if op.fn is not self:
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op, "name"):
28 raise ValueError("Op already named")
29 orig_name = name
30 while True:
31 if name not in self.__op_names:
32 self.__op_names[name] = op
33 return name
34 name = orig_name + str(self.__next_name_suffix)
35 self.__next_name_suffix += 1
36
37 def __repr__(self):
38 # type: () -> str
39 return "<Fn>"
40
41
42 @unique
43 @final
44 class BaseTy(Enum):
45 I64 = enum.auto()
46 CA = enum.auto()
47 VL_MAXVL = enum.auto()
48
49 @cached_property
50 def only_scalar(self):
51 # type: () -> bool
52 if self is BaseTy.I64:
53 return False
54 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
55 return True
56 else:
57 assert_never(self)
58
59 @cached_property
60 def max_reg_len(self):
61 # type: () -> int
62 if self is BaseTy.I64:
63 return 128
64 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
65 return 1
66 else:
67 assert_never(self)
68
69 def __repr__(self):
70 return "BaseTy." + self._name_
71
72
73 @plain_data(frozen=True, unsafe_hash=True)
74 @final
75 class Ty:
76 __slots__ = "base_ty", "reg_len"
77
78 @staticmethod
79 def validate(base_ty, reg_len):
80 # type: (BaseTy, int) -> str | None
81 """ return a string with the error if the combination is invalid,
82 otherwise return None
83 """
84 if base_ty.only_scalar and reg_len != 1:
85 return f"can't create a vector of an only-scalar type: {base_ty}"
86 if reg_len < 1 or reg_len > base_ty.max_reg_len:
87 return "reg_len out of range"
88 return None
89
90 def __init__(self, base_ty, reg_len):
91 # type: (BaseTy, int) -> None
92 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
93 if msg is not None:
94 raise ValueError(msg)
95 self.base_ty = base_ty
96 self.reg_len = reg_len
97
98
99 @unique
100 @final
101 class LocKind(Enum):
102 GPR = enum.auto()
103 StackI64 = enum.auto()
104 CA = enum.auto()
105 VL_MAXVL = enum.auto()
106
107 @cached_property
108 def base_ty(self):
109 # type: () -> BaseTy
110 if self is LocKind.GPR or self is LocKind.StackI64:
111 return BaseTy.I64
112 if self is LocKind.CA:
113 return BaseTy.CA
114 if self is LocKind.VL_MAXVL:
115 return BaseTy.VL_MAXVL
116 else:
117 assert_never(self)
118
119 @cached_property
120 def loc_count(self):
121 # type: () -> int
122 if self is LocKind.StackI64:
123 return 1024
124 if self is LocKind.GPR or self is LocKind.CA \
125 or self is LocKind.VL_MAXVL:
126 return self.base_ty.max_reg_len
127 else:
128 assert_never(self)
129
130 def __repr__(self):
131 return "LocKind." + self._name_
132
133
134 @final
135 @unique
136 class LocSubKind(Enum):
137 BASE_GPR = enum.auto()
138 SV_EXTRA2_VGPR = enum.auto()
139 SV_EXTRA2_SGPR = enum.auto()
140 SV_EXTRA3_VGPR = enum.auto()
141 SV_EXTRA3_SGPR = enum.auto()
142 StackI64 = enum.auto()
143 CA = enum.auto()
144 VL_MAXVL = enum.auto()
145
146 @cached_property
147 def kind(self):
148 # type: () -> LocKind
149 # pyright fails typechecking when using `in` here:
150 # reported: https://github.com/microsoft/pyright/issues/4102
151 if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
152 or self is LocSubKind.SV_EXTRA2_SGPR \
153 or self is LocSubKind.SV_EXTRA3_VGPR \
154 or self is LocSubKind.SV_EXTRA3_SGPR:
155 return LocKind.GPR
156 if self is LocSubKind.StackI64:
157 return LocKind.StackI64
158 if self is LocSubKind.CA:
159 return LocKind.CA
160 if self is LocSubKind.VL_MAXVL:
161 return LocKind.VL_MAXVL
162 assert_never(self)
163
164 @property
165 def base_ty(self):
166 return self.kind.base_ty
167
168 @lru_cache()
169 def allocatable_locs(self, ty):
170 # type: (Ty) -> LocSet
171 if ty.base_ty != self.base_ty:
172 raise ValueError("type mismatch")
173 raise NotImplementedError # FIXME: finish
174
175
176 @plain_data(frozen=True, unsafe_hash=True)
177 @final
178 class GenericTy:
179 __slots__ = "base_ty", "is_vec"
180
181 def __init__(self, base_ty, is_vec):
182 # type: (BaseTy, bool) -> None
183 self.base_ty = base_ty
184 if base_ty.only_scalar and is_vec:
185 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
186 self.is_vec = is_vec
187
188 def instantiate(self, maxvl):
189 # type: (int) -> Ty
190 # here's where subvl and elwid would be accounted for
191 if self.is_vec:
192 return Ty(self.base_ty, maxvl)
193 return Ty(self.base_ty, 1)
194
195 def can_instantiate_to(self, ty):
196 # type: (Ty) -> bool
197 if self.base_ty != ty.base_ty:
198 return False
199 if self.is_vec:
200 return True
201 return ty.reg_len == 1
202
203
204 @plain_data(frozen=True, unsafe_hash=True)
205 @final
206 class Loc:
207 __slots__ = "kind", "start", "reg_len"
208
209 @staticmethod
210 def validate(kind, start, reg_len):
211 # type: (LocKind, int, int) -> str | None
212 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
213 if msg is not None:
214 return msg
215 if reg_len > kind.loc_count:
216 return "invalid reg_len"
217 if start < 0 or start + reg_len > kind.loc_count:
218 return "start not in valid range"
219 return None
220
221 @staticmethod
222 def try_make(kind, start, reg_len):
223 # type: (LocKind, int, int) -> Loc | None
224 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
225 if msg is None:
226 return None
227 return Loc(kind=kind, start=start, reg_len=reg_len)
228
229 def __init__(self, kind, start, reg_len):
230 # type: (LocKind, int, int) -> None
231 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
232 if msg is not None:
233 raise ValueError(msg)
234 self.kind = kind
235 self.reg_len = reg_len
236 self.start = start
237
238 def conflicts(self, other):
239 # type: (Loc) -> bool
240 return (self.kind != other.kind
241 and self.start < other.stop and other.start < self.stop)
242
243 @staticmethod
244 def make_ty(kind, reg_len):
245 # type: (LocKind, int) -> Ty
246 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
247
248 @cached_property
249 def ty(self):
250 # type: () -> Ty
251 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
252
253 @property
254 def stop(self):
255 # type: () -> int
256 return self.start + self.reg_len
257
258 def try_concat(self, *others):
259 # type: (*Loc | None) -> Loc | None
260 reg_len = self.reg_len
261 stop = self.stop
262 for other in others:
263 if other is None or other.kind != self.kind:
264 return None
265 if stop != other.start:
266 return None
267 stop = other.stop
268 reg_len += other.reg_len
269 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
270
271
272 @plain_data(frozen=True, eq=False, repr=False)
273 @final
274 class LocSet(AbstractSet[Loc]):
275 __slots__ = "starts", "ty"
276
277 def __init__(self, __locs=()):
278 # type: (Iterable[Loc]) -> None
279 if isinstance(__locs, LocSet):
280 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
281 self.ty = __locs.ty # type: Ty | None
282 return
283 starts = {i: BitSet() for i in LocKind}
284 ty = None
285 for loc in __locs:
286 if ty is None:
287 ty = loc.ty
288 if ty != loc.ty:
289 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
290 starts[loc.kind].add(loc.start)
291 self.starts = FMap(
292 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
293 self.ty = ty
294
295 @cached_property
296 def stops(self):
297 # type: () -> FMap[LocKind, FBitSet]
298 if self.ty is None:
299 return FMap()
300 sh = self.ty.reg_len
301 return FMap(
302 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
303
304 @property
305 def kinds(self):
306 # type: () -> AbstractSet[LocKind]
307 return self.starts.keys()
308
309 @property
310 def reg_len(self):
311 # type: () -> int | None
312 if self.ty is None:
313 return None
314 return self.ty.reg_len
315
316 @property
317 def base_ty(self):
318 # type: () -> BaseTy | None
319 if self.ty is None:
320 return None
321 return self.ty.base_ty
322
323 def concat(self, *others):
324 # type: (*LocSet) -> LocSet
325 if self.ty is None:
326 return LocSet()
327 base_ty = self.ty.base_ty
328 reg_len = self.ty.reg_len
329 starts = {k: BitSet(v) for k, v in self.starts.items()}
330 for other in others:
331 if other.ty is None:
332 return LocSet()
333 if other.ty.base_ty != base_ty:
334 return LocSet()
335 for kind, other_starts in other.starts.items():
336 if kind not in starts:
337 continue
338 starts[kind].bits &= other_starts.bits >> reg_len
339 if starts[kind] == 0:
340 del starts[kind]
341 if len(starts) == 0:
342 return LocSet()
343 reg_len += other.ty.reg_len
344
345 def locs():
346 # type: () -> Iterable[Loc]
347 for kind, v in starts.items():
348 for start in v:
349 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
350 if loc is not None:
351 yield loc
352 return LocSet(locs())
353
354 def __contains__(self, loc):
355 # type: (Loc | Any) -> bool
356 if not isinstance(loc, Loc) or loc.ty == self.ty:
357 return False
358 if loc.kind not in self.starts:
359 return False
360 return loc.start in self.starts[loc.kind]
361
362 def __iter__(self):
363 # type: () -> Iterator[Loc]
364 if self.ty is None:
365 return
366 for kind, starts in self.starts.items():
367 for start in starts:
368 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
369
370 @cached_property
371 def __len(self):
372 return sum((len(v) for v in self.starts.values()), 0)
373
374 def __len__(self):
375 return self.__len
376
377 @cached_property
378 def __hash(self):
379 return super()._hash()
380
381 def __hash__(self):
382 return self.__hash
383
384
385 @plain_data(frozen=True, unsafe_hash=True)
386 @final
387 class GenericOperandDesc:
388 """generic Op operand descriptor"""
389 __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
390
391 def __init__(
392 self, ty, # type: GenericTy
393 sub_kinds, # type: Iterable[LocSubKind]
394 *,
395 fixed_loc=None, # type: Loc | None
396 tied_input_index=None, # type: int | None
397 spread=False, # type: bool
398 ):
399 # type: (...) -> None
400 self.ty = ty
401 self.sub_kinds = OFSet(sub_kinds)
402 if len(self.sub_kinds) == 0:
403 raise ValueError("sub_kinds can't be empty")
404 self.fixed_loc = fixed_loc
405 if fixed_loc is not None:
406 if tied_input_index is not None:
407 raise ValueError("operand can't be both tied and fixed")
408 if not ty.can_instantiate_to(fixed_loc.ty):
409 raise ValueError(
410 f"fixed_loc has incompatible type for given generic "
411 f"type: fixed_loc={fixed_loc} generic ty={ty}")
412 if len(self.sub_kinds) != 1:
413 raise ValueError(
414 "multiple sub_kinds not allowed for fixed operand")
415 for sub_kind in self.sub_kinds:
416 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
417 raise ValueError(
418 f"fixed_loc not in given sub_kind: "
419 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
420 for sub_kind in self.sub_kinds:
421 if sub_kind.base_ty != ty.base_ty:
422 raise ValueError(f"sub_kind is incompatible with type: "
423 f"sub_kind={sub_kind} ty={ty}")
424 if tied_input_index is not None and tied_input_index < 0:
425 raise ValueError("invalid tied_input_index")
426 self.tied_input_index = tied_input_index
427 self.spread = spread
428 if spread:
429 if self.tied_input_index is not None:
430 raise ValueError("operand can't be both spread and tied")
431 if self.fixed_loc is not None:
432 raise ValueError("operand can't be both spread and fixed")
433 if self.ty.is_vec:
434 raise ValueError("operand can't be both spread and vector")
435
436 def tied_to_input(self, tied_input_index):
437 # type: (int) -> Self
438 return GenericOperandDesc(self.ty, self.sub_kinds,
439 tied_input_index=tied_input_index)
440
441 def with_fixed_loc(self, fixed_loc):
442 # type: (Loc) -> Self
443 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
444
445 def instantiate(self, maxvl):
446 # type: (int) -> Iterable[OperandDesc]
447 rep_count = 1
448 if self.spread:
449 rep_count = maxvl
450 maxvl = 1
451 ty = self.ty.instantiate(maxvl=maxvl)
452
453 def locs():
454 # type: () -> Iterable[Loc]
455 if self.fixed_loc is not None:
456 if ty != self.fixed_loc.ty:
457 raise ValueError(
458 f"instantiation failed: type mismatch with fixed_loc: "
459 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
460 yield self.fixed_loc
461 return
462 for sub_kind in self.sub_kinds:
463 yield from sub_kind.allocatable_locs(ty)
464 loc_set = LocSet(locs())
465 for idx in range(rep_count):
466 if not self.spread:
467 idx = None
468 yield OperandDesc(loc_set=loc_set,
469 tied_input_index=self.tied_input_index,
470 spread_index=idx)
471
472
473 @plain_data(frozen=True, unsafe_hash=True)
474 @final
475 class OperandDesc:
476 """Op operand descriptor"""
477 __slots__ = "loc_set", "tied_input_index", "spread_index"
478
479 def __init__(self, loc_set, tied_input_index, spread_index):
480 # type: (LocSet, int | None, int | None) -> None
481 if len(loc_set) == 0:
482 raise ValueError("loc_set must not be empty")
483 self.loc_set = loc_set
484 self.tied_input_index = tied_input_index
485 if self.tied_input_index is not None and self.spread_index is not None:
486 raise ValueError("operand can't be both spread and tied")
487
488
489 OD_BASE_SGPR = GenericOperandDesc(
490 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
491 sub_kinds=[LocSubKind.BASE_GPR])
492 OD_EXTRA3_SGPR = GenericOperandDesc(
493 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
494 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
495 OD_EXTRA3_VGPR = GenericOperandDesc(
496 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
497 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
498 OD_EXTRA2_SGPR = GenericOperandDesc(
499 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
500 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
501 OD_EXTRA2_VGPR = GenericOperandDesc(
502 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
503 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
504 OD_CA = GenericOperandDesc(
505 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
506 sub_kinds=[LocSubKind.CA])
507 OD_VL = GenericOperandDesc(
508 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
509 sub_kinds=[LocSubKind.VL_MAXVL])
510
511
512 @plain_data(frozen=True, unsafe_hash=True)
513 @final
514 class GenericOpProperties:
515 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
516 "is_copy", "is_load_immediate", "has_side_effects")
517
518 def __init__(self, demo_asm, # type: str
519 inputs, # type: Iterable[GenericOperandDesc]
520 outputs, # type: Iterable[GenericOperandDesc]
521 immediates=(), # type: Iterable[range]
522 is_copy=False, # type: bool
523 is_load_immediate=False, # type: bool
524 has_side_effects=False, # type: bool
525 ):
526 # type: (...) -> None
527 self.demo_asm = demo_asm
528 self.inputs = tuple(inputs)
529 for inp in self.inputs:
530 if inp.tied_input_index is not None:
531 raise ValueError(
532 f"tied_input_index is not allowed on inputs: {inp}")
533 self.outputs = tuple(outputs)
534 fixed_locs = [] # type: list[tuple[Loc, int]]
535 for idx, out in enumerate(self.outputs):
536 if out.tied_input_index is not None \
537 and out.tied_input_index >= len(self.inputs):
538 raise ValueError(f"tied_input_index out of range: {out}")
539 if out.fixed_loc is not None:
540 for other_fixed_loc, other_idx in fixed_locs:
541 if not other_fixed_loc.conflicts(out.fixed_loc):
542 continue
543 raise ValueError(
544 f"conflicting fixed_locs: outputs[{idx}] and "
545 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
546 f"with {other_fixed_loc}")
547 fixed_locs.append((out.fixed_loc, idx))
548 self.immediates = tuple(immediates)
549 self.is_copy = is_copy
550 self.is_load_immediate = is_load_immediate
551 self.has_side_effects = has_side_effects
552
553
554 @plain_data(frozen=True, unsafe_hash=True)
555 @final
556 class OpProperties:
557 __slots__ = "kind", "inputs", "outputs"
558
559 def __init__(self, kind, maxvl):
560 # type: (OpKind, int) -> None
561 self.kind = kind
562 inputs = [] # type: list[OperandDesc]
563 for inp in self.generic.inputs:
564 inputs.extend(inp.instantiate(maxvl=maxvl))
565 self.inputs = tuple(inputs)
566 outputs = [] # type: list[OperandDesc]
567 for out in self.generic.outputs:
568 outputs.extend(out.instantiate(maxvl=maxvl))
569 self.outputs = tuple(outputs)
570
571 @property
572 def generic(self):
573 # type: () -> GenericOpProperties
574 return self.kind.properties
575
576 @property
577 def immediates(self):
578 # type: () -> tuple[range, ...]
579 return self.generic.immediates
580
581 @property
582 def demo_asm(self):
583 # type: () -> str
584 return self.generic.demo_asm
585
586 @property
587 def is_copy(self):
588 # type: () -> bool
589 return self.generic.is_copy
590
591 @property
592 def is_load_immediate(self):
593 # type: () -> bool
594 return self.generic.is_load_immediate
595
596 @property
597 def has_side_effects(self):
598 # type: () -> bool
599 return self.generic.has_side_effects
600
601
602 @unique
603 @final
604 class OpKind(Enum):
605 def __init__(self, properties):
606 # type: (GenericOpProperties) -> None
607 super().__init__()
608 self.__properties = properties
609
610 @property
611 def properties(self):
612 # type: () -> GenericOpProperties
613 return self.__properties
614
615 SvAddE = GenericOpProperties(
616 demo_asm="sv.adde *RT, *RA, *RB",
617 inputs=(OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL),
618 outputs=(OD_EXTRA3_VGPR, OD_CA),
619 )
620 SvSubFE = GenericOpProperties(
621 demo_asm="sv.subfe *RT, *RA, *RB",
622 inputs=(OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL),
623 outputs=(OD_EXTRA3_VGPR, OD_CA),
624 )
625 SvMAddEDU = GenericOpProperties(
626 demo_asm="sv.maddedu *RT, *RA, RB, RC",
627 inputs=(OD_EXTRA2_VGPR, OD_EXTRA2_VGPR, OD_EXTRA2_SGPR,
628 OD_EXTRA2_SGPR, OD_VL),
629 outputs=(OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(3)),
630 )
631 SetVLI = GenericOpProperties(
632 demo_asm="setvl 0, 0, imm, 0, 1, 1",
633 inputs=(),
634 outputs=(OD_VL,),
635 immediates=(range(1, 65),),
636 is_load_immediate=True,
637 )
638 SvLI = GenericOpProperties(
639 demo_asm="sv.addi *RT, 0, imm",
640 inputs=(OD_VL,),
641 outputs=(OD_EXTRA3_VGPR,),
642 immediates=(range(-2 ** 15, 2 ** 15),),
643 is_load_immediate=True,
644 )
645 LI = GenericOpProperties(
646 demo_asm="addi RT, 0, imm",
647 inputs=(),
648 outputs=(OD_BASE_SGPR,),
649 immediates=(range(-2 ** 15, 2 ** 15),),
650 is_load_immediate=True,
651 )
652 VecCopyToReg = GenericOpProperties(
653 demo_asm="sv.mv dest, src",
654 inputs=(GenericOperandDesc(
655 ty=GenericTy(BaseTy.I64, is_vec=True),
656 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64),
657 ), OD_VL),
658 outputs=(OD_EXTRA3_VGPR,),
659 is_copy=True,
660 )
661 VecCopyFromReg = GenericOpProperties(
662 demo_asm="sv.mv dest, src",
663 inputs=(OD_EXTRA3_VGPR, OD_VL),
664 outputs=(GenericOperandDesc(
665 ty=GenericTy(BaseTy.I64, is_vec=True),
666 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64),
667 ),),
668 is_copy=True,
669 )
670 CopyToReg = GenericOpProperties(
671 demo_asm="mv dest, src",
672 inputs=(GenericOperandDesc(
673 ty=GenericTy(BaseTy.I64, is_vec=False),
674 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
675 LocSubKind.StackI64),
676 ),),
677 outputs=(GenericOperandDesc(
678 ty=GenericTy(BaseTy.I64, is_vec=False),
679 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR),
680 ),),
681 is_copy=True,
682 )
683 CopyFromReg = GenericOpProperties(
684 demo_asm="mv dest, src",
685 inputs=(GenericOperandDesc(
686 ty=GenericTy(BaseTy.I64, is_vec=False),
687 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR),
688 ),),
689 outputs=(GenericOperandDesc(
690 ty=GenericTy(BaseTy.I64, is_vec=False),
691 sub_kinds=(LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
692 LocSubKind.StackI64),
693 ),),
694 is_copy=True,
695 )
696 Concat = GenericOpProperties(
697 demo_asm="sv.mv dest, src",
698 inputs=(GenericOperandDesc(
699 ty=GenericTy(BaseTy.I64, is_vec=False),
700 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR,),
701 spread=True,
702 ), OD_VL),
703 outputs=(OD_EXTRA3_VGPR,),
704 is_copy=True,
705 )
706 Spread = GenericOpProperties(
707 demo_asm="sv.mv dest, src",
708 inputs=(OD_EXTRA3_VGPR, OD_VL),
709 outputs=(GenericOperandDesc(
710 ty=GenericTy(BaseTy.I64, is_vec=False),
711 sub_kinds=(LocSubKind.SV_EXTRA3_VGPR,),
712 spread=True,
713 ),),
714 is_copy=True,
715 )
716
717
718 # FIXME: rewrite from here
719
720
721 @plain_data(frozen=True, unsafe_hash=True, repr=False)
722 @final
723 class SSAVal:
724 __slots__ = "sliced_op_outputs",
725
726 _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
727 "tuple[Op, int]", "SSAVal"]
728
729 @staticmethod
730 def __process_sliced_op_outputs(inp):
731 # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
732 for v in inp:
733 if isinstance(v, SSAVal):
734 yield from v.sliced_op_outputs
735 continue
736 op = v[0]
737 output_index = v[1]
738 if output_index < 0 or output_index >= len(op.properties.outputs):
739 raise ValueError("invalid output_index")
740 cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
741 slice_ = slice(None) if len(v) == 2 else v[2]
742 if isinstance(slice_, range):
743 slice_ = slice(slice_.start, slice_.stop, slice_.step)
744 if isinstance(slice_, int):
745 # raise exception for out-of-range values
746 idx = range(cur_len)[slice_]
747 range_ = range(idx, idx + 1)
748 else:
749 # raise exception for out-of-range values
750 range_ = range(cur_len)[slice_]
751 if range_.step != 1:
752 raise ValueError("slice step must be 1")
753 if len(range_) == 0:
754 continue
755 yield op, output_index, range_
756
757 def __init__(self, sliced_op_outputs):
758 # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
759 # we have length arg so plain_data.replace works
760 if isinstance(sliced_op_outputs, SSAVal):
761 inp = sliced_op_outputs.sliced_op_outputs
762 else:
763 inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
764 processed = [] # type: list[tuple[Op, int, range]]
765 length = 0
766 for op, output_index, range_ in inp:
767 length += len(range_)
768 if len(processed) == 0:
769 processed.append((op, output_index, range_))
770 continue
771 last_op, last_output_index, last_range_ = processed[-1]
772 if last_op == op and last_output_index == output_index \
773 and last_range_.stop == range_.start:
774 # merge slices
775 range_ = range(last_range_.start, range_.stop)
776 processed[-1] = op, output_index, range_
777 else:
778 processed.append((op, output_index, range_))
779 self.sliced_op_outputs = tuple(processed)
780
781 def __add__(self, other):
782 # type: (SSAVal | Any) -> SSAVal
783 if not isinstance(other, SSAVal):
784 return NotImplemented
785 return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
786
787 def __radd__(self, other):
788 # type: (SSAVal | Any) -> SSAVal
789 if isinstance(other, SSAVal):
790 return other.__add__(self)
791 return NotImplemented
792
793 @cached_property
794 def expanded_sliced_op_outputs(self):
795 # type: () -> tuple[tuple[Op, int, int], ...]
796 retval = [] # type: list[tuple[Op, int, int]]
797 for op, output_index, range_ in self.sliced_op_outputs:
798 for i in range_:
799 retval.append((op, output_index, i))
800 # must be tuple to not be modifiable since it's cached
801 return tuple(retval)
802
803 def __getitem__(self, idx):
804 # type: (int | slice) -> SSAVal
805 if isinstance(idx, int):
806 return SSAVal([self.expanded_sliced_op_outputs[idx]])
807 return SSAVal(self.expanded_sliced_op_outputs[idx])
808
809 def __len__(self):
810 return len(self.expanded_sliced_op_outputs)
811
812 def __iter__(self):
813 # type: () -> Iterator[SSAVal]
814 for v in self.expanded_sliced_op_outputs:
815 yield SSAVal([v])
816
817 def __repr__(self):
818 # type: () -> str
819 if len(self.sliced_op_outputs) == 0:
820 return "SSAVal([])"
821 parts = [] # type: list[str]
822 for op, output_index, range_ in self.sliced_op_outputs:
823 out_len = op.properties.outputs[output_index].get_length(op.maxvl)
824 parts.append(f"<{op.name}#{output_index}>")
825 if range_ != range(out_len):
826 parts[-1] += f"[{range_.start}:{range_.stop}]"
827 return " + ".join(parts)
828
829
830 @plain_data(frozen=True, eq=False)
831 @final
832 class Op:
833 __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
834
835 def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
836 # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
837 self.fn = fn
838 self.kind = kind
839 self.inputs = list(inputs)
840 self.immediates = list(immediates)
841 self.maxvl = maxvl
842 outputs_len = len(self.properties.outputs)
843 self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
844 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
845
846 @property
847 def properties(self):
848 return self.kind.properties
849
850 def __eq__(self, other):
851 # type: (Op | Any) -> bool
852 if isinstance(other, Op):
853 return self is other
854 return NotImplemented
855
856 def __hash__(self):
857 return object.__hash__(self)