remove unused code
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 import dataclasses
4 from abc import ABCMeta, abstractmethod
5 from enum import Enum, unique
6 from functools import lru_cache, total_ordering
7 from io import StringIO
8 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
9 Mapping, Sequence, TypeVar, Union, overload)
10 from weakref import WeakValueDictionary as _WeakVDict
11
12 from cached_property import cached_property
13 from nmutil import plain_data # type: ignore
14
15 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
16 final)
17 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
18 OFSet, OSet, bit_count)
19
20 GPR_SIZE_IN_BYTES = 8
21 BITS_IN_BYTE = 8
22 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
23 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
24
25
26 @final
27 class Fn:
28 def __init__(self):
29 self.ops = [] # type: list[Op]
30 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
31 self.__next_name_suffix = 2
32
33 def _add_op_with_unused_name(self, op, name=""):
34 # type: (Op, str) -> str
35 if op.fn is not self:
36 raise ValueError("can't add Op to wrong Fn")
37 if hasattr(op, "name"):
38 raise ValueError("Op already named")
39 orig_name = name
40 while True:
41 if name != "" and name not in self.__op_names:
42 self.__op_names[name] = op
43 return name
44 name = orig_name + str(self.__next_name_suffix)
45 self.__next_name_suffix += 1
46
47 def __repr__(self):
48 # type: () -> str
49 return "<Fn>"
50
51 def ops_to_str(self, as_python_literal=False, wrap_width=63,
52 python_indent=" ", indent=" "):
53 # type: (bool, int, str, str) -> str
54 l = [] # type: list[str]
55 for op in self.ops:
56 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
57 retval = "\n".join(l)
58 if as_python_literal:
59 l = [python_indent + "\""]
60 for ch in retval:
61 if ch == "\n":
62 l.append(f"\\n\"\n{python_indent}\"")
63 elif ch in "\"\\":
64 l.append("\\" + ch)
65 elif ch.isascii() and ch.isprintable():
66 l.append(ch)
67 else:
68 l.append(repr(ch).strip("\"'"))
69 l.append("\"")
70 retval = "".join(l)
71 empty_end = f"\"\n{python_indent}\"\""
72 if retval.endswith(empty_end):
73 retval = retval[:-len(empty_end)]
74 return retval
75
76 def append_op(self, op):
77 # type: (Op) -> None
78 if op.fn is not self:
79 raise ValueError("can't add Op to wrong Fn")
80 self.ops.append(op)
81
82 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
83 maxvl=1):
84 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
85 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
86 input_vals=input_vals, immediates=immediates, name=name)
87 self.append_op(retval)
88 return retval
89
90 def sim(self, state):
91 # type: (BaseSimState) -> None
92 for op in self.ops:
93 op.sim(state)
94
95 def gen_asm(self, state):
96 # type: (GenAsmState) -> None
97 for op in self.ops:
98 op.gen_asm(state)
99
100 def pre_ra_insert_copies(self):
101 # type: () -> None
102 orig_ops = list(self.ops)
103 copied_outputs = {} # type: dict[SSAVal, SSAVal]
104 setvli_outputs = {} # type: dict[SSAVal, Op]
105 self.ops.clear()
106 for op in orig_ops:
107 for i in range(len(op.input_vals)):
108 inp = copied_outputs[op.input_vals[i]]
109 if inp.ty.base_ty is BaseTy.I64:
110 maxvl = inp.ty.reg_len
111 if inp.ty.reg_len != 1:
112 setvl = self.append_new_op(
113 OpKind.SetVLI, immediates=[maxvl],
114 name=f"{op.name}.inp{i}.setvl")
115 vl = setvl.outputs[0]
116 mv = self.append_new_op(
117 OpKind.VecCopyToReg, input_vals=[inp, vl],
118 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
119 else:
120 mv = self.append_new_op(
121 OpKind.CopyToReg, input_vals=[inp],
122 name=f"{op.name}.inp{i}.copy")
123 op.input_vals[i] = mv.outputs[0]
124 elif inp.ty.base_ty is BaseTy.CA \
125 or inp.ty.base_ty is BaseTy.VL_MAXVL:
126 # all copies would be no-ops, so we don't need to copy,
127 # though we do need to rematerialize SetVLI ops right
128 # before the ops VL
129 if inp in setvli_outputs:
130 setvl = self.append_new_op(
131 OpKind.SetVLI,
132 immediates=setvli_outputs[inp].immediates,
133 name=f"{op.name}.inp{i}.setvl")
134 inp = setvl.outputs[0]
135 op.input_vals[i] = inp
136 else:
137 assert_never(inp.ty.base_ty)
138 self.ops.append(op)
139 for i, out in enumerate(op.outputs):
140 if op.kind is OpKind.SetVLI:
141 setvli_outputs[out] = op
142 if out.ty.base_ty is BaseTy.I64:
143 maxvl = out.ty.reg_len
144 if out.ty.reg_len != 1:
145 setvl = self.append_new_op(
146 OpKind.SetVLI, immediates=[maxvl],
147 name=f"{op.name}.out{i}.setvl")
148 vl = setvl.outputs[0]
149 mv = self.append_new_op(
150 OpKind.VecCopyFromReg, input_vals=[out, vl],
151 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
152 else:
153 mv = self.append_new_op(
154 OpKind.CopyFromReg, input_vals=[out],
155 name=f"{op.name}.out{i}.copy")
156 copied_outputs[out] = mv.outputs[0]
157 elif out.ty.base_ty is BaseTy.CA \
158 or out.ty.base_ty is BaseTy.VL_MAXVL:
159 # all copies would be no-ops, so we don't need to copy
160 copied_outputs[out] = out
161 else:
162 assert_never(out.ty.base_ty)
163
164
165 @final
166 @unique
167 @total_ordering
168 class OpStage(Enum):
169 value: Literal[0, 1] # type: ignore
170
171 def __new__(cls, value):
172 # type: (int) -> OpStage
173 value = int(value)
174 if value not in (0, 1):
175 raise ValueError("invalid value")
176 retval = object.__new__(cls)
177 retval._value_ = value
178 return retval
179
180 Early = 0
181 """ early stage of Op execution, where all input reads occur.
182 all output writes with `write_stage == Early` occur here too, and therefore
183 conflict with input reads, telling the compiler that it that can't share
184 that output's register with any inputs that the output isn't tied to.
185
186 All outputs, even unused outputs, can't share registers with any other
187 outputs, independent of `write_stage` settings.
188 """
189 Late = 1
190 """ late stage of Op execution, where all output writes with
191 `write_stage == Late` occur, and therefore don't conflict with input reads,
192 telling the compiler that any inputs can safely use the same register as
193 those outputs.
194
195 All outputs, even unused outputs, can't share registers with any other
196 outputs, independent of `write_stage` settings.
197 """
198
199 def __repr__(self):
200 # type: () -> str
201 return f"OpStage.{self._name_}"
202
203 def __lt__(self, other):
204 # type: (OpStage | object) -> bool
205 if isinstance(other, OpStage):
206 return self.value < other.value
207 return NotImplemented
208
209
210 assert OpStage.Early < OpStage.Late, "early must be less than late"
211
212
213 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
214 @final
215 class ProgramPoint(Interned):
216 op_index: int
217 stage: OpStage
218
219 @property
220 def int_value(self):
221 # type: () -> int
222 """ an integer representation of `self` such that it keeps ordering and
223 successor/predecessor relations.
224 """
225 return self.op_index * 2 + self.stage.value
226
227 @staticmethod
228 def from_int_value(int_value):
229 # type: (int) -> ProgramPoint
230 op_index, stage = divmod(int_value, 2)
231 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
232
233 def next(self, steps=1):
234 # type: (int) -> ProgramPoint
235 return ProgramPoint.from_int_value(self.int_value + steps)
236
237 def prev(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return self.next(steps=-steps)
240
241 def __lt__(self, other):
242 # type: (ProgramPoint | Any) -> bool
243 if not isinstance(other, ProgramPoint):
244 return NotImplemented
245 if self.op_index != other.op_index:
246 return self.op_index < other.op_index
247 return self.stage < other.stage
248
249 def __gt__(self, other):
250 # type: (ProgramPoint | Any) -> bool
251 if not isinstance(other, ProgramPoint):
252 return NotImplemented
253 return other.__lt__(self)
254
255 def __le__(self, other):
256 # type: (ProgramPoint | Any) -> bool
257 if not isinstance(other, ProgramPoint):
258 return NotImplemented
259 return not self.__gt__(other)
260
261 def __ge__(self, other):
262 # type: (ProgramPoint | Any) -> bool
263 if not isinstance(other, ProgramPoint):
264 return NotImplemented
265 return not self.__lt__(other)
266
267 def __repr__(self):
268 # type: () -> str
269 return f"<ops[{self.op_index}]:{self.stage._name_}>"
270
271
272 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
273 @final
274 class ProgramRange(Sequence[ProgramPoint], Interned):
275 start: ProgramPoint
276 stop: ProgramPoint
277
278 @cached_property
279 def int_value_range(self):
280 # type: () -> range
281 return range(self.start.int_value, self.stop.int_value)
282
283 @staticmethod
284 def from_int_value_range(int_value_range):
285 # type: (range) -> ProgramRange
286 if int_value_range.step != 1:
287 raise ValueError("int_value_range must have step == 1")
288 return ProgramRange(
289 start=ProgramPoint.from_int_value(int_value_range.start),
290 stop=ProgramPoint.from_int_value(int_value_range.stop))
291
292 @overload
293 def __getitem__(self, __idx):
294 # type: (int) -> ProgramPoint
295 ...
296
297 @overload
298 def __getitem__(self, __idx):
299 # type: (slice) -> ProgramRange
300 ...
301
302 def __getitem__(self, __idx):
303 # type: (int | slice) -> ProgramPoint | ProgramRange
304 v = range(self.start.int_value, self.stop.int_value)[__idx]
305 if isinstance(v, int):
306 return ProgramPoint.from_int_value(v)
307 return ProgramRange.from_int_value_range(v)
308
309 def __len__(self):
310 # type: () -> int
311 return len(self.int_value_range)
312
313 def __iter__(self):
314 # type: () -> Iterator[ProgramPoint]
315 return map(ProgramPoint.from_int_value, self.int_value_range)
316
317 def __repr__(self):
318 # type: () -> str
319 start = repr(self.start).lstrip("<").rstrip(">")
320 stop = repr(self.stop).lstrip("<").rstrip(">")
321 return f"<range:{start}..{stop}>"
322
323
324 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
325 @final
326 class SSAValSubReg(Interned):
327 ssa_val: "SSAVal"
328 reg_idx: int
329
330 def __post_init__(self):
331 if self.reg_idx < 0 or self.reg_idx >= self.ssa_val.ty.reg_len:
332 raise ValueError("reg_idx out of range")
333
334 def __repr__(self):
335 # type: () -> str
336 return f"{self.ssa_val}[{self.reg_idx}]"
337
338
339 @plain_data.plain_data(frozen=True, eq=False, repr=False)
340 @final
341 class FnAnalysis:
342 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
343 "def_program_ranges", "use_program_points",
344 "all_program_points")
345
346 def __init__(self, fn):
347 # type: (Fn) -> None
348 self.fn = fn
349 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
350 self.all_program_points = ProgramRange(
351 start=ProgramPoint(op_index=0, stage=OpStage.Early),
352 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
353 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
354 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
355 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
356 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
357 for op in fn.ops:
358 for use in op.input_uses:
359 uses[use.ssa_val].add(use)
360 use_program_point = self.__get_use_program_point(use)
361 use_program_points[use] = use_program_point
362 live_range_stops[use.ssa_val] = max(
363 live_range_stops[use.ssa_val], use_program_point.next())
364 for out in op.outputs:
365 uses[out] = OSet()
366 def_program_range = self.__get_def_program_range(out)
367 def_program_ranges[out] = def_program_range
368 live_range_stops[out] = def_program_range.stop
369 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
370 self.def_program_ranges = FMap(def_program_ranges)
371 self.use_program_points = FMap(use_program_points)
372 live_ranges = {} # type: dict[SSAVal, ProgramRange]
373 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
374 for ssa_val in uses.keys():
375 live_ranges[ssa_val] = live_range = ProgramRange(
376 start=self.def_program_ranges[ssa_val].start,
377 stop=live_range_stops[ssa_val])
378 for program_point in live_range:
379 live_at[program_point].add(ssa_val)
380 self.live_ranges = FMap(live_ranges)
381 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
382 self.copies # initialize
383 self.const_ssa_vals # initialize
384 self.const_ssa_val_sub_regs # initialize
385
386 def __get_def_program_range(self, ssa_val):
387 # type: (SSAVal) -> ProgramRange
388 write_stage = ssa_val.defining_descriptor.write_stage
389 start = ProgramPoint(
390 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
391 # always include late stage of ssa_val.op, to ensure outputs always
392 # overlap all other outputs.
393 # stop is exclusive, so we need the next program point.
394 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
395 return ProgramRange(start=start, stop=stop)
396
397 def __get_use_program_point(self, ssa_use):
398 # type: (SSAUse) -> ProgramPoint
399 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
400 "assumed here, ensured by GenericOpProperties.__init__"
401 return ProgramPoint(
402 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
403
404 def __eq__(self, other):
405 # type: (FnAnalysis | Any) -> bool
406 if isinstance(other, FnAnalysis):
407 return self.fn == other.fn
408 return NotImplemented
409
410 def __hash__(self):
411 # type: () -> int
412 return hash(self.fn)
413
414 def __repr__(self):
415 # type: () -> str
416 return "<FnAnalysis>"
417
418 @cached_property
419 def copies(self):
420 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
421 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
422 a copy of, looking through all layers of copies. The map excludes all
423 SSAValSubRegs that aren't copies of other SSAValSubRegs.
424 This ignores inputs of copy Ops that aren't actually being copied
425 (e.g. the VL input of VecCopyToReg).
426 """
427 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
428 for op in self.op_indexes.keys():
429 if not op.properties.is_copy:
430 continue
431 copy_reg_len = op.properties.copy_reg_len
432 copy_inputs = [] # type: list[SSAValSubReg]
433 for inp in op.input_vals[:op.properties.copy_inputs_len]:
434 for inp_sub_reg in inp.ssa_val_sub_regs:
435 # propagate copies of copies
436 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
437 copy_inputs.append(inp_sub_reg)
438 assert len(copy_inputs) == copy_reg_len, "logic error"
439 copy_outputs = [] # type: list[SSAValSubReg]
440 for out in op.outputs[:op.properties.copy_outputs_len]:
441 copy_outputs.extend(out.ssa_val_sub_regs)
442 assert len(copy_outputs) == copy_reg_len, "logic error"
443 for inp, out in zip(copy_inputs, copy_outputs):
444 retval[out] = inp
445 return FMap(retval)
446
447 @cached_property
448 def copy_related_ssa_vals(self):
449 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
450 """ map from SSAVals to the full set of SSAVals that are related by
451 being sources/destinations of copies, transitively looking through all
452 copies.
453 This ignores inputs of copy Ops that aren't actually being copied
454 (e.g. the VL input of VecCopyToReg).
455 """
456 sets_map = {i: OSet([i]) for i in self.uses.keys()}
457 for k, v in self.copies.items():
458 k_set = sets_map[k.ssa_val]
459 v_set = sets_map[v.ssa_val]
460 # merge k_set and v_set
461 if k_set is v_set:
462 continue
463 k_set |= v_set
464 for i in k_set:
465 sets_map[i] = k_set
466 # this way we construct each OFSet only once rather than
467 # for each SSAVal
468 sets_set = {id(i): i for i in sets_map.values()}
469 retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
470 for v in sets_set.values():
471 v = OFSet(v)
472 for k in v:
473 retval[k] = v
474 return FMap(retval)
475
476 @cached_property
477 def const_ssa_vals(self):
478 # type: () -> FMap[SSAVal, tuple[int, ...]]
479 state = ConstPropagationState(
480 ssa_vals={}, memory={}, skipped_ops=OSet())
481 self.fn.sim(state)
482 return FMap(state.ssa_vals)
483
484 @cached_property
485 def const_ssa_val_sub_regs(self):
486 # type: () -> FMap[SSAValSubReg, int]
487 retval = {} # type: dict[SSAValSubReg, int]
488 for ssa_val, const_val in self.const_ssa_vals.items():
489 assert ssa_val.ty.reg_len == len(const_val), "logic error"
490 for reg_idx, v in enumerate(const_val):
491 retval[SSAValSubReg(ssa_val, reg_idx)] = v
492 return FMap(retval)
493
494 def is_always_equal(self, a, b):
495 # type: (SSAValSubReg, SSAValSubReg) -> bool
496 """check if a and b are known to be always equal to each other.
497 This means they can be allocated to the same location if other
498 constraints don't prevent that.
499
500 this can happen for a number of reasons, such as:
501 * a and b are copies of the same thing
502 * a and b are known to be constants and they have the same value
503 """
504 if a.ssa_val.base_ty != b.ssa_val.base_ty:
505 return False # can't be equal, they have different types
506 # look through copies
507 a = self.copies.get(a, a)
508 b = self.copies.get(b, b)
509 if a == b:
510 return True
511 # check if they have the same constant value
512 try:
513 a_const_val = self.const_ssa_val_sub_regs[a]
514 b_const_val = self.const_ssa_val_sub_regs[b]
515 if a_const_val == b_const_val:
516 return True
517 except KeyError:
518 pass
519 return False
520
521
522 @unique
523 @final
524 class BaseTy(Enum):
525 I64 = enum.auto()
526 CA = enum.auto()
527 VL_MAXVL = enum.auto()
528
529 @cached_property
530 def only_scalar(self):
531 # type: () -> bool
532 if self is BaseTy.I64:
533 return False
534 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
535 return True
536 else:
537 assert_never(self)
538
539 @cached_property
540 def max_reg_len(self):
541 # type: () -> int
542 if self is BaseTy.I64:
543 return 128
544 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
545 return 1
546 else:
547 assert_never(self)
548
549 def __repr__(self):
550 return "BaseTy." + self._name_
551
552
553 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
554 @final
555 class Ty(Interned):
556 base_ty: BaseTy
557 reg_len: int
558
559 @staticmethod
560 def validate(base_ty, reg_len):
561 # type: (BaseTy, int) -> str | None
562 """ return a string with the error if the combination is invalid,
563 otherwise return None
564 """
565 if base_ty.only_scalar and reg_len != 1:
566 return f"can't create a vector of an only-scalar type: {base_ty}"
567 if reg_len < 1 or reg_len > base_ty.max_reg_len:
568 return "reg_len out of range"
569 return None
570
571 def __post_init__(self):
572 msg = self.validate(base_ty=self.base_ty, reg_len=self.reg_len)
573 if msg is not None:
574 raise ValueError(msg)
575
576 def __repr__(self):
577 # type: () -> str
578 if self.reg_len != 1:
579 reg_len = f"*{self.reg_len}"
580 else:
581 reg_len = ""
582 return f"<{self.base_ty._name_}{reg_len}>"
583
584
585 @unique
586 @final
587 class LocKind(Enum):
588 GPR = enum.auto()
589 StackI64 = enum.auto()
590 CA = enum.auto()
591 VL_MAXVL = enum.auto()
592
593 @cached_property
594 def base_ty(self):
595 # type: () -> BaseTy
596 if self is LocKind.GPR or self is LocKind.StackI64:
597 return BaseTy.I64
598 if self is LocKind.CA:
599 return BaseTy.CA
600 if self is LocKind.VL_MAXVL:
601 return BaseTy.VL_MAXVL
602 else:
603 assert_never(self)
604
605 @cached_property
606 def loc_count(self):
607 # type: () -> int
608 if self is LocKind.StackI64:
609 return 512
610 if self is LocKind.GPR or self is LocKind.CA \
611 or self is LocKind.VL_MAXVL:
612 return self.base_ty.max_reg_len
613 else:
614 assert_never(self)
615
616 def __repr__(self):
617 return "LocKind." + self._name_
618
619
620 @final
621 @unique
622 class LocSubKind(Enum):
623 BASE_GPR = enum.auto()
624 SV_EXTRA2_VGPR = enum.auto()
625 SV_EXTRA2_SGPR = enum.auto()
626 SV_EXTRA3_VGPR = enum.auto()
627 SV_EXTRA3_SGPR = enum.auto()
628 StackI64 = enum.auto()
629 CA = enum.auto()
630 VL_MAXVL = enum.auto()
631
632 @cached_property
633 def kind(self):
634 # type: () -> LocKind
635 # pyright fails typechecking when using `in` here:
636 # reported: https://github.com/microsoft/pyright/issues/4102
637 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
638 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
639 LocSubKind.SV_EXTRA3_SGPR):
640 return LocKind.GPR
641 if self is LocSubKind.StackI64:
642 return LocKind.StackI64
643 if self is LocSubKind.CA:
644 return LocKind.CA
645 if self is LocSubKind.VL_MAXVL:
646 return LocKind.VL_MAXVL
647 assert_never(self)
648
649 @property
650 def base_ty(self):
651 return self.kind.base_ty
652
653 @lru_cache()
654 def allocatable_locs(self, ty):
655 # type: (Ty) -> LocSet
656 if ty.base_ty != self.base_ty:
657 raise ValueError("type mismatch")
658 if self is LocSubKind.BASE_GPR:
659 starts = range(32)
660 elif self is LocSubKind.SV_EXTRA2_VGPR:
661 starts = range(0, 128, 2)
662 elif self is LocSubKind.SV_EXTRA2_SGPR:
663 starts = range(64)
664 elif self is LocSubKind.SV_EXTRA3_VGPR \
665 or self is LocSubKind.SV_EXTRA3_SGPR:
666 starts = range(128)
667 elif self is LocSubKind.StackI64:
668 starts = range(LocKind.StackI64.loc_count)
669 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
670 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
671 else:
672 assert_never(self)
673 retval = [] # type: list[Loc]
674 for start in starts:
675 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
676 if loc is None:
677 continue
678 conflicts = False
679 for special_loc in SPECIAL_GPRS:
680 if loc.conflicts(special_loc):
681 conflicts = True
682 break
683 if not conflicts:
684 retval.append(loc)
685 return LocSet(retval)
686
687 def __repr__(self):
688 return "LocSubKind." + self._name_
689
690
691 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
692 @final
693 class GenericTy(Interned):
694 base_ty: BaseTy
695 is_vec: bool
696
697 def __post_init__(self):
698 if self.base_ty.only_scalar and self.is_vec:
699 raise ValueError(f"base_ty={self.base_ty} requires is_vec=False")
700
701 def instantiate(self, maxvl):
702 # type: (int) -> Ty
703 # here's where subvl and elwid would be accounted for
704 if self.is_vec:
705 return Ty(self.base_ty, maxvl)
706 return Ty(self.base_ty, 1)
707
708 def can_instantiate_to(self, ty):
709 # type: (Ty) -> bool
710 if self.base_ty != ty.base_ty:
711 return False
712 if self.is_vec:
713 return True
714 return ty.reg_len == 1
715
716
717 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
718 @final
719 class Loc(Interned):
720 kind: LocKind
721 start: int
722 reg_len: int
723
724 @staticmethod
725 def validate(kind, start, reg_len):
726 # type: (LocKind, int, int) -> str | None
727 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
728 if msg is not None:
729 return msg
730 if reg_len > kind.loc_count:
731 return "invalid reg_len"
732 if start < 0 or start + reg_len > kind.loc_count:
733 return "start not in valid range"
734 return None
735
736 @staticmethod
737 def try_make(kind, start, reg_len):
738 # type: (LocKind, int, int) -> Loc | None
739 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
740 if msg is not None:
741 return None
742 return Loc(kind=kind, start=start, reg_len=reg_len)
743
744 def __post_init__(self):
745 msg = self.validate(kind=self.kind, start=self.start,
746 reg_len=self.reg_len)
747 if msg is not None:
748 raise ValueError(msg)
749
750 def conflicts(self, other):
751 # type: (Loc) -> bool
752 return (self.kind == other.kind
753 and self.start < other.stop and other.start < self.stop)
754
755 @staticmethod
756 def make_ty(kind, reg_len):
757 # type: (LocKind, int) -> Ty
758 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
759
760 @cached_property
761 def ty(self):
762 # type: () -> Ty
763 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
764
765 @cached_property
766 def stop(self):
767 # type: () -> int
768 return self.start + self.reg_len
769
770 def try_concat(self, *others):
771 # type: (*Loc | None) -> Loc | None
772 reg_len = self.reg_len
773 stop = self.stop
774 for other in others:
775 if other is None or other.kind != self.kind:
776 return None
777 if stop != other.start:
778 return None
779 stop = other.stop
780 reg_len += other.reg_len
781 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
782
783 def get_subloc_at_offset(self, subloc_ty, offset):
784 # type: (Ty, int) -> Loc
785 if subloc_ty.base_ty != self.kind.base_ty:
786 raise ValueError("BaseTy mismatch")
787 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
788 raise ValueError("invalid sub-Loc: offset and/or "
789 "subloc_ty.reg_len out of range")
790 return Loc(kind=self.kind,
791 start=self.start + offset, reg_len=subloc_ty.reg_len)
792
793 def get_superloc_with_self_at_offset(self, superloc_ty, offset):
794 # type: (Ty, int) -> Loc
795 """get the Loc containing `self` such that:
796 `retval.get_subloc_at_offset(self.ty, offset) == self`
797 and `retval.ty == superloc_ty`
798 """
799 if superloc_ty.base_ty != self.kind.base_ty:
800 raise ValueError("BaseTy mismatch")
801 if offset < 0 or offset + self.reg_len > superloc_ty.reg_len:
802 raise ValueError("invalid sub-Loc: offset and/or "
803 "superloc_ty.reg_len out of range")
804 return Loc(kind=self.kind,
805 start=self.start - offset, reg_len=superloc_ty.reg_len)
806
807
808 SPECIAL_GPRS = (
809 Loc(kind=LocKind.GPR, start=0, reg_len=1),
810 Loc(kind=LocKind.GPR, start=1, reg_len=1),
811 Loc(kind=LocKind.GPR, start=2, reg_len=1),
812 Loc(kind=LocKind.GPR, start=13, reg_len=1),
813 )
814
815
816 @final
817 class LocSet(OFSet[Loc], Interned):
818 def __init__(self, __locs=()):
819 # type: (Iterable[Loc]) -> None
820 super().__init__(__locs)
821 if isinstance(__locs, LocSet):
822 self.__starts = __locs.starts
823 self.__ty = __locs.ty
824 return
825 starts = {i: BitSet() for i in LocKind}
826 ty = None # type: None | Ty
827 for loc in self:
828 if ty is None:
829 ty = loc.ty
830 if ty != loc.ty:
831 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
832 starts[loc.kind].add(loc.start)
833 self.__starts = FMap(
834 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
835 self.__ty = ty
836
837 @property
838 def starts(self):
839 # type: () -> FMap[LocKind, FBitSet]
840 return self.__starts
841
842 @property
843 def ty(self):
844 # type: () -> Ty | None
845 return self.__ty
846
847 @cached_property
848 def stops(self):
849 # type: () -> FMap[LocKind, FBitSet]
850 if self.ty is None:
851 return FMap()
852 sh = self.ty.reg_len
853 return FMap(
854 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
855
856 @property
857 def kinds(self):
858 # type: () -> AbstractSet[LocKind]
859 return self.starts.keys()
860
861 @property
862 def reg_len(self):
863 # type: () -> int | None
864 if self.ty is None:
865 return None
866 return self.ty.reg_len
867
868 @property
869 def base_ty(self):
870 # type: () -> BaseTy | None
871 if self.ty is None:
872 return None
873 return self.ty.base_ty
874
875 @lru_cache(maxsize=None, typed=True)
876 def max_conflicts_with(self, other):
877 # type: (LocSet | Loc) -> int
878 """the largest number of Locs in `self` that a single Loc
879 from `other` can conflict with
880 """
881 if isinstance(other, LocSet):
882 return max(self.max_conflicts_with(i) for i in other)
883 else:
884 # now we do the equivalent of:
885 # return sum(other.conflicts(i) for i in self)
886 reg_len = self.reg_len
887 if reg_len is None:
888 return 0
889 starts = self.starts.get(other.kind)
890 if starts is None:
891 return 0
892 # now we do the equivalent of:
893 # return sum(other.start < start + reg_len
894 # and start < other.start + other.reg_len
895 # for start in starts)
896 stops = starts.bits << reg_len
897
898 # find all the bit indexes `i` where `i < other.start + 1`
899 lt_other_start_plus_1 = ~(~0 << (other.start + 1))
900
901 # find all the bit indexes `i` where
902 # `i < other.start + other.reg_len + reg_len`
903 lt_other_start_plus_other_reg_len_plus_reg_len = (
904 ~(~0 << (other.start + other.reg_len + reg_len)))
905 included = ~(stops & lt_other_start_plus_1)
906 included &= stops
907 included &= lt_other_start_plus_other_reg_len_plus_reg_len
908 return bit_count(included)
909
910 def __repr__(self):
911 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
912
913 @cached_property
914 def only_loc(self):
915 # type: () -> Loc | None
916 """if len(self) == 1 then return the Loc in self, otherwise None"""
917 only_loc = None
918 for i in self:
919 if only_loc is None:
920 only_loc = i
921 else:
922 return None # len(self) > 1
923 return only_loc
924
925
926 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
927 @final
928 class GenericOperandDesc(Interned):
929 """generic Op operand descriptor"""
930 ty: GenericTy
931 sub_kinds: OFSet[LocSubKind]
932 fixed_loc: "Loc | None" = None
933 tied_input_index: "int | None" = None
934 spread: bool = False
935 write_stage: OpStage = OpStage.Early
936
937 def __init__(
938 self, ty, # type: GenericTy
939 sub_kinds, # type: Iterable[LocSubKind]
940 *,
941 fixed_loc=None, # type: Loc | None
942 tied_input_index=None, # type: int | None
943 spread=False, # type: bool
944 write_stage=OpStage.Early, # type: OpStage
945 ):
946 # type: (...) -> None
947 object.__setattr__(self, "ty", ty)
948 object.__setattr__(self, "sub_kinds", OFSet(sub_kinds))
949 if len(self.sub_kinds) == 0:
950 raise ValueError("sub_kinds can't be empty")
951 object.__setattr__(self, "fixed_loc", fixed_loc)
952 if fixed_loc is not None:
953 if tied_input_index is not None:
954 raise ValueError("operand can't be both tied and fixed")
955 if not ty.can_instantiate_to(fixed_loc.ty):
956 raise ValueError(
957 f"fixed_loc has incompatible type for given generic "
958 f"type: fixed_loc={fixed_loc} generic ty={ty}")
959 if len(self.sub_kinds) != 1:
960 raise ValueError(
961 "multiple sub_kinds not allowed for fixed operand")
962 for sub_kind in self.sub_kinds:
963 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
964 raise ValueError(
965 f"fixed_loc not in given sub_kind: "
966 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
967 for sub_kind in self.sub_kinds:
968 if sub_kind.base_ty != ty.base_ty:
969 raise ValueError(f"sub_kind is incompatible with type: "
970 f"sub_kind={sub_kind} ty={ty}")
971 if tied_input_index is not None and tied_input_index < 0:
972 raise ValueError("invalid tied_input_index")
973 object.__setattr__(self, "tied_input_index", tied_input_index)
974 object.__setattr__(self, "spread", spread)
975 if spread:
976 if self.tied_input_index is not None:
977 raise ValueError("operand can't be both spread and tied")
978 if self.fixed_loc is not None:
979 raise ValueError("operand can't be both spread and fixed")
980 if self.ty.is_vec:
981 raise ValueError("operand can't be both spread and vector")
982 object.__setattr__(self, "write_stage", write_stage)
983
984 @cached_property
985 def ty_before_spread(self):
986 # type: () -> GenericTy
987 if self.spread:
988 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
989 return self.ty
990
991 def tied_to_input(self, tied_input_index):
992 # type: (int) -> Self
993 return GenericOperandDesc(self.ty, self.sub_kinds,
994 tied_input_index=tied_input_index,
995 write_stage=self.write_stage)
996
997 def with_fixed_loc(self, fixed_loc):
998 # type: (Loc) -> Self
999 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
1000 write_stage=self.write_stage)
1001
1002 def with_write_stage(self, write_stage):
1003 # type: (OpStage) -> Self
1004 return GenericOperandDesc(self.ty, self.sub_kinds,
1005 fixed_loc=self.fixed_loc,
1006 tied_input_index=self.tied_input_index,
1007 spread=self.spread,
1008 write_stage=write_stage)
1009
1010 def instantiate(self, maxvl):
1011 # type: (int) -> Iterable[OperandDesc]
1012 # assumes all spread operands have ty.reg_len = 1
1013 rep_count = 1
1014 if self.spread:
1015 rep_count = maxvl
1016 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
1017
1018 def locs_before_spread():
1019 # type: () -> Iterable[Loc]
1020 if self.fixed_loc is not None:
1021 if ty_before_spread != self.fixed_loc.ty:
1022 raise ValueError(
1023 f"instantiation failed: type mismatch with fixed_loc: "
1024 f"instantiated type: {ty_before_spread} "
1025 f"fixed_loc: {self.fixed_loc}")
1026 yield self.fixed_loc
1027 return
1028 for sub_kind in self.sub_kinds:
1029 yield from sub_kind.allocatable_locs(ty_before_spread)
1030 loc_set_before_spread = LocSet(locs_before_spread())
1031 for idx in range(rep_count):
1032 if not self.spread:
1033 idx = None
1034 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
1035 tied_input_index=self.tied_input_index,
1036 spread_index=idx, write_stage=self.write_stage)
1037
1038
1039 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1040 @final
1041 class OperandDesc(Interned):
1042 """Op operand descriptor"""
1043 loc_set_before_spread: LocSet
1044 tied_input_index: "int | None"
1045 spread_index: "int | None"
1046 write_stage: "OpStage"
1047
1048 def __post_init__(self):
1049 if len(self.loc_set_before_spread) == 0:
1050 raise ValueError("loc_set_before_spread must not be empty")
1051 if self.tied_input_index is not None and self.spread_index is not None:
1052 raise ValueError("operand can't be both spread and tied")
1053
1054 @cached_property
1055 def ty_before_spread(self):
1056 # type: () -> Ty
1057 ty = self.loc_set_before_spread.ty
1058 assert ty is not None, (
1059 "__init__ checked that the LocSet isn't empty, "
1060 "non-empty LocSets should always have ty set")
1061 return ty
1062
1063 @cached_property
1064 def ty(self):
1065 """ Ty after any spread is applied """
1066 if self.spread_index is not None:
1067 # assumes all spread operands have ty.reg_len = 1
1068 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1069 return self.ty_before_spread
1070
1071 @property
1072 def reg_offset_in_unspread(self):
1073 """ the number of reg-sized slots in the unspread Loc before self's Loc
1074
1075 e.g. if the unspread Loc containing self is:
1076 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1077 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1078 then reg_offset_into_unspread == 2 == 10 - 8
1079 """
1080 if self.spread_index is None:
1081 return 0
1082 return self.spread_index * self.ty.reg_len
1083
1084
1085 OD_BASE_SGPR = GenericOperandDesc(
1086 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1087 sub_kinds=[LocSubKind.BASE_GPR])
1088 OD_EXTRA3_SGPR = GenericOperandDesc(
1089 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1090 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1091 OD_EXTRA3_VGPR = GenericOperandDesc(
1092 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1093 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1094 OD_EXTRA2_SGPR = GenericOperandDesc(
1095 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1096 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1097 OD_EXTRA2_VGPR = GenericOperandDesc(
1098 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1099 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1100 OD_CA = GenericOperandDesc(
1101 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1102 sub_kinds=[LocSubKind.CA])
1103 OD_VL = GenericOperandDesc(
1104 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1105 sub_kinds=[LocSubKind.VL_MAXVL])
1106
1107
1108 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1109 @final
1110 class GenericOpProperties(Interned):
1111 demo_asm: str
1112 inputs: "tuple[GenericOperandDesc, ...]"
1113 outputs: "tuple[GenericOperandDesc, ...]"
1114 immediates: "tuple[range, ...]"
1115 is_copy: bool
1116 is_load_immediate: bool
1117 has_side_effects: bool
1118
1119 def __init__(
1120 self, demo_asm, # type: str
1121 inputs, # type: Iterable[GenericOperandDesc]
1122 outputs, # type: Iterable[GenericOperandDesc]
1123 immediates=(), # type: Iterable[range]
1124 is_copy=False, # type: bool
1125 is_load_immediate=False, # type: bool
1126 has_side_effects=False, # type: bool
1127 ):
1128 # type: (...) -> None
1129 object.__setattr__(self, "demo_asm", demo_asm)
1130 object.__setattr__(self, "inputs", tuple(inputs))
1131 for inp in self.inputs:
1132 if inp.tied_input_index is not None:
1133 raise ValueError(
1134 f"tied_input_index is not allowed on inputs: {inp}")
1135 if inp.write_stage is not OpStage.Early:
1136 raise ValueError(
1137 f"write_stage is not allowed on inputs: {inp}")
1138 object.__setattr__(self, "outputs", tuple(outputs))
1139 fixed_locs = [] # type: list[tuple[Loc, int]]
1140 for idx, out in enumerate(self.outputs):
1141 if out.tied_input_index is not None:
1142 if out.tied_input_index >= len(self.inputs):
1143 raise ValueError(f"tied_input_index out of range: {out}")
1144 tied_inp = self.inputs[out.tied_input_index]
1145 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1146 .with_write_stage(out.write_stage)
1147 if expected_out != out:
1148 raise ValueError(f"output can't be tied to non-equivalent "
1149 f"input: {out} tied to {tied_inp}")
1150 if out.fixed_loc is not None:
1151 for other_fixed_loc, other_idx in fixed_locs:
1152 if not other_fixed_loc.conflicts(out.fixed_loc):
1153 continue
1154 raise ValueError(
1155 f"conflicting fixed_locs: outputs[{idx}] and "
1156 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1157 f"with {other_fixed_loc}")
1158 fixed_locs.append((out.fixed_loc, idx))
1159 object.__setattr__(self, "immediates", tuple(immediates))
1160 object.__setattr__(self, "is_copy", is_copy)
1161 object.__setattr__(self, "is_load_immediate", is_load_immediate)
1162 object.__setattr__(self, "has_side_effects", has_side_effects)
1163
1164
1165 @plain_data.plain_data(frozen=True, unsafe_hash=True)
1166 @final
1167 class OpProperties:
1168 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1169
1170 def __init__(self, kind, maxvl):
1171 # type: (OpKind, int) -> None
1172 self.kind = kind # type: OpKind
1173 inputs = [] # type: list[OperandDesc]
1174 for inp in self.generic.inputs:
1175 inputs.extend(inp.instantiate(maxvl=maxvl))
1176 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1177 outputs = [] # type: list[OperandDesc]
1178 for out in self.generic.outputs:
1179 outputs.extend(out.instantiate(maxvl=maxvl))
1180 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1181 self.maxvl = maxvl # type: int
1182 copy_input_reg_len = 0
1183 for inp in self.inputs[:self.copy_inputs_len]:
1184 copy_input_reg_len += inp.ty.reg_len
1185 copy_output_reg_len = 0
1186 for out in self.outputs[:self.copy_outputs_len]:
1187 copy_output_reg_len += out.ty.reg_len
1188 if copy_input_reg_len != copy_output_reg_len:
1189 raise ValueError(f"invalid copy: copy's input reg len must "
1190 f"match its output reg len: "
1191 f"{copy_input_reg_len} != {copy_output_reg_len}")
1192 self.copy_reg_len = copy_input_reg_len
1193
1194 @property
1195 def generic(self):
1196 # type: () -> GenericOpProperties
1197 return self.kind.properties
1198
1199 @property
1200 def immediates(self):
1201 # type: () -> tuple[range, ...]
1202 return self.generic.immediates
1203
1204 @property
1205 def demo_asm(self):
1206 # type: () -> str
1207 return self.generic.demo_asm
1208
1209 @property
1210 def is_copy(self):
1211 # type: () -> bool
1212 return self.generic.is_copy
1213
1214 @property
1215 def is_load_immediate(self):
1216 # type: () -> bool
1217 return self.generic.is_load_immediate
1218
1219 @property
1220 def has_side_effects(self):
1221 # type: () -> bool
1222 return self.generic.has_side_effects
1223
1224 @cached_property
1225 def copy_inputs_len(self):
1226 # type: () -> int
1227 if not self.is_copy:
1228 return 0
1229 if self.inputs[0].spread_index is None:
1230 return 1
1231 retval = 0
1232 for i, inp in enumerate(self.inputs):
1233 if inp.spread_index != i:
1234 break
1235 retval += 1
1236 return retval
1237
1238 @cached_property
1239 def copy_outputs_len(self):
1240 # type: () -> int
1241 if not self.is_copy:
1242 return 0
1243 if self.outputs[0].spread_index is None:
1244 return 1
1245 retval = 0
1246 for i, out in enumerate(self.outputs):
1247 if out.spread_index != i:
1248 break
1249 retval += 1
1250 return retval
1251
1252
1253 IMM_S16 = range(-1 << 15, 1 << 15)
1254
1255 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1256 _SIM_FN2 = Callable[[], _SIM_FN]
1257 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1258 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1259 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1260 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1261
1262
1263 @unique
1264 @final
1265 class OpKind(Enum):
1266 def __init__(self, properties):
1267 # type: (GenericOpProperties) -> None
1268 super().__init__()
1269 self.__properties = properties
1270
1271 @property
1272 def properties(self):
1273 # type: () -> GenericOpProperties
1274 return self.__properties
1275
1276 def instantiate(self, maxvl):
1277 # type: (int) -> OpProperties
1278 return OpProperties(self, maxvl=maxvl)
1279
1280 def __repr__(self):
1281 # type: () -> str
1282 return "OpKind." + self._name_
1283
1284 @cached_property
1285 def sim(self):
1286 # type: () -> _SIM_FN
1287 return _SIM_FNS[self.properties]()
1288
1289 @cached_property
1290 def gen_asm(self):
1291 # type: () -> _GEN_ASM_FN
1292 return _GEN_ASMS[self.properties]()
1293
1294 @staticmethod
1295 def __clearca_sim(op, state):
1296 # type: (Op, BaseSimState) -> None
1297 state[op.outputs[0]] = False,
1298
1299 @staticmethod
1300 def __clearca_gen_asm(op, state):
1301 # type: (Op, GenAsmState) -> None
1302 state.writeln("addic 0, 0, 0")
1303 ClearCA = GenericOpProperties(
1304 demo_asm="addic 0, 0, 0",
1305 inputs=[],
1306 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1307 )
1308 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1309 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1310
1311 @staticmethod
1312 def __setca_sim(op, state):
1313 # type: (Op, BaseSimState) -> None
1314 state[op.outputs[0]] = True,
1315
1316 @staticmethod
1317 def __setca_gen_asm(op, state):
1318 # type: (Op, GenAsmState) -> None
1319 state.writeln("subfc 0, 0, 0")
1320 SetCA = GenericOpProperties(
1321 demo_asm="subfc 0, 0, 0",
1322 inputs=[],
1323 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1324 )
1325 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1326 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1327
1328 @staticmethod
1329 def __svadde_sim(op, state):
1330 # type: (Op, BaseSimState) -> None
1331 RA = state[op.input_vals[0]]
1332 RB = state[op.input_vals[1]]
1333 carry, = state[op.input_vals[2]]
1334 VL, = state[op.input_vals[3]]
1335 RT = [] # type: list[int]
1336 for i in range(VL):
1337 v = RA[i] + RB[i] + carry
1338 RT.append(v & GPR_VALUE_MASK)
1339 carry = (v >> GPR_SIZE_IN_BITS) != 0
1340 state[op.outputs[0]] = tuple(RT)
1341 state[op.outputs[1]] = carry,
1342
1343 @staticmethod
1344 def __svadde_gen_asm(op, state):
1345 # type: (Op, GenAsmState) -> None
1346 RT = state.vgpr(op.outputs[0])
1347 RA = state.vgpr(op.input_vals[0])
1348 RB = state.vgpr(op.input_vals[1])
1349 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1350 SvAddE = GenericOpProperties(
1351 demo_asm="sv.adde *RT, *RA, *RB",
1352 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1353 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1354 )
1355 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1356 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1357
1358 @staticmethod
1359 def __addze_sim(op, state):
1360 # type: (Op, BaseSimState) -> None
1361 RA, = state[op.input_vals[0]]
1362 carry, = state[op.input_vals[1]]
1363 v = RA + carry
1364 RT = v & GPR_VALUE_MASK
1365 carry = (v >> GPR_SIZE_IN_BITS) != 0
1366 state[op.outputs[0]] = RT,
1367 state[op.outputs[1]] = carry,
1368
1369 @staticmethod
1370 def __addze_gen_asm(op, state):
1371 # type: (Op, GenAsmState) -> None
1372 RT = state.vgpr(op.outputs[0])
1373 RA = state.vgpr(op.input_vals[0])
1374 state.writeln(f"addze {RT}, {RA}")
1375 AddZE = GenericOpProperties(
1376 demo_asm="addze RT, RA",
1377 inputs=[OD_BASE_SGPR, OD_CA],
1378 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1379 )
1380 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1381 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1382
1383 @staticmethod
1384 def __svsubfe_sim(op, state):
1385 # type: (Op, BaseSimState) -> None
1386 RA = state[op.input_vals[0]]
1387 RB = state[op.input_vals[1]]
1388 carry, = state[op.input_vals[2]]
1389 VL, = state[op.input_vals[3]]
1390 RT = [] # type: list[int]
1391 for i in range(VL):
1392 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1393 RT.append(v & GPR_VALUE_MASK)
1394 carry = (v >> GPR_SIZE_IN_BITS) != 0
1395 state[op.outputs[0]] = tuple(RT)
1396 state[op.outputs[1]] = carry,
1397
1398 @staticmethod
1399 def __svsubfe_gen_asm(op, state):
1400 # type: (Op, GenAsmState) -> None
1401 RT = state.vgpr(op.outputs[0])
1402 RA = state.vgpr(op.input_vals[0])
1403 RB = state.vgpr(op.input_vals[1])
1404 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1405 SvSubFE = GenericOpProperties(
1406 demo_asm="sv.subfe *RT, *RA, *RB",
1407 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1408 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1409 )
1410 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1411 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1412
1413 @staticmethod
1414 def __svandvs_sim(op, state):
1415 # type: (Op, BaseSimState) -> None
1416 RA = state[op.input_vals[0]]
1417 RB, = state[op.input_vals[1]]
1418 VL, = state[op.input_vals[2]]
1419 RT = [] # type: list[int]
1420 for i in range(VL):
1421 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1422 state[op.outputs[0]] = tuple(RT)
1423
1424 @staticmethod
1425 def __svandvs_gen_asm(op, state):
1426 # type: (Op, GenAsmState) -> None
1427 RT = state.vgpr(op.outputs[0])
1428 RA = state.vgpr(op.input_vals[0])
1429 RB = state.sgpr(op.input_vals[1])
1430 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1431 SvAndVS = GenericOpProperties(
1432 demo_asm="sv.and *RT, *RA, RB",
1433 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1434 outputs=[OD_EXTRA3_VGPR],
1435 )
1436 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1437 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1438
1439 @staticmethod
1440 def __svmaddedu_sim(op, state):
1441 # type: (Op, BaseSimState) -> None
1442 RA = state[op.input_vals[0]]
1443 RB, = state[op.input_vals[1]]
1444 carry, = state[op.input_vals[2]]
1445 VL, = state[op.input_vals[3]]
1446 RT = [] # type: list[int]
1447 for i in range(VL):
1448 v = RA[i] * RB + carry
1449 RT.append(v & GPR_VALUE_MASK)
1450 carry = v >> GPR_SIZE_IN_BITS
1451 state[op.outputs[0]] = tuple(RT)
1452 state[op.outputs[1]] = carry,
1453
1454 @staticmethod
1455 def __svmaddedu_gen_asm(op, state):
1456 # type: (Op, GenAsmState) -> None
1457 RT = state.vgpr(op.outputs[0])
1458 RA = state.vgpr(op.input_vals[0])
1459 RB = state.sgpr(op.input_vals[1])
1460 RC = state.sgpr(op.input_vals[2])
1461 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1462 SvMAddEDU = GenericOpProperties(
1463 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1464 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1465 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1466 )
1467 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1468 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1469
1470 @staticmethod
1471 def __sradi_sim(op, state):
1472 # type: (Op, BaseSimState) -> None
1473 rs, = state[op.input_vals[0]]
1474 imm = op.immediates[0]
1475 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1476 rs -= 1 << GPR_SIZE_IN_BITS
1477 v = rs >> imm
1478 RA = v & GPR_VALUE_MASK
1479 CA = (RA << imm) != rs
1480 state[op.outputs[0]] = RA,
1481 state[op.outputs[1]] = CA,
1482
1483 @staticmethod
1484 def __sradi_gen_asm(op, state):
1485 # type: (Op, GenAsmState) -> None
1486 RA = state.sgpr(op.outputs[0])
1487 RS = state.sgpr(op.input_vals[0])
1488 imm = op.immediates[0]
1489 state.writeln(f"sradi {RA}, {RS}, {imm}")
1490 SRADI = GenericOpProperties(
1491 demo_asm="sradi RA, RS, imm",
1492 inputs=[OD_BASE_SGPR],
1493 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1494 OD_CA.with_write_stage(OpStage.Late)],
1495 immediates=[range(GPR_SIZE_IN_BITS)],
1496 )
1497 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1498 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1499
1500 @staticmethod
1501 def __setvli_sim(op, state):
1502 # type: (Op, BaseSimState) -> None
1503 state[op.outputs[0]] = op.immediates[0],
1504
1505 @staticmethod
1506 def __setvli_gen_asm(op, state):
1507 # type: (Op, GenAsmState) -> None
1508 imm = op.immediates[0]
1509 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1510 SetVLI = GenericOpProperties(
1511 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1512 inputs=(),
1513 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1514 immediates=[range(1, 65)],
1515 is_load_immediate=True,
1516 )
1517 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1518 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1519
1520 @staticmethod
1521 def __svli_sim(op, state):
1522 # type: (Op, BaseSimState) -> None
1523 VL, = state[op.input_vals[0]]
1524 imm = op.immediates[0] & GPR_VALUE_MASK
1525 state[op.outputs[0]] = (imm,) * VL
1526
1527 @staticmethod
1528 def __svli_gen_asm(op, state):
1529 # type: (Op, GenAsmState) -> None
1530 RT = state.vgpr(op.outputs[0])
1531 imm = op.immediates[0]
1532 state.writeln(f"sv.addi {RT}, 0, {imm}")
1533 SvLI = GenericOpProperties(
1534 demo_asm="sv.addi *RT, 0, imm",
1535 inputs=[OD_VL],
1536 outputs=[OD_EXTRA3_VGPR],
1537 immediates=[IMM_S16],
1538 is_load_immediate=True,
1539 )
1540 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1541 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1542
1543 @staticmethod
1544 def __li_sim(op, state):
1545 # type: (Op, BaseSimState) -> None
1546 imm = op.immediates[0] & GPR_VALUE_MASK
1547 state[op.outputs[0]] = imm,
1548
1549 @staticmethod
1550 def __li_gen_asm(op, state):
1551 # type: (Op, GenAsmState) -> None
1552 RT = state.sgpr(op.outputs[0])
1553 imm = op.immediates[0]
1554 state.writeln(f"addi {RT}, 0, {imm}")
1555 LI = GenericOpProperties(
1556 demo_asm="addi RT, 0, imm",
1557 inputs=(),
1558 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1559 immediates=[IMM_S16],
1560 is_load_immediate=True,
1561 )
1562 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1563 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1564
1565 @staticmethod
1566 def __veccopytoreg_sim(op, state):
1567 # type: (Op, BaseSimState) -> None
1568 state[op.outputs[0]] = state[op.input_vals[0]]
1569
1570 @staticmethod
1571 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1572 # type: (Loc, Loc, bool, GenAsmState) -> None
1573 sv = "sv." if is_vec else ""
1574 rev = ""
1575 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1576 rev = "/mrr"
1577 if src_loc == dest_loc:
1578 return # no-op
1579 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1580 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1581 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1582 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1583 if src_loc.kind is LocKind.StackI64:
1584 if dest_loc.kind is LocKind.StackI64:
1585 raise ValueError(
1586 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1587 elif dest_loc.kind is not LocKind.GPR:
1588 assert_never(dest_loc.kind)
1589 src = state.stack(src_loc)
1590 dest = state.gpr(dest_loc, is_vec=is_vec)
1591 state.writeln(f"{sv}ld {dest}, {src}")
1592 elif dest_loc.kind is LocKind.StackI64:
1593 if src_loc.kind is not LocKind.GPR:
1594 assert_never(src_loc.kind)
1595 src = state.gpr(src_loc, is_vec=is_vec)
1596 dest = state.stack(dest_loc)
1597 state.writeln(f"{sv}std {src}, {dest}")
1598 elif src_loc.kind is LocKind.GPR:
1599 if dest_loc.kind is not LocKind.GPR:
1600 assert_never(dest_loc.kind)
1601 src = state.gpr(src_loc, is_vec=is_vec)
1602 dest = state.gpr(dest_loc, is_vec=is_vec)
1603 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1604 else:
1605 assert_never(src_loc.kind)
1606
1607 @staticmethod
1608 def __veccopytoreg_gen_asm(op, state):
1609 # type: (Op, GenAsmState) -> None
1610 OpKind.__copy_to_from_reg_gen_asm(
1611 src_loc=state.loc(
1612 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1613 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1614 is_vec=True, state=state)
1615
1616 VecCopyToReg = GenericOpProperties(
1617 demo_asm="sv.mv dest, src",
1618 inputs=[GenericOperandDesc(
1619 ty=GenericTy(BaseTy.I64, is_vec=True),
1620 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1621 ), OD_VL],
1622 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1623 is_copy=True,
1624 )
1625 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1626 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1627
1628 @staticmethod
1629 def __veccopyfromreg_sim(op, state):
1630 # type: (Op, BaseSimState) -> None
1631 state[op.outputs[0]] = state[op.input_vals[0]]
1632
1633 @staticmethod
1634 def __veccopyfromreg_gen_asm(op, state):
1635 # type: (Op, GenAsmState) -> None
1636 OpKind.__copy_to_from_reg_gen_asm(
1637 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1638 dest_loc=state.loc(
1639 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1640 is_vec=True, state=state)
1641 VecCopyFromReg = GenericOpProperties(
1642 demo_asm="sv.mv dest, src",
1643 inputs=[OD_EXTRA3_VGPR, OD_VL],
1644 outputs=[GenericOperandDesc(
1645 ty=GenericTy(BaseTy.I64, is_vec=True),
1646 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1647 write_stage=OpStage.Late,
1648 )],
1649 is_copy=True,
1650 )
1651 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1652 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1653
1654 @staticmethod
1655 def __copytoreg_sim(op, state):
1656 # type: (Op, BaseSimState) -> None
1657 state[op.outputs[0]] = state[op.input_vals[0]]
1658
1659 @staticmethod
1660 def __copytoreg_gen_asm(op, state):
1661 # type: (Op, GenAsmState) -> None
1662 OpKind.__copy_to_from_reg_gen_asm(
1663 src_loc=state.loc(
1664 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1665 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1666 is_vec=False, state=state)
1667 CopyToReg = GenericOpProperties(
1668 demo_asm="mv dest, src",
1669 inputs=[GenericOperandDesc(
1670 ty=GenericTy(BaseTy.I64, is_vec=False),
1671 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1672 LocSubKind.StackI64],
1673 )],
1674 outputs=[GenericOperandDesc(
1675 ty=GenericTy(BaseTy.I64, is_vec=False),
1676 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1677 write_stage=OpStage.Late,
1678 )],
1679 is_copy=True,
1680 )
1681 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1682 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1683
1684 @staticmethod
1685 def __copyfromreg_sim(op, state):
1686 # type: (Op, BaseSimState) -> None
1687 state[op.outputs[0]] = state[op.input_vals[0]]
1688
1689 @staticmethod
1690 def __copyfromreg_gen_asm(op, state):
1691 # type: (Op, GenAsmState) -> None
1692 OpKind.__copy_to_from_reg_gen_asm(
1693 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1694 dest_loc=state.loc(
1695 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1696 is_vec=False, state=state)
1697 CopyFromReg = GenericOpProperties(
1698 demo_asm="mv dest, src",
1699 inputs=[GenericOperandDesc(
1700 ty=GenericTy(BaseTy.I64, is_vec=False),
1701 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1702 )],
1703 outputs=[GenericOperandDesc(
1704 ty=GenericTy(BaseTy.I64, is_vec=False),
1705 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1706 LocSubKind.StackI64],
1707 write_stage=OpStage.Late,
1708 )],
1709 is_copy=True,
1710 )
1711 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1712 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1713
1714 @staticmethod
1715 def __concat_sim(op, state):
1716 # type: (Op, BaseSimState) -> None
1717 state[op.outputs[0]] = tuple(
1718 state[i][0] for i in op.input_vals[:-1])
1719
1720 @staticmethod
1721 def __concat_gen_asm(op, state):
1722 # type: (Op, GenAsmState) -> None
1723 OpKind.__copy_to_from_reg_gen_asm(
1724 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1725 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1726 is_vec=True, state=state)
1727 Concat = GenericOpProperties(
1728 demo_asm="sv.mv dest, src",
1729 inputs=[GenericOperandDesc(
1730 ty=GenericTy(BaseTy.I64, is_vec=False),
1731 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1732 spread=True,
1733 ), OD_VL],
1734 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1735 is_copy=True,
1736 )
1737 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1738 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1739
1740 @staticmethod
1741 def __spread_sim(op, state):
1742 # type: (Op, BaseSimState) -> None
1743 for idx, inp in enumerate(state[op.input_vals[0]]):
1744 state[op.outputs[idx]] = inp,
1745
1746 @staticmethod
1747 def __spread_gen_asm(op, state):
1748 # type: (Op, GenAsmState) -> None
1749 OpKind.__copy_to_from_reg_gen_asm(
1750 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1751 dest_loc=state.loc(op.outputs, LocKind.GPR),
1752 is_vec=True, state=state)
1753 Spread = GenericOpProperties(
1754 demo_asm="sv.mv dest, src",
1755 inputs=[OD_EXTRA3_VGPR, OD_VL],
1756 outputs=[GenericOperandDesc(
1757 ty=GenericTy(BaseTy.I64, is_vec=False),
1758 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1759 spread=True,
1760 write_stage=OpStage.Late,
1761 )],
1762 is_copy=True,
1763 )
1764 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1765 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1766
1767 @staticmethod
1768 def __svld_sim(op, state):
1769 # type: (Op, BaseSimState) -> None
1770 RA, = state[op.input_vals[0]]
1771 VL, = state[op.input_vals[1]]
1772 addr = RA + op.immediates[0]
1773 RT = [] # type: list[int]
1774 for i in range(VL):
1775 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1776 RT.append(v & GPR_VALUE_MASK)
1777 state[op.outputs[0]] = tuple(RT)
1778
1779 @staticmethod
1780 def __svld_gen_asm(op, state):
1781 # type: (Op, GenAsmState) -> None
1782 RA = state.sgpr(op.input_vals[0])
1783 RT = state.vgpr(op.outputs[0])
1784 imm = op.immediates[0]
1785 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1786 SvLd = GenericOpProperties(
1787 demo_asm="sv.ld *RT, imm(RA)",
1788 inputs=[OD_EXTRA3_SGPR, OD_VL],
1789 outputs=[OD_EXTRA3_VGPR],
1790 immediates=[IMM_S16],
1791 )
1792 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1793 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1794
1795 @staticmethod
1796 def __ld_sim(op, state):
1797 # type: (Op, BaseSimState) -> None
1798 RA, = state[op.input_vals[0]]
1799 addr = RA + op.immediates[0]
1800 v = state.load(addr)
1801 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1802
1803 @staticmethod
1804 def __ld_gen_asm(op, state):
1805 # type: (Op, GenAsmState) -> None
1806 RA = state.sgpr(op.input_vals[0])
1807 RT = state.sgpr(op.outputs[0])
1808 imm = op.immediates[0]
1809 state.writeln(f"ld {RT}, {imm}({RA})")
1810 Ld = GenericOpProperties(
1811 demo_asm="ld RT, imm(RA)",
1812 inputs=[OD_BASE_SGPR],
1813 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1814 immediates=[IMM_S16],
1815 )
1816 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1817 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1818
1819 @staticmethod
1820 def __svstd_sim(op, state):
1821 # type: (Op, BaseSimState) -> None
1822 RS = state[op.input_vals[0]]
1823 RA, = state[op.input_vals[1]]
1824 VL, = state[op.input_vals[2]]
1825 addr = RA + op.immediates[0]
1826 for i in range(VL):
1827 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1828
1829 @staticmethod
1830 def __svstd_gen_asm(op, state):
1831 # type: (Op, GenAsmState) -> None
1832 RS = state.vgpr(op.input_vals[0])
1833 RA = state.sgpr(op.input_vals[1])
1834 imm = op.immediates[0]
1835 state.writeln(f"sv.std {RS}, {imm}({RA})")
1836 SvStd = GenericOpProperties(
1837 demo_asm="sv.std *RS, imm(RA)",
1838 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1839 outputs=[],
1840 immediates=[IMM_S16],
1841 has_side_effects=True,
1842 )
1843 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1844 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1845
1846 @staticmethod
1847 def __std_sim(op, state):
1848 # type: (Op, BaseSimState) -> None
1849 RS, = state[op.input_vals[0]]
1850 RA, = state[op.input_vals[1]]
1851 addr = RA + op.immediates[0]
1852 state.store(addr, value=RS)
1853
1854 @staticmethod
1855 def __std_gen_asm(op, state):
1856 # type: (Op, GenAsmState) -> None
1857 RS = state.sgpr(op.input_vals[0])
1858 RA = state.sgpr(op.input_vals[1])
1859 imm = op.immediates[0]
1860 state.writeln(f"std {RS}, {imm}({RA})")
1861 Std = GenericOpProperties(
1862 demo_asm="std RS, imm(RA)",
1863 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1864 outputs=[],
1865 immediates=[IMM_S16],
1866 has_side_effects=True,
1867 )
1868 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1869 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1870
1871 @staticmethod
1872 def __funcargr3_sim(op, state):
1873 # type: (Op, BaseSimState) -> None
1874 pass # return value set before simulation
1875
1876 @staticmethod
1877 def __funcargr3_gen_asm(op, state):
1878 # type: (Op, GenAsmState) -> None
1879 pass # no instructions needed
1880 FuncArgR3 = GenericOpProperties(
1881 demo_asm="",
1882 inputs=[],
1883 outputs=[OD_BASE_SGPR.with_fixed_loc(
1884 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1885 )
1886 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1887 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1888
1889
1890 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1891 class SSAValOrUse(Interned):
1892 op: "Op"
1893 operand_idx: int
1894
1895 def __post_init__(self):
1896 if self.operand_idx < 0 or \
1897 self.operand_idx >= len(self.descriptor_array):
1898 raise ValueError("invalid operand_idx")
1899
1900 @abstractmethod
1901 def __repr__(self):
1902 # type: () -> str
1903 ...
1904
1905 @property
1906 @abstractmethod
1907 def descriptor_array(self):
1908 # type: () -> tuple[OperandDesc, ...]
1909 ...
1910
1911 @cached_property
1912 def defining_descriptor(self):
1913 # type: () -> OperandDesc
1914 return self.descriptor_array[self.operand_idx]
1915
1916 @cached_property
1917 def ty(self):
1918 # type: () -> Ty
1919 return self.defining_descriptor.ty
1920
1921 @cached_property
1922 def ty_before_spread(self):
1923 # type: () -> Ty
1924 return self.defining_descriptor.ty_before_spread
1925
1926 @property
1927 def base_ty(self):
1928 # type: () -> BaseTy
1929 return self.ty_before_spread.base_ty
1930
1931 @property
1932 def reg_offset_in_unspread(self):
1933 """ the number of reg-sized slots in the unspread Loc before self's Loc
1934
1935 e.g. if the unspread Loc containing self is:
1936 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1937 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1938 then reg_offset_into_unspread == 2 == 10 - 8
1939 """
1940 return self.defining_descriptor.reg_offset_in_unspread
1941
1942 @property
1943 def unspread_start_idx(self):
1944 # type: () -> int
1945 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1946
1947 @property
1948 def unspread_start(self):
1949 # type: () -> Self
1950 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1951
1952
1953 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1954 @final
1955 class SSAVal(SSAValOrUse):
1956 __slots__ = ()
1957
1958 def __repr__(self):
1959 # type: () -> str
1960 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1961
1962 @cached_property
1963 def def_loc_set_before_spread(self):
1964 # type: () -> LocSet
1965 return self.defining_descriptor.loc_set_before_spread
1966
1967 @cached_property
1968 def descriptor_array(self):
1969 # type: () -> tuple[OperandDesc, ...]
1970 return self.op.properties.outputs
1971
1972 @cached_property
1973 def tied_input(self):
1974 # type: () -> None | SSAUse
1975 if self.defining_descriptor.tied_input_index is None:
1976 return None
1977 return SSAUse(op=self.op,
1978 operand_idx=self.defining_descriptor.tied_input_index)
1979
1980 @property
1981 def write_stage(self):
1982 # type: () -> OpStage
1983 return self.defining_descriptor.write_stage
1984
1985 @property
1986 def current_debugging_value(self):
1987 # type: () -> tuple[int, ...]
1988 """ get the current value for debugging in pdb or similar.
1989
1990 This is intended for use with
1991 `PreRASimState.set_current_debugging_state`.
1992
1993 This is only intended for debugging, do not use in unit tests or
1994 production code.
1995 """
1996 return PreRASimState.get_current_debugging_state()[self]
1997
1998 @cached_property
1999 def ssa_val_sub_regs(self):
2000 # type: () -> tuple[SSAValSubReg, ...]
2001 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
2002
2003
2004 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
2005 @final
2006 class SSAUse(SSAValOrUse):
2007 __slots__ = ()
2008
2009 @cached_property
2010 def use_loc_set_before_spread(self):
2011 # type: () -> LocSet
2012 return self.defining_descriptor.loc_set_before_spread
2013
2014 @cached_property
2015 def descriptor_array(self):
2016 # type: () -> tuple[OperandDesc, ...]
2017 return self.op.properties.inputs
2018
2019 def __repr__(self):
2020 # type: () -> str
2021 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
2022
2023 @property
2024 def ssa_val(self):
2025 # type: () -> SSAVal
2026 return self.op.input_vals[self.operand_idx]
2027
2028 @ssa_val.setter
2029 def ssa_val(self, ssa_val):
2030 # type: (SSAVal) -> None
2031 self.op.input_vals[self.operand_idx] = ssa_val
2032
2033
2034 _T = TypeVar("_T")
2035 _Desc = TypeVar("_Desc")
2036
2037
2038 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
2039 @abstractmethod
2040 def _verify_write_with_desc(self, idx, item, desc):
2041 # type: (int, _T | Any, _Desc) -> None
2042 raise NotImplementedError
2043
2044 @final
2045 def _verify_write(self, idx, item):
2046 # type: (int | Any, _T | Any) -> int
2047 if not isinstance(idx, int):
2048 if isinstance(idx, slice):
2049 raise TypeError(
2050 f"can't write to slice of {self.__class__.__name__}")
2051 raise TypeError(f"can't write with index {idx!r}")
2052 # normalize idx, raising IndexError if it is out of range
2053 idx = range(len(self.descriptors))[idx]
2054 desc = self.descriptors[idx]
2055 self._verify_write_with_desc(idx, item, desc)
2056 return idx
2057
2058 @abstractmethod
2059 def _get_descriptors(self):
2060 # type: () -> tuple[_Desc, ...]
2061 raise NotImplementedError
2062
2063 @cached_property
2064 @final
2065 def descriptors(self):
2066 # type: () -> tuple[_Desc, ...]
2067 return self._get_descriptors()
2068
2069 @property
2070 @final
2071 def op(self):
2072 return self.__op
2073
2074 def __init__(self, items, op):
2075 # type: (Iterable[_T], Op) -> None
2076 super().__init__()
2077 self.__op = op
2078 self.__items = [] # type: list[_T]
2079 for idx, item in enumerate(items):
2080 if idx >= len(self.descriptors):
2081 raise ValueError("too many items")
2082 _ = self._verify_write(idx, item)
2083 self.__items.append(item)
2084 if len(self.__items) < len(self.descriptors):
2085 raise ValueError("not enough items")
2086
2087 @final
2088 def __iter__(self):
2089 # type: () -> Iterator[_T]
2090 yield from self.__items
2091
2092 @overload
2093 def __getitem__(self, idx):
2094 # type: (int) -> _T
2095 ...
2096
2097 @overload
2098 def __getitem__(self, idx):
2099 # type: (slice) -> list[_T]
2100 ...
2101
2102 @final
2103 def __getitem__(self, idx):
2104 # type: (int | slice) -> _T | list[_T]
2105 return self.__items[idx]
2106
2107 @final
2108 def __setitem__(self, idx, item):
2109 # type: (int, _T) -> None
2110 idx = self._verify_write(idx, item)
2111 self.__items[idx] = item
2112
2113 @final
2114 def __len__(self):
2115 # type: () -> int
2116 return len(self.__items)
2117
2118 def __repr__(self):
2119 # type: () -> str
2120 return f"{self.__class__.__name__}({self.__items}, op=...)"
2121
2122
2123 @final
2124 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2125 def _get_descriptors(self):
2126 # type: () -> tuple[OperandDesc, ...]
2127 return self.op.properties.inputs
2128
2129 def _verify_write_with_desc(self, idx, item, desc):
2130 # type: (int, SSAVal | Any, OperandDesc) -> None
2131 if not isinstance(item, SSAVal):
2132 raise TypeError("expected value of type SSAVal")
2133 if item.ty != desc.ty:
2134 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2135 f"corresponding input's type {desc.ty!r}")
2136
2137 def __init__(self, items, op):
2138 # type: (Iterable[SSAVal], Op) -> None
2139 if hasattr(op, "inputs"):
2140 raise ValueError("Op.inputs already set")
2141 super().__init__(items, op)
2142
2143
2144 @final
2145 class OpImmediates(OpInputSeq[int, range]):
2146 def _get_descriptors(self):
2147 # type: () -> tuple[range, ...]
2148 return self.op.properties.immediates
2149
2150 def _verify_write_with_desc(self, idx, item, desc):
2151 # type: (int, int | Any, range) -> None
2152 if not isinstance(item, int):
2153 raise TypeError("expected value of type int")
2154 if item not in desc:
2155 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2156
2157 def __init__(self, items, op):
2158 # type: (Iterable[int], Op) -> None
2159 if hasattr(op, "immediates"):
2160 raise ValueError("Op.immediates already set")
2161 super().__init__(items, op)
2162
2163
2164 @plain_data.plain_data(frozen=True, eq=False, repr=False)
2165 @final
2166 class Op:
2167 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2168 "outputs", "name")
2169
2170 def __init__(self, fn, properties, input_vals, immediates, name=""):
2171 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2172 self.fn = fn
2173 self.properties = properties
2174 self.input_vals = OpInputVals(input_vals, op=self)
2175 inputs_len = len(self.properties.inputs)
2176 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2177 self.immediates = OpImmediates(immediates, op=self)
2178 outputs_len = len(self.properties.outputs)
2179 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2180 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2181
2182 @property
2183 def kind(self):
2184 # type: () -> OpKind
2185 return self.properties.kind
2186
2187 def __eq__(self, other):
2188 # type: (Op | Any) -> bool
2189 if isinstance(other, Op):
2190 return self is other
2191 return NotImplemented
2192
2193 def __hash__(self):
2194 # type: () -> int
2195 return object.__hash__(self)
2196
2197 def __repr__(self, wrap_width=63, indent=" "):
2198 # type: (int, str) -> str
2199 WRAP_POINT = "\u200B" # zero-width space
2200 items = [f"{self.name}:\n"]
2201 for i, out in enumerate(self.outputs):
2202 item = f"<...outputs[{i}]: {out.ty}>"
2203 if i == 0:
2204 item = "(" + WRAP_POINT + item
2205 if i != len(self.outputs) - 1:
2206 item += ", " + WRAP_POINT
2207 else:
2208 item += WRAP_POINT + ") <= "
2209 items.append(item)
2210 items.append(self.kind._name_)
2211 if len(self.input_vals) + len(self.immediates) != 0:
2212 items[-1] += "("
2213 items[-1] += WRAP_POINT
2214 for i, inp in enumerate(self.input_vals):
2215 item = repr(inp)
2216 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2217 item += ", " + WRAP_POINT
2218 else:
2219 item += ") " + WRAP_POINT
2220 items.append(item)
2221 for i, imm in enumerate(self.immediates):
2222 item = hex(imm)
2223 if i != len(self.immediates) - 1:
2224 item += ", " + WRAP_POINT
2225 else:
2226 item += ") " + WRAP_POINT
2227 items.append(item)
2228 lines = [] # type: list[str]
2229 for i, line_in in enumerate("".join(items).splitlines()):
2230 if i != 0:
2231 line_in = indent + line_in
2232 line_out = ""
2233 for part in line_in.split(WRAP_POINT):
2234 if line_out == "":
2235 line_out = part
2236 continue
2237 trial_line_out = line_out + part
2238 if len(trial_line_out.rstrip()) > wrap_width:
2239 lines.append(line_out.rstrip())
2240 line_out = indent + part
2241 else:
2242 line_out = trial_line_out
2243 lines.append(line_out.rstrip())
2244 return "\n".join(lines)
2245
2246 def sim(self, state):
2247 # type: (BaseSimState) -> None
2248 for inp in self.input_vals:
2249 try:
2250 val = state[inp]
2251 except KeyError:
2252 raise ValueError(f"SSAVal {inp} not yet assigned when "
2253 f"running {self}")
2254 except SimSkipOp:
2255 continue
2256 if len(val) != inp.ty.reg_len:
2257 raise ValueError(
2258 f"value of SSAVal {inp} has wrong number of elements: "
2259 f"expected {inp.ty.reg_len} found "
2260 f"{len(val)}: {val!r}")
2261 if isinstance(state, PreRASimState):
2262 for out in self.outputs:
2263 if out in state.ssa_vals:
2264 if self.kind is OpKind.FuncArgR3:
2265 continue
2266 raise ValueError(f"SSAVal {out} already assigned before "
2267 f"running {self}")
2268 try:
2269 self.kind.sim(self, state)
2270 except SimSkipOp:
2271 state.on_skip(self)
2272 for out in self.outputs:
2273 try:
2274 val = state[out]
2275 except KeyError:
2276 raise ValueError(f"running {self} failed to assign to {out}")
2277 except SimSkipOp:
2278 continue
2279 if len(val) != out.ty.reg_len:
2280 raise ValueError(
2281 f"value of SSAVal {out} has wrong number of elements: "
2282 f"expected {out.ty.reg_len} found "
2283 f"{len(val)}: {val!r}")
2284
2285 def gen_asm(self, state):
2286 # type: (GenAsmState) -> None
2287 all_loc_kinds = tuple(LocKind)
2288 for inp in self.input_vals:
2289 state.loc(inp, expected_kinds=all_loc_kinds)
2290 for out in self.outputs:
2291 state.loc(out, expected_kinds=all_loc_kinds)
2292 self.kind.gen_asm(self, state)
2293
2294
2295 @plain_data.plain_data(frozen=True, repr=False)
2296 class BaseSimState(metaclass=ABCMeta):
2297 __slots__ = "memory",
2298
2299 def __init__(self, memory):
2300 # type: (dict[int, int]) -> None
2301 super().__init__()
2302 self.memory = memory # type: dict[int, int]
2303
2304 def _default_memory_value(self):
2305 # type: () -> int
2306 return 0
2307
2308 def on_skip(self, op):
2309 # type: (Op) -> None
2310 raise ValueError("skipping instructions not supported")
2311
2312 def load_byte(self, addr):
2313 # type: (int) -> int
2314 addr &= GPR_VALUE_MASK
2315 try:
2316 return self.memory[addr] & 0xFF
2317 except KeyError:
2318 return self._default_memory_value()
2319
2320 def store_byte(self, addr, value):
2321 # type: (int, int) -> None
2322 addr &= GPR_VALUE_MASK
2323 value &= 0xFF
2324 self.memory[addr] = value
2325
2326 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2327 # type: (int, int, bool) -> int
2328 if addr % size_in_bytes != 0:
2329 raise ValueError(f"address not aligned: {hex(addr)} "
2330 f"required alignment: {size_in_bytes}")
2331 retval = 0
2332 for i in range(size_in_bytes):
2333 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2334 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2335 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2336 return retval
2337
2338 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2339 # type: (int, int, int) -> None
2340 if addr % size_in_bytes != 0:
2341 raise ValueError(f"address not aligned: {hex(addr)} "
2342 f"required alignment: {size_in_bytes}")
2343 for i in range(size_in_bytes):
2344 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2345
2346 def _memory__repr(self):
2347 # type: () -> str
2348 if len(self.memory) == 0:
2349 return "{}"
2350 keys = sorted(self.memory.keys(), reverse=True)
2351 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2352 items = [] # type: list[str]
2353 while len(keys) != 0:
2354 addr = keys[-1]
2355 if (len(keys) >= CHUNK_SIZE
2356 and addr % CHUNK_SIZE == 0
2357 and keys[-CHUNK_SIZE:]
2358 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2359 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2360 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2361 keys[-CHUNK_SIZE:] = ()
2362 else:
2363 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2364 if len(items) == 1:
2365 return f"{{{items[0]}}}"
2366 items_str = ",\n".join(items)
2367 return f"{{\n{items_str}}}"
2368
2369 def __repr__(self):
2370 # type: () -> str
2371 field_vals = [] # type: list[str]
2372 for name in plain_data.fields(self):
2373 try:
2374 value = getattr(self, name)
2375 except AttributeError:
2376 field_vals.append(f"{name}=<not set>")
2377 continue
2378 repr_fn = getattr(self, f"_{name}__repr", None)
2379 if callable(repr_fn):
2380 field_vals.append(f"{name}={repr_fn()}")
2381 else:
2382 field_vals.append(f"{name}={value!r}")
2383 field_vals_str = ", ".join(field_vals)
2384 return f"{self.__class__.__name__}({field_vals_str})"
2385
2386 @abstractmethod
2387 def __getitem__(self, ssa_val):
2388 # type: (SSAVal) -> tuple[int, ...]
2389 ...
2390
2391 @abstractmethod
2392 def __setitem__(self, ssa_val, value):
2393 # type: (SSAVal, Iterable[int]) -> None
2394 ...
2395
2396
2397 @plain_data.plain_data(frozen=True, repr=False)
2398 class PreRABaseSimState(BaseSimState):
2399 __slots__ = "ssa_vals",
2400
2401 def __init__(self, ssa_vals, memory):
2402 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2403 super().__init__(memory)
2404 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2405
2406 def _ssa_vals__repr(self):
2407 # type: () -> str
2408 if len(self.ssa_vals) == 0:
2409 return "{}"
2410 items = [] # type: list[str]
2411 CHUNK_SIZE = 4
2412 for k, v in self.ssa_vals.items():
2413 element_strs = [] # type: list[str]
2414 for i, el in enumerate(v):
2415 if i % CHUNK_SIZE != 0:
2416 element_strs.append(" " + hex(el))
2417 else:
2418 element_strs.append("\n " + hex(el))
2419 if len(element_strs) <= CHUNK_SIZE:
2420 element_strs[0] = element_strs[0].lstrip()
2421 if len(element_strs) == 1:
2422 element_strs.append("")
2423 v_str = ",".join(element_strs)
2424 items.append(f"{k!r}: ({v_str})")
2425 if len(items) == 1 and "\n" not in items[0]:
2426 return f"{{{items[0]}}}"
2427 items_str = ",\n".join(items)
2428 return f"{{\n{items_str},\n}}"
2429
2430 def __getitem__(self, ssa_val):
2431 # type: (SSAVal) -> tuple[int, ...]
2432 try:
2433 return self.ssa_vals[ssa_val]
2434 except KeyError:
2435 return self._handle_undefined_ssa_val(ssa_val)
2436
2437 def _handle_undefined_ssa_val(self, ssa_val):
2438 # type: (SSAVal) -> tuple[int, ...]
2439 raise KeyError("SSAVal has no value set", ssa_val)
2440
2441 def __setitem__(self, ssa_val, value):
2442 # type: (SSAVal, Iterable[int]) -> None
2443 value = tuple(map(int, value))
2444 if len(value) != ssa_val.ty.reg_len:
2445 raise ValueError("value has wrong len")
2446 self.ssa_vals[ssa_val] = value
2447
2448
2449 class SimSkipOp(Exception):
2450 pass
2451
2452
2453 @plain_data.plain_data(frozen=True, repr=False)
2454 @final
2455 class ConstPropagationState(PreRABaseSimState):
2456 __slots__ = "skipped_ops",
2457
2458 def __init__(self, ssa_vals, memory, skipped_ops):
2459 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2460 super().__init__(ssa_vals, memory)
2461 self.skipped_ops = skipped_ops
2462
2463 def _default_memory_value(self):
2464 # type: () -> int
2465 raise SimSkipOp
2466
2467 def _handle_undefined_ssa_val(self, ssa_val):
2468 # type: (SSAVal) -> tuple[int, ...]
2469 raise SimSkipOp
2470
2471 def on_skip(self, op):
2472 # type: (Op) -> None
2473 self.skipped_ops.add(op)
2474
2475
2476 @plain_data.plain_data(frozen=True, repr=False)
2477 class PreRASimState(PreRABaseSimState):
2478 __slots__ = ()
2479
2480 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2481
2482 @contextmanager
2483 def set_as_current_debugging_state(self):
2484 """ return a context manager that sets self as the current state for
2485 debugging in pdb or similar. This is intended only for use with
2486 `get_current_debugging_state` which should not be used in unit tests
2487 or production code.
2488 """
2489 try:
2490 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2491 yield
2492 finally:
2493 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2494 "inconsistent __CURRENT_DEBUGGING_STATE"
2495
2496 @staticmethod
2497 def get_current_debugging_state():
2498 # type: () -> PreRASimState
2499 """ get the current state for debugging in pdb or similar.
2500
2501 This is intended for use with `set_current_debugging_state`.
2502
2503 This is only intended for debugging, do not use in unit tests or
2504 production code.
2505 """
2506 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2507 raise ValueError("no current debugging state")
2508 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2509
2510
2511 @plain_data.plain_data(frozen=True, repr=False)
2512 @final
2513 class PostRASimState(BaseSimState):
2514 __slots__ = "ssa_val_to_loc_map", "loc_values"
2515
2516 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2517 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2518 super().__init__(memory)
2519 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2520 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2521 if ssa_val.ty != loc.ty:
2522 raise ValueError(
2523 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2524 self.loc_values = loc_values
2525 for loc in self.loc_values.keys():
2526 if loc.reg_len != 1:
2527 raise ValueError(
2528 "loc_values must only contain Locs with reg_len=1, all "
2529 "larger Locs will be split into reg_len=1 sub-Locs")
2530
2531 def _loc_values__repr(self):
2532 # type: () -> str
2533 locs = sorted(self.loc_values.keys(),
2534 key=lambda v: (v.kind.name, v.start))
2535 items = [] # type: list[str]
2536 for loc in locs:
2537 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2538 items_str = ",\n".join(items)
2539 return f"{{\n{items_str},\n}}"
2540
2541 def __getitem__(self, ssa_val):
2542 # type: (SSAVal) -> tuple[int, ...]
2543 loc = self.ssa_val_to_loc_map[ssa_val]
2544 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2545 retval = [] # type: list[int]
2546 for i in range(loc.reg_len):
2547 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2548 retval.append(self.loc_values.get(subloc, 0))
2549 return tuple(retval)
2550
2551 def __setitem__(self, ssa_val, value):
2552 # type: (SSAVal, Iterable[int]) -> None
2553 value = tuple(map(int, value))
2554 if len(value) != ssa_val.ty.reg_len:
2555 raise ValueError("value has wrong len")
2556 loc = self.ssa_val_to_loc_map[ssa_val]
2557 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2558 for i in range(loc.reg_len):
2559 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2560 self.loc_values[subloc] = value[i]
2561
2562
2563 @plain_data.plain_data(frozen=True)
2564 class GenAsmState:
2565 __slots__ = "allocated_locs", "output"
2566
2567 def __init__(self, allocated_locs, output=None):
2568 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2569 super().__init__()
2570 self.allocated_locs = FMap(allocated_locs)
2571 for ssa_val, loc in self.allocated_locs.items():
2572 if ssa_val.ty != loc.ty:
2573 raise ValueError(
2574 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2575 if output is None:
2576 output = []
2577 self.output = output
2578
2579 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2580
2581 def loc(self, ssa_val_or_locs, expected_kinds):
2582 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2583 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2584 ssa_val_or_locs = [ssa_val_or_locs]
2585 locs = [] # type: list[Loc]
2586 for i in ssa_val_or_locs:
2587 if isinstance(i, SSAVal):
2588 locs.append(self.allocated_locs[i])
2589 else:
2590 locs.append(i)
2591 if len(locs) == 0:
2592 raise ValueError("invalid Loc sequence: must not be empty")
2593 retval = locs[0].try_concat(*locs[1:])
2594 if retval is None:
2595 raise ValueError("invalid Loc sequence: try_concat failed")
2596 if isinstance(expected_kinds, LocKind):
2597 expected_kinds = expected_kinds,
2598 if retval.kind not in expected_kinds:
2599 if len(expected_kinds) == 1:
2600 expected_kinds = expected_kinds[0]
2601 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2602 f"{retval.kind} expected {expected_kinds}")
2603 return retval
2604
2605 def gpr(self, ssa_val_or_locs, is_vec):
2606 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2607 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2608 vec_str = "*" if is_vec else ""
2609 return vec_str + str(loc.start)
2610
2611 def sgpr(self, ssa_val_or_locs):
2612 # type: (__SSA_VAL_OR_LOCS) -> str
2613 return self.gpr(ssa_val_or_locs, is_vec=False)
2614
2615 def vgpr(self, ssa_val_or_locs):
2616 # type: (__SSA_VAL_OR_LOCS) -> str
2617 return self.gpr(ssa_val_or_locs, is_vec=True)
2618
2619 def stack(self, ssa_val_or_locs):
2620 # type: (__SSA_VAL_OR_LOCS) -> str
2621 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2622 return f"{loc.start}(1)"
2623
2624 def writeln(self, *line_segments):
2625 # type: (*str) -> None
2626 line = " ".join(line_segments)
2627 if isinstance(self.output, list):
2628 self.output.append(line)
2629 else:
2630 self.output.write(line + "\n")