9e4d13d96fc28f021a16860fcef5717910dd2296
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 import dataclasses
4 from abc import ABCMeta, abstractmethod
5 from enum import Enum, unique
6 from functools import lru_cache, total_ordering
7 from io import StringIO
8 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
9 Mapping, Sequence, TypeVar, Union, overload)
10 from weakref import WeakValueDictionary as _WeakVDict
11
12 from cached_property import cached_property
13 from nmutil import plain_data # type: ignore
14
15 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
16 final)
17 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
18 OFSet, OSet)
19
20 GPR_SIZE_IN_BYTES = 8
21 BITS_IN_BYTE = 8
22 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
23 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
24
25
26 @final
27 class Fn:
28 def __init__(self):
29 self.ops = [] # type: list[Op]
30 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
31 self.__next_name_suffix = 2
32
33 def _add_op_with_unused_name(self, op, name=""):
34 # type: (Op, str) -> str
35 if op.fn is not self:
36 raise ValueError("can't add Op to wrong Fn")
37 if hasattr(op, "name"):
38 raise ValueError("Op already named")
39 orig_name = name
40 while True:
41 if name != "" and name not in self.__op_names:
42 self.__op_names[name] = op
43 return name
44 name = orig_name + str(self.__next_name_suffix)
45 self.__next_name_suffix += 1
46
47 def __repr__(self):
48 # type: () -> str
49 return "<Fn>"
50
51 def ops_to_str(self, as_python_literal=False, wrap_width=63,
52 python_indent=" ", indent=" "):
53 # type: (bool, int, str, str) -> str
54 l = [] # type: list[str]
55 for op in self.ops:
56 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
57 retval = "\n".join(l)
58 if as_python_literal:
59 l = [python_indent + "\""]
60 for ch in retval:
61 if ch == "\n":
62 l.append(f"\\n\"\n{python_indent}\"")
63 elif ch in "\"\\":
64 l.append("\\" + ch)
65 elif ch.isascii() and ch.isprintable():
66 l.append(ch)
67 else:
68 l.append(repr(ch).strip("\"'"))
69 l.append("\"")
70 retval = "".join(l)
71 empty_end = f"\"\n{python_indent}\"\""
72 if retval.endswith(empty_end):
73 retval = retval[:-len(empty_end)]
74 return retval
75
76 def append_op(self, op):
77 # type: (Op) -> None
78 if op.fn is not self:
79 raise ValueError("can't add Op to wrong Fn")
80 self.ops.append(op)
81
82 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
83 maxvl=1):
84 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
85 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
86 input_vals=input_vals, immediates=immediates, name=name)
87 self.append_op(retval)
88 return retval
89
90 def sim(self, state):
91 # type: (BaseSimState) -> None
92 for op in self.ops:
93 op.sim(state)
94
95 def gen_asm(self, state):
96 # type: (GenAsmState) -> None
97 for op in self.ops:
98 op.gen_asm(state)
99
100 def pre_ra_insert_copies(self):
101 # type: () -> None
102 orig_ops = list(self.ops)
103 copied_outputs = {} # type: dict[SSAVal, SSAVal]
104 setvli_outputs = {} # type: dict[SSAVal, Op]
105 self.ops.clear()
106 for op in orig_ops:
107 for i in range(len(op.input_vals)):
108 inp = copied_outputs[op.input_vals[i]]
109 if inp.ty.base_ty is BaseTy.I64:
110 maxvl = inp.ty.reg_len
111 if inp.ty.reg_len != 1:
112 setvl = self.append_new_op(
113 OpKind.SetVLI, immediates=[maxvl],
114 name=f"{op.name}.inp{i}.setvl")
115 vl = setvl.outputs[0]
116 mv = self.append_new_op(
117 OpKind.VecCopyToReg, input_vals=[inp, vl],
118 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
119 else:
120 mv = self.append_new_op(
121 OpKind.CopyToReg, input_vals=[inp],
122 name=f"{op.name}.inp{i}.copy")
123 op.input_vals[i] = mv.outputs[0]
124 elif inp.ty.base_ty is BaseTy.CA \
125 or inp.ty.base_ty is BaseTy.VL_MAXVL:
126 # all copies would be no-ops, so we don't need to copy,
127 # though we do need to rematerialize SetVLI ops right
128 # before the ops VL
129 if inp in setvli_outputs:
130 setvl = self.append_new_op(
131 OpKind.SetVLI,
132 immediates=setvli_outputs[inp].immediates,
133 name=f"{op.name}.inp{i}.setvl")
134 inp = setvl.outputs[0]
135 op.input_vals[i] = inp
136 else:
137 assert_never(inp.ty.base_ty)
138 self.ops.append(op)
139 for i, out in enumerate(op.outputs):
140 if op.kind is OpKind.SetVLI:
141 setvli_outputs[out] = op
142 if out.ty.base_ty is BaseTy.I64:
143 maxvl = out.ty.reg_len
144 if out.ty.reg_len != 1:
145 setvl = self.append_new_op(
146 OpKind.SetVLI, immediates=[maxvl],
147 name=f"{op.name}.out{i}.setvl")
148 vl = setvl.outputs[0]
149 mv = self.append_new_op(
150 OpKind.VecCopyFromReg, input_vals=[out, vl],
151 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
152 else:
153 mv = self.append_new_op(
154 OpKind.CopyFromReg, input_vals=[out],
155 name=f"{op.name}.out{i}.copy")
156 copied_outputs[out] = mv.outputs[0]
157 elif out.ty.base_ty is BaseTy.CA \
158 or out.ty.base_ty is BaseTy.VL_MAXVL:
159 # all copies would be no-ops, so we don't need to copy
160 copied_outputs[out] = out
161 else:
162 assert_never(out.ty.base_ty)
163
164
165 @final
166 @unique
167 @total_ordering
168 class OpStage(Enum):
169 value: Literal[0, 1] # type: ignore
170
171 def __new__(cls, value):
172 # type: (int) -> OpStage
173 value = int(value)
174 if value not in (0, 1):
175 raise ValueError("invalid value")
176 retval = object.__new__(cls)
177 retval._value_ = value
178 return retval
179
180 Early = 0
181 """ early stage of Op execution, where all input reads occur.
182 all output writes with `write_stage == Early` occur here too, and therefore
183 conflict with input reads, telling the compiler that it that can't share
184 that output's register with any inputs that the output isn't tied to.
185
186 All outputs, even unused outputs, can't share registers with any other
187 outputs, independent of `write_stage` settings.
188 """
189 Late = 1
190 """ late stage of Op execution, where all output writes with
191 `write_stage == Late` occur, and therefore don't conflict with input reads,
192 telling the compiler that any inputs can safely use the same register as
193 those outputs.
194
195 All outputs, even unused outputs, can't share registers with any other
196 outputs, independent of `write_stage` settings.
197 """
198
199 def __repr__(self):
200 # type: () -> str
201 return f"OpStage.{self._name_}"
202
203 def __lt__(self, other):
204 # type: (OpStage | object) -> bool
205 if isinstance(other, OpStage):
206 return self.value < other.value
207 return NotImplemented
208
209
210 assert OpStage.Early < OpStage.Late, "early must be less than late"
211
212
213 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
214 @final
215 class ProgramPoint(Interned):
216 op_index: int
217 stage: OpStage
218
219 @property
220 def int_value(self):
221 # type: () -> int
222 """ an integer representation of `self` such that it keeps ordering and
223 successor/predecessor relations.
224 """
225 return self.op_index * 2 + self.stage.value
226
227 @staticmethod
228 def from_int_value(int_value):
229 # type: (int) -> ProgramPoint
230 op_index, stage = divmod(int_value, 2)
231 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
232
233 def next(self, steps=1):
234 # type: (int) -> ProgramPoint
235 return ProgramPoint.from_int_value(self.int_value + steps)
236
237 def prev(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return self.next(steps=-steps)
240
241 def __lt__(self, other):
242 # type: (ProgramPoint | Any) -> bool
243 if not isinstance(other, ProgramPoint):
244 return NotImplemented
245 if self.op_index != other.op_index:
246 return self.op_index < other.op_index
247 return self.stage < other.stage
248
249 def __gt__(self, other):
250 # type: (ProgramPoint | Any) -> bool
251 if not isinstance(other, ProgramPoint):
252 return NotImplemented
253 return other.__lt__(self)
254
255 def __le__(self, other):
256 # type: (ProgramPoint | Any) -> bool
257 if not isinstance(other, ProgramPoint):
258 return NotImplemented
259 return not self.__gt__(other)
260
261 def __ge__(self, other):
262 # type: (ProgramPoint | Any) -> bool
263 if not isinstance(other, ProgramPoint):
264 return NotImplemented
265 return not self.__lt__(other)
266
267 def __repr__(self):
268 # type: () -> str
269 return f"<ops[{self.op_index}]:{self.stage._name_}>"
270
271
272 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
273 @final
274 class ProgramRange(Sequence[ProgramPoint], Interned):
275 start: ProgramPoint
276 stop: ProgramPoint
277
278 @cached_property
279 def int_value_range(self):
280 # type: () -> range
281 return range(self.start.int_value, self.stop.int_value)
282
283 @staticmethod
284 def from_int_value_range(int_value_range):
285 # type: (range) -> ProgramRange
286 if int_value_range.step != 1:
287 raise ValueError("int_value_range must have step == 1")
288 return ProgramRange(
289 start=ProgramPoint.from_int_value(int_value_range.start),
290 stop=ProgramPoint.from_int_value(int_value_range.stop))
291
292 @overload
293 def __getitem__(self, __idx):
294 # type: (int) -> ProgramPoint
295 ...
296
297 @overload
298 def __getitem__(self, __idx):
299 # type: (slice) -> ProgramRange
300 ...
301
302 def __getitem__(self, __idx):
303 # type: (int | slice) -> ProgramPoint | ProgramRange
304 v = range(self.start.int_value, self.stop.int_value)[__idx]
305 if isinstance(v, int):
306 return ProgramPoint.from_int_value(v)
307 return ProgramRange.from_int_value_range(v)
308
309 def __len__(self):
310 # type: () -> int
311 return len(self.int_value_range)
312
313 def __iter__(self):
314 # type: () -> Iterator[ProgramPoint]
315 return map(ProgramPoint.from_int_value, self.int_value_range)
316
317 def __repr__(self):
318 # type: () -> str
319 start = repr(self.start).lstrip("<").rstrip(">")
320 stop = repr(self.stop).lstrip("<").rstrip(">")
321 return f"<range:{start}..{stop}>"
322
323
324 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
325 @final
326 class SSAValSubReg(Interned):
327 ssa_val: "SSAVal"
328 reg_idx: int
329
330 def __post_init__(self):
331 if self.reg_idx < 0 or self.reg_idx >= self.ssa_val.ty.reg_len:
332 raise ValueError("reg_idx out of range")
333
334 def __repr__(self):
335 # type: () -> str
336 return f"{self.ssa_val}[{self.reg_idx}]"
337
338
339 @plain_data.plain_data(frozen=True, eq=False, repr=False)
340 @final
341 class FnAnalysis:
342 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
343 "def_program_ranges", "use_program_points",
344 "all_program_points")
345
346 def __init__(self, fn):
347 # type: (Fn) -> None
348 self.fn = fn
349 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
350 self.all_program_points = ProgramRange(
351 start=ProgramPoint(op_index=0, stage=OpStage.Early),
352 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
353 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
354 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
355 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
356 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
357 for op in fn.ops:
358 for use in op.input_uses:
359 uses[use.ssa_val].add(use)
360 use_program_point = self.__get_use_program_point(use)
361 use_program_points[use] = use_program_point
362 live_range_stops[use.ssa_val] = max(
363 live_range_stops[use.ssa_val], use_program_point.next())
364 for out in op.outputs:
365 uses[out] = OSet()
366 def_program_range = self.__get_def_program_range(out)
367 def_program_ranges[out] = def_program_range
368 live_range_stops[out] = def_program_range.stop
369 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
370 self.def_program_ranges = FMap(def_program_ranges)
371 self.use_program_points = FMap(use_program_points)
372 live_ranges = {} # type: dict[SSAVal, ProgramRange]
373 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
374 for ssa_val in uses.keys():
375 live_ranges[ssa_val] = live_range = ProgramRange(
376 start=self.def_program_ranges[ssa_val].start,
377 stop=live_range_stops[ssa_val])
378 for program_point in live_range:
379 live_at[program_point].add(ssa_val)
380 self.live_ranges = FMap(live_ranges)
381 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
382 self.copies # initialize
383 self.const_ssa_vals # initialize
384 self.const_ssa_val_sub_regs # initialize
385
386 def __get_def_program_range(self, ssa_val):
387 # type: (SSAVal) -> ProgramRange
388 write_stage = ssa_val.defining_descriptor.write_stage
389 start = ProgramPoint(
390 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
391 # always include late stage of ssa_val.op, to ensure outputs always
392 # overlap all other outputs.
393 # stop is exclusive, so we need the next program point.
394 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
395 return ProgramRange(start=start, stop=stop)
396
397 def __get_use_program_point(self, ssa_use):
398 # type: (SSAUse) -> ProgramPoint
399 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
400 "assumed here, ensured by GenericOpProperties.__init__"
401 return ProgramPoint(
402 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
403
404 def __eq__(self, other):
405 # type: (FnAnalysis | Any) -> bool
406 if isinstance(other, FnAnalysis):
407 return self.fn == other.fn
408 return NotImplemented
409
410 def __hash__(self):
411 # type: () -> int
412 return hash(self.fn)
413
414 def __repr__(self):
415 # type: () -> str
416 return "<FnAnalysis>"
417
418 @cached_property
419 def copies(self):
420 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
421 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
422 a copy of, looking through all layers of copies. The map excludes all
423 SSAValSubRegs that aren't copies of other SSAValSubRegs.
424 This ignores inputs of copy Ops that aren't actually being copied
425 (e.g. the VL input of VecCopyToReg).
426 """
427 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
428 for op in self.op_indexes.keys():
429 if not op.properties.is_copy:
430 continue
431 copy_reg_len = op.properties.copy_reg_len
432 copy_inputs = [] # type: list[SSAValSubReg]
433 for inp in op.input_vals[:op.properties.copy_inputs_len]:
434 for inp_sub_reg in inp.ssa_val_sub_regs:
435 # propagate copies of copies
436 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
437 copy_inputs.append(inp_sub_reg)
438 assert len(copy_inputs) == copy_reg_len, "logic error"
439 copy_outputs = [] # type: list[SSAValSubReg]
440 for out in op.outputs[:op.properties.copy_outputs_len]:
441 copy_outputs.extend(out.ssa_val_sub_regs)
442 assert len(copy_outputs) == copy_reg_len, "logic error"
443 for inp, out in zip(copy_inputs, copy_outputs):
444 retval[out] = inp
445 return FMap(retval)
446
447 @cached_property
448 def copy_related_ssa_vals(self):
449 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
450 """ map from SSAVals to the full set of SSAVals that are related by
451 being sources/destinations of copies, transitively looking through all
452 copies.
453 This ignores inputs of copy Ops that aren't actually being copied
454 (e.g. the VL input of VecCopyToReg).
455 """
456 sets_map = {i: OSet([i]) for i in self.uses.keys()}
457 for k, v in self.copies.items():
458 k_set = sets_map[k.ssa_val]
459 v_set = sets_map[v.ssa_val]
460 # merge k_set and v_set
461 if k_set is v_set:
462 continue
463 k_set |= v_set
464 for i in k_set:
465 sets_map[i] = k_set
466 # this way we construct each OFSet only once rather than
467 # for each SSAVal
468 sets_set = {id(i): i for i in sets_map.values()}
469 retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
470 for v in sets_set.values():
471 v = OFSet(v)
472 for k in v:
473 retval[k] = v
474 return FMap(retval)
475
476 @cached_property
477 def const_ssa_vals(self):
478 # type: () -> FMap[SSAVal, tuple[int, ...]]
479 state = ConstPropagationState(
480 ssa_vals={}, memory={}, skipped_ops=OSet())
481 self.fn.sim(state)
482 return FMap(state.ssa_vals)
483
484 @cached_property
485 def const_ssa_val_sub_regs(self):
486 # type: () -> FMap[SSAValSubReg, int]
487 retval = {} # type: dict[SSAValSubReg, int]
488 for ssa_val, const_val in self.const_ssa_vals.items():
489 assert ssa_val.ty.reg_len == len(const_val), "logic error"
490 for reg_idx, v in enumerate(const_val):
491 retval[SSAValSubReg(ssa_val, reg_idx)] = v
492 return FMap(retval)
493
494 def is_always_equal(self, a, b):
495 # type: (SSAValSubReg, SSAValSubReg) -> bool
496 """check if a and b are known to be always equal to each other.
497 This means they can be allocated to the same location if other
498 constraints don't prevent that.
499
500 this can happen for a number of reasons, such as:
501 * a and b are copies of the same thing
502 * a and b are known to be constants and they have the same value
503 """
504 if a.ssa_val.base_ty != b.ssa_val.base_ty:
505 return False # can't be equal, they have different types
506 # look through copies
507 a = self.copies.get(a, a)
508 b = self.copies.get(b, b)
509 if a == b:
510 return True
511 # check if they have the same constant value
512 try:
513 a_const_val = self.const_ssa_val_sub_regs[a]
514 b_const_val = self.const_ssa_val_sub_regs[b]
515 if a_const_val == b_const_val:
516 return True
517 except KeyError:
518 pass
519 return False
520
521
522 @unique
523 @final
524 class BaseTy(Enum):
525 I64 = enum.auto()
526 CA = enum.auto()
527 VL_MAXVL = enum.auto()
528
529 @cached_property
530 def only_scalar(self):
531 # type: () -> bool
532 if self is BaseTy.I64:
533 return False
534 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
535 return True
536 else:
537 assert_never(self)
538
539 @cached_property
540 def max_reg_len(self):
541 # type: () -> int
542 if self is BaseTy.I64:
543 return 128
544 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
545 return 1
546 else:
547 assert_never(self)
548
549 def __repr__(self):
550 return "BaseTy." + self._name_
551
552
553 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
554 @final
555 class Ty(Interned):
556 base_ty: BaseTy
557 reg_len: int
558
559 @staticmethod
560 def validate(base_ty, reg_len):
561 # type: (BaseTy, int) -> str | None
562 """ return a string with the error if the combination is invalid,
563 otherwise return None
564 """
565 if base_ty.only_scalar and reg_len != 1:
566 return f"can't create a vector of an only-scalar type: {base_ty}"
567 if reg_len < 1 or reg_len > base_ty.max_reg_len:
568 return "reg_len out of range"
569 return None
570
571 def __post_init__(self):
572 msg = self.validate(base_ty=self.base_ty, reg_len=self.reg_len)
573 if msg is not None:
574 raise ValueError(msg)
575
576 def __repr__(self):
577 # type: () -> str
578 if self.reg_len != 1:
579 reg_len = f"*{self.reg_len}"
580 else:
581 reg_len = ""
582 return f"<{self.base_ty._name_}{reg_len}>"
583
584
585 @unique
586 @final
587 class LocKind(Enum):
588 GPR = enum.auto()
589 StackI64 = enum.auto()
590 CA = enum.auto()
591 VL_MAXVL = enum.auto()
592
593 @cached_property
594 def base_ty(self):
595 # type: () -> BaseTy
596 if self is LocKind.GPR or self is LocKind.StackI64:
597 return BaseTy.I64
598 if self is LocKind.CA:
599 return BaseTy.CA
600 if self is LocKind.VL_MAXVL:
601 return BaseTy.VL_MAXVL
602 else:
603 assert_never(self)
604
605 @cached_property
606 def loc_count(self):
607 # type: () -> int
608 if self is LocKind.StackI64:
609 return 512
610 if self is LocKind.GPR or self is LocKind.CA \
611 or self is LocKind.VL_MAXVL:
612 return self.base_ty.max_reg_len
613 else:
614 assert_never(self)
615
616 def __repr__(self):
617 return "LocKind." + self._name_
618
619
620 @final
621 @unique
622 class LocSubKind(Enum):
623 BASE_GPR = enum.auto()
624 SV_EXTRA2_VGPR = enum.auto()
625 SV_EXTRA2_SGPR = enum.auto()
626 SV_EXTRA3_VGPR = enum.auto()
627 SV_EXTRA3_SGPR = enum.auto()
628 StackI64 = enum.auto()
629 CA = enum.auto()
630 VL_MAXVL = enum.auto()
631
632 @cached_property
633 def kind(self):
634 # type: () -> LocKind
635 # pyright fails typechecking when using `in` here:
636 # reported: https://github.com/microsoft/pyright/issues/4102
637 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
638 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
639 LocSubKind.SV_EXTRA3_SGPR):
640 return LocKind.GPR
641 if self is LocSubKind.StackI64:
642 return LocKind.StackI64
643 if self is LocSubKind.CA:
644 return LocKind.CA
645 if self is LocSubKind.VL_MAXVL:
646 return LocKind.VL_MAXVL
647 assert_never(self)
648
649 @property
650 def base_ty(self):
651 return self.kind.base_ty
652
653 @lru_cache()
654 def allocatable_locs(self, ty):
655 # type: (Ty) -> LocSet
656 if ty.base_ty != self.base_ty:
657 raise ValueError("type mismatch")
658 if self is LocSubKind.BASE_GPR:
659 starts = range(32)
660 elif self is LocSubKind.SV_EXTRA2_VGPR:
661 starts = range(0, 128, 2)
662 elif self is LocSubKind.SV_EXTRA2_SGPR:
663 starts = range(64)
664 elif self is LocSubKind.SV_EXTRA3_VGPR \
665 or self is LocSubKind.SV_EXTRA3_SGPR:
666 starts = range(128)
667 elif self is LocSubKind.StackI64:
668 starts = range(LocKind.StackI64.loc_count)
669 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
670 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
671 else:
672 assert_never(self)
673 retval = [] # type: list[Loc]
674 for start in starts:
675 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
676 if loc is None:
677 continue
678 conflicts = False
679 for special_loc in SPECIAL_GPRS:
680 if loc.conflicts(special_loc):
681 conflicts = True
682 break
683 if not conflicts:
684 retval.append(loc)
685 return LocSet(retval)
686
687 def __repr__(self):
688 return "LocSubKind." + self._name_
689
690
691 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
692 @final
693 class GenericTy(Interned):
694 base_ty: BaseTy
695 is_vec: bool
696
697 def __post_init__(self):
698 if self.base_ty.only_scalar and self.is_vec:
699 raise ValueError(f"base_ty={self.base_ty} requires is_vec=False")
700
701 def instantiate(self, maxvl):
702 # type: (int) -> Ty
703 # here's where subvl and elwid would be accounted for
704 if self.is_vec:
705 return Ty(self.base_ty, maxvl)
706 return Ty(self.base_ty, 1)
707
708 def can_instantiate_to(self, ty):
709 # type: (Ty) -> bool
710 if self.base_ty != ty.base_ty:
711 return False
712 if self.is_vec:
713 return True
714 return ty.reg_len == 1
715
716
717 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
718 @final
719 class Loc(Interned):
720 kind: LocKind
721 start: int
722 reg_len: int
723
724 @staticmethod
725 def validate(kind, start, reg_len):
726 # type: (LocKind, int, int) -> str | None
727 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
728 if msg is not None:
729 return msg
730 if reg_len > kind.loc_count:
731 return "invalid reg_len"
732 if start < 0 or start + reg_len > kind.loc_count:
733 return "start not in valid range"
734 return None
735
736 @staticmethod
737 def try_make(kind, start, reg_len):
738 # type: (LocKind, int, int) -> Loc | None
739 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
740 if msg is not None:
741 return None
742 return Loc(kind=kind, start=start, reg_len=reg_len)
743
744 def __post_init__(self):
745 msg = self.validate(kind=self.kind, start=self.start,
746 reg_len=self.reg_len)
747 if msg is not None:
748 raise ValueError(msg)
749
750 def conflicts(self, other):
751 # type: (Loc) -> bool
752 return (self.kind == other.kind
753 and self.start < other.stop and other.start < self.stop)
754
755 @staticmethod
756 def make_ty(kind, reg_len):
757 # type: (LocKind, int) -> Ty
758 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
759
760 @cached_property
761 def ty(self):
762 # type: () -> Ty
763 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
764
765 @property
766 def stop(self):
767 # type: () -> int
768 return self.start + self.reg_len
769
770 def try_concat(self, *others):
771 # type: (*Loc | None) -> Loc | None
772 reg_len = self.reg_len
773 stop = self.stop
774 for other in others:
775 if other is None or other.kind != self.kind:
776 return None
777 if stop != other.start:
778 return None
779 stop = other.stop
780 reg_len += other.reg_len
781 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
782
783 def get_subloc_at_offset(self, subloc_ty, offset):
784 # type: (Ty, int) -> Loc
785 if subloc_ty.base_ty != self.kind.base_ty:
786 raise ValueError("BaseTy mismatch")
787 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
788 raise ValueError("invalid sub-Loc: offset and/or "
789 "subloc_ty.reg_len out of range")
790 return Loc(kind=self.kind,
791 start=self.start + offset, reg_len=subloc_ty.reg_len)
792
793 def get_superloc_with_self_at_offset(self, superloc_ty, offset):
794 # type: (Ty, int) -> Loc
795 """get the Loc containing `self` such that:
796 `retval.get_subloc_at_offset(self.ty, offset) == self`
797 and `retval.ty == superloc_ty`
798 """
799 if superloc_ty.base_ty != self.kind.base_ty:
800 raise ValueError("BaseTy mismatch")
801 if offset < 0 or offset + self.reg_len > superloc_ty.reg_len:
802 raise ValueError("invalid sub-Loc: offset and/or "
803 "superloc_ty.reg_len out of range")
804 return Loc(kind=self.kind,
805 start=self.start - offset, reg_len=superloc_ty.reg_len)
806
807
808 SPECIAL_GPRS = (
809 Loc(kind=LocKind.GPR, start=0, reg_len=1),
810 Loc(kind=LocKind.GPR, start=1, reg_len=1),
811 Loc(kind=LocKind.GPR, start=2, reg_len=1),
812 Loc(kind=LocKind.GPR, start=13, reg_len=1),
813 )
814
815
816 @final
817 class LocSet(OFSet[Loc], Interned):
818 def __init__(self, __locs=()):
819 # type: (Iterable[Loc]) -> None
820 super().__init__(__locs)
821 if isinstance(__locs, LocSet):
822 self.__starts = __locs.starts
823 self.__ty = __locs.ty
824 return
825 starts = {i: BitSet() for i in LocKind}
826 ty = None # type: None | Ty
827 for loc in self:
828 if ty is None:
829 ty = loc.ty
830 if ty != loc.ty:
831 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
832 starts[loc.kind].add(loc.start)
833 self.__starts = FMap(
834 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
835 self.__ty = ty
836
837 @property
838 def starts(self):
839 # type: () -> FMap[LocKind, FBitSet]
840 return self.__starts
841
842 @property
843 def ty(self):
844 # type: () -> Ty | None
845 return self.__ty
846
847 @cached_property
848 def stops(self):
849 # type: () -> FMap[LocKind, FBitSet]
850 if self.ty is None:
851 return FMap()
852 sh = self.ty.reg_len
853 return FMap(
854 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
855
856 @property
857 def kinds(self):
858 # type: () -> AbstractSet[LocKind]
859 return self.starts.keys()
860
861 @property
862 def reg_len(self):
863 # type: () -> int | None
864 if self.ty is None:
865 return None
866 return self.ty.reg_len
867
868 @property
869 def base_ty(self):
870 # type: () -> BaseTy | None
871 if self.ty is None:
872 return None
873 return self.ty.base_ty
874
875 def concat(self, *others):
876 # type: (*LocSet) -> LocSet
877 if self.ty is None:
878 return LocSet()
879 base_ty = self.ty.base_ty
880 reg_len = self.ty.reg_len
881 starts = {k: BitSet(v) for k, v in self.starts.items()}
882 for other in others:
883 if other.ty is None:
884 return LocSet()
885 if other.ty.base_ty != base_ty:
886 return LocSet()
887 for kind, other_starts in other.starts.items():
888 if kind not in starts:
889 continue
890 starts[kind].bits &= other_starts.bits >> reg_len
891 if starts[kind] == 0:
892 del starts[kind]
893 if len(starts) == 0:
894 return LocSet()
895 reg_len += other.ty.reg_len
896
897 def locs():
898 # type: () -> Iterable[Loc]
899 for kind, v in starts.items():
900 for start in v:
901 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
902 if loc is not None:
903 yield loc
904 return LocSet(locs())
905
906 @lru_cache(maxsize=None, typed=True)
907 def max_conflicts_with(self, other):
908 # type: (LocSet | Loc) -> int
909 """the largest number of Locs in `self` that a single Loc
910 from `other` can conflict with
911 """
912 if isinstance(other, LocSet):
913 return max(self.max_conflicts_with(i) for i in other)
914 else:
915 return sum(other.conflicts(i) for i in self)
916
917 def __repr__(self):
918 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
919
920 @cached_property
921 def only_loc(self):
922 # type: () -> Loc | None
923 """if len(self) == 1 then return the Loc in self, otherwise None"""
924 only_loc = None
925 for i in self:
926 if only_loc is None:
927 only_loc = i
928 else:
929 return None # len(self) > 1
930 return only_loc
931
932
933 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
934 @final
935 class GenericOperandDesc(Interned):
936 """generic Op operand descriptor"""
937 ty: GenericTy
938 sub_kinds: OFSet[LocSubKind]
939 fixed_loc: "Loc | None" = None
940 tied_input_index: "int | None" = None
941 spread: bool = False
942 write_stage: OpStage = OpStage.Early
943
944 def __init__(
945 self, ty, # type: GenericTy
946 sub_kinds, # type: Iterable[LocSubKind]
947 *,
948 fixed_loc=None, # type: Loc | None
949 tied_input_index=None, # type: int | None
950 spread=False, # type: bool
951 write_stage=OpStage.Early, # type: OpStage
952 ):
953 # type: (...) -> None
954 object.__setattr__(self, "ty", ty)
955 object.__setattr__(self, "sub_kinds", OFSet(sub_kinds))
956 if len(self.sub_kinds) == 0:
957 raise ValueError("sub_kinds can't be empty")
958 object.__setattr__(self, "fixed_loc", fixed_loc)
959 if fixed_loc is not None:
960 if tied_input_index is not None:
961 raise ValueError("operand can't be both tied and fixed")
962 if not ty.can_instantiate_to(fixed_loc.ty):
963 raise ValueError(
964 f"fixed_loc has incompatible type for given generic "
965 f"type: fixed_loc={fixed_loc} generic ty={ty}")
966 if len(self.sub_kinds) != 1:
967 raise ValueError(
968 "multiple sub_kinds not allowed for fixed operand")
969 for sub_kind in self.sub_kinds:
970 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
971 raise ValueError(
972 f"fixed_loc not in given sub_kind: "
973 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
974 for sub_kind in self.sub_kinds:
975 if sub_kind.base_ty != ty.base_ty:
976 raise ValueError(f"sub_kind is incompatible with type: "
977 f"sub_kind={sub_kind} ty={ty}")
978 if tied_input_index is not None and tied_input_index < 0:
979 raise ValueError("invalid tied_input_index")
980 object.__setattr__(self, "tied_input_index", tied_input_index)
981 object.__setattr__(self, "spread", spread)
982 if spread:
983 if self.tied_input_index is not None:
984 raise ValueError("operand can't be both spread and tied")
985 if self.fixed_loc is not None:
986 raise ValueError("operand can't be both spread and fixed")
987 if self.ty.is_vec:
988 raise ValueError("operand can't be both spread and vector")
989 object.__setattr__(self, "write_stage", write_stage)
990
991 @cached_property
992 def ty_before_spread(self):
993 # type: () -> GenericTy
994 if self.spread:
995 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
996 return self.ty
997
998 def tied_to_input(self, tied_input_index):
999 # type: (int) -> Self
1000 return GenericOperandDesc(self.ty, self.sub_kinds,
1001 tied_input_index=tied_input_index,
1002 write_stage=self.write_stage)
1003
1004 def with_fixed_loc(self, fixed_loc):
1005 # type: (Loc) -> Self
1006 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
1007 write_stage=self.write_stage)
1008
1009 def with_write_stage(self, write_stage):
1010 # type: (OpStage) -> Self
1011 return GenericOperandDesc(self.ty, self.sub_kinds,
1012 fixed_loc=self.fixed_loc,
1013 tied_input_index=self.tied_input_index,
1014 spread=self.spread,
1015 write_stage=write_stage)
1016
1017 def instantiate(self, maxvl):
1018 # type: (int) -> Iterable[OperandDesc]
1019 # assumes all spread operands have ty.reg_len = 1
1020 rep_count = 1
1021 if self.spread:
1022 rep_count = maxvl
1023 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
1024
1025 def locs_before_spread():
1026 # type: () -> Iterable[Loc]
1027 if self.fixed_loc is not None:
1028 if ty_before_spread != self.fixed_loc.ty:
1029 raise ValueError(
1030 f"instantiation failed: type mismatch with fixed_loc: "
1031 f"instantiated type: {ty_before_spread} "
1032 f"fixed_loc: {self.fixed_loc}")
1033 yield self.fixed_loc
1034 return
1035 for sub_kind in self.sub_kinds:
1036 yield from sub_kind.allocatable_locs(ty_before_spread)
1037 loc_set_before_spread = LocSet(locs_before_spread())
1038 for idx in range(rep_count):
1039 if not self.spread:
1040 idx = None
1041 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
1042 tied_input_index=self.tied_input_index,
1043 spread_index=idx, write_stage=self.write_stage)
1044
1045
1046 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1047 @final
1048 class OperandDesc(Interned):
1049 """Op operand descriptor"""
1050 loc_set_before_spread: LocSet
1051 tied_input_index: "int | None"
1052 spread_index: "int | None"
1053 write_stage: "OpStage"
1054
1055 def __post_init__(self):
1056 if len(self.loc_set_before_spread) == 0:
1057 raise ValueError("loc_set_before_spread must not be empty")
1058 if self.tied_input_index is not None and self.spread_index is not None:
1059 raise ValueError("operand can't be both spread and tied")
1060
1061 @cached_property
1062 def ty_before_spread(self):
1063 # type: () -> Ty
1064 ty = self.loc_set_before_spread.ty
1065 assert ty is not None, (
1066 "__init__ checked that the LocSet isn't empty, "
1067 "non-empty LocSets should always have ty set")
1068 return ty
1069
1070 @cached_property
1071 def ty(self):
1072 """ Ty after any spread is applied """
1073 if self.spread_index is not None:
1074 # assumes all spread operands have ty.reg_len = 1
1075 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1076 return self.ty_before_spread
1077
1078 @property
1079 def reg_offset_in_unspread(self):
1080 """ the number of reg-sized slots in the unspread Loc before self's Loc
1081
1082 e.g. if the unspread Loc containing self is:
1083 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1084 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1085 then reg_offset_into_unspread == 2 == 10 - 8
1086 """
1087 if self.spread_index is None:
1088 return 0
1089 return self.spread_index * self.ty.reg_len
1090
1091
1092 OD_BASE_SGPR = GenericOperandDesc(
1093 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1094 sub_kinds=[LocSubKind.BASE_GPR])
1095 OD_EXTRA3_SGPR = GenericOperandDesc(
1096 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1097 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1098 OD_EXTRA3_VGPR = GenericOperandDesc(
1099 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1100 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1101 OD_EXTRA2_SGPR = GenericOperandDesc(
1102 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1103 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1104 OD_EXTRA2_VGPR = GenericOperandDesc(
1105 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1106 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1107 OD_CA = GenericOperandDesc(
1108 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1109 sub_kinds=[LocSubKind.CA])
1110 OD_VL = GenericOperandDesc(
1111 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1112 sub_kinds=[LocSubKind.VL_MAXVL])
1113
1114
1115 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1116 @final
1117 class GenericOpProperties(Interned):
1118 demo_asm: str
1119 inputs: "tuple[GenericOperandDesc, ...]"
1120 outputs: "tuple[GenericOperandDesc, ...]"
1121 immediates: "tuple[range, ...]"
1122 is_copy: bool
1123 is_load_immediate: bool
1124 has_side_effects: bool
1125
1126 def __init__(
1127 self, demo_asm, # type: str
1128 inputs, # type: Iterable[GenericOperandDesc]
1129 outputs, # type: Iterable[GenericOperandDesc]
1130 immediates=(), # type: Iterable[range]
1131 is_copy=False, # type: bool
1132 is_load_immediate=False, # type: bool
1133 has_side_effects=False, # type: bool
1134 ):
1135 # type: (...) -> None
1136 object.__setattr__(self, "demo_asm", demo_asm)
1137 object.__setattr__(self, "inputs", tuple(inputs))
1138 for inp in self.inputs:
1139 if inp.tied_input_index is not None:
1140 raise ValueError(
1141 f"tied_input_index is not allowed on inputs: {inp}")
1142 if inp.write_stage is not OpStage.Early:
1143 raise ValueError(
1144 f"write_stage is not allowed on inputs: {inp}")
1145 object.__setattr__(self, "outputs", tuple(outputs))
1146 fixed_locs = [] # type: list[tuple[Loc, int]]
1147 for idx, out in enumerate(self.outputs):
1148 if out.tied_input_index is not None:
1149 if out.tied_input_index >= len(self.inputs):
1150 raise ValueError(f"tied_input_index out of range: {out}")
1151 tied_inp = self.inputs[out.tied_input_index]
1152 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1153 .with_write_stage(out.write_stage)
1154 if expected_out != out:
1155 raise ValueError(f"output can't be tied to non-equivalent "
1156 f"input: {out} tied to {tied_inp}")
1157 if out.fixed_loc is not None:
1158 for other_fixed_loc, other_idx in fixed_locs:
1159 if not other_fixed_loc.conflicts(out.fixed_loc):
1160 continue
1161 raise ValueError(
1162 f"conflicting fixed_locs: outputs[{idx}] and "
1163 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1164 f"with {other_fixed_loc}")
1165 fixed_locs.append((out.fixed_loc, idx))
1166 object.__setattr__(self, "immediates", tuple(immediates))
1167 object.__setattr__(self, "is_copy", is_copy)
1168 object.__setattr__(self, "is_load_immediate", is_load_immediate)
1169 object.__setattr__(self, "has_side_effects", has_side_effects)
1170
1171
1172 @plain_data.plain_data(frozen=True, unsafe_hash=True)
1173 @final
1174 class OpProperties:
1175 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1176
1177 def __init__(self, kind, maxvl):
1178 # type: (OpKind, int) -> None
1179 self.kind = kind # type: OpKind
1180 inputs = [] # type: list[OperandDesc]
1181 for inp in self.generic.inputs:
1182 inputs.extend(inp.instantiate(maxvl=maxvl))
1183 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1184 outputs = [] # type: list[OperandDesc]
1185 for out in self.generic.outputs:
1186 outputs.extend(out.instantiate(maxvl=maxvl))
1187 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1188 self.maxvl = maxvl # type: int
1189 copy_input_reg_len = 0
1190 for inp in self.inputs[:self.copy_inputs_len]:
1191 copy_input_reg_len += inp.ty.reg_len
1192 copy_output_reg_len = 0
1193 for out in self.outputs[:self.copy_outputs_len]:
1194 copy_output_reg_len += out.ty.reg_len
1195 if copy_input_reg_len != copy_output_reg_len:
1196 raise ValueError(f"invalid copy: copy's input reg len must "
1197 f"match its output reg len: "
1198 f"{copy_input_reg_len} != {copy_output_reg_len}")
1199 self.copy_reg_len = copy_input_reg_len
1200
1201 @property
1202 def generic(self):
1203 # type: () -> GenericOpProperties
1204 return self.kind.properties
1205
1206 @property
1207 def immediates(self):
1208 # type: () -> tuple[range, ...]
1209 return self.generic.immediates
1210
1211 @property
1212 def demo_asm(self):
1213 # type: () -> str
1214 return self.generic.demo_asm
1215
1216 @property
1217 def is_copy(self):
1218 # type: () -> bool
1219 return self.generic.is_copy
1220
1221 @property
1222 def is_load_immediate(self):
1223 # type: () -> bool
1224 return self.generic.is_load_immediate
1225
1226 @property
1227 def has_side_effects(self):
1228 # type: () -> bool
1229 return self.generic.has_side_effects
1230
1231 @cached_property
1232 def copy_inputs_len(self):
1233 # type: () -> int
1234 if not self.is_copy:
1235 return 0
1236 if self.inputs[0].spread_index is None:
1237 return 1
1238 retval = 0
1239 for i, inp in enumerate(self.inputs):
1240 if inp.spread_index != i:
1241 break
1242 retval += 1
1243 return retval
1244
1245 @cached_property
1246 def copy_outputs_len(self):
1247 # type: () -> int
1248 if not self.is_copy:
1249 return 0
1250 if self.outputs[0].spread_index is None:
1251 return 1
1252 retval = 0
1253 for i, out in enumerate(self.outputs):
1254 if out.spread_index != i:
1255 break
1256 retval += 1
1257 return retval
1258
1259
1260 IMM_S16 = range(-1 << 15, 1 << 15)
1261
1262 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1263 _SIM_FN2 = Callable[[], _SIM_FN]
1264 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1265 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1266 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1267 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1268
1269
1270 @unique
1271 @final
1272 class OpKind(Enum):
1273 def __init__(self, properties):
1274 # type: (GenericOpProperties) -> None
1275 super().__init__()
1276 self.__properties = properties
1277
1278 @property
1279 def properties(self):
1280 # type: () -> GenericOpProperties
1281 return self.__properties
1282
1283 def instantiate(self, maxvl):
1284 # type: (int) -> OpProperties
1285 return OpProperties(self, maxvl=maxvl)
1286
1287 def __repr__(self):
1288 # type: () -> str
1289 return "OpKind." + self._name_
1290
1291 @cached_property
1292 def sim(self):
1293 # type: () -> _SIM_FN
1294 return _SIM_FNS[self.properties]()
1295
1296 @cached_property
1297 def gen_asm(self):
1298 # type: () -> _GEN_ASM_FN
1299 return _GEN_ASMS[self.properties]()
1300
1301 @staticmethod
1302 def __clearca_sim(op, state):
1303 # type: (Op, BaseSimState) -> None
1304 state[op.outputs[0]] = False,
1305
1306 @staticmethod
1307 def __clearca_gen_asm(op, state):
1308 # type: (Op, GenAsmState) -> None
1309 state.writeln("addic 0, 0, 0")
1310 ClearCA = GenericOpProperties(
1311 demo_asm="addic 0, 0, 0",
1312 inputs=[],
1313 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1314 )
1315 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1316 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1317
1318 @staticmethod
1319 def __setca_sim(op, state):
1320 # type: (Op, BaseSimState) -> None
1321 state[op.outputs[0]] = True,
1322
1323 @staticmethod
1324 def __setca_gen_asm(op, state):
1325 # type: (Op, GenAsmState) -> None
1326 state.writeln("subfc 0, 0, 0")
1327 SetCA = GenericOpProperties(
1328 demo_asm="subfc 0, 0, 0",
1329 inputs=[],
1330 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1331 )
1332 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1333 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1334
1335 @staticmethod
1336 def __svadde_sim(op, state):
1337 # type: (Op, BaseSimState) -> None
1338 RA = state[op.input_vals[0]]
1339 RB = state[op.input_vals[1]]
1340 carry, = state[op.input_vals[2]]
1341 VL, = state[op.input_vals[3]]
1342 RT = [] # type: list[int]
1343 for i in range(VL):
1344 v = RA[i] + RB[i] + carry
1345 RT.append(v & GPR_VALUE_MASK)
1346 carry = (v >> GPR_SIZE_IN_BITS) != 0
1347 state[op.outputs[0]] = tuple(RT)
1348 state[op.outputs[1]] = carry,
1349
1350 @staticmethod
1351 def __svadde_gen_asm(op, state):
1352 # type: (Op, GenAsmState) -> None
1353 RT = state.vgpr(op.outputs[0])
1354 RA = state.vgpr(op.input_vals[0])
1355 RB = state.vgpr(op.input_vals[1])
1356 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1357 SvAddE = GenericOpProperties(
1358 demo_asm="sv.adde *RT, *RA, *RB",
1359 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1360 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1361 )
1362 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1363 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1364
1365 @staticmethod
1366 def __addze_sim(op, state):
1367 # type: (Op, BaseSimState) -> None
1368 RA, = state[op.input_vals[0]]
1369 carry, = state[op.input_vals[1]]
1370 v = RA + carry
1371 RT = v & GPR_VALUE_MASK
1372 carry = (v >> GPR_SIZE_IN_BITS) != 0
1373 state[op.outputs[0]] = RT,
1374 state[op.outputs[1]] = carry,
1375
1376 @staticmethod
1377 def __addze_gen_asm(op, state):
1378 # type: (Op, GenAsmState) -> None
1379 RT = state.vgpr(op.outputs[0])
1380 RA = state.vgpr(op.input_vals[0])
1381 state.writeln(f"addze {RT}, {RA}")
1382 AddZE = GenericOpProperties(
1383 demo_asm="addze RT, RA",
1384 inputs=[OD_BASE_SGPR, OD_CA],
1385 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1386 )
1387 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1388 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1389
1390 @staticmethod
1391 def __svsubfe_sim(op, state):
1392 # type: (Op, BaseSimState) -> None
1393 RA = state[op.input_vals[0]]
1394 RB = state[op.input_vals[1]]
1395 carry, = state[op.input_vals[2]]
1396 VL, = state[op.input_vals[3]]
1397 RT = [] # type: list[int]
1398 for i in range(VL):
1399 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1400 RT.append(v & GPR_VALUE_MASK)
1401 carry = (v >> GPR_SIZE_IN_BITS) != 0
1402 state[op.outputs[0]] = tuple(RT)
1403 state[op.outputs[1]] = carry,
1404
1405 @staticmethod
1406 def __svsubfe_gen_asm(op, state):
1407 # type: (Op, GenAsmState) -> None
1408 RT = state.vgpr(op.outputs[0])
1409 RA = state.vgpr(op.input_vals[0])
1410 RB = state.vgpr(op.input_vals[1])
1411 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1412 SvSubFE = GenericOpProperties(
1413 demo_asm="sv.subfe *RT, *RA, *RB",
1414 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1415 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1416 )
1417 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1418 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1419
1420 @staticmethod
1421 def __svandvs_sim(op, state):
1422 # type: (Op, BaseSimState) -> None
1423 RA = state[op.input_vals[0]]
1424 RB, = state[op.input_vals[1]]
1425 VL, = state[op.input_vals[2]]
1426 RT = [] # type: list[int]
1427 for i in range(VL):
1428 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1429 state[op.outputs[0]] = tuple(RT)
1430
1431 @staticmethod
1432 def __svandvs_gen_asm(op, state):
1433 # type: (Op, GenAsmState) -> None
1434 RT = state.vgpr(op.outputs[0])
1435 RA = state.vgpr(op.input_vals[0])
1436 RB = state.sgpr(op.input_vals[1])
1437 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1438 SvAndVS = GenericOpProperties(
1439 demo_asm="sv.and *RT, *RA, RB",
1440 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1441 outputs=[OD_EXTRA3_VGPR],
1442 )
1443 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1444 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1445
1446 @staticmethod
1447 def __svmaddedu_sim(op, state):
1448 # type: (Op, BaseSimState) -> None
1449 RA = state[op.input_vals[0]]
1450 RB, = state[op.input_vals[1]]
1451 carry, = state[op.input_vals[2]]
1452 VL, = state[op.input_vals[3]]
1453 RT = [] # type: list[int]
1454 for i in range(VL):
1455 v = RA[i] * RB + carry
1456 RT.append(v & GPR_VALUE_MASK)
1457 carry = v >> GPR_SIZE_IN_BITS
1458 state[op.outputs[0]] = tuple(RT)
1459 state[op.outputs[1]] = carry,
1460
1461 @staticmethod
1462 def __svmaddedu_gen_asm(op, state):
1463 # type: (Op, GenAsmState) -> None
1464 RT = state.vgpr(op.outputs[0])
1465 RA = state.vgpr(op.input_vals[0])
1466 RB = state.sgpr(op.input_vals[1])
1467 RC = state.sgpr(op.input_vals[2])
1468 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1469 SvMAddEDU = GenericOpProperties(
1470 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1471 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1472 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1473 )
1474 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1475 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1476
1477 @staticmethod
1478 def __sradi_sim(op, state):
1479 # type: (Op, BaseSimState) -> None
1480 rs, = state[op.input_vals[0]]
1481 imm = op.immediates[0]
1482 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1483 rs -= 1 << GPR_SIZE_IN_BITS
1484 v = rs >> imm
1485 RA = v & GPR_VALUE_MASK
1486 CA = (RA << imm) != rs
1487 state[op.outputs[0]] = RA,
1488 state[op.outputs[1]] = CA,
1489
1490 @staticmethod
1491 def __sradi_gen_asm(op, state):
1492 # type: (Op, GenAsmState) -> None
1493 RA = state.sgpr(op.outputs[0])
1494 RS = state.sgpr(op.input_vals[0])
1495 imm = op.immediates[0]
1496 state.writeln(f"sradi {RA}, {RS}, {imm}")
1497 SRADI = GenericOpProperties(
1498 demo_asm="sradi RA, RS, imm",
1499 inputs=[OD_BASE_SGPR],
1500 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1501 OD_CA.with_write_stage(OpStage.Late)],
1502 immediates=[range(GPR_SIZE_IN_BITS)],
1503 )
1504 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1505 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1506
1507 @staticmethod
1508 def __setvli_sim(op, state):
1509 # type: (Op, BaseSimState) -> None
1510 state[op.outputs[0]] = op.immediates[0],
1511
1512 @staticmethod
1513 def __setvli_gen_asm(op, state):
1514 # type: (Op, GenAsmState) -> None
1515 imm = op.immediates[0]
1516 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1517 SetVLI = GenericOpProperties(
1518 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1519 inputs=(),
1520 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1521 immediates=[range(1, 65)],
1522 is_load_immediate=True,
1523 )
1524 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1525 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1526
1527 @staticmethod
1528 def __svli_sim(op, state):
1529 # type: (Op, BaseSimState) -> None
1530 VL, = state[op.input_vals[0]]
1531 imm = op.immediates[0] & GPR_VALUE_MASK
1532 state[op.outputs[0]] = (imm,) * VL
1533
1534 @staticmethod
1535 def __svli_gen_asm(op, state):
1536 # type: (Op, GenAsmState) -> None
1537 RT = state.vgpr(op.outputs[0])
1538 imm = op.immediates[0]
1539 state.writeln(f"sv.addi {RT}, 0, {imm}")
1540 SvLI = GenericOpProperties(
1541 demo_asm="sv.addi *RT, 0, imm",
1542 inputs=[OD_VL],
1543 outputs=[OD_EXTRA3_VGPR],
1544 immediates=[IMM_S16],
1545 is_load_immediate=True,
1546 )
1547 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1548 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1549
1550 @staticmethod
1551 def __li_sim(op, state):
1552 # type: (Op, BaseSimState) -> None
1553 imm = op.immediates[0] & GPR_VALUE_MASK
1554 state[op.outputs[0]] = imm,
1555
1556 @staticmethod
1557 def __li_gen_asm(op, state):
1558 # type: (Op, GenAsmState) -> None
1559 RT = state.sgpr(op.outputs[0])
1560 imm = op.immediates[0]
1561 state.writeln(f"addi {RT}, 0, {imm}")
1562 LI = GenericOpProperties(
1563 demo_asm="addi RT, 0, imm",
1564 inputs=(),
1565 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1566 immediates=[IMM_S16],
1567 is_load_immediate=True,
1568 )
1569 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1570 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1571
1572 @staticmethod
1573 def __veccopytoreg_sim(op, state):
1574 # type: (Op, BaseSimState) -> None
1575 state[op.outputs[0]] = state[op.input_vals[0]]
1576
1577 @staticmethod
1578 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1579 # type: (Loc, Loc, bool, GenAsmState) -> None
1580 sv = "sv." if is_vec else ""
1581 rev = ""
1582 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1583 rev = "/mrr"
1584 if src_loc == dest_loc:
1585 return # no-op
1586 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1587 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1588 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1589 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1590 if src_loc.kind is LocKind.StackI64:
1591 if dest_loc.kind is LocKind.StackI64:
1592 raise ValueError(
1593 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1594 elif dest_loc.kind is not LocKind.GPR:
1595 assert_never(dest_loc.kind)
1596 src = state.stack(src_loc)
1597 dest = state.gpr(dest_loc, is_vec=is_vec)
1598 state.writeln(f"{sv}ld {dest}, {src}")
1599 elif dest_loc.kind is LocKind.StackI64:
1600 if src_loc.kind is not LocKind.GPR:
1601 assert_never(src_loc.kind)
1602 src = state.gpr(src_loc, is_vec=is_vec)
1603 dest = state.stack(dest_loc)
1604 state.writeln(f"{sv}std {src}, {dest}")
1605 elif src_loc.kind is LocKind.GPR:
1606 if dest_loc.kind is not LocKind.GPR:
1607 assert_never(dest_loc.kind)
1608 src = state.gpr(src_loc, is_vec=is_vec)
1609 dest = state.gpr(dest_loc, is_vec=is_vec)
1610 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1611 else:
1612 assert_never(src_loc.kind)
1613
1614 @staticmethod
1615 def __veccopytoreg_gen_asm(op, state):
1616 # type: (Op, GenAsmState) -> None
1617 OpKind.__copy_to_from_reg_gen_asm(
1618 src_loc=state.loc(
1619 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1620 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1621 is_vec=True, state=state)
1622
1623 VecCopyToReg = GenericOpProperties(
1624 demo_asm="sv.mv dest, src",
1625 inputs=[GenericOperandDesc(
1626 ty=GenericTy(BaseTy.I64, is_vec=True),
1627 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1628 ), OD_VL],
1629 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1630 is_copy=True,
1631 )
1632 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1633 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1634
1635 @staticmethod
1636 def __veccopyfromreg_sim(op, state):
1637 # type: (Op, BaseSimState) -> None
1638 state[op.outputs[0]] = state[op.input_vals[0]]
1639
1640 @staticmethod
1641 def __veccopyfromreg_gen_asm(op, state):
1642 # type: (Op, GenAsmState) -> None
1643 OpKind.__copy_to_from_reg_gen_asm(
1644 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1645 dest_loc=state.loc(
1646 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1647 is_vec=True, state=state)
1648 VecCopyFromReg = GenericOpProperties(
1649 demo_asm="sv.mv dest, src",
1650 inputs=[OD_EXTRA3_VGPR, OD_VL],
1651 outputs=[GenericOperandDesc(
1652 ty=GenericTy(BaseTy.I64, is_vec=True),
1653 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1654 write_stage=OpStage.Late,
1655 )],
1656 is_copy=True,
1657 )
1658 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1659 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1660
1661 @staticmethod
1662 def __copytoreg_sim(op, state):
1663 # type: (Op, BaseSimState) -> None
1664 state[op.outputs[0]] = state[op.input_vals[0]]
1665
1666 @staticmethod
1667 def __copytoreg_gen_asm(op, state):
1668 # type: (Op, GenAsmState) -> None
1669 OpKind.__copy_to_from_reg_gen_asm(
1670 src_loc=state.loc(
1671 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1672 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1673 is_vec=False, state=state)
1674 CopyToReg = GenericOpProperties(
1675 demo_asm="mv dest, src",
1676 inputs=[GenericOperandDesc(
1677 ty=GenericTy(BaseTy.I64, is_vec=False),
1678 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1679 LocSubKind.StackI64],
1680 )],
1681 outputs=[GenericOperandDesc(
1682 ty=GenericTy(BaseTy.I64, is_vec=False),
1683 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1684 write_stage=OpStage.Late,
1685 )],
1686 is_copy=True,
1687 )
1688 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1689 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1690
1691 @staticmethod
1692 def __copyfromreg_sim(op, state):
1693 # type: (Op, BaseSimState) -> None
1694 state[op.outputs[0]] = state[op.input_vals[0]]
1695
1696 @staticmethod
1697 def __copyfromreg_gen_asm(op, state):
1698 # type: (Op, GenAsmState) -> None
1699 OpKind.__copy_to_from_reg_gen_asm(
1700 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1701 dest_loc=state.loc(
1702 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1703 is_vec=False, state=state)
1704 CopyFromReg = GenericOpProperties(
1705 demo_asm="mv dest, src",
1706 inputs=[GenericOperandDesc(
1707 ty=GenericTy(BaseTy.I64, is_vec=False),
1708 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1709 )],
1710 outputs=[GenericOperandDesc(
1711 ty=GenericTy(BaseTy.I64, is_vec=False),
1712 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1713 LocSubKind.StackI64],
1714 write_stage=OpStage.Late,
1715 )],
1716 is_copy=True,
1717 )
1718 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1719 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1720
1721 @staticmethod
1722 def __concat_sim(op, state):
1723 # type: (Op, BaseSimState) -> None
1724 state[op.outputs[0]] = tuple(
1725 state[i][0] for i in op.input_vals[:-1])
1726
1727 @staticmethod
1728 def __concat_gen_asm(op, state):
1729 # type: (Op, GenAsmState) -> None
1730 OpKind.__copy_to_from_reg_gen_asm(
1731 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1732 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1733 is_vec=True, state=state)
1734 Concat = GenericOpProperties(
1735 demo_asm="sv.mv dest, src",
1736 inputs=[GenericOperandDesc(
1737 ty=GenericTy(BaseTy.I64, is_vec=False),
1738 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1739 spread=True,
1740 ), OD_VL],
1741 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1742 is_copy=True,
1743 )
1744 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1745 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1746
1747 @staticmethod
1748 def __spread_sim(op, state):
1749 # type: (Op, BaseSimState) -> None
1750 for idx, inp in enumerate(state[op.input_vals[0]]):
1751 state[op.outputs[idx]] = inp,
1752
1753 @staticmethod
1754 def __spread_gen_asm(op, state):
1755 # type: (Op, GenAsmState) -> None
1756 OpKind.__copy_to_from_reg_gen_asm(
1757 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1758 dest_loc=state.loc(op.outputs, LocKind.GPR),
1759 is_vec=True, state=state)
1760 Spread = GenericOpProperties(
1761 demo_asm="sv.mv dest, src",
1762 inputs=[OD_EXTRA3_VGPR, OD_VL],
1763 outputs=[GenericOperandDesc(
1764 ty=GenericTy(BaseTy.I64, is_vec=False),
1765 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1766 spread=True,
1767 write_stage=OpStage.Late,
1768 )],
1769 is_copy=True,
1770 )
1771 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1772 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1773
1774 @staticmethod
1775 def __svld_sim(op, state):
1776 # type: (Op, BaseSimState) -> None
1777 RA, = state[op.input_vals[0]]
1778 VL, = state[op.input_vals[1]]
1779 addr = RA + op.immediates[0]
1780 RT = [] # type: list[int]
1781 for i in range(VL):
1782 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1783 RT.append(v & GPR_VALUE_MASK)
1784 state[op.outputs[0]] = tuple(RT)
1785
1786 @staticmethod
1787 def __svld_gen_asm(op, state):
1788 # type: (Op, GenAsmState) -> None
1789 RA = state.sgpr(op.input_vals[0])
1790 RT = state.vgpr(op.outputs[0])
1791 imm = op.immediates[0]
1792 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1793 SvLd = GenericOpProperties(
1794 demo_asm="sv.ld *RT, imm(RA)",
1795 inputs=[OD_EXTRA3_SGPR, OD_VL],
1796 outputs=[OD_EXTRA3_VGPR],
1797 immediates=[IMM_S16],
1798 )
1799 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1800 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1801
1802 @staticmethod
1803 def __ld_sim(op, state):
1804 # type: (Op, BaseSimState) -> None
1805 RA, = state[op.input_vals[0]]
1806 addr = RA + op.immediates[0]
1807 v = state.load(addr)
1808 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1809
1810 @staticmethod
1811 def __ld_gen_asm(op, state):
1812 # type: (Op, GenAsmState) -> None
1813 RA = state.sgpr(op.input_vals[0])
1814 RT = state.sgpr(op.outputs[0])
1815 imm = op.immediates[0]
1816 state.writeln(f"ld {RT}, {imm}({RA})")
1817 Ld = GenericOpProperties(
1818 demo_asm="ld RT, imm(RA)",
1819 inputs=[OD_BASE_SGPR],
1820 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1821 immediates=[IMM_S16],
1822 )
1823 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1824 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1825
1826 @staticmethod
1827 def __svstd_sim(op, state):
1828 # type: (Op, BaseSimState) -> None
1829 RS = state[op.input_vals[0]]
1830 RA, = state[op.input_vals[1]]
1831 VL, = state[op.input_vals[2]]
1832 addr = RA + op.immediates[0]
1833 for i in range(VL):
1834 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1835
1836 @staticmethod
1837 def __svstd_gen_asm(op, state):
1838 # type: (Op, GenAsmState) -> None
1839 RS = state.vgpr(op.input_vals[0])
1840 RA = state.sgpr(op.input_vals[1])
1841 imm = op.immediates[0]
1842 state.writeln(f"sv.std {RS}, {imm}({RA})")
1843 SvStd = GenericOpProperties(
1844 demo_asm="sv.std *RS, imm(RA)",
1845 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1846 outputs=[],
1847 immediates=[IMM_S16],
1848 has_side_effects=True,
1849 )
1850 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1851 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1852
1853 @staticmethod
1854 def __std_sim(op, state):
1855 # type: (Op, BaseSimState) -> None
1856 RS, = state[op.input_vals[0]]
1857 RA, = state[op.input_vals[1]]
1858 addr = RA + op.immediates[0]
1859 state.store(addr, value=RS)
1860
1861 @staticmethod
1862 def __std_gen_asm(op, state):
1863 # type: (Op, GenAsmState) -> None
1864 RS = state.sgpr(op.input_vals[0])
1865 RA = state.sgpr(op.input_vals[1])
1866 imm = op.immediates[0]
1867 state.writeln(f"std {RS}, {imm}({RA})")
1868 Std = GenericOpProperties(
1869 demo_asm="std RS, imm(RA)",
1870 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1871 outputs=[],
1872 immediates=[IMM_S16],
1873 has_side_effects=True,
1874 )
1875 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1876 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1877
1878 @staticmethod
1879 def __funcargr3_sim(op, state):
1880 # type: (Op, BaseSimState) -> None
1881 pass # return value set before simulation
1882
1883 @staticmethod
1884 def __funcargr3_gen_asm(op, state):
1885 # type: (Op, GenAsmState) -> None
1886 pass # no instructions needed
1887 FuncArgR3 = GenericOpProperties(
1888 demo_asm="",
1889 inputs=[],
1890 outputs=[OD_BASE_SGPR.with_fixed_loc(
1891 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1892 )
1893 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1894 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1895
1896
1897 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1898 class SSAValOrUse(Interned):
1899 op: "Op"
1900 operand_idx: int
1901
1902 def __post_init__(self):
1903 if self.operand_idx < 0 or \
1904 self.operand_idx >= len(self.descriptor_array):
1905 raise ValueError("invalid operand_idx")
1906
1907 @abstractmethod
1908 def __repr__(self):
1909 # type: () -> str
1910 ...
1911
1912 @property
1913 @abstractmethod
1914 def descriptor_array(self):
1915 # type: () -> tuple[OperandDesc, ...]
1916 ...
1917
1918 @cached_property
1919 def defining_descriptor(self):
1920 # type: () -> OperandDesc
1921 return self.descriptor_array[self.operand_idx]
1922
1923 @cached_property
1924 def ty(self):
1925 # type: () -> Ty
1926 return self.defining_descriptor.ty
1927
1928 @cached_property
1929 def ty_before_spread(self):
1930 # type: () -> Ty
1931 return self.defining_descriptor.ty_before_spread
1932
1933 @property
1934 def base_ty(self):
1935 # type: () -> BaseTy
1936 return self.ty_before_spread.base_ty
1937
1938 @property
1939 def reg_offset_in_unspread(self):
1940 """ the number of reg-sized slots in the unspread Loc before self's Loc
1941
1942 e.g. if the unspread Loc containing self is:
1943 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1944 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1945 then reg_offset_into_unspread == 2 == 10 - 8
1946 """
1947 return self.defining_descriptor.reg_offset_in_unspread
1948
1949 @property
1950 def unspread_start_idx(self):
1951 # type: () -> int
1952 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1953
1954 @property
1955 def unspread_start(self):
1956 # type: () -> Self
1957 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1958
1959
1960 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1961 @final
1962 class SSAVal(SSAValOrUse):
1963 __slots__ = ()
1964
1965 def __repr__(self):
1966 # type: () -> str
1967 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1968
1969 @cached_property
1970 def def_loc_set_before_spread(self):
1971 # type: () -> LocSet
1972 return self.defining_descriptor.loc_set_before_spread
1973
1974 @cached_property
1975 def descriptor_array(self):
1976 # type: () -> tuple[OperandDesc, ...]
1977 return self.op.properties.outputs
1978
1979 @cached_property
1980 def tied_input(self):
1981 # type: () -> None | SSAUse
1982 if self.defining_descriptor.tied_input_index is None:
1983 return None
1984 return SSAUse(op=self.op,
1985 operand_idx=self.defining_descriptor.tied_input_index)
1986
1987 @property
1988 def write_stage(self):
1989 # type: () -> OpStage
1990 return self.defining_descriptor.write_stage
1991
1992 @property
1993 def current_debugging_value(self):
1994 # type: () -> tuple[int, ...]
1995 """ get the current value for debugging in pdb or similar.
1996
1997 This is intended for use with
1998 `PreRASimState.set_current_debugging_state`.
1999
2000 This is only intended for debugging, do not use in unit tests or
2001 production code.
2002 """
2003 return PreRASimState.get_current_debugging_state()[self]
2004
2005 @cached_property
2006 def ssa_val_sub_regs(self):
2007 # type: () -> tuple[SSAValSubReg, ...]
2008 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
2009
2010
2011 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
2012 @final
2013 class SSAUse(SSAValOrUse):
2014 __slots__ = ()
2015
2016 @cached_property
2017 def use_loc_set_before_spread(self):
2018 # type: () -> LocSet
2019 return self.defining_descriptor.loc_set_before_spread
2020
2021 @cached_property
2022 def descriptor_array(self):
2023 # type: () -> tuple[OperandDesc, ...]
2024 return self.op.properties.inputs
2025
2026 def __repr__(self):
2027 # type: () -> str
2028 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
2029
2030 @property
2031 def ssa_val(self):
2032 # type: () -> SSAVal
2033 return self.op.input_vals[self.operand_idx]
2034
2035 @ssa_val.setter
2036 def ssa_val(self, ssa_val):
2037 # type: (SSAVal) -> None
2038 self.op.input_vals[self.operand_idx] = ssa_val
2039
2040
2041 _T = TypeVar("_T")
2042 _Desc = TypeVar("_Desc")
2043
2044
2045 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
2046 @abstractmethod
2047 def _verify_write_with_desc(self, idx, item, desc):
2048 # type: (int, _T | Any, _Desc) -> None
2049 raise NotImplementedError
2050
2051 @final
2052 def _verify_write(self, idx, item):
2053 # type: (int | Any, _T | Any) -> int
2054 if not isinstance(idx, int):
2055 if isinstance(idx, slice):
2056 raise TypeError(
2057 f"can't write to slice of {self.__class__.__name__}")
2058 raise TypeError(f"can't write with index {idx!r}")
2059 # normalize idx, raising IndexError if it is out of range
2060 idx = range(len(self.descriptors))[idx]
2061 desc = self.descriptors[idx]
2062 self._verify_write_with_desc(idx, item, desc)
2063 return idx
2064
2065 def _on_set(self, idx, new_item, old_item):
2066 # type: (int, _T, _T | None) -> None
2067 pass
2068
2069 @abstractmethod
2070 def _get_descriptors(self):
2071 # type: () -> tuple[_Desc, ...]
2072 raise NotImplementedError
2073
2074 @cached_property
2075 @final
2076 def descriptors(self):
2077 # type: () -> tuple[_Desc, ...]
2078 return self._get_descriptors()
2079
2080 @property
2081 @final
2082 def op(self):
2083 return self.__op
2084
2085 def __init__(self, items, op):
2086 # type: (Iterable[_T], Op) -> None
2087 super().__init__()
2088 self.__op = op
2089 self.__items = [] # type: list[_T]
2090 for idx, item in enumerate(items):
2091 if idx >= len(self.descriptors):
2092 raise ValueError("too many items")
2093 _ = self._verify_write(idx, item)
2094 self.__items.append(item)
2095 if len(self.__items) < len(self.descriptors):
2096 raise ValueError("not enough items")
2097
2098 @final
2099 def __iter__(self):
2100 # type: () -> Iterator[_T]
2101 yield from self.__items
2102
2103 @overload
2104 def __getitem__(self, idx):
2105 # type: (int) -> _T
2106 ...
2107
2108 @overload
2109 def __getitem__(self, idx):
2110 # type: (slice) -> list[_T]
2111 ...
2112
2113 @final
2114 def __getitem__(self, idx):
2115 # type: (int | slice) -> _T | list[_T]
2116 return self.__items[idx]
2117
2118 @final
2119 def __setitem__(self, idx, item):
2120 # type: (int, _T) -> None
2121 idx = self._verify_write(idx, item)
2122 self.__items[idx] = item
2123
2124 @final
2125 def __len__(self):
2126 # type: () -> int
2127 return len(self.__items)
2128
2129 def __repr__(self):
2130 # type: () -> str
2131 return f"{self.__class__.__name__}({self.__items}, op=...)"
2132
2133
2134 @final
2135 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2136 def _get_descriptors(self):
2137 # type: () -> tuple[OperandDesc, ...]
2138 return self.op.properties.inputs
2139
2140 def _verify_write_with_desc(self, idx, item, desc):
2141 # type: (int, SSAVal | Any, OperandDesc) -> None
2142 if not isinstance(item, SSAVal):
2143 raise TypeError("expected value of type SSAVal")
2144 if item.ty != desc.ty:
2145 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2146 f"corresponding input's type {desc.ty!r}")
2147
2148 def _on_set(self, idx, new_item, old_item):
2149 # type: (int, SSAVal, SSAVal | None) -> None
2150 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2151
2152 def __init__(self, items, op):
2153 # type: (Iterable[SSAVal], Op) -> None
2154 if hasattr(op, "inputs"):
2155 raise ValueError("Op.inputs already set")
2156 super().__init__(items, op)
2157
2158
2159 @final
2160 class OpImmediates(OpInputSeq[int, range]):
2161 def _get_descriptors(self):
2162 # type: () -> tuple[range, ...]
2163 return self.op.properties.immediates
2164
2165 def _verify_write_with_desc(self, idx, item, desc):
2166 # type: (int, int | Any, range) -> None
2167 if not isinstance(item, int):
2168 raise TypeError("expected value of type int")
2169 if item not in desc:
2170 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2171
2172 def __init__(self, items, op):
2173 # type: (Iterable[int], Op) -> None
2174 if hasattr(op, "immediates"):
2175 raise ValueError("Op.immediates already set")
2176 super().__init__(items, op)
2177
2178
2179 @plain_data.plain_data(frozen=True, eq=False, repr=False)
2180 @final
2181 class Op:
2182 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2183 "outputs", "name")
2184
2185 def __init__(self, fn, properties, input_vals, immediates, name=""):
2186 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2187 self.fn = fn
2188 self.properties = properties
2189 self.input_vals = OpInputVals(input_vals, op=self)
2190 inputs_len = len(self.properties.inputs)
2191 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2192 self.immediates = OpImmediates(immediates, op=self)
2193 outputs_len = len(self.properties.outputs)
2194 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2195 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2196
2197 @property
2198 def kind(self):
2199 # type: () -> OpKind
2200 return self.properties.kind
2201
2202 def __eq__(self, other):
2203 # type: (Op | Any) -> bool
2204 if isinstance(other, Op):
2205 return self is other
2206 return NotImplemented
2207
2208 def __hash__(self):
2209 # type: () -> int
2210 return object.__hash__(self)
2211
2212 def __repr__(self, wrap_width=63, indent=" "):
2213 # type: (int, str) -> str
2214 WRAP_POINT = "\u200B" # zero-width space
2215 items = [f"{self.name}:\n"]
2216 for i, out in enumerate(self.outputs):
2217 item = f"<...outputs[{i}]: {out.ty}>"
2218 if i == 0:
2219 item = "(" + WRAP_POINT + item
2220 if i != len(self.outputs) - 1:
2221 item += ", " + WRAP_POINT
2222 else:
2223 item += WRAP_POINT + ") <= "
2224 items.append(item)
2225 items.append(self.kind._name_)
2226 if len(self.input_vals) + len(self.immediates) != 0:
2227 items[-1] += "("
2228 items[-1] += WRAP_POINT
2229 for i, inp in enumerate(self.input_vals):
2230 item = repr(inp)
2231 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2232 item += ", " + WRAP_POINT
2233 else:
2234 item += ") " + WRAP_POINT
2235 items.append(item)
2236 for i, imm in enumerate(self.immediates):
2237 item = hex(imm)
2238 if i != len(self.immediates) - 1:
2239 item += ", " + WRAP_POINT
2240 else:
2241 item += ") " + WRAP_POINT
2242 items.append(item)
2243 lines = [] # type: list[str]
2244 for i, line_in in enumerate("".join(items).splitlines()):
2245 if i != 0:
2246 line_in = indent + line_in
2247 line_out = ""
2248 for part in line_in.split(WRAP_POINT):
2249 if line_out == "":
2250 line_out = part
2251 continue
2252 trial_line_out = line_out + part
2253 if len(trial_line_out.rstrip()) > wrap_width:
2254 lines.append(line_out.rstrip())
2255 line_out = indent + part
2256 else:
2257 line_out = trial_line_out
2258 lines.append(line_out.rstrip())
2259 return "\n".join(lines)
2260
2261 def sim(self, state):
2262 # type: (BaseSimState) -> None
2263 for inp in self.input_vals:
2264 try:
2265 val = state[inp]
2266 except KeyError:
2267 raise ValueError(f"SSAVal {inp} not yet assigned when "
2268 f"running {self}")
2269 except SimSkipOp:
2270 continue
2271 if len(val) != inp.ty.reg_len:
2272 raise ValueError(
2273 f"value of SSAVal {inp} has wrong number of elements: "
2274 f"expected {inp.ty.reg_len} found "
2275 f"{len(val)}: {val!r}")
2276 if isinstance(state, PreRASimState):
2277 for out in self.outputs:
2278 if out in state.ssa_vals:
2279 if self.kind is OpKind.FuncArgR3:
2280 continue
2281 raise ValueError(f"SSAVal {out} already assigned before "
2282 f"running {self}")
2283 try:
2284 self.kind.sim(self, state)
2285 except SimSkipOp:
2286 state.on_skip(self)
2287 for out in self.outputs:
2288 try:
2289 val = state[out]
2290 except KeyError:
2291 raise ValueError(f"running {self} failed to assign to {out}")
2292 except SimSkipOp:
2293 continue
2294 if len(val) != out.ty.reg_len:
2295 raise ValueError(
2296 f"value of SSAVal {out} has wrong number of elements: "
2297 f"expected {out.ty.reg_len} found "
2298 f"{len(val)}: {val!r}")
2299
2300 def gen_asm(self, state):
2301 # type: (GenAsmState) -> None
2302 all_loc_kinds = tuple(LocKind)
2303 for inp in self.input_vals:
2304 state.loc(inp, expected_kinds=all_loc_kinds)
2305 for out in self.outputs:
2306 state.loc(out, expected_kinds=all_loc_kinds)
2307 self.kind.gen_asm(self, state)
2308
2309
2310 @plain_data.plain_data(frozen=True, repr=False)
2311 class BaseSimState(metaclass=ABCMeta):
2312 __slots__ = "memory",
2313
2314 def __init__(self, memory):
2315 # type: (dict[int, int]) -> None
2316 super().__init__()
2317 self.memory = memory # type: dict[int, int]
2318
2319 def _default_memory_value(self):
2320 # type: () -> int
2321 return 0
2322
2323 def on_skip(self, op):
2324 # type: (Op) -> None
2325 raise ValueError("skipping instructions not supported")
2326
2327 def load_byte(self, addr):
2328 # type: (int) -> int
2329 addr &= GPR_VALUE_MASK
2330 try:
2331 return self.memory[addr] & 0xFF
2332 except KeyError:
2333 return self._default_memory_value()
2334
2335 def store_byte(self, addr, value):
2336 # type: (int, int) -> None
2337 addr &= GPR_VALUE_MASK
2338 value &= 0xFF
2339 self.memory[addr] = value
2340
2341 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2342 # type: (int, int, bool) -> int
2343 if addr % size_in_bytes != 0:
2344 raise ValueError(f"address not aligned: {hex(addr)} "
2345 f"required alignment: {size_in_bytes}")
2346 retval = 0
2347 for i in range(size_in_bytes):
2348 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2349 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2350 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2351 return retval
2352
2353 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2354 # type: (int, int, int) -> None
2355 if addr % size_in_bytes != 0:
2356 raise ValueError(f"address not aligned: {hex(addr)} "
2357 f"required alignment: {size_in_bytes}")
2358 for i in range(size_in_bytes):
2359 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2360
2361 def _memory__repr(self):
2362 # type: () -> str
2363 if len(self.memory) == 0:
2364 return "{}"
2365 keys = sorted(self.memory.keys(), reverse=True)
2366 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2367 items = [] # type: list[str]
2368 while len(keys) != 0:
2369 addr = keys[-1]
2370 if (len(keys) >= CHUNK_SIZE
2371 and addr % CHUNK_SIZE == 0
2372 and keys[-CHUNK_SIZE:]
2373 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2374 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2375 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2376 keys[-CHUNK_SIZE:] = ()
2377 else:
2378 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2379 if len(items) == 1:
2380 return f"{{{items[0]}}}"
2381 items_str = ",\n".join(items)
2382 return f"{{\n{items_str}}}"
2383
2384 def __repr__(self):
2385 # type: () -> str
2386 field_vals = [] # type: list[str]
2387 for name in plain_data.fields(self):
2388 try:
2389 value = getattr(self, name)
2390 except AttributeError:
2391 field_vals.append(f"{name}=<not set>")
2392 continue
2393 repr_fn = getattr(self, f"_{name}__repr", None)
2394 if callable(repr_fn):
2395 field_vals.append(f"{name}={repr_fn()}")
2396 else:
2397 field_vals.append(f"{name}={value!r}")
2398 field_vals_str = ", ".join(field_vals)
2399 return f"{self.__class__.__name__}({field_vals_str})"
2400
2401 @abstractmethod
2402 def __getitem__(self, ssa_val):
2403 # type: (SSAVal) -> tuple[int, ...]
2404 ...
2405
2406 @abstractmethod
2407 def __setitem__(self, ssa_val, value):
2408 # type: (SSAVal, Iterable[int]) -> None
2409 ...
2410
2411
2412 @plain_data.plain_data(frozen=True, repr=False)
2413 class PreRABaseSimState(BaseSimState):
2414 __slots__ = "ssa_vals",
2415
2416 def __init__(self, ssa_vals, memory):
2417 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2418 super().__init__(memory)
2419 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2420
2421 def _ssa_vals__repr(self):
2422 # type: () -> str
2423 if len(self.ssa_vals) == 0:
2424 return "{}"
2425 items = [] # type: list[str]
2426 CHUNK_SIZE = 4
2427 for k, v in self.ssa_vals.items():
2428 element_strs = [] # type: list[str]
2429 for i, el in enumerate(v):
2430 if i % CHUNK_SIZE != 0:
2431 element_strs.append(" " + hex(el))
2432 else:
2433 element_strs.append("\n " + hex(el))
2434 if len(element_strs) <= CHUNK_SIZE:
2435 element_strs[0] = element_strs[0].lstrip()
2436 if len(element_strs) == 1:
2437 element_strs.append("")
2438 v_str = ",".join(element_strs)
2439 items.append(f"{k!r}: ({v_str})")
2440 if len(items) == 1 and "\n" not in items[0]:
2441 return f"{{{items[0]}}}"
2442 items_str = ",\n".join(items)
2443 return f"{{\n{items_str},\n}}"
2444
2445 def __getitem__(self, ssa_val):
2446 # type: (SSAVal) -> tuple[int, ...]
2447 try:
2448 return self.ssa_vals[ssa_val]
2449 except KeyError:
2450 return self._handle_undefined_ssa_val(ssa_val)
2451
2452 def _handle_undefined_ssa_val(self, ssa_val):
2453 # type: (SSAVal) -> tuple[int, ...]
2454 raise KeyError("SSAVal has no value set", ssa_val)
2455
2456 def __setitem__(self, ssa_val, value):
2457 # type: (SSAVal, Iterable[int]) -> None
2458 value = tuple(map(int, value))
2459 if len(value) != ssa_val.ty.reg_len:
2460 raise ValueError("value has wrong len")
2461 self.ssa_vals[ssa_val] = value
2462
2463
2464 class SimSkipOp(Exception):
2465 pass
2466
2467
2468 @plain_data.plain_data(frozen=True, repr=False)
2469 @final
2470 class ConstPropagationState(PreRABaseSimState):
2471 __slots__ = "skipped_ops",
2472
2473 def __init__(self, ssa_vals, memory, skipped_ops):
2474 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2475 super().__init__(ssa_vals, memory)
2476 self.skipped_ops = skipped_ops
2477
2478 def _default_memory_value(self):
2479 # type: () -> int
2480 raise SimSkipOp
2481
2482 def _handle_undefined_ssa_val(self, ssa_val):
2483 # type: (SSAVal) -> tuple[int, ...]
2484 raise SimSkipOp
2485
2486 def on_skip(self, op):
2487 # type: (Op) -> None
2488 self.skipped_ops.add(op)
2489
2490
2491 @plain_data.plain_data(frozen=True, repr=False)
2492 class PreRASimState(PreRABaseSimState):
2493 __slots__ = ()
2494
2495 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2496
2497 @contextmanager
2498 def set_as_current_debugging_state(self):
2499 """ return a context manager that sets self as the current state for
2500 debugging in pdb or similar. This is intended only for use with
2501 `get_current_debugging_state` which should not be used in unit tests
2502 or production code.
2503 """
2504 try:
2505 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2506 yield
2507 finally:
2508 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2509 "inconsistent __CURRENT_DEBUGGING_STATE"
2510
2511 @staticmethod
2512 def get_current_debugging_state():
2513 # type: () -> PreRASimState
2514 """ get the current state for debugging in pdb or similar.
2515
2516 This is intended for use with `set_current_debugging_state`.
2517
2518 This is only intended for debugging, do not use in unit tests or
2519 production code.
2520 """
2521 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2522 raise ValueError("no current debugging state")
2523 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2524
2525
2526 @plain_data.plain_data(frozen=True, repr=False)
2527 @final
2528 class PostRASimState(BaseSimState):
2529 __slots__ = "ssa_val_to_loc_map", "loc_values"
2530
2531 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2532 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2533 super().__init__(memory)
2534 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2535 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2536 if ssa_val.ty != loc.ty:
2537 raise ValueError(
2538 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2539 self.loc_values = loc_values
2540 for loc in self.loc_values.keys():
2541 if loc.reg_len != 1:
2542 raise ValueError(
2543 "loc_values must only contain Locs with reg_len=1, all "
2544 "larger Locs will be split into reg_len=1 sub-Locs")
2545
2546 def _loc_values__repr(self):
2547 # type: () -> str
2548 locs = sorted(self.loc_values.keys(),
2549 key=lambda v: (v.kind.name, v.start))
2550 items = [] # type: list[str]
2551 for loc in locs:
2552 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2553 items_str = ",\n".join(items)
2554 return f"{{\n{items_str},\n}}"
2555
2556 def __getitem__(self, ssa_val):
2557 # type: (SSAVal) -> tuple[int, ...]
2558 loc = self.ssa_val_to_loc_map[ssa_val]
2559 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2560 retval = [] # type: list[int]
2561 for i in range(loc.reg_len):
2562 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2563 retval.append(self.loc_values.get(subloc, 0))
2564 return tuple(retval)
2565
2566 def __setitem__(self, ssa_val, value):
2567 # type: (SSAVal, Iterable[int]) -> None
2568 value = tuple(map(int, value))
2569 if len(value) != ssa_val.ty.reg_len:
2570 raise ValueError("value has wrong len")
2571 loc = self.ssa_val_to_loc_map[ssa_val]
2572 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2573 for i in range(loc.reg_len):
2574 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2575 self.loc_values[subloc] = value[i]
2576
2577
2578 @plain_data.plain_data(frozen=True)
2579 class GenAsmState:
2580 __slots__ = "allocated_locs", "output"
2581
2582 def __init__(self, allocated_locs, output=None):
2583 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2584 super().__init__()
2585 self.allocated_locs = FMap(allocated_locs)
2586 for ssa_val, loc in self.allocated_locs.items():
2587 if ssa_val.ty != loc.ty:
2588 raise ValueError(
2589 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2590 if output is None:
2591 output = []
2592 self.output = output
2593
2594 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2595
2596 def loc(self, ssa_val_or_locs, expected_kinds):
2597 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2598 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2599 ssa_val_or_locs = [ssa_val_or_locs]
2600 locs = [] # type: list[Loc]
2601 for i in ssa_val_or_locs:
2602 if isinstance(i, SSAVal):
2603 locs.append(self.allocated_locs[i])
2604 else:
2605 locs.append(i)
2606 if len(locs) == 0:
2607 raise ValueError("invalid Loc sequence: must not be empty")
2608 retval = locs[0].try_concat(*locs[1:])
2609 if retval is None:
2610 raise ValueError("invalid Loc sequence: try_concat failed")
2611 if isinstance(expected_kinds, LocKind):
2612 expected_kinds = expected_kinds,
2613 if retval.kind not in expected_kinds:
2614 if len(expected_kinds) == 1:
2615 expected_kinds = expected_kinds[0]
2616 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2617 f"{retval.kind} expected {expected_kinds}")
2618 return retval
2619
2620 def gpr(self, ssa_val_or_locs, is_vec):
2621 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2622 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2623 vec_str = "*" if is_vec else ""
2624 return vec_str + str(loc.start)
2625
2626 def sgpr(self, ssa_val_or_locs):
2627 # type: (__SSA_VAL_OR_LOCS) -> str
2628 return self.gpr(ssa_val_or_locs, is_vec=False)
2629
2630 def vgpr(self, ssa_val_or_locs):
2631 # type: (__SSA_VAL_OR_LOCS) -> str
2632 return self.gpr(ssa_val_or_locs, is_vec=True)
2633
2634 def stack(self, ssa_val_or_locs):
2635 # type: (__SSA_VAL_OR_LOCS) -> str
2636 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2637 return f"{loc.start}(1)"
2638
2639 def writeln(self, *line_segments):
2640 # type: (*str) -> None
2641 line = " ".join(line_segments)
2642 if isinstance(self.output, list):
2643 self.output.append(line)
2644 else:
2645 self.output.write(line + "\n")