optimize LocSet.max_conflicts_with
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 import dataclasses
4 from abc import ABCMeta, abstractmethod
5 from enum import Enum, unique
6 from functools import lru_cache, total_ordering
7 from io import StringIO
8 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
9 Mapping, Sequence, TypeVar, Union, overload)
10 from weakref import WeakValueDictionary as _WeakVDict
11
12 from cached_property import cached_property
13 from nmutil import plain_data # type: ignore
14
15 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
16 final)
17 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
18 OFSet, OSet, bit_count)
19
20 GPR_SIZE_IN_BYTES = 8
21 BITS_IN_BYTE = 8
22 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
23 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
24
25
26 @final
27 class Fn:
28 def __init__(self):
29 self.ops = [] # type: list[Op]
30 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
31 self.__next_name_suffix = 2
32
33 def _add_op_with_unused_name(self, op, name=""):
34 # type: (Op, str) -> str
35 if op.fn is not self:
36 raise ValueError("can't add Op to wrong Fn")
37 if hasattr(op, "name"):
38 raise ValueError("Op already named")
39 orig_name = name
40 while True:
41 if name != "" and name not in self.__op_names:
42 self.__op_names[name] = op
43 return name
44 name = orig_name + str(self.__next_name_suffix)
45 self.__next_name_suffix += 1
46
47 def __repr__(self):
48 # type: () -> str
49 return "<Fn>"
50
51 def ops_to_str(self, as_python_literal=False, wrap_width=63,
52 python_indent=" ", indent=" "):
53 # type: (bool, int, str, str) -> str
54 l = [] # type: list[str]
55 for op in self.ops:
56 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
57 retval = "\n".join(l)
58 if as_python_literal:
59 l = [python_indent + "\""]
60 for ch in retval:
61 if ch == "\n":
62 l.append(f"\\n\"\n{python_indent}\"")
63 elif ch in "\"\\":
64 l.append("\\" + ch)
65 elif ch.isascii() and ch.isprintable():
66 l.append(ch)
67 else:
68 l.append(repr(ch).strip("\"'"))
69 l.append("\"")
70 retval = "".join(l)
71 empty_end = f"\"\n{python_indent}\"\""
72 if retval.endswith(empty_end):
73 retval = retval[:-len(empty_end)]
74 return retval
75
76 def append_op(self, op):
77 # type: (Op) -> None
78 if op.fn is not self:
79 raise ValueError("can't add Op to wrong Fn")
80 self.ops.append(op)
81
82 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
83 maxvl=1):
84 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
85 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
86 input_vals=input_vals, immediates=immediates, name=name)
87 self.append_op(retval)
88 return retval
89
90 def sim(self, state):
91 # type: (BaseSimState) -> None
92 for op in self.ops:
93 op.sim(state)
94
95 def gen_asm(self, state):
96 # type: (GenAsmState) -> None
97 for op in self.ops:
98 op.gen_asm(state)
99
100 def pre_ra_insert_copies(self):
101 # type: () -> None
102 orig_ops = list(self.ops)
103 copied_outputs = {} # type: dict[SSAVal, SSAVal]
104 setvli_outputs = {} # type: dict[SSAVal, Op]
105 self.ops.clear()
106 for op in orig_ops:
107 for i in range(len(op.input_vals)):
108 inp = copied_outputs[op.input_vals[i]]
109 if inp.ty.base_ty is BaseTy.I64:
110 maxvl = inp.ty.reg_len
111 if inp.ty.reg_len != 1:
112 setvl = self.append_new_op(
113 OpKind.SetVLI, immediates=[maxvl],
114 name=f"{op.name}.inp{i}.setvl")
115 vl = setvl.outputs[0]
116 mv = self.append_new_op(
117 OpKind.VecCopyToReg, input_vals=[inp, vl],
118 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
119 else:
120 mv = self.append_new_op(
121 OpKind.CopyToReg, input_vals=[inp],
122 name=f"{op.name}.inp{i}.copy")
123 op.input_vals[i] = mv.outputs[0]
124 elif inp.ty.base_ty is BaseTy.CA \
125 or inp.ty.base_ty is BaseTy.VL_MAXVL:
126 # all copies would be no-ops, so we don't need to copy,
127 # though we do need to rematerialize SetVLI ops right
128 # before the ops VL
129 if inp in setvli_outputs:
130 setvl = self.append_new_op(
131 OpKind.SetVLI,
132 immediates=setvli_outputs[inp].immediates,
133 name=f"{op.name}.inp{i}.setvl")
134 inp = setvl.outputs[0]
135 op.input_vals[i] = inp
136 else:
137 assert_never(inp.ty.base_ty)
138 self.ops.append(op)
139 for i, out in enumerate(op.outputs):
140 if op.kind is OpKind.SetVLI:
141 setvli_outputs[out] = op
142 if out.ty.base_ty is BaseTy.I64:
143 maxvl = out.ty.reg_len
144 if out.ty.reg_len != 1:
145 setvl = self.append_new_op(
146 OpKind.SetVLI, immediates=[maxvl],
147 name=f"{op.name}.out{i}.setvl")
148 vl = setvl.outputs[0]
149 mv = self.append_new_op(
150 OpKind.VecCopyFromReg, input_vals=[out, vl],
151 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
152 else:
153 mv = self.append_new_op(
154 OpKind.CopyFromReg, input_vals=[out],
155 name=f"{op.name}.out{i}.copy")
156 copied_outputs[out] = mv.outputs[0]
157 elif out.ty.base_ty is BaseTy.CA \
158 or out.ty.base_ty is BaseTy.VL_MAXVL:
159 # all copies would be no-ops, so we don't need to copy
160 copied_outputs[out] = out
161 else:
162 assert_never(out.ty.base_ty)
163
164
165 @final
166 @unique
167 @total_ordering
168 class OpStage(Enum):
169 value: Literal[0, 1] # type: ignore
170
171 def __new__(cls, value):
172 # type: (int) -> OpStage
173 value = int(value)
174 if value not in (0, 1):
175 raise ValueError("invalid value")
176 retval = object.__new__(cls)
177 retval._value_ = value
178 return retval
179
180 Early = 0
181 """ early stage of Op execution, where all input reads occur.
182 all output writes with `write_stage == Early` occur here too, and therefore
183 conflict with input reads, telling the compiler that it that can't share
184 that output's register with any inputs that the output isn't tied to.
185
186 All outputs, even unused outputs, can't share registers with any other
187 outputs, independent of `write_stage` settings.
188 """
189 Late = 1
190 """ late stage of Op execution, where all output writes with
191 `write_stage == Late` occur, and therefore don't conflict with input reads,
192 telling the compiler that any inputs can safely use the same register as
193 those outputs.
194
195 All outputs, even unused outputs, can't share registers with any other
196 outputs, independent of `write_stage` settings.
197 """
198
199 def __repr__(self):
200 # type: () -> str
201 return f"OpStage.{self._name_}"
202
203 def __lt__(self, other):
204 # type: (OpStage | object) -> bool
205 if isinstance(other, OpStage):
206 return self.value < other.value
207 return NotImplemented
208
209
210 assert OpStage.Early < OpStage.Late, "early must be less than late"
211
212
213 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
214 @final
215 class ProgramPoint(Interned):
216 op_index: int
217 stage: OpStage
218
219 @property
220 def int_value(self):
221 # type: () -> int
222 """ an integer representation of `self` such that it keeps ordering and
223 successor/predecessor relations.
224 """
225 return self.op_index * 2 + self.stage.value
226
227 @staticmethod
228 def from_int_value(int_value):
229 # type: (int) -> ProgramPoint
230 op_index, stage = divmod(int_value, 2)
231 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
232
233 def next(self, steps=1):
234 # type: (int) -> ProgramPoint
235 return ProgramPoint.from_int_value(self.int_value + steps)
236
237 def prev(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return self.next(steps=-steps)
240
241 def __lt__(self, other):
242 # type: (ProgramPoint | Any) -> bool
243 if not isinstance(other, ProgramPoint):
244 return NotImplemented
245 if self.op_index != other.op_index:
246 return self.op_index < other.op_index
247 return self.stage < other.stage
248
249 def __gt__(self, other):
250 # type: (ProgramPoint | Any) -> bool
251 if not isinstance(other, ProgramPoint):
252 return NotImplemented
253 return other.__lt__(self)
254
255 def __le__(self, other):
256 # type: (ProgramPoint | Any) -> bool
257 if not isinstance(other, ProgramPoint):
258 return NotImplemented
259 return not self.__gt__(other)
260
261 def __ge__(self, other):
262 # type: (ProgramPoint | Any) -> bool
263 if not isinstance(other, ProgramPoint):
264 return NotImplemented
265 return not self.__lt__(other)
266
267 def __repr__(self):
268 # type: () -> str
269 return f"<ops[{self.op_index}]:{self.stage._name_}>"
270
271
272 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
273 @final
274 class ProgramRange(Sequence[ProgramPoint], Interned):
275 start: ProgramPoint
276 stop: ProgramPoint
277
278 @cached_property
279 def int_value_range(self):
280 # type: () -> range
281 return range(self.start.int_value, self.stop.int_value)
282
283 @staticmethod
284 def from_int_value_range(int_value_range):
285 # type: (range) -> ProgramRange
286 if int_value_range.step != 1:
287 raise ValueError("int_value_range must have step == 1")
288 return ProgramRange(
289 start=ProgramPoint.from_int_value(int_value_range.start),
290 stop=ProgramPoint.from_int_value(int_value_range.stop))
291
292 @overload
293 def __getitem__(self, __idx):
294 # type: (int) -> ProgramPoint
295 ...
296
297 @overload
298 def __getitem__(self, __idx):
299 # type: (slice) -> ProgramRange
300 ...
301
302 def __getitem__(self, __idx):
303 # type: (int | slice) -> ProgramPoint | ProgramRange
304 v = range(self.start.int_value, self.stop.int_value)[__idx]
305 if isinstance(v, int):
306 return ProgramPoint.from_int_value(v)
307 return ProgramRange.from_int_value_range(v)
308
309 def __len__(self):
310 # type: () -> int
311 return len(self.int_value_range)
312
313 def __iter__(self):
314 # type: () -> Iterator[ProgramPoint]
315 return map(ProgramPoint.from_int_value, self.int_value_range)
316
317 def __repr__(self):
318 # type: () -> str
319 start = repr(self.start).lstrip("<").rstrip(">")
320 stop = repr(self.stop).lstrip("<").rstrip(">")
321 return f"<range:{start}..{stop}>"
322
323
324 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
325 @final
326 class SSAValSubReg(Interned):
327 ssa_val: "SSAVal"
328 reg_idx: int
329
330 def __post_init__(self):
331 if self.reg_idx < 0 or self.reg_idx >= self.ssa_val.ty.reg_len:
332 raise ValueError("reg_idx out of range")
333
334 def __repr__(self):
335 # type: () -> str
336 return f"{self.ssa_val}[{self.reg_idx}]"
337
338
339 @plain_data.plain_data(frozen=True, eq=False, repr=False)
340 @final
341 class FnAnalysis:
342 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
343 "def_program_ranges", "use_program_points",
344 "all_program_points")
345
346 def __init__(self, fn):
347 # type: (Fn) -> None
348 self.fn = fn
349 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
350 self.all_program_points = ProgramRange(
351 start=ProgramPoint(op_index=0, stage=OpStage.Early),
352 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
353 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
354 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
355 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
356 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
357 for op in fn.ops:
358 for use in op.input_uses:
359 uses[use.ssa_val].add(use)
360 use_program_point = self.__get_use_program_point(use)
361 use_program_points[use] = use_program_point
362 live_range_stops[use.ssa_val] = max(
363 live_range_stops[use.ssa_val], use_program_point.next())
364 for out in op.outputs:
365 uses[out] = OSet()
366 def_program_range = self.__get_def_program_range(out)
367 def_program_ranges[out] = def_program_range
368 live_range_stops[out] = def_program_range.stop
369 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
370 self.def_program_ranges = FMap(def_program_ranges)
371 self.use_program_points = FMap(use_program_points)
372 live_ranges = {} # type: dict[SSAVal, ProgramRange]
373 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
374 for ssa_val in uses.keys():
375 live_ranges[ssa_val] = live_range = ProgramRange(
376 start=self.def_program_ranges[ssa_val].start,
377 stop=live_range_stops[ssa_val])
378 for program_point in live_range:
379 live_at[program_point].add(ssa_val)
380 self.live_ranges = FMap(live_ranges)
381 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
382 self.copies # initialize
383 self.const_ssa_vals # initialize
384 self.const_ssa_val_sub_regs # initialize
385
386 def __get_def_program_range(self, ssa_val):
387 # type: (SSAVal) -> ProgramRange
388 write_stage = ssa_val.defining_descriptor.write_stage
389 start = ProgramPoint(
390 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
391 # always include late stage of ssa_val.op, to ensure outputs always
392 # overlap all other outputs.
393 # stop is exclusive, so we need the next program point.
394 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
395 return ProgramRange(start=start, stop=stop)
396
397 def __get_use_program_point(self, ssa_use):
398 # type: (SSAUse) -> ProgramPoint
399 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
400 "assumed here, ensured by GenericOpProperties.__init__"
401 return ProgramPoint(
402 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
403
404 def __eq__(self, other):
405 # type: (FnAnalysis | Any) -> bool
406 if isinstance(other, FnAnalysis):
407 return self.fn == other.fn
408 return NotImplemented
409
410 def __hash__(self):
411 # type: () -> int
412 return hash(self.fn)
413
414 def __repr__(self):
415 # type: () -> str
416 return "<FnAnalysis>"
417
418 @cached_property
419 def copies(self):
420 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
421 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
422 a copy of, looking through all layers of copies. The map excludes all
423 SSAValSubRegs that aren't copies of other SSAValSubRegs.
424 This ignores inputs of copy Ops that aren't actually being copied
425 (e.g. the VL input of VecCopyToReg).
426 """
427 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
428 for op in self.op_indexes.keys():
429 if not op.properties.is_copy:
430 continue
431 copy_reg_len = op.properties.copy_reg_len
432 copy_inputs = [] # type: list[SSAValSubReg]
433 for inp in op.input_vals[:op.properties.copy_inputs_len]:
434 for inp_sub_reg in inp.ssa_val_sub_regs:
435 # propagate copies of copies
436 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
437 copy_inputs.append(inp_sub_reg)
438 assert len(copy_inputs) == copy_reg_len, "logic error"
439 copy_outputs = [] # type: list[SSAValSubReg]
440 for out in op.outputs[:op.properties.copy_outputs_len]:
441 copy_outputs.extend(out.ssa_val_sub_regs)
442 assert len(copy_outputs) == copy_reg_len, "logic error"
443 for inp, out in zip(copy_inputs, copy_outputs):
444 retval[out] = inp
445 return FMap(retval)
446
447 @cached_property
448 def copy_related_ssa_vals(self):
449 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
450 """ map from SSAVals to the full set of SSAVals that are related by
451 being sources/destinations of copies, transitively looking through all
452 copies.
453 This ignores inputs of copy Ops that aren't actually being copied
454 (e.g. the VL input of VecCopyToReg).
455 """
456 sets_map = {i: OSet([i]) for i in self.uses.keys()}
457 for k, v in self.copies.items():
458 k_set = sets_map[k.ssa_val]
459 v_set = sets_map[v.ssa_val]
460 # merge k_set and v_set
461 if k_set is v_set:
462 continue
463 k_set |= v_set
464 for i in k_set:
465 sets_map[i] = k_set
466 # this way we construct each OFSet only once rather than
467 # for each SSAVal
468 sets_set = {id(i): i for i in sets_map.values()}
469 retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
470 for v in sets_set.values():
471 v = OFSet(v)
472 for k in v:
473 retval[k] = v
474 return FMap(retval)
475
476 @cached_property
477 def const_ssa_vals(self):
478 # type: () -> FMap[SSAVal, tuple[int, ...]]
479 state = ConstPropagationState(
480 ssa_vals={}, memory={}, skipped_ops=OSet())
481 self.fn.sim(state)
482 return FMap(state.ssa_vals)
483
484 @cached_property
485 def const_ssa_val_sub_regs(self):
486 # type: () -> FMap[SSAValSubReg, int]
487 retval = {} # type: dict[SSAValSubReg, int]
488 for ssa_val, const_val in self.const_ssa_vals.items():
489 assert ssa_val.ty.reg_len == len(const_val), "logic error"
490 for reg_idx, v in enumerate(const_val):
491 retval[SSAValSubReg(ssa_val, reg_idx)] = v
492 return FMap(retval)
493
494 def is_always_equal(self, a, b):
495 # type: (SSAValSubReg, SSAValSubReg) -> bool
496 """check if a and b are known to be always equal to each other.
497 This means they can be allocated to the same location if other
498 constraints don't prevent that.
499
500 this can happen for a number of reasons, such as:
501 * a and b are copies of the same thing
502 * a and b are known to be constants and they have the same value
503 """
504 if a.ssa_val.base_ty != b.ssa_val.base_ty:
505 return False # can't be equal, they have different types
506 # look through copies
507 a = self.copies.get(a, a)
508 b = self.copies.get(b, b)
509 if a == b:
510 return True
511 # check if they have the same constant value
512 try:
513 a_const_val = self.const_ssa_val_sub_regs[a]
514 b_const_val = self.const_ssa_val_sub_regs[b]
515 if a_const_val == b_const_val:
516 return True
517 except KeyError:
518 pass
519 return False
520
521
522 @unique
523 @final
524 class BaseTy(Enum):
525 I64 = enum.auto()
526 CA = enum.auto()
527 VL_MAXVL = enum.auto()
528
529 @cached_property
530 def only_scalar(self):
531 # type: () -> bool
532 if self is BaseTy.I64:
533 return False
534 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
535 return True
536 else:
537 assert_never(self)
538
539 @cached_property
540 def max_reg_len(self):
541 # type: () -> int
542 if self is BaseTy.I64:
543 return 128
544 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
545 return 1
546 else:
547 assert_never(self)
548
549 def __repr__(self):
550 return "BaseTy." + self._name_
551
552
553 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
554 @final
555 class Ty(Interned):
556 base_ty: BaseTy
557 reg_len: int
558
559 @staticmethod
560 def validate(base_ty, reg_len):
561 # type: (BaseTy, int) -> str | None
562 """ return a string with the error if the combination is invalid,
563 otherwise return None
564 """
565 if base_ty.only_scalar and reg_len != 1:
566 return f"can't create a vector of an only-scalar type: {base_ty}"
567 if reg_len < 1 or reg_len > base_ty.max_reg_len:
568 return "reg_len out of range"
569 return None
570
571 def __post_init__(self):
572 msg = self.validate(base_ty=self.base_ty, reg_len=self.reg_len)
573 if msg is not None:
574 raise ValueError(msg)
575
576 def __repr__(self):
577 # type: () -> str
578 if self.reg_len != 1:
579 reg_len = f"*{self.reg_len}"
580 else:
581 reg_len = ""
582 return f"<{self.base_ty._name_}{reg_len}>"
583
584
585 @unique
586 @final
587 class LocKind(Enum):
588 GPR = enum.auto()
589 StackI64 = enum.auto()
590 CA = enum.auto()
591 VL_MAXVL = enum.auto()
592
593 @cached_property
594 def base_ty(self):
595 # type: () -> BaseTy
596 if self is LocKind.GPR or self is LocKind.StackI64:
597 return BaseTy.I64
598 if self is LocKind.CA:
599 return BaseTy.CA
600 if self is LocKind.VL_MAXVL:
601 return BaseTy.VL_MAXVL
602 else:
603 assert_never(self)
604
605 @cached_property
606 def loc_count(self):
607 # type: () -> int
608 if self is LocKind.StackI64:
609 return 512
610 if self is LocKind.GPR or self is LocKind.CA \
611 or self is LocKind.VL_MAXVL:
612 return self.base_ty.max_reg_len
613 else:
614 assert_never(self)
615
616 def __repr__(self):
617 return "LocKind." + self._name_
618
619
620 @final
621 @unique
622 class LocSubKind(Enum):
623 BASE_GPR = enum.auto()
624 SV_EXTRA2_VGPR = enum.auto()
625 SV_EXTRA2_SGPR = enum.auto()
626 SV_EXTRA3_VGPR = enum.auto()
627 SV_EXTRA3_SGPR = enum.auto()
628 StackI64 = enum.auto()
629 CA = enum.auto()
630 VL_MAXVL = enum.auto()
631
632 @cached_property
633 def kind(self):
634 # type: () -> LocKind
635 # pyright fails typechecking when using `in` here:
636 # reported: https://github.com/microsoft/pyright/issues/4102
637 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
638 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
639 LocSubKind.SV_EXTRA3_SGPR):
640 return LocKind.GPR
641 if self is LocSubKind.StackI64:
642 return LocKind.StackI64
643 if self is LocSubKind.CA:
644 return LocKind.CA
645 if self is LocSubKind.VL_MAXVL:
646 return LocKind.VL_MAXVL
647 assert_never(self)
648
649 @property
650 def base_ty(self):
651 return self.kind.base_ty
652
653 @lru_cache()
654 def allocatable_locs(self, ty):
655 # type: (Ty) -> LocSet
656 if ty.base_ty != self.base_ty:
657 raise ValueError("type mismatch")
658 if self is LocSubKind.BASE_GPR:
659 starts = range(32)
660 elif self is LocSubKind.SV_EXTRA2_VGPR:
661 starts = range(0, 128, 2)
662 elif self is LocSubKind.SV_EXTRA2_SGPR:
663 starts = range(64)
664 elif self is LocSubKind.SV_EXTRA3_VGPR \
665 or self is LocSubKind.SV_EXTRA3_SGPR:
666 starts = range(128)
667 elif self is LocSubKind.StackI64:
668 starts = range(LocKind.StackI64.loc_count)
669 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
670 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
671 else:
672 assert_never(self)
673 retval = [] # type: list[Loc]
674 for start in starts:
675 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
676 if loc is None:
677 continue
678 conflicts = False
679 for special_loc in SPECIAL_GPRS:
680 if loc.conflicts(special_loc):
681 conflicts = True
682 break
683 if not conflicts:
684 retval.append(loc)
685 return LocSet(retval)
686
687 def __repr__(self):
688 return "LocSubKind." + self._name_
689
690
691 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
692 @final
693 class GenericTy(Interned):
694 base_ty: BaseTy
695 is_vec: bool
696
697 def __post_init__(self):
698 if self.base_ty.only_scalar and self.is_vec:
699 raise ValueError(f"base_ty={self.base_ty} requires is_vec=False")
700
701 def instantiate(self, maxvl):
702 # type: (int) -> Ty
703 # here's where subvl and elwid would be accounted for
704 if self.is_vec:
705 return Ty(self.base_ty, maxvl)
706 return Ty(self.base_ty, 1)
707
708 def can_instantiate_to(self, ty):
709 # type: (Ty) -> bool
710 if self.base_ty != ty.base_ty:
711 return False
712 if self.is_vec:
713 return True
714 return ty.reg_len == 1
715
716
717 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
718 @final
719 class Loc(Interned):
720 kind: LocKind
721 start: int
722 reg_len: int
723
724 @staticmethod
725 def validate(kind, start, reg_len):
726 # type: (LocKind, int, int) -> str | None
727 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
728 if msg is not None:
729 return msg
730 if reg_len > kind.loc_count:
731 return "invalid reg_len"
732 if start < 0 or start + reg_len > kind.loc_count:
733 return "start not in valid range"
734 return None
735
736 @staticmethod
737 def try_make(kind, start, reg_len):
738 # type: (LocKind, int, int) -> Loc | None
739 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
740 if msg is not None:
741 return None
742 return Loc(kind=kind, start=start, reg_len=reg_len)
743
744 def __post_init__(self):
745 msg = self.validate(kind=self.kind, start=self.start,
746 reg_len=self.reg_len)
747 if msg is not None:
748 raise ValueError(msg)
749
750 def conflicts(self, other):
751 # type: (Loc) -> bool
752 return (self.kind == other.kind
753 and self.start < other.stop and other.start < self.stop)
754
755 @staticmethod
756 def make_ty(kind, reg_len):
757 # type: (LocKind, int) -> Ty
758 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
759
760 @cached_property
761 def ty(self):
762 # type: () -> Ty
763 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
764
765 @cached_property
766 def stop(self):
767 # type: () -> int
768 return self.start + self.reg_len
769
770 def try_concat(self, *others):
771 # type: (*Loc | None) -> Loc | None
772 reg_len = self.reg_len
773 stop = self.stop
774 for other in others:
775 if other is None or other.kind != self.kind:
776 return None
777 if stop != other.start:
778 return None
779 stop = other.stop
780 reg_len += other.reg_len
781 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
782
783 def get_subloc_at_offset(self, subloc_ty, offset):
784 # type: (Ty, int) -> Loc
785 if subloc_ty.base_ty != self.kind.base_ty:
786 raise ValueError("BaseTy mismatch")
787 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
788 raise ValueError("invalid sub-Loc: offset and/or "
789 "subloc_ty.reg_len out of range")
790 return Loc(kind=self.kind,
791 start=self.start + offset, reg_len=subloc_ty.reg_len)
792
793 def get_superloc_with_self_at_offset(self, superloc_ty, offset):
794 # type: (Ty, int) -> Loc
795 """get the Loc containing `self` such that:
796 `retval.get_subloc_at_offset(self.ty, offset) == self`
797 and `retval.ty == superloc_ty`
798 """
799 if superloc_ty.base_ty != self.kind.base_ty:
800 raise ValueError("BaseTy mismatch")
801 if offset < 0 or offset + self.reg_len > superloc_ty.reg_len:
802 raise ValueError("invalid sub-Loc: offset and/or "
803 "superloc_ty.reg_len out of range")
804 return Loc(kind=self.kind,
805 start=self.start - offset, reg_len=superloc_ty.reg_len)
806
807
808 SPECIAL_GPRS = (
809 Loc(kind=LocKind.GPR, start=0, reg_len=1),
810 Loc(kind=LocKind.GPR, start=1, reg_len=1),
811 Loc(kind=LocKind.GPR, start=2, reg_len=1),
812 Loc(kind=LocKind.GPR, start=13, reg_len=1),
813 )
814
815
816 @final
817 class LocSet(OFSet[Loc], Interned):
818 def __init__(self, __locs=()):
819 # type: (Iterable[Loc]) -> None
820 super().__init__(__locs)
821 if isinstance(__locs, LocSet):
822 self.__starts = __locs.starts
823 self.__ty = __locs.ty
824 return
825 starts = {i: BitSet() for i in LocKind}
826 ty = None # type: None | Ty
827 for loc in self:
828 if ty is None:
829 ty = loc.ty
830 if ty != loc.ty:
831 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
832 starts[loc.kind].add(loc.start)
833 self.__starts = FMap(
834 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
835 self.__ty = ty
836
837 @property
838 def starts(self):
839 # type: () -> FMap[LocKind, FBitSet]
840 return self.__starts
841
842 @property
843 def ty(self):
844 # type: () -> Ty | None
845 return self.__ty
846
847 @cached_property
848 def stops(self):
849 # type: () -> FMap[LocKind, FBitSet]
850 if self.ty is None:
851 return FMap()
852 sh = self.ty.reg_len
853 return FMap(
854 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
855
856 @property
857 def kinds(self):
858 # type: () -> AbstractSet[LocKind]
859 return self.starts.keys()
860
861 @property
862 def reg_len(self):
863 # type: () -> int | None
864 if self.ty is None:
865 return None
866 return self.ty.reg_len
867
868 @property
869 def base_ty(self):
870 # type: () -> BaseTy | None
871 if self.ty is None:
872 return None
873 return self.ty.base_ty
874
875 def concat(self, *others):
876 # type: (*LocSet) -> LocSet
877 if self.ty is None:
878 return LocSet()
879 base_ty = self.ty.base_ty
880 reg_len = self.ty.reg_len
881 starts = {k: BitSet(v) for k, v in self.starts.items()}
882 for other in others:
883 if other.ty is None:
884 return LocSet()
885 if other.ty.base_ty != base_ty:
886 return LocSet()
887 for kind, other_starts in other.starts.items():
888 if kind not in starts:
889 continue
890 starts[kind].bits &= other_starts.bits >> reg_len
891 if starts[kind] == 0:
892 del starts[kind]
893 if len(starts) == 0:
894 return LocSet()
895 reg_len += other.ty.reg_len
896
897 def locs():
898 # type: () -> Iterable[Loc]
899 for kind, v in starts.items():
900 for start in v:
901 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
902 if loc is not None:
903 yield loc
904 return LocSet(locs())
905
906 @lru_cache(maxsize=None, typed=True)
907 def max_conflicts_with(self, other):
908 # type: (LocSet | Loc) -> int
909 """the largest number of Locs in `self` that a single Loc
910 from `other` can conflict with
911 """
912 if isinstance(other, LocSet):
913 return max(self.max_conflicts_with(i) for i in other)
914 else:
915 reg_len = self.reg_len
916 if reg_len is None:
917 return 0
918 starts = self.starts.get(other.kind)
919 if starts is None:
920 return 0
921 # now we do the equivalent of:
922 # return sum(other.conflicts(i) for i in self)
923 # which is the equivalent of:
924 # return sum(other.start < start + reg_len
925 # and start < other.start + other.reg_len
926 # for start in starts)
927 stops = starts.bits << reg_len
928
929 # find all the bit indexes `i` where `i < other.start + 1`
930 lt_other_start_plus_1 = ~(~0 << (other.start + 1))
931
932 # find all the bit indexes `i` where
933 # `i < other.start + other.reg_len + reg_len`
934 lt_other_start_plus_other_reg_len_plus_reg_len = (
935 ~(~0 << (other.start + other.reg_len + reg_len)))
936 included = ~(stops & lt_other_start_plus_1)
937 included &= stops
938 included &= lt_other_start_plus_other_reg_len_plus_reg_len
939 return bit_count(included)
940
941 def __repr__(self):
942 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
943
944 @cached_property
945 def only_loc(self):
946 # type: () -> Loc | None
947 """if len(self) == 1 then return the Loc in self, otherwise None"""
948 only_loc = None
949 for i in self:
950 if only_loc is None:
951 only_loc = i
952 else:
953 return None # len(self) > 1
954 return only_loc
955
956
957 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
958 @final
959 class GenericOperandDesc(Interned):
960 """generic Op operand descriptor"""
961 ty: GenericTy
962 sub_kinds: OFSet[LocSubKind]
963 fixed_loc: "Loc | None" = None
964 tied_input_index: "int | None" = None
965 spread: bool = False
966 write_stage: OpStage = OpStage.Early
967
968 def __init__(
969 self, ty, # type: GenericTy
970 sub_kinds, # type: Iterable[LocSubKind]
971 *,
972 fixed_loc=None, # type: Loc | None
973 tied_input_index=None, # type: int | None
974 spread=False, # type: bool
975 write_stage=OpStage.Early, # type: OpStage
976 ):
977 # type: (...) -> None
978 object.__setattr__(self, "ty", ty)
979 object.__setattr__(self, "sub_kinds", OFSet(sub_kinds))
980 if len(self.sub_kinds) == 0:
981 raise ValueError("sub_kinds can't be empty")
982 object.__setattr__(self, "fixed_loc", fixed_loc)
983 if fixed_loc is not None:
984 if tied_input_index is not None:
985 raise ValueError("operand can't be both tied and fixed")
986 if not ty.can_instantiate_to(fixed_loc.ty):
987 raise ValueError(
988 f"fixed_loc has incompatible type for given generic "
989 f"type: fixed_loc={fixed_loc} generic ty={ty}")
990 if len(self.sub_kinds) != 1:
991 raise ValueError(
992 "multiple sub_kinds not allowed for fixed operand")
993 for sub_kind in self.sub_kinds:
994 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
995 raise ValueError(
996 f"fixed_loc not in given sub_kind: "
997 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
998 for sub_kind in self.sub_kinds:
999 if sub_kind.base_ty != ty.base_ty:
1000 raise ValueError(f"sub_kind is incompatible with type: "
1001 f"sub_kind={sub_kind} ty={ty}")
1002 if tied_input_index is not None and tied_input_index < 0:
1003 raise ValueError("invalid tied_input_index")
1004 object.__setattr__(self, "tied_input_index", tied_input_index)
1005 object.__setattr__(self, "spread", spread)
1006 if spread:
1007 if self.tied_input_index is not None:
1008 raise ValueError("operand can't be both spread and tied")
1009 if self.fixed_loc is not None:
1010 raise ValueError("operand can't be both spread and fixed")
1011 if self.ty.is_vec:
1012 raise ValueError("operand can't be both spread and vector")
1013 object.__setattr__(self, "write_stage", write_stage)
1014
1015 @cached_property
1016 def ty_before_spread(self):
1017 # type: () -> GenericTy
1018 if self.spread:
1019 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
1020 return self.ty
1021
1022 def tied_to_input(self, tied_input_index):
1023 # type: (int) -> Self
1024 return GenericOperandDesc(self.ty, self.sub_kinds,
1025 tied_input_index=tied_input_index,
1026 write_stage=self.write_stage)
1027
1028 def with_fixed_loc(self, fixed_loc):
1029 # type: (Loc) -> Self
1030 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
1031 write_stage=self.write_stage)
1032
1033 def with_write_stage(self, write_stage):
1034 # type: (OpStage) -> Self
1035 return GenericOperandDesc(self.ty, self.sub_kinds,
1036 fixed_loc=self.fixed_loc,
1037 tied_input_index=self.tied_input_index,
1038 spread=self.spread,
1039 write_stage=write_stage)
1040
1041 def instantiate(self, maxvl):
1042 # type: (int) -> Iterable[OperandDesc]
1043 # assumes all spread operands have ty.reg_len = 1
1044 rep_count = 1
1045 if self.spread:
1046 rep_count = maxvl
1047 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
1048
1049 def locs_before_spread():
1050 # type: () -> Iterable[Loc]
1051 if self.fixed_loc is not None:
1052 if ty_before_spread != self.fixed_loc.ty:
1053 raise ValueError(
1054 f"instantiation failed: type mismatch with fixed_loc: "
1055 f"instantiated type: {ty_before_spread} "
1056 f"fixed_loc: {self.fixed_loc}")
1057 yield self.fixed_loc
1058 return
1059 for sub_kind in self.sub_kinds:
1060 yield from sub_kind.allocatable_locs(ty_before_spread)
1061 loc_set_before_spread = LocSet(locs_before_spread())
1062 for idx in range(rep_count):
1063 if not self.spread:
1064 idx = None
1065 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
1066 tied_input_index=self.tied_input_index,
1067 spread_index=idx, write_stage=self.write_stage)
1068
1069
1070 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1071 @final
1072 class OperandDesc(Interned):
1073 """Op operand descriptor"""
1074 loc_set_before_spread: LocSet
1075 tied_input_index: "int | None"
1076 spread_index: "int | None"
1077 write_stage: "OpStage"
1078
1079 def __post_init__(self):
1080 if len(self.loc_set_before_spread) == 0:
1081 raise ValueError("loc_set_before_spread must not be empty")
1082 if self.tied_input_index is not None and self.spread_index is not None:
1083 raise ValueError("operand can't be both spread and tied")
1084
1085 @cached_property
1086 def ty_before_spread(self):
1087 # type: () -> Ty
1088 ty = self.loc_set_before_spread.ty
1089 assert ty is not None, (
1090 "__init__ checked that the LocSet isn't empty, "
1091 "non-empty LocSets should always have ty set")
1092 return ty
1093
1094 @cached_property
1095 def ty(self):
1096 """ Ty after any spread is applied """
1097 if self.spread_index is not None:
1098 # assumes all spread operands have ty.reg_len = 1
1099 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1100 return self.ty_before_spread
1101
1102 @property
1103 def reg_offset_in_unspread(self):
1104 """ the number of reg-sized slots in the unspread Loc before self's Loc
1105
1106 e.g. if the unspread Loc containing self is:
1107 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1108 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1109 then reg_offset_into_unspread == 2 == 10 - 8
1110 """
1111 if self.spread_index is None:
1112 return 0
1113 return self.spread_index * self.ty.reg_len
1114
1115
1116 OD_BASE_SGPR = GenericOperandDesc(
1117 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1118 sub_kinds=[LocSubKind.BASE_GPR])
1119 OD_EXTRA3_SGPR = GenericOperandDesc(
1120 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1121 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1122 OD_EXTRA3_VGPR = GenericOperandDesc(
1123 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1124 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1125 OD_EXTRA2_SGPR = GenericOperandDesc(
1126 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1127 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1128 OD_EXTRA2_VGPR = GenericOperandDesc(
1129 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1130 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1131 OD_CA = GenericOperandDesc(
1132 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1133 sub_kinds=[LocSubKind.CA])
1134 OD_VL = GenericOperandDesc(
1135 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1136 sub_kinds=[LocSubKind.VL_MAXVL])
1137
1138
1139 @dataclasses.dataclass(frozen=True, unsafe_hash=True)
1140 @final
1141 class GenericOpProperties(Interned):
1142 demo_asm: str
1143 inputs: "tuple[GenericOperandDesc, ...]"
1144 outputs: "tuple[GenericOperandDesc, ...]"
1145 immediates: "tuple[range, ...]"
1146 is_copy: bool
1147 is_load_immediate: bool
1148 has_side_effects: bool
1149
1150 def __init__(
1151 self, demo_asm, # type: str
1152 inputs, # type: Iterable[GenericOperandDesc]
1153 outputs, # type: Iterable[GenericOperandDesc]
1154 immediates=(), # type: Iterable[range]
1155 is_copy=False, # type: bool
1156 is_load_immediate=False, # type: bool
1157 has_side_effects=False, # type: bool
1158 ):
1159 # type: (...) -> None
1160 object.__setattr__(self, "demo_asm", demo_asm)
1161 object.__setattr__(self, "inputs", tuple(inputs))
1162 for inp in self.inputs:
1163 if inp.tied_input_index is not None:
1164 raise ValueError(
1165 f"tied_input_index is not allowed on inputs: {inp}")
1166 if inp.write_stage is not OpStage.Early:
1167 raise ValueError(
1168 f"write_stage is not allowed on inputs: {inp}")
1169 object.__setattr__(self, "outputs", tuple(outputs))
1170 fixed_locs = [] # type: list[tuple[Loc, int]]
1171 for idx, out in enumerate(self.outputs):
1172 if out.tied_input_index is not None:
1173 if out.tied_input_index >= len(self.inputs):
1174 raise ValueError(f"tied_input_index out of range: {out}")
1175 tied_inp = self.inputs[out.tied_input_index]
1176 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1177 .with_write_stage(out.write_stage)
1178 if expected_out != out:
1179 raise ValueError(f"output can't be tied to non-equivalent "
1180 f"input: {out} tied to {tied_inp}")
1181 if out.fixed_loc is not None:
1182 for other_fixed_loc, other_idx in fixed_locs:
1183 if not other_fixed_loc.conflicts(out.fixed_loc):
1184 continue
1185 raise ValueError(
1186 f"conflicting fixed_locs: outputs[{idx}] and "
1187 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1188 f"with {other_fixed_loc}")
1189 fixed_locs.append((out.fixed_loc, idx))
1190 object.__setattr__(self, "immediates", tuple(immediates))
1191 object.__setattr__(self, "is_copy", is_copy)
1192 object.__setattr__(self, "is_load_immediate", is_load_immediate)
1193 object.__setattr__(self, "has_side_effects", has_side_effects)
1194
1195
1196 @plain_data.plain_data(frozen=True, unsafe_hash=True)
1197 @final
1198 class OpProperties:
1199 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1200
1201 def __init__(self, kind, maxvl):
1202 # type: (OpKind, int) -> None
1203 self.kind = kind # type: OpKind
1204 inputs = [] # type: list[OperandDesc]
1205 for inp in self.generic.inputs:
1206 inputs.extend(inp.instantiate(maxvl=maxvl))
1207 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1208 outputs = [] # type: list[OperandDesc]
1209 for out in self.generic.outputs:
1210 outputs.extend(out.instantiate(maxvl=maxvl))
1211 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1212 self.maxvl = maxvl # type: int
1213 copy_input_reg_len = 0
1214 for inp in self.inputs[:self.copy_inputs_len]:
1215 copy_input_reg_len += inp.ty.reg_len
1216 copy_output_reg_len = 0
1217 for out in self.outputs[:self.copy_outputs_len]:
1218 copy_output_reg_len += out.ty.reg_len
1219 if copy_input_reg_len != copy_output_reg_len:
1220 raise ValueError(f"invalid copy: copy's input reg len must "
1221 f"match its output reg len: "
1222 f"{copy_input_reg_len} != {copy_output_reg_len}")
1223 self.copy_reg_len = copy_input_reg_len
1224
1225 @property
1226 def generic(self):
1227 # type: () -> GenericOpProperties
1228 return self.kind.properties
1229
1230 @property
1231 def immediates(self):
1232 # type: () -> tuple[range, ...]
1233 return self.generic.immediates
1234
1235 @property
1236 def demo_asm(self):
1237 # type: () -> str
1238 return self.generic.demo_asm
1239
1240 @property
1241 def is_copy(self):
1242 # type: () -> bool
1243 return self.generic.is_copy
1244
1245 @property
1246 def is_load_immediate(self):
1247 # type: () -> bool
1248 return self.generic.is_load_immediate
1249
1250 @property
1251 def has_side_effects(self):
1252 # type: () -> bool
1253 return self.generic.has_side_effects
1254
1255 @cached_property
1256 def copy_inputs_len(self):
1257 # type: () -> int
1258 if not self.is_copy:
1259 return 0
1260 if self.inputs[0].spread_index is None:
1261 return 1
1262 retval = 0
1263 for i, inp in enumerate(self.inputs):
1264 if inp.spread_index != i:
1265 break
1266 retval += 1
1267 return retval
1268
1269 @cached_property
1270 def copy_outputs_len(self):
1271 # type: () -> int
1272 if not self.is_copy:
1273 return 0
1274 if self.outputs[0].spread_index is None:
1275 return 1
1276 retval = 0
1277 for i, out in enumerate(self.outputs):
1278 if out.spread_index != i:
1279 break
1280 retval += 1
1281 return retval
1282
1283
1284 IMM_S16 = range(-1 << 15, 1 << 15)
1285
1286 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1287 _SIM_FN2 = Callable[[], _SIM_FN]
1288 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1289 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1290 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1291 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1292
1293
1294 @unique
1295 @final
1296 class OpKind(Enum):
1297 def __init__(self, properties):
1298 # type: (GenericOpProperties) -> None
1299 super().__init__()
1300 self.__properties = properties
1301
1302 @property
1303 def properties(self):
1304 # type: () -> GenericOpProperties
1305 return self.__properties
1306
1307 def instantiate(self, maxvl):
1308 # type: (int) -> OpProperties
1309 return OpProperties(self, maxvl=maxvl)
1310
1311 def __repr__(self):
1312 # type: () -> str
1313 return "OpKind." + self._name_
1314
1315 @cached_property
1316 def sim(self):
1317 # type: () -> _SIM_FN
1318 return _SIM_FNS[self.properties]()
1319
1320 @cached_property
1321 def gen_asm(self):
1322 # type: () -> _GEN_ASM_FN
1323 return _GEN_ASMS[self.properties]()
1324
1325 @staticmethod
1326 def __clearca_sim(op, state):
1327 # type: (Op, BaseSimState) -> None
1328 state[op.outputs[0]] = False,
1329
1330 @staticmethod
1331 def __clearca_gen_asm(op, state):
1332 # type: (Op, GenAsmState) -> None
1333 state.writeln("addic 0, 0, 0")
1334 ClearCA = GenericOpProperties(
1335 demo_asm="addic 0, 0, 0",
1336 inputs=[],
1337 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1338 )
1339 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1340 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1341
1342 @staticmethod
1343 def __setca_sim(op, state):
1344 # type: (Op, BaseSimState) -> None
1345 state[op.outputs[0]] = True,
1346
1347 @staticmethod
1348 def __setca_gen_asm(op, state):
1349 # type: (Op, GenAsmState) -> None
1350 state.writeln("subfc 0, 0, 0")
1351 SetCA = GenericOpProperties(
1352 demo_asm="subfc 0, 0, 0",
1353 inputs=[],
1354 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1355 )
1356 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1357 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1358
1359 @staticmethod
1360 def __svadde_sim(op, state):
1361 # type: (Op, BaseSimState) -> None
1362 RA = state[op.input_vals[0]]
1363 RB = state[op.input_vals[1]]
1364 carry, = state[op.input_vals[2]]
1365 VL, = state[op.input_vals[3]]
1366 RT = [] # type: list[int]
1367 for i in range(VL):
1368 v = RA[i] + RB[i] + carry
1369 RT.append(v & GPR_VALUE_MASK)
1370 carry = (v >> GPR_SIZE_IN_BITS) != 0
1371 state[op.outputs[0]] = tuple(RT)
1372 state[op.outputs[1]] = carry,
1373
1374 @staticmethod
1375 def __svadde_gen_asm(op, state):
1376 # type: (Op, GenAsmState) -> None
1377 RT = state.vgpr(op.outputs[0])
1378 RA = state.vgpr(op.input_vals[0])
1379 RB = state.vgpr(op.input_vals[1])
1380 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1381 SvAddE = GenericOpProperties(
1382 demo_asm="sv.adde *RT, *RA, *RB",
1383 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1384 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1385 )
1386 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1387 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1388
1389 @staticmethod
1390 def __addze_sim(op, state):
1391 # type: (Op, BaseSimState) -> None
1392 RA, = state[op.input_vals[0]]
1393 carry, = state[op.input_vals[1]]
1394 v = RA + carry
1395 RT = v & GPR_VALUE_MASK
1396 carry = (v >> GPR_SIZE_IN_BITS) != 0
1397 state[op.outputs[0]] = RT,
1398 state[op.outputs[1]] = carry,
1399
1400 @staticmethod
1401 def __addze_gen_asm(op, state):
1402 # type: (Op, GenAsmState) -> None
1403 RT = state.vgpr(op.outputs[0])
1404 RA = state.vgpr(op.input_vals[0])
1405 state.writeln(f"addze {RT}, {RA}")
1406 AddZE = GenericOpProperties(
1407 demo_asm="addze RT, RA",
1408 inputs=[OD_BASE_SGPR, OD_CA],
1409 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1410 )
1411 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1412 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1413
1414 @staticmethod
1415 def __svsubfe_sim(op, state):
1416 # type: (Op, BaseSimState) -> None
1417 RA = state[op.input_vals[0]]
1418 RB = state[op.input_vals[1]]
1419 carry, = state[op.input_vals[2]]
1420 VL, = state[op.input_vals[3]]
1421 RT = [] # type: list[int]
1422 for i in range(VL):
1423 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1424 RT.append(v & GPR_VALUE_MASK)
1425 carry = (v >> GPR_SIZE_IN_BITS) != 0
1426 state[op.outputs[0]] = tuple(RT)
1427 state[op.outputs[1]] = carry,
1428
1429 @staticmethod
1430 def __svsubfe_gen_asm(op, state):
1431 # type: (Op, GenAsmState) -> None
1432 RT = state.vgpr(op.outputs[0])
1433 RA = state.vgpr(op.input_vals[0])
1434 RB = state.vgpr(op.input_vals[1])
1435 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1436 SvSubFE = GenericOpProperties(
1437 demo_asm="sv.subfe *RT, *RA, *RB",
1438 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1439 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1440 )
1441 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1442 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1443
1444 @staticmethod
1445 def __svandvs_sim(op, state):
1446 # type: (Op, BaseSimState) -> None
1447 RA = state[op.input_vals[0]]
1448 RB, = state[op.input_vals[1]]
1449 VL, = state[op.input_vals[2]]
1450 RT = [] # type: list[int]
1451 for i in range(VL):
1452 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1453 state[op.outputs[0]] = tuple(RT)
1454
1455 @staticmethod
1456 def __svandvs_gen_asm(op, state):
1457 # type: (Op, GenAsmState) -> None
1458 RT = state.vgpr(op.outputs[0])
1459 RA = state.vgpr(op.input_vals[0])
1460 RB = state.sgpr(op.input_vals[1])
1461 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1462 SvAndVS = GenericOpProperties(
1463 demo_asm="sv.and *RT, *RA, RB",
1464 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1465 outputs=[OD_EXTRA3_VGPR],
1466 )
1467 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1468 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1469
1470 @staticmethod
1471 def __svmaddedu_sim(op, state):
1472 # type: (Op, BaseSimState) -> None
1473 RA = state[op.input_vals[0]]
1474 RB, = state[op.input_vals[1]]
1475 carry, = state[op.input_vals[2]]
1476 VL, = state[op.input_vals[3]]
1477 RT = [] # type: list[int]
1478 for i in range(VL):
1479 v = RA[i] * RB + carry
1480 RT.append(v & GPR_VALUE_MASK)
1481 carry = v >> GPR_SIZE_IN_BITS
1482 state[op.outputs[0]] = tuple(RT)
1483 state[op.outputs[1]] = carry,
1484
1485 @staticmethod
1486 def __svmaddedu_gen_asm(op, state):
1487 # type: (Op, GenAsmState) -> None
1488 RT = state.vgpr(op.outputs[0])
1489 RA = state.vgpr(op.input_vals[0])
1490 RB = state.sgpr(op.input_vals[1])
1491 RC = state.sgpr(op.input_vals[2])
1492 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1493 SvMAddEDU = GenericOpProperties(
1494 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1495 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1496 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1497 )
1498 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1499 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1500
1501 @staticmethod
1502 def __sradi_sim(op, state):
1503 # type: (Op, BaseSimState) -> None
1504 rs, = state[op.input_vals[0]]
1505 imm = op.immediates[0]
1506 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1507 rs -= 1 << GPR_SIZE_IN_BITS
1508 v = rs >> imm
1509 RA = v & GPR_VALUE_MASK
1510 CA = (RA << imm) != rs
1511 state[op.outputs[0]] = RA,
1512 state[op.outputs[1]] = CA,
1513
1514 @staticmethod
1515 def __sradi_gen_asm(op, state):
1516 # type: (Op, GenAsmState) -> None
1517 RA = state.sgpr(op.outputs[0])
1518 RS = state.sgpr(op.input_vals[0])
1519 imm = op.immediates[0]
1520 state.writeln(f"sradi {RA}, {RS}, {imm}")
1521 SRADI = GenericOpProperties(
1522 demo_asm="sradi RA, RS, imm",
1523 inputs=[OD_BASE_SGPR],
1524 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1525 OD_CA.with_write_stage(OpStage.Late)],
1526 immediates=[range(GPR_SIZE_IN_BITS)],
1527 )
1528 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1529 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1530
1531 @staticmethod
1532 def __setvli_sim(op, state):
1533 # type: (Op, BaseSimState) -> None
1534 state[op.outputs[0]] = op.immediates[0],
1535
1536 @staticmethod
1537 def __setvli_gen_asm(op, state):
1538 # type: (Op, GenAsmState) -> None
1539 imm = op.immediates[0]
1540 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1541 SetVLI = GenericOpProperties(
1542 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1543 inputs=(),
1544 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1545 immediates=[range(1, 65)],
1546 is_load_immediate=True,
1547 )
1548 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1549 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1550
1551 @staticmethod
1552 def __svli_sim(op, state):
1553 # type: (Op, BaseSimState) -> None
1554 VL, = state[op.input_vals[0]]
1555 imm = op.immediates[0] & GPR_VALUE_MASK
1556 state[op.outputs[0]] = (imm,) * VL
1557
1558 @staticmethod
1559 def __svli_gen_asm(op, state):
1560 # type: (Op, GenAsmState) -> None
1561 RT = state.vgpr(op.outputs[0])
1562 imm = op.immediates[0]
1563 state.writeln(f"sv.addi {RT}, 0, {imm}")
1564 SvLI = GenericOpProperties(
1565 demo_asm="sv.addi *RT, 0, imm",
1566 inputs=[OD_VL],
1567 outputs=[OD_EXTRA3_VGPR],
1568 immediates=[IMM_S16],
1569 is_load_immediate=True,
1570 )
1571 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1572 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1573
1574 @staticmethod
1575 def __li_sim(op, state):
1576 # type: (Op, BaseSimState) -> None
1577 imm = op.immediates[0] & GPR_VALUE_MASK
1578 state[op.outputs[0]] = imm,
1579
1580 @staticmethod
1581 def __li_gen_asm(op, state):
1582 # type: (Op, GenAsmState) -> None
1583 RT = state.sgpr(op.outputs[0])
1584 imm = op.immediates[0]
1585 state.writeln(f"addi {RT}, 0, {imm}")
1586 LI = GenericOpProperties(
1587 demo_asm="addi RT, 0, imm",
1588 inputs=(),
1589 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1590 immediates=[IMM_S16],
1591 is_load_immediate=True,
1592 )
1593 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1594 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1595
1596 @staticmethod
1597 def __veccopytoreg_sim(op, state):
1598 # type: (Op, BaseSimState) -> None
1599 state[op.outputs[0]] = state[op.input_vals[0]]
1600
1601 @staticmethod
1602 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1603 # type: (Loc, Loc, bool, GenAsmState) -> None
1604 sv = "sv." if is_vec else ""
1605 rev = ""
1606 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1607 rev = "/mrr"
1608 if src_loc == dest_loc:
1609 return # no-op
1610 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1611 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1612 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1613 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1614 if src_loc.kind is LocKind.StackI64:
1615 if dest_loc.kind is LocKind.StackI64:
1616 raise ValueError(
1617 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1618 elif dest_loc.kind is not LocKind.GPR:
1619 assert_never(dest_loc.kind)
1620 src = state.stack(src_loc)
1621 dest = state.gpr(dest_loc, is_vec=is_vec)
1622 state.writeln(f"{sv}ld {dest}, {src}")
1623 elif dest_loc.kind is LocKind.StackI64:
1624 if src_loc.kind is not LocKind.GPR:
1625 assert_never(src_loc.kind)
1626 src = state.gpr(src_loc, is_vec=is_vec)
1627 dest = state.stack(dest_loc)
1628 state.writeln(f"{sv}std {src}, {dest}")
1629 elif src_loc.kind is LocKind.GPR:
1630 if dest_loc.kind is not LocKind.GPR:
1631 assert_never(dest_loc.kind)
1632 src = state.gpr(src_loc, is_vec=is_vec)
1633 dest = state.gpr(dest_loc, is_vec=is_vec)
1634 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1635 else:
1636 assert_never(src_loc.kind)
1637
1638 @staticmethod
1639 def __veccopytoreg_gen_asm(op, state):
1640 # type: (Op, GenAsmState) -> None
1641 OpKind.__copy_to_from_reg_gen_asm(
1642 src_loc=state.loc(
1643 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1644 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1645 is_vec=True, state=state)
1646
1647 VecCopyToReg = GenericOpProperties(
1648 demo_asm="sv.mv dest, src",
1649 inputs=[GenericOperandDesc(
1650 ty=GenericTy(BaseTy.I64, is_vec=True),
1651 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1652 ), OD_VL],
1653 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1654 is_copy=True,
1655 )
1656 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1657 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1658
1659 @staticmethod
1660 def __veccopyfromreg_sim(op, state):
1661 # type: (Op, BaseSimState) -> None
1662 state[op.outputs[0]] = state[op.input_vals[0]]
1663
1664 @staticmethod
1665 def __veccopyfromreg_gen_asm(op, state):
1666 # type: (Op, GenAsmState) -> None
1667 OpKind.__copy_to_from_reg_gen_asm(
1668 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1669 dest_loc=state.loc(
1670 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1671 is_vec=True, state=state)
1672 VecCopyFromReg = GenericOpProperties(
1673 demo_asm="sv.mv dest, src",
1674 inputs=[OD_EXTRA3_VGPR, OD_VL],
1675 outputs=[GenericOperandDesc(
1676 ty=GenericTy(BaseTy.I64, is_vec=True),
1677 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1678 write_stage=OpStage.Late,
1679 )],
1680 is_copy=True,
1681 )
1682 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1683 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1684
1685 @staticmethod
1686 def __copytoreg_sim(op, state):
1687 # type: (Op, BaseSimState) -> None
1688 state[op.outputs[0]] = state[op.input_vals[0]]
1689
1690 @staticmethod
1691 def __copytoreg_gen_asm(op, state):
1692 # type: (Op, GenAsmState) -> None
1693 OpKind.__copy_to_from_reg_gen_asm(
1694 src_loc=state.loc(
1695 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1696 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1697 is_vec=False, state=state)
1698 CopyToReg = GenericOpProperties(
1699 demo_asm="mv dest, src",
1700 inputs=[GenericOperandDesc(
1701 ty=GenericTy(BaseTy.I64, is_vec=False),
1702 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1703 LocSubKind.StackI64],
1704 )],
1705 outputs=[GenericOperandDesc(
1706 ty=GenericTy(BaseTy.I64, is_vec=False),
1707 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1708 write_stage=OpStage.Late,
1709 )],
1710 is_copy=True,
1711 )
1712 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1713 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1714
1715 @staticmethod
1716 def __copyfromreg_sim(op, state):
1717 # type: (Op, BaseSimState) -> None
1718 state[op.outputs[0]] = state[op.input_vals[0]]
1719
1720 @staticmethod
1721 def __copyfromreg_gen_asm(op, state):
1722 # type: (Op, GenAsmState) -> None
1723 OpKind.__copy_to_from_reg_gen_asm(
1724 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1725 dest_loc=state.loc(
1726 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1727 is_vec=False, state=state)
1728 CopyFromReg = GenericOpProperties(
1729 demo_asm="mv dest, src",
1730 inputs=[GenericOperandDesc(
1731 ty=GenericTy(BaseTy.I64, is_vec=False),
1732 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1733 )],
1734 outputs=[GenericOperandDesc(
1735 ty=GenericTy(BaseTy.I64, is_vec=False),
1736 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1737 LocSubKind.StackI64],
1738 write_stage=OpStage.Late,
1739 )],
1740 is_copy=True,
1741 )
1742 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1743 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1744
1745 @staticmethod
1746 def __concat_sim(op, state):
1747 # type: (Op, BaseSimState) -> None
1748 state[op.outputs[0]] = tuple(
1749 state[i][0] for i in op.input_vals[:-1])
1750
1751 @staticmethod
1752 def __concat_gen_asm(op, state):
1753 # type: (Op, GenAsmState) -> None
1754 OpKind.__copy_to_from_reg_gen_asm(
1755 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1756 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1757 is_vec=True, state=state)
1758 Concat = GenericOpProperties(
1759 demo_asm="sv.mv dest, src",
1760 inputs=[GenericOperandDesc(
1761 ty=GenericTy(BaseTy.I64, is_vec=False),
1762 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1763 spread=True,
1764 ), OD_VL],
1765 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1766 is_copy=True,
1767 )
1768 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1769 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1770
1771 @staticmethod
1772 def __spread_sim(op, state):
1773 # type: (Op, BaseSimState) -> None
1774 for idx, inp in enumerate(state[op.input_vals[0]]):
1775 state[op.outputs[idx]] = inp,
1776
1777 @staticmethod
1778 def __spread_gen_asm(op, state):
1779 # type: (Op, GenAsmState) -> None
1780 OpKind.__copy_to_from_reg_gen_asm(
1781 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1782 dest_loc=state.loc(op.outputs, LocKind.GPR),
1783 is_vec=True, state=state)
1784 Spread = GenericOpProperties(
1785 demo_asm="sv.mv dest, src",
1786 inputs=[OD_EXTRA3_VGPR, OD_VL],
1787 outputs=[GenericOperandDesc(
1788 ty=GenericTy(BaseTy.I64, is_vec=False),
1789 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1790 spread=True,
1791 write_stage=OpStage.Late,
1792 )],
1793 is_copy=True,
1794 )
1795 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1796 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1797
1798 @staticmethod
1799 def __svld_sim(op, state):
1800 # type: (Op, BaseSimState) -> None
1801 RA, = state[op.input_vals[0]]
1802 VL, = state[op.input_vals[1]]
1803 addr = RA + op.immediates[0]
1804 RT = [] # type: list[int]
1805 for i in range(VL):
1806 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1807 RT.append(v & GPR_VALUE_MASK)
1808 state[op.outputs[0]] = tuple(RT)
1809
1810 @staticmethod
1811 def __svld_gen_asm(op, state):
1812 # type: (Op, GenAsmState) -> None
1813 RA = state.sgpr(op.input_vals[0])
1814 RT = state.vgpr(op.outputs[0])
1815 imm = op.immediates[0]
1816 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1817 SvLd = GenericOpProperties(
1818 demo_asm="sv.ld *RT, imm(RA)",
1819 inputs=[OD_EXTRA3_SGPR, OD_VL],
1820 outputs=[OD_EXTRA3_VGPR],
1821 immediates=[IMM_S16],
1822 )
1823 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1824 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1825
1826 @staticmethod
1827 def __ld_sim(op, state):
1828 # type: (Op, BaseSimState) -> None
1829 RA, = state[op.input_vals[0]]
1830 addr = RA + op.immediates[0]
1831 v = state.load(addr)
1832 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1833
1834 @staticmethod
1835 def __ld_gen_asm(op, state):
1836 # type: (Op, GenAsmState) -> None
1837 RA = state.sgpr(op.input_vals[0])
1838 RT = state.sgpr(op.outputs[0])
1839 imm = op.immediates[0]
1840 state.writeln(f"ld {RT}, {imm}({RA})")
1841 Ld = GenericOpProperties(
1842 demo_asm="ld RT, imm(RA)",
1843 inputs=[OD_BASE_SGPR],
1844 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1845 immediates=[IMM_S16],
1846 )
1847 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1848 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1849
1850 @staticmethod
1851 def __svstd_sim(op, state):
1852 # type: (Op, BaseSimState) -> None
1853 RS = state[op.input_vals[0]]
1854 RA, = state[op.input_vals[1]]
1855 VL, = state[op.input_vals[2]]
1856 addr = RA + op.immediates[0]
1857 for i in range(VL):
1858 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1859
1860 @staticmethod
1861 def __svstd_gen_asm(op, state):
1862 # type: (Op, GenAsmState) -> None
1863 RS = state.vgpr(op.input_vals[0])
1864 RA = state.sgpr(op.input_vals[1])
1865 imm = op.immediates[0]
1866 state.writeln(f"sv.std {RS}, {imm}({RA})")
1867 SvStd = GenericOpProperties(
1868 demo_asm="sv.std *RS, imm(RA)",
1869 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1870 outputs=[],
1871 immediates=[IMM_S16],
1872 has_side_effects=True,
1873 )
1874 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1875 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1876
1877 @staticmethod
1878 def __std_sim(op, state):
1879 # type: (Op, BaseSimState) -> None
1880 RS, = state[op.input_vals[0]]
1881 RA, = state[op.input_vals[1]]
1882 addr = RA + op.immediates[0]
1883 state.store(addr, value=RS)
1884
1885 @staticmethod
1886 def __std_gen_asm(op, state):
1887 # type: (Op, GenAsmState) -> None
1888 RS = state.sgpr(op.input_vals[0])
1889 RA = state.sgpr(op.input_vals[1])
1890 imm = op.immediates[0]
1891 state.writeln(f"std {RS}, {imm}({RA})")
1892 Std = GenericOpProperties(
1893 demo_asm="std RS, imm(RA)",
1894 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1895 outputs=[],
1896 immediates=[IMM_S16],
1897 has_side_effects=True,
1898 )
1899 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1900 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1901
1902 @staticmethod
1903 def __funcargr3_sim(op, state):
1904 # type: (Op, BaseSimState) -> None
1905 pass # return value set before simulation
1906
1907 @staticmethod
1908 def __funcargr3_gen_asm(op, state):
1909 # type: (Op, GenAsmState) -> None
1910 pass # no instructions needed
1911 FuncArgR3 = GenericOpProperties(
1912 demo_asm="",
1913 inputs=[],
1914 outputs=[OD_BASE_SGPR.with_fixed_loc(
1915 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1916 )
1917 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1918 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1919
1920
1921 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1922 class SSAValOrUse(Interned):
1923 op: "Op"
1924 operand_idx: int
1925
1926 def __post_init__(self):
1927 if self.operand_idx < 0 or \
1928 self.operand_idx >= len(self.descriptor_array):
1929 raise ValueError("invalid operand_idx")
1930
1931 @abstractmethod
1932 def __repr__(self):
1933 # type: () -> str
1934 ...
1935
1936 @property
1937 @abstractmethod
1938 def descriptor_array(self):
1939 # type: () -> tuple[OperandDesc, ...]
1940 ...
1941
1942 @cached_property
1943 def defining_descriptor(self):
1944 # type: () -> OperandDesc
1945 return self.descriptor_array[self.operand_idx]
1946
1947 @cached_property
1948 def ty(self):
1949 # type: () -> Ty
1950 return self.defining_descriptor.ty
1951
1952 @cached_property
1953 def ty_before_spread(self):
1954 # type: () -> Ty
1955 return self.defining_descriptor.ty_before_spread
1956
1957 @property
1958 def base_ty(self):
1959 # type: () -> BaseTy
1960 return self.ty_before_spread.base_ty
1961
1962 @property
1963 def reg_offset_in_unspread(self):
1964 """ the number of reg-sized slots in the unspread Loc before self's Loc
1965
1966 e.g. if the unspread Loc containing self is:
1967 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1968 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1969 then reg_offset_into_unspread == 2 == 10 - 8
1970 """
1971 return self.defining_descriptor.reg_offset_in_unspread
1972
1973 @property
1974 def unspread_start_idx(self):
1975 # type: () -> int
1976 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1977
1978 @property
1979 def unspread_start(self):
1980 # type: () -> Self
1981 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1982
1983
1984 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
1985 @final
1986 class SSAVal(SSAValOrUse):
1987 __slots__ = ()
1988
1989 def __repr__(self):
1990 # type: () -> str
1991 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1992
1993 @cached_property
1994 def def_loc_set_before_spread(self):
1995 # type: () -> LocSet
1996 return self.defining_descriptor.loc_set_before_spread
1997
1998 @cached_property
1999 def descriptor_array(self):
2000 # type: () -> tuple[OperandDesc, ...]
2001 return self.op.properties.outputs
2002
2003 @cached_property
2004 def tied_input(self):
2005 # type: () -> None | SSAUse
2006 if self.defining_descriptor.tied_input_index is None:
2007 return None
2008 return SSAUse(op=self.op,
2009 operand_idx=self.defining_descriptor.tied_input_index)
2010
2011 @property
2012 def write_stage(self):
2013 # type: () -> OpStage
2014 return self.defining_descriptor.write_stage
2015
2016 @property
2017 def current_debugging_value(self):
2018 # type: () -> tuple[int, ...]
2019 """ get the current value for debugging in pdb or similar.
2020
2021 This is intended for use with
2022 `PreRASimState.set_current_debugging_state`.
2023
2024 This is only intended for debugging, do not use in unit tests or
2025 production code.
2026 """
2027 return PreRASimState.get_current_debugging_state()[self]
2028
2029 @cached_property
2030 def ssa_val_sub_regs(self):
2031 # type: () -> tuple[SSAValSubReg, ...]
2032 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
2033
2034
2035 @dataclasses.dataclass(frozen=True, unsafe_hash=True, repr=False)
2036 @final
2037 class SSAUse(SSAValOrUse):
2038 __slots__ = ()
2039
2040 @cached_property
2041 def use_loc_set_before_spread(self):
2042 # type: () -> LocSet
2043 return self.defining_descriptor.loc_set_before_spread
2044
2045 @cached_property
2046 def descriptor_array(self):
2047 # type: () -> tuple[OperandDesc, ...]
2048 return self.op.properties.inputs
2049
2050 def __repr__(self):
2051 # type: () -> str
2052 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
2053
2054 @property
2055 def ssa_val(self):
2056 # type: () -> SSAVal
2057 return self.op.input_vals[self.operand_idx]
2058
2059 @ssa_val.setter
2060 def ssa_val(self, ssa_val):
2061 # type: (SSAVal) -> None
2062 self.op.input_vals[self.operand_idx] = ssa_val
2063
2064
2065 _T = TypeVar("_T")
2066 _Desc = TypeVar("_Desc")
2067
2068
2069 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
2070 @abstractmethod
2071 def _verify_write_with_desc(self, idx, item, desc):
2072 # type: (int, _T | Any, _Desc) -> None
2073 raise NotImplementedError
2074
2075 @final
2076 def _verify_write(self, idx, item):
2077 # type: (int | Any, _T | Any) -> int
2078 if not isinstance(idx, int):
2079 if isinstance(idx, slice):
2080 raise TypeError(
2081 f"can't write to slice of {self.__class__.__name__}")
2082 raise TypeError(f"can't write with index {idx!r}")
2083 # normalize idx, raising IndexError if it is out of range
2084 idx = range(len(self.descriptors))[idx]
2085 desc = self.descriptors[idx]
2086 self._verify_write_with_desc(idx, item, desc)
2087 return idx
2088
2089 def _on_set(self, idx, new_item, old_item):
2090 # type: (int, _T, _T | None) -> None
2091 pass
2092
2093 @abstractmethod
2094 def _get_descriptors(self):
2095 # type: () -> tuple[_Desc, ...]
2096 raise NotImplementedError
2097
2098 @cached_property
2099 @final
2100 def descriptors(self):
2101 # type: () -> tuple[_Desc, ...]
2102 return self._get_descriptors()
2103
2104 @property
2105 @final
2106 def op(self):
2107 return self.__op
2108
2109 def __init__(self, items, op):
2110 # type: (Iterable[_T], Op) -> None
2111 super().__init__()
2112 self.__op = op
2113 self.__items = [] # type: list[_T]
2114 for idx, item in enumerate(items):
2115 if idx >= len(self.descriptors):
2116 raise ValueError("too many items")
2117 _ = self._verify_write(idx, item)
2118 self.__items.append(item)
2119 if len(self.__items) < len(self.descriptors):
2120 raise ValueError("not enough items")
2121
2122 @final
2123 def __iter__(self):
2124 # type: () -> Iterator[_T]
2125 yield from self.__items
2126
2127 @overload
2128 def __getitem__(self, idx):
2129 # type: (int) -> _T
2130 ...
2131
2132 @overload
2133 def __getitem__(self, idx):
2134 # type: (slice) -> list[_T]
2135 ...
2136
2137 @final
2138 def __getitem__(self, idx):
2139 # type: (int | slice) -> _T | list[_T]
2140 return self.__items[idx]
2141
2142 @final
2143 def __setitem__(self, idx, item):
2144 # type: (int, _T) -> None
2145 idx = self._verify_write(idx, item)
2146 self.__items[idx] = item
2147
2148 @final
2149 def __len__(self):
2150 # type: () -> int
2151 return len(self.__items)
2152
2153 def __repr__(self):
2154 # type: () -> str
2155 return f"{self.__class__.__name__}({self.__items}, op=...)"
2156
2157
2158 @final
2159 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2160 def _get_descriptors(self):
2161 # type: () -> tuple[OperandDesc, ...]
2162 return self.op.properties.inputs
2163
2164 def _verify_write_with_desc(self, idx, item, desc):
2165 # type: (int, SSAVal | Any, OperandDesc) -> None
2166 if not isinstance(item, SSAVal):
2167 raise TypeError("expected value of type SSAVal")
2168 if item.ty != desc.ty:
2169 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2170 f"corresponding input's type {desc.ty!r}")
2171
2172 def _on_set(self, idx, new_item, old_item):
2173 # type: (int, SSAVal, SSAVal | None) -> None
2174 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2175
2176 def __init__(self, items, op):
2177 # type: (Iterable[SSAVal], Op) -> None
2178 if hasattr(op, "inputs"):
2179 raise ValueError("Op.inputs already set")
2180 super().__init__(items, op)
2181
2182
2183 @final
2184 class OpImmediates(OpInputSeq[int, range]):
2185 def _get_descriptors(self):
2186 # type: () -> tuple[range, ...]
2187 return self.op.properties.immediates
2188
2189 def _verify_write_with_desc(self, idx, item, desc):
2190 # type: (int, int | Any, range) -> None
2191 if not isinstance(item, int):
2192 raise TypeError("expected value of type int")
2193 if item not in desc:
2194 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2195
2196 def __init__(self, items, op):
2197 # type: (Iterable[int], Op) -> None
2198 if hasattr(op, "immediates"):
2199 raise ValueError("Op.immediates already set")
2200 super().__init__(items, op)
2201
2202
2203 @plain_data.plain_data(frozen=True, eq=False, repr=False)
2204 @final
2205 class Op:
2206 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2207 "outputs", "name")
2208
2209 def __init__(self, fn, properties, input_vals, immediates, name=""):
2210 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2211 self.fn = fn
2212 self.properties = properties
2213 self.input_vals = OpInputVals(input_vals, op=self)
2214 inputs_len = len(self.properties.inputs)
2215 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2216 self.immediates = OpImmediates(immediates, op=self)
2217 outputs_len = len(self.properties.outputs)
2218 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2219 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2220
2221 @property
2222 def kind(self):
2223 # type: () -> OpKind
2224 return self.properties.kind
2225
2226 def __eq__(self, other):
2227 # type: (Op | Any) -> bool
2228 if isinstance(other, Op):
2229 return self is other
2230 return NotImplemented
2231
2232 def __hash__(self):
2233 # type: () -> int
2234 return object.__hash__(self)
2235
2236 def __repr__(self, wrap_width=63, indent=" "):
2237 # type: (int, str) -> str
2238 WRAP_POINT = "\u200B" # zero-width space
2239 items = [f"{self.name}:\n"]
2240 for i, out in enumerate(self.outputs):
2241 item = f"<...outputs[{i}]: {out.ty}>"
2242 if i == 0:
2243 item = "(" + WRAP_POINT + item
2244 if i != len(self.outputs) - 1:
2245 item += ", " + WRAP_POINT
2246 else:
2247 item += WRAP_POINT + ") <= "
2248 items.append(item)
2249 items.append(self.kind._name_)
2250 if len(self.input_vals) + len(self.immediates) != 0:
2251 items[-1] += "("
2252 items[-1] += WRAP_POINT
2253 for i, inp in enumerate(self.input_vals):
2254 item = repr(inp)
2255 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2256 item += ", " + WRAP_POINT
2257 else:
2258 item += ") " + WRAP_POINT
2259 items.append(item)
2260 for i, imm in enumerate(self.immediates):
2261 item = hex(imm)
2262 if i != len(self.immediates) - 1:
2263 item += ", " + WRAP_POINT
2264 else:
2265 item += ") " + WRAP_POINT
2266 items.append(item)
2267 lines = [] # type: list[str]
2268 for i, line_in in enumerate("".join(items).splitlines()):
2269 if i != 0:
2270 line_in = indent + line_in
2271 line_out = ""
2272 for part in line_in.split(WRAP_POINT):
2273 if line_out == "":
2274 line_out = part
2275 continue
2276 trial_line_out = line_out + part
2277 if len(trial_line_out.rstrip()) > wrap_width:
2278 lines.append(line_out.rstrip())
2279 line_out = indent + part
2280 else:
2281 line_out = trial_line_out
2282 lines.append(line_out.rstrip())
2283 return "\n".join(lines)
2284
2285 def sim(self, state):
2286 # type: (BaseSimState) -> None
2287 for inp in self.input_vals:
2288 try:
2289 val = state[inp]
2290 except KeyError:
2291 raise ValueError(f"SSAVal {inp} not yet assigned when "
2292 f"running {self}")
2293 except SimSkipOp:
2294 continue
2295 if len(val) != inp.ty.reg_len:
2296 raise ValueError(
2297 f"value of SSAVal {inp} has wrong number of elements: "
2298 f"expected {inp.ty.reg_len} found "
2299 f"{len(val)}: {val!r}")
2300 if isinstance(state, PreRASimState):
2301 for out in self.outputs:
2302 if out in state.ssa_vals:
2303 if self.kind is OpKind.FuncArgR3:
2304 continue
2305 raise ValueError(f"SSAVal {out} already assigned before "
2306 f"running {self}")
2307 try:
2308 self.kind.sim(self, state)
2309 except SimSkipOp:
2310 state.on_skip(self)
2311 for out in self.outputs:
2312 try:
2313 val = state[out]
2314 except KeyError:
2315 raise ValueError(f"running {self} failed to assign to {out}")
2316 except SimSkipOp:
2317 continue
2318 if len(val) != out.ty.reg_len:
2319 raise ValueError(
2320 f"value of SSAVal {out} has wrong number of elements: "
2321 f"expected {out.ty.reg_len} found "
2322 f"{len(val)}: {val!r}")
2323
2324 def gen_asm(self, state):
2325 # type: (GenAsmState) -> None
2326 all_loc_kinds = tuple(LocKind)
2327 for inp in self.input_vals:
2328 state.loc(inp, expected_kinds=all_loc_kinds)
2329 for out in self.outputs:
2330 state.loc(out, expected_kinds=all_loc_kinds)
2331 self.kind.gen_asm(self, state)
2332
2333
2334 @plain_data.plain_data(frozen=True, repr=False)
2335 class BaseSimState(metaclass=ABCMeta):
2336 __slots__ = "memory",
2337
2338 def __init__(self, memory):
2339 # type: (dict[int, int]) -> None
2340 super().__init__()
2341 self.memory = memory # type: dict[int, int]
2342
2343 def _default_memory_value(self):
2344 # type: () -> int
2345 return 0
2346
2347 def on_skip(self, op):
2348 # type: (Op) -> None
2349 raise ValueError("skipping instructions not supported")
2350
2351 def load_byte(self, addr):
2352 # type: (int) -> int
2353 addr &= GPR_VALUE_MASK
2354 try:
2355 return self.memory[addr] & 0xFF
2356 except KeyError:
2357 return self._default_memory_value()
2358
2359 def store_byte(self, addr, value):
2360 # type: (int, int) -> None
2361 addr &= GPR_VALUE_MASK
2362 value &= 0xFF
2363 self.memory[addr] = value
2364
2365 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2366 # type: (int, int, bool) -> int
2367 if addr % size_in_bytes != 0:
2368 raise ValueError(f"address not aligned: {hex(addr)} "
2369 f"required alignment: {size_in_bytes}")
2370 retval = 0
2371 for i in range(size_in_bytes):
2372 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2373 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2374 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2375 return retval
2376
2377 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2378 # type: (int, int, int) -> None
2379 if addr % size_in_bytes != 0:
2380 raise ValueError(f"address not aligned: {hex(addr)} "
2381 f"required alignment: {size_in_bytes}")
2382 for i in range(size_in_bytes):
2383 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2384
2385 def _memory__repr(self):
2386 # type: () -> str
2387 if len(self.memory) == 0:
2388 return "{}"
2389 keys = sorted(self.memory.keys(), reverse=True)
2390 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2391 items = [] # type: list[str]
2392 while len(keys) != 0:
2393 addr = keys[-1]
2394 if (len(keys) >= CHUNK_SIZE
2395 and addr % CHUNK_SIZE == 0
2396 and keys[-CHUNK_SIZE:]
2397 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2398 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2399 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2400 keys[-CHUNK_SIZE:] = ()
2401 else:
2402 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2403 if len(items) == 1:
2404 return f"{{{items[0]}}}"
2405 items_str = ",\n".join(items)
2406 return f"{{\n{items_str}}}"
2407
2408 def __repr__(self):
2409 # type: () -> str
2410 field_vals = [] # type: list[str]
2411 for name in plain_data.fields(self):
2412 try:
2413 value = getattr(self, name)
2414 except AttributeError:
2415 field_vals.append(f"{name}=<not set>")
2416 continue
2417 repr_fn = getattr(self, f"_{name}__repr", None)
2418 if callable(repr_fn):
2419 field_vals.append(f"{name}={repr_fn()}")
2420 else:
2421 field_vals.append(f"{name}={value!r}")
2422 field_vals_str = ", ".join(field_vals)
2423 return f"{self.__class__.__name__}({field_vals_str})"
2424
2425 @abstractmethod
2426 def __getitem__(self, ssa_val):
2427 # type: (SSAVal) -> tuple[int, ...]
2428 ...
2429
2430 @abstractmethod
2431 def __setitem__(self, ssa_val, value):
2432 # type: (SSAVal, Iterable[int]) -> None
2433 ...
2434
2435
2436 @plain_data.plain_data(frozen=True, repr=False)
2437 class PreRABaseSimState(BaseSimState):
2438 __slots__ = "ssa_vals",
2439
2440 def __init__(self, ssa_vals, memory):
2441 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2442 super().__init__(memory)
2443 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2444
2445 def _ssa_vals__repr(self):
2446 # type: () -> str
2447 if len(self.ssa_vals) == 0:
2448 return "{}"
2449 items = [] # type: list[str]
2450 CHUNK_SIZE = 4
2451 for k, v in self.ssa_vals.items():
2452 element_strs = [] # type: list[str]
2453 for i, el in enumerate(v):
2454 if i % CHUNK_SIZE != 0:
2455 element_strs.append(" " + hex(el))
2456 else:
2457 element_strs.append("\n " + hex(el))
2458 if len(element_strs) <= CHUNK_SIZE:
2459 element_strs[0] = element_strs[0].lstrip()
2460 if len(element_strs) == 1:
2461 element_strs.append("")
2462 v_str = ",".join(element_strs)
2463 items.append(f"{k!r}: ({v_str})")
2464 if len(items) == 1 and "\n" not in items[0]:
2465 return f"{{{items[0]}}}"
2466 items_str = ",\n".join(items)
2467 return f"{{\n{items_str},\n}}"
2468
2469 def __getitem__(self, ssa_val):
2470 # type: (SSAVal) -> tuple[int, ...]
2471 try:
2472 return self.ssa_vals[ssa_val]
2473 except KeyError:
2474 return self._handle_undefined_ssa_val(ssa_val)
2475
2476 def _handle_undefined_ssa_val(self, ssa_val):
2477 # type: (SSAVal) -> tuple[int, ...]
2478 raise KeyError("SSAVal has no value set", ssa_val)
2479
2480 def __setitem__(self, ssa_val, value):
2481 # type: (SSAVal, Iterable[int]) -> None
2482 value = tuple(map(int, value))
2483 if len(value) != ssa_val.ty.reg_len:
2484 raise ValueError("value has wrong len")
2485 self.ssa_vals[ssa_val] = value
2486
2487
2488 class SimSkipOp(Exception):
2489 pass
2490
2491
2492 @plain_data.plain_data(frozen=True, repr=False)
2493 @final
2494 class ConstPropagationState(PreRABaseSimState):
2495 __slots__ = "skipped_ops",
2496
2497 def __init__(self, ssa_vals, memory, skipped_ops):
2498 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2499 super().__init__(ssa_vals, memory)
2500 self.skipped_ops = skipped_ops
2501
2502 def _default_memory_value(self):
2503 # type: () -> int
2504 raise SimSkipOp
2505
2506 def _handle_undefined_ssa_val(self, ssa_val):
2507 # type: (SSAVal) -> tuple[int, ...]
2508 raise SimSkipOp
2509
2510 def on_skip(self, op):
2511 # type: (Op) -> None
2512 self.skipped_ops.add(op)
2513
2514
2515 @plain_data.plain_data(frozen=True, repr=False)
2516 class PreRASimState(PreRABaseSimState):
2517 __slots__ = ()
2518
2519 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2520
2521 @contextmanager
2522 def set_as_current_debugging_state(self):
2523 """ return a context manager that sets self as the current state for
2524 debugging in pdb or similar. This is intended only for use with
2525 `get_current_debugging_state` which should not be used in unit tests
2526 or production code.
2527 """
2528 try:
2529 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2530 yield
2531 finally:
2532 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2533 "inconsistent __CURRENT_DEBUGGING_STATE"
2534
2535 @staticmethod
2536 def get_current_debugging_state():
2537 # type: () -> PreRASimState
2538 """ get the current state for debugging in pdb or similar.
2539
2540 This is intended for use with `set_current_debugging_state`.
2541
2542 This is only intended for debugging, do not use in unit tests or
2543 production code.
2544 """
2545 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2546 raise ValueError("no current debugging state")
2547 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2548
2549
2550 @plain_data.plain_data(frozen=True, repr=False)
2551 @final
2552 class PostRASimState(BaseSimState):
2553 __slots__ = "ssa_val_to_loc_map", "loc_values"
2554
2555 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2556 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2557 super().__init__(memory)
2558 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2559 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2560 if ssa_val.ty != loc.ty:
2561 raise ValueError(
2562 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2563 self.loc_values = loc_values
2564 for loc in self.loc_values.keys():
2565 if loc.reg_len != 1:
2566 raise ValueError(
2567 "loc_values must only contain Locs with reg_len=1, all "
2568 "larger Locs will be split into reg_len=1 sub-Locs")
2569
2570 def _loc_values__repr(self):
2571 # type: () -> str
2572 locs = sorted(self.loc_values.keys(),
2573 key=lambda v: (v.kind.name, v.start))
2574 items = [] # type: list[str]
2575 for loc in locs:
2576 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2577 items_str = ",\n".join(items)
2578 return f"{{\n{items_str},\n}}"
2579
2580 def __getitem__(self, ssa_val):
2581 # type: (SSAVal) -> tuple[int, ...]
2582 loc = self.ssa_val_to_loc_map[ssa_val]
2583 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2584 retval = [] # type: list[int]
2585 for i in range(loc.reg_len):
2586 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2587 retval.append(self.loc_values.get(subloc, 0))
2588 return tuple(retval)
2589
2590 def __setitem__(self, ssa_val, value):
2591 # type: (SSAVal, Iterable[int]) -> None
2592 value = tuple(map(int, value))
2593 if len(value) != ssa_val.ty.reg_len:
2594 raise ValueError("value has wrong len")
2595 loc = self.ssa_val_to_loc_map[ssa_val]
2596 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2597 for i in range(loc.reg_len):
2598 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2599 self.loc_values[subloc] = value[i]
2600
2601
2602 @plain_data.plain_data(frozen=True)
2603 class GenAsmState:
2604 __slots__ = "allocated_locs", "output"
2605
2606 def __init__(self, allocated_locs, output=None):
2607 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2608 super().__init__()
2609 self.allocated_locs = FMap(allocated_locs)
2610 for ssa_val, loc in self.allocated_locs.items():
2611 if ssa_val.ty != loc.ty:
2612 raise ValueError(
2613 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2614 if output is None:
2615 output = []
2616 self.output = output
2617
2618 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2619
2620 def loc(self, ssa_val_or_locs, expected_kinds):
2621 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2622 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2623 ssa_val_or_locs = [ssa_val_or_locs]
2624 locs = [] # type: list[Loc]
2625 for i in ssa_val_or_locs:
2626 if isinstance(i, SSAVal):
2627 locs.append(self.allocated_locs[i])
2628 else:
2629 locs.append(i)
2630 if len(locs) == 0:
2631 raise ValueError("invalid Loc sequence: must not be empty")
2632 retval = locs[0].try_concat(*locs[1:])
2633 if retval is None:
2634 raise ValueError("invalid Loc sequence: try_concat failed")
2635 if isinstance(expected_kinds, LocKind):
2636 expected_kinds = expected_kinds,
2637 if retval.kind not in expected_kinds:
2638 if len(expected_kinds) == 1:
2639 expected_kinds = expected_kinds[0]
2640 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2641 f"{retval.kind} expected {expected_kinds}")
2642 return retval
2643
2644 def gpr(self, ssa_val_or_locs, is_vec):
2645 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2646 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2647 vec_str = "*" if is_vec else ""
2648 return vec_str + str(loc.start)
2649
2650 def sgpr(self, ssa_val_or_locs):
2651 # type: (__SSA_VAL_OR_LOCS) -> str
2652 return self.gpr(ssa_val_or_locs, is_vec=False)
2653
2654 def vgpr(self, ssa_val_or_locs):
2655 # type: (__SSA_VAL_OR_LOCS) -> str
2656 return self.gpr(ssa_val_or_locs, is_vec=True)
2657
2658 def stack(self, ssa_val_or_locs):
2659 # type: (__SSA_VAL_OR_LOCS) -> str
2660 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2661 return f"{loc.start}(1)"
2662
2663 def writeln(self, *line_segments):
2664 # type: (*str) -> None
2665 line = " ".join(line_segments)
2666 if isinstance(self.output, list):
2667 self.output.append(line)
2668 else:
2669 self.output.write(line + "\n")