add is_copy_related to interference graph edges
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from collections import defaultdict
2 from contextlib import contextmanager
3 import enum
4 from abc import ABCMeta, abstractmethod
5 from enum import Enum, unique
6 from functools import lru_cache, total_ordering
7 from io import StringIO
8 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
9 Mapping, Sequence, TypeVar, Union, overload)
10 from weakref import WeakValueDictionary as _WeakVDict
11
12 from cached_property import cached_property
13 from nmutil.plain_data import fields, plain_data
14
15 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
16 final)
17 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
18 OFSet, OSet)
19
20 GPR_SIZE_IN_BYTES = 8
21 BITS_IN_BYTE = 8
22 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
23 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
24
25
26 @final
27 class Fn:
28 def __init__(self):
29 self.ops = [] # type: list[Op]
30 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
31 self.__next_name_suffix = 2
32
33 def _add_op_with_unused_name(self, op, name=""):
34 # type: (Op, str) -> str
35 if op.fn is not self:
36 raise ValueError("can't add Op to wrong Fn")
37 if hasattr(op, "name"):
38 raise ValueError("Op already named")
39 orig_name = name
40 while True:
41 if name != "" and name not in self.__op_names:
42 self.__op_names[name] = op
43 return name
44 name = orig_name + str(self.__next_name_suffix)
45 self.__next_name_suffix += 1
46
47 def __repr__(self):
48 # type: () -> str
49 return "<Fn>"
50
51 def ops_to_str(self, as_python_literal=False, wrap_width=63,
52 python_indent=" ", indent=" "):
53 # type: (bool, int, str, str) -> str
54 l = [] # type: list[str]
55 for op in self.ops:
56 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
57 retval = "\n".join(l)
58 if as_python_literal:
59 l = [python_indent + "\""]
60 for ch in retval:
61 if ch == "\n":
62 l.append(f"\\n\"\n{python_indent}\"")
63 elif ch in "\"\\":
64 l.append("\\" + ch)
65 elif ch.isascii() and ch.isprintable():
66 l.append(ch)
67 else:
68 l.append(repr(ch).strip("\"'"))
69 l.append("\"")
70 retval = "".join(l)
71 empty_end = f"\"\n{python_indent}\"\""
72 if retval.endswith(empty_end):
73 retval = retval[:-len(empty_end)]
74 return retval
75
76 def append_op(self, op):
77 # type: (Op) -> None
78 if op.fn is not self:
79 raise ValueError("can't add Op to wrong Fn")
80 self.ops.append(op)
81
82 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
83 maxvl=1):
84 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
85 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
86 input_vals=input_vals, immediates=immediates, name=name)
87 self.append_op(retval)
88 return retval
89
90 def sim(self, state):
91 # type: (BaseSimState) -> None
92 for op in self.ops:
93 op.sim(state)
94
95 def gen_asm(self, state):
96 # type: (GenAsmState) -> None
97 for op in self.ops:
98 op.gen_asm(state)
99
100 def pre_ra_insert_copies(self):
101 # type: () -> None
102 orig_ops = list(self.ops)
103 copied_outputs = {} # type: dict[SSAVal, SSAVal]
104 setvli_outputs = {} # type: dict[SSAVal, Op]
105 self.ops.clear()
106 for op in orig_ops:
107 for i in range(len(op.input_vals)):
108 inp = copied_outputs[op.input_vals[i]]
109 if inp.ty.base_ty is BaseTy.I64:
110 maxvl = inp.ty.reg_len
111 if inp.ty.reg_len != 1:
112 setvl = self.append_new_op(
113 OpKind.SetVLI, immediates=[maxvl],
114 name=f"{op.name}.inp{i}.setvl")
115 vl = setvl.outputs[0]
116 mv = self.append_new_op(
117 OpKind.VecCopyToReg, input_vals=[inp, vl],
118 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
119 else:
120 mv = self.append_new_op(
121 OpKind.CopyToReg, input_vals=[inp],
122 name=f"{op.name}.inp{i}.copy")
123 op.input_vals[i] = mv.outputs[0]
124 elif inp.ty.base_ty is BaseTy.CA \
125 or inp.ty.base_ty is BaseTy.VL_MAXVL:
126 # all copies would be no-ops, so we don't need to copy,
127 # though we do need to rematerialize SetVLI ops right
128 # before the ops VL
129 if inp in setvli_outputs:
130 setvl = self.append_new_op(
131 OpKind.SetVLI,
132 immediates=setvli_outputs[inp].immediates,
133 name=f"{op.name}.inp{i}.setvl")
134 inp = setvl.outputs[0]
135 op.input_vals[i] = inp
136 else:
137 assert_never(inp.ty.base_ty)
138 self.ops.append(op)
139 for i, out in enumerate(op.outputs):
140 if op.kind is OpKind.SetVLI:
141 setvli_outputs[out] = op
142 if out.ty.base_ty is BaseTy.I64:
143 maxvl = out.ty.reg_len
144 if out.ty.reg_len != 1:
145 setvl = self.append_new_op(
146 OpKind.SetVLI, immediates=[maxvl],
147 name=f"{op.name}.out{i}.setvl")
148 vl = setvl.outputs[0]
149 mv = self.append_new_op(
150 OpKind.VecCopyFromReg, input_vals=[out, vl],
151 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
152 else:
153 mv = self.append_new_op(
154 OpKind.CopyFromReg, input_vals=[out],
155 name=f"{op.name}.out{i}.copy")
156 copied_outputs[out] = mv.outputs[0]
157 elif out.ty.base_ty is BaseTy.CA \
158 or out.ty.base_ty is BaseTy.VL_MAXVL:
159 # all copies would be no-ops, so we don't need to copy
160 copied_outputs[out] = out
161 else:
162 assert_never(out.ty.base_ty)
163
164
165 @final
166 @unique
167 @total_ordering
168 class OpStage(Enum):
169 value: Literal[0, 1] # type: ignore
170
171 def __new__(cls, value):
172 # type: (int) -> OpStage
173 value = int(value)
174 if value not in (0, 1):
175 raise ValueError("invalid value")
176 retval = object.__new__(cls)
177 retval._value_ = value
178 return retval
179
180 Early = 0
181 """ early stage of Op execution, where all input reads occur.
182 all output writes with `write_stage == Early` occur here too, and therefore
183 conflict with input reads, telling the compiler that it that can't share
184 that output's register with any inputs that the output isn't tied to.
185
186 All outputs, even unused outputs, can't share registers with any other
187 outputs, independent of `write_stage` settings.
188 """
189 Late = 1
190 """ late stage of Op execution, where all output writes with
191 `write_stage == Late` occur, and therefore don't conflict with input reads,
192 telling the compiler that any inputs can safely use the same register as
193 those outputs.
194
195 All outputs, even unused outputs, can't share registers with any other
196 outputs, independent of `write_stage` settings.
197 """
198
199 def __repr__(self):
200 # type: () -> str
201 return f"OpStage.{self._name_}"
202
203 def __lt__(self, other):
204 # type: (OpStage | object) -> bool
205 if isinstance(other, OpStage):
206 return self.value < other.value
207 return NotImplemented
208
209
210 assert OpStage.Early < OpStage.Late, "early must be less than late"
211
212
213 @plain_data(frozen=True, unsafe_hash=True, repr=False)
214 @final
215 @total_ordering
216 class ProgramPoint(metaclass=InternedMeta):
217 __slots__ = "op_index", "stage"
218
219 def __init__(self, op_index, stage):
220 # type: (int, OpStage) -> None
221 self.op_index = op_index
222 self.stage = stage
223
224 @property
225 def int_value(self):
226 # type: () -> int
227 """ an integer representation of `self` such that it keeps ordering and
228 successor/predecessor relations.
229 """
230 return self.op_index * 2 + self.stage.value
231
232 @staticmethod
233 def from_int_value(int_value):
234 # type: (int) -> ProgramPoint
235 op_index, stage = divmod(int_value, 2)
236 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
237
238 def next(self, steps=1):
239 # type: (int) -> ProgramPoint
240 return ProgramPoint.from_int_value(self.int_value + steps)
241
242 def prev(self, steps=1):
243 # type: (int) -> ProgramPoint
244 return self.next(steps=-steps)
245
246 def __lt__(self, other):
247 # type: (ProgramPoint | Any) -> bool
248 if not isinstance(other, ProgramPoint):
249 return NotImplemented
250 if self.op_index != other.op_index:
251 return self.op_index < other.op_index
252 return self.stage < other.stage
253
254 def __repr__(self):
255 # type: () -> str
256 return f"<ops[{self.op_index}]:{self.stage._name_}>"
257
258
259 @plain_data(frozen=True, unsafe_hash=True, repr=False)
260 @final
261 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
262 __slots__ = "start", "stop"
263
264 def __init__(self, start, stop):
265 # type: (ProgramPoint, ProgramPoint) -> None
266 self.start = start
267 self.stop = stop
268
269 @cached_property
270 def int_value_range(self):
271 # type: () -> range
272 return range(self.start.int_value, self.stop.int_value)
273
274 @staticmethod
275 def from_int_value_range(int_value_range):
276 # type: (range) -> ProgramRange
277 if int_value_range.step != 1:
278 raise ValueError("int_value_range must have step == 1")
279 return ProgramRange(
280 start=ProgramPoint.from_int_value(int_value_range.start),
281 stop=ProgramPoint.from_int_value(int_value_range.stop))
282
283 @overload
284 def __getitem__(self, __idx):
285 # type: (int) -> ProgramPoint
286 ...
287
288 @overload
289 def __getitem__(self, __idx):
290 # type: (slice) -> ProgramRange
291 ...
292
293 def __getitem__(self, __idx):
294 # type: (int | slice) -> ProgramPoint | ProgramRange
295 v = range(self.start.int_value, self.stop.int_value)[__idx]
296 if isinstance(v, int):
297 return ProgramPoint.from_int_value(v)
298 return ProgramRange.from_int_value_range(v)
299
300 def __len__(self):
301 # type: () -> int
302 return len(self.int_value_range)
303
304 def __iter__(self):
305 # type: () -> Iterator[ProgramPoint]
306 return map(ProgramPoint.from_int_value, self.int_value_range)
307
308 def __repr__(self):
309 # type: () -> str
310 start = repr(self.start).lstrip("<").rstrip(">")
311 stop = repr(self.stop).lstrip("<").rstrip(">")
312 return f"<range:{start}..{stop}>"
313
314
315 @plain_data(frozen=True, unsafe_hash=True, repr=False)
316 @final
317 class SSAValSubReg(metaclass=InternedMeta):
318 __slots__ = "ssa_val", "reg_idx"
319
320 def __init__(self, ssa_val, reg_idx):
321 # type: (SSAVal, int) -> None
322 if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
323 raise ValueError("reg_idx out of range")
324 self.ssa_val = ssa_val
325 self.reg_idx = reg_idx
326
327 def __repr__(self):
328 # type: () -> str
329 return f"{self.ssa_val}[{self.reg_idx}]"
330
331
332 @plain_data(frozen=True, eq=False, repr=False)
333 @final
334 class FnAnalysis:
335 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
336 "def_program_ranges", "use_program_points",
337 "all_program_points")
338
339 def __init__(self, fn):
340 # type: (Fn) -> None
341 self.fn = fn
342 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
343 self.all_program_points = ProgramRange(
344 start=ProgramPoint(op_index=0, stage=OpStage.Early),
345 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
346 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
347 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
348 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
349 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
350 for op in fn.ops:
351 for use in op.input_uses:
352 uses[use.ssa_val].add(use)
353 use_program_point = self.__get_use_program_point(use)
354 use_program_points[use] = use_program_point
355 live_range_stops[use.ssa_val] = max(
356 live_range_stops[use.ssa_val], use_program_point.next())
357 for out in op.outputs:
358 uses[out] = OSet()
359 def_program_range = self.__get_def_program_range(out)
360 def_program_ranges[out] = def_program_range
361 live_range_stops[out] = def_program_range.stop
362 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
363 self.def_program_ranges = FMap(def_program_ranges)
364 self.use_program_points = FMap(use_program_points)
365 live_ranges = {} # type: dict[SSAVal, ProgramRange]
366 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
367 for ssa_val in uses.keys():
368 live_ranges[ssa_val] = live_range = ProgramRange(
369 start=self.def_program_ranges[ssa_val].start,
370 stop=live_range_stops[ssa_val])
371 for program_point in live_range:
372 live_at[program_point].add(ssa_val)
373 self.live_ranges = FMap(live_ranges)
374 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
375 self.copies # initialize
376 self.const_ssa_vals # initialize
377 self.const_ssa_val_sub_regs # initialize
378
379 def __get_def_program_range(self, ssa_val):
380 # type: (SSAVal) -> ProgramRange
381 write_stage = ssa_val.defining_descriptor.write_stage
382 start = ProgramPoint(
383 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
384 # always include late stage of ssa_val.op, to ensure outputs always
385 # overlap all other outputs.
386 # stop is exclusive, so we need the next program point.
387 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
388 return ProgramRange(start=start, stop=stop)
389
390 def __get_use_program_point(self, ssa_use):
391 # type: (SSAUse) -> ProgramPoint
392 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
393 "assumed here, ensured by GenericOpProperties.__init__"
394 return ProgramPoint(
395 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
396
397 def __eq__(self, other):
398 # type: (FnAnalysis | Any) -> bool
399 if isinstance(other, FnAnalysis):
400 return self.fn == other.fn
401 return NotImplemented
402
403 def __hash__(self):
404 # type: () -> int
405 return hash(self.fn)
406
407 def __repr__(self):
408 # type: () -> str
409 return "<FnAnalysis>"
410
411 @cached_property
412 def copies(self):
413 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
414 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
415 a copy of, looking through all layers of copies. The map excludes all
416 SSAValSubRegs that aren't copies of other SSAValSubRegs.
417 This ignores inputs of copy Ops that aren't actually being copied
418 (e.g. the VL input of VecCopyToReg).
419 """
420 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
421 for op in self.op_indexes.keys():
422 if not op.properties.is_copy:
423 continue
424 copy_reg_len = op.properties.copy_reg_len
425 copy_inputs = [] # type: list[SSAValSubReg]
426 for inp in op.input_vals[:op.properties.copy_inputs_len]:
427 for inp_sub_reg in inp.ssa_val_sub_regs:
428 # propagate copies of copies
429 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
430 copy_inputs.append(inp_sub_reg)
431 assert len(copy_inputs) == copy_reg_len, "logic error"
432 copy_outputs = [] # type: list[SSAValSubReg]
433 for out in op.outputs[:op.properties.copy_outputs_len]:
434 copy_outputs.extend(out.ssa_val_sub_regs)
435 assert len(copy_outputs) == copy_reg_len, "logic error"
436 for inp, out in zip(copy_inputs, copy_outputs):
437 retval[out] = inp
438 return FMap(retval)
439
440 @cached_property
441 def copy_related_ssa_vals(self):
442 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
443 """ map from SSAVals to the full set of SSAVals that are related by
444 being sources/destinations of copies, transitively looking through all
445 copies.
446 This ignores inputs of copy Ops that aren't actually being copied
447 (e.g. the VL input of VecCopyToReg).
448 """
449 sets_map = {i: OSet([i]) for i in self.uses.keys()}
450 for k, v in self.copies.items():
451 k_set = sets_map[k.ssa_val]
452 v_set = sets_map[v.ssa_val]
453 # merge k_set and v_set
454 if k_set is v_set:
455 continue
456 k_set |= v_set
457 for i in k_set:
458 sets_map[i] = k_set
459 # this way we construct each OFSet only once rather than
460 # for each SSAVal
461 sets_set = {id(i): i for i in sets_map.values()}
462 retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
463 for v in sets_set.values():
464 v = OFSet(v)
465 for k in v:
466 retval[k] = v
467 return FMap(retval)
468
469 @cached_property
470 def const_ssa_vals(self):
471 # type: () -> FMap[SSAVal, tuple[int, ...]]
472 state = ConstPropagationState(
473 ssa_vals={}, memory={}, skipped_ops=OSet())
474 self.fn.sim(state)
475 return FMap(state.ssa_vals)
476
477 @cached_property
478 def const_ssa_val_sub_regs(self):
479 # type: () -> FMap[SSAValSubReg, int]
480 retval = {} # type: dict[SSAValSubReg, int]
481 for ssa_val, const_val in self.const_ssa_vals.items():
482 assert ssa_val.ty.reg_len == len(const_val), "logic error"
483 for reg_idx, v in enumerate(const_val):
484 retval[SSAValSubReg(ssa_val, reg_idx)] = v
485 return FMap(retval)
486
487 def is_always_equal(self, a, b):
488 # type: (SSAValSubReg, SSAValSubReg) -> bool
489 """check if a and b are known to be always equal to each other.
490 This means they can be allocated to the same location if other
491 constraints don't prevent that.
492
493 this can happen for a number of reasons, such as:
494 * a and b are copies of the same thing
495 * a and b are known to be constants and they have the same value
496 """
497 if a.ssa_val.base_ty != b.ssa_val.base_ty:
498 return False # can't be equal, they have different types
499 # look through copies
500 a = self.copies.get(a, a)
501 b = self.copies.get(b, b)
502 if a == b:
503 return True
504 # check if they have the same constant value
505 try:
506 a_const_val = self.const_ssa_val_sub_regs[a]
507 b_const_val = self.const_ssa_val_sub_regs[b]
508 if a_const_val == b_const_val:
509 return True
510 except KeyError:
511 pass
512 return False
513
514
515 @unique
516 @final
517 class BaseTy(Enum):
518 I64 = enum.auto()
519 CA = enum.auto()
520 VL_MAXVL = enum.auto()
521
522 @cached_property
523 def only_scalar(self):
524 # type: () -> bool
525 if self is BaseTy.I64:
526 return False
527 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
528 return True
529 else:
530 assert_never(self)
531
532 @cached_property
533 def max_reg_len(self):
534 # type: () -> int
535 if self is BaseTy.I64:
536 return 128
537 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
538 return 1
539 else:
540 assert_never(self)
541
542 def __repr__(self):
543 return "BaseTy." + self._name_
544
545
546 @plain_data(frozen=True, unsafe_hash=True, repr=False)
547 @final
548 class Ty(metaclass=InternedMeta):
549 __slots__ = "base_ty", "reg_len"
550
551 @staticmethod
552 def validate(base_ty, reg_len):
553 # type: (BaseTy, int) -> str | None
554 """ return a string with the error if the combination is invalid,
555 otherwise return None
556 """
557 if base_ty.only_scalar and reg_len != 1:
558 return f"can't create a vector of an only-scalar type: {base_ty}"
559 if reg_len < 1 or reg_len > base_ty.max_reg_len:
560 return "reg_len out of range"
561 return None
562
563 def __init__(self, base_ty, reg_len):
564 # type: (BaseTy, int) -> None
565 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
566 if msg is not None:
567 raise ValueError(msg)
568 self.base_ty = base_ty
569 self.reg_len = reg_len
570
571 def __repr__(self):
572 # type: () -> str
573 if self.reg_len != 1:
574 reg_len = f"*{self.reg_len}"
575 else:
576 reg_len = ""
577 return f"<{self.base_ty._name_}{reg_len}>"
578
579
580 @unique
581 @final
582 class LocKind(Enum):
583 GPR = enum.auto()
584 StackI64 = enum.auto()
585 CA = enum.auto()
586 VL_MAXVL = enum.auto()
587
588 @cached_property
589 def base_ty(self):
590 # type: () -> BaseTy
591 if self is LocKind.GPR or self is LocKind.StackI64:
592 return BaseTy.I64
593 if self is LocKind.CA:
594 return BaseTy.CA
595 if self is LocKind.VL_MAXVL:
596 return BaseTy.VL_MAXVL
597 else:
598 assert_never(self)
599
600 @cached_property
601 def loc_count(self):
602 # type: () -> int
603 if self is LocKind.StackI64:
604 return 512
605 if self is LocKind.GPR or self is LocKind.CA \
606 or self is LocKind.VL_MAXVL:
607 return self.base_ty.max_reg_len
608 else:
609 assert_never(self)
610
611 def __repr__(self):
612 return "LocKind." + self._name_
613
614
615 @final
616 @unique
617 class LocSubKind(Enum):
618 BASE_GPR = enum.auto()
619 SV_EXTRA2_VGPR = enum.auto()
620 SV_EXTRA2_SGPR = enum.auto()
621 SV_EXTRA3_VGPR = enum.auto()
622 SV_EXTRA3_SGPR = enum.auto()
623 StackI64 = enum.auto()
624 CA = enum.auto()
625 VL_MAXVL = enum.auto()
626
627 @cached_property
628 def kind(self):
629 # type: () -> LocKind
630 # pyright fails typechecking when using `in` here:
631 # reported: https://github.com/microsoft/pyright/issues/4102
632 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
633 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
634 LocSubKind.SV_EXTRA3_SGPR):
635 return LocKind.GPR
636 if self is LocSubKind.StackI64:
637 return LocKind.StackI64
638 if self is LocSubKind.CA:
639 return LocKind.CA
640 if self is LocSubKind.VL_MAXVL:
641 return LocKind.VL_MAXVL
642 assert_never(self)
643
644 @property
645 def base_ty(self):
646 return self.kind.base_ty
647
648 @lru_cache()
649 def allocatable_locs(self, ty):
650 # type: (Ty) -> LocSet
651 if ty.base_ty != self.base_ty:
652 raise ValueError("type mismatch")
653 if self is LocSubKind.BASE_GPR:
654 starts = range(32)
655 elif self is LocSubKind.SV_EXTRA2_VGPR:
656 starts = range(0, 128, 2)
657 elif self is LocSubKind.SV_EXTRA2_SGPR:
658 starts = range(64)
659 elif self is LocSubKind.SV_EXTRA3_VGPR \
660 or self is LocSubKind.SV_EXTRA3_SGPR:
661 starts = range(128)
662 elif self is LocSubKind.StackI64:
663 starts = range(LocKind.StackI64.loc_count)
664 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
665 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
666 else:
667 assert_never(self)
668 retval = [] # type: list[Loc]
669 for start in starts:
670 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
671 if loc is None:
672 continue
673 conflicts = False
674 for special_loc in SPECIAL_GPRS:
675 if loc.conflicts(special_loc):
676 conflicts = True
677 break
678 if not conflicts:
679 retval.append(loc)
680 return LocSet(retval)
681
682 def __repr__(self):
683 return "LocSubKind." + self._name_
684
685
686 @plain_data(frozen=True, unsafe_hash=True)
687 @final
688 class GenericTy(metaclass=InternedMeta):
689 __slots__ = "base_ty", "is_vec"
690
691 def __init__(self, base_ty, is_vec):
692 # type: (BaseTy, bool) -> None
693 self.base_ty = base_ty
694 if base_ty.only_scalar and is_vec:
695 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
696 self.is_vec = is_vec
697
698 def instantiate(self, maxvl):
699 # type: (int) -> Ty
700 # here's where subvl and elwid would be accounted for
701 if self.is_vec:
702 return Ty(self.base_ty, maxvl)
703 return Ty(self.base_ty, 1)
704
705 def can_instantiate_to(self, ty):
706 # type: (Ty) -> bool
707 if self.base_ty != ty.base_ty:
708 return False
709 if self.is_vec:
710 return True
711 return ty.reg_len == 1
712
713
714 @plain_data(frozen=True, unsafe_hash=True)
715 @final
716 class Loc(metaclass=InternedMeta):
717 __slots__ = "kind", "start", "reg_len"
718
719 @staticmethod
720 def validate(kind, start, reg_len):
721 # type: (LocKind, int, int) -> str | None
722 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
723 if msg is not None:
724 return msg
725 if reg_len > kind.loc_count:
726 return "invalid reg_len"
727 if start < 0 or start + reg_len > kind.loc_count:
728 return "start not in valid range"
729 return None
730
731 @staticmethod
732 def try_make(kind, start, reg_len):
733 # type: (LocKind, int, int) -> Loc | None
734 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
735 if msg is not None:
736 return None
737 return Loc(kind=kind, start=start, reg_len=reg_len)
738
739 def __init__(self, kind, start, reg_len):
740 # type: (LocKind, int, int) -> None
741 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
742 if msg is not None:
743 raise ValueError(msg)
744 self.kind = kind
745 self.reg_len = reg_len
746 self.start = start
747
748 def conflicts(self, other):
749 # type: (Loc) -> bool
750 return (self.kind == other.kind
751 and self.start < other.stop and other.start < self.stop)
752
753 @staticmethod
754 def make_ty(kind, reg_len):
755 # type: (LocKind, int) -> Ty
756 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
757
758 @cached_property
759 def ty(self):
760 # type: () -> Ty
761 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
762
763 @property
764 def stop(self):
765 # type: () -> int
766 return self.start + self.reg_len
767
768 def try_concat(self, *others):
769 # type: (*Loc | None) -> Loc | None
770 reg_len = self.reg_len
771 stop = self.stop
772 for other in others:
773 if other is None or other.kind != self.kind:
774 return None
775 if stop != other.start:
776 return None
777 stop = other.stop
778 reg_len += other.reg_len
779 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
780
781 def get_subloc_at_offset(self, subloc_ty, offset):
782 # type: (Ty, int) -> Loc
783 if subloc_ty.base_ty != self.kind.base_ty:
784 raise ValueError("BaseTy mismatch")
785 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
786 raise ValueError("invalid sub-Loc: offset and/or "
787 "subloc_ty.reg_len out of range")
788 return Loc(kind=self.kind,
789 start=self.start + offset, reg_len=subloc_ty.reg_len)
790
791
792 SPECIAL_GPRS = (
793 Loc(kind=LocKind.GPR, start=0, reg_len=1),
794 Loc(kind=LocKind.GPR, start=1, reg_len=1),
795 Loc(kind=LocKind.GPR, start=2, reg_len=1),
796 Loc(kind=LocKind.GPR, start=13, reg_len=1),
797 )
798
799
800 @final
801 class LocSet(OFSet[Loc], metaclass=InternedMeta):
802 def __init__(self, __locs=()):
803 # type: (Iterable[Loc]) -> None
804 super().__init__(__locs)
805 if isinstance(__locs, LocSet):
806 self.__starts = __locs.starts
807 self.__ty = __locs.ty
808 return
809 starts = {i: BitSet() for i in LocKind}
810 ty = None # type: None | Ty
811 for loc in self:
812 if ty is None:
813 ty = loc.ty
814 if ty != loc.ty:
815 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
816 starts[loc.kind].add(loc.start)
817 self.__starts = FMap(
818 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
819 self.__ty = ty
820
821 @property
822 def starts(self):
823 # type: () -> FMap[LocKind, FBitSet]
824 return self.__starts
825
826 @property
827 def ty(self):
828 # type: () -> Ty | None
829 return self.__ty
830
831 @cached_property
832 def stops(self):
833 # type: () -> FMap[LocKind, FBitSet]
834 if self.ty is None:
835 return FMap()
836 sh = self.ty.reg_len
837 return FMap(
838 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
839
840 @property
841 def kinds(self):
842 # type: () -> AbstractSet[LocKind]
843 return self.starts.keys()
844
845 @property
846 def reg_len(self):
847 # type: () -> int | None
848 if self.ty is None:
849 return None
850 return self.ty.reg_len
851
852 @property
853 def base_ty(self):
854 # type: () -> BaseTy | None
855 if self.ty is None:
856 return None
857 return self.ty.base_ty
858
859 def concat(self, *others):
860 # type: (*LocSet) -> LocSet
861 if self.ty is None:
862 return LocSet()
863 base_ty = self.ty.base_ty
864 reg_len = self.ty.reg_len
865 starts = {k: BitSet(v) for k, v in self.starts.items()}
866 for other in others:
867 if other.ty is None:
868 return LocSet()
869 if other.ty.base_ty != base_ty:
870 return LocSet()
871 for kind, other_starts in other.starts.items():
872 if kind not in starts:
873 continue
874 starts[kind].bits &= other_starts.bits >> reg_len
875 if starts[kind] == 0:
876 del starts[kind]
877 if len(starts) == 0:
878 return LocSet()
879 reg_len += other.ty.reg_len
880
881 def locs():
882 # type: () -> Iterable[Loc]
883 for kind, v in starts.items():
884 for start in v:
885 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
886 if loc is not None:
887 yield loc
888 return LocSet(locs())
889
890 @lru_cache(maxsize=None, typed=True)
891 def max_conflicts_with(self, other):
892 # type: (LocSet | Loc) -> int
893 """the largest number of Locs in `self` that a single Loc
894 from `other` can conflict with
895 """
896 if isinstance(other, LocSet):
897 return max(self.max_conflicts_with(i) for i in other)
898 else:
899 return sum(other.conflicts(i) for i in self)
900
901 def __repr__(self):
902 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
903
904
905 @plain_data(frozen=True, unsafe_hash=True)
906 @final
907 class GenericOperandDesc(metaclass=InternedMeta):
908 """generic Op operand descriptor"""
909 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
910 "write_stage")
911
912 def __init__(
913 self, ty, # type: GenericTy
914 sub_kinds, # type: Iterable[LocSubKind]
915 *,
916 fixed_loc=None, # type: Loc | None
917 tied_input_index=None, # type: int | None
918 spread=False, # type: bool
919 write_stage=OpStage.Early, # type: OpStage
920 ):
921 # type: (...) -> None
922 self.ty = ty
923 self.sub_kinds = OFSet(sub_kinds)
924 if len(self.sub_kinds) == 0:
925 raise ValueError("sub_kinds can't be empty")
926 self.fixed_loc = fixed_loc
927 if fixed_loc is not None:
928 if tied_input_index is not None:
929 raise ValueError("operand can't be both tied and fixed")
930 if not ty.can_instantiate_to(fixed_loc.ty):
931 raise ValueError(
932 f"fixed_loc has incompatible type for given generic "
933 f"type: fixed_loc={fixed_loc} generic ty={ty}")
934 if len(self.sub_kinds) != 1:
935 raise ValueError(
936 "multiple sub_kinds not allowed for fixed operand")
937 for sub_kind in self.sub_kinds:
938 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
939 raise ValueError(
940 f"fixed_loc not in given sub_kind: "
941 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
942 for sub_kind in self.sub_kinds:
943 if sub_kind.base_ty != ty.base_ty:
944 raise ValueError(f"sub_kind is incompatible with type: "
945 f"sub_kind={sub_kind} ty={ty}")
946 if tied_input_index is not None and tied_input_index < 0:
947 raise ValueError("invalid tied_input_index")
948 self.tied_input_index = tied_input_index
949 self.spread = spread
950 if spread:
951 if self.tied_input_index is not None:
952 raise ValueError("operand can't be both spread and tied")
953 if self.fixed_loc is not None:
954 raise ValueError("operand can't be both spread and fixed")
955 if self.ty.is_vec:
956 raise ValueError("operand can't be both spread and vector")
957 self.write_stage = write_stage
958
959 @cached_property
960 def ty_before_spread(self):
961 # type: () -> GenericTy
962 if self.spread:
963 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
964 return self.ty
965
966 def tied_to_input(self, tied_input_index):
967 # type: (int) -> Self
968 return GenericOperandDesc(self.ty, self.sub_kinds,
969 tied_input_index=tied_input_index,
970 write_stage=self.write_stage)
971
972 def with_fixed_loc(self, fixed_loc):
973 # type: (Loc) -> Self
974 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
975 write_stage=self.write_stage)
976
977 def with_write_stage(self, write_stage):
978 # type: (OpStage) -> Self
979 return GenericOperandDesc(self.ty, self.sub_kinds,
980 fixed_loc=self.fixed_loc,
981 tied_input_index=self.tied_input_index,
982 spread=self.spread,
983 write_stage=write_stage)
984
985 def instantiate(self, maxvl):
986 # type: (int) -> Iterable[OperandDesc]
987 # assumes all spread operands have ty.reg_len = 1
988 rep_count = 1
989 if self.spread:
990 rep_count = maxvl
991 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
992
993 def locs_before_spread():
994 # type: () -> Iterable[Loc]
995 if self.fixed_loc is not None:
996 if ty_before_spread != self.fixed_loc.ty:
997 raise ValueError(
998 f"instantiation failed: type mismatch with fixed_loc: "
999 f"instantiated type: {ty_before_spread} "
1000 f"fixed_loc: {self.fixed_loc}")
1001 yield self.fixed_loc
1002 return
1003 for sub_kind in self.sub_kinds:
1004 yield from sub_kind.allocatable_locs(ty_before_spread)
1005 loc_set_before_spread = LocSet(locs_before_spread())
1006 for idx in range(rep_count):
1007 if not self.spread:
1008 idx = None
1009 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
1010 tied_input_index=self.tied_input_index,
1011 spread_index=idx, write_stage=self.write_stage)
1012
1013
1014 @plain_data(frozen=True, unsafe_hash=True)
1015 @final
1016 class OperandDesc(metaclass=InternedMeta):
1017 """Op operand descriptor"""
1018 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
1019 "write_stage")
1020
1021 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
1022 write_stage):
1023 # type: (LocSet, int | None, int | None, OpStage) -> None
1024 if len(loc_set_before_spread) == 0:
1025 raise ValueError("loc_set_before_spread must not be empty")
1026 self.loc_set_before_spread = loc_set_before_spread
1027 self.tied_input_index = tied_input_index
1028 if self.tied_input_index is not None and spread_index is not None:
1029 raise ValueError("operand can't be both spread and tied")
1030 self.spread_index = spread_index
1031 self.write_stage = write_stage
1032
1033 @cached_property
1034 def ty_before_spread(self):
1035 # type: () -> Ty
1036 ty = self.loc_set_before_spread.ty
1037 assert ty is not None, (
1038 "__init__ checked that the LocSet isn't empty, "
1039 "non-empty LocSets should always have ty set")
1040 return ty
1041
1042 @cached_property
1043 def ty(self):
1044 """ Ty after any spread is applied """
1045 if self.spread_index is not None:
1046 # assumes all spread operands have ty.reg_len = 1
1047 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1048 return self.ty_before_spread
1049
1050 @property
1051 def reg_offset_in_unspread(self):
1052 """ the number of reg-sized slots in the unspread Loc before self's Loc
1053
1054 e.g. if the unspread Loc containing self is:
1055 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1056 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1057 then reg_offset_into_unspread == 2 == 10 - 8
1058 """
1059 if self.spread_index is None:
1060 return 0
1061 return self.spread_index * self.ty.reg_len
1062
1063
1064 OD_BASE_SGPR = GenericOperandDesc(
1065 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1066 sub_kinds=[LocSubKind.BASE_GPR])
1067 OD_EXTRA3_SGPR = GenericOperandDesc(
1068 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1069 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1070 OD_EXTRA3_VGPR = GenericOperandDesc(
1071 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1072 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1073 OD_EXTRA2_SGPR = GenericOperandDesc(
1074 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1075 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1076 OD_EXTRA2_VGPR = GenericOperandDesc(
1077 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1078 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1079 OD_CA = GenericOperandDesc(
1080 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1081 sub_kinds=[LocSubKind.CA])
1082 OD_VL = GenericOperandDesc(
1083 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1084 sub_kinds=[LocSubKind.VL_MAXVL])
1085
1086
1087 @plain_data(frozen=True, unsafe_hash=True)
1088 @final
1089 class GenericOpProperties(metaclass=InternedMeta):
1090 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
1091 "is_copy", "is_load_immediate", "has_side_effects")
1092
1093 def __init__(
1094 self, demo_asm, # type: str
1095 inputs, # type: Iterable[GenericOperandDesc]
1096 outputs, # type: Iterable[GenericOperandDesc]
1097 immediates=(), # type: Iterable[range]
1098 is_copy=False, # type: bool
1099 is_load_immediate=False, # type: bool
1100 has_side_effects=False, # type: bool
1101 ):
1102 # type: (...) -> None
1103 self.demo_asm = demo_asm # type: str
1104 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
1105 for inp in self.inputs:
1106 if inp.tied_input_index is not None:
1107 raise ValueError(
1108 f"tied_input_index is not allowed on inputs: {inp}")
1109 if inp.write_stage is not OpStage.Early:
1110 raise ValueError(
1111 f"write_stage is not allowed on inputs: {inp}")
1112 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
1113 fixed_locs = [] # type: list[tuple[Loc, int]]
1114 for idx, out in enumerate(self.outputs):
1115 if out.tied_input_index is not None:
1116 if out.tied_input_index >= len(self.inputs):
1117 raise ValueError(f"tied_input_index out of range: {out}")
1118 tied_inp = self.inputs[out.tied_input_index]
1119 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1120 .with_write_stage(out.write_stage)
1121 if expected_out != out:
1122 raise ValueError(f"output can't be tied to non-equivalent "
1123 f"input: {out} tied to {tied_inp}")
1124 if out.fixed_loc is not None:
1125 for other_fixed_loc, other_idx in fixed_locs:
1126 if not other_fixed_loc.conflicts(out.fixed_loc):
1127 continue
1128 raise ValueError(
1129 f"conflicting fixed_locs: outputs[{idx}] and "
1130 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1131 f"with {other_fixed_loc}")
1132 fixed_locs.append((out.fixed_loc, idx))
1133 self.immediates = tuple(immediates) # type: tuple[range, ...]
1134 self.is_copy = is_copy # type: bool
1135 self.is_load_immediate = is_load_immediate # type: bool
1136 self.has_side_effects = has_side_effects # type: bool
1137
1138
1139 @plain_data(frozen=True, unsafe_hash=True)
1140 @final
1141 class OpProperties(metaclass=InternedMeta):
1142 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1143
1144 def __init__(self, kind, maxvl):
1145 # type: (OpKind, int) -> None
1146 self.kind = kind # type: OpKind
1147 inputs = [] # type: list[OperandDesc]
1148 for inp in self.generic.inputs:
1149 inputs.extend(inp.instantiate(maxvl=maxvl))
1150 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1151 outputs = [] # type: list[OperandDesc]
1152 for out in self.generic.outputs:
1153 outputs.extend(out.instantiate(maxvl=maxvl))
1154 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1155 self.maxvl = maxvl # type: int
1156 copy_input_reg_len = 0
1157 for inp in self.inputs[:self.copy_inputs_len]:
1158 copy_input_reg_len += inp.ty.reg_len
1159 copy_output_reg_len = 0
1160 for out in self.outputs[:self.copy_outputs_len]:
1161 copy_output_reg_len += out.ty.reg_len
1162 if copy_input_reg_len != copy_output_reg_len:
1163 raise ValueError(f"invalid copy: copy's input reg len must "
1164 f"match its output reg len: "
1165 f"{copy_input_reg_len} != {copy_output_reg_len}")
1166 self.copy_reg_len = copy_input_reg_len
1167
1168 @property
1169 def generic(self):
1170 # type: () -> GenericOpProperties
1171 return self.kind.properties
1172
1173 @property
1174 def immediates(self):
1175 # type: () -> tuple[range, ...]
1176 return self.generic.immediates
1177
1178 @property
1179 def demo_asm(self):
1180 # type: () -> str
1181 return self.generic.demo_asm
1182
1183 @property
1184 def is_copy(self):
1185 # type: () -> bool
1186 return self.generic.is_copy
1187
1188 @property
1189 def is_load_immediate(self):
1190 # type: () -> bool
1191 return self.generic.is_load_immediate
1192
1193 @property
1194 def has_side_effects(self):
1195 # type: () -> bool
1196 return self.generic.has_side_effects
1197
1198 @cached_property
1199 def copy_inputs_len(self):
1200 # type: () -> int
1201 if not self.is_copy:
1202 return 0
1203 if self.inputs[0].spread_index is None:
1204 return 1
1205 retval = 0
1206 for i, inp in enumerate(self.inputs):
1207 if inp.spread_index != i:
1208 break
1209 retval += 1
1210 return retval
1211
1212 @cached_property
1213 def copy_outputs_len(self):
1214 # type: () -> int
1215 if not self.is_copy:
1216 return 0
1217 if self.outputs[0].spread_index is None:
1218 return 1
1219 retval = 0
1220 for i, out in enumerate(self.outputs):
1221 if out.spread_index != i:
1222 break
1223 retval += 1
1224 return retval
1225
1226
1227 IMM_S16 = range(-1 << 15, 1 << 15)
1228
1229 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1230 _SIM_FN2 = Callable[[], _SIM_FN]
1231 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1232 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1233 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1234 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1235
1236
1237 @unique
1238 @final
1239 class OpKind(Enum):
1240 def __init__(self, properties):
1241 # type: (GenericOpProperties) -> None
1242 super().__init__()
1243 self.__properties = properties
1244
1245 @property
1246 def properties(self):
1247 # type: () -> GenericOpProperties
1248 return self.__properties
1249
1250 def instantiate(self, maxvl):
1251 # type: (int) -> OpProperties
1252 return OpProperties(self, maxvl=maxvl)
1253
1254 def __repr__(self):
1255 # type: () -> str
1256 return "OpKind." + self._name_
1257
1258 @cached_property
1259 def sim(self):
1260 # type: () -> _SIM_FN
1261 return _SIM_FNS[self.properties]()
1262
1263 @cached_property
1264 def gen_asm(self):
1265 # type: () -> _GEN_ASM_FN
1266 return _GEN_ASMS[self.properties]()
1267
1268 @staticmethod
1269 def __clearca_sim(op, state):
1270 # type: (Op, BaseSimState) -> None
1271 state[op.outputs[0]] = False,
1272
1273 @staticmethod
1274 def __clearca_gen_asm(op, state):
1275 # type: (Op, GenAsmState) -> None
1276 state.writeln("addic 0, 0, 0")
1277 ClearCA = GenericOpProperties(
1278 demo_asm="addic 0, 0, 0",
1279 inputs=[],
1280 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1281 )
1282 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1283 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1284
1285 @staticmethod
1286 def __setca_sim(op, state):
1287 # type: (Op, BaseSimState) -> None
1288 state[op.outputs[0]] = True,
1289
1290 @staticmethod
1291 def __setca_gen_asm(op, state):
1292 # type: (Op, GenAsmState) -> None
1293 state.writeln("subfc 0, 0, 0")
1294 SetCA = GenericOpProperties(
1295 demo_asm="subfc 0, 0, 0",
1296 inputs=[],
1297 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1298 )
1299 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1300 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1301
1302 @staticmethod
1303 def __svadde_sim(op, state):
1304 # type: (Op, BaseSimState) -> None
1305 RA = state[op.input_vals[0]]
1306 RB = state[op.input_vals[1]]
1307 carry, = state[op.input_vals[2]]
1308 VL, = state[op.input_vals[3]]
1309 RT = [] # type: list[int]
1310 for i in range(VL):
1311 v = RA[i] + RB[i] + carry
1312 RT.append(v & GPR_VALUE_MASK)
1313 carry = (v >> GPR_SIZE_IN_BITS) != 0
1314 state[op.outputs[0]] = tuple(RT)
1315 state[op.outputs[1]] = carry,
1316
1317 @staticmethod
1318 def __svadde_gen_asm(op, state):
1319 # type: (Op, GenAsmState) -> None
1320 RT = state.vgpr(op.outputs[0])
1321 RA = state.vgpr(op.input_vals[0])
1322 RB = state.vgpr(op.input_vals[1])
1323 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1324 SvAddE = GenericOpProperties(
1325 demo_asm="sv.adde *RT, *RA, *RB",
1326 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1327 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1328 )
1329 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1330 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1331
1332 @staticmethod
1333 def __addze_sim(op, state):
1334 # type: (Op, BaseSimState) -> None
1335 RA, = state[op.input_vals[0]]
1336 carry, = state[op.input_vals[1]]
1337 v = RA + carry
1338 RT = v & GPR_VALUE_MASK
1339 carry = (v >> GPR_SIZE_IN_BITS) != 0
1340 state[op.outputs[0]] = RT,
1341 state[op.outputs[1]] = carry,
1342
1343 @staticmethod
1344 def __addze_gen_asm(op, state):
1345 # type: (Op, GenAsmState) -> None
1346 RT = state.vgpr(op.outputs[0])
1347 RA = state.vgpr(op.input_vals[0])
1348 state.writeln(f"addze {RT}, {RA}")
1349 AddZE = GenericOpProperties(
1350 demo_asm="addze RT, RA",
1351 inputs=[OD_BASE_SGPR, OD_CA],
1352 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1353 )
1354 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1355 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1356
1357 @staticmethod
1358 def __svsubfe_sim(op, state):
1359 # type: (Op, BaseSimState) -> None
1360 RA = state[op.input_vals[0]]
1361 RB = state[op.input_vals[1]]
1362 carry, = state[op.input_vals[2]]
1363 VL, = state[op.input_vals[3]]
1364 RT = [] # type: list[int]
1365 for i in range(VL):
1366 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1367 RT.append(v & GPR_VALUE_MASK)
1368 carry = (v >> GPR_SIZE_IN_BITS) != 0
1369 state[op.outputs[0]] = tuple(RT)
1370 state[op.outputs[1]] = carry,
1371
1372 @staticmethod
1373 def __svsubfe_gen_asm(op, state):
1374 # type: (Op, GenAsmState) -> None
1375 RT = state.vgpr(op.outputs[0])
1376 RA = state.vgpr(op.input_vals[0])
1377 RB = state.vgpr(op.input_vals[1])
1378 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1379 SvSubFE = GenericOpProperties(
1380 demo_asm="sv.subfe *RT, *RA, *RB",
1381 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1382 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1383 )
1384 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1385 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1386
1387 @staticmethod
1388 def __svandvs_sim(op, state):
1389 # type: (Op, BaseSimState) -> None
1390 RA = state[op.input_vals[0]]
1391 RB, = state[op.input_vals[1]]
1392 VL, = state[op.input_vals[2]]
1393 RT = [] # type: list[int]
1394 for i in range(VL):
1395 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1396 state[op.outputs[0]] = tuple(RT)
1397
1398 @staticmethod
1399 def __svandvs_gen_asm(op, state):
1400 # type: (Op, GenAsmState) -> None
1401 RT = state.vgpr(op.outputs[0])
1402 RA = state.vgpr(op.input_vals[0])
1403 RB = state.sgpr(op.input_vals[1])
1404 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1405 SvAndVS = GenericOpProperties(
1406 demo_asm="sv.and *RT, *RA, RB",
1407 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1408 outputs=[OD_EXTRA3_VGPR],
1409 )
1410 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1411 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1412
1413 @staticmethod
1414 def __svmaddedu_sim(op, state):
1415 # type: (Op, BaseSimState) -> None
1416 RA = state[op.input_vals[0]]
1417 RB, = state[op.input_vals[1]]
1418 carry, = state[op.input_vals[2]]
1419 VL, = state[op.input_vals[3]]
1420 RT = [] # type: list[int]
1421 for i in range(VL):
1422 v = RA[i] * RB + carry
1423 RT.append(v & GPR_VALUE_MASK)
1424 carry = v >> GPR_SIZE_IN_BITS
1425 state[op.outputs[0]] = tuple(RT)
1426 state[op.outputs[1]] = carry,
1427
1428 @staticmethod
1429 def __svmaddedu_gen_asm(op, state):
1430 # type: (Op, GenAsmState) -> None
1431 RT = state.vgpr(op.outputs[0])
1432 RA = state.vgpr(op.input_vals[0])
1433 RB = state.sgpr(op.input_vals[1])
1434 RC = state.sgpr(op.input_vals[2])
1435 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1436 SvMAddEDU = GenericOpProperties(
1437 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1438 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1439 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1440 )
1441 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1442 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1443
1444 @staticmethod
1445 def __sradi_sim(op, state):
1446 # type: (Op, BaseSimState) -> None
1447 rs, = state[op.input_vals[0]]
1448 imm = op.immediates[0]
1449 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1450 rs -= 1 << GPR_SIZE_IN_BITS
1451 v = rs >> imm
1452 RA = v & GPR_VALUE_MASK
1453 CA = (RA << imm) != rs
1454 state[op.outputs[0]] = RA,
1455 state[op.outputs[1]] = CA,
1456
1457 @staticmethod
1458 def __sradi_gen_asm(op, state):
1459 # type: (Op, GenAsmState) -> None
1460 RA = state.sgpr(op.outputs[0])
1461 RS = state.sgpr(op.input_vals[0])
1462 imm = op.immediates[0]
1463 state.writeln(f"sradi {RA}, {RS}, {imm}")
1464 SRADI = GenericOpProperties(
1465 demo_asm="sradi RA, RS, imm",
1466 inputs=[OD_BASE_SGPR],
1467 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1468 OD_CA.with_write_stage(OpStage.Late)],
1469 immediates=[range(GPR_SIZE_IN_BITS)],
1470 )
1471 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1472 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1473
1474 @staticmethod
1475 def __setvli_sim(op, state):
1476 # type: (Op, BaseSimState) -> None
1477 state[op.outputs[0]] = op.immediates[0],
1478
1479 @staticmethod
1480 def __setvli_gen_asm(op, state):
1481 # type: (Op, GenAsmState) -> None
1482 imm = op.immediates[0]
1483 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1484 SetVLI = GenericOpProperties(
1485 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1486 inputs=(),
1487 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1488 immediates=[range(1, 65)],
1489 is_load_immediate=True,
1490 )
1491 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1492 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1493
1494 @staticmethod
1495 def __svli_sim(op, state):
1496 # type: (Op, BaseSimState) -> None
1497 VL, = state[op.input_vals[0]]
1498 imm = op.immediates[0] & GPR_VALUE_MASK
1499 state[op.outputs[0]] = (imm,) * VL
1500
1501 @staticmethod
1502 def __svli_gen_asm(op, state):
1503 # type: (Op, GenAsmState) -> None
1504 RT = state.vgpr(op.outputs[0])
1505 imm = op.immediates[0]
1506 state.writeln(f"sv.addi {RT}, 0, {imm}")
1507 SvLI = GenericOpProperties(
1508 demo_asm="sv.addi *RT, 0, imm",
1509 inputs=[OD_VL],
1510 outputs=[OD_EXTRA3_VGPR],
1511 immediates=[IMM_S16],
1512 is_load_immediate=True,
1513 )
1514 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1515 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1516
1517 @staticmethod
1518 def __li_sim(op, state):
1519 # type: (Op, BaseSimState) -> None
1520 imm = op.immediates[0] & GPR_VALUE_MASK
1521 state[op.outputs[0]] = imm,
1522
1523 @staticmethod
1524 def __li_gen_asm(op, state):
1525 # type: (Op, GenAsmState) -> None
1526 RT = state.sgpr(op.outputs[0])
1527 imm = op.immediates[0]
1528 state.writeln(f"addi {RT}, 0, {imm}")
1529 LI = GenericOpProperties(
1530 demo_asm="addi RT, 0, imm",
1531 inputs=(),
1532 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1533 immediates=[IMM_S16],
1534 is_load_immediate=True,
1535 )
1536 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1537 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1538
1539 @staticmethod
1540 def __veccopytoreg_sim(op, state):
1541 # type: (Op, BaseSimState) -> None
1542 state[op.outputs[0]] = state[op.input_vals[0]]
1543
1544 @staticmethod
1545 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1546 # type: (Loc, Loc, bool, GenAsmState) -> None
1547 sv = "sv." if is_vec else ""
1548 rev = ""
1549 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1550 rev = "/mrr"
1551 if src_loc == dest_loc:
1552 return # no-op
1553 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1554 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1555 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1556 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1557 if src_loc.kind is LocKind.StackI64:
1558 if dest_loc.kind is LocKind.StackI64:
1559 raise ValueError(
1560 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1561 elif dest_loc.kind is not LocKind.GPR:
1562 assert_never(dest_loc.kind)
1563 src = state.stack(src_loc)
1564 dest = state.gpr(dest_loc, is_vec=is_vec)
1565 state.writeln(f"{sv}ld {dest}, {src}")
1566 elif dest_loc.kind is LocKind.StackI64:
1567 if src_loc.kind is not LocKind.GPR:
1568 assert_never(src_loc.kind)
1569 src = state.gpr(src_loc, is_vec=is_vec)
1570 dest = state.stack(dest_loc)
1571 state.writeln(f"{sv}std {src}, {dest}")
1572 elif src_loc.kind is LocKind.GPR:
1573 if dest_loc.kind is not LocKind.GPR:
1574 assert_never(dest_loc.kind)
1575 src = state.gpr(src_loc, is_vec=is_vec)
1576 dest = state.gpr(dest_loc, is_vec=is_vec)
1577 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1578 else:
1579 assert_never(src_loc.kind)
1580
1581 @staticmethod
1582 def __veccopytoreg_gen_asm(op, state):
1583 # type: (Op, GenAsmState) -> None
1584 OpKind.__copy_to_from_reg_gen_asm(
1585 src_loc=state.loc(
1586 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1587 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1588 is_vec=True, state=state)
1589
1590 VecCopyToReg = GenericOpProperties(
1591 demo_asm="sv.mv dest, src",
1592 inputs=[GenericOperandDesc(
1593 ty=GenericTy(BaseTy.I64, is_vec=True),
1594 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1595 ), OD_VL],
1596 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1597 is_copy=True,
1598 )
1599 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1600 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1601
1602 @staticmethod
1603 def __veccopyfromreg_sim(op, state):
1604 # type: (Op, BaseSimState) -> None
1605 state[op.outputs[0]] = state[op.input_vals[0]]
1606
1607 @staticmethod
1608 def __veccopyfromreg_gen_asm(op, state):
1609 # type: (Op, GenAsmState) -> None
1610 OpKind.__copy_to_from_reg_gen_asm(
1611 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1612 dest_loc=state.loc(
1613 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1614 is_vec=True, state=state)
1615 VecCopyFromReg = GenericOpProperties(
1616 demo_asm="sv.mv dest, src",
1617 inputs=[OD_EXTRA3_VGPR, OD_VL],
1618 outputs=[GenericOperandDesc(
1619 ty=GenericTy(BaseTy.I64, is_vec=True),
1620 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1621 write_stage=OpStage.Late,
1622 )],
1623 is_copy=True,
1624 )
1625 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1626 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1627
1628 @staticmethod
1629 def __copytoreg_sim(op, state):
1630 # type: (Op, BaseSimState) -> None
1631 state[op.outputs[0]] = state[op.input_vals[0]]
1632
1633 @staticmethod
1634 def __copytoreg_gen_asm(op, state):
1635 # type: (Op, GenAsmState) -> None
1636 OpKind.__copy_to_from_reg_gen_asm(
1637 src_loc=state.loc(
1638 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1639 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1640 is_vec=False, state=state)
1641 CopyToReg = GenericOpProperties(
1642 demo_asm="mv dest, src",
1643 inputs=[GenericOperandDesc(
1644 ty=GenericTy(BaseTy.I64, is_vec=False),
1645 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1646 LocSubKind.StackI64],
1647 )],
1648 outputs=[GenericOperandDesc(
1649 ty=GenericTy(BaseTy.I64, is_vec=False),
1650 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1651 write_stage=OpStage.Late,
1652 )],
1653 is_copy=True,
1654 )
1655 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1656 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1657
1658 @staticmethod
1659 def __copyfromreg_sim(op, state):
1660 # type: (Op, BaseSimState) -> None
1661 state[op.outputs[0]] = state[op.input_vals[0]]
1662
1663 @staticmethod
1664 def __copyfromreg_gen_asm(op, state):
1665 # type: (Op, GenAsmState) -> None
1666 OpKind.__copy_to_from_reg_gen_asm(
1667 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1668 dest_loc=state.loc(
1669 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1670 is_vec=False, state=state)
1671 CopyFromReg = GenericOpProperties(
1672 demo_asm="mv dest, src",
1673 inputs=[GenericOperandDesc(
1674 ty=GenericTy(BaseTy.I64, is_vec=False),
1675 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1676 )],
1677 outputs=[GenericOperandDesc(
1678 ty=GenericTy(BaseTy.I64, is_vec=False),
1679 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1680 LocSubKind.StackI64],
1681 write_stage=OpStage.Late,
1682 )],
1683 is_copy=True,
1684 )
1685 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1686 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1687
1688 @staticmethod
1689 def __concat_sim(op, state):
1690 # type: (Op, BaseSimState) -> None
1691 state[op.outputs[0]] = tuple(
1692 state[i][0] for i in op.input_vals[:-1])
1693
1694 @staticmethod
1695 def __concat_gen_asm(op, state):
1696 # type: (Op, GenAsmState) -> None
1697 OpKind.__copy_to_from_reg_gen_asm(
1698 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1699 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1700 is_vec=True, state=state)
1701 Concat = GenericOpProperties(
1702 demo_asm="sv.mv dest, src",
1703 inputs=[GenericOperandDesc(
1704 ty=GenericTy(BaseTy.I64, is_vec=False),
1705 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1706 spread=True,
1707 ), OD_VL],
1708 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1709 is_copy=True,
1710 )
1711 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1712 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1713
1714 @staticmethod
1715 def __spread_sim(op, state):
1716 # type: (Op, BaseSimState) -> None
1717 for idx, inp in enumerate(state[op.input_vals[0]]):
1718 state[op.outputs[idx]] = inp,
1719
1720 @staticmethod
1721 def __spread_gen_asm(op, state):
1722 # type: (Op, GenAsmState) -> None
1723 OpKind.__copy_to_from_reg_gen_asm(
1724 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1725 dest_loc=state.loc(op.outputs, LocKind.GPR),
1726 is_vec=True, state=state)
1727 Spread = GenericOpProperties(
1728 demo_asm="sv.mv dest, src",
1729 inputs=[OD_EXTRA3_VGPR, OD_VL],
1730 outputs=[GenericOperandDesc(
1731 ty=GenericTy(BaseTy.I64, is_vec=False),
1732 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1733 spread=True,
1734 write_stage=OpStage.Late,
1735 )],
1736 is_copy=True,
1737 )
1738 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1739 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1740
1741 @staticmethod
1742 def __svld_sim(op, state):
1743 # type: (Op, BaseSimState) -> None
1744 RA, = state[op.input_vals[0]]
1745 VL, = state[op.input_vals[1]]
1746 addr = RA + op.immediates[0]
1747 RT = [] # type: list[int]
1748 for i in range(VL):
1749 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1750 RT.append(v & GPR_VALUE_MASK)
1751 state[op.outputs[0]] = tuple(RT)
1752
1753 @staticmethod
1754 def __svld_gen_asm(op, state):
1755 # type: (Op, GenAsmState) -> None
1756 RA = state.sgpr(op.input_vals[0])
1757 RT = state.vgpr(op.outputs[0])
1758 imm = op.immediates[0]
1759 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1760 SvLd = GenericOpProperties(
1761 demo_asm="sv.ld *RT, imm(RA)",
1762 inputs=[OD_EXTRA3_SGPR, OD_VL],
1763 outputs=[OD_EXTRA3_VGPR],
1764 immediates=[IMM_S16],
1765 )
1766 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1767 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1768
1769 @staticmethod
1770 def __ld_sim(op, state):
1771 # type: (Op, BaseSimState) -> None
1772 RA, = state[op.input_vals[0]]
1773 addr = RA + op.immediates[0]
1774 v = state.load(addr)
1775 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1776
1777 @staticmethod
1778 def __ld_gen_asm(op, state):
1779 # type: (Op, GenAsmState) -> None
1780 RA = state.sgpr(op.input_vals[0])
1781 RT = state.sgpr(op.outputs[0])
1782 imm = op.immediates[0]
1783 state.writeln(f"ld {RT}, {imm}({RA})")
1784 Ld = GenericOpProperties(
1785 demo_asm="ld RT, imm(RA)",
1786 inputs=[OD_BASE_SGPR],
1787 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1788 immediates=[IMM_S16],
1789 )
1790 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1791 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1792
1793 @staticmethod
1794 def __svstd_sim(op, state):
1795 # type: (Op, BaseSimState) -> None
1796 RS = state[op.input_vals[0]]
1797 RA, = state[op.input_vals[1]]
1798 VL, = state[op.input_vals[2]]
1799 addr = RA + op.immediates[0]
1800 for i in range(VL):
1801 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1802
1803 @staticmethod
1804 def __svstd_gen_asm(op, state):
1805 # type: (Op, GenAsmState) -> None
1806 RS = state.vgpr(op.input_vals[0])
1807 RA = state.sgpr(op.input_vals[1])
1808 imm = op.immediates[0]
1809 state.writeln(f"sv.std {RS}, {imm}({RA})")
1810 SvStd = GenericOpProperties(
1811 demo_asm="sv.std *RS, imm(RA)",
1812 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1813 outputs=[],
1814 immediates=[IMM_S16],
1815 has_side_effects=True,
1816 )
1817 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1818 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1819
1820 @staticmethod
1821 def __std_sim(op, state):
1822 # type: (Op, BaseSimState) -> None
1823 RS, = state[op.input_vals[0]]
1824 RA, = state[op.input_vals[1]]
1825 addr = RA + op.immediates[0]
1826 state.store(addr, value=RS)
1827
1828 @staticmethod
1829 def __std_gen_asm(op, state):
1830 # type: (Op, GenAsmState) -> None
1831 RS = state.sgpr(op.input_vals[0])
1832 RA = state.sgpr(op.input_vals[1])
1833 imm = op.immediates[0]
1834 state.writeln(f"std {RS}, {imm}({RA})")
1835 Std = GenericOpProperties(
1836 demo_asm="std RS, imm(RA)",
1837 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1838 outputs=[],
1839 immediates=[IMM_S16],
1840 has_side_effects=True,
1841 )
1842 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1843 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1844
1845 @staticmethod
1846 def __funcargr3_sim(op, state):
1847 # type: (Op, BaseSimState) -> None
1848 pass # return value set before simulation
1849
1850 @staticmethod
1851 def __funcargr3_gen_asm(op, state):
1852 # type: (Op, GenAsmState) -> None
1853 pass # no instructions needed
1854 FuncArgR3 = GenericOpProperties(
1855 demo_asm="",
1856 inputs=[],
1857 outputs=[OD_BASE_SGPR.with_fixed_loc(
1858 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1859 )
1860 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1861 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1862
1863
1864 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1865 class SSAValOrUse(metaclass=InternedMeta):
1866 __slots__ = "op", "operand_idx"
1867
1868 def __init__(self, op, operand_idx):
1869 # type: (Op, int) -> None
1870 super().__init__()
1871 self.op = op
1872 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1873 raise ValueError("invalid operand_idx")
1874 self.operand_idx = operand_idx
1875
1876 @abstractmethod
1877 def __repr__(self):
1878 # type: () -> str
1879 ...
1880
1881 @property
1882 @abstractmethod
1883 def descriptor_array(self):
1884 # type: () -> tuple[OperandDesc, ...]
1885 ...
1886
1887 @cached_property
1888 def defining_descriptor(self):
1889 # type: () -> OperandDesc
1890 return self.descriptor_array[self.operand_idx]
1891
1892 @cached_property
1893 def ty(self):
1894 # type: () -> Ty
1895 return self.defining_descriptor.ty
1896
1897 @cached_property
1898 def ty_before_spread(self):
1899 # type: () -> Ty
1900 return self.defining_descriptor.ty_before_spread
1901
1902 @property
1903 def base_ty(self):
1904 # type: () -> BaseTy
1905 return self.ty_before_spread.base_ty
1906
1907 @property
1908 def reg_offset_in_unspread(self):
1909 """ the number of reg-sized slots in the unspread Loc before self's Loc
1910
1911 e.g. if the unspread Loc containing self is:
1912 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1913 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1914 then reg_offset_into_unspread == 2 == 10 - 8
1915 """
1916 return self.defining_descriptor.reg_offset_in_unspread
1917
1918 @property
1919 def unspread_start_idx(self):
1920 # type: () -> int
1921 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1922
1923 @property
1924 def unspread_start(self):
1925 # type: () -> Self
1926 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1927
1928
1929 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1930 @final
1931 class SSAVal(SSAValOrUse):
1932 __slots__ = ()
1933
1934 def __repr__(self):
1935 # type: () -> str
1936 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1937
1938 @cached_property
1939 def def_loc_set_before_spread(self):
1940 # type: () -> LocSet
1941 return self.defining_descriptor.loc_set_before_spread
1942
1943 @cached_property
1944 def descriptor_array(self):
1945 # type: () -> tuple[OperandDesc, ...]
1946 return self.op.properties.outputs
1947
1948 @cached_property
1949 def tied_input(self):
1950 # type: () -> None | SSAUse
1951 if self.defining_descriptor.tied_input_index is None:
1952 return None
1953 return SSAUse(op=self.op,
1954 operand_idx=self.defining_descriptor.tied_input_index)
1955
1956 @property
1957 def write_stage(self):
1958 # type: () -> OpStage
1959 return self.defining_descriptor.write_stage
1960
1961 @property
1962 def current_debugging_value(self):
1963 # type: () -> tuple[int, ...]
1964 """ get the current value for debugging in pdb or similar.
1965
1966 This is intended for use with
1967 `PreRASimState.set_current_debugging_state`.
1968
1969 This is only intended for debugging, do not use in unit tests or
1970 production code.
1971 """
1972 return PreRASimState.get_current_debugging_state()[self]
1973
1974 @cached_property
1975 def ssa_val_sub_regs(self):
1976 # type: () -> tuple[SSAValSubReg, ...]
1977 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
1978
1979
1980 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1981 @final
1982 class SSAUse(SSAValOrUse):
1983 __slots__ = ()
1984
1985 @cached_property
1986 def use_loc_set_before_spread(self):
1987 # type: () -> LocSet
1988 return self.defining_descriptor.loc_set_before_spread
1989
1990 @cached_property
1991 def descriptor_array(self):
1992 # type: () -> tuple[OperandDesc, ...]
1993 return self.op.properties.inputs
1994
1995 def __repr__(self):
1996 # type: () -> str
1997 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1998
1999 @property
2000 def ssa_val(self):
2001 # type: () -> SSAVal
2002 return self.op.input_vals[self.operand_idx]
2003
2004 @ssa_val.setter
2005 def ssa_val(self, ssa_val):
2006 # type: (SSAVal) -> None
2007 self.op.input_vals[self.operand_idx] = ssa_val
2008
2009
2010 _T = TypeVar("_T")
2011 _Desc = TypeVar("_Desc")
2012
2013
2014 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
2015 @abstractmethod
2016 def _verify_write_with_desc(self, idx, item, desc):
2017 # type: (int, _T | Any, _Desc) -> None
2018 raise NotImplementedError
2019
2020 @final
2021 def _verify_write(self, idx, item):
2022 # type: (int | Any, _T | Any) -> int
2023 if not isinstance(idx, int):
2024 if isinstance(idx, slice):
2025 raise TypeError(
2026 f"can't write to slice of {self.__class__.__name__}")
2027 raise TypeError(f"can't write with index {idx!r}")
2028 # normalize idx, raising IndexError if it is out of range
2029 idx = range(len(self.descriptors))[idx]
2030 desc = self.descriptors[idx]
2031 self._verify_write_with_desc(idx, item, desc)
2032 return idx
2033
2034 def _on_set(self, idx, new_item, old_item):
2035 # type: (int, _T, _T | None) -> None
2036 pass
2037
2038 @abstractmethod
2039 def _get_descriptors(self):
2040 # type: () -> tuple[_Desc, ...]
2041 raise NotImplementedError
2042
2043 @cached_property
2044 @final
2045 def descriptors(self):
2046 # type: () -> tuple[_Desc, ...]
2047 return self._get_descriptors()
2048
2049 @property
2050 @final
2051 def op(self):
2052 return self.__op
2053
2054 def __init__(self, items, op):
2055 # type: (Iterable[_T], Op) -> None
2056 super().__init__()
2057 self.__op = op
2058 self.__items = [] # type: list[_T]
2059 for idx, item in enumerate(items):
2060 if idx >= len(self.descriptors):
2061 raise ValueError("too many items")
2062 _ = self._verify_write(idx, item)
2063 self.__items.append(item)
2064 if len(self.__items) < len(self.descriptors):
2065 raise ValueError("not enough items")
2066
2067 @final
2068 def __iter__(self):
2069 # type: () -> Iterator[_T]
2070 yield from self.__items
2071
2072 @overload
2073 def __getitem__(self, idx):
2074 # type: (int) -> _T
2075 ...
2076
2077 @overload
2078 def __getitem__(self, idx):
2079 # type: (slice) -> list[_T]
2080 ...
2081
2082 @final
2083 def __getitem__(self, idx):
2084 # type: (int | slice) -> _T | list[_T]
2085 return self.__items[idx]
2086
2087 @final
2088 def __setitem__(self, idx, item):
2089 # type: (int, _T) -> None
2090 idx = self._verify_write(idx, item)
2091 self.__items[idx] = item
2092
2093 @final
2094 def __len__(self):
2095 # type: () -> int
2096 return len(self.__items)
2097
2098 def __repr__(self):
2099 # type: () -> str
2100 return f"{self.__class__.__name__}({self.__items}, op=...)"
2101
2102
2103 @final
2104 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2105 def _get_descriptors(self):
2106 # type: () -> tuple[OperandDesc, ...]
2107 return self.op.properties.inputs
2108
2109 def _verify_write_with_desc(self, idx, item, desc):
2110 # type: (int, SSAVal | Any, OperandDesc) -> None
2111 if not isinstance(item, SSAVal):
2112 raise TypeError("expected value of type SSAVal")
2113 if item.ty != desc.ty:
2114 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2115 f"corresponding input's type {desc.ty!r}")
2116
2117 def _on_set(self, idx, new_item, old_item):
2118 # type: (int, SSAVal, SSAVal | None) -> None
2119 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2120
2121 def __init__(self, items, op):
2122 # type: (Iterable[SSAVal], Op) -> None
2123 if hasattr(op, "inputs"):
2124 raise ValueError("Op.inputs already set")
2125 super().__init__(items, op)
2126
2127
2128 @final
2129 class OpImmediates(OpInputSeq[int, range]):
2130 def _get_descriptors(self):
2131 # type: () -> tuple[range, ...]
2132 return self.op.properties.immediates
2133
2134 def _verify_write_with_desc(self, idx, item, desc):
2135 # type: (int, int | Any, range) -> None
2136 if not isinstance(item, int):
2137 raise TypeError("expected value of type int")
2138 if item not in desc:
2139 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2140
2141 def __init__(self, items, op):
2142 # type: (Iterable[int], Op) -> None
2143 if hasattr(op, "immediates"):
2144 raise ValueError("Op.immediates already set")
2145 super().__init__(items, op)
2146
2147
2148 @plain_data(frozen=True, eq=False, repr=False)
2149 @final
2150 class Op:
2151 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2152 "outputs", "name")
2153
2154 def __init__(self, fn, properties, input_vals, immediates, name=""):
2155 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2156 self.fn = fn
2157 self.properties = properties
2158 self.input_vals = OpInputVals(input_vals, op=self)
2159 inputs_len = len(self.properties.inputs)
2160 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2161 self.immediates = OpImmediates(immediates, op=self)
2162 outputs_len = len(self.properties.outputs)
2163 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2164 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2165
2166 @property
2167 def kind(self):
2168 # type: () -> OpKind
2169 return self.properties.kind
2170
2171 def __eq__(self, other):
2172 # type: (Op | Any) -> bool
2173 if isinstance(other, Op):
2174 return self is other
2175 return NotImplemented
2176
2177 def __hash__(self):
2178 # type: () -> int
2179 return object.__hash__(self)
2180
2181 def __repr__(self, wrap_width=63, indent=" "):
2182 # type: (int, str) -> str
2183 WRAP_POINT = "\u200B" # zero-width space
2184 items = [f"{self.name}:\n"]
2185 for i, out in enumerate(self.outputs):
2186 item = f"<...outputs[{i}]: {out.ty}>"
2187 if i == 0:
2188 item = "(" + WRAP_POINT + item
2189 if i != len(self.outputs) - 1:
2190 item += ", " + WRAP_POINT
2191 else:
2192 item += WRAP_POINT + ") <= "
2193 items.append(item)
2194 items.append(self.kind._name_)
2195 if len(self.input_vals) + len(self.immediates) != 0:
2196 items[-1] += "("
2197 items[-1] += WRAP_POINT
2198 for i, inp in enumerate(self.input_vals):
2199 item = repr(inp)
2200 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2201 item += ", " + WRAP_POINT
2202 else:
2203 item += ") " + WRAP_POINT
2204 items.append(item)
2205 for i, imm in enumerate(self.immediates):
2206 item = hex(imm)
2207 if i != len(self.immediates) - 1:
2208 item += ", " + WRAP_POINT
2209 else:
2210 item += ") " + WRAP_POINT
2211 items.append(item)
2212 lines = [] # type: list[str]
2213 for i, line_in in enumerate("".join(items).splitlines()):
2214 if i != 0:
2215 line_in = indent + line_in
2216 line_out = ""
2217 for part in line_in.split(WRAP_POINT):
2218 if line_out == "":
2219 line_out = part
2220 continue
2221 trial_line_out = line_out + part
2222 if len(trial_line_out.rstrip()) > wrap_width:
2223 lines.append(line_out.rstrip())
2224 line_out = indent + part
2225 else:
2226 line_out = trial_line_out
2227 lines.append(line_out.rstrip())
2228 return "\n".join(lines)
2229
2230 def sim(self, state):
2231 # type: (BaseSimState) -> None
2232 for inp in self.input_vals:
2233 try:
2234 val = state[inp]
2235 except KeyError:
2236 raise ValueError(f"SSAVal {inp} not yet assigned when "
2237 f"running {self}")
2238 except SimSkipOp:
2239 continue
2240 if len(val) != inp.ty.reg_len:
2241 raise ValueError(
2242 f"value of SSAVal {inp} has wrong number of elements: "
2243 f"expected {inp.ty.reg_len} found "
2244 f"{len(val)}: {val!r}")
2245 if isinstance(state, PreRASimState):
2246 for out in self.outputs:
2247 if out in state.ssa_vals:
2248 if self.kind is OpKind.FuncArgR3:
2249 continue
2250 raise ValueError(f"SSAVal {out} already assigned before "
2251 f"running {self}")
2252 try:
2253 self.kind.sim(self, state)
2254 except SimSkipOp:
2255 state.on_skip(self)
2256 for out in self.outputs:
2257 try:
2258 val = state[out]
2259 except KeyError:
2260 raise ValueError(f"running {self} failed to assign to {out}")
2261 except SimSkipOp:
2262 continue
2263 if len(val) != out.ty.reg_len:
2264 raise ValueError(
2265 f"value of SSAVal {out} has wrong number of elements: "
2266 f"expected {out.ty.reg_len} found "
2267 f"{len(val)}: {val!r}")
2268
2269 def gen_asm(self, state):
2270 # type: (GenAsmState) -> None
2271 all_loc_kinds = tuple(LocKind)
2272 for inp in self.input_vals:
2273 state.loc(inp, expected_kinds=all_loc_kinds)
2274 for out in self.outputs:
2275 state.loc(out, expected_kinds=all_loc_kinds)
2276 self.kind.gen_asm(self, state)
2277
2278
2279 @plain_data(frozen=True, repr=False)
2280 class BaseSimState(metaclass=ABCMeta):
2281 __slots__ = "memory",
2282
2283 def __init__(self, memory):
2284 # type: (dict[int, int]) -> None
2285 super().__init__()
2286 self.memory = memory # type: dict[int, int]
2287
2288 def _default_memory_value(self):
2289 # type: () -> int
2290 return 0
2291
2292 def on_skip(self, op):
2293 # type: (Op) -> None
2294 raise ValueError("skipping instructions not supported")
2295
2296 def load_byte(self, addr):
2297 # type: (int) -> int
2298 addr &= GPR_VALUE_MASK
2299 try:
2300 return self.memory[addr] & 0xFF
2301 except KeyError:
2302 return self._default_memory_value()
2303
2304 def store_byte(self, addr, value):
2305 # type: (int, int) -> None
2306 addr &= GPR_VALUE_MASK
2307 value &= 0xFF
2308 self.memory[addr] = value
2309
2310 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2311 # type: (int, int, bool) -> int
2312 if addr % size_in_bytes != 0:
2313 raise ValueError(f"address not aligned: {hex(addr)} "
2314 f"required alignment: {size_in_bytes}")
2315 retval = 0
2316 for i in range(size_in_bytes):
2317 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2318 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2319 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2320 return retval
2321
2322 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2323 # type: (int, int, int) -> None
2324 if addr % size_in_bytes != 0:
2325 raise ValueError(f"address not aligned: {hex(addr)} "
2326 f"required alignment: {size_in_bytes}")
2327 for i in range(size_in_bytes):
2328 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2329
2330 def _memory__repr(self):
2331 # type: () -> str
2332 if len(self.memory) == 0:
2333 return "{}"
2334 keys = sorted(self.memory.keys(), reverse=True)
2335 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2336 items = [] # type: list[str]
2337 while len(keys) != 0:
2338 addr = keys[-1]
2339 if (len(keys) >= CHUNK_SIZE
2340 and addr % CHUNK_SIZE == 0
2341 and keys[-CHUNK_SIZE:]
2342 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2343 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2344 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2345 keys[-CHUNK_SIZE:] = ()
2346 else:
2347 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2348 if len(items) == 1:
2349 return f"{{{items[0]}}}"
2350 items_str = ",\n".join(items)
2351 return f"{{\n{items_str}}}"
2352
2353 def __repr__(self):
2354 # type: () -> str
2355 field_vals = [] # type: list[str]
2356 for name in fields(self):
2357 try:
2358 value = getattr(self, name)
2359 except AttributeError:
2360 field_vals.append(f"{name}=<not set>")
2361 continue
2362 repr_fn = getattr(self, f"_{name}__repr", None)
2363 if callable(repr_fn):
2364 field_vals.append(f"{name}={repr_fn()}")
2365 else:
2366 field_vals.append(f"{name}={value!r}")
2367 field_vals_str = ", ".join(field_vals)
2368 return f"{self.__class__.__name__}({field_vals_str})"
2369
2370 @abstractmethod
2371 def __getitem__(self, ssa_val):
2372 # type: (SSAVal) -> tuple[int, ...]
2373 ...
2374
2375 @abstractmethod
2376 def __setitem__(self, ssa_val, value):
2377 # type: (SSAVal, Iterable[int]) -> None
2378 ...
2379
2380
2381 @plain_data(frozen=True, repr=False)
2382 class PreRABaseSimState(BaseSimState):
2383 __slots__ = "ssa_vals",
2384
2385 def __init__(self, ssa_vals, memory):
2386 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2387 super().__init__(memory)
2388 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2389
2390 def _ssa_vals__repr(self):
2391 # type: () -> str
2392 if len(self.ssa_vals) == 0:
2393 return "{}"
2394 items = [] # type: list[str]
2395 CHUNK_SIZE = 4
2396 for k, v in self.ssa_vals.items():
2397 element_strs = [] # type: list[str]
2398 for i, el in enumerate(v):
2399 if i % CHUNK_SIZE != 0:
2400 element_strs.append(" " + hex(el))
2401 else:
2402 element_strs.append("\n " + hex(el))
2403 if len(element_strs) <= CHUNK_SIZE:
2404 element_strs[0] = element_strs[0].lstrip()
2405 if len(element_strs) == 1:
2406 element_strs.append("")
2407 v_str = ",".join(element_strs)
2408 items.append(f"{k!r}: ({v_str})")
2409 if len(items) == 1 and "\n" not in items[0]:
2410 return f"{{{items[0]}}}"
2411 items_str = ",\n".join(items)
2412 return f"{{\n{items_str},\n}}"
2413
2414 def __getitem__(self, ssa_val):
2415 # type: (SSAVal) -> tuple[int, ...]
2416 try:
2417 return self.ssa_vals[ssa_val]
2418 except KeyError:
2419 return self._handle_undefined_ssa_val(ssa_val)
2420
2421 def _handle_undefined_ssa_val(self, ssa_val):
2422 # type: (SSAVal) -> tuple[int, ...]
2423 raise KeyError("SSAVal has no value set", ssa_val)
2424
2425 def __setitem__(self, ssa_val, value):
2426 # type: (SSAVal, Iterable[int]) -> None
2427 value = tuple(map(int, value))
2428 if len(value) != ssa_val.ty.reg_len:
2429 raise ValueError("value has wrong len")
2430 self.ssa_vals[ssa_val] = value
2431
2432
2433 class SimSkipOp(Exception):
2434 pass
2435
2436
2437 @plain_data(frozen=True, repr=False)
2438 @final
2439 class ConstPropagationState(PreRABaseSimState):
2440 __slots__ = "skipped_ops",
2441
2442 def __init__(self, ssa_vals, memory, skipped_ops):
2443 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2444 super().__init__(ssa_vals, memory)
2445 self.skipped_ops = skipped_ops
2446
2447 def _default_memory_value(self):
2448 # type: () -> int
2449 raise SimSkipOp
2450
2451 def _handle_undefined_ssa_val(self, ssa_val):
2452 # type: (SSAVal) -> tuple[int, ...]
2453 raise SimSkipOp
2454
2455 def on_skip(self, op):
2456 # type: (Op) -> None
2457 self.skipped_ops.add(op)
2458
2459
2460 @plain_data(frozen=True, repr=False)
2461 class PreRASimState(PreRABaseSimState):
2462 __slots__ = ()
2463
2464 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2465
2466 @contextmanager
2467 def set_as_current_debugging_state(self):
2468 """ return a context manager that sets self as the current state for
2469 debugging in pdb or similar. This is intended only for use with
2470 `get_current_debugging_state` which should not be used in unit tests
2471 or production code.
2472 """
2473 try:
2474 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2475 yield
2476 finally:
2477 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2478 "inconsistent __CURRENT_DEBUGGING_STATE"
2479
2480 @staticmethod
2481 def get_current_debugging_state():
2482 # type: () -> PreRASimState
2483 """ get the current state for debugging in pdb or similar.
2484
2485 This is intended for use with `set_current_debugging_state`.
2486
2487 This is only intended for debugging, do not use in unit tests or
2488 production code.
2489 """
2490 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2491 raise ValueError("no current debugging state")
2492 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2493
2494
2495 @plain_data(frozen=True, repr=False)
2496 @final
2497 class PostRASimState(BaseSimState):
2498 __slots__ = "ssa_val_to_loc_map", "loc_values"
2499
2500 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2501 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2502 super().__init__(memory)
2503 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2504 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2505 if ssa_val.ty != loc.ty:
2506 raise ValueError(
2507 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2508 self.loc_values = loc_values
2509 for loc in self.loc_values.keys():
2510 if loc.reg_len != 1:
2511 raise ValueError(
2512 "loc_values must only contain Locs with reg_len=1, all "
2513 "larger Locs will be split into reg_len=1 sub-Locs")
2514
2515 def _loc_values__repr(self):
2516 # type: () -> str
2517 locs = sorted(self.loc_values.keys(),
2518 key=lambda v: (v.kind.name, v.start))
2519 items = [] # type: list[str]
2520 for loc in locs:
2521 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2522 items_str = ",\n".join(items)
2523 return f"{{\n{items_str},\n}}"
2524
2525 def __getitem__(self, ssa_val):
2526 # type: (SSAVal) -> tuple[int, ...]
2527 loc = self.ssa_val_to_loc_map[ssa_val]
2528 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2529 retval = [] # type: list[int]
2530 for i in range(loc.reg_len):
2531 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2532 retval.append(self.loc_values.get(subloc, 0))
2533 return tuple(retval)
2534
2535 def __setitem__(self, ssa_val, value):
2536 # type: (SSAVal, Iterable[int]) -> None
2537 value = tuple(map(int, value))
2538 if len(value) != ssa_val.ty.reg_len:
2539 raise ValueError("value has wrong len")
2540 loc = self.ssa_val_to_loc_map[ssa_val]
2541 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2542 for i in range(loc.reg_len):
2543 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2544 self.loc_values[subloc] = value[i]
2545
2546
2547 @plain_data(frozen=True)
2548 class GenAsmState:
2549 __slots__ = "allocated_locs", "output"
2550
2551 def __init__(self, allocated_locs, output=None):
2552 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2553 super().__init__()
2554 self.allocated_locs = FMap(allocated_locs)
2555 for ssa_val, loc in self.allocated_locs.items():
2556 if ssa_val.ty != loc.ty:
2557 raise ValueError(
2558 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2559 if output is None:
2560 output = []
2561 self.output = output
2562
2563 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2564
2565 def loc(self, ssa_val_or_locs, expected_kinds):
2566 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2567 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2568 ssa_val_or_locs = [ssa_val_or_locs]
2569 locs = [] # type: list[Loc]
2570 for i in ssa_val_or_locs:
2571 if isinstance(i, SSAVal):
2572 locs.append(self.allocated_locs[i])
2573 else:
2574 locs.append(i)
2575 if len(locs) == 0:
2576 raise ValueError("invalid Loc sequence: must not be empty")
2577 retval = locs[0].try_concat(*locs[1:])
2578 if retval is None:
2579 raise ValueError("invalid Loc sequence: try_concat failed")
2580 if isinstance(expected_kinds, LocKind):
2581 expected_kinds = expected_kinds,
2582 if retval.kind not in expected_kinds:
2583 if len(expected_kinds) == 1:
2584 expected_kinds = expected_kinds[0]
2585 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2586 f"{retval.kind} expected {expected_kinds}")
2587 return retval
2588
2589 def gpr(self, ssa_val_or_locs, is_vec):
2590 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2591 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2592 vec_str = "*" if is_vec else ""
2593 return vec_str + str(loc.start)
2594
2595 def sgpr(self, ssa_val_or_locs):
2596 # type: (__SSA_VAL_OR_LOCS) -> str
2597 return self.gpr(ssa_val_or_locs, is_vec=False)
2598
2599 def vgpr(self, ssa_val_or_locs):
2600 # type: (__SSA_VAL_OR_LOCS) -> str
2601 return self.gpr(ssa_val_or_locs, is_vec=True)
2602
2603 def stack(self, ssa_val_or_locs):
2604 # type: (__SSA_VAL_OR_LOCS) -> str
2605 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2606 return f"{loc.start}(1)"
2607
2608 def writeln(self, *line_segments):
2609 # type: (*str) -> None
2610 line = " ".join(line_segments)
2611 if isinstance(self.output, list):
2612 self.output.append(line)
2613 else:
2614 self.output.write(line + "\n")