382d66c457f5dacbe45561f262e464b1b9fa84f6
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 import enum
2 from abc import abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache
5 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
6 Sequence, TypeVar, overload)
7 from weakref import WeakValueDictionary as _WeakVDict
8
9 from cached_property import cached_property
10 from nmutil.plain_data import fields, plain_data
11
12 from bigint_presentation_code.type_util import Self, assert_never, final
13 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
14
15
16 @final
17 class Fn:
18 def __init__(self):
19 self.ops = [] # type: list[Op]
20 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
21 self.__next_name_suffix = 2
22
23 def _add_op_with_unused_name(self, op, name=""):
24 # type: (Op, str) -> str
25 if op.fn is not self:
26 raise ValueError("can't add Op to wrong Fn")
27 if hasattr(op, "name"):
28 raise ValueError("Op already named")
29 orig_name = name
30 while True:
31 if name != "" and name not in self.__op_names:
32 self.__op_names[name] = op
33 return name
34 name = orig_name + str(self.__next_name_suffix)
35 self.__next_name_suffix += 1
36
37 def __repr__(self):
38 # type: () -> str
39 return "<Fn>"
40
41 def append_op(self, op):
42 # type: (Op) -> None
43 if op.fn is not self:
44 raise ValueError("can't add Op to wrong Fn")
45 self.ops.append(op)
46
47 def append_new_op(self, kind, inputs=(), immediates=(), name="", maxvl=1):
48 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
49 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
50 inputs=inputs, immediates=immediates, name=name)
51 self.append_op(retval)
52 return retval
53
54 def pre_ra_sim(self, state):
55 # type: (PreRASimState) -> None
56 for op in self.ops:
57 op.pre_ra_sim(state)
58
59 def pre_ra_insert_copies(self):
60 # type: () -> None
61 orig_ops = list(self.ops)
62 copied_outputs = {} # type: dict[SSAVal, SSAVal]
63 self.ops.clear()
64 for op in orig_ops:
65 for i in range(len(op.inputs)):
66 inp = copied_outputs[op.inputs[i]]
67 if inp.ty.base_ty is BaseTy.I64:
68 maxvl = inp.ty.reg_len
69 if inp.ty.reg_len != 1:
70 setvl = self.append_new_op(OpKind.SetVLI,
71 immediates=[maxvl])
72 vl = setvl.outputs[0]
73 mv = self.append_new_op(OpKind.VecCopyToReg,
74 inputs=[inp, vl], maxvl=maxvl)
75 else:
76 mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp])
77 op.inputs[i] = mv.outputs[0]
78 elif inp.ty.base_ty is BaseTy.CA \
79 or inp.ty.base_ty is BaseTy.VL_MAXVL:
80 # all copies would be no-ops, so we don't need to copy
81 op.inputs[i] = inp
82 else:
83 assert_never(inp.ty.base_ty)
84 self.ops.append(op)
85 for out in op.outputs:
86 if out.ty.base_ty is BaseTy.I64:
87 maxvl = out.ty.reg_len
88 if out.ty.reg_len != 1:
89 setvl = self.append_new_op(OpKind.SetVLI,
90 immediates=[maxvl])
91 vl = setvl.outputs[0]
92 mv = self.append_new_op(OpKind.VecCopyFromReg,
93 inputs=[out, vl], maxvl=maxvl)
94 else:
95 mv = self.append_new_op(OpKind.CopyFromReg,
96 inputs=[out])
97 copied_outputs[out] = mv.outputs[0]
98 elif out.ty.base_ty is BaseTy.CA \
99 or out.ty.base_ty is BaseTy.VL_MAXVL:
100 # all copies would be no-ops, so we don't need to copy
101 copied_outputs[out] = out
102 else:
103 assert_never(out.ty.base_ty)
104
105
106 @unique
107 @final
108 class BaseTy(Enum):
109 I64 = enum.auto()
110 CA = enum.auto()
111 VL_MAXVL = enum.auto()
112
113 @cached_property
114 def only_scalar(self):
115 # type: () -> bool
116 if self is BaseTy.I64:
117 return False
118 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
119 return True
120 else:
121 assert_never(self)
122
123 @cached_property
124 def max_reg_len(self):
125 # type: () -> int
126 if self is BaseTy.I64:
127 return 128
128 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
129 return 1
130 else:
131 assert_never(self)
132
133 def __repr__(self):
134 return "BaseTy." + self._name_
135
136
137 @plain_data(frozen=True, unsafe_hash=True, repr=False)
138 @final
139 class Ty:
140 __slots__ = "base_ty", "reg_len"
141
142 @staticmethod
143 def validate(base_ty, reg_len):
144 # type: (BaseTy, int) -> str | None
145 """ return a string with the error if the combination is invalid,
146 otherwise return None
147 """
148 if base_ty.only_scalar and reg_len != 1:
149 return f"can't create a vector of an only-scalar type: {base_ty}"
150 if reg_len < 1 or reg_len > base_ty.max_reg_len:
151 return "reg_len out of range"
152 return None
153
154 def __init__(self, base_ty, reg_len):
155 # type: (BaseTy, int) -> None
156 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
157 if msg is not None:
158 raise ValueError(msg)
159 self.base_ty = base_ty
160 self.reg_len = reg_len
161
162 def __repr__(self):
163 # type: () -> str
164 if self.reg_len != 1:
165 reg_len = f"*{self.reg_len}"
166 else:
167 reg_len = ""
168 return f"<{self.base_ty._name_}{reg_len}>"
169
170
171 @unique
172 @final
173 class LocKind(Enum):
174 GPR = enum.auto()
175 StackI64 = enum.auto()
176 CA = enum.auto()
177 VL_MAXVL = enum.auto()
178
179 @cached_property
180 def base_ty(self):
181 # type: () -> BaseTy
182 if self is LocKind.GPR or self is LocKind.StackI64:
183 return BaseTy.I64
184 if self is LocKind.CA:
185 return BaseTy.CA
186 if self is LocKind.VL_MAXVL:
187 return BaseTy.VL_MAXVL
188 else:
189 assert_never(self)
190
191 @cached_property
192 def loc_count(self):
193 # type: () -> int
194 if self is LocKind.StackI64:
195 return 1024
196 if self is LocKind.GPR or self is LocKind.CA \
197 or self is LocKind.VL_MAXVL:
198 return self.base_ty.max_reg_len
199 else:
200 assert_never(self)
201
202 def __repr__(self):
203 return "LocKind." + self._name_
204
205
206 @final
207 @unique
208 class LocSubKind(Enum):
209 BASE_GPR = enum.auto()
210 SV_EXTRA2_VGPR = enum.auto()
211 SV_EXTRA2_SGPR = enum.auto()
212 SV_EXTRA3_VGPR = enum.auto()
213 SV_EXTRA3_SGPR = enum.auto()
214 StackI64 = enum.auto()
215 CA = enum.auto()
216 VL_MAXVL = enum.auto()
217
218 @cached_property
219 def kind(self):
220 # type: () -> LocKind
221 # pyright fails typechecking when using `in` here:
222 # reported: https://github.com/microsoft/pyright/issues/4102
223 if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
224 or self is LocSubKind.SV_EXTRA2_SGPR \
225 or self is LocSubKind.SV_EXTRA3_VGPR \
226 or self is LocSubKind.SV_EXTRA3_SGPR:
227 return LocKind.GPR
228 if self is LocSubKind.StackI64:
229 return LocKind.StackI64
230 if self is LocSubKind.CA:
231 return LocKind.CA
232 if self is LocSubKind.VL_MAXVL:
233 return LocKind.VL_MAXVL
234 assert_never(self)
235
236 @property
237 def base_ty(self):
238 return self.kind.base_ty
239
240 @lru_cache()
241 def allocatable_locs(self, ty):
242 # type: (Ty) -> LocSet
243 if ty.base_ty != self.base_ty:
244 raise ValueError("type mismatch")
245 if self is LocSubKind.BASE_GPR:
246 starts = range(32)
247 elif self is LocSubKind.SV_EXTRA2_VGPR:
248 starts = range(0, 128, 2)
249 elif self is LocSubKind.SV_EXTRA2_SGPR:
250 starts = range(64)
251 elif self is LocSubKind.SV_EXTRA3_VGPR \
252 or self is LocSubKind.SV_EXTRA3_SGPR:
253 starts = range(128)
254 elif self is LocSubKind.StackI64:
255 starts = range(LocKind.StackI64.loc_count)
256 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
257 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
258 else:
259 assert_never(self)
260 retval = [] # type: list[Loc]
261 for start in starts:
262 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
263 if loc is None:
264 continue
265 conflicts = False
266 for special_loc in SPECIAL_GPRS:
267 if loc.conflicts(special_loc):
268 conflicts = True
269 break
270 if not conflicts:
271 retval.append(loc)
272 return LocSet(retval)
273
274 def __repr__(self):
275 return "LocSubKind." + self._name_
276
277
278 @plain_data(frozen=True, unsafe_hash=True)
279 @final
280 class GenericTy:
281 __slots__ = "base_ty", "is_vec"
282
283 def __init__(self, base_ty, is_vec):
284 # type: (BaseTy, bool) -> None
285 self.base_ty = base_ty
286 if base_ty.only_scalar and is_vec:
287 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
288 self.is_vec = is_vec
289
290 def instantiate(self, maxvl):
291 # type: (int) -> Ty
292 # here's where subvl and elwid would be accounted for
293 if self.is_vec:
294 return Ty(self.base_ty, maxvl)
295 return Ty(self.base_ty, 1)
296
297 def can_instantiate_to(self, ty):
298 # type: (Ty) -> bool
299 if self.base_ty != ty.base_ty:
300 return False
301 if self.is_vec:
302 return True
303 return ty.reg_len == 1
304
305
306 @plain_data(frozen=True, unsafe_hash=True)
307 @final
308 class Loc:
309 __slots__ = "kind", "start", "reg_len"
310
311 @staticmethod
312 def validate(kind, start, reg_len):
313 # type: (LocKind, int, int) -> str | None
314 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
315 if msg is not None:
316 return msg
317 if reg_len > kind.loc_count:
318 return "invalid reg_len"
319 if start < 0 or start + reg_len > kind.loc_count:
320 return "start not in valid range"
321 return None
322
323 @staticmethod
324 def try_make(kind, start, reg_len):
325 # type: (LocKind, int, int) -> Loc | None
326 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
327 if msg is not None:
328 return None
329 return Loc(kind=kind, start=start, reg_len=reg_len)
330
331 def __init__(self, kind, start, reg_len):
332 # type: (LocKind, int, int) -> None
333 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
334 if msg is not None:
335 raise ValueError(msg)
336 self.kind = kind
337 self.reg_len = reg_len
338 self.start = start
339
340 def conflicts(self, other):
341 # type: (Loc) -> bool
342 return (self.kind == other.kind
343 and self.start < other.stop and other.start < self.stop)
344
345 @staticmethod
346 def make_ty(kind, reg_len):
347 # type: (LocKind, int) -> Ty
348 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
349
350 @cached_property
351 def ty(self):
352 # type: () -> Ty
353 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
354
355 @property
356 def stop(self):
357 # type: () -> int
358 return self.start + self.reg_len
359
360 def try_concat(self, *others):
361 # type: (*Loc | None) -> Loc | None
362 reg_len = self.reg_len
363 stop = self.stop
364 for other in others:
365 if other is None or other.kind != self.kind:
366 return None
367 if stop != other.start:
368 return None
369 stop = other.stop
370 reg_len += other.reg_len
371 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
372
373
374 SPECIAL_GPRS = (
375 Loc(kind=LocKind.GPR, start=0, reg_len=1),
376 Loc(kind=LocKind.GPR, start=1, reg_len=1),
377 Loc(kind=LocKind.GPR, start=2, reg_len=1),
378 Loc(kind=LocKind.GPR, start=13, reg_len=1),
379 )
380
381
382 @plain_data(frozen=True, eq=False)
383 @final
384 class LocSet(AbstractSet[Loc]):
385 __slots__ = "starts", "ty"
386
387 def __init__(self, __locs=()):
388 # type: (Iterable[Loc]) -> None
389 if isinstance(__locs, LocSet):
390 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
391 self.ty = __locs.ty # type: Ty | None
392 return
393 starts = {i: BitSet() for i in LocKind}
394 ty = None
395 for loc in __locs:
396 if ty is None:
397 ty = loc.ty
398 if ty != loc.ty:
399 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
400 starts[loc.kind].add(loc.start)
401 self.starts = FMap(
402 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
403 self.ty = ty
404
405 @cached_property
406 def stops(self):
407 # type: () -> FMap[LocKind, FBitSet]
408 if self.ty is None:
409 return FMap()
410 sh = self.ty.reg_len
411 return FMap(
412 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
413
414 @property
415 def kinds(self):
416 # type: () -> AbstractSet[LocKind]
417 return self.starts.keys()
418
419 @property
420 def reg_len(self):
421 # type: () -> int | None
422 if self.ty is None:
423 return None
424 return self.ty.reg_len
425
426 @property
427 def base_ty(self):
428 # type: () -> BaseTy | None
429 if self.ty is None:
430 return None
431 return self.ty.base_ty
432
433 def concat(self, *others):
434 # type: (*LocSet) -> LocSet
435 if self.ty is None:
436 return LocSet()
437 base_ty = self.ty.base_ty
438 reg_len = self.ty.reg_len
439 starts = {k: BitSet(v) for k, v in self.starts.items()}
440 for other in others:
441 if other.ty is None:
442 return LocSet()
443 if other.ty.base_ty != base_ty:
444 return LocSet()
445 for kind, other_starts in other.starts.items():
446 if kind not in starts:
447 continue
448 starts[kind].bits &= other_starts.bits >> reg_len
449 if starts[kind] == 0:
450 del starts[kind]
451 if len(starts) == 0:
452 return LocSet()
453 reg_len += other.ty.reg_len
454
455 def locs():
456 # type: () -> Iterable[Loc]
457 for kind, v in starts.items():
458 for start in v:
459 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
460 if loc is not None:
461 yield loc
462 return LocSet(locs())
463
464 def __contains__(self, loc):
465 # type: (Loc | Any) -> bool
466 if not isinstance(loc, Loc) or loc.ty != self.ty:
467 return False
468 if loc.kind not in self.starts:
469 return False
470 return loc.start in self.starts[loc.kind]
471
472 def __iter__(self):
473 # type: () -> Iterator[Loc]
474 if self.ty is None:
475 return
476 for kind, starts in self.starts.items():
477 for start in starts:
478 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
479
480 @cached_property
481 def __len(self):
482 return sum((len(v) for v in self.starts.values()), 0)
483
484 def __len__(self):
485 return self.__len
486
487 @cached_property
488 def __hash(self):
489 return super()._hash()
490
491 def __hash__(self):
492 return self.__hash
493
494
495 @plain_data(frozen=True, unsafe_hash=True)
496 @final
497 class GenericOperandDesc:
498 """generic Op operand descriptor"""
499 __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
500
501 def __init__(
502 self, ty, # type: GenericTy
503 sub_kinds, # type: Iterable[LocSubKind]
504 *,
505 fixed_loc=None, # type: Loc | None
506 tied_input_index=None, # type: int | None
507 spread=False, # type: bool
508 ):
509 # type: (...) -> None
510 self.ty = ty
511 self.sub_kinds = OFSet(sub_kinds)
512 if len(self.sub_kinds) == 0:
513 raise ValueError("sub_kinds can't be empty")
514 self.fixed_loc = fixed_loc
515 if fixed_loc is not None:
516 if tied_input_index is not None:
517 raise ValueError("operand can't be both tied and fixed")
518 if not ty.can_instantiate_to(fixed_loc.ty):
519 raise ValueError(
520 f"fixed_loc has incompatible type for given generic "
521 f"type: fixed_loc={fixed_loc} generic ty={ty}")
522 if len(self.sub_kinds) != 1:
523 raise ValueError(
524 "multiple sub_kinds not allowed for fixed operand")
525 for sub_kind in self.sub_kinds:
526 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
527 raise ValueError(
528 f"fixed_loc not in given sub_kind: "
529 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
530 for sub_kind in self.sub_kinds:
531 if sub_kind.base_ty != ty.base_ty:
532 raise ValueError(f"sub_kind is incompatible with type: "
533 f"sub_kind={sub_kind} ty={ty}")
534 if tied_input_index is not None and tied_input_index < 0:
535 raise ValueError("invalid tied_input_index")
536 self.tied_input_index = tied_input_index
537 self.spread = spread
538 if spread:
539 if self.tied_input_index is not None:
540 raise ValueError("operand can't be both spread and tied")
541 if self.fixed_loc is not None:
542 raise ValueError("operand can't be both spread and fixed")
543 if self.ty.is_vec:
544 raise ValueError("operand can't be both spread and vector")
545
546 def tied_to_input(self, tied_input_index):
547 # type: (int) -> Self
548 return GenericOperandDesc(self.ty, self.sub_kinds,
549 tied_input_index=tied_input_index)
550
551 def with_fixed_loc(self, fixed_loc):
552 # type: (Loc) -> Self
553 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
554
555 def instantiate(self, maxvl):
556 # type: (int) -> Iterable[OperandDesc]
557 rep_count = 1
558 if self.spread:
559 rep_count = maxvl
560 maxvl = 1
561 ty = self.ty.instantiate(maxvl=maxvl)
562
563 def locs():
564 # type: () -> Iterable[Loc]
565 if self.fixed_loc is not None:
566 if ty != self.fixed_loc.ty:
567 raise ValueError(
568 f"instantiation failed: type mismatch with fixed_loc: "
569 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
570 yield self.fixed_loc
571 return
572 for sub_kind in self.sub_kinds:
573 yield from sub_kind.allocatable_locs(ty)
574 loc_set_before_spread = LocSet(locs())
575 for idx in range(rep_count):
576 if not self.spread:
577 idx = None
578 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
579 tied_input_index=self.tied_input_index,
580 spread_index=idx)
581
582
583 @plain_data(frozen=True, unsafe_hash=True)
584 @final
585 class OperandDesc:
586 """Op operand descriptor"""
587 __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
588
589 def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
590 # type: (LocSet, int | None, int | None) -> None
591 if len(loc_set_before_spread) == 0:
592 raise ValueError("loc_set_before_spread must not be empty")
593 self.loc_set_before_spread = loc_set_before_spread
594 self.tied_input_index = tied_input_index
595 if self.tied_input_index is not None and self.spread_index is not None:
596 raise ValueError("operand can't be both spread and tied")
597 self.spread_index = spread_index
598
599 @cached_property
600 def ty_before_spread(self):
601 # type: () -> Ty
602 ty = self.loc_set_before_spread.ty
603 assert ty is not None, (
604 "__init__ checked that the LocSet isn't empty, "
605 "non-empty LocSets should always have ty set")
606 return ty
607
608 @cached_property
609 def ty(self):
610 """ Ty after any spread is applied """
611 if self.spread_index is not None:
612 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
613 return self.ty_before_spread
614
615
616 OD_BASE_SGPR = GenericOperandDesc(
617 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
618 sub_kinds=[LocSubKind.BASE_GPR])
619 OD_EXTRA3_SGPR = GenericOperandDesc(
620 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
621 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
622 OD_EXTRA3_VGPR = GenericOperandDesc(
623 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
624 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
625 OD_EXTRA2_SGPR = GenericOperandDesc(
626 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
627 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
628 OD_EXTRA2_VGPR = GenericOperandDesc(
629 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
630 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
631 OD_CA = GenericOperandDesc(
632 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
633 sub_kinds=[LocSubKind.CA])
634 OD_VL = GenericOperandDesc(
635 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
636 sub_kinds=[LocSubKind.VL_MAXVL])
637
638
639 @plain_data(frozen=True, unsafe_hash=True)
640 @final
641 class GenericOpProperties:
642 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
643 "is_copy", "is_load_immediate", "has_side_effects")
644
645 def __init__(
646 self, demo_asm, # type: str
647 inputs, # type: Iterable[GenericOperandDesc]
648 outputs, # type: Iterable[GenericOperandDesc]
649 immediates=(), # type: Iterable[range]
650 is_copy=False, # type: bool
651 is_load_immediate=False, # type: bool
652 has_side_effects=False, # type: bool
653 ):
654 # type: (...) -> None
655 self.demo_asm = demo_asm
656 self.inputs = tuple(inputs)
657 for inp in self.inputs:
658 if inp.tied_input_index is not None:
659 raise ValueError(
660 f"tied_input_index is not allowed on inputs: {inp}")
661 self.outputs = tuple(outputs)
662 fixed_locs = [] # type: list[tuple[Loc, int]]
663 for idx, out in enumerate(self.outputs):
664 if out.tied_input_index is not None \
665 and out.tied_input_index >= len(self.inputs):
666 raise ValueError(f"tied_input_index out of range: {out}")
667 if out.fixed_loc is not None:
668 for other_fixed_loc, other_idx in fixed_locs:
669 if not other_fixed_loc.conflicts(out.fixed_loc):
670 continue
671 raise ValueError(
672 f"conflicting fixed_locs: outputs[{idx}] and "
673 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
674 f"with {other_fixed_loc}")
675 fixed_locs.append((out.fixed_loc, idx))
676 self.immediates = tuple(immediates)
677 self.is_copy = is_copy
678 self.is_load_immediate = is_load_immediate
679 self.has_side_effects = has_side_effects
680
681
682 @plain_data(frozen=True, unsafe_hash=True)
683 @final
684 class OpProperties:
685 __slots__ = "kind", "inputs", "outputs", "maxvl"
686
687 def __init__(self, kind, maxvl):
688 # type: (OpKind, int) -> None
689 self.kind = kind
690 inputs = [] # type: list[OperandDesc]
691 for inp in self.generic.inputs:
692 inputs.extend(inp.instantiate(maxvl=maxvl))
693 self.inputs = tuple(inputs)
694 outputs = [] # type: list[OperandDesc]
695 for out in self.generic.outputs:
696 outputs.extend(out.instantiate(maxvl=maxvl))
697 self.outputs = tuple(outputs)
698 self.maxvl = maxvl
699
700 @property
701 def generic(self):
702 # type: () -> GenericOpProperties
703 return self.kind.properties
704
705 @property
706 def immediates(self):
707 # type: () -> tuple[range, ...]
708 return self.generic.immediates
709
710 @property
711 def demo_asm(self):
712 # type: () -> str
713 return self.generic.demo_asm
714
715 @property
716 def is_copy(self):
717 # type: () -> bool
718 return self.generic.is_copy
719
720 @property
721 def is_load_immediate(self):
722 # type: () -> bool
723 return self.generic.is_load_immediate
724
725 @property
726 def has_side_effects(self):
727 # type: () -> bool
728 return self.generic.has_side_effects
729
730
731 IMM_S16 = range(-1 << 15, 1 << 15)
732
733 _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
734 _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
735 _PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
736
737
738 @unique
739 @final
740 class OpKind(Enum):
741 def __init__(self, properties):
742 # type: (GenericOpProperties) -> None
743 super().__init__()
744 self.__properties = properties
745
746 @property
747 def properties(self):
748 # type: () -> GenericOpProperties
749 return self.__properties
750
751 def instantiate(self, maxvl):
752 # type: (int) -> OpProperties
753 return OpProperties(self, maxvl=maxvl)
754
755 def __repr__(self):
756 return "OpKind." + self._name_
757
758 @cached_property
759 def pre_ra_sim(self):
760 # type: () -> _PRE_RA_SIM_FN
761 return _PRE_RA_SIMS[self.properties]()
762
763 @staticmethod
764 def __clearca_pre_ra_sim(op, state):
765 # type: (Op, PreRASimState) -> None
766 state.ssa_vals[op.outputs[0]] = False,
767 ClearCA = GenericOpProperties(
768 demo_asm="addic 0, 0, 0",
769 inputs=[],
770 outputs=[OD_CA],
771 )
772 _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
773
774 @staticmethod
775 def __setca_pre_ra_sim(op, state):
776 # type: (Op, PreRASimState) -> None
777 state.ssa_vals[op.outputs[0]] = True,
778 SetCA = GenericOpProperties(
779 demo_asm="subfc 0, 0, 0",
780 inputs=[],
781 outputs=[OD_CA],
782 )
783 _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
784
785 @staticmethod
786 def __svadde_pre_ra_sim(op, state):
787 # type: (Op, PreRASimState) -> None
788 RA = state.ssa_vals[op.inputs[0]]
789 RB = state.ssa_vals[op.inputs[1]]
790 carry, = state.ssa_vals[op.inputs[2]]
791 VL, = state.ssa_vals[op.inputs[3]]
792 RT = [] # type: list[int]
793 for i in range(VL):
794 v = RA[i] + RB[i] + carry
795 RT.append(v & GPR_VALUE_MASK)
796 carry = (v >> GPR_SIZE_IN_BITS) != 0
797 state.ssa_vals[op.outputs[0]] = tuple(RT)
798 state.ssa_vals[op.outputs[1]] = carry,
799 SvAddE = GenericOpProperties(
800 demo_asm="sv.adde *RT, *RA, *RB",
801 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
802 outputs=[OD_EXTRA3_VGPR, OD_CA],
803 )
804 _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
805
806 @staticmethod
807 def __svsubfe_pre_ra_sim(op, state):
808 # type: (Op, PreRASimState) -> None
809 RA = state.ssa_vals[op.inputs[0]]
810 RB = state.ssa_vals[op.inputs[1]]
811 carry, = state.ssa_vals[op.inputs[2]]
812 VL, = state.ssa_vals[op.inputs[3]]
813 RT = [] # type: list[int]
814 for i in range(VL):
815 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
816 RT.append(v & GPR_VALUE_MASK)
817 carry = (v >> GPR_SIZE_IN_BITS) != 0
818 state.ssa_vals[op.outputs[0]] = tuple(RT)
819 state.ssa_vals[op.outputs[1]] = carry,
820 SvSubFE = GenericOpProperties(
821 demo_asm="sv.subfe *RT, *RA, *RB",
822 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
823 outputs=[OD_EXTRA3_VGPR, OD_CA],
824 )
825 _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
826
827 @staticmethod
828 def __svmaddedu_pre_ra_sim(op, state):
829 # type: (Op, PreRASimState) -> None
830 RA = state.ssa_vals[op.inputs[0]]
831 RB, = state.ssa_vals[op.inputs[1]]
832 carry, = state.ssa_vals[op.inputs[2]]
833 VL, = state.ssa_vals[op.inputs[3]]
834 RT = [] # type: list[int]
835 for i in range(VL):
836 v = RA[i] * RB + carry
837 RT.append(v & GPR_VALUE_MASK)
838 carry = v >> GPR_SIZE_IN_BITS
839 state.ssa_vals[op.outputs[0]] = tuple(RT)
840 state.ssa_vals[op.outputs[1]] = carry,
841 SvMAddEDU = GenericOpProperties(
842 demo_asm="sv.maddedu *RT, *RA, RB, RC",
843 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
844 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
845 )
846 _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
847
848 @staticmethod
849 def __setvli_pre_ra_sim(op, state):
850 # type: (Op, PreRASimState) -> None
851 state.ssa_vals[op.outputs[0]] = op.immediates[0],
852 SetVLI = GenericOpProperties(
853 demo_asm="setvl 0, 0, imm, 0, 1, 1",
854 inputs=(),
855 outputs=[OD_VL],
856 immediates=[range(1, 65)],
857 is_load_immediate=True,
858 )
859 _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
860
861 @staticmethod
862 def __svli_pre_ra_sim(op, state):
863 # type: (Op, PreRASimState) -> None
864 VL, = state.ssa_vals[op.inputs[0]]
865 imm = op.immediates[0] & GPR_VALUE_MASK
866 state.ssa_vals[op.outputs[0]] = (imm,) * VL
867 SvLI = GenericOpProperties(
868 demo_asm="sv.addi *RT, 0, imm",
869 inputs=[OD_VL],
870 outputs=[OD_EXTRA3_VGPR],
871 immediates=[IMM_S16],
872 is_load_immediate=True,
873 )
874 _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
875
876 @staticmethod
877 def __li_pre_ra_sim(op, state):
878 # type: (Op, PreRASimState) -> None
879 imm = op.immediates[0] & GPR_VALUE_MASK
880 state.ssa_vals[op.outputs[0]] = imm,
881 LI = GenericOpProperties(
882 demo_asm="addi RT, 0, imm",
883 inputs=(),
884 outputs=[OD_BASE_SGPR],
885 immediates=[IMM_S16],
886 is_load_immediate=True,
887 )
888 _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
889
890 @staticmethod
891 def __veccopytoreg_pre_ra_sim(op, state):
892 # type: (Op, PreRASimState) -> None
893 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
894 VecCopyToReg = GenericOpProperties(
895 demo_asm="sv.mv dest, src",
896 inputs=[GenericOperandDesc(
897 ty=GenericTy(BaseTy.I64, is_vec=True),
898 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
899 ), OD_VL],
900 outputs=[OD_EXTRA3_VGPR],
901 is_copy=True,
902 )
903 _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
904
905 @staticmethod
906 def __veccopyfromreg_pre_ra_sim(op, state):
907 # type: (Op, PreRASimState) -> None
908 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
909 VecCopyFromReg = GenericOpProperties(
910 demo_asm="sv.mv dest, src",
911 inputs=[OD_EXTRA3_VGPR, OD_VL],
912 outputs=[GenericOperandDesc(
913 ty=GenericTy(BaseTy.I64, is_vec=True),
914 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
915 )],
916 is_copy=True,
917 )
918 _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
919
920 @staticmethod
921 def __copytoreg_pre_ra_sim(op, state):
922 # type: (Op, PreRASimState) -> None
923 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
924 CopyToReg = GenericOpProperties(
925 demo_asm="mv dest, src",
926 inputs=[GenericOperandDesc(
927 ty=GenericTy(BaseTy.I64, is_vec=False),
928 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
929 LocSubKind.StackI64],
930 )],
931 outputs=[GenericOperandDesc(
932 ty=GenericTy(BaseTy.I64, is_vec=False),
933 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
934 )],
935 is_copy=True,
936 )
937 _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
938
939 @staticmethod
940 def __copyfromreg_pre_ra_sim(op, state):
941 # type: (Op, PreRASimState) -> None
942 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
943 CopyFromReg = GenericOpProperties(
944 demo_asm="mv dest, src",
945 inputs=[GenericOperandDesc(
946 ty=GenericTy(BaseTy.I64, is_vec=False),
947 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
948 )],
949 outputs=[GenericOperandDesc(
950 ty=GenericTy(BaseTy.I64, is_vec=False),
951 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
952 LocSubKind.StackI64],
953 )],
954 is_copy=True,
955 )
956 _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
957
958 @staticmethod
959 def __concat_pre_ra_sim(op, state):
960 # type: (Op, PreRASimState) -> None
961 state.ssa_vals[op.outputs[0]] = tuple(
962 state.ssa_vals[i][0] for i in op.inputs[:-1])
963 Concat = GenericOpProperties(
964 demo_asm="sv.mv dest, src",
965 inputs=[GenericOperandDesc(
966 ty=GenericTy(BaseTy.I64, is_vec=False),
967 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
968 spread=True,
969 ), OD_VL],
970 outputs=[OD_EXTRA3_VGPR],
971 is_copy=True,
972 )
973 _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
974
975 @staticmethod
976 def __spread_pre_ra_sim(op, state):
977 # type: (Op, PreRASimState) -> None
978 for idx, inp in enumerate(state.ssa_vals[op.inputs[0]]):
979 state.ssa_vals[op.outputs[idx]] = inp,
980 Spread = GenericOpProperties(
981 demo_asm="sv.mv dest, src",
982 inputs=[OD_EXTRA3_VGPR, OD_VL],
983 outputs=[GenericOperandDesc(
984 ty=GenericTy(BaseTy.I64, is_vec=False),
985 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
986 spread=True,
987 )],
988 is_copy=True,
989 )
990 _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
991
992 @staticmethod
993 def __svld_pre_ra_sim(op, state):
994 # type: (Op, PreRASimState) -> None
995 RA, = state.ssa_vals[op.inputs[0]]
996 VL, = state.ssa_vals[op.inputs[1]]
997 addr = RA + op.immediates[0]
998 RT = [] # type: list[int]
999 for i in range(VL):
1000 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1001 RT.append(v & GPR_VALUE_MASK)
1002 state.ssa_vals[op.outputs[0]] = tuple(RT)
1003 SvLd = GenericOpProperties(
1004 demo_asm="sv.ld *RT, imm(RA)",
1005 inputs=[OD_EXTRA3_SGPR, OD_VL],
1006 outputs=[OD_EXTRA3_VGPR],
1007 immediates=[IMM_S16],
1008 )
1009 _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
1010
1011 @staticmethod
1012 def __ld_pre_ra_sim(op, state):
1013 # type: (Op, PreRASimState) -> None
1014 RA, = state.ssa_vals[op.inputs[0]]
1015 addr = RA + op.immediates[0]
1016 v = state.load(addr)
1017 state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
1018 Ld = GenericOpProperties(
1019 demo_asm="ld RT, imm(RA)",
1020 inputs=[OD_BASE_SGPR],
1021 outputs=[OD_BASE_SGPR],
1022 immediates=[IMM_S16],
1023 )
1024 _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
1025
1026 @staticmethod
1027 def __svstd_pre_ra_sim(op, state):
1028 # type: (Op, PreRASimState) -> None
1029 RS = state.ssa_vals[op.inputs[0]]
1030 RA, = state.ssa_vals[op.inputs[1]]
1031 VL, = state.ssa_vals[op.inputs[2]]
1032 addr = RA + op.immediates[0]
1033 for i in range(VL):
1034 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1035 SvStd = GenericOpProperties(
1036 demo_asm="sv.std *RS, imm(RA)",
1037 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1038 outputs=[],
1039 immediates=[IMM_S16],
1040 has_side_effects=True,
1041 )
1042 _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
1043
1044 @staticmethod
1045 def __std_pre_ra_sim(op, state):
1046 # type: (Op, PreRASimState) -> None
1047 RS, = state.ssa_vals[op.inputs[0]]
1048 RA, = state.ssa_vals[op.inputs[1]]
1049 addr = RA + op.immediates[0]
1050 state.store(addr, value=RS)
1051 Std = GenericOpProperties(
1052 demo_asm="std RT, imm(RA)",
1053 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1054 outputs=[],
1055 immediates=[IMM_S16],
1056 has_side_effects=True,
1057 )
1058 _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
1059
1060 @staticmethod
1061 def __funcargr3_pre_ra_sim(op, state):
1062 # type: (Op, PreRASimState) -> None
1063 pass # return value set before simulation
1064 FuncArgR3 = GenericOpProperties(
1065 demo_asm="",
1066 inputs=[],
1067 outputs=[OD_BASE_SGPR.with_fixed_loc(
1068 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1069 )
1070 _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
1071
1072
1073 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1074 @final
1075 class SSAVal:
1076 __slots__ = "op", "output_idx"
1077
1078 def __init__(self, op, output_idx):
1079 # type: (Op, int) -> None
1080 self.op = op
1081 if output_idx < 0 or output_idx >= len(op.properties.outputs):
1082 raise ValueError("invalid output_idx")
1083 self.output_idx = output_idx
1084
1085 def __repr__(self):
1086 # type: () -> str
1087 return f"<{self.op.name}#{self.output_idx}: {self.ty}>"
1088
1089 @cached_property
1090 def defining_descriptor(self):
1091 # type: () -> OperandDesc
1092 return self.op.properties.outputs[self.output_idx]
1093
1094 @cached_property
1095 def loc_set_before_spread(self):
1096 # type: () -> LocSet
1097 return self.defining_descriptor.loc_set_before_spread
1098
1099 @cached_property
1100 def ty(self):
1101 # type: () -> Ty
1102 return self.defining_descriptor.ty
1103
1104 @cached_property
1105 def ty_before_spread(self):
1106 # type: () -> Ty
1107 return self.defining_descriptor.ty_before_spread
1108
1109
1110 _T = TypeVar("_T")
1111 _Desc = TypeVar("_Desc")
1112
1113
1114 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1115 @abstractmethod
1116 def _verify_write_with_desc(self, idx, item, desc):
1117 # type: (int, _T | Any, _Desc) -> None
1118 raise NotImplementedError
1119
1120 @final
1121 def _verify_write(self, idx, item):
1122 # type: (int | Any, _T | Any) -> int
1123 if not isinstance(idx, int):
1124 if isinstance(idx, slice):
1125 raise TypeError(
1126 f"can't write to slice of {self.__class__.__name__}")
1127 raise TypeError(f"can't write with index {idx!r}")
1128 # normalize idx, raising IndexError if it is out of range
1129 idx = range(len(self.descriptors))[idx]
1130 desc = self.descriptors[idx]
1131 self._verify_write_with_desc(idx, item, desc)
1132 return idx
1133
1134 @abstractmethod
1135 def _get_descriptors(self):
1136 # type: () -> tuple[_Desc, ...]
1137 raise NotImplementedError
1138
1139 @cached_property
1140 @final
1141 def descriptors(self):
1142 # type: () -> tuple[_Desc, ...]
1143 return self._get_descriptors()
1144
1145 @property
1146 @final
1147 def op(self):
1148 return self.__op
1149
1150 def __init__(self, items, op):
1151 # type: (Iterable[_T], Op) -> None
1152 self.__op = op
1153 self.__items = [] # type: list[_T]
1154 for idx, item in enumerate(items):
1155 if idx >= len(self.descriptors):
1156 raise ValueError("too many items")
1157 self._verify_write(idx, item)
1158 self.__items.append(item)
1159 if len(self.__items) < len(self.descriptors):
1160 raise ValueError("not enough items")
1161
1162 @final
1163 def __iter__(self):
1164 # type: () -> Iterator[_T]
1165 yield from self.__items
1166
1167 @overload
1168 def __getitem__(self, idx):
1169 # type: (int) -> _T
1170 ...
1171
1172 @overload
1173 def __getitem__(self, idx):
1174 # type: (slice) -> list[_T]
1175 ...
1176
1177 @final
1178 def __getitem__(self, idx):
1179 # type: (int | slice) -> _T | list[_T]
1180 return self.__items[idx]
1181
1182 @final
1183 def __setitem__(self, idx, item):
1184 # type: (int, _T) -> None
1185 idx = self._verify_write(idx, item)
1186 self.__items[idx] = item
1187
1188 @final
1189 def __len__(self):
1190 # type: () -> int
1191 return len(self.__items)
1192
1193 def __repr__(self):
1194 return f"{self.__class__.__name__}({self.__items}, op=...)"
1195
1196
1197 @final
1198 class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
1199 def _get_descriptors(self):
1200 # type: () -> tuple[OperandDesc, ...]
1201 return self.op.properties.inputs
1202
1203 def _verify_write_with_desc(self, idx, item, desc):
1204 # type: (int, SSAVal | Any, OperandDesc) -> None
1205 if not isinstance(item, SSAVal):
1206 raise TypeError("expected value of type SSAVal")
1207 if item.ty != desc.ty:
1208 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1209 f"corresponding input's type {desc.ty!r}")
1210
1211 def __init__(self, items, op):
1212 # type: (Iterable[SSAVal], Op) -> None
1213 if hasattr(op, "inputs"):
1214 raise ValueError("Op.inputs already set")
1215 super().__init__(items, op)
1216
1217
1218 @final
1219 class OpImmediates(OpInputSeq[int, range]):
1220 def _get_descriptors(self):
1221 # type: () -> tuple[range, ...]
1222 return self.op.properties.immediates
1223
1224 def _verify_write_with_desc(self, idx, item, desc):
1225 # type: (int, int | Any, range) -> None
1226 if not isinstance(item, int):
1227 raise TypeError("expected value of type int")
1228 if item not in desc:
1229 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1230
1231 def __init__(self, items, op):
1232 # type: (Iterable[int], Op) -> None
1233 if hasattr(op, "immediates"):
1234 raise ValueError("Op.immediates already set")
1235 super().__init__(items, op)
1236
1237
1238 @plain_data(frozen=True, eq=False, repr=False)
1239 @final
1240 class Op:
1241 __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
1242
1243 def __init__(self, fn, properties, inputs, immediates, name=""):
1244 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1245 self.fn = fn
1246 self.properties = properties
1247 self.inputs = OpInputs(inputs, op=self)
1248 self.immediates = OpImmediates(immediates, op=self)
1249 outputs_len = len(self.properties.outputs)
1250 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1251 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1252
1253 @property
1254 def kind(self):
1255 return self.properties.kind
1256
1257 def __eq__(self, other):
1258 # type: (Op | Any) -> bool
1259 if isinstance(other, Op):
1260 return self is other
1261 return NotImplemented
1262
1263 def __hash__(self):
1264 return object.__hash__(self)
1265
1266 def __repr__(self):
1267 # type: () -> str
1268 field_vals = [] # type: list[str]
1269 for name in fields(self):
1270 if name == "properties":
1271 name = "kind"
1272 elif name == "fn":
1273 continue
1274 try:
1275 value = getattr(self, name)
1276 except AttributeError:
1277 field_vals.append(f"{name}=<not set>")
1278 continue
1279 if isinstance(value, OpInputSeq):
1280 value = list(value) # type: ignore
1281 field_vals.append(f"{name}={value!r}")
1282 field_vals_str = ", ".join(field_vals)
1283 return f"Op({field_vals_str})"
1284
1285 def pre_ra_sim(self, state):
1286 # type: (PreRASimState) -> None
1287 for inp in self.inputs:
1288 if inp not in state.ssa_vals:
1289 raise ValueError(f"SSAVal {inp} not yet assigned when "
1290 f"running {self}")
1291 if len(state.ssa_vals[inp]) != inp.ty.reg_len:
1292 raise ValueError(
1293 f"value of SSAVal {inp} has wrong number of elements: "
1294 f"expected {inp.ty.reg_len} found "
1295 f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1296 for out in self.outputs:
1297 if out in state.ssa_vals:
1298 if self.kind is OpKind.FuncArgR3:
1299 continue
1300 raise ValueError(f"SSAVal {out} already assigned before "
1301 f"running {self}")
1302 self.kind.pre_ra_sim(self, state)
1303 for out in self.outputs:
1304 if out not in state.ssa_vals:
1305 raise ValueError(f"running {self} failed to assign to {out}")
1306 if len(state.ssa_vals[out]) != out.ty.reg_len:
1307 raise ValueError(
1308 f"value of SSAVal {out} has wrong number of elements: "
1309 f"expected {out.ty.reg_len} found "
1310 f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1311
1312
1313 GPR_SIZE_IN_BYTES = 8
1314 BITS_IN_BYTE = 8
1315 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
1316 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
1317
1318
1319 @plain_data(frozen=True, repr=False)
1320 @final
1321 class PreRASimState:
1322 __slots__ = "ssa_vals", "memory"
1323
1324 def __init__(self, ssa_vals, memory):
1325 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1326 self.ssa_vals = ssa_vals
1327 self.memory = memory
1328
1329 def load_byte(self, addr):
1330 # type: (int) -> int
1331 addr &= GPR_VALUE_MASK
1332 return self.memory.get(addr, 0) & 0xFF
1333
1334 def store_byte(self, addr, value):
1335 # type: (int, int) -> None
1336 addr &= GPR_VALUE_MASK
1337 value &= 0xFF
1338 self.memory[addr] = value
1339
1340 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
1341 # type: (int, int, bool) -> int
1342 if addr % size_in_bytes != 0:
1343 raise ValueError(f"address not aligned: {hex(addr)} "
1344 f"required alignment: {size_in_bytes}")
1345 retval = 0
1346 for i in range(size_in_bytes):
1347 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
1348 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
1349 retval -= 1 << size_in_bytes * BITS_IN_BYTE
1350 return retval
1351
1352 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
1353 # type: (int, int, int) -> None
1354 if addr % size_in_bytes != 0:
1355 raise ValueError(f"address not aligned: {hex(addr)} "
1356 f"required alignment: {size_in_bytes}")
1357 for i in range(size_in_bytes):
1358 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
1359
1360 def _memory__repr(self):
1361 # type: () -> str
1362 if len(self.memory) == 0:
1363 return "{}"
1364 keys = sorted(self.memory.keys(), reverse=True)
1365 CHUNK_SIZE = GPR_SIZE_IN_BYTES
1366 items = [] # type: list[str]
1367 while len(keys) != 0:
1368 addr = keys[-1]
1369 if (len(keys) >= CHUNK_SIZE
1370 and addr % CHUNK_SIZE == 0
1371 and keys[-CHUNK_SIZE:]
1372 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
1373 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
1374 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1375 keys[-CHUNK_SIZE:] = ()
1376 else:
1377 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1378 if len(items) == 1:
1379 return f"{{{items[0]}}}"
1380 items_str = ",\n".join(items)
1381 return f"{{\n{items_str}}}"
1382
1383 def _ssa_vals__repr(self):
1384 # type: () -> str
1385 if len(self.ssa_vals) == 0:
1386 return "{}"
1387 items = [] # type: list[str]
1388 CHUNK_SIZE = 4
1389 for k, v in self.ssa_vals.items():
1390 element_strs = [] # type: list[str]
1391 for i, el in enumerate(v):
1392 if i % CHUNK_SIZE != 0:
1393 element_strs.append(" " + hex(el))
1394 else:
1395 element_strs.append("\n " + hex(el))
1396 if len(element_strs) <= CHUNK_SIZE:
1397 element_strs[0] = element_strs[0].lstrip()
1398 if len(element_strs) == 1:
1399 element_strs.append("")
1400 v_str = ",".join(element_strs)
1401 items.append(f"{k!r}: ({v_str})")
1402 if len(items) == 1 and "\n" not in items[0]:
1403 return f"{{{items[0]}}}"
1404 items_str = ",\n".join(items)
1405 return f"{{\n{items_str},\n}}"
1406
1407 def __repr__(self):
1408 # type: () -> str
1409 field_vals = [] # type: list[str]
1410 for name in fields(self):
1411 try:
1412 value = getattr(self, name)
1413 except AttributeError:
1414 field_vals.append(f"{name}=<not set>")
1415 continue
1416 repr_fn = getattr(self, f"_{name}__repr", None)
1417 if callable(repr_fn):
1418 field_vals.append(f"{name}={repr_fn()}")
1419 else:
1420 field_vals.append(f"{name}={value!r}")
1421 field_vals_str = ", ".join(field_vals)
1422 return f"PreRASimState({field_vals_str})"