699102b655fe86d19dd67283e21620861c012985
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 from abc import ABCMeta, abstractmethod
4 from enum import Enum, unique
5 from functools import lru_cache, total_ordering
6 from io import StringIO
7 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
8 Mapping, Sequence, TypeVar, Union, overload)
9 from weakref import WeakValueDictionary as _WeakVDict
10
11 from cached_property import cached_property
12 from nmutil.plain_data import fields, plain_data
13
14 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
15 final)
16 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
17 OFSet, OSet)
18
19 GPR_SIZE_IN_BYTES = 8
20 BITS_IN_BYTE = 8
21 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
22 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
23
24
25 @final
26 class Fn:
27 def __init__(self):
28 self.ops = [] # type: list[Op]
29 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
30 self.__next_name_suffix = 2
31
32 def _add_op_with_unused_name(self, op, name=""):
33 # type: (Op, str) -> str
34 if op.fn is not self:
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op, "name"):
37 raise ValueError("Op already named")
38 orig_name = name
39 while True:
40 if name != "" and name not in self.__op_names:
41 self.__op_names[name] = op
42 return name
43 name = orig_name + str(self.__next_name_suffix)
44 self.__next_name_suffix += 1
45
46 def __repr__(self):
47 # type: () -> str
48 return "<Fn>"
49
50 def ops_to_str(self, as_python_literal=False, wrap_width=63,
51 python_indent=" ", indent=" "):
52 # type: (bool, int, str, str) -> str
53 l = [] # type: list[str]
54 for op in self.ops:
55 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
56 retval = "\n".join(l)
57 if as_python_literal:
58 l = [python_indent + "\""]
59 for ch in retval:
60 if ch == "\n":
61 l.append(f"\\n\"\n{python_indent}\"")
62 elif ch in "\"\\":
63 l.append("\\" + ch)
64 elif ch.isascii() and ch.isprintable():
65 l.append(ch)
66 else:
67 l.append(repr(ch).strip("\"'"))
68 l.append("\"")
69 retval = "".join(l)
70 empty_end = f"\"\n{python_indent}\"\""
71 if retval.endswith(empty_end):
72 retval = retval[:-len(empty_end)]
73 return retval
74
75 def append_op(self, op):
76 # type: (Op) -> None
77 if op.fn is not self:
78 raise ValueError("can't add Op to wrong Fn")
79 self.ops.append(op)
80
81 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
82 maxvl=1):
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
85 input_vals=input_vals, immediates=immediates, name=name)
86 self.append_op(retval)
87 return retval
88
89 def sim(self, state):
90 # type: (BaseSimState) -> None
91 for op in self.ops:
92 op.sim(state)
93
94 def gen_asm(self, state):
95 # type: (GenAsmState) -> None
96 for op in self.ops:
97 op.gen_asm(state)
98
99 def pre_ra_insert_copies(self):
100 # type: () -> None
101 orig_ops = list(self.ops)
102 copied_outputs = {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs = {} # type: dict[SSAVal, Op]
104 self.ops.clear()
105 for op in orig_ops:
106 for i in range(len(op.input_vals)):
107 inp = copied_outputs[op.input_vals[i]]
108 if inp.ty.base_ty is BaseTy.I64:
109 maxvl = inp.ty.reg_len
110 if inp.ty.reg_len != 1:
111 setvl = self.append_new_op(
112 OpKind.SetVLI, immediates=[maxvl],
113 name=f"{op.name}.inp{i}.setvl")
114 vl = setvl.outputs[0]
115 mv = self.append_new_op(
116 OpKind.VecCopyToReg, input_vals=[inp, vl],
117 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
118 else:
119 mv = self.append_new_op(
120 OpKind.CopyToReg, input_vals=[inp],
121 name=f"{op.name}.inp{i}.copy")
122 op.input_vals[i] = mv.outputs[0]
123 elif inp.ty.base_ty is BaseTy.CA \
124 or inp.ty.base_ty is BaseTy.VL_MAXVL:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
127 # before the ops VL
128 if inp in setvli_outputs:
129 setvl = self.append_new_op(
130 OpKind.SetVLI,
131 immediates=setvli_outputs[inp].immediates,
132 name=f"{op.name}.inp{i}.setvl")
133 inp = setvl.outputs[0]
134 op.input_vals[i] = inp
135 else:
136 assert_never(inp.ty.base_ty)
137 self.ops.append(op)
138 for i, out in enumerate(op.outputs):
139 if op.kind is OpKind.SetVLI:
140 setvli_outputs[out] = op
141 if out.ty.base_ty is BaseTy.I64:
142 maxvl = out.ty.reg_len
143 if out.ty.reg_len != 1:
144 setvl = self.append_new_op(
145 OpKind.SetVLI, immediates=[maxvl],
146 name=f"{op.name}.out{i}.setvl")
147 vl = setvl.outputs[0]
148 mv = self.append_new_op(
149 OpKind.VecCopyFromReg, input_vals=[out, vl],
150 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
151 else:
152 mv = self.append_new_op(
153 OpKind.CopyFromReg, input_vals=[out],
154 name=f"{op.name}.out{i}.copy")
155 copied_outputs[out] = mv.outputs[0]
156 elif out.ty.base_ty is BaseTy.CA \
157 or out.ty.base_ty is BaseTy.VL_MAXVL:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs[out] = out
160 else:
161 assert_never(out.ty.base_ty)
162
163
164 @final
165 @unique
166 @total_ordering
167 class OpStage(Enum):
168 value: Literal[0, 1] # type: ignore
169
170 def __new__(cls, value):
171 # type: (int) -> OpStage
172 value = int(value)
173 if value not in (0, 1):
174 raise ValueError("invalid value")
175 retval = object.__new__(cls)
176 retval._value_ = value
177 return retval
178
179 Early = 0
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
184
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
187 """
188 Late = 1
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
192 those outputs.
193
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
196 """
197
198 def __repr__(self):
199 # type: () -> str
200 return f"OpStage.{self._name_}"
201
202 def __lt__(self, other):
203 # type: (OpStage | object) -> bool
204 if isinstance(other, OpStage):
205 return self.value < other.value
206 return NotImplemented
207
208
209 assert OpStage.Early < OpStage.Late, "early must be less than late"
210
211
212 @plain_data(frozen=True, unsafe_hash=True, repr=False)
213 @final
214 @total_ordering
215 class ProgramPoint(metaclass=InternedMeta):
216 __slots__ = "op_index", "stage"
217
218 def __init__(self, op_index, stage):
219 # type: (int, OpStage) -> None
220 self.op_index = op_index
221 self.stage = stage
222
223 @property
224 def int_value(self):
225 # type: () -> int
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
228 """
229 return self.op_index * 2 + self.stage.value
230
231 @staticmethod
232 def from_int_value(int_value):
233 # type: (int) -> ProgramPoint
234 op_index, stage = divmod(int_value, 2)
235 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
236
237 def next(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint.from_int_value(self.int_value + steps)
240
241 def prev(self, steps=1):
242 # type: (int) -> ProgramPoint
243 return self.next(steps=-steps)
244
245 def __lt__(self, other):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other, ProgramPoint):
248 return NotImplemented
249 if self.op_index != other.op_index:
250 return self.op_index < other.op_index
251 return self.stage < other.stage
252
253 def __repr__(self):
254 # type: () -> str
255 return f"<ops[{self.op_index}]:{self.stage._name_}>"
256
257
258 @plain_data(frozen=True, unsafe_hash=True, repr=False)
259 @final
260 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
261 __slots__ = "start", "stop"
262
263 def __init__(self, start, stop):
264 # type: (ProgramPoint, ProgramPoint) -> None
265 self.start = start
266 self.stop = stop
267
268 @cached_property
269 def int_value_range(self):
270 # type: () -> range
271 return range(self.start.int_value, self.stop.int_value)
272
273 @staticmethod
274 def from_int_value_range(int_value_range):
275 # type: (range) -> ProgramRange
276 if int_value_range.step != 1:
277 raise ValueError("int_value_range must have step == 1")
278 return ProgramRange(
279 start=ProgramPoint.from_int_value(int_value_range.start),
280 stop=ProgramPoint.from_int_value(int_value_range.stop))
281
282 @overload
283 def __getitem__(self, __idx):
284 # type: (int) -> ProgramPoint
285 ...
286
287 @overload
288 def __getitem__(self, __idx):
289 # type: (slice) -> ProgramRange
290 ...
291
292 def __getitem__(self, __idx):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v = range(self.start.int_value, self.stop.int_value)[__idx]
295 if isinstance(v, int):
296 return ProgramPoint.from_int_value(v)
297 return ProgramRange.from_int_value_range(v)
298
299 def __len__(self):
300 # type: () -> int
301 return len(self.int_value_range)
302
303 def __iter__(self):
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint.from_int_value, self.int_value_range)
306
307 def __repr__(self):
308 # type: () -> str
309 start = repr(self.start).lstrip("<").rstrip(">")
310 stop = repr(self.stop).lstrip("<").rstrip(">")
311 return f"<range:{start}..{stop}>"
312
313
314 @plain_data(frozen=True, unsafe_hash=True, repr=False)
315 @final
316 class SSAValSubReg(metaclass=InternedMeta):
317 __slots__ = "ssa_val", "reg_idx"
318
319 def __init__(self, ssa_val, reg_idx):
320 # type: (SSAVal, int) -> None
321 if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
322 raise ValueError("reg_idx out of range")
323 self.ssa_val = ssa_val
324 self.reg_idx = reg_idx
325
326 def __repr__(self):
327 # type: () -> str
328 return f"{self.ssa_val}[{self.reg_idx}]"
329
330
331 @plain_data(frozen=True, eq=False, repr=False)
332 @final
333 class FnAnalysis:
334 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
335 "def_program_ranges", "use_program_points",
336 "all_program_points")
337
338 def __init__(self, fn):
339 # type: (Fn) -> None
340 self.fn = fn
341 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
342 self.all_program_points = ProgramRange(
343 start=ProgramPoint(op_index=0, stage=OpStage.Early),
344 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
345 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
346 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
347 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
348 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
349 for op in fn.ops:
350 for use in op.input_uses:
351 uses[use.ssa_val].add(use)
352 use_program_point = self.__get_use_program_point(use)
353 use_program_points[use] = use_program_point
354 live_range_stops[use.ssa_val] = max(
355 live_range_stops[use.ssa_val], use_program_point.next())
356 for out in op.outputs:
357 uses[out] = OSet()
358 def_program_range = self.__get_def_program_range(out)
359 def_program_ranges[out] = def_program_range
360 live_range_stops[out] = def_program_range.stop
361 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
362 self.def_program_ranges = FMap(def_program_ranges)
363 self.use_program_points = FMap(use_program_points)
364 live_ranges = {} # type: dict[SSAVal, ProgramRange]
365 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
366 for ssa_val in uses.keys():
367 live_ranges[ssa_val] = live_range = ProgramRange(
368 start=self.def_program_ranges[ssa_val].start,
369 stop=live_range_stops[ssa_val])
370 for program_point in live_range:
371 live_at[program_point].add(ssa_val)
372 self.live_ranges = FMap(live_ranges)
373 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
374 self.copies # initialize
375 self.const_ssa_vals # initialize
376 self.const_ssa_val_sub_regs # initialize
377
378 def __get_def_program_range(self, ssa_val):
379 # type: (SSAVal) -> ProgramRange
380 write_stage = ssa_val.defining_descriptor.write_stage
381 start = ProgramPoint(
382 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
383 # always include late stage of ssa_val.op, to ensure outputs always
384 # overlap all other outputs.
385 # stop is exclusive, so we need the next program point.
386 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
387 return ProgramRange(start=start, stop=stop)
388
389 def __get_use_program_point(self, ssa_use):
390 # type: (SSAUse) -> ProgramPoint
391 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
392 "assumed here, ensured by GenericOpProperties.__init__"
393 return ProgramPoint(
394 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
395
396 def __eq__(self, other):
397 # type: (FnAnalysis | Any) -> bool
398 if isinstance(other, FnAnalysis):
399 return self.fn == other.fn
400 return NotImplemented
401
402 def __hash__(self):
403 # type: () -> int
404 return hash(self.fn)
405
406 def __repr__(self):
407 # type: () -> str
408 return "<FnAnalysis>"
409
410 @cached_property
411 def copies(self):
412 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
413 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
414 a copy of, looking through all layers of copies. The map excludes all
415 SSAValSubRegs that aren't copies of other SSAValSubRegs.
416 This ignores inputs of copy Ops that aren't actually being copied
417 (e.g. the VL input of VecCopyToReg).
418 """
419 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
420 for op in self.op_indexes.keys():
421 if not op.properties.is_copy:
422 continue
423 copy_reg_len = op.properties.copy_reg_len
424 copy_inputs = [] # type: list[SSAValSubReg]
425 for inp in op.input_vals[:op.properties.copy_inputs_len]:
426 for inp_sub_reg in inp.ssa_val_sub_regs:
427 # propagate copies of copies
428 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
429 copy_inputs.append(inp_sub_reg)
430 assert len(copy_inputs) == copy_reg_len, "logic error"
431 copy_outputs = [] # type: list[SSAValSubReg]
432 for out in op.outputs[:op.properties.copy_outputs_len]:
433 copy_outputs.extend(out.ssa_val_sub_regs)
434 assert len(copy_outputs) == copy_reg_len, "logic error"
435 for inp, out in zip(copy_inputs, copy_outputs):
436 retval[out] = inp
437 return FMap(retval)
438
439 @cached_property
440 def copy_related_ssa_vals(self):
441 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
442 """ map from SSAVals to the full set of SSAVals that are related by
443 being sources/destinations of copies, transitively looking through all
444 copies.
445 This ignores inputs of copy Ops that aren't actually being copied
446 (e.g. the VL input of VecCopyToReg).
447 """
448 sets_map = {i: OSet([i]) for i in self.uses.keys()}
449 for k, v in self.copies.items():
450 k_set = sets_map[k.ssa_val]
451 v_set = sets_map[v.ssa_val]
452 # merge k_set and v_set
453 if k_set is v_set:
454 continue
455 k_set |= v_set
456 for i in k_set:
457 sets_map[i] = k_set
458 # this way we construct each OFSet only once rather than
459 # for each SSAVal
460 sets_set = {id(i): i for i in sets_map.values()}
461 retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
462 for v in sets_set.values():
463 v = OFSet(v)
464 for k in v:
465 retval[k] = v
466 return FMap(retval)
467
468 @cached_property
469 def const_ssa_vals(self):
470 # type: () -> FMap[SSAVal, tuple[int, ...]]
471 state = ConstPropagationState(
472 ssa_vals={}, memory={}, skipped_ops=OSet())
473 self.fn.sim(state)
474 return FMap(state.ssa_vals)
475
476 @cached_property
477 def const_ssa_val_sub_regs(self):
478 # type: () -> FMap[SSAValSubReg, int]
479 retval = {} # type: dict[SSAValSubReg, int]
480 for ssa_val, const_val in self.const_ssa_vals.items():
481 assert ssa_val.ty.reg_len == len(const_val), "logic error"
482 for reg_idx, v in enumerate(const_val):
483 retval[SSAValSubReg(ssa_val, reg_idx)] = v
484 return FMap(retval)
485
486 def is_always_equal(self, a, b):
487 # type: (SSAValSubReg, SSAValSubReg) -> bool
488 """check if a and b are known to be always equal to each other.
489 This means they can be allocated to the same location if other
490 constraints don't prevent that.
491
492 this can happen for a number of reasons, such as:
493 * a and b are copies of the same thing
494 * a and b are known to be constants and they have the same value
495 """
496 if a.ssa_val.base_ty != b.ssa_val.base_ty:
497 return False # can't be equal, they have different types
498 # look through copies
499 a = self.copies.get(a, a)
500 b = self.copies.get(b, b)
501 if a == b:
502 return True
503 # check if they have the same constant value
504 try:
505 a_const_val = self.const_ssa_val_sub_regs[a]
506 b_const_val = self.const_ssa_val_sub_regs[b]
507 if a_const_val == b_const_val:
508 return True
509 except KeyError:
510 pass
511 return False
512
513
514 @unique
515 @final
516 class BaseTy(Enum):
517 I64 = enum.auto()
518 CA = enum.auto()
519 VL_MAXVL = enum.auto()
520
521 @cached_property
522 def only_scalar(self):
523 # type: () -> bool
524 if self is BaseTy.I64:
525 return False
526 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
527 return True
528 else:
529 assert_never(self)
530
531 @cached_property
532 def max_reg_len(self):
533 # type: () -> int
534 if self is BaseTy.I64:
535 return 128
536 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
537 return 1
538 else:
539 assert_never(self)
540
541 def __repr__(self):
542 return "BaseTy." + self._name_
543
544
545 @plain_data(frozen=True, unsafe_hash=True, repr=False)
546 @final
547 class Ty(metaclass=InternedMeta):
548 __slots__ = "base_ty", "reg_len"
549
550 @staticmethod
551 def validate(base_ty, reg_len):
552 # type: (BaseTy, int) -> str | None
553 """ return a string with the error if the combination is invalid,
554 otherwise return None
555 """
556 if base_ty.only_scalar and reg_len != 1:
557 return f"can't create a vector of an only-scalar type: {base_ty}"
558 if reg_len < 1 or reg_len > base_ty.max_reg_len:
559 return "reg_len out of range"
560 return None
561
562 def __init__(self, base_ty, reg_len):
563 # type: (BaseTy, int) -> None
564 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
565 if msg is not None:
566 raise ValueError(msg)
567 self.base_ty = base_ty
568 self.reg_len = reg_len
569
570 def __repr__(self):
571 # type: () -> str
572 if self.reg_len != 1:
573 reg_len = f"*{self.reg_len}"
574 else:
575 reg_len = ""
576 return f"<{self.base_ty._name_}{reg_len}>"
577
578
579 @unique
580 @final
581 class LocKind(Enum):
582 GPR = enum.auto()
583 StackI64 = enum.auto()
584 CA = enum.auto()
585 VL_MAXVL = enum.auto()
586
587 @cached_property
588 def base_ty(self):
589 # type: () -> BaseTy
590 if self is LocKind.GPR or self is LocKind.StackI64:
591 return BaseTy.I64
592 if self is LocKind.CA:
593 return BaseTy.CA
594 if self is LocKind.VL_MAXVL:
595 return BaseTy.VL_MAXVL
596 else:
597 assert_never(self)
598
599 @cached_property
600 def loc_count(self):
601 # type: () -> int
602 if self is LocKind.StackI64:
603 return 512
604 if self is LocKind.GPR or self is LocKind.CA \
605 or self is LocKind.VL_MAXVL:
606 return self.base_ty.max_reg_len
607 else:
608 assert_never(self)
609
610 def __repr__(self):
611 return "LocKind." + self._name_
612
613
614 @final
615 @unique
616 class LocSubKind(Enum):
617 BASE_GPR = enum.auto()
618 SV_EXTRA2_VGPR = enum.auto()
619 SV_EXTRA2_SGPR = enum.auto()
620 SV_EXTRA3_VGPR = enum.auto()
621 SV_EXTRA3_SGPR = enum.auto()
622 StackI64 = enum.auto()
623 CA = enum.auto()
624 VL_MAXVL = enum.auto()
625
626 @cached_property
627 def kind(self):
628 # type: () -> LocKind
629 # pyright fails typechecking when using `in` here:
630 # reported: https://github.com/microsoft/pyright/issues/4102
631 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
632 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
633 LocSubKind.SV_EXTRA3_SGPR):
634 return LocKind.GPR
635 if self is LocSubKind.StackI64:
636 return LocKind.StackI64
637 if self is LocSubKind.CA:
638 return LocKind.CA
639 if self is LocSubKind.VL_MAXVL:
640 return LocKind.VL_MAXVL
641 assert_never(self)
642
643 @property
644 def base_ty(self):
645 return self.kind.base_ty
646
647 @lru_cache()
648 def allocatable_locs(self, ty):
649 # type: (Ty) -> LocSet
650 if ty.base_ty != self.base_ty:
651 raise ValueError("type mismatch")
652 if self is LocSubKind.BASE_GPR:
653 starts = range(32)
654 elif self is LocSubKind.SV_EXTRA2_VGPR:
655 starts = range(0, 128, 2)
656 elif self is LocSubKind.SV_EXTRA2_SGPR:
657 starts = range(64)
658 elif self is LocSubKind.SV_EXTRA3_VGPR \
659 or self is LocSubKind.SV_EXTRA3_SGPR:
660 starts = range(128)
661 elif self is LocSubKind.StackI64:
662 starts = range(LocKind.StackI64.loc_count)
663 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
664 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
665 else:
666 assert_never(self)
667 retval = [] # type: list[Loc]
668 for start in starts:
669 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
670 if loc is None:
671 continue
672 conflicts = False
673 for special_loc in SPECIAL_GPRS:
674 if loc.conflicts(special_loc):
675 conflicts = True
676 break
677 if not conflicts:
678 retval.append(loc)
679 return LocSet(retval)
680
681 def __repr__(self):
682 return "LocSubKind." + self._name_
683
684
685 @plain_data(frozen=True, unsafe_hash=True)
686 @final
687 class GenericTy(metaclass=InternedMeta):
688 __slots__ = "base_ty", "is_vec"
689
690 def __init__(self, base_ty, is_vec):
691 # type: (BaseTy, bool) -> None
692 self.base_ty = base_ty
693 if base_ty.only_scalar and is_vec:
694 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
695 self.is_vec = is_vec
696
697 def instantiate(self, maxvl):
698 # type: (int) -> Ty
699 # here's where subvl and elwid would be accounted for
700 if self.is_vec:
701 return Ty(self.base_ty, maxvl)
702 return Ty(self.base_ty, 1)
703
704 def can_instantiate_to(self, ty):
705 # type: (Ty) -> bool
706 if self.base_ty != ty.base_ty:
707 return False
708 if self.is_vec:
709 return True
710 return ty.reg_len == 1
711
712
713 @plain_data(frozen=True, unsafe_hash=True)
714 @final
715 class Loc(metaclass=InternedMeta):
716 __slots__ = "kind", "start", "reg_len"
717
718 @staticmethod
719 def validate(kind, start, reg_len):
720 # type: (LocKind, int, int) -> str | None
721 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
722 if msg is not None:
723 return msg
724 if reg_len > kind.loc_count:
725 return "invalid reg_len"
726 if start < 0 or start + reg_len > kind.loc_count:
727 return "start not in valid range"
728 return None
729
730 @staticmethod
731 def try_make(kind, start, reg_len):
732 # type: (LocKind, int, int) -> Loc | None
733 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
734 if msg is not None:
735 return None
736 return Loc(kind=kind, start=start, reg_len=reg_len)
737
738 def __init__(self, kind, start, reg_len):
739 # type: (LocKind, int, int) -> None
740 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
741 if msg is not None:
742 raise ValueError(msg)
743 self.kind = kind
744 self.reg_len = reg_len
745 self.start = start
746
747 def conflicts(self, other):
748 # type: (Loc) -> bool
749 return (self.kind == other.kind
750 and self.start < other.stop and other.start < self.stop)
751
752 @staticmethod
753 def make_ty(kind, reg_len):
754 # type: (LocKind, int) -> Ty
755 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
756
757 @cached_property
758 def ty(self):
759 # type: () -> Ty
760 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
761
762 @property
763 def stop(self):
764 # type: () -> int
765 return self.start + self.reg_len
766
767 def try_concat(self, *others):
768 # type: (*Loc | None) -> Loc | None
769 reg_len = self.reg_len
770 stop = self.stop
771 for other in others:
772 if other is None or other.kind != self.kind:
773 return None
774 if stop != other.start:
775 return None
776 stop = other.stop
777 reg_len += other.reg_len
778 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
779
780 def get_subloc_at_offset(self, subloc_ty, offset):
781 # type: (Ty, int) -> Loc
782 if subloc_ty.base_ty != self.kind.base_ty:
783 raise ValueError("BaseTy mismatch")
784 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
785 raise ValueError("invalid sub-Loc: offset and/or "
786 "subloc_ty.reg_len out of range")
787 return Loc(kind=self.kind,
788 start=self.start + offset, reg_len=subloc_ty.reg_len)
789
790
791 SPECIAL_GPRS = (
792 Loc(kind=LocKind.GPR, start=0, reg_len=1),
793 Loc(kind=LocKind.GPR, start=1, reg_len=1),
794 Loc(kind=LocKind.GPR, start=2, reg_len=1),
795 Loc(kind=LocKind.GPR, start=13, reg_len=1),
796 )
797
798
799 @final
800 class LocSet(OFSet[Loc], metaclass=InternedMeta):
801 def __init__(self, __locs=()):
802 # type: (Iterable[Loc]) -> None
803 super().__init__(__locs)
804 if isinstance(__locs, LocSet):
805 self.__starts = __locs.starts
806 self.__ty = __locs.ty
807 return
808 starts = {i: BitSet() for i in LocKind}
809 ty = None # type: None | Ty
810 for loc in self:
811 if ty is None:
812 ty = loc.ty
813 if ty != loc.ty:
814 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
815 starts[loc.kind].add(loc.start)
816 self.__starts = FMap(
817 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
818 self.__ty = ty
819
820 @property
821 def starts(self):
822 # type: () -> FMap[LocKind, FBitSet]
823 return self.__starts
824
825 @property
826 def ty(self):
827 # type: () -> Ty | None
828 return self.__ty
829
830 @cached_property
831 def stops(self):
832 # type: () -> FMap[LocKind, FBitSet]
833 if self.ty is None:
834 return FMap()
835 sh = self.ty.reg_len
836 return FMap(
837 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
838
839 @property
840 def kinds(self):
841 # type: () -> AbstractSet[LocKind]
842 return self.starts.keys()
843
844 @property
845 def reg_len(self):
846 # type: () -> int | None
847 if self.ty is None:
848 return None
849 return self.ty.reg_len
850
851 @property
852 def base_ty(self):
853 # type: () -> BaseTy | None
854 if self.ty is None:
855 return None
856 return self.ty.base_ty
857
858 def concat(self, *others):
859 # type: (*LocSet) -> LocSet
860 if self.ty is None:
861 return LocSet()
862 base_ty = self.ty.base_ty
863 reg_len = self.ty.reg_len
864 starts = {k: BitSet(v) for k, v in self.starts.items()}
865 for other in others:
866 if other.ty is None:
867 return LocSet()
868 if other.ty.base_ty != base_ty:
869 return LocSet()
870 for kind, other_starts in other.starts.items():
871 if kind not in starts:
872 continue
873 starts[kind].bits &= other_starts.bits >> reg_len
874 if starts[kind] == 0:
875 del starts[kind]
876 if len(starts) == 0:
877 return LocSet()
878 reg_len += other.ty.reg_len
879
880 def locs():
881 # type: () -> Iterable[Loc]
882 for kind, v in starts.items():
883 for start in v:
884 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
885 if loc is not None:
886 yield loc
887 return LocSet(locs())
888
889 @lru_cache(maxsize=None, typed=True)
890 def max_conflicts_with(self, other):
891 # type: (LocSet | Loc) -> int
892 """the largest number of Locs in `self` that a single Loc
893 from `other` can conflict with
894 """
895 if isinstance(other, LocSet):
896 return max(self.max_conflicts_with(i) for i in other)
897 else:
898 return sum(other.conflicts(i) for i in self)
899
900 def __repr__(self):
901 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
902
903
904 @plain_data(frozen=True, unsafe_hash=True)
905 @final
906 class GenericOperandDesc(metaclass=InternedMeta):
907 """generic Op operand descriptor"""
908 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
909 "write_stage")
910
911 def __init__(
912 self, ty, # type: GenericTy
913 sub_kinds, # type: Iterable[LocSubKind]
914 *,
915 fixed_loc=None, # type: Loc | None
916 tied_input_index=None, # type: int | None
917 spread=False, # type: bool
918 write_stage=OpStage.Early, # type: OpStage
919 ):
920 # type: (...) -> None
921 self.ty = ty
922 self.sub_kinds = OFSet(sub_kinds)
923 if len(self.sub_kinds) == 0:
924 raise ValueError("sub_kinds can't be empty")
925 self.fixed_loc = fixed_loc
926 if fixed_loc is not None:
927 if tied_input_index is not None:
928 raise ValueError("operand can't be both tied and fixed")
929 if not ty.can_instantiate_to(fixed_loc.ty):
930 raise ValueError(
931 f"fixed_loc has incompatible type for given generic "
932 f"type: fixed_loc={fixed_loc} generic ty={ty}")
933 if len(self.sub_kinds) != 1:
934 raise ValueError(
935 "multiple sub_kinds not allowed for fixed operand")
936 for sub_kind in self.sub_kinds:
937 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
938 raise ValueError(
939 f"fixed_loc not in given sub_kind: "
940 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
941 for sub_kind in self.sub_kinds:
942 if sub_kind.base_ty != ty.base_ty:
943 raise ValueError(f"sub_kind is incompatible with type: "
944 f"sub_kind={sub_kind} ty={ty}")
945 if tied_input_index is not None and tied_input_index < 0:
946 raise ValueError("invalid tied_input_index")
947 self.tied_input_index = tied_input_index
948 self.spread = spread
949 if spread:
950 if self.tied_input_index is not None:
951 raise ValueError("operand can't be both spread and tied")
952 if self.fixed_loc is not None:
953 raise ValueError("operand can't be both spread and fixed")
954 if self.ty.is_vec:
955 raise ValueError("operand can't be both spread and vector")
956 self.write_stage = write_stage
957
958 @cached_property
959 def ty_before_spread(self):
960 # type: () -> GenericTy
961 if self.spread:
962 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
963 return self.ty
964
965 def tied_to_input(self, tied_input_index):
966 # type: (int) -> Self
967 return GenericOperandDesc(self.ty, self.sub_kinds,
968 tied_input_index=tied_input_index,
969 write_stage=self.write_stage)
970
971 def with_fixed_loc(self, fixed_loc):
972 # type: (Loc) -> Self
973 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
974 write_stage=self.write_stage)
975
976 def with_write_stage(self, write_stage):
977 # type: (OpStage) -> Self
978 return GenericOperandDesc(self.ty, self.sub_kinds,
979 fixed_loc=self.fixed_loc,
980 tied_input_index=self.tied_input_index,
981 spread=self.spread,
982 write_stage=write_stage)
983
984 def instantiate(self, maxvl):
985 # type: (int) -> Iterable[OperandDesc]
986 # assumes all spread operands have ty.reg_len = 1
987 rep_count = 1
988 if self.spread:
989 rep_count = maxvl
990 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
991
992 def locs_before_spread():
993 # type: () -> Iterable[Loc]
994 if self.fixed_loc is not None:
995 if ty_before_spread != self.fixed_loc.ty:
996 raise ValueError(
997 f"instantiation failed: type mismatch with fixed_loc: "
998 f"instantiated type: {ty_before_spread} "
999 f"fixed_loc: {self.fixed_loc}")
1000 yield self.fixed_loc
1001 return
1002 for sub_kind in self.sub_kinds:
1003 yield from sub_kind.allocatable_locs(ty_before_spread)
1004 loc_set_before_spread = LocSet(locs_before_spread())
1005 for idx in range(rep_count):
1006 if not self.spread:
1007 idx = None
1008 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
1009 tied_input_index=self.tied_input_index,
1010 spread_index=idx, write_stage=self.write_stage)
1011
1012
1013 @plain_data(frozen=True, unsafe_hash=True)
1014 @final
1015 class OperandDesc(metaclass=InternedMeta):
1016 """Op operand descriptor"""
1017 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
1018 "write_stage")
1019
1020 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
1021 write_stage):
1022 # type: (LocSet, int | None, int | None, OpStage) -> None
1023 if len(loc_set_before_spread) == 0:
1024 raise ValueError("loc_set_before_spread must not be empty")
1025 self.loc_set_before_spread = loc_set_before_spread
1026 self.tied_input_index = tied_input_index
1027 if self.tied_input_index is not None and spread_index is not None:
1028 raise ValueError("operand can't be both spread and tied")
1029 self.spread_index = spread_index
1030 self.write_stage = write_stage
1031
1032 @cached_property
1033 def ty_before_spread(self):
1034 # type: () -> Ty
1035 ty = self.loc_set_before_spread.ty
1036 assert ty is not None, (
1037 "__init__ checked that the LocSet isn't empty, "
1038 "non-empty LocSets should always have ty set")
1039 return ty
1040
1041 @cached_property
1042 def ty(self):
1043 """ Ty after any spread is applied """
1044 if self.spread_index is not None:
1045 # assumes all spread operands have ty.reg_len = 1
1046 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1047 return self.ty_before_spread
1048
1049 @property
1050 def reg_offset_in_unspread(self):
1051 """ the number of reg-sized slots in the unspread Loc before self's Loc
1052
1053 e.g. if the unspread Loc containing self is:
1054 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1055 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1056 then reg_offset_into_unspread == 2 == 10 - 8
1057 """
1058 if self.spread_index is None:
1059 return 0
1060 return self.spread_index * self.ty.reg_len
1061
1062
1063 OD_BASE_SGPR = GenericOperandDesc(
1064 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1065 sub_kinds=[LocSubKind.BASE_GPR])
1066 OD_EXTRA3_SGPR = GenericOperandDesc(
1067 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1068 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1069 OD_EXTRA3_VGPR = GenericOperandDesc(
1070 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1071 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1072 OD_EXTRA2_SGPR = GenericOperandDesc(
1073 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1074 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1075 OD_EXTRA2_VGPR = GenericOperandDesc(
1076 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1077 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1078 OD_CA = GenericOperandDesc(
1079 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1080 sub_kinds=[LocSubKind.CA])
1081 OD_VL = GenericOperandDesc(
1082 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1083 sub_kinds=[LocSubKind.VL_MAXVL])
1084
1085
1086 @plain_data(frozen=True, unsafe_hash=True)
1087 @final
1088 class GenericOpProperties(metaclass=InternedMeta):
1089 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
1090 "is_copy", "is_load_immediate", "has_side_effects")
1091
1092 def __init__(
1093 self, demo_asm, # type: str
1094 inputs, # type: Iterable[GenericOperandDesc]
1095 outputs, # type: Iterable[GenericOperandDesc]
1096 immediates=(), # type: Iterable[range]
1097 is_copy=False, # type: bool
1098 is_load_immediate=False, # type: bool
1099 has_side_effects=False, # type: bool
1100 ):
1101 # type: (...) -> None
1102 self.demo_asm = demo_asm # type: str
1103 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
1104 for inp in self.inputs:
1105 if inp.tied_input_index is not None:
1106 raise ValueError(
1107 f"tied_input_index is not allowed on inputs: {inp}")
1108 if inp.write_stage is not OpStage.Early:
1109 raise ValueError(
1110 f"write_stage is not allowed on inputs: {inp}")
1111 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
1112 fixed_locs = [] # type: list[tuple[Loc, int]]
1113 for idx, out in enumerate(self.outputs):
1114 if out.tied_input_index is not None:
1115 if out.tied_input_index >= len(self.inputs):
1116 raise ValueError(f"tied_input_index out of range: {out}")
1117 tied_inp = self.inputs[out.tied_input_index]
1118 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1119 .with_write_stage(out.write_stage)
1120 if expected_out != out:
1121 raise ValueError(f"output can't be tied to non-equivalent "
1122 f"input: {out} tied to {tied_inp}")
1123 if out.fixed_loc is not None:
1124 for other_fixed_loc, other_idx in fixed_locs:
1125 if not other_fixed_loc.conflicts(out.fixed_loc):
1126 continue
1127 raise ValueError(
1128 f"conflicting fixed_locs: outputs[{idx}] and "
1129 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1130 f"with {other_fixed_loc}")
1131 fixed_locs.append((out.fixed_loc, idx))
1132 self.immediates = tuple(immediates) # type: tuple[range, ...]
1133 self.is_copy = is_copy # type: bool
1134 self.is_load_immediate = is_load_immediate # type: bool
1135 self.has_side_effects = has_side_effects # type: bool
1136
1137
1138 @plain_data(frozen=True, unsafe_hash=True)
1139 @final
1140 class OpProperties(metaclass=InternedMeta):
1141 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1142
1143 def __init__(self, kind, maxvl):
1144 # type: (OpKind, int) -> None
1145 self.kind = kind # type: OpKind
1146 inputs = [] # type: list[OperandDesc]
1147 for inp in self.generic.inputs:
1148 inputs.extend(inp.instantiate(maxvl=maxvl))
1149 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1150 outputs = [] # type: list[OperandDesc]
1151 for out in self.generic.outputs:
1152 outputs.extend(out.instantiate(maxvl=maxvl))
1153 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1154 self.maxvl = maxvl # type: int
1155 copy_input_reg_len = 0
1156 for inp in self.inputs[:self.copy_inputs_len]:
1157 copy_input_reg_len += inp.ty.reg_len
1158 copy_output_reg_len = 0
1159 for out in self.outputs[:self.copy_outputs_len]:
1160 copy_output_reg_len += out.ty.reg_len
1161 if copy_input_reg_len != copy_output_reg_len:
1162 raise ValueError(f"invalid copy: copy's input reg len must "
1163 f"match its output reg len: "
1164 f"{copy_input_reg_len} != {copy_output_reg_len}")
1165 self.copy_reg_len = copy_input_reg_len
1166
1167 @property
1168 def generic(self):
1169 # type: () -> GenericOpProperties
1170 return self.kind.properties
1171
1172 @property
1173 def immediates(self):
1174 # type: () -> tuple[range, ...]
1175 return self.generic.immediates
1176
1177 @property
1178 def demo_asm(self):
1179 # type: () -> str
1180 return self.generic.demo_asm
1181
1182 @property
1183 def is_copy(self):
1184 # type: () -> bool
1185 return self.generic.is_copy
1186
1187 @property
1188 def is_load_immediate(self):
1189 # type: () -> bool
1190 return self.generic.is_load_immediate
1191
1192 @property
1193 def has_side_effects(self):
1194 # type: () -> bool
1195 return self.generic.has_side_effects
1196
1197 @cached_property
1198 def copy_inputs_len(self):
1199 # type: () -> int
1200 if not self.is_copy:
1201 return 0
1202 if self.inputs[0].spread_index is None:
1203 return 1
1204 retval = 0
1205 for i, inp in enumerate(self.inputs):
1206 if inp.spread_index != i:
1207 break
1208 retval += 1
1209 return retval
1210
1211 @cached_property
1212 def copy_outputs_len(self):
1213 # type: () -> int
1214 if not self.is_copy:
1215 return 0
1216 if self.outputs[0].spread_index is None:
1217 return 1
1218 retval = 0
1219 for i, out in enumerate(self.outputs):
1220 if out.spread_index != i:
1221 break
1222 retval += 1
1223 return retval
1224
1225
1226 IMM_S16 = range(-1 << 15, 1 << 15)
1227
1228 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1229 _SIM_FN2 = Callable[[], _SIM_FN]
1230 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1231 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1232 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1233 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1234
1235
1236 @unique
1237 @final
1238 class OpKind(Enum):
1239 def __init__(self, properties):
1240 # type: (GenericOpProperties) -> None
1241 super().__init__()
1242 self.__properties = properties
1243
1244 @property
1245 def properties(self):
1246 # type: () -> GenericOpProperties
1247 return self.__properties
1248
1249 def instantiate(self, maxvl):
1250 # type: (int) -> OpProperties
1251 return OpProperties(self, maxvl=maxvl)
1252
1253 def __repr__(self):
1254 # type: () -> str
1255 return "OpKind." + self._name_
1256
1257 @cached_property
1258 def sim(self):
1259 # type: () -> _SIM_FN
1260 return _SIM_FNS[self.properties]()
1261
1262 @cached_property
1263 def gen_asm(self):
1264 # type: () -> _GEN_ASM_FN
1265 return _GEN_ASMS[self.properties]()
1266
1267 @staticmethod
1268 def __clearca_sim(op, state):
1269 # type: (Op, BaseSimState) -> None
1270 state[op.outputs[0]] = False,
1271
1272 @staticmethod
1273 def __clearca_gen_asm(op, state):
1274 # type: (Op, GenAsmState) -> None
1275 state.writeln("addic 0, 0, 0")
1276 ClearCA = GenericOpProperties(
1277 demo_asm="addic 0, 0, 0",
1278 inputs=[],
1279 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1280 )
1281 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1282 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1283
1284 @staticmethod
1285 def __setca_sim(op, state):
1286 # type: (Op, BaseSimState) -> None
1287 state[op.outputs[0]] = True,
1288
1289 @staticmethod
1290 def __setca_gen_asm(op, state):
1291 # type: (Op, GenAsmState) -> None
1292 state.writeln("subfc 0, 0, 0")
1293 SetCA = GenericOpProperties(
1294 demo_asm="subfc 0, 0, 0",
1295 inputs=[],
1296 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1297 )
1298 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1299 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1300
1301 @staticmethod
1302 def __svadde_sim(op, state):
1303 # type: (Op, BaseSimState) -> None
1304 RA = state[op.input_vals[0]]
1305 RB = state[op.input_vals[1]]
1306 carry, = state[op.input_vals[2]]
1307 VL, = state[op.input_vals[3]]
1308 RT = [] # type: list[int]
1309 for i in range(VL):
1310 v = RA[i] + RB[i] + carry
1311 RT.append(v & GPR_VALUE_MASK)
1312 carry = (v >> GPR_SIZE_IN_BITS) != 0
1313 state[op.outputs[0]] = tuple(RT)
1314 state[op.outputs[1]] = carry,
1315
1316 @staticmethod
1317 def __svadde_gen_asm(op, state):
1318 # type: (Op, GenAsmState) -> None
1319 RT = state.vgpr(op.outputs[0])
1320 RA = state.vgpr(op.input_vals[0])
1321 RB = state.vgpr(op.input_vals[1])
1322 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1323 SvAddE = GenericOpProperties(
1324 demo_asm="sv.adde *RT, *RA, *RB",
1325 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1326 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1327 )
1328 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1329 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1330
1331 @staticmethod
1332 def __addze_sim(op, state):
1333 # type: (Op, BaseSimState) -> None
1334 RA, = state[op.input_vals[0]]
1335 carry, = state[op.input_vals[1]]
1336 v = RA + carry
1337 RT = v & GPR_VALUE_MASK
1338 carry = (v >> GPR_SIZE_IN_BITS) != 0
1339 state[op.outputs[0]] = RT,
1340 state[op.outputs[1]] = carry,
1341
1342 @staticmethod
1343 def __addze_gen_asm(op, state):
1344 # type: (Op, GenAsmState) -> None
1345 RT = state.vgpr(op.outputs[0])
1346 RA = state.vgpr(op.input_vals[0])
1347 state.writeln(f"addze {RT}, {RA}")
1348 AddZE = GenericOpProperties(
1349 demo_asm="addze RT, RA",
1350 inputs=[OD_BASE_SGPR, OD_CA],
1351 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1352 )
1353 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1354 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1355
1356 @staticmethod
1357 def __svsubfe_sim(op, state):
1358 # type: (Op, BaseSimState) -> None
1359 RA = state[op.input_vals[0]]
1360 RB = state[op.input_vals[1]]
1361 carry, = state[op.input_vals[2]]
1362 VL, = state[op.input_vals[3]]
1363 RT = [] # type: list[int]
1364 for i in range(VL):
1365 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1366 RT.append(v & GPR_VALUE_MASK)
1367 carry = (v >> GPR_SIZE_IN_BITS) != 0
1368 state[op.outputs[0]] = tuple(RT)
1369 state[op.outputs[1]] = carry,
1370
1371 @staticmethod
1372 def __svsubfe_gen_asm(op, state):
1373 # type: (Op, GenAsmState) -> None
1374 RT = state.vgpr(op.outputs[0])
1375 RA = state.vgpr(op.input_vals[0])
1376 RB = state.vgpr(op.input_vals[1])
1377 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1378 SvSubFE = GenericOpProperties(
1379 demo_asm="sv.subfe *RT, *RA, *RB",
1380 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1381 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1382 )
1383 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1384 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1385
1386 @staticmethod
1387 def __svandvs_sim(op, state):
1388 # type: (Op, BaseSimState) -> None
1389 RA = state[op.input_vals[0]]
1390 RB, = state[op.input_vals[1]]
1391 VL, = state[op.input_vals[2]]
1392 RT = [] # type: list[int]
1393 for i in range(VL):
1394 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1395 state[op.outputs[0]] = tuple(RT)
1396
1397 @staticmethod
1398 def __svandvs_gen_asm(op, state):
1399 # type: (Op, GenAsmState) -> None
1400 RT = state.vgpr(op.outputs[0])
1401 RA = state.vgpr(op.input_vals[0])
1402 RB = state.sgpr(op.input_vals[1])
1403 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1404 SvAndVS = GenericOpProperties(
1405 demo_asm="sv.and *RT, *RA, RB",
1406 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1407 outputs=[OD_EXTRA3_VGPR],
1408 )
1409 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1410 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1411
1412 @staticmethod
1413 def __svmaddedu_sim(op, state):
1414 # type: (Op, BaseSimState) -> None
1415 RA = state[op.input_vals[0]]
1416 RB, = state[op.input_vals[1]]
1417 carry, = state[op.input_vals[2]]
1418 VL, = state[op.input_vals[3]]
1419 RT = [] # type: list[int]
1420 for i in range(VL):
1421 v = RA[i] * RB + carry
1422 RT.append(v & GPR_VALUE_MASK)
1423 carry = v >> GPR_SIZE_IN_BITS
1424 state[op.outputs[0]] = tuple(RT)
1425 state[op.outputs[1]] = carry,
1426
1427 @staticmethod
1428 def __svmaddedu_gen_asm(op, state):
1429 # type: (Op, GenAsmState) -> None
1430 RT = state.vgpr(op.outputs[0])
1431 RA = state.vgpr(op.input_vals[0])
1432 RB = state.sgpr(op.input_vals[1])
1433 RC = state.sgpr(op.input_vals[2])
1434 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1435 SvMAddEDU = GenericOpProperties(
1436 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1437 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1438 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1439 )
1440 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1441 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1442
1443 @staticmethod
1444 def __sradi_sim(op, state):
1445 # type: (Op, BaseSimState) -> None
1446 rs, = state[op.input_vals[0]]
1447 imm = op.immediates[0]
1448 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1449 rs -= 1 << GPR_SIZE_IN_BITS
1450 v = rs >> imm
1451 RA = v & GPR_VALUE_MASK
1452 CA = (RA << imm) != rs
1453 state[op.outputs[0]] = RA,
1454 state[op.outputs[1]] = CA,
1455
1456 @staticmethod
1457 def __sradi_gen_asm(op, state):
1458 # type: (Op, GenAsmState) -> None
1459 RA = state.sgpr(op.outputs[0])
1460 RS = state.sgpr(op.input_vals[0])
1461 imm = op.immediates[0]
1462 state.writeln(f"sradi {RA}, {RS}, {imm}")
1463 SRADI = GenericOpProperties(
1464 demo_asm="sradi RA, RS, imm",
1465 inputs=[OD_BASE_SGPR],
1466 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1467 OD_CA.with_write_stage(OpStage.Late)],
1468 immediates=[range(GPR_SIZE_IN_BITS)],
1469 )
1470 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1471 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1472
1473 @staticmethod
1474 def __setvli_sim(op, state):
1475 # type: (Op, BaseSimState) -> None
1476 state[op.outputs[0]] = op.immediates[0],
1477
1478 @staticmethod
1479 def __setvli_gen_asm(op, state):
1480 # type: (Op, GenAsmState) -> None
1481 imm = op.immediates[0]
1482 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1483 SetVLI = GenericOpProperties(
1484 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1485 inputs=(),
1486 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1487 immediates=[range(1, 65)],
1488 is_load_immediate=True,
1489 )
1490 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1491 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1492
1493 @staticmethod
1494 def __svli_sim(op, state):
1495 # type: (Op, BaseSimState) -> None
1496 VL, = state[op.input_vals[0]]
1497 imm = op.immediates[0] & GPR_VALUE_MASK
1498 state[op.outputs[0]] = (imm,) * VL
1499
1500 @staticmethod
1501 def __svli_gen_asm(op, state):
1502 # type: (Op, GenAsmState) -> None
1503 RT = state.vgpr(op.outputs[0])
1504 imm = op.immediates[0]
1505 state.writeln(f"sv.addi {RT}, 0, {imm}")
1506 SvLI = GenericOpProperties(
1507 demo_asm="sv.addi *RT, 0, imm",
1508 inputs=[OD_VL],
1509 outputs=[OD_EXTRA3_VGPR],
1510 immediates=[IMM_S16],
1511 is_load_immediate=True,
1512 )
1513 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1514 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1515
1516 @staticmethod
1517 def __li_sim(op, state):
1518 # type: (Op, BaseSimState) -> None
1519 imm = op.immediates[0] & GPR_VALUE_MASK
1520 state[op.outputs[0]] = imm,
1521
1522 @staticmethod
1523 def __li_gen_asm(op, state):
1524 # type: (Op, GenAsmState) -> None
1525 RT = state.sgpr(op.outputs[0])
1526 imm = op.immediates[0]
1527 state.writeln(f"addi {RT}, 0, {imm}")
1528 LI = GenericOpProperties(
1529 demo_asm="addi RT, 0, imm",
1530 inputs=(),
1531 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1532 immediates=[IMM_S16],
1533 is_load_immediate=True,
1534 )
1535 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1536 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1537
1538 @staticmethod
1539 def __veccopytoreg_sim(op, state):
1540 # type: (Op, BaseSimState) -> None
1541 state[op.outputs[0]] = state[op.input_vals[0]]
1542
1543 @staticmethod
1544 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1545 # type: (Loc, Loc, bool, GenAsmState) -> None
1546 sv = "sv." if is_vec else ""
1547 rev = ""
1548 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1549 rev = "/mrr"
1550 if src_loc == dest_loc:
1551 return # no-op
1552 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1553 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1554 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1555 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1556 if src_loc.kind is LocKind.StackI64:
1557 if dest_loc.kind is LocKind.StackI64:
1558 raise ValueError(
1559 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1560 elif dest_loc.kind is not LocKind.GPR:
1561 assert_never(dest_loc.kind)
1562 src = state.stack(src_loc)
1563 dest = state.gpr(dest_loc, is_vec=is_vec)
1564 state.writeln(f"{sv}ld {dest}, {src}")
1565 elif dest_loc.kind is LocKind.StackI64:
1566 if src_loc.kind is not LocKind.GPR:
1567 assert_never(src_loc.kind)
1568 src = state.gpr(src_loc, is_vec=is_vec)
1569 dest = state.stack(dest_loc)
1570 state.writeln(f"{sv}std {src}, {dest}")
1571 elif src_loc.kind is LocKind.GPR:
1572 if dest_loc.kind is not LocKind.GPR:
1573 assert_never(dest_loc.kind)
1574 src = state.gpr(src_loc, is_vec=is_vec)
1575 dest = state.gpr(dest_loc, is_vec=is_vec)
1576 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1577 else:
1578 assert_never(src_loc.kind)
1579
1580 @staticmethod
1581 def __veccopytoreg_gen_asm(op, state):
1582 # type: (Op, GenAsmState) -> None
1583 OpKind.__copy_to_from_reg_gen_asm(
1584 src_loc=state.loc(
1585 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1586 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1587 is_vec=True, state=state)
1588
1589 VecCopyToReg = GenericOpProperties(
1590 demo_asm="sv.mv dest, src",
1591 inputs=[GenericOperandDesc(
1592 ty=GenericTy(BaseTy.I64, is_vec=True),
1593 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1594 ), OD_VL],
1595 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1596 is_copy=True,
1597 )
1598 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1599 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1600
1601 @staticmethod
1602 def __veccopyfromreg_sim(op, state):
1603 # type: (Op, BaseSimState) -> None
1604 state[op.outputs[0]] = state[op.input_vals[0]]
1605
1606 @staticmethod
1607 def __veccopyfromreg_gen_asm(op, state):
1608 # type: (Op, GenAsmState) -> None
1609 OpKind.__copy_to_from_reg_gen_asm(
1610 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1611 dest_loc=state.loc(
1612 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1613 is_vec=True, state=state)
1614 VecCopyFromReg = GenericOpProperties(
1615 demo_asm="sv.mv dest, src",
1616 inputs=[OD_EXTRA3_VGPR, OD_VL],
1617 outputs=[GenericOperandDesc(
1618 ty=GenericTy(BaseTy.I64, is_vec=True),
1619 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1620 write_stage=OpStage.Late,
1621 )],
1622 is_copy=True,
1623 )
1624 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1625 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1626
1627 @staticmethod
1628 def __copytoreg_sim(op, state):
1629 # type: (Op, BaseSimState) -> None
1630 state[op.outputs[0]] = state[op.input_vals[0]]
1631
1632 @staticmethod
1633 def __copytoreg_gen_asm(op, state):
1634 # type: (Op, GenAsmState) -> None
1635 OpKind.__copy_to_from_reg_gen_asm(
1636 src_loc=state.loc(
1637 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1638 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1639 is_vec=False, state=state)
1640 CopyToReg = GenericOpProperties(
1641 demo_asm="mv dest, src",
1642 inputs=[GenericOperandDesc(
1643 ty=GenericTy(BaseTy.I64, is_vec=False),
1644 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1645 LocSubKind.StackI64],
1646 )],
1647 outputs=[GenericOperandDesc(
1648 ty=GenericTy(BaseTy.I64, is_vec=False),
1649 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1650 write_stage=OpStage.Late,
1651 )],
1652 is_copy=True,
1653 )
1654 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1655 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1656
1657 @staticmethod
1658 def __copyfromreg_sim(op, state):
1659 # type: (Op, BaseSimState) -> None
1660 state[op.outputs[0]] = state[op.input_vals[0]]
1661
1662 @staticmethod
1663 def __copyfromreg_gen_asm(op, state):
1664 # type: (Op, GenAsmState) -> None
1665 OpKind.__copy_to_from_reg_gen_asm(
1666 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1667 dest_loc=state.loc(
1668 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1669 is_vec=False, state=state)
1670 CopyFromReg = GenericOpProperties(
1671 demo_asm="mv dest, src",
1672 inputs=[GenericOperandDesc(
1673 ty=GenericTy(BaseTy.I64, is_vec=False),
1674 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1675 )],
1676 outputs=[GenericOperandDesc(
1677 ty=GenericTy(BaseTy.I64, is_vec=False),
1678 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1679 LocSubKind.StackI64],
1680 write_stage=OpStage.Late,
1681 )],
1682 is_copy=True,
1683 )
1684 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1685 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1686
1687 @staticmethod
1688 def __concat_sim(op, state):
1689 # type: (Op, BaseSimState) -> None
1690 state[op.outputs[0]] = tuple(
1691 state[i][0] for i in op.input_vals[:-1])
1692
1693 @staticmethod
1694 def __concat_gen_asm(op, state):
1695 # type: (Op, GenAsmState) -> None
1696 OpKind.__copy_to_from_reg_gen_asm(
1697 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1698 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1699 is_vec=True, state=state)
1700 Concat = GenericOpProperties(
1701 demo_asm="sv.mv dest, src",
1702 inputs=[GenericOperandDesc(
1703 ty=GenericTy(BaseTy.I64, is_vec=False),
1704 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1705 spread=True,
1706 ), OD_VL],
1707 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1708 is_copy=True,
1709 )
1710 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1711 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1712
1713 @staticmethod
1714 def __spread_sim(op, state):
1715 # type: (Op, BaseSimState) -> None
1716 for idx, inp in enumerate(state[op.input_vals[0]]):
1717 state[op.outputs[idx]] = inp,
1718
1719 @staticmethod
1720 def __spread_gen_asm(op, state):
1721 # type: (Op, GenAsmState) -> None
1722 OpKind.__copy_to_from_reg_gen_asm(
1723 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1724 dest_loc=state.loc(op.outputs, LocKind.GPR),
1725 is_vec=True, state=state)
1726 Spread = GenericOpProperties(
1727 demo_asm="sv.mv dest, src",
1728 inputs=[OD_EXTRA3_VGPR, OD_VL],
1729 outputs=[GenericOperandDesc(
1730 ty=GenericTy(BaseTy.I64, is_vec=False),
1731 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1732 spread=True,
1733 write_stage=OpStage.Late,
1734 )],
1735 is_copy=True,
1736 )
1737 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1738 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1739
1740 @staticmethod
1741 def __svld_sim(op, state):
1742 # type: (Op, BaseSimState) -> None
1743 RA, = state[op.input_vals[0]]
1744 VL, = state[op.input_vals[1]]
1745 addr = RA + op.immediates[0]
1746 RT = [] # type: list[int]
1747 for i in range(VL):
1748 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1749 RT.append(v & GPR_VALUE_MASK)
1750 state[op.outputs[0]] = tuple(RT)
1751
1752 @staticmethod
1753 def __svld_gen_asm(op, state):
1754 # type: (Op, GenAsmState) -> None
1755 RA = state.sgpr(op.input_vals[0])
1756 RT = state.vgpr(op.outputs[0])
1757 imm = op.immediates[0]
1758 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1759 SvLd = GenericOpProperties(
1760 demo_asm="sv.ld *RT, imm(RA)",
1761 inputs=[OD_EXTRA3_SGPR, OD_VL],
1762 outputs=[OD_EXTRA3_VGPR],
1763 immediates=[IMM_S16],
1764 )
1765 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1766 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1767
1768 @staticmethod
1769 def __ld_sim(op, state):
1770 # type: (Op, BaseSimState) -> None
1771 RA, = state[op.input_vals[0]]
1772 addr = RA + op.immediates[0]
1773 v = state.load(addr)
1774 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1775
1776 @staticmethod
1777 def __ld_gen_asm(op, state):
1778 # type: (Op, GenAsmState) -> None
1779 RA = state.sgpr(op.input_vals[0])
1780 RT = state.sgpr(op.outputs[0])
1781 imm = op.immediates[0]
1782 state.writeln(f"ld {RT}, {imm}({RA})")
1783 Ld = GenericOpProperties(
1784 demo_asm="ld RT, imm(RA)",
1785 inputs=[OD_BASE_SGPR],
1786 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1787 immediates=[IMM_S16],
1788 )
1789 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1790 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1791
1792 @staticmethod
1793 def __svstd_sim(op, state):
1794 # type: (Op, BaseSimState) -> None
1795 RS = state[op.input_vals[0]]
1796 RA, = state[op.input_vals[1]]
1797 VL, = state[op.input_vals[2]]
1798 addr = RA + op.immediates[0]
1799 for i in range(VL):
1800 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1801
1802 @staticmethod
1803 def __svstd_gen_asm(op, state):
1804 # type: (Op, GenAsmState) -> None
1805 RS = state.vgpr(op.input_vals[0])
1806 RA = state.sgpr(op.input_vals[1])
1807 imm = op.immediates[0]
1808 state.writeln(f"sv.std {RS}, {imm}({RA})")
1809 SvStd = GenericOpProperties(
1810 demo_asm="sv.std *RS, imm(RA)",
1811 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1812 outputs=[],
1813 immediates=[IMM_S16],
1814 has_side_effects=True,
1815 )
1816 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1817 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1818
1819 @staticmethod
1820 def __std_sim(op, state):
1821 # type: (Op, BaseSimState) -> None
1822 RS, = state[op.input_vals[0]]
1823 RA, = state[op.input_vals[1]]
1824 addr = RA + op.immediates[0]
1825 state.store(addr, value=RS)
1826
1827 @staticmethod
1828 def __std_gen_asm(op, state):
1829 # type: (Op, GenAsmState) -> None
1830 RS = state.sgpr(op.input_vals[0])
1831 RA = state.sgpr(op.input_vals[1])
1832 imm = op.immediates[0]
1833 state.writeln(f"std {RS}, {imm}({RA})")
1834 Std = GenericOpProperties(
1835 demo_asm="std RS, imm(RA)",
1836 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1837 outputs=[],
1838 immediates=[IMM_S16],
1839 has_side_effects=True,
1840 )
1841 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1842 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1843
1844 @staticmethod
1845 def __funcargr3_sim(op, state):
1846 # type: (Op, BaseSimState) -> None
1847 pass # return value set before simulation
1848
1849 @staticmethod
1850 def __funcargr3_gen_asm(op, state):
1851 # type: (Op, GenAsmState) -> None
1852 pass # no instructions needed
1853 FuncArgR3 = GenericOpProperties(
1854 demo_asm="",
1855 inputs=[],
1856 outputs=[OD_BASE_SGPR.with_fixed_loc(
1857 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1858 )
1859 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1860 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1861
1862
1863 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1864 class SSAValOrUse(metaclass=InternedMeta):
1865 __slots__ = "op", "operand_idx"
1866
1867 def __init__(self, op, operand_idx):
1868 # type: (Op, int) -> None
1869 super().__init__()
1870 self.op = op
1871 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1872 raise ValueError("invalid operand_idx")
1873 self.operand_idx = operand_idx
1874
1875 @abstractmethod
1876 def __repr__(self):
1877 # type: () -> str
1878 ...
1879
1880 @property
1881 @abstractmethod
1882 def descriptor_array(self):
1883 # type: () -> tuple[OperandDesc, ...]
1884 ...
1885
1886 @cached_property
1887 def defining_descriptor(self):
1888 # type: () -> OperandDesc
1889 return self.descriptor_array[self.operand_idx]
1890
1891 @cached_property
1892 def ty(self):
1893 # type: () -> Ty
1894 return self.defining_descriptor.ty
1895
1896 @cached_property
1897 def ty_before_spread(self):
1898 # type: () -> Ty
1899 return self.defining_descriptor.ty_before_spread
1900
1901 @property
1902 def base_ty(self):
1903 # type: () -> BaseTy
1904 return self.ty_before_spread.base_ty
1905
1906 @property
1907 def reg_offset_in_unspread(self):
1908 """ the number of reg-sized slots in the unspread Loc before self's Loc
1909
1910 e.g. if the unspread Loc containing self is:
1911 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1912 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1913 then reg_offset_into_unspread == 2 == 10 - 8
1914 """
1915 return self.defining_descriptor.reg_offset_in_unspread
1916
1917 @property
1918 def unspread_start_idx(self):
1919 # type: () -> int
1920 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1921
1922 @property
1923 def unspread_start(self):
1924 # type: () -> Self
1925 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1926
1927
1928 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1929 @final
1930 class SSAVal(SSAValOrUse):
1931 __slots__ = ()
1932
1933 def __repr__(self):
1934 # type: () -> str
1935 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1936
1937 @cached_property
1938 def def_loc_set_before_spread(self):
1939 # type: () -> LocSet
1940 return self.defining_descriptor.loc_set_before_spread
1941
1942 @cached_property
1943 def descriptor_array(self):
1944 # type: () -> tuple[OperandDesc, ...]
1945 return self.op.properties.outputs
1946
1947 @cached_property
1948 def tied_input(self):
1949 # type: () -> None | SSAUse
1950 if self.defining_descriptor.tied_input_index is None:
1951 return None
1952 return SSAUse(op=self.op,
1953 operand_idx=self.defining_descriptor.tied_input_index)
1954
1955 @property
1956 def write_stage(self):
1957 # type: () -> OpStage
1958 return self.defining_descriptor.write_stage
1959
1960 @property
1961 def current_debugging_value(self):
1962 # type: () -> tuple[int, ...]
1963 """ get the current value for debugging in pdb or similar.
1964
1965 This is intended for use with
1966 `PreRASimState.set_current_debugging_state`.
1967
1968 This is only intended for debugging, do not use in unit tests or
1969 production code.
1970 """
1971 return PreRASimState.get_current_debugging_state()[self]
1972
1973 @cached_property
1974 def ssa_val_sub_regs(self):
1975 # type: () -> tuple[SSAValSubReg, ...]
1976 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
1977
1978
1979 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1980 @final
1981 class SSAUse(SSAValOrUse):
1982 __slots__ = ()
1983
1984 @cached_property
1985 def use_loc_set_before_spread(self):
1986 # type: () -> LocSet
1987 return self.defining_descriptor.loc_set_before_spread
1988
1989 @cached_property
1990 def descriptor_array(self):
1991 # type: () -> tuple[OperandDesc, ...]
1992 return self.op.properties.inputs
1993
1994 def __repr__(self):
1995 # type: () -> str
1996 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1997
1998 @property
1999 def ssa_val(self):
2000 # type: () -> SSAVal
2001 return self.op.input_vals[self.operand_idx]
2002
2003 @ssa_val.setter
2004 def ssa_val(self, ssa_val):
2005 # type: (SSAVal) -> None
2006 self.op.input_vals[self.operand_idx] = ssa_val
2007
2008
2009 _T = TypeVar("_T")
2010 _Desc = TypeVar("_Desc")
2011
2012
2013 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
2014 @abstractmethod
2015 def _verify_write_with_desc(self, idx, item, desc):
2016 # type: (int, _T | Any, _Desc) -> None
2017 raise NotImplementedError
2018
2019 @final
2020 def _verify_write(self, idx, item):
2021 # type: (int | Any, _T | Any) -> int
2022 if not isinstance(idx, int):
2023 if isinstance(idx, slice):
2024 raise TypeError(
2025 f"can't write to slice of {self.__class__.__name__}")
2026 raise TypeError(f"can't write with index {idx!r}")
2027 # normalize idx, raising IndexError if it is out of range
2028 idx = range(len(self.descriptors))[idx]
2029 desc = self.descriptors[idx]
2030 self._verify_write_with_desc(idx, item, desc)
2031 return idx
2032
2033 def _on_set(self, idx, new_item, old_item):
2034 # type: (int, _T, _T | None) -> None
2035 pass
2036
2037 @abstractmethod
2038 def _get_descriptors(self):
2039 # type: () -> tuple[_Desc, ...]
2040 raise NotImplementedError
2041
2042 @cached_property
2043 @final
2044 def descriptors(self):
2045 # type: () -> tuple[_Desc, ...]
2046 return self._get_descriptors()
2047
2048 @property
2049 @final
2050 def op(self):
2051 return self.__op
2052
2053 def __init__(self, items, op):
2054 # type: (Iterable[_T], Op) -> None
2055 super().__init__()
2056 self.__op = op
2057 self.__items = [] # type: list[_T]
2058 for idx, item in enumerate(items):
2059 if idx >= len(self.descriptors):
2060 raise ValueError("too many items")
2061 _ = self._verify_write(idx, item)
2062 self.__items.append(item)
2063 if len(self.__items) < len(self.descriptors):
2064 raise ValueError("not enough items")
2065
2066 @final
2067 def __iter__(self):
2068 # type: () -> Iterator[_T]
2069 yield from self.__items
2070
2071 @overload
2072 def __getitem__(self, idx):
2073 # type: (int) -> _T
2074 ...
2075
2076 @overload
2077 def __getitem__(self, idx):
2078 # type: (slice) -> list[_T]
2079 ...
2080
2081 @final
2082 def __getitem__(self, idx):
2083 # type: (int | slice) -> _T | list[_T]
2084 return self.__items[idx]
2085
2086 @final
2087 def __setitem__(self, idx, item):
2088 # type: (int, _T) -> None
2089 idx = self._verify_write(idx, item)
2090 self.__items[idx] = item
2091
2092 @final
2093 def __len__(self):
2094 # type: () -> int
2095 return len(self.__items)
2096
2097 def __repr__(self):
2098 # type: () -> str
2099 return f"{self.__class__.__name__}({self.__items}, op=...)"
2100
2101
2102 @final
2103 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2104 def _get_descriptors(self):
2105 # type: () -> tuple[OperandDesc, ...]
2106 return self.op.properties.inputs
2107
2108 def _verify_write_with_desc(self, idx, item, desc):
2109 # type: (int, SSAVal | Any, OperandDesc) -> None
2110 if not isinstance(item, SSAVal):
2111 raise TypeError("expected value of type SSAVal")
2112 if item.ty != desc.ty:
2113 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2114 f"corresponding input's type {desc.ty!r}")
2115
2116 def _on_set(self, idx, new_item, old_item):
2117 # type: (int, SSAVal, SSAVal | None) -> None
2118 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2119
2120 def __init__(self, items, op):
2121 # type: (Iterable[SSAVal], Op) -> None
2122 if hasattr(op, "inputs"):
2123 raise ValueError("Op.inputs already set")
2124 super().__init__(items, op)
2125
2126
2127 @final
2128 class OpImmediates(OpInputSeq[int, range]):
2129 def _get_descriptors(self):
2130 # type: () -> tuple[range, ...]
2131 return self.op.properties.immediates
2132
2133 def _verify_write_with_desc(self, idx, item, desc):
2134 # type: (int, int | Any, range) -> None
2135 if not isinstance(item, int):
2136 raise TypeError("expected value of type int")
2137 if item not in desc:
2138 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2139
2140 def __init__(self, items, op):
2141 # type: (Iterable[int], Op) -> None
2142 if hasattr(op, "immediates"):
2143 raise ValueError("Op.immediates already set")
2144 super().__init__(items, op)
2145
2146
2147 @plain_data(frozen=True, eq=False, repr=False)
2148 @final
2149 class Op:
2150 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2151 "outputs", "name")
2152
2153 def __init__(self, fn, properties, input_vals, immediates, name=""):
2154 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2155 self.fn = fn
2156 self.properties = properties
2157 self.input_vals = OpInputVals(input_vals, op=self)
2158 inputs_len = len(self.properties.inputs)
2159 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2160 self.immediates = OpImmediates(immediates, op=self)
2161 outputs_len = len(self.properties.outputs)
2162 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2163 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2164
2165 @property
2166 def kind(self):
2167 # type: () -> OpKind
2168 return self.properties.kind
2169
2170 def __eq__(self, other):
2171 # type: (Op | Any) -> bool
2172 if isinstance(other, Op):
2173 return self is other
2174 return NotImplemented
2175
2176 def __hash__(self):
2177 # type: () -> int
2178 return object.__hash__(self)
2179
2180 def __repr__(self, wrap_width=63, indent=" "):
2181 # type: (int, str) -> str
2182 WRAP_POINT = "\u200B" # zero-width space
2183 items = [f"{self.name}:\n"]
2184 for i, out in enumerate(self.outputs):
2185 item = f"<...outputs[{i}]: {out.ty}>"
2186 if i == 0:
2187 item = "(" + WRAP_POINT + item
2188 if i != len(self.outputs) - 1:
2189 item += ", " + WRAP_POINT
2190 else:
2191 item += WRAP_POINT + ") <= "
2192 items.append(item)
2193 items.append(self.kind._name_)
2194 if len(self.input_vals) + len(self.immediates) != 0:
2195 items[-1] += "("
2196 items[-1] += WRAP_POINT
2197 for i, inp in enumerate(self.input_vals):
2198 item = repr(inp)
2199 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2200 item += ", " + WRAP_POINT
2201 else:
2202 item += ") " + WRAP_POINT
2203 items.append(item)
2204 for i, imm in enumerate(self.immediates):
2205 item = hex(imm)
2206 if i != len(self.immediates) - 1:
2207 item += ", " + WRAP_POINT
2208 else:
2209 item += ") " + WRAP_POINT
2210 items.append(item)
2211 lines = [] # type: list[str]
2212 for i, line_in in enumerate("".join(items).splitlines()):
2213 if i != 0:
2214 line_in = indent + line_in
2215 line_out = ""
2216 for part in line_in.split(WRAP_POINT):
2217 if line_out == "":
2218 line_out = part
2219 continue
2220 trial_line_out = line_out + part
2221 if len(trial_line_out.rstrip()) > wrap_width:
2222 lines.append(line_out.rstrip())
2223 line_out = indent + part
2224 else:
2225 line_out = trial_line_out
2226 lines.append(line_out.rstrip())
2227 return "\n".join(lines)
2228
2229 def sim(self, state):
2230 # type: (BaseSimState) -> None
2231 for inp in self.input_vals:
2232 try:
2233 val = state[inp]
2234 except KeyError:
2235 raise ValueError(f"SSAVal {inp} not yet assigned when "
2236 f"running {self}")
2237 except SimSkipOp:
2238 continue
2239 if len(val) != inp.ty.reg_len:
2240 raise ValueError(
2241 f"value of SSAVal {inp} has wrong number of elements: "
2242 f"expected {inp.ty.reg_len} found "
2243 f"{len(val)}: {val!r}")
2244 if isinstance(state, PreRASimState):
2245 for out in self.outputs:
2246 if out in state.ssa_vals:
2247 if self.kind is OpKind.FuncArgR3:
2248 continue
2249 raise ValueError(f"SSAVal {out} already assigned before "
2250 f"running {self}")
2251 try:
2252 self.kind.sim(self, state)
2253 except SimSkipOp:
2254 state.on_skip(self)
2255 for out in self.outputs:
2256 try:
2257 val = state[out]
2258 except KeyError:
2259 raise ValueError(f"running {self} failed to assign to {out}")
2260 except SimSkipOp:
2261 continue
2262 if len(val) != out.ty.reg_len:
2263 raise ValueError(
2264 f"value of SSAVal {out} has wrong number of elements: "
2265 f"expected {out.ty.reg_len} found "
2266 f"{len(val)}: {val!r}")
2267
2268 def gen_asm(self, state):
2269 # type: (GenAsmState) -> None
2270 all_loc_kinds = tuple(LocKind)
2271 for inp in self.input_vals:
2272 state.loc(inp, expected_kinds=all_loc_kinds)
2273 for out in self.outputs:
2274 state.loc(out, expected_kinds=all_loc_kinds)
2275 self.kind.gen_asm(self, state)
2276
2277
2278 @plain_data(frozen=True, repr=False)
2279 class BaseSimState(metaclass=ABCMeta):
2280 __slots__ = "memory",
2281
2282 def __init__(self, memory):
2283 # type: (dict[int, int]) -> None
2284 super().__init__()
2285 self.memory = memory # type: dict[int, int]
2286
2287 def _default_memory_value(self):
2288 # type: () -> int
2289 return 0
2290
2291 def on_skip(self, op):
2292 # type: (Op) -> None
2293 raise ValueError("skipping instructions not supported")
2294
2295 def load_byte(self, addr):
2296 # type: (int) -> int
2297 addr &= GPR_VALUE_MASK
2298 try:
2299 return self.memory[addr] & 0xFF
2300 except KeyError:
2301 return self._default_memory_value()
2302
2303 def store_byte(self, addr, value):
2304 # type: (int, int) -> None
2305 addr &= GPR_VALUE_MASK
2306 value &= 0xFF
2307 self.memory[addr] = value
2308
2309 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2310 # type: (int, int, bool) -> int
2311 if addr % size_in_bytes != 0:
2312 raise ValueError(f"address not aligned: {hex(addr)} "
2313 f"required alignment: {size_in_bytes}")
2314 retval = 0
2315 for i in range(size_in_bytes):
2316 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2317 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2318 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2319 return retval
2320
2321 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2322 # type: (int, int, int) -> None
2323 if addr % size_in_bytes != 0:
2324 raise ValueError(f"address not aligned: {hex(addr)} "
2325 f"required alignment: {size_in_bytes}")
2326 for i in range(size_in_bytes):
2327 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2328
2329 def _memory__repr(self):
2330 # type: () -> str
2331 if len(self.memory) == 0:
2332 return "{}"
2333 keys = sorted(self.memory.keys(), reverse=True)
2334 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2335 items = [] # type: list[str]
2336 while len(keys) != 0:
2337 addr = keys[-1]
2338 if (len(keys) >= CHUNK_SIZE
2339 and addr % CHUNK_SIZE == 0
2340 and keys[-CHUNK_SIZE:]
2341 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2342 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2343 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2344 keys[-CHUNK_SIZE:] = ()
2345 else:
2346 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2347 if len(items) == 1:
2348 return f"{{{items[0]}}}"
2349 items_str = ",\n".join(items)
2350 return f"{{\n{items_str}}}"
2351
2352 def __repr__(self):
2353 # type: () -> str
2354 field_vals = [] # type: list[str]
2355 for name in fields(self):
2356 try:
2357 value = getattr(self, name)
2358 except AttributeError:
2359 field_vals.append(f"{name}=<not set>")
2360 continue
2361 repr_fn = getattr(self, f"_{name}__repr", None)
2362 if callable(repr_fn):
2363 field_vals.append(f"{name}={repr_fn()}")
2364 else:
2365 field_vals.append(f"{name}={value!r}")
2366 field_vals_str = ", ".join(field_vals)
2367 return f"{self.__class__.__name__}({field_vals_str})"
2368
2369 @abstractmethod
2370 def __getitem__(self, ssa_val):
2371 # type: (SSAVal) -> tuple[int, ...]
2372 ...
2373
2374 @abstractmethod
2375 def __setitem__(self, ssa_val, value):
2376 # type: (SSAVal, Iterable[int]) -> None
2377 ...
2378
2379
2380 @plain_data(frozen=True, repr=False)
2381 class PreRABaseSimState(BaseSimState):
2382 __slots__ = "ssa_vals",
2383
2384 def __init__(self, ssa_vals, memory):
2385 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2386 super().__init__(memory)
2387 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2388
2389 def _ssa_vals__repr(self):
2390 # type: () -> str
2391 if len(self.ssa_vals) == 0:
2392 return "{}"
2393 items = [] # type: list[str]
2394 CHUNK_SIZE = 4
2395 for k, v in self.ssa_vals.items():
2396 element_strs = [] # type: list[str]
2397 for i, el in enumerate(v):
2398 if i % CHUNK_SIZE != 0:
2399 element_strs.append(" " + hex(el))
2400 else:
2401 element_strs.append("\n " + hex(el))
2402 if len(element_strs) <= CHUNK_SIZE:
2403 element_strs[0] = element_strs[0].lstrip()
2404 if len(element_strs) == 1:
2405 element_strs.append("")
2406 v_str = ",".join(element_strs)
2407 items.append(f"{k!r}: ({v_str})")
2408 if len(items) == 1 and "\n" not in items[0]:
2409 return f"{{{items[0]}}}"
2410 items_str = ",\n".join(items)
2411 return f"{{\n{items_str},\n}}"
2412
2413 def __getitem__(self, ssa_val):
2414 # type: (SSAVal) -> tuple[int, ...]
2415 try:
2416 return self.ssa_vals[ssa_val]
2417 except KeyError:
2418 return self._handle_undefined_ssa_val(ssa_val)
2419
2420 def _handle_undefined_ssa_val(self, ssa_val):
2421 # type: (SSAVal) -> tuple[int, ...]
2422 raise KeyError("SSAVal has no value set", ssa_val)
2423
2424 def __setitem__(self, ssa_val, value):
2425 # type: (SSAVal, Iterable[int]) -> None
2426 value = tuple(map(int, value))
2427 if len(value) != ssa_val.ty.reg_len:
2428 raise ValueError("value has wrong len")
2429 self.ssa_vals[ssa_val] = value
2430
2431
2432 class SimSkipOp(Exception):
2433 pass
2434
2435
2436 @plain_data(frozen=True, repr=False)
2437 @final
2438 class ConstPropagationState(PreRABaseSimState):
2439 __slots__ = "skipped_ops",
2440
2441 def __init__(self, ssa_vals, memory, skipped_ops):
2442 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2443 super().__init__(ssa_vals, memory)
2444 self.skipped_ops = skipped_ops
2445
2446 def _default_memory_value(self):
2447 # type: () -> int
2448 raise SimSkipOp
2449
2450 def _handle_undefined_ssa_val(self, ssa_val):
2451 # type: (SSAVal) -> tuple[int, ...]
2452 raise SimSkipOp
2453
2454 def on_skip(self, op):
2455 # type: (Op) -> None
2456 self.skipped_ops.add(op)
2457
2458
2459 @plain_data(frozen=True, repr=False)
2460 class PreRASimState(PreRABaseSimState):
2461 __slots__ = ()
2462
2463 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2464
2465 @contextmanager
2466 def set_as_current_debugging_state(self):
2467 """ return a context manager that sets self as the current state for
2468 debugging in pdb or similar. This is intended only for use with
2469 `get_current_debugging_state` which should not be used in unit tests
2470 or production code.
2471 """
2472 try:
2473 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2474 yield
2475 finally:
2476 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2477 "inconsistent __CURRENT_DEBUGGING_STATE"
2478
2479 @staticmethod
2480 def get_current_debugging_state():
2481 # type: () -> PreRASimState
2482 """ get the current state for debugging in pdb or similar.
2483
2484 This is intended for use with `set_current_debugging_state`.
2485
2486 This is only intended for debugging, do not use in unit tests or
2487 production code.
2488 """
2489 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2490 raise ValueError("no current debugging state")
2491 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2492
2493
2494 @plain_data(frozen=True, repr=False)
2495 @final
2496 class PostRASimState(BaseSimState):
2497 __slots__ = "ssa_val_to_loc_map", "loc_values"
2498
2499 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2500 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2501 super().__init__(memory)
2502 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2503 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2504 if ssa_val.ty != loc.ty:
2505 raise ValueError(
2506 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2507 self.loc_values = loc_values
2508 for loc in self.loc_values.keys():
2509 if loc.reg_len != 1:
2510 raise ValueError(
2511 "loc_values must only contain Locs with reg_len=1, all "
2512 "larger Locs will be split into reg_len=1 sub-Locs")
2513
2514 def _loc_values__repr(self):
2515 # type: () -> str
2516 locs = sorted(self.loc_values.keys(),
2517 key=lambda v: (v.kind.name, v.start))
2518 items = [] # type: list[str]
2519 for loc in locs:
2520 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2521 items_str = ",\n".join(items)
2522 return f"{{\n{items_str},\n}}"
2523
2524 def __getitem__(self, ssa_val):
2525 # type: (SSAVal) -> tuple[int, ...]
2526 loc = self.ssa_val_to_loc_map[ssa_val]
2527 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2528 retval = [] # type: list[int]
2529 for i in range(loc.reg_len):
2530 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2531 retval.append(self.loc_values.get(subloc, 0))
2532 return tuple(retval)
2533
2534 def __setitem__(self, ssa_val, value):
2535 # type: (SSAVal, Iterable[int]) -> None
2536 value = tuple(map(int, value))
2537 if len(value) != ssa_val.ty.reg_len:
2538 raise ValueError("value has wrong len")
2539 loc = self.ssa_val_to_loc_map[ssa_val]
2540 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2541 for i in range(loc.reg_len):
2542 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2543 self.loc_values[subloc] = value[i]
2544
2545
2546 @plain_data(frozen=True)
2547 class GenAsmState:
2548 __slots__ = "allocated_locs", "output"
2549
2550 def __init__(self, allocated_locs, output=None):
2551 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2552 super().__init__()
2553 self.allocated_locs = FMap(allocated_locs)
2554 for ssa_val, loc in self.allocated_locs.items():
2555 if ssa_val.ty != loc.ty:
2556 raise ValueError(
2557 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2558 if output is None:
2559 output = []
2560 self.output = output
2561
2562 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2563
2564 def loc(self, ssa_val_or_locs, expected_kinds):
2565 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2566 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2567 ssa_val_or_locs = [ssa_val_or_locs]
2568 locs = [] # type: list[Loc]
2569 for i in ssa_val_or_locs:
2570 if isinstance(i, SSAVal):
2571 locs.append(self.allocated_locs[i])
2572 else:
2573 locs.append(i)
2574 if len(locs) == 0:
2575 raise ValueError("invalid Loc sequence: must not be empty")
2576 retval = locs[0].try_concat(*locs[1:])
2577 if retval is None:
2578 raise ValueError("invalid Loc sequence: try_concat failed")
2579 if isinstance(expected_kinds, LocKind):
2580 expected_kinds = expected_kinds,
2581 if retval.kind not in expected_kinds:
2582 if len(expected_kinds) == 1:
2583 expected_kinds = expected_kinds[0]
2584 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2585 f"{retval.kind} expected {expected_kinds}")
2586 return retval
2587
2588 def gpr(self, ssa_val_or_locs, is_vec):
2589 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2590 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2591 vec_str = "*" if is_vec else ""
2592 return vec_str + str(loc.start)
2593
2594 def sgpr(self, ssa_val_or_locs):
2595 # type: (__SSA_VAL_OR_LOCS) -> str
2596 return self.gpr(ssa_val_or_locs, is_vec=False)
2597
2598 def vgpr(self, ssa_val_or_locs):
2599 # type: (__SSA_VAL_OR_LOCS) -> str
2600 return self.gpr(ssa_val_or_locs, is_vec=True)
2601
2602 def stack(self, ssa_val_or_locs):
2603 # type: (__SSA_VAL_OR_LOCS) -> str
2604 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2605 return f"{loc.start}(1)"
2606
2607 def writeln(self, *line_segments):
2608 # type: (*str) -> None
2609 line = " ".join(line_segments)
2610 if isinstance(self.output, list):
2611 self.output.append(line)
2612 else:
2613 self.output.write(line + "\n")