03b623dafbf5f80ebdc96e487088552a42ffe040
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 import enum
2 from abc import ABCMeta, abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache
5 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
6 Sequence, TypeVar, overload)
7 from weakref import WeakValueDictionary as _WeakVDict
8
9 from cached_property import cached_property
10 from nmutil.plain_data import fields, plain_data
11
12 from bigint_presentation_code.type_util import Self, assert_never, final
13 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
14
15
16 @final
17 class Fn:
18 def __init__(self):
19 self.ops = [] # type: list[Op]
20 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
21 self.__next_name_suffix = 2
22
23 def _add_op_with_unused_name(self, op, name=""):
24 # type: (Op, str) -> str
25 if op.fn is not self:
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op, "name"):
28 raise ValueError("Op already named")
29 orig_name = name
30 while True:
31 if name != "" and name not in self.__op_names:
32 self.__op_names[name] = op
33 return name
34 name = orig_name + str(self.__next_name_suffix)
35 self.__next_name_suffix += 1
36
37 def __repr__(self):
38 # type: () -> str
39 return "<Fn>"
40
41 def append_op(self, op):
42 # type: (Op) -> None
43 if op.fn is not self:
44 raise ValueError("can't add Op to wrong Fn")
45 self.ops.append(op)
46
47 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
48 maxvl=1):
49 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
50 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
51 input_vals=input_vals, immediates=immediates, name=name)
52 self.append_op(retval)
53 return retval
54
55 def pre_ra_sim(self, state):
56 # type: (PreRASimState) -> None
57 for op in self.ops:
58 op.pre_ra_sim(state)
59
60 def pre_ra_insert_copies(self):
61 # type: () -> None
62 orig_ops = list(self.ops)
63 copied_outputs = {} # type: dict[SSAVal, SSAVal]
64 self.ops.clear()
65 for op in orig_ops:
66 for i in range(len(op.input_vals)):
67 inp = copied_outputs[op.input_vals[i]]
68 if inp.ty.base_ty is BaseTy.I64:
69 maxvl = inp.ty.reg_len
70 if inp.ty.reg_len != 1:
71 setvl = self.append_new_op(
72 OpKind.SetVLI, immediates=[maxvl],
73 name=f"{op.name}.inp{i}.setvl")
74 vl = setvl.outputs[0]
75 mv = self.append_new_op(
76 OpKind.VecCopyToReg, input_vals=[inp, vl],
77 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
78 else:
79 mv = self.append_new_op(
80 OpKind.CopyToReg, input_vals=[inp],
81 name=f"{op.name}.inp{i}.copy")
82 op.input_vals[i] = mv.outputs[0]
83 elif inp.ty.base_ty is BaseTy.CA \
84 or inp.ty.base_ty is BaseTy.VL_MAXVL:
85 # all copies would be no-ops, so we don't need to copy
86 op.input_vals[i] = inp
87 else:
88 assert_never(inp.ty.base_ty)
89 self.ops.append(op)
90 for i, out in enumerate(op.outputs):
91 if out.ty.base_ty is BaseTy.I64:
92 maxvl = out.ty.reg_len
93 if out.ty.reg_len != 1:
94 setvl = self.append_new_op(
95 OpKind.SetVLI, immediates=[maxvl],
96 name=f"{op.name}.out{i}.setvl")
97 vl = setvl.outputs[0]
98 mv = self.append_new_op(
99 OpKind.VecCopyFromReg, input_vals=[out, vl],
100 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
101 else:
102 mv = self.append_new_op(
103 OpKind.CopyFromReg, input_vals=[out],
104 name=f"{op.name}.out{i}.copy")
105 copied_outputs[out] = mv.outputs[0]
106 elif out.ty.base_ty is BaseTy.CA \
107 or out.ty.base_ty is BaseTy.VL_MAXVL:
108 # all copies would be no-ops, so we don't need to copy
109 copied_outputs[out] = out
110 else:
111 assert_never(out.ty.base_ty)
112
113
114 @plain_data(frozen=True, eq=False)
115 @final
116 class FnWithUses:
117 __slots__ = "fn", "uses"
118
119 def __init__(self, fn):
120 # type: (Fn) -> None
121 self.fn = fn
122 retval = {} # type: dict[SSAVal, OSet[SSAUse]]
123 for op in fn.ops:
124 for idx, inp in enumerate(op.input_vals):
125 retval[inp].add(SSAUse(op, idx))
126 for out in op.outputs:
127 retval[out] = OSet()
128 self.uses = FMap((k, OFSet(v)) for k, v in retval.items())
129
130 def __eq__(self, other):
131 # type: (FnWithUses | Any) -> bool
132 if isinstance(other, FnWithUses):
133 return self.fn == other.fn
134 return NotImplemented
135
136 def __hash__(self):
137 # type: () -> int
138 return hash(self.fn)
139
140
141 @unique
142 @final
143 class BaseTy(Enum):
144 I64 = enum.auto()
145 CA = enum.auto()
146 VL_MAXVL = enum.auto()
147
148 @cached_property
149 def only_scalar(self):
150 # type: () -> bool
151 if self is BaseTy.I64:
152 return False
153 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
154 return True
155 else:
156 assert_never(self)
157
158 @cached_property
159 def max_reg_len(self):
160 # type: () -> int
161 if self is BaseTy.I64:
162 return 128
163 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
164 return 1
165 else:
166 assert_never(self)
167
168 def __repr__(self):
169 return "BaseTy." + self._name_
170
171
172 @plain_data(frozen=True, unsafe_hash=True, repr=False)
173 @final
174 class Ty:
175 __slots__ = "base_ty", "reg_len"
176
177 @staticmethod
178 def validate(base_ty, reg_len):
179 # type: (BaseTy, int) -> str | None
180 """ return a string with the error if the combination is invalid,
181 otherwise return None
182 """
183 if base_ty.only_scalar and reg_len != 1:
184 return f"can't create a vector of an only-scalar type: {base_ty}"
185 if reg_len < 1 or reg_len > base_ty.max_reg_len:
186 return "reg_len out of range"
187 return None
188
189 def __init__(self, base_ty, reg_len):
190 # type: (BaseTy, int) -> None
191 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
192 if msg is not None:
193 raise ValueError(msg)
194 self.base_ty = base_ty
195 self.reg_len = reg_len
196
197 def __repr__(self):
198 # type: () -> str
199 if self.reg_len != 1:
200 reg_len = f"*{self.reg_len}"
201 else:
202 reg_len = ""
203 return f"<{self.base_ty._name_}{reg_len}>"
204
205
206 @unique
207 @final
208 class LocKind(Enum):
209 GPR = enum.auto()
210 StackI64 = enum.auto()
211 CA = enum.auto()
212 VL_MAXVL = enum.auto()
213
214 @cached_property
215 def base_ty(self):
216 # type: () -> BaseTy
217 if self is LocKind.GPR or self is LocKind.StackI64:
218 return BaseTy.I64
219 if self is LocKind.CA:
220 return BaseTy.CA
221 if self is LocKind.VL_MAXVL:
222 return BaseTy.VL_MAXVL
223 else:
224 assert_never(self)
225
226 @cached_property
227 def loc_count(self):
228 # type: () -> int
229 if self is LocKind.StackI64:
230 return 1024
231 if self is LocKind.GPR or self is LocKind.CA \
232 or self is LocKind.VL_MAXVL:
233 return self.base_ty.max_reg_len
234 else:
235 assert_never(self)
236
237 def __repr__(self):
238 return "LocKind." + self._name_
239
240
241 @final
242 @unique
243 class LocSubKind(Enum):
244 BASE_GPR = enum.auto()
245 SV_EXTRA2_VGPR = enum.auto()
246 SV_EXTRA2_SGPR = enum.auto()
247 SV_EXTRA3_VGPR = enum.auto()
248 SV_EXTRA3_SGPR = enum.auto()
249 StackI64 = enum.auto()
250 CA = enum.auto()
251 VL_MAXVL = enum.auto()
252
253 @cached_property
254 def kind(self):
255 # type: () -> LocKind
256 # pyright fails typechecking when using `in` here:
257 # reported: https://github.com/microsoft/pyright/issues/4102
258 if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
259 or self is LocSubKind.SV_EXTRA2_SGPR \
260 or self is LocSubKind.SV_EXTRA3_VGPR \
261 or self is LocSubKind.SV_EXTRA3_SGPR:
262 return LocKind.GPR
263 if self is LocSubKind.StackI64:
264 return LocKind.StackI64
265 if self is LocSubKind.CA:
266 return LocKind.CA
267 if self is LocSubKind.VL_MAXVL:
268 return LocKind.VL_MAXVL
269 assert_never(self)
270
271 @property
272 def base_ty(self):
273 return self.kind.base_ty
274
275 @lru_cache()
276 def allocatable_locs(self, ty):
277 # type: (Ty) -> LocSet
278 if ty.base_ty != self.base_ty:
279 raise ValueError("type mismatch")
280 if self is LocSubKind.BASE_GPR:
281 starts = range(32)
282 elif self is LocSubKind.SV_EXTRA2_VGPR:
283 starts = range(0, 128, 2)
284 elif self is LocSubKind.SV_EXTRA2_SGPR:
285 starts = range(64)
286 elif self is LocSubKind.SV_EXTRA3_VGPR \
287 or self is LocSubKind.SV_EXTRA3_SGPR:
288 starts = range(128)
289 elif self is LocSubKind.StackI64:
290 starts = range(LocKind.StackI64.loc_count)
291 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
292 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
293 else:
294 assert_never(self)
295 retval = [] # type: list[Loc]
296 for start in starts:
297 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
298 if loc is None:
299 continue
300 conflicts = False
301 for special_loc in SPECIAL_GPRS:
302 if loc.conflicts(special_loc):
303 conflicts = True
304 break
305 if not conflicts:
306 retval.append(loc)
307 return LocSet(retval)
308
309 def __repr__(self):
310 return "LocSubKind." + self._name_
311
312
313 @plain_data(frozen=True, unsafe_hash=True)
314 @final
315 class GenericTy:
316 __slots__ = "base_ty", "is_vec"
317
318 def __init__(self, base_ty, is_vec):
319 # type: (BaseTy, bool) -> None
320 self.base_ty = base_ty
321 if base_ty.only_scalar and is_vec:
322 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
323 self.is_vec = is_vec
324
325 def instantiate(self, maxvl):
326 # type: (int) -> Ty
327 # here's where subvl and elwid would be accounted for
328 if self.is_vec:
329 return Ty(self.base_ty, maxvl)
330 return Ty(self.base_ty, 1)
331
332 def can_instantiate_to(self, ty):
333 # type: (Ty) -> bool
334 if self.base_ty != ty.base_ty:
335 return False
336 if self.is_vec:
337 return True
338 return ty.reg_len == 1
339
340
341 @plain_data(frozen=True, unsafe_hash=True)
342 @final
343 class Loc:
344 __slots__ = "kind", "start", "reg_len"
345
346 @staticmethod
347 def validate(kind, start, reg_len):
348 # type: (LocKind, int, int) -> str | None
349 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
350 if msg is not None:
351 return msg
352 if reg_len > kind.loc_count:
353 return "invalid reg_len"
354 if start < 0 or start + reg_len > kind.loc_count:
355 return "start not in valid range"
356 return None
357
358 @staticmethod
359 def try_make(kind, start, reg_len):
360 # type: (LocKind, int, int) -> Loc | None
361 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
362 if msg is not None:
363 return None
364 return Loc(kind=kind, start=start, reg_len=reg_len)
365
366 def __init__(self, kind, start, reg_len):
367 # type: (LocKind, int, int) -> None
368 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
369 if msg is not None:
370 raise ValueError(msg)
371 self.kind = kind
372 self.reg_len = reg_len
373 self.start = start
374
375 def conflicts(self, other):
376 # type: (Loc) -> bool
377 return (self.kind == other.kind
378 and self.start < other.stop and other.start < self.stop)
379
380 @staticmethod
381 def make_ty(kind, reg_len):
382 # type: (LocKind, int) -> Ty
383 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
384
385 @cached_property
386 def ty(self):
387 # type: () -> Ty
388 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
389
390 @property
391 def stop(self):
392 # type: () -> int
393 return self.start + self.reg_len
394
395 def try_concat(self, *others):
396 # type: (*Loc | None) -> Loc | None
397 reg_len = self.reg_len
398 stop = self.stop
399 for other in others:
400 if other is None or other.kind != self.kind:
401 return None
402 if stop != other.start:
403 return None
404 stop = other.stop
405 reg_len += other.reg_len
406 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
407
408
409 SPECIAL_GPRS = (
410 Loc(kind=LocKind.GPR, start=0, reg_len=1),
411 Loc(kind=LocKind.GPR, start=1, reg_len=1),
412 Loc(kind=LocKind.GPR, start=2, reg_len=1),
413 Loc(kind=LocKind.GPR, start=13, reg_len=1),
414 )
415
416
417 @plain_data(frozen=True, eq=False)
418 @final
419 class LocSet(AbstractSet[Loc]):
420 __slots__ = "starts", "ty"
421
422 def __init__(self, __locs=()):
423 # type: (Iterable[Loc]) -> None
424 if isinstance(__locs, LocSet):
425 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
426 self.ty = __locs.ty # type: Ty | None
427 return
428 starts = {i: BitSet() for i in LocKind}
429 ty = None
430 for loc in __locs:
431 if ty is None:
432 ty = loc.ty
433 if ty != loc.ty:
434 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
435 starts[loc.kind].add(loc.start)
436 self.starts = FMap(
437 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
438 self.ty = ty
439
440 @cached_property
441 def stops(self):
442 # type: () -> FMap[LocKind, FBitSet]
443 if self.ty is None:
444 return FMap()
445 sh = self.ty.reg_len
446 return FMap(
447 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
448
449 @property
450 def kinds(self):
451 # type: () -> AbstractSet[LocKind]
452 return self.starts.keys()
453
454 @property
455 def reg_len(self):
456 # type: () -> int | None
457 if self.ty is None:
458 return None
459 return self.ty.reg_len
460
461 @property
462 def base_ty(self):
463 # type: () -> BaseTy | None
464 if self.ty is None:
465 return None
466 return self.ty.base_ty
467
468 def concat(self, *others):
469 # type: (*LocSet) -> LocSet
470 if self.ty is None:
471 return LocSet()
472 base_ty = self.ty.base_ty
473 reg_len = self.ty.reg_len
474 starts = {k: BitSet(v) for k, v in self.starts.items()}
475 for other in others:
476 if other.ty is None:
477 return LocSet()
478 if other.ty.base_ty != base_ty:
479 return LocSet()
480 for kind, other_starts in other.starts.items():
481 if kind not in starts:
482 continue
483 starts[kind].bits &= other_starts.bits >> reg_len
484 if starts[kind] == 0:
485 del starts[kind]
486 if len(starts) == 0:
487 return LocSet()
488 reg_len += other.ty.reg_len
489
490 def locs():
491 # type: () -> Iterable[Loc]
492 for kind, v in starts.items():
493 for start in v:
494 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
495 if loc is not None:
496 yield loc
497 return LocSet(locs())
498
499 def __contains__(self, loc):
500 # type: (Loc | Any) -> bool
501 if not isinstance(loc, Loc) or loc.ty != self.ty:
502 return False
503 if loc.kind not in self.starts:
504 return False
505 return loc.start in self.starts[loc.kind]
506
507 def __iter__(self):
508 # type: () -> Iterator[Loc]
509 if self.ty is None:
510 return
511 for kind, starts in self.starts.items():
512 for start in starts:
513 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
514
515 @cached_property
516 def __len(self):
517 return sum((len(v) for v in self.starts.values()), 0)
518
519 def __len__(self):
520 return self.__len
521
522 @cached_property
523 def __hash(self):
524 return super()._hash()
525
526 def __hash__(self):
527 return self.__hash
528
529
530 @plain_data(frozen=True, unsafe_hash=True)
531 @final
532 class GenericOperandDesc:
533 """generic Op operand descriptor"""
534 __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
535
536 def __init__(
537 self, ty, # type: GenericTy
538 sub_kinds, # type: Iterable[LocSubKind]
539 *,
540 fixed_loc=None, # type: Loc | None
541 tied_input_index=None, # type: int | None
542 spread=False, # type: bool
543 ):
544 # type: (...) -> None
545 self.ty = ty
546 self.sub_kinds = OFSet(sub_kinds)
547 if len(self.sub_kinds) == 0:
548 raise ValueError("sub_kinds can't be empty")
549 self.fixed_loc = fixed_loc
550 if fixed_loc is not None:
551 if tied_input_index is not None:
552 raise ValueError("operand can't be both tied and fixed")
553 if not ty.can_instantiate_to(fixed_loc.ty):
554 raise ValueError(
555 f"fixed_loc has incompatible type for given generic "
556 f"type: fixed_loc={fixed_loc} generic ty={ty}")
557 if len(self.sub_kinds) != 1:
558 raise ValueError(
559 "multiple sub_kinds not allowed for fixed operand")
560 for sub_kind in self.sub_kinds:
561 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
562 raise ValueError(
563 f"fixed_loc not in given sub_kind: "
564 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
565 for sub_kind in self.sub_kinds:
566 if sub_kind.base_ty != ty.base_ty:
567 raise ValueError(f"sub_kind is incompatible with type: "
568 f"sub_kind={sub_kind} ty={ty}")
569 if tied_input_index is not None and tied_input_index < 0:
570 raise ValueError("invalid tied_input_index")
571 self.tied_input_index = tied_input_index
572 self.spread = spread
573 if spread:
574 if self.tied_input_index is not None:
575 raise ValueError("operand can't be both spread and tied")
576 if self.fixed_loc is not None:
577 raise ValueError("operand can't be both spread and fixed")
578 if self.ty.is_vec:
579 raise ValueError("operand can't be both spread and vector")
580
581 def tied_to_input(self, tied_input_index):
582 # type: (int) -> Self
583 return GenericOperandDesc(self.ty, self.sub_kinds,
584 tied_input_index=tied_input_index)
585
586 def with_fixed_loc(self, fixed_loc):
587 # type: (Loc) -> Self
588 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
589
590 def instantiate(self, maxvl):
591 # type: (int) -> Iterable[OperandDesc]
592 # assumes all spread operands have ty.reg_len = 1
593 rep_count = 1
594 if self.spread:
595 rep_count = maxvl
596 maxvl = 1
597 ty = self.ty.instantiate(maxvl=maxvl)
598
599 def locs():
600 # type: () -> Iterable[Loc]
601 if self.fixed_loc is not None:
602 if ty != self.fixed_loc.ty:
603 raise ValueError(
604 f"instantiation failed: type mismatch with fixed_loc: "
605 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
606 yield self.fixed_loc
607 return
608 for sub_kind in self.sub_kinds:
609 yield from sub_kind.allocatable_locs(ty)
610 loc_set_before_spread = LocSet(locs())
611 for idx in range(rep_count):
612 if not self.spread:
613 idx = None
614 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
615 tied_input_index=self.tied_input_index,
616 spread_index=idx)
617
618
619 @plain_data(frozen=True, unsafe_hash=True)
620 @final
621 class OperandDesc:
622 """Op operand descriptor"""
623 __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
624
625 def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
626 # type: (LocSet, int | None, int | None) -> None
627 if len(loc_set_before_spread) == 0:
628 raise ValueError("loc_set_before_spread must not be empty")
629 self.loc_set_before_spread = loc_set_before_spread
630 self.tied_input_index = tied_input_index
631 if self.tied_input_index is not None and self.spread_index is not None:
632 raise ValueError("operand can't be both spread and tied")
633 self.spread_index = spread_index
634
635 @cached_property
636 def ty_before_spread(self):
637 # type: () -> Ty
638 ty = self.loc_set_before_spread.ty
639 assert ty is not None, (
640 "__init__ checked that the LocSet isn't empty, "
641 "non-empty LocSets should always have ty set")
642 return ty
643
644 @cached_property
645 def ty(self):
646 """ Ty after any spread is applied """
647 if self.spread_index is not None:
648 # assumes all spread operands have ty.reg_len = 1
649 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
650 return self.ty_before_spread
651
652 @property
653 def reg_offset_in_unspread(self):
654 """ the number of reg-sized slots in the unspread Loc before self's Loc
655
656 e.g. if the unspread Loc containing self is:
657 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
658 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
659 then reg_offset_into_unspread == 2 == 10 - 8
660 """
661 if self.spread_index is None:
662 return 0
663 return self.spread_index * self.ty.reg_len
664
665
666 OD_BASE_SGPR = GenericOperandDesc(
667 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
668 sub_kinds=[LocSubKind.BASE_GPR])
669 OD_EXTRA3_SGPR = GenericOperandDesc(
670 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
671 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
672 OD_EXTRA3_VGPR = GenericOperandDesc(
673 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
674 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
675 OD_EXTRA2_SGPR = GenericOperandDesc(
676 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
677 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
678 OD_EXTRA2_VGPR = GenericOperandDesc(
679 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
680 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
681 OD_CA = GenericOperandDesc(
682 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
683 sub_kinds=[LocSubKind.CA])
684 OD_VL = GenericOperandDesc(
685 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
686 sub_kinds=[LocSubKind.VL_MAXVL])
687
688
689 @plain_data(frozen=True, unsafe_hash=True)
690 @final
691 class GenericOpProperties:
692 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
693 "is_copy", "is_load_immediate", "has_side_effects")
694
695 def __init__(
696 self, demo_asm, # type: str
697 inputs, # type: Iterable[GenericOperandDesc]
698 outputs, # type: Iterable[GenericOperandDesc]
699 immediates=(), # type: Iterable[range]
700 is_copy=False, # type: bool
701 is_load_immediate=False, # type: bool
702 has_side_effects=False, # type: bool
703 ):
704 # type: (...) -> None
705 self.demo_asm = demo_asm
706 self.inputs = tuple(inputs)
707 for inp in self.inputs:
708 if inp.tied_input_index is not None:
709 raise ValueError(
710 f"tied_input_index is not allowed on inputs: {inp}")
711 self.outputs = tuple(outputs)
712 fixed_locs = [] # type: list[tuple[Loc, int]]
713 for idx, out in enumerate(self.outputs):
714 if out.tied_input_index is not None:
715 if out.tied_input_index >= len(self.inputs):
716 raise ValueError(f"tied_input_index out of range: {out}")
717 tied_inp = self.inputs[out.tied_input_index]
718 if tied_inp.tied_to_input(out.tied_input_index) != out:
719 raise ValueError(f"output can't be tied to non-equivalent "
720 f"input: {out} tied to {tied_inp}")
721 if out.fixed_loc is not None:
722 for other_fixed_loc, other_idx in fixed_locs:
723 if not other_fixed_loc.conflicts(out.fixed_loc):
724 continue
725 raise ValueError(
726 f"conflicting fixed_locs: outputs[{idx}] and "
727 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
728 f"with {other_fixed_loc}")
729 fixed_locs.append((out.fixed_loc, idx))
730 self.immediates = tuple(immediates)
731 self.is_copy = is_copy
732 self.is_load_immediate = is_load_immediate
733 self.has_side_effects = has_side_effects
734
735
736 @plain_data(frozen=True, unsafe_hash=True)
737 @final
738 class OpProperties:
739 __slots__ = "kind", "inputs", "outputs", "maxvl"
740
741 def __init__(self, kind, maxvl):
742 # type: (OpKind, int) -> None
743 self.kind = kind
744 inputs = [] # type: list[OperandDesc]
745 for inp in self.generic.inputs:
746 inputs.extend(inp.instantiate(maxvl=maxvl))
747 self.inputs = tuple(inputs)
748 outputs = [] # type: list[OperandDesc]
749 for out in self.generic.outputs:
750 outputs.extend(out.instantiate(maxvl=maxvl))
751 self.outputs = tuple(outputs)
752 self.maxvl = maxvl
753
754 @property
755 def generic(self):
756 # type: () -> GenericOpProperties
757 return self.kind.properties
758
759 @property
760 def immediates(self):
761 # type: () -> tuple[range, ...]
762 return self.generic.immediates
763
764 @property
765 def demo_asm(self):
766 # type: () -> str
767 return self.generic.demo_asm
768
769 @property
770 def is_copy(self):
771 # type: () -> bool
772 return self.generic.is_copy
773
774 @property
775 def is_load_immediate(self):
776 # type: () -> bool
777 return self.generic.is_load_immediate
778
779 @property
780 def has_side_effects(self):
781 # type: () -> bool
782 return self.generic.has_side_effects
783
784
785 IMM_S16 = range(-1 << 15, 1 << 15)
786
787 _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
788 _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
789 _PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
790
791
792 @unique
793 @final
794 class OpKind(Enum):
795 def __init__(self, properties):
796 # type: (GenericOpProperties) -> None
797 super().__init__()
798 self.__properties = properties
799
800 @property
801 def properties(self):
802 # type: () -> GenericOpProperties
803 return self.__properties
804
805 def instantiate(self, maxvl):
806 # type: (int) -> OpProperties
807 return OpProperties(self, maxvl=maxvl)
808
809 def __repr__(self):
810 return "OpKind." + self._name_
811
812 @cached_property
813 def pre_ra_sim(self):
814 # type: () -> _PRE_RA_SIM_FN
815 return _PRE_RA_SIMS[self.properties]()
816
817 @staticmethod
818 def __clearca_pre_ra_sim(op, state):
819 # type: (Op, PreRASimState) -> None
820 state.ssa_vals[op.outputs[0]] = False,
821 ClearCA = GenericOpProperties(
822 demo_asm="addic 0, 0, 0",
823 inputs=[],
824 outputs=[OD_CA],
825 )
826 _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
827
828 @staticmethod
829 def __setca_pre_ra_sim(op, state):
830 # type: (Op, PreRASimState) -> None
831 state.ssa_vals[op.outputs[0]] = True,
832 SetCA = GenericOpProperties(
833 demo_asm="subfc 0, 0, 0",
834 inputs=[],
835 outputs=[OD_CA],
836 )
837 _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
838
839 @staticmethod
840 def __svadde_pre_ra_sim(op, state):
841 # type: (Op, PreRASimState) -> None
842 RA = state.ssa_vals[op.input_vals[0]]
843 RB = state.ssa_vals[op.input_vals[1]]
844 carry, = state.ssa_vals[op.input_vals[2]]
845 VL, = state.ssa_vals[op.input_vals[3]]
846 RT = [] # type: list[int]
847 for i in range(VL):
848 v = RA[i] + RB[i] + carry
849 RT.append(v & GPR_VALUE_MASK)
850 carry = (v >> GPR_SIZE_IN_BITS) != 0
851 state.ssa_vals[op.outputs[0]] = tuple(RT)
852 state.ssa_vals[op.outputs[1]] = carry,
853 SvAddE = GenericOpProperties(
854 demo_asm="sv.adde *RT, *RA, *RB",
855 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
856 outputs=[OD_EXTRA3_VGPR, OD_CA],
857 )
858 _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
859
860 @staticmethod
861 def __svsubfe_pre_ra_sim(op, state):
862 # type: (Op, PreRASimState) -> None
863 RA = state.ssa_vals[op.input_vals[0]]
864 RB = state.ssa_vals[op.input_vals[1]]
865 carry, = state.ssa_vals[op.input_vals[2]]
866 VL, = state.ssa_vals[op.input_vals[3]]
867 RT = [] # type: list[int]
868 for i in range(VL):
869 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
870 RT.append(v & GPR_VALUE_MASK)
871 carry = (v >> GPR_SIZE_IN_BITS) != 0
872 state.ssa_vals[op.outputs[0]] = tuple(RT)
873 state.ssa_vals[op.outputs[1]] = carry,
874 SvSubFE = GenericOpProperties(
875 demo_asm="sv.subfe *RT, *RA, *RB",
876 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
877 outputs=[OD_EXTRA3_VGPR, OD_CA],
878 )
879 _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
880
881 @staticmethod
882 def __svmaddedu_pre_ra_sim(op, state):
883 # type: (Op, PreRASimState) -> None
884 RA = state.ssa_vals[op.input_vals[0]]
885 RB, = state.ssa_vals[op.input_vals[1]]
886 carry, = state.ssa_vals[op.input_vals[2]]
887 VL, = state.ssa_vals[op.input_vals[3]]
888 RT = [] # type: list[int]
889 for i in range(VL):
890 v = RA[i] * RB + carry
891 RT.append(v & GPR_VALUE_MASK)
892 carry = v >> GPR_SIZE_IN_BITS
893 state.ssa_vals[op.outputs[0]] = tuple(RT)
894 state.ssa_vals[op.outputs[1]] = carry,
895 SvMAddEDU = GenericOpProperties(
896 demo_asm="sv.maddedu *RT, *RA, RB, RC",
897 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
898 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
899 )
900 _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
901
902 @staticmethod
903 def __setvli_pre_ra_sim(op, state):
904 # type: (Op, PreRASimState) -> None
905 state.ssa_vals[op.outputs[0]] = op.immediates[0],
906 SetVLI = GenericOpProperties(
907 demo_asm="setvl 0, 0, imm, 0, 1, 1",
908 inputs=(),
909 outputs=[OD_VL],
910 immediates=[range(1, 65)],
911 is_load_immediate=True,
912 )
913 _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
914
915 @staticmethod
916 def __svli_pre_ra_sim(op, state):
917 # type: (Op, PreRASimState) -> None
918 VL, = state.ssa_vals[op.input_vals[0]]
919 imm = op.immediates[0] & GPR_VALUE_MASK
920 state.ssa_vals[op.outputs[0]] = (imm,) * VL
921 SvLI = GenericOpProperties(
922 demo_asm="sv.addi *RT, 0, imm",
923 inputs=[OD_VL],
924 outputs=[OD_EXTRA3_VGPR],
925 immediates=[IMM_S16],
926 is_load_immediate=True,
927 )
928 _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
929
930 @staticmethod
931 def __li_pre_ra_sim(op, state):
932 # type: (Op, PreRASimState) -> None
933 imm = op.immediates[0] & GPR_VALUE_MASK
934 state.ssa_vals[op.outputs[0]] = imm,
935 LI = GenericOpProperties(
936 demo_asm="addi RT, 0, imm",
937 inputs=(),
938 outputs=[OD_BASE_SGPR],
939 immediates=[IMM_S16],
940 is_load_immediate=True,
941 )
942 _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
943
944 @staticmethod
945 def __veccopytoreg_pre_ra_sim(op, state):
946 # type: (Op, PreRASimState) -> None
947 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
948 VecCopyToReg = GenericOpProperties(
949 demo_asm="sv.mv dest, src",
950 inputs=[GenericOperandDesc(
951 ty=GenericTy(BaseTy.I64, is_vec=True),
952 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
953 ), OD_VL],
954 outputs=[OD_EXTRA3_VGPR],
955 is_copy=True,
956 )
957 _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
958
959 @staticmethod
960 def __veccopyfromreg_pre_ra_sim(op, state):
961 # type: (Op, PreRASimState) -> None
962 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
963 VecCopyFromReg = GenericOpProperties(
964 demo_asm="sv.mv dest, src",
965 inputs=[OD_EXTRA3_VGPR, OD_VL],
966 outputs=[GenericOperandDesc(
967 ty=GenericTy(BaseTy.I64, is_vec=True),
968 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
969 )],
970 is_copy=True,
971 )
972 _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
973
974 @staticmethod
975 def __copytoreg_pre_ra_sim(op, state):
976 # type: (Op, PreRASimState) -> None
977 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
978 CopyToReg = GenericOpProperties(
979 demo_asm="mv dest, src",
980 inputs=[GenericOperandDesc(
981 ty=GenericTy(BaseTy.I64, is_vec=False),
982 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
983 LocSubKind.StackI64],
984 )],
985 outputs=[GenericOperandDesc(
986 ty=GenericTy(BaseTy.I64, is_vec=False),
987 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
988 )],
989 is_copy=True,
990 )
991 _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
992
993 @staticmethod
994 def __copyfromreg_pre_ra_sim(op, state):
995 # type: (Op, PreRASimState) -> None
996 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
997 CopyFromReg = GenericOpProperties(
998 demo_asm="mv dest, src",
999 inputs=[GenericOperandDesc(
1000 ty=GenericTy(BaseTy.I64, is_vec=False),
1001 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1002 )],
1003 outputs=[GenericOperandDesc(
1004 ty=GenericTy(BaseTy.I64, is_vec=False),
1005 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1006 LocSubKind.StackI64],
1007 )],
1008 is_copy=True,
1009 )
1010 _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
1011
1012 @staticmethod
1013 def __concat_pre_ra_sim(op, state):
1014 # type: (Op, PreRASimState) -> None
1015 state.ssa_vals[op.outputs[0]] = tuple(
1016 state.ssa_vals[i][0] for i in op.input_vals[:-1])
1017 Concat = GenericOpProperties(
1018 demo_asm="sv.mv dest, src",
1019 inputs=[GenericOperandDesc(
1020 ty=GenericTy(BaseTy.I64, is_vec=False),
1021 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1022 spread=True,
1023 ), OD_VL],
1024 outputs=[OD_EXTRA3_VGPR],
1025 is_copy=True,
1026 )
1027 _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
1028
1029 @staticmethod
1030 def __spread_pre_ra_sim(op, state):
1031 # type: (Op, PreRASimState) -> None
1032 for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
1033 state.ssa_vals[op.outputs[idx]] = inp,
1034 Spread = GenericOpProperties(
1035 demo_asm="sv.mv dest, src",
1036 inputs=[OD_EXTRA3_VGPR, OD_VL],
1037 outputs=[GenericOperandDesc(
1038 ty=GenericTy(BaseTy.I64, is_vec=False),
1039 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1040 spread=True,
1041 )],
1042 is_copy=True,
1043 )
1044 _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
1045
1046 @staticmethod
1047 def __svld_pre_ra_sim(op, state):
1048 # type: (Op, PreRASimState) -> None
1049 RA, = state.ssa_vals[op.input_vals[0]]
1050 VL, = state.ssa_vals[op.input_vals[1]]
1051 addr = RA + op.immediates[0]
1052 RT = [] # type: list[int]
1053 for i in range(VL):
1054 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1055 RT.append(v & GPR_VALUE_MASK)
1056 state.ssa_vals[op.outputs[0]] = tuple(RT)
1057 SvLd = GenericOpProperties(
1058 demo_asm="sv.ld *RT, imm(RA)",
1059 inputs=[OD_EXTRA3_SGPR, OD_VL],
1060 outputs=[OD_EXTRA3_VGPR],
1061 immediates=[IMM_S16],
1062 )
1063 _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
1064
1065 @staticmethod
1066 def __ld_pre_ra_sim(op, state):
1067 # type: (Op, PreRASimState) -> None
1068 RA, = state.ssa_vals[op.input_vals[0]]
1069 addr = RA + op.immediates[0]
1070 v = state.load(addr)
1071 state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
1072 Ld = GenericOpProperties(
1073 demo_asm="ld RT, imm(RA)",
1074 inputs=[OD_BASE_SGPR],
1075 outputs=[OD_BASE_SGPR],
1076 immediates=[IMM_S16],
1077 )
1078 _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
1079
1080 @staticmethod
1081 def __svstd_pre_ra_sim(op, state):
1082 # type: (Op, PreRASimState) -> None
1083 RS = state.ssa_vals[op.input_vals[0]]
1084 RA, = state.ssa_vals[op.input_vals[1]]
1085 VL, = state.ssa_vals[op.input_vals[2]]
1086 addr = RA + op.immediates[0]
1087 for i in range(VL):
1088 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1089 SvStd = GenericOpProperties(
1090 demo_asm="sv.std *RS, imm(RA)",
1091 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1092 outputs=[],
1093 immediates=[IMM_S16],
1094 has_side_effects=True,
1095 )
1096 _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
1097
1098 @staticmethod
1099 def __std_pre_ra_sim(op, state):
1100 # type: (Op, PreRASimState) -> None
1101 RS, = state.ssa_vals[op.input_vals[0]]
1102 RA, = state.ssa_vals[op.input_vals[1]]
1103 addr = RA + op.immediates[0]
1104 state.store(addr, value=RS)
1105 Std = GenericOpProperties(
1106 demo_asm="std RT, imm(RA)",
1107 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1108 outputs=[],
1109 immediates=[IMM_S16],
1110 has_side_effects=True,
1111 )
1112 _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
1113
1114 @staticmethod
1115 def __funcargr3_pre_ra_sim(op, state):
1116 # type: (Op, PreRASimState) -> None
1117 pass # return value set before simulation
1118 FuncArgR3 = GenericOpProperties(
1119 demo_asm="",
1120 inputs=[],
1121 outputs=[OD_BASE_SGPR.with_fixed_loc(
1122 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1123 )
1124 _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
1125
1126
1127 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1128 class SSAValOrUse(metaclass=ABCMeta):
1129 __slots__ = "op", "operand_idx"
1130
1131 def __init__(self, op, operand_idx):
1132 # type: (Op, int) -> None
1133 self.op = op
1134 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1135 raise ValueError("invalid operand_idx")
1136 self.operand_idx = operand_idx
1137
1138 @abstractmethod
1139 def __repr__(self):
1140 # type: () -> str
1141 ...
1142
1143 @property
1144 @abstractmethod
1145 def descriptor_array(self):
1146 # type: () -> tuple[OperandDesc, ...]
1147 ...
1148
1149 @property
1150 def defining_descriptor(self):
1151 # type: () -> OperandDesc
1152 return self.descriptor_array[self.operand_idx]
1153
1154 @cached_property
1155 def ty(self):
1156 # type: () -> Ty
1157 return self.defining_descriptor.ty
1158
1159 @cached_property
1160 def ty_before_spread(self):
1161 # type: () -> Ty
1162 return self.defining_descriptor.ty_before_spread
1163
1164 @property
1165 def base_ty(self):
1166 # type: () -> BaseTy
1167 return self.ty_before_spread.base_ty
1168
1169 @property
1170 def reg_offset_in_unspread(self):
1171 """ the number of reg-sized slots in the unspread Loc before self's Loc
1172
1173 e.g. if the unspread Loc containing self is:
1174 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1175 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1176 then reg_offset_into_unspread == 2 == 10 - 8
1177 """
1178 return self.defining_descriptor.reg_offset_in_unspread
1179
1180 @property
1181 def unspread_start_idx(self):
1182 # type: () -> int
1183 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1184
1185 @property
1186 def unspread_start(self):
1187 # type: () -> Self
1188 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1189
1190
1191 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1192 @final
1193 class SSAVal(SSAValOrUse):
1194 __slots__ = ()
1195
1196 def __repr__(self):
1197 # type: () -> str
1198 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1199
1200 @cached_property
1201 def def_loc_set_before_spread(self):
1202 # type: () -> LocSet
1203 return self.defining_descriptor.loc_set_before_spread
1204
1205 @cached_property
1206 def descriptor_array(self):
1207 # type: () -> tuple[OperandDesc, ...]
1208 return self.op.properties.outputs
1209
1210 @cached_property
1211 def tied_input(self):
1212 # type: () -> None | SSAUse
1213 if self.defining_descriptor.tied_input_index is None:
1214 return None
1215 return SSAUse(op=self.op,
1216 operand_idx=self.defining_descriptor.tied_input_index)
1217
1218
1219 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1220 @final
1221 class SSAUse(SSAValOrUse):
1222 __slots__ = ()
1223
1224 @cached_property
1225 def use_loc_set_before_spread(self):
1226 # type: () -> LocSet
1227 return self.defining_descriptor.loc_set_before_spread
1228
1229 @cached_property
1230 def descriptor_array(self):
1231 # type: () -> tuple[OperandDesc, ...]
1232 return self.op.properties.inputs
1233
1234 def __repr__(self):
1235 # type: () -> str
1236 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1237
1238 @property
1239 def ssa_val(self):
1240 # type: () -> SSAVal
1241 return self.op.input_vals[self.operand_idx]
1242
1243 @ssa_val.setter
1244 def ssa_val(self, ssa_val):
1245 # type: (SSAVal) -> None
1246 self.op.input_vals[self.operand_idx] = ssa_val
1247
1248
1249 _T = TypeVar("_T")
1250 _Desc = TypeVar("_Desc")
1251
1252
1253 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1254 @abstractmethod
1255 def _verify_write_with_desc(self, idx, item, desc):
1256 # type: (int, _T | Any, _Desc) -> None
1257 raise NotImplementedError
1258
1259 @final
1260 def _verify_write(self, idx, item):
1261 # type: (int | Any, _T | Any) -> int
1262 if not isinstance(idx, int):
1263 if isinstance(idx, slice):
1264 raise TypeError(
1265 f"can't write to slice of {self.__class__.__name__}")
1266 raise TypeError(f"can't write with index {idx!r}")
1267 # normalize idx, raising IndexError if it is out of range
1268 idx = range(len(self.descriptors))[idx]
1269 desc = self.descriptors[idx]
1270 self._verify_write_with_desc(idx, item, desc)
1271 return idx
1272
1273 def _on_set(self, idx, new_item, old_item):
1274 # type: (int, _T, _T | None) -> None
1275 pass
1276
1277 @abstractmethod
1278 def _get_descriptors(self):
1279 # type: () -> tuple[_Desc, ...]
1280 raise NotImplementedError
1281
1282 @cached_property
1283 @final
1284 def descriptors(self):
1285 # type: () -> tuple[_Desc, ...]
1286 return self._get_descriptors()
1287
1288 @property
1289 @final
1290 def op(self):
1291 return self.__op
1292
1293 def __init__(self, items, op):
1294 # type: (Iterable[_T], Op) -> None
1295 self.__op = op
1296 self.__items = [] # type: list[_T]
1297 for idx, item in enumerate(items):
1298 if idx >= len(self.descriptors):
1299 raise ValueError("too many items")
1300 self._verify_write(idx, item)
1301 self.__items.append(item)
1302 if len(self.__items) < len(self.descriptors):
1303 raise ValueError("not enough items")
1304
1305 @final
1306 def __iter__(self):
1307 # type: () -> Iterator[_T]
1308 yield from self.__items
1309
1310 @overload
1311 def __getitem__(self, idx):
1312 # type: (int) -> _T
1313 ...
1314
1315 @overload
1316 def __getitem__(self, idx):
1317 # type: (slice) -> list[_T]
1318 ...
1319
1320 @final
1321 def __getitem__(self, idx):
1322 # type: (int | slice) -> _T | list[_T]
1323 return self.__items[idx]
1324
1325 @final
1326 def __setitem__(self, idx, item):
1327 # type: (int, _T) -> None
1328 idx = self._verify_write(idx, item)
1329 self.__items[idx] = item
1330
1331 @final
1332 def __len__(self):
1333 # type: () -> int
1334 return len(self.__items)
1335
1336 def __repr__(self):
1337 return f"{self.__class__.__name__}({self.__items}, op=...)"
1338
1339
1340 @final
1341 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1342 def _get_descriptors(self):
1343 # type: () -> tuple[OperandDesc, ...]
1344 return self.op.properties.inputs
1345
1346 def _verify_write_with_desc(self, idx, item, desc):
1347 # type: (int, SSAVal | Any, OperandDesc) -> None
1348 if not isinstance(item, SSAVal):
1349 raise TypeError("expected value of type SSAVal")
1350 if item.ty != desc.ty:
1351 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1352 f"corresponding input's type {desc.ty!r}")
1353
1354 def _on_set(self, idx, new_item, old_item):
1355 # type: (int, SSAVal, SSAVal | None) -> None
1356 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1357
1358 def __init__(self, items, op):
1359 # type: (Iterable[SSAVal], Op) -> None
1360 if hasattr(op, "inputs"):
1361 raise ValueError("Op.inputs already set")
1362 super().__init__(items, op)
1363
1364
1365 @final
1366 class OpImmediates(OpInputSeq[int, range]):
1367 def _get_descriptors(self):
1368 # type: () -> tuple[range, ...]
1369 return self.op.properties.immediates
1370
1371 def _verify_write_with_desc(self, idx, item, desc):
1372 # type: (int, int | Any, range) -> None
1373 if not isinstance(item, int):
1374 raise TypeError("expected value of type int")
1375 if item not in desc:
1376 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1377
1378 def __init__(self, items, op):
1379 # type: (Iterable[int], Op) -> None
1380 if hasattr(op, "immediates"):
1381 raise ValueError("Op.immediates already set")
1382 super().__init__(items, op)
1383
1384
1385 @plain_data(frozen=True, eq=False, repr=False)
1386 @final
1387 class Op:
1388 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1389 "outputs", "name")
1390
1391 def __init__(self, fn, properties, input_vals, immediates, name=""):
1392 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1393 self.fn = fn
1394 self.properties = properties
1395 self.input_vals = OpInputVals(input_vals, op=self)
1396 inputs_len = len(self.properties.inputs)
1397 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1398 self.immediates = OpImmediates(immediates, op=self)
1399 outputs_len = len(self.properties.outputs)
1400 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1401 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1402
1403 @property
1404 def kind(self):
1405 return self.properties.kind
1406
1407 def __eq__(self, other):
1408 # type: (Op | Any) -> bool
1409 if isinstance(other, Op):
1410 return self is other
1411 return NotImplemented
1412
1413 def __hash__(self):
1414 return object.__hash__(self)
1415
1416 def __repr__(self):
1417 # type: () -> str
1418 field_vals = [] # type: list[str]
1419 for name in fields(self):
1420 if name == "properties":
1421 name = "kind"
1422 elif name == "fn":
1423 continue
1424 try:
1425 value = getattr(self, name)
1426 except AttributeError:
1427 field_vals.append(f"{name}=<not set>")
1428 continue
1429 if isinstance(value, OpInputSeq):
1430 value = list(value) # type: ignore
1431 field_vals.append(f"{name}={value!r}")
1432 field_vals_str = ", ".join(field_vals)
1433 return f"Op({field_vals_str})"
1434
1435 def pre_ra_sim(self, state):
1436 # type: (PreRASimState) -> None
1437 for inp in self.input_vals:
1438 if inp not in state.ssa_vals:
1439 raise ValueError(f"SSAVal {inp} not yet assigned when "
1440 f"running {self}")
1441 if len(state.ssa_vals[inp]) != inp.ty.reg_len:
1442 raise ValueError(
1443 f"value of SSAVal {inp} has wrong number of elements: "
1444 f"expected {inp.ty.reg_len} found "
1445 f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1446 for out in self.outputs:
1447 if out in state.ssa_vals:
1448 if self.kind is OpKind.FuncArgR3:
1449 continue
1450 raise ValueError(f"SSAVal {out} already assigned before "
1451 f"running {self}")
1452 self.kind.pre_ra_sim(self, state)
1453 for out in self.outputs:
1454 if out not in state.ssa_vals:
1455 raise ValueError(f"running {self} failed to assign to {out}")
1456 if len(state.ssa_vals[out]) != out.ty.reg_len:
1457 raise ValueError(
1458 f"value of SSAVal {out} has wrong number of elements: "
1459 f"expected {out.ty.reg_len} found "
1460 f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1461
1462
1463 GPR_SIZE_IN_BYTES = 8
1464 BITS_IN_BYTE = 8
1465 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
1466 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
1467
1468
1469 @plain_data(frozen=True, repr=False)
1470 @final
1471 class PreRASimState:
1472 __slots__ = "ssa_vals", "memory"
1473
1474 def __init__(self, ssa_vals, memory):
1475 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1476 self.ssa_vals = ssa_vals
1477 self.memory = memory
1478
1479 def load_byte(self, addr):
1480 # type: (int) -> int
1481 addr &= GPR_VALUE_MASK
1482 return self.memory.get(addr, 0) & 0xFF
1483
1484 def store_byte(self, addr, value):
1485 # type: (int, int) -> None
1486 addr &= GPR_VALUE_MASK
1487 value &= 0xFF
1488 self.memory[addr] = value
1489
1490 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
1491 # type: (int, int, bool) -> int
1492 if addr % size_in_bytes != 0:
1493 raise ValueError(f"address not aligned: {hex(addr)} "
1494 f"required alignment: {size_in_bytes}")
1495 retval = 0
1496 for i in range(size_in_bytes):
1497 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
1498 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
1499 retval -= 1 << size_in_bytes * BITS_IN_BYTE
1500 return retval
1501
1502 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
1503 # type: (int, int, int) -> None
1504 if addr % size_in_bytes != 0:
1505 raise ValueError(f"address not aligned: {hex(addr)} "
1506 f"required alignment: {size_in_bytes}")
1507 for i in range(size_in_bytes):
1508 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
1509
1510 def _memory__repr(self):
1511 # type: () -> str
1512 if len(self.memory) == 0:
1513 return "{}"
1514 keys = sorted(self.memory.keys(), reverse=True)
1515 CHUNK_SIZE = GPR_SIZE_IN_BYTES
1516 items = [] # type: list[str]
1517 while len(keys) != 0:
1518 addr = keys[-1]
1519 if (len(keys) >= CHUNK_SIZE
1520 and addr % CHUNK_SIZE == 0
1521 and keys[-CHUNK_SIZE:]
1522 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
1523 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
1524 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1525 keys[-CHUNK_SIZE:] = ()
1526 else:
1527 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1528 if len(items) == 1:
1529 return f"{{{items[0]}}}"
1530 items_str = ",\n".join(items)
1531 return f"{{\n{items_str}}}"
1532
1533 def _ssa_vals__repr(self):
1534 # type: () -> str
1535 if len(self.ssa_vals) == 0:
1536 return "{}"
1537 items = [] # type: list[str]
1538 CHUNK_SIZE = 4
1539 for k, v in self.ssa_vals.items():
1540 element_strs = [] # type: list[str]
1541 for i, el in enumerate(v):
1542 if i % CHUNK_SIZE != 0:
1543 element_strs.append(" " + hex(el))
1544 else:
1545 element_strs.append("\n " + hex(el))
1546 if len(element_strs) <= CHUNK_SIZE:
1547 element_strs[0] = element_strs[0].lstrip()
1548 if len(element_strs) == 1:
1549 element_strs.append("")
1550 v_str = ",".join(element_strs)
1551 items.append(f"{k!r}: ({v_str})")
1552 if len(items) == 1 and "\n" not in items[0]:
1553 return f"{{{items[0]}}}"
1554 items_str = ",\n".join(items)
1555 return f"{{\n{items_str},\n}}"
1556
1557 def __repr__(self):
1558 # type: () -> str
1559 field_vals = [] # type: list[str]
1560 for name in fields(self):
1561 try:
1562 value = getattr(self, name)
1563 except AttributeError:
1564 field_vals.append(f"{name}=<not set>")
1565 continue
1566 repr_fn = getattr(self, f"_{name}__repr", None)
1567 if callable(repr_fn):
1568 field_vals.append(f"{name}={repr_fn()}")
1569 else:
1570 field_vals.append(f"{name}={value!r}")
1571 field_vals_str = ", ".join(field_vals)
1572 return f"PreRASimState({field_vals_str})"