7109e776d23ce08fc70d1ec2e44b01cde4e77d44
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 import enum
2 from abc import abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache
5 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
6 Sequence, TypeVar, overload)
7 from weakref import WeakValueDictionary as _WeakVDict
8
9 from cached_property import cached_property
10 from nmutil.plain_data import fields, plain_data
11
12 from bigint_presentation_code.type_util import Self, assert_never, final
13 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
14
15
16 @final
17 class Fn:
18 def __init__(self):
19 self.ops = [] # type: list[Op]
20 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
21 self.__next_name_suffix = 2
22
23 def _add_op_with_unused_name(self, op, name=""):
24 # type: (Op, str) -> str
25 if op.fn is not self:
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op, "name"):
28 raise ValueError("Op already named")
29 orig_name = name
30 while True:
31 if name != "" and name not in self.__op_names:
32 self.__op_names[name] = op
33 return name
34 name = orig_name + str(self.__next_name_suffix)
35 self.__next_name_suffix += 1
36
37 def __repr__(self):
38 # type: () -> str
39 return "<Fn>"
40
41 def append_op(self, op):
42 # type: (Op) -> None
43 if op.fn is not self:
44 raise ValueError("can't add Op to wrong Fn")
45 self.ops.append(op)
46
47 def append_new_op(self, kind, inputs=(), immediates=(), name="", maxvl=1):
48 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
49 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
50 inputs=inputs, immediates=immediates, name=name)
51 self.append_op(retval)
52 return retval
53
54 def pre_ra_sim(self, state):
55 # type: (PreRASimState) -> None
56 for op in self.ops:
57 op.pre_ra_sim(state)
58
59 def pre_ra_insert_copies(self):
60 # type: () -> None
61 orig_ops = list(self.ops)
62 copied_outputs = {} # type: dict[SSAVal, SSAVal]
63 self.ops.clear()
64 for op in orig_ops:
65 for i in range(len(op.inputs)):
66 inp = copied_outputs[op.inputs[i]]
67 if inp.ty.base_ty is BaseTy.I64:
68 maxvl = inp.ty.reg_len
69 if inp.ty.reg_len != 1:
70 setvl = self.append_new_op(OpKind.SetVLI,
71 immediates=[maxvl])
72 vl = setvl.outputs[0]
73 mv = self.append_new_op(OpKind.VecCopyToReg,
74 inputs=[inp, vl], maxvl=maxvl)
75 else:
76 mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp])
77 op.inputs[i] = mv.outputs[0]
78 elif inp.ty.base_ty is BaseTy.CA \
79 or inp.ty.base_ty is BaseTy.VL_MAXVL:
80 # all copies would be no-ops, so we don't need to copy
81 op.inputs[i] = inp
82 else:
83 assert_never(inp.ty.base_ty)
84 self.ops.append(op)
85 for out in op.outputs:
86 if out.ty.base_ty is BaseTy.I64:
87 maxvl = out.ty.reg_len
88 if out.ty.reg_len != 1:
89 setvl = self.append_new_op(OpKind.SetVLI,
90 immediates=[maxvl])
91 vl = setvl.outputs[0]
92 mv = self.append_new_op(OpKind.VecCopyFromReg,
93 inputs=[out, vl], maxvl=maxvl)
94 else:
95 mv = self.append_new_op(OpKind.CopyFromReg,
96 inputs=[out])
97 copied_outputs[out] = mv.outputs[0]
98 elif out.ty.base_ty is BaseTy.CA \
99 or out.ty.base_ty is BaseTy.VL_MAXVL:
100 # all copies would be no-ops, so we don't need to copy
101 copied_outputs[out] = out
102 else:
103 assert_never(out.ty.base_ty)
104
105
106 @unique
107 @final
108 class BaseTy(Enum):
109 I64 = enum.auto()
110 CA = enum.auto()
111 VL_MAXVL = enum.auto()
112
113 @cached_property
114 def only_scalar(self):
115 # type: () -> bool
116 if self is BaseTy.I64:
117 return False
118 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
119 return True
120 else:
121 assert_never(self)
122
123 @cached_property
124 def max_reg_len(self):
125 # type: () -> int
126 if self is BaseTy.I64:
127 return 128
128 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
129 return 1
130 else:
131 assert_never(self)
132
133 def __repr__(self):
134 return "BaseTy." + self._name_
135
136
137 @plain_data(frozen=True, unsafe_hash=True, repr=False)
138 @final
139 class Ty:
140 __slots__ = "base_ty", "reg_len"
141
142 @staticmethod
143 def validate(base_ty, reg_len):
144 # type: (BaseTy, int) -> str | None
145 """ return a string with the error if the combination is invalid,
146 otherwise return None
147 """
148 if base_ty.only_scalar and reg_len != 1:
149 return f"can't create a vector of an only-scalar type: {base_ty}"
150 if reg_len < 1 or reg_len > base_ty.max_reg_len:
151 return "reg_len out of range"
152 return None
153
154 def __init__(self, base_ty, reg_len):
155 # type: (BaseTy, int) -> None
156 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
157 if msg is not None:
158 raise ValueError(msg)
159 self.base_ty = base_ty
160 self.reg_len = reg_len
161
162 def __repr__(self):
163 # type: () -> str
164 if self.reg_len != 1:
165 reg_len = f"*{self.reg_len}"
166 else:
167 reg_len = ""
168 return f"<{self.base_ty._name_}{reg_len}>"
169
170
171 @unique
172 @final
173 class LocKind(Enum):
174 GPR = enum.auto()
175 StackI64 = enum.auto()
176 CA = enum.auto()
177 VL_MAXVL = enum.auto()
178
179 @cached_property
180 def base_ty(self):
181 # type: () -> BaseTy
182 if self is LocKind.GPR or self is LocKind.StackI64:
183 return BaseTy.I64
184 if self is LocKind.CA:
185 return BaseTy.CA
186 if self is LocKind.VL_MAXVL:
187 return BaseTy.VL_MAXVL
188 else:
189 assert_never(self)
190
191 @cached_property
192 def loc_count(self):
193 # type: () -> int
194 if self is LocKind.StackI64:
195 return 1024
196 if self is LocKind.GPR or self is LocKind.CA \
197 or self is LocKind.VL_MAXVL:
198 return self.base_ty.max_reg_len
199 else:
200 assert_never(self)
201
202 def __repr__(self):
203 return "LocKind." + self._name_
204
205
206 @final
207 @unique
208 class LocSubKind(Enum):
209 BASE_GPR = enum.auto()
210 SV_EXTRA2_VGPR = enum.auto()
211 SV_EXTRA2_SGPR = enum.auto()
212 SV_EXTRA3_VGPR = enum.auto()
213 SV_EXTRA3_SGPR = enum.auto()
214 StackI64 = enum.auto()
215 CA = enum.auto()
216 VL_MAXVL = enum.auto()
217
218 @cached_property
219 def kind(self):
220 # type: () -> LocKind
221 # pyright fails typechecking when using `in` here:
222 # reported: https://github.com/microsoft/pyright/issues/4102
223 if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
224 or self is LocSubKind.SV_EXTRA2_SGPR \
225 or self is LocSubKind.SV_EXTRA3_VGPR \
226 or self is LocSubKind.SV_EXTRA3_SGPR:
227 return LocKind.GPR
228 if self is LocSubKind.StackI64:
229 return LocKind.StackI64
230 if self is LocSubKind.CA:
231 return LocKind.CA
232 if self is LocSubKind.VL_MAXVL:
233 return LocKind.VL_MAXVL
234 assert_never(self)
235
236 @property
237 def base_ty(self):
238 return self.kind.base_ty
239
240 @lru_cache()
241 def allocatable_locs(self, ty):
242 # type: (Ty) -> LocSet
243 if ty.base_ty != self.base_ty:
244 raise ValueError("type mismatch")
245 if self is LocSubKind.BASE_GPR:
246 starts = range(32)
247 elif self is LocSubKind.SV_EXTRA2_VGPR:
248 starts = range(0, 128, 2)
249 elif self is LocSubKind.SV_EXTRA2_SGPR:
250 starts = range(64)
251 elif self is LocSubKind.SV_EXTRA3_VGPR \
252 or self is LocSubKind.SV_EXTRA3_SGPR:
253 starts = range(128)
254 elif self is LocSubKind.StackI64:
255 starts = range(LocKind.StackI64.loc_count)
256 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
257 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
258 else:
259 assert_never(self)
260 retval = [] # type: list[Loc]
261 for start in starts:
262 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
263 if loc is None:
264 continue
265 conflicts = False
266 for special_loc in SPECIAL_GPRS:
267 if loc.conflicts(special_loc):
268 conflicts = True
269 break
270 if not conflicts:
271 retval.append(loc)
272 return LocSet(retval)
273
274 def __repr__(self):
275 return "LocSubKind." + self._name_
276
277
278 @plain_data(frozen=True, unsafe_hash=True)
279 @final
280 class GenericTy:
281 __slots__ = "base_ty", "is_vec"
282
283 def __init__(self, base_ty, is_vec):
284 # type: (BaseTy, bool) -> None
285 self.base_ty = base_ty
286 if base_ty.only_scalar and is_vec:
287 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
288 self.is_vec = is_vec
289
290 def instantiate(self, maxvl):
291 # type: (int) -> Ty
292 # here's where subvl and elwid would be accounted for
293 if self.is_vec:
294 return Ty(self.base_ty, maxvl)
295 return Ty(self.base_ty, 1)
296
297 def can_instantiate_to(self, ty):
298 # type: (Ty) -> bool
299 if self.base_ty != ty.base_ty:
300 return False
301 if self.is_vec:
302 return True
303 return ty.reg_len == 1
304
305
306 @plain_data(frozen=True, unsafe_hash=True)
307 @final
308 class Loc:
309 __slots__ = "kind", "start", "reg_len"
310
311 @staticmethod
312 def validate(kind, start, reg_len):
313 # type: (LocKind, int, int) -> str | None
314 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
315 if msg is not None:
316 return msg
317 if reg_len > kind.loc_count:
318 return "invalid reg_len"
319 if start < 0 or start + reg_len > kind.loc_count:
320 return "start not in valid range"
321 return None
322
323 @staticmethod
324 def try_make(kind, start, reg_len):
325 # type: (LocKind, int, int) -> Loc | None
326 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
327 if msg is not None:
328 return None
329 return Loc(kind=kind, start=start, reg_len=reg_len)
330
331 def __init__(self, kind, start, reg_len):
332 # type: (LocKind, int, int) -> None
333 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
334 if msg is not None:
335 raise ValueError(msg)
336 self.kind = kind
337 self.reg_len = reg_len
338 self.start = start
339
340 def conflicts(self, other):
341 # type: (Loc) -> bool
342 return (self.kind == other.kind
343 and self.start < other.stop and other.start < self.stop)
344
345 @staticmethod
346 def make_ty(kind, reg_len):
347 # type: (LocKind, int) -> Ty
348 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
349
350 @cached_property
351 def ty(self):
352 # type: () -> Ty
353 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
354
355 @property
356 def stop(self):
357 # type: () -> int
358 return self.start + self.reg_len
359
360 def try_concat(self, *others):
361 # type: (*Loc | None) -> Loc | None
362 reg_len = self.reg_len
363 stop = self.stop
364 for other in others:
365 if other is None or other.kind != self.kind:
366 return None
367 if stop != other.start:
368 return None
369 stop = other.stop
370 reg_len += other.reg_len
371 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
372
373
374 SPECIAL_GPRS = (
375 Loc(kind=LocKind.GPR, start=0, reg_len=1),
376 Loc(kind=LocKind.GPR, start=1, reg_len=1),
377 Loc(kind=LocKind.GPR, start=2, reg_len=1),
378 Loc(kind=LocKind.GPR, start=13, reg_len=1),
379 )
380
381
382 @plain_data(frozen=True, eq=False)
383 @final
384 class LocSet(AbstractSet[Loc]):
385 __slots__ = "starts", "ty"
386
387 def __init__(self, __locs=()):
388 # type: (Iterable[Loc]) -> None
389 if isinstance(__locs, LocSet):
390 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
391 self.ty = __locs.ty # type: Ty | None
392 return
393 starts = {i: BitSet() for i in LocKind}
394 ty = None
395 for loc in __locs:
396 if ty is None:
397 ty = loc.ty
398 if ty != loc.ty:
399 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
400 starts[loc.kind].add(loc.start)
401 self.starts = FMap(
402 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
403 self.ty = ty
404
405 @cached_property
406 def stops(self):
407 # type: () -> FMap[LocKind, FBitSet]
408 if self.ty is None:
409 return FMap()
410 sh = self.ty.reg_len
411 return FMap(
412 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
413
414 @property
415 def kinds(self):
416 # type: () -> AbstractSet[LocKind]
417 return self.starts.keys()
418
419 @property
420 def reg_len(self):
421 # type: () -> int | None
422 if self.ty is None:
423 return None
424 return self.ty.reg_len
425
426 @property
427 def base_ty(self):
428 # type: () -> BaseTy | None
429 if self.ty is None:
430 return None
431 return self.ty.base_ty
432
433 def concat(self, *others):
434 # type: (*LocSet) -> LocSet
435 if self.ty is None:
436 return LocSet()
437 base_ty = self.ty.base_ty
438 reg_len = self.ty.reg_len
439 starts = {k: BitSet(v) for k, v in self.starts.items()}
440 for other in others:
441 if other.ty is None:
442 return LocSet()
443 if other.ty.base_ty != base_ty:
444 return LocSet()
445 for kind, other_starts in other.starts.items():
446 if kind not in starts:
447 continue
448 starts[kind].bits &= other_starts.bits >> reg_len
449 if starts[kind] == 0:
450 del starts[kind]
451 if len(starts) == 0:
452 return LocSet()
453 reg_len += other.ty.reg_len
454
455 def locs():
456 # type: () -> Iterable[Loc]
457 for kind, v in starts.items():
458 for start in v:
459 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
460 if loc is not None:
461 yield loc
462 return LocSet(locs())
463
464 def __contains__(self, loc):
465 # type: (Loc | Any) -> bool
466 if not isinstance(loc, Loc) or loc.ty != self.ty:
467 return False
468 if loc.kind not in self.starts:
469 return False
470 return loc.start in self.starts[loc.kind]
471
472 def __iter__(self):
473 # type: () -> Iterator[Loc]
474 if self.ty is None:
475 return
476 for kind, starts in self.starts.items():
477 for start in starts:
478 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
479
480 @cached_property
481 def __len(self):
482 return sum((len(v) for v in self.starts.values()), 0)
483
484 def __len__(self):
485 return self.__len
486
487 @cached_property
488 def __hash(self):
489 return super()._hash()
490
491 def __hash__(self):
492 return self.__hash
493
494
495 @plain_data(frozen=True, unsafe_hash=True)
496 @final
497 class GenericOperandDesc:
498 """generic Op operand descriptor"""
499 __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
500
501 def __init__(
502 self, ty, # type: GenericTy
503 sub_kinds, # type: Iterable[LocSubKind]
504 *,
505 fixed_loc=None, # type: Loc | None
506 tied_input_index=None, # type: int | None
507 spread=False, # type: bool
508 ):
509 # type: (...) -> None
510 self.ty = ty
511 self.sub_kinds = OFSet(sub_kinds)
512 if len(self.sub_kinds) == 0:
513 raise ValueError("sub_kinds can't be empty")
514 self.fixed_loc = fixed_loc
515 if fixed_loc is not None:
516 if tied_input_index is not None:
517 raise ValueError("operand can't be both tied and fixed")
518 if not ty.can_instantiate_to(fixed_loc.ty):
519 raise ValueError(
520 f"fixed_loc has incompatible type for given generic "
521 f"type: fixed_loc={fixed_loc} generic ty={ty}")
522 if len(self.sub_kinds) != 1:
523 raise ValueError(
524 "multiple sub_kinds not allowed for fixed operand")
525 for sub_kind in self.sub_kinds:
526 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
527 raise ValueError(
528 f"fixed_loc not in given sub_kind: "
529 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
530 for sub_kind in self.sub_kinds:
531 if sub_kind.base_ty != ty.base_ty:
532 raise ValueError(f"sub_kind is incompatible with type: "
533 f"sub_kind={sub_kind} ty={ty}")
534 if tied_input_index is not None and tied_input_index < 0:
535 raise ValueError("invalid tied_input_index")
536 self.tied_input_index = tied_input_index
537 self.spread = spread
538 if spread:
539 if self.tied_input_index is not None:
540 raise ValueError("operand can't be both spread and tied")
541 if self.fixed_loc is not None:
542 raise ValueError("operand can't be both spread and fixed")
543 if self.ty.is_vec:
544 raise ValueError("operand can't be both spread and vector")
545
546 def tied_to_input(self, tied_input_index):
547 # type: (int) -> Self
548 return GenericOperandDesc(self.ty, self.sub_kinds,
549 tied_input_index=tied_input_index)
550
551 def with_fixed_loc(self, fixed_loc):
552 # type: (Loc) -> Self
553 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
554
555 def instantiate(self, maxvl):
556 # type: (int) -> Iterable[OperandDesc]
557 rep_count = 1
558 if self.spread:
559 rep_count = maxvl
560 maxvl = 1
561 ty = self.ty.instantiate(maxvl=maxvl)
562
563 def locs():
564 # type: () -> Iterable[Loc]
565 if self.fixed_loc is not None:
566 if ty != self.fixed_loc.ty:
567 raise ValueError(
568 f"instantiation failed: type mismatch with fixed_loc: "
569 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
570 yield self.fixed_loc
571 return
572 for sub_kind in self.sub_kinds:
573 yield from sub_kind.allocatable_locs(ty)
574 loc_set_before_spread = LocSet(locs())
575 for idx in range(rep_count):
576 if not self.spread:
577 idx = None
578 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
579 tied_input_index=self.tied_input_index,
580 spread_index=idx)
581
582
583 @plain_data(frozen=True, unsafe_hash=True)
584 @final
585 class OperandDesc:
586 """Op operand descriptor"""
587 __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
588
589 def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
590 # type: (LocSet, int | None, int | None) -> None
591 if len(loc_set_before_spread) == 0:
592 raise ValueError("loc_set_before_spread must not be empty")
593 self.loc_set_before_spread = loc_set_before_spread
594 self.tied_input_index = tied_input_index
595 if self.tied_input_index is not None and self.spread_index is not None:
596 raise ValueError("operand can't be both spread and tied")
597 self.spread_index = spread_index
598
599 @cached_property
600 def ty_before_spread(self):
601 # type: () -> Ty
602 ty = self.loc_set_before_spread.ty
603 assert ty is not None, (
604 "__init__ checked that the LocSet isn't empty, "
605 "non-empty LocSets should always have ty set")
606 return ty
607
608 @cached_property
609 def ty(self):
610 """ Ty after any spread is applied """
611 if self.spread_index is not None:
612 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
613 return self.ty_before_spread
614
615
616 OD_BASE_SGPR = GenericOperandDesc(
617 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
618 sub_kinds=[LocSubKind.BASE_GPR])
619 OD_EXTRA3_SGPR = GenericOperandDesc(
620 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
621 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
622 OD_EXTRA3_VGPR = GenericOperandDesc(
623 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
624 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
625 OD_EXTRA2_SGPR = GenericOperandDesc(
626 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
627 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
628 OD_EXTRA2_VGPR = GenericOperandDesc(
629 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
630 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
631 OD_CA = GenericOperandDesc(
632 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
633 sub_kinds=[LocSubKind.CA])
634 OD_VL = GenericOperandDesc(
635 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
636 sub_kinds=[LocSubKind.VL_MAXVL])
637
638
639 @plain_data(frozen=True, unsafe_hash=True)
640 @final
641 class GenericOpProperties:
642 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
643 "is_copy", "is_load_immediate", "has_side_effects")
644
645 def __init__(
646 self, demo_asm, # type: str
647 inputs, # type: Iterable[GenericOperandDesc]
648 outputs, # type: Iterable[GenericOperandDesc]
649 immediates=(), # type: Iterable[range]
650 is_copy=False, # type: bool
651 is_load_immediate=False, # type: bool
652 has_side_effects=False, # type: bool
653 ):
654 # type: (...) -> None
655 self.demo_asm = demo_asm
656 self.inputs = tuple(inputs)
657 for inp in self.inputs:
658 if inp.tied_input_index is not None:
659 raise ValueError(
660 f"tied_input_index is not allowed on inputs: {inp}")
661 self.outputs = tuple(outputs)
662 fixed_locs = [] # type: list[tuple[Loc, int]]
663 for idx, out in enumerate(self.outputs):
664 if out.tied_input_index is not None:
665 if out.tied_input_index >= len(self.inputs):
666 raise ValueError(f"tied_input_index out of range: {out}")
667 tied_inp = self.inputs[out.tied_input_index]
668 if tied_inp.tied_to_input(out.tied_input_index) != out:
669 raise ValueError(f"output can't be tied to non-equivalent "
670 f"input: {out} tied to {tied_inp}")
671 if out.fixed_loc is not None:
672 for other_fixed_loc, other_idx in fixed_locs:
673 if not other_fixed_loc.conflicts(out.fixed_loc):
674 continue
675 raise ValueError(
676 f"conflicting fixed_locs: outputs[{idx}] and "
677 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
678 f"with {other_fixed_loc}")
679 fixed_locs.append((out.fixed_loc, idx))
680 self.immediates = tuple(immediates)
681 self.is_copy = is_copy
682 self.is_load_immediate = is_load_immediate
683 self.has_side_effects = has_side_effects
684
685
686 @plain_data(frozen=True, unsafe_hash=True)
687 @final
688 class OpProperties:
689 __slots__ = "kind", "inputs", "outputs", "maxvl"
690
691 def __init__(self, kind, maxvl):
692 # type: (OpKind, int) -> None
693 self.kind = kind
694 inputs = [] # type: list[OperandDesc]
695 for inp in self.generic.inputs:
696 inputs.extend(inp.instantiate(maxvl=maxvl))
697 self.inputs = tuple(inputs)
698 outputs = [] # type: list[OperandDesc]
699 for out in self.generic.outputs:
700 outputs.extend(out.instantiate(maxvl=maxvl))
701 self.outputs = tuple(outputs)
702 self.maxvl = maxvl
703
704 @property
705 def generic(self):
706 # type: () -> GenericOpProperties
707 return self.kind.properties
708
709 @property
710 def immediates(self):
711 # type: () -> tuple[range, ...]
712 return self.generic.immediates
713
714 @property
715 def demo_asm(self):
716 # type: () -> str
717 return self.generic.demo_asm
718
719 @property
720 def is_copy(self):
721 # type: () -> bool
722 return self.generic.is_copy
723
724 @property
725 def is_load_immediate(self):
726 # type: () -> bool
727 return self.generic.is_load_immediate
728
729 @property
730 def has_side_effects(self):
731 # type: () -> bool
732 return self.generic.has_side_effects
733
734
735 IMM_S16 = range(-1 << 15, 1 << 15)
736
737 _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
738 _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
739 _PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
740
741
742 @unique
743 @final
744 class OpKind(Enum):
745 def __init__(self, properties):
746 # type: (GenericOpProperties) -> None
747 super().__init__()
748 self.__properties = properties
749
750 @property
751 def properties(self):
752 # type: () -> GenericOpProperties
753 return self.__properties
754
755 def instantiate(self, maxvl):
756 # type: (int) -> OpProperties
757 return OpProperties(self, maxvl=maxvl)
758
759 def __repr__(self):
760 return "OpKind." + self._name_
761
762 @cached_property
763 def pre_ra_sim(self):
764 # type: () -> _PRE_RA_SIM_FN
765 return _PRE_RA_SIMS[self.properties]()
766
767 @staticmethod
768 def __clearca_pre_ra_sim(op, state):
769 # type: (Op, PreRASimState) -> None
770 state.ssa_vals[op.outputs[0]] = False,
771 ClearCA = GenericOpProperties(
772 demo_asm="addic 0, 0, 0",
773 inputs=[],
774 outputs=[OD_CA],
775 )
776 _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
777
778 @staticmethod
779 def __setca_pre_ra_sim(op, state):
780 # type: (Op, PreRASimState) -> None
781 state.ssa_vals[op.outputs[0]] = True,
782 SetCA = GenericOpProperties(
783 demo_asm="subfc 0, 0, 0",
784 inputs=[],
785 outputs=[OD_CA],
786 )
787 _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
788
789 @staticmethod
790 def __svadde_pre_ra_sim(op, state):
791 # type: (Op, PreRASimState) -> None
792 RA = state.ssa_vals[op.inputs[0]]
793 RB = state.ssa_vals[op.inputs[1]]
794 carry, = state.ssa_vals[op.inputs[2]]
795 VL, = state.ssa_vals[op.inputs[3]]
796 RT = [] # type: list[int]
797 for i in range(VL):
798 v = RA[i] + RB[i] + carry
799 RT.append(v & GPR_VALUE_MASK)
800 carry = (v >> GPR_SIZE_IN_BITS) != 0
801 state.ssa_vals[op.outputs[0]] = tuple(RT)
802 state.ssa_vals[op.outputs[1]] = carry,
803 SvAddE = GenericOpProperties(
804 demo_asm="sv.adde *RT, *RA, *RB",
805 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
806 outputs=[OD_EXTRA3_VGPR, OD_CA],
807 )
808 _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
809
810 @staticmethod
811 def __svsubfe_pre_ra_sim(op, state):
812 # type: (Op, PreRASimState) -> None
813 RA = state.ssa_vals[op.inputs[0]]
814 RB = state.ssa_vals[op.inputs[1]]
815 carry, = state.ssa_vals[op.inputs[2]]
816 VL, = state.ssa_vals[op.inputs[3]]
817 RT = [] # type: list[int]
818 for i in range(VL):
819 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
820 RT.append(v & GPR_VALUE_MASK)
821 carry = (v >> GPR_SIZE_IN_BITS) != 0
822 state.ssa_vals[op.outputs[0]] = tuple(RT)
823 state.ssa_vals[op.outputs[1]] = carry,
824 SvSubFE = GenericOpProperties(
825 demo_asm="sv.subfe *RT, *RA, *RB",
826 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
827 outputs=[OD_EXTRA3_VGPR, OD_CA],
828 )
829 _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
830
831 @staticmethod
832 def __svmaddedu_pre_ra_sim(op, state):
833 # type: (Op, PreRASimState) -> None
834 RA = state.ssa_vals[op.inputs[0]]
835 RB, = state.ssa_vals[op.inputs[1]]
836 carry, = state.ssa_vals[op.inputs[2]]
837 VL, = state.ssa_vals[op.inputs[3]]
838 RT = [] # type: list[int]
839 for i in range(VL):
840 v = RA[i] * RB + carry
841 RT.append(v & GPR_VALUE_MASK)
842 carry = v >> GPR_SIZE_IN_BITS
843 state.ssa_vals[op.outputs[0]] = tuple(RT)
844 state.ssa_vals[op.outputs[1]] = carry,
845 SvMAddEDU = GenericOpProperties(
846 demo_asm="sv.maddedu *RT, *RA, RB, RC",
847 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
848 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
849 )
850 _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
851
852 @staticmethod
853 def __setvli_pre_ra_sim(op, state):
854 # type: (Op, PreRASimState) -> None
855 state.ssa_vals[op.outputs[0]] = op.immediates[0],
856 SetVLI = GenericOpProperties(
857 demo_asm="setvl 0, 0, imm, 0, 1, 1",
858 inputs=(),
859 outputs=[OD_VL],
860 immediates=[range(1, 65)],
861 is_load_immediate=True,
862 )
863 _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
864
865 @staticmethod
866 def __svli_pre_ra_sim(op, state):
867 # type: (Op, PreRASimState) -> None
868 VL, = state.ssa_vals[op.inputs[0]]
869 imm = op.immediates[0] & GPR_VALUE_MASK
870 state.ssa_vals[op.outputs[0]] = (imm,) * VL
871 SvLI = GenericOpProperties(
872 demo_asm="sv.addi *RT, 0, imm",
873 inputs=[OD_VL],
874 outputs=[OD_EXTRA3_VGPR],
875 immediates=[IMM_S16],
876 is_load_immediate=True,
877 )
878 _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
879
880 @staticmethod
881 def __li_pre_ra_sim(op, state):
882 # type: (Op, PreRASimState) -> None
883 imm = op.immediates[0] & GPR_VALUE_MASK
884 state.ssa_vals[op.outputs[0]] = imm,
885 LI = GenericOpProperties(
886 demo_asm="addi RT, 0, imm",
887 inputs=(),
888 outputs=[OD_BASE_SGPR],
889 immediates=[IMM_S16],
890 is_load_immediate=True,
891 )
892 _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
893
894 @staticmethod
895 def __veccopytoreg_pre_ra_sim(op, state):
896 # type: (Op, PreRASimState) -> None
897 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
898 VecCopyToReg = GenericOpProperties(
899 demo_asm="sv.mv dest, src",
900 inputs=[GenericOperandDesc(
901 ty=GenericTy(BaseTy.I64, is_vec=True),
902 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
903 ), OD_VL],
904 outputs=[OD_EXTRA3_VGPR],
905 is_copy=True,
906 )
907 _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
908
909 @staticmethod
910 def __veccopyfromreg_pre_ra_sim(op, state):
911 # type: (Op, PreRASimState) -> None
912 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
913 VecCopyFromReg = GenericOpProperties(
914 demo_asm="sv.mv dest, src",
915 inputs=[OD_EXTRA3_VGPR, OD_VL],
916 outputs=[GenericOperandDesc(
917 ty=GenericTy(BaseTy.I64, is_vec=True),
918 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
919 )],
920 is_copy=True,
921 )
922 _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
923
924 @staticmethod
925 def __copytoreg_pre_ra_sim(op, state):
926 # type: (Op, PreRASimState) -> None
927 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
928 CopyToReg = GenericOpProperties(
929 demo_asm="mv dest, src",
930 inputs=[GenericOperandDesc(
931 ty=GenericTy(BaseTy.I64, is_vec=False),
932 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
933 LocSubKind.StackI64],
934 )],
935 outputs=[GenericOperandDesc(
936 ty=GenericTy(BaseTy.I64, is_vec=False),
937 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
938 )],
939 is_copy=True,
940 )
941 _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
942
943 @staticmethod
944 def __copyfromreg_pre_ra_sim(op, state):
945 # type: (Op, PreRASimState) -> None
946 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
947 CopyFromReg = GenericOpProperties(
948 demo_asm="mv dest, src",
949 inputs=[GenericOperandDesc(
950 ty=GenericTy(BaseTy.I64, is_vec=False),
951 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
952 )],
953 outputs=[GenericOperandDesc(
954 ty=GenericTy(BaseTy.I64, is_vec=False),
955 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
956 LocSubKind.StackI64],
957 )],
958 is_copy=True,
959 )
960 _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
961
962 @staticmethod
963 def __concat_pre_ra_sim(op, state):
964 # type: (Op, PreRASimState) -> None
965 state.ssa_vals[op.outputs[0]] = tuple(
966 state.ssa_vals[i][0] for i in op.inputs[:-1])
967 Concat = GenericOpProperties(
968 demo_asm="sv.mv dest, src",
969 inputs=[GenericOperandDesc(
970 ty=GenericTy(BaseTy.I64, is_vec=False),
971 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
972 spread=True,
973 ), OD_VL],
974 outputs=[OD_EXTRA3_VGPR],
975 is_copy=True,
976 )
977 _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
978
979 @staticmethod
980 def __spread_pre_ra_sim(op, state):
981 # type: (Op, PreRASimState) -> None
982 for idx, inp in enumerate(state.ssa_vals[op.inputs[0]]):
983 state.ssa_vals[op.outputs[idx]] = inp,
984 Spread = GenericOpProperties(
985 demo_asm="sv.mv dest, src",
986 inputs=[OD_EXTRA3_VGPR, OD_VL],
987 outputs=[GenericOperandDesc(
988 ty=GenericTy(BaseTy.I64, is_vec=False),
989 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
990 spread=True,
991 )],
992 is_copy=True,
993 )
994 _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
995
996 @staticmethod
997 def __svld_pre_ra_sim(op, state):
998 # type: (Op, PreRASimState) -> None
999 RA, = state.ssa_vals[op.inputs[0]]
1000 VL, = state.ssa_vals[op.inputs[1]]
1001 addr = RA + op.immediates[0]
1002 RT = [] # type: list[int]
1003 for i in range(VL):
1004 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1005 RT.append(v & GPR_VALUE_MASK)
1006 state.ssa_vals[op.outputs[0]] = tuple(RT)
1007 SvLd = GenericOpProperties(
1008 demo_asm="sv.ld *RT, imm(RA)",
1009 inputs=[OD_EXTRA3_SGPR, OD_VL],
1010 outputs=[OD_EXTRA3_VGPR],
1011 immediates=[IMM_S16],
1012 )
1013 _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
1014
1015 @staticmethod
1016 def __ld_pre_ra_sim(op, state):
1017 # type: (Op, PreRASimState) -> None
1018 RA, = state.ssa_vals[op.inputs[0]]
1019 addr = RA + op.immediates[0]
1020 v = state.load(addr)
1021 state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
1022 Ld = GenericOpProperties(
1023 demo_asm="ld RT, imm(RA)",
1024 inputs=[OD_BASE_SGPR],
1025 outputs=[OD_BASE_SGPR],
1026 immediates=[IMM_S16],
1027 )
1028 _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
1029
1030 @staticmethod
1031 def __svstd_pre_ra_sim(op, state):
1032 # type: (Op, PreRASimState) -> None
1033 RS = state.ssa_vals[op.inputs[0]]
1034 RA, = state.ssa_vals[op.inputs[1]]
1035 VL, = state.ssa_vals[op.inputs[2]]
1036 addr = RA + op.immediates[0]
1037 for i in range(VL):
1038 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1039 SvStd = GenericOpProperties(
1040 demo_asm="sv.std *RS, imm(RA)",
1041 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1042 outputs=[],
1043 immediates=[IMM_S16],
1044 has_side_effects=True,
1045 )
1046 _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
1047
1048 @staticmethod
1049 def __std_pre_ra_sim(op, state):
1050 # type: (Op, PreRASimState) -> None
1051 RS, = state.ssa_vals[op.inputs[0]]
1052 RA, = state.ssa_vals[op.inputs[1]]
1053 addr = RA + op.immediates[0]
1054 state.store(addr, value=RS)
1055 Std = GenericOpProperties(
1056 demo_asm="std RT, imm(RA)",
1057 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1058 outputs=[],
1059 immediates=[IMM_S16],
1060 has_side_effects=True,
1061 )
1062 _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
1063
1064 @staticmethod
1065 def __funcargr3_pre_ra_sim(op, state):
1066 # type: (Op, PreRASimState) -> None
1067 pass # return value set before simulation
1068 FuncArgR3 = GenericOpProperties(
1069 demo_asm="",
1070 inputs=[],
1071 outputs=[OD_BASE_SGPR.with_fixed_loc(
1072 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1073 )
1074 _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
1075
1076
1077 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1078 @final
1079 class SSAVal:
1080 __slots__ = "op", "output_idx"
1081
1082 def __init__(self, op, output_idx):
1083 # type: (Op, int) -> None
1084 self.op = op
1085 if output_idx < 0 or output_idx >= len(op.properties.outputs):
1086 raise ValueError("invalid output_idx")
1087 self.output_idx = output_idx
1088
1089 def __repr__(self):
1090 # type: () -> str
1091 return f"<{self.op.name}#{self.output_idx}: {self.ty}>"
1092
1093 @cached_property
1094 def defining_descriptor(self):
1095 # type: () -> OperandDesc
1096 return self.op.properties.outputs[self.output_idx]
1097
1098 @cached_property
1099 def loc_set_before_spread(self):
1100 # type: () -> LocSet
1101 return self.defining_descriptor.loc_set_before_spread
1102
1103 @cached_property
1104 def ty(self):
1105 # type: () -> Ty
1106 return self.defining_descriptor.ty
1107
1108 @cached_property
1109 def ty_before_spread(self):
1110 # type: () -> Ty
1111 return self.defining_descriptor.ty_before_spread
1112
1113
1114 _T = TypeVar("_T")
1115 _Desc = TypeVar("_Desc")
1116
1117
1118 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1119 @abstractmethod
1120 def _verify_write_with_desc(self, idx, item, desc):
1121 # type: (int, _T | Any, _Desc) -> None
1122 raise NotImplementedError
1123
1124 @final
1125 def _verify_write(self, idx, item):
1126 # type: (int | Any, _T | Any) -> int
1127 if not isinstance(idx, int):
1128 if isinstance(idx, slice):
1129 raise TypeError(
1130 f"can't write to slice of {self.__class__.__name__}")
1131 raise TypeError(f"can't write with index {idx!r}")
1132 # normalize idx, raising IndexError if it is out of range
1133 idx = range(len(self.descriptors))[idx]
1134 desc = self.descriptors[idx]
1135 self._verify_write_with_desc(idx, item, desc)
1136 return idx
1137
1138 @abstractmethod
1139 def _get_descriptors(self):
1140 # type: () -> tuple[_Desc, ...]
1141 raise NotImplementedError
1142
1143 @cached_property
1144 @final
1145 def descriptors(self):
1146 # type: () -> tuple[_Desc, ...]
1147 return self._get_descriptors()
1148
1149 @property
1150 @final
1151 def op(self):
1152 return self.__op
1153
1154 def __init__(self, items, op):
1155 # type: (Iterable[_T], Op) -> None
1156 self.__op = op
1157 self.__items = [] # type: list[_T]
1158 for idx, item in enumerate(items):
1159 if idx >= len(self.descriptors):
1160 raise ValueError("too many items")
1161 self._verify_write(idx, item)
1162 self.__items.append(item)
1163 if len(self.__items) < len(self.descriptors):
1164 raise ValueError("not enough items")
1165
1166 @final
1167 def __iter__(self):
1168 # type: () -> Iterator[_T]
1169 yield from self.__items
1170
1171 @overload
1172 def __getitem__(self, idx):
1173 # type: (int) -> _T
1174 ...
1175
1176 @overload
1177 def __getitem__(self, idx):
1178 # type: (slice) -> list[_T]
1179 ...
1180
1181 @final
1182 def __getitem__(self, idx):
1183 # type: (int | slice) -> _T | list[_T]
1184 return self.__items[idx]
1185
1186 @final
1187 def __setitem__(self, idx, item):
1188 # type: (int, _T) -> None
1189 idx = self._verify_write(idx, item)
1190 self.__items[idx] = item
1191
1192 @final
1193 def __len__(self):
1194 # type: () -> int
1195 return len(self.__items)
1196
1197 def __repr__(self):
1198 return f"{self.__class__.__name__}({self.__items}, op=...)"
1199
1200
1201 @final
1202 class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
1203 def _get_descriptors(self):
1204 # type: () -> tuple[OperandDesc, ...]
1205 return self.op.properties.inputs
1206
1207 def _verify_write_with_desc(self, idx, item, desc):
1208 # type: (int, SSAVal | Any, OperandDesc) -> None
1209 if not isinstance(item, SSAVal):
1210 raise TypeError("expected value of type SSAVal")
1211 if item.ty != desc.ty:
1212 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1213 f"corresponding input's type {desc.ty!r}")
1214
1215 def __init__(self, items, op):
1216 # type: (Iterable[SSAVal], Op) -> None
1217 if hasattr(op, "inputs"):
1218 raise ValueError("Op.inputs already set")
1219 super().__init__(items, op)
1220
1221
1222 @final
1223 class OpImmediates(OpInputSeq[int, range]):
1224 def _get_descriptors(self):
1225 # type: () -> tuple[range, ...]
1226 return self.op.properties.immediates
1227
1228 def _verify_write_with_desc(self, idx, item, desc):
1229 # type: (int, int | Any, range) -> None
1230 if not isinstance(item, int):
1231 raise TypeError("expected value of type int")
1232 if item not in desc:
1233 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1234
1235 def __init__(self, items, op):
1236 # type: (Iterable[int], Op) -> None
1237 if hasattr(op, "immediates"):
1238 raise ValueError("Op.immediates already set")
1239 super().__init__(items, op)
1240
1241
1242 @plain_data(frozen=True, eq=False, repr=False)
1243 @final
1244 class Op:
1245 __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
1246
1247 def __init__(self, fn, properties, inputs, immediates, name=""):
1248 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1249 self.fn = fn
1250 self.properties = properties
1251 self.inputs = OpInputs(inputs, op=self)
1252 self.immediates = OpImmediates(immediates, op=self)
1253 outputs_len = len(self.properties.outputs)
1254 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1255 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1256
1257 @property
1258 def kind(self):
1259 return self.properties.kind
1260
1261 def __eq__(self, other):
1262 # type: (Op | Any) -> bool
1263 if isinstance(other, Op):
1264 return self is other
1265 return NotImplemented
1266
1267 def __hash__(self):
1268 return object.__hash__(self)
1269
1270 def __repr__(self):
1271 # type: () -> str
1272 field_vals = [] # type: list[str]
1273 for name in fields(self):
1274 if name == "properties":
1275 name = "kind"
1276 elif name == "fn":
1277 continue
1278 try:
1279 value = getattr(self, name)
1280 except AttributeError:
1281 field_vals.append(f"{name}=<not set>")
1282 continue
1283 if isinstance(value, OpInputSeq):
1284 value = list(value) # type: ignore
1285 field_vals.append(f"{name}={value!r}")
1286 field_vals_str = ", ".join(field_vals)
1287 return f"Op({field_vals_str})"
1288
1289 def pre_ra_sim(self, state):
1290 # type: (PreRASimState) -> None
1291 for inp in self.inputs:
1292 if inp not in state.ssa_vals:
1293 raise ValueError(f"SSAVal {inp} not yet assigned when "
1294 f"running {self}")
1295 if len(state.ssa_vals[inp]) != inp.ty.reg_len:
1296 raise ValueError(
1297 f"value of SSAVal {inp} has wrong number of elements: "
1298 f"expected {inp.ty.reg_len} found "
1299 f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1300 for out in self.outputs:
1301 if out in state.ssa_vals:
1302 if self.kind is OpKind.FuncArgR3:
1303 continue
1304 raise ValueError(f"SSAVal {out} already assigned before "
1305 f"running {self}")
1306 self.kind.pre_ra_sim(self, state)
1307 for out in self.outputs:
1308 if out not in state.ssa_vals:
1309 raise ValueError(f"running {self} failed to assign to {out}")
1310 if len(state.ssa_vals[out]) != out.ty.reg_len:
1311 raise ValueError(
1312 f"value of SSAVal {out} has wrong number of elements: "
1313 f"expected {out.ty.reg_len} found "
1314 f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1315
1316
1317 GPR_SIZE_IN_BYTES = 8
1318 BITS_IN_BYTE = 8
1319 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
1320 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
1321
1322
1323 @plain_data(frozen=True, repr=False)
1324 @final
1325 class PreRASimState:
1326 __slots__ = "ssa_vals", "memory"
1327
1328 def __init__(self, ssa_vals, memory):
1329 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1330 self.ssa_vals = ssa_vals
1331 self.memory = memory
1332
1333 def load_byte(self, addr):
1334 # type: (int) -> int
1335 addr &= GPR_VALUE_MASK
1336 return self.memory.get(addr, 0) & 0xFF
1337
1338 def store_byte(self, addr, value):
1339 # type: (int, int) -> None
1340 addr &= GPR_VALUE_MASK
1341 value &= 0xFF
1342 self.memory[addr] = value
1343
1344 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
1345 # type: (int, int, bool) -> int
1346 if addr % size_in_bytes != 0:
1347 raise ValueError(f"address not aligned: {hex(addr)} "
1348 f"required alignment: {size_in_bytes}")
1349 retval = 0
1350 for i in range(size_in_bytes):
1351 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
1352 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
1353 retval -= 1 << size_in_bytes * BITS_IN_BYTE
1354 return retval
1355
1356 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
1357 # type: (int, int, int) -> None
1358 if addr % size_in_bytes != 0:
1359 raise ValueError(f"address not aligned: {hex(addr)} "
1360 f"required alignment: {size_in_bytes}")
1361 for i in range(size_in_bytes):
1362 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
1363
1364 def _memory__repr(self):
1365 # type: () -> str
1366 if len(self.memory) == 0:
1367 return "{}"
1368 keys = sorted(self.memory.keys(), reverse=True)
1369 CHUNK_SIZE = GPR_SIZE_IN_BYTES
1370 items = [] # type: list[str]
1371 while len(keys) != 0:
1372 addr = keys[-1]
1373 if (len(keys) >= CHUNK_SIZE
1374 and addr % CHUNK_SIZE == 0
1375 and keys[-CHUNK_SIZE:]
1376 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
1377 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
1378 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1379 keys[-CHUNK_SIZE:] = ()
1380 else:
1381 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1382 if len(items) == 1:
1383 return f"{{{items[0]}}}"
1384 items_str = ",\n".join(items)
1385 return f"{{\n{items_str}}}"
1386
1387 def _ssa_vals__repr(self):
1388 # type: () -> str
1389 if len(self.ssa_vals) == 0:
1390 return "{}"
1391 items = [] # type: list[str]
1392 CHUNK_SIZE = 4
1393 for k, v in self.ssa_vals.items():
1394 element_strs = [] # type: list[str]
1395 for i, el in enumerate(v):
1396 if i % CHUNK_SIZE != 0:
1397 element_strs.append(" " + hex(el))
1398 else:
1399 element_strs.append("\n " + hex(el))
1400 if len(element_strs) <= CHUNK_SIZE:
1401 element_strs[0] = element_strs[0].lstrip()
1402 if len(element_strs) == 1:
1403 element_strs.append("")
1404 v_str = ",".join(element_strs)
1405 items.append(f"{k!r}: ({v_str})")
1406 if len(items) == 1 and "\n" not in items[0]:
1407 return f"{{{items[0]}}}"
1408 items_str = ",\n".join(items)
1409 return f"{{\n{items_str},\n}}"
1410
1411 def __repr__(self):
1412 # type: () -> str
1413 field_vals = [] # type: list[str]
1414 for name in fields(self):
1415 try:
1416 value = getattr(self, name)
1417 except AttributeError:
1418 field_vals.append(f"{name}=<not set>")
1419 continue
1420 repr_fn = getattr(self, f"_{name}__repr", None)
1421 if callable(repr_fn):
1422 field_vals.append(f"{name}={repr_fn()}")
1423 else:
1424 field_vals.append(f"{name}={value!r}")
1425 field_vals_str = ", ".join(field_vals)
1426 return f"PreRASimState({field_vals_str})"