finished __mergable_check
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 from abc import ABCMeta, abstractmethod
4 from enum import Enum, unique
5 from functools import lru_cache, total_ordering
6 from io import StringIO
7 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
8 Mapping, Sequence, TypeVar, Union, overload)
9 from weakref import WeakValueDictionary as _WeakVDict
10
11 from cached_property import cached_property
12 from nmutil.plain_data import fields, plain_data
13
14 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
15 final)
16 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
17 OFSet, OSet)
18
19 GPR_SIZE_IN_BYTES = 8
20 BITS_IN_BYTE = 8
21 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
22 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
23
24
25 @final
26 class Fn:
27 def __init__(self):
28 self.ops = [] # type: list[Op]
29 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
30 self.__next_name_suffix = 2
31
32 def _add_op_with_unused_name(self, op, name=""):
33 # type: (Op, str) -> str
34 if op.fn is not self:
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op, "name"):
37 raise ValueError("Op already named")
38 orig_name = name
39 while True:
40 if name != "" and name not in self.__op_names:
41 self.__op_names[name] = op
42 return name
43 name = orig_name + str(self.__next_name_suffix)
44 self.__next_name_suffix += 1
45
46 def __repr__(self):
47 # type: () -> str
48 return "<Fn>"
49
50 def ops_to_str(self, as_python_literal=False, wrap_width=63,
51 python_indent=" ", indent=" "):
52 # type: (bool, int, str, str) -> str
53 l = [] # type: list[str]
54 for op in self.ops:
55 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
56 retval = "\n".join(l)
57 if as_python_literal:
58 l = [python_indent + "\""]
59 for ch in retval:
60 if ch == "\n":
61 l.append(f"\\n\"\n{python_indent}\"")
62 elif ch in "\"\\":
63 l.append("\\" + ch)
64 elif ch.isascii() and ch.isprintable():
65 l.append(ch)
66 else:
67 l.append(repr(ch).strip("\"'"))
68 l.append("\"")
69 retval = "".join(l)
70 empty_end = f"\"\n{python_indent}\"\""
71 if retval.endswith(empty_end):
72 retval = retval[:-len(empty_end)]
73 return retval
74
75 def append_op(self, op):
76 # type: (Op) -> None
77 if op.fn is not self:
78 raise ValueError("can't add Op to wrong Fn")
79 self.ops.append(op)
80
81 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
82 maxvl=1):
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
85 input_vals=input_vals, immediates=immediates, name=name)
86 self.append_op(retval)
87 return retval
88
89 def sim(self, state):
90 # type: (BaseSimState) -> None
91 for op in self.ops:
92 op.sim(state)
93
94 def gen_asm(self, state):
95 # type: (GenAsmState) -> None
96 for op in self.ops:
97 op.gen_asm(state)
98
99 def pre_ra_insert_copies(self):
100 # type: () -> None
101 orig_ops = list(self.ops)
102 copied_outputs = {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs = {} # type: dict[SSAVal, Op]
104 self.ops.clear()
105 for op in orig_ops:
106 for i in range(len(op.input_vals)):
107 inp = copied_outputs[op.input_vals[i]]
108 if inp.ty.base_ty is BaseTy.I64:
109 maxvl = inp.ty.reg_len
110 if inp.ty.reg_len != 1:
111 setvl = self.append_new_op(
112 OpKind.SetVLI, immediates=[maxvl],
113 name=f"{op.name}.inp{i}.setvl")
114 vl = setvl.outputs[0]
115 mv = self.append_new_op(
116 OpKind.VecCopyToReg, input_vals=[inp, vl],
117 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
118 else:
119 mv = self.append_new_op(
120 OpKind.CopyToReg, input_vals=[inp],
121 name=f"{op.name}.inp{i}.copy")
122 op.input_vals[i] = mv.outputs[0]
123 elif inp.ty.base_ty is BaseTy.CA \
124 or inp.ty.base_ty is BaseTy.VL_MAXVL:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
127 # before the ops VL
128 if inp in setvli_outputs:
129 setvl = self.append_new_op(
130 OpKind.SetVLI,
131 immediates=setvli_outputs[inp].immediates,
132 name=f"{op.name}.inp{i}.setvl")
133 inp = setvl.outputs[0]
134 op.input_vals[i] = inp
135 else:
136 assert_never(inp.ty.base_ty)
137 self.ops.append(op)
138 for i, out in enumerate(op.outputs):
139 if op.kind is OpKind.SetVLI:
140 setvli_outputs[out] = op
141 if out.ty.base_ty is BaseTy.I64:
142 maxvl = out.ty.reg_len
143 if out.ty.reg_len != 1:
144 setvl = self.append_new_op(
145 OpKind.SetVLI, immediates=[maxvl],
146 name=f"{op.name}.out{i}.setvl")
147 vl = setvl.outputs[0]
148 mv = self.append_new_op(
149 OpKind.VecCopyFromReg, input_vals=[out, vl],
150 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
151 else:
152 mv = self.append_new_op(
153 OpKind.CopyFromReg, input_vals=[out],
154 name=f"{op.name}.out{i}.copy")
155 copied_outputs[out] = mv.outputs[0]
156 elif out.ty.base_ty is BaseTy.CA \
157 or out.ty.base_ty is BaseTy.VL_MAXVL:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs[out] = out
160 else:
161 assert_never(out.ty.base_ty)
162
163
164 @final
165 @unique
166 @total_ordering
167 class OpStage(Enum):
168 value: Literal[0, 1] # type: ignore
169
170 def __new__(cls, value):
171 # type: (int) -> OpStage
172 value = int(value)
173 if value not in (0, 1):
174 raise ValueError("invalid value")
175 retval = object.__new__(cls)
176 retval._value_ = value
177 return retval
178
179 Early = 0
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
184
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
187 """
188 Late = 1
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
192 those outputs.
193
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
196 """
197
198 def __repr__(self):
199 # type: () -> str
200 return f"OpStage.{self._name_}"
201
202 def __lt__(self, other):
203 # type: (OpStage | object) -> bool
204 if isinstance(other, OpStage):
205 return self.value < other.value
206 return NotImplemented
207
208
209 assert OpStage.Early < OpStage.Late, "early must be less than late"
210
211
212 @plain_data(frozen=True, unsafe_hash=True, repr=False)
213 @final
214 @total_ordering
215 class ProgramPoint(metaclass=InternedMeta):
216 __slots__ = "op_index", "stage"
217
218 def __init__(self, op_index, stage):
219 # type: (int, OpStage) -> None
220 self.op_index = op_index
221 self.stage = stage
222
223 @property
224 def int_value(self):
225 # type: () -> int
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
228 """
229 return self.op_index * 2 + self.stage.value
230
231 @staticmethod
232 def from_int_value(int_value):
233 # type: (int) -> ProgramPoint
234 op_index, stage = divmod(int_value, 2)
235 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
236
237 def next(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint.from_int_value(self.int_value + steps)
240
241 def prev(self, steps=1):
242 # type: (int) -> ProgramPoint
243 return self.next(steps=-steps)
244
245 def __lt__(self, other):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other, ProgramPoint):
248 return NotImplemented
249 if self.op_index != other.op_index:
250 return self.op_index < other.op_index
251 return self.stage < other.stage
252
253 def __repr__(self):
254 # type: () -> str
255 return f"<ops[{self.op_index}]:{self.stage._name_}>"
256
257
258 @plain_data(frozen=True, unsafe_hash=True, repr=False)
259 @final
260 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
261 __slots__ = "start", "stop"
262
263 def __init__(self, start, stop):
264 # type: (ProgramPoint, ProgramPoint) -> None
265 self.start = start
266 self.stop = stop
267
268 @cached_property
269 def int_value_range(self):
270 # type: () -> range
271 return range(self.start.int_value, self.stop.int_value)
272
273 @staticmethod
274 def from_int_value_range(int_value_range):
275 # type: (range) -> ProgramRange
276 if int_value_range.step != 1:
277 raise ValueError("int_value_range must have step == 1")
278 return ProgramRange(
279 start=ProgramPoint.from_int_value(int_value_range.start),
280 stop=ProgramPoint.from_int_value(int_value_range.stop))
281
282 @overload
283 def __getitem__(self, __idx):
284 # type: (int) -> ProgramPoint
285 ...
286
287 @overload
288 def __getitem__(self, __idx):
289 # type: (slice) -> ProgramRange
290 ...
291
292 def __getitem__(self, __idx):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v = range(self.start.int_value, self.stop.int_value)[__idx]
295 if isinstance(v, int):
296 return ProgramPoint.from_int_value(v)
297 return ProgramRange.from_int_value_range(v)
298
299 def __len__(self):
300 # type: () -> int
301 return len(self.int_value_range)
302
303 def __iter__(self):
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint.from_int_value, self.int_value_range)
306
307 def __repr__(self):
308 # type: () -> str
309 start = repr(self.start).lstrip("<").rstrip(">")
310 stop = repr(self.stop).lstrip("<").rstrip(">")
311 return f"<range:{start}..{stop}>"
312
313
314 @plain_data(frozen=True, unsafe_hash=True, repr=False)
315 @final
316 class SSAValSubReg(metaclass=InternedMeta):
317 __slots__ = "ssa_val", "reg_idx"
318
319 def __init__(self, ssa_val, reg_idx):
320 # type: (SSAVal, int) -> None
321 if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
322 raise ValueError("reg_idx out of range")
323 self.ssa_val = ssa_val
324 self.reg_idx = reg_idx
325
326 def __repr__(self):
327 # type: () -> str
328 return f"{self.ssa_val}[{self.reg_idx}]"
329
330
331 @plain_data(frozen=True, eq=False, repr=False)
332 @final
333 class FnAnalysis:
334 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
335 "def_program_ranges", "use_program_points",
336 "all_program_points")
337
338 def __init__(self, fn):
339 # type: (Fn) -> None
340 self.fn = fn
341 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
342 self.all_program_points = ProgramRange(
343 start=ProgramPoint(op_index=0, stage=OpStage.Early),
344 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
345 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
346 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
347 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
348 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
349 for op in fn.ops:
350 for use in op.input_uses:
351 uses[use.ssa_val].add(use)
352 use_program_point = self.__get_use_program_point(use)
353 use_program_points[use] = use_program_point
354 live_range_stops[use.ssa_val] = max(
355 live_range_stops[use.ssa_val], use_program_point.next())
356 for out in op.outputs:
357 uses[out] = OSet()
358 def_program_range = self.__get_def_program_range(out)
359 def_program_ranges[out] = def_program_range
360 live_range_stops[out] = def_program_range.stop
361 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
362 self.def_program_ranges = FMap(def_program_ranges)
363 self.use_program_points = FMap(use_program_points)
364 live_ranges = {} # type: dict[SSAVal, ProgramRange]
365 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
366 for ssa_val in uses.keys():
367 live_ranges[ssa_val] = live_range = ProgramRange(
368 start=self.def_program_ranges[ssa_val].start,
369 stop=live_range_stops[ssa_val])
370 for program_point in live_range:
371 live_at[program_point].add(ssa_val)
372 self.live_ranges = FMap(live_ranges)
373 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
374 self.copies # initialize
375 self.const_ssa_vals # initialize
376 self.const_ssa_val_sub_regs # initialize
377
378 def __get_def_program_range(self, ssa_val):
379 # type: (SSAVal) -> ProgramRange
380 write_stage = ssa_val.defining_descriptor.write_stage
381 start = ProgramPoint(
382 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
383 # always include late stage of ssa_val.op, to ensure outputs always
384 # overlap all other outputs.
385 # stop is exclusive, so we need the next program point.
386 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
387 return ProgramRange(start=start, stop=stop)
388
389 def __get_use_program_point(self, ssa_use):
390 # type: (SSAUse) -> ProgramPoint
391 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
392 "assumed here, ensured by GenericOpProperties.__init__"
393 return ProgramPoint(
394 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
395
396 def __eq__(self, other):
397 # type: (FnAnalysis | Any) -> bool
398 if isinstance(other, FnAnalysis):
399 return self.fn == other.fn
400 return NotImplemented
401
402 def __hash__(self):
403 # type: () -> int
404 return hash(self.fn)
405
406 def __repr__(self):
407 # type: () -> str
408 return "<FnAnalysis>"
409
410 @cached_property
411 def copies(self):
412 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
413 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
414 a copy of, looking through all layers of copies. The map excludes all
415 SSAValSubRegs that aren't copies of other SSAValSubRegs.
416 """
417 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
418 for op in self.op_indexes.keys():
419 if not op.properties.is_copy:
420 continue
421 copy_reg_len = op.properties.copy_reg_len
422 copy_inputs = [] # type: list[SSAValSubReg]
423 for inp in op.input_vals[:op.properties.copy_inputs_len]:
424 for inp_sub_reg in inp.ssa_val_sub_regs:
425 # propagate copies of copies
426 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
427 copy_inputs.append(inp_sub_reg)
428 assert len(copy_inputs) == copy_reg_len, "logic error"
429 copy_outputs = [] # type: list[SSAValSubReg]
430 for out in op.outputs[:op.properties.copy_outputs_len]:
431 copy_outputs.extend(out.ssa_val_sub_regs)
432 assert len(copy_outputs) == copy_reg_len, "logic error"
433 for inp, out in zip(copy_inputs, copy_outputs):
434 retval[out] = inp
435 return FMap(retval)
436
437 @cached_property
438 def const_ssa_vals(self):
439 # type: () -> FMap[SSAVal, tuple[int, ...]]
440 state = ConstPropagationState(
441 ssa_vals={}, memory={}, skipped_ops=OSet())
442 self.fn.sim(state)
443 return FMap(state.ssa_vals)
444
445 @cached_property
446 def const_ssa_val_sub_regs(self):
447 # type: () -> FMap[SSAValSubReg, int]
448 retval = {} # type: dict[SSAValSubReg, int]
449 for ssa_val, const_val in self.const_ssa_vals.items():
450 assert ssa_val.ty.reg_len == len(const_val), "logic error"
451 for reg_idx, v in enumerate(const_val):
452 retval[SSAValSubReg(ssa_val, reg_idx)] = v
453 return FMap(retval)
454
455 def is_always_equal(self, a, b):
456 # type: (SSAValSubReg, SSAValSubReg) -> bool
457 """check if a and b are known to be always equal to each other.
458 This means they can be allocated to the same location if other
459 constraints don't prevent that.
460
461 this can happen for a number of reasons, such as:
462 * a and b are copies of the same thing
463 * a and b are known to be constants and they have the same value
464 """
465 if a.ssa_val.base_ty != b.ssa_val.base_ty:
466 return False # can't be equal, they have different types
467 # look through copies
468 a = self.copies.get(a, a)
469 b = self.copies.get(b, b)
470 if a == b:
471 return True
472 # check if they have the same constant value
473 try:
474 a_const_val = self.const_ssa_val_sub_regs[a]
475 b_const_val = self.const_ssa_val_sub_regs[b]
476 if a_const_val == b_const_val:
477 return True
478 except KeyError:
479 pass
480 return False
481
482
483 @unique
484 @final
485 class BaseTy(Enum):
486 I64 = enum.auto()
487 CA = enum.auto()
488 VL_MAXVL = enum.auto()
489
490 @cached_property
491 def only_scalar(self):
492 # type: () -> bool
493 if self is BaseTy.I64:
494 return False
495 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
496 return True
497 else:
498 assert_never(self)
499
500 @cached_property
501 def max_reg_len(self):
502 # type: () -> int
503 if self is BaseTy.I64:
504 return 128
505 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
506 return 1
507 else:
508 assert_never(self)
509
510 def __repr__(self):
511 return "BaseTy." + self._name_
512
513
514 @plain_data(frozen=True, unsafe_hash=True, repr=False)
515 @final
516 class Ty(metaclass=InternedMeta):
517 __slots__ = "base_ty", "reg_len"
518
519 @staticmethod
520 def validate(base_ty, reg_len):
521 # type: (BaseTy, int) -> str | None
522 """ return a string with the error if the combination is invalid,
523 otherwise return None
524 """
525 if base_ty.only_scalar and reg_len != 1:
526 return f"can't create a vector of an only-scalar type: {base_ty}"
527 if reg_len < 1 or reg_len > base_ty.max_reg_len:
528 return "reg_len out of range"
529 return None
530
531 def __init__(self, base_ty, reg_len):
532 # type: (BaseTy, int) -> None
533 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
534 if msg is not None:
535 raise ValueError(msg)
536 self.base_ty = base_ty
537 self.reg_len = reg_len
538
539 def __repr__(self):
540 # type: () -> str
541 if self.reg_len != 1:
542 reg_len = f"*{self.reg_len}"
543 else:
544 reg_len = ""
545 return f"<{self.base_ty._name_}{reg_len}>"
546
547
548 @unique
549 @final
550 class LocKind(Enum):
551 GPR = enum.auto()
552 StackI64 = enum.auto()
553 CA = enum.auto()
554 VL_MAXVL = enum.auto()
555
556 @cached_property
557 def base_ty(self):
558 # type: () -> BaseTy
559 if self is LocKind.GPR or self is LocKind.StackI64:
560 return BaseTy.I64
561 if self is LocKind.CA:
562 return BaseTy.CA
563 if self is LocKind.VL_MAXVL:
564 return BaseTy.VL_MAXVL
565 else:
566 assert_never(self)
567
568 @cached_property
569 def loc_count(self):
570 # type: () -> int
571 if self is LocKind.StackI64:
572 return 512
573 if self is LocKind.GPR or self is LocKind.CA \
574 or self is LocKind.VL_MAXVL:
575 return self.base_ty.max_reg_len
576 else:
577 assert_never(self)
578
579 def __repr__(self):
580 return "LocKind." + self._name_
581
582
583 @final
584 @unique
585 class LocSubKind(Enum):
586 BASE_GPR = enum.auto()
587 SV_EXTRA2_VGPR = enum.auto()
588 SV_EXTRA2_SGPR = enum.auto()
589 SV_EXTRA3_VGPR = enum.auto()
590 SV_EXTRA3_SGPR = enum.auto()
591 StackI64 = enum.auto()
592 CA = enum.auto()
593 VL_MAXVL = enum.auto()
594
595 @cached_property
596 def kind(self):
597 # type: () -> LocKind
598 # pyright fails typechecking when using `in` here:
599 # reported: https://github.com/microsoft/pyright/issues/4102
600 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
601 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
602 LocSubKind.SV_EXTRA3_SGPR):
603 return LocKind.GPR
604 if self is LocSubKind.StackI64:
605 return LocKind.StackI64
606 if self is LocSubKind.CA:
607 return LocKind.CA
608 if self is LocSubKind.VL_MAXVL:
609 return LocKind.VL_MAXVL
610 assert_never(self)
611
612 @property
613 def base_ty(self):
614 return self.kind.base_ty
615
616 @lru_cache()
617 def allocatable_locs(self, ty):
618 # type: (Ty) -> LocSet
619 if ty.base_ty != self.base_ty:
620 raise ValueError("type mismatch")
621 if self is LocSubKind.BASE_GPR:
622 starts = range(32)
623 elif self is LocSubKind.SV_EXTRA2_VGPR:
624 starts = range(0, 128, 2)
625 elif self is LocSubKind.SV_EXTRA2_SGPR:
626 starts = range(64)
627 elif self is LocSubKind.SV_EXTRA3_VGPR \
628 or self is LocSubKind.SV_EXTRA3_SGPR:
629 starts = range(128)
630 elif self is LocSubKind.StackI64:
631 starts = range(LocKind.StackI64.loc_count)
632 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
633 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
634 else:
635 assert_never(self)
636 retval = [] # type: list[Loc]
637 for start in starts:
638 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
639 if loc is None:
640 continue
641 conflicts = False
642 for special_loc in SPECIAL_GPRS:
643 if loc.conflicts(special_loc):
644 conflicts = True
645 break
646 if not conflicts:
647 retval.append(loc)
648 return LocSet(retval)
649
650 def __repr__(self):
651 return "LocSubKind." + self._name_
652
653
654 @plain_data(frozen=True, unsafe_hash=True)
655 @final
656 class GenericTy(metaclass=InternedMeta):
657 __slots__ = "base_ty", "is_vec"
658
659 def __init__(self, base_ty, is_vec):
660 # type: (BaseTy, bool) -> None
661 self.base_ty = base_ty
662 if base_ty.only_scalar and is_vec:
663 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
664 self.is_vec = is_vec
665
666 def instantiate(self, maxvl):
667 # type: (int) -> Ty
668 # here's where subvl and elwid would be accounted for
669 if self.is_vec:
670 return Ty(self.base_ty, maxvl)
671 return Ty(self.base_ty, 1)
672
673 def can_instantiate_to(self, ty):
674 # type: (Ty) -> bool
675 if self.base_ty != ty.base_ty:
676 return False
677 if self.is_vec:
678 return True
679 return ty.reg_len == 1
680
681
682 @plain_data(frozen=True, unsafe_hash=True)
683 @final
684 class Loc(metaclass=InternedMeta):
685 __slots__ = "kind", "start", "reg_len"
686
687 @staticmethod
688 def validate(kind, start, reg_len):
689 # type: (LocKind, int, int) -> str | None
690 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
691 if msg is not None:
692 return msg
693 if reg_len > kind.loc_count:
694 return "invalid reg_len"
695 if start < 0 or start + reg_len > kind.loc_count:
696 return "start not in valid range"
697 return None
698
699 @staticmethod
700 def try_make(kind, start, reg_len):
701 # type: (LocKind, int, int) -> Loc | None
702 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
703 if msg is not None:
704 return None
705 return Loc(kind=kind, start=start, reg_len=reg_len)
706
707 def __init__(self, kind, start, reg_len):
708 # type: (LocKind, int, int) -> None
709 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
710 if msg is not None:
711 raise ValueError(msg)
712 self.kind = kind
713 self.reg_len = reg_len
714 self.start = start
715
716 def conflicts(self, other):
717 # type: (Loc) -> bool
718 return (self.kind == other.kind
719 and self.start < other.stop and other.start < self.stop)
720
721 @staticmethod
722 def make_ty(kind, reg_len):
723 # type: (LocKind, int) -> Ty
724 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
725
726 @cached_property
727 def ty(self):
728 # type: () -> Ty
729 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
730
731 @property
732 def stop(self):
733 # type: () -> int
734 return self.start + self.reg_len
735
736 def try_concat(self, *others):
737 # type: (*Loc | None) -> Loc | None
738 reg_len = self.reg_len
739 stop = self.stop
740 for other in others:
741 if other is None or other.kind != self.kind:
742 return None
743 if stop != other.start:
744 return None
745 stop = other.stop
746 reg_len += other.reg_len
747 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
748
749 def get_subloc_at_offset(self, subloc_ty, offset):
750 # type: (Ty, int) -> Loc
751 if subloc_ty.base_ty != self.kind.base_ty:
752 raise ValueError("BaseTy mismatch")
753 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
754 raise ValueError("invalid sub-Loc: offset and/or "
755 "subloc_ty.reg_len out of range")
756 return Loc(kind=self.kind,
757 start=self.start + offset, reg_len=subloc_ty.reg_len)
758
759
760 SPECIAL_GPRS = (
761 Loc(kind=LocKind.GPR, start=0, reg_len=1),
762 Loc(kind=LocKind.GPR, start=1, reg_len=1),
763 Loc(kind=LocKind.GPR, start=2, reg_len=1),
764 Loc(kind=LocKind.GPR, start=13, reg_len=1),
765 )
766
767
768 @final
769 class LocSet(OFSet[Loc], metaclass=InternedMeta):
770 def __init__(self, __locs=()):
771 # type: (Iterable[Loc]) -> None
772 super().__init__(__locs)
773 if isinstance(__locs, LocSet):
774 self.__starts = __locs.starts
775 self.__ty = __locs.ty
776 return
777 starts = {i: BitSet() for i in LocKind}
778 ty = None # type: None | Ty
779 for loc in self:
780 if ty is None:
781 ty = loc.ty
782 if ty != loc.ty:
783 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
784 starts[loc.kind].add(loc.start)
785 self.__starts = FMap(
786 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
787 self.__ty = ty
788
789 @property
790 def starts(self):
791 # type: () -> FMap[LocKind, FBitSet]
792 return self.__starts
793
794 @property
795 def ty(self):
796 # type: () -> Ty | None
797 return self.__ty
798
799 @cached_property
800 def stops(self):
801 # type: () -> FMap[LocKind, FBitSet]
802 if self.ty is None:
803 return FMap()
804 sh = self.ty.reg_len
805 return FMap(
806 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
807
808 @property
809 def kinds(self):
810 # type: () -> AbstractSet[LocKind]
811 return self.starts.keys()
812
813 @property
814 def reg_len(self):
815 # type: () -> int | None
816 if self.ty is None:
817 return None
818 return self.ty.reg_len
819
820 @property
821 def base_ty(self):
822 # type: () -> BaseTy | None
823 if self.ty is None:
824 return None
825 return self.ty.base_ty
826
827 def concat(self, *others):
828 # type: (*LocSet) -> LocSet
829 if self.ty is None:
830 return LocSet()
831 base_ty = self.ty.base_ty
832 reg_len = self.ty.reg_len
833 starts = {k: BitSet(v) for k, v in self.starts.items()}
834 for other in others:
835 if other.ty is None:
836 return LocSet()
837 if other.ty.base_ty != base_ty:
838 return LocSet()
839 for kind, other_starts in other.starts.items():
840 if kind not in starts:
841 continue
842 starts[kind].bits &= other_starts.bits >> reg_len
843 if starts[kind] == 0:
844 del starts[kind]
845 if len(starts) == 0:
846 return LocSet()
847 reg_len += other.ty.reg_len
848
849 def locs():
850 # type: () -> Iterable[Loc]
851 for kind, v in starts.items():
852 for start in v:
853 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
854 if loc is not None:
855 yield loc
856 return LocSet(locs())
857
858 @lru_cache(maxsize=None, typed=True)
859 def max_conflicts_with(self, other):
860 # type: (LocSet | Loc) -> int
861 """the largest number of Locs in `self` that a single Loc
862 from `other` can conflict with
863 """
864 if isinstance(other, LocSet):
865 return max(self.max_conflicts_with(i) for i in other)
866 else:
867 return sum(other.conflicts(i) for i in self)
868
869 def __repr__(self):
870 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
871
872
873 @plain_data(frozen=True, unsafe_hash=True)
874 @final
875 class GenericOperandDesc(metaclass=InternedMeta):
876 """generic Op operand descriptor"""
877 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
878 "write_stage")
879
880 def __init__(
881 self, ty, # type: GenericTy
882 sub_kinds, # type: Iterable[LocSubKind]
883 *,
884 fixed_loc=None, # type: Loc | None
885 tied_input_index=None, # type: int | None
886 spread=False, # type: bool
887 write_stage=OpStage.Early, # type: OpStage
888 ):
889 # type: (...) -> None
890 self.ty = ty
891 self.sub_kinds = OFSet(sub_kinds)
892 if len(self.sub_kinds) == 0:
893 raise ValueError("sub_kinds can't be empty")
894 self.fixed_loc = fixed_loc
895 if fixed_loc is not None:
896 if tied_input_index is not None:
897 raise ValueError("operand can't be both tied and fixed")
898 if not ty.can_instantiate_to(fixed_loc.ty):
899 raise ValueError(
900 f"fixed_loc has incompatible type for given generic "
901 f"type: fixed_loc={fixed_loc} generic ty={ty}")
902 if len(self.sub_kinds) != 1:
903 raise ValueError(
904 "multiple sub_kinds not allowed for fixed operand")
905 for sub_kind in self.sub_kinds:
906 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
907 raise ValueError(
908 f"fixed_loc not in given sub_kind: "
909 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
910 for sub_kind in self.sub_kinds:
911 if sub_kind.base_ty != ty.base_ty:
912 raise ValueError(f"sub_kind is incompatible with type: "
913 f"sub_kind={sub_kind} ty={ty}")
914 if tied_input_index is not None and tied_input_index < 0:
915 raise ValueError("invalid tied_input_index")
916 self.tied_input_index = tied_input_index
917 self.spread = spread
918 if spread:
919 if self.tied_input_index is not None:
920 raise ValueError("operand can't be both spread and tied")
921 if self.fixed_loc is not None:
922 raise ValueError("operand can't be both spread and fixed")
923 if self.ty.is_vec:
924 raise ValueError("operand can't be both spread and vector")
925 self.write_stage = write_stage
926
927 @cached_property
928 def ty_before_spread(self):
929 # type: () -> GenericTy
930 if self.spread:
931 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
932 return self.ty
933
934 def tied_to_input(self, tied_input_index):
935 # type: (int) -> Self
936 return GenericOperandDesc(self.ty, self.sub_kinds,
937 tied_input_index=tied_input_index,
938 write_stage=self.write_stage)
939
940 def with_fixed_loc(self, fixed_loc):
941 # type: (Loc) -> Self
942 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
943 write_stage=self.write_stage)
944
945 def with_write_stage(self, write_stage):
946 # type: (OpStage) -> Self
947 return GenericOperandDesc(self.ty, self.sub_kinds,
948 fixed_loc=self.fixed_loc,
949 tied_input_index=self.tied_input_index,
950 spread=self.spread,
951 write_stage=write_stage)
952
953 def instantiate(self, maxvl):
954 # type: (int) -> Iterable[OperandDesc]
955 # assumes all spread operands have ty.reg_len = 1
956 rep_count = 1
957 if self.spread:
958 rep_count = maxvl
959 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
960
961 def locs_before_spread():
962 # type: () -> Iterable[Loc]
963 if self.fixed_loc is not None:
964 if ty_before_spread != self.fixed_loc.ty:
965 raise ValueError(
966 f"instantiation failed: type mismatch with fixed_loc: "
967 f"instantiated type: {ty_before_spread} "
968 f"fixed_loc: {self.fixed_loc}")
969 yield self.fixed_loc
970 return
971 for sub_kind in self.sub_kinds:
972 yield from sub_kind.allocatable_locs(ty_before_spread)
973 loc_set_before_spread = LocSet(locs_before_spread())
974 for idx in range(rep_count):
975 if not self.spread:
976 idx = None
977 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
978 tied_input_index=self.tied_input_index,
979 spread_index=idx, write_stage=self.write_stage)
980
981
982 @plain_data(frozen=True, unsafe_hash=True)
983 @final
984 class OperandDesc(metaclass=InternedMeta):
985 """Op operand descriptor"""
986 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
987 "write_stage")
988
989 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
990 write_stage):
991 # type: (LocSet, int | None, int | None, OpStage) -> None
992 if len(loc_set_before_spread) == 0:
993 raise ValueError("loc_set_before_spread must not be empty")
994 self.loc_set_before_spread = loc_set_before_spread
995 self.tied_input_index = tied_input_index
996 if self.tied_input_index is not None and spread_index is not None:
997 raise ValueError("operand can't be both spread and tied")
998 self.spread_index = spread_index
999 self.write_stage = write_stage
1000
1001 @cached_property
1002 def ty_before_spread(self):
1003 # type: () -> Ty
1004 ty = self.loc_set_before_spread.ty
1005 assert ty is not None, (
1006 "__init__ checked that the LocSet isn't empty, "
1007 "non-empty LocSets should always have ty set")
1008 return ty
1009
1010 @cached_property
1011 def ty(self):
1012 """ Ty after any spread is applied """
1013 if self.spread_index is not None:
1014 # assumes all spread operands have ty.reg_len = 1
1015 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1016 return self.ty_before_spread
1017
1018 @property
1019 def reg_offset_in_unspread(self):
1020 """ the number of reg-sized slots in the unspread Loc before self's Loc
1021
1022 e.g. if the unspread Loc containing self is:
1023 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1024 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1025 then reg_offset_into_unspread == 2 == 10 - 8
1026 """
1027 if self.spread_index is None:
1028 return 0
1029 return self.spread_index * self.ty.reg_len
1030
1031
1032 OD_BASE_SGPR = GenericOperandDesc(
1033 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1034 sub_kinds=[LocSubKind.BASE_GPR])
1035 OD_EXTRA3_SGPR = GenericOperandDesc(
1036 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1037 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1038 OD_EXTRA3_VGPR = GenericOperandDesc(
1039 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1040 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1041 OD_EXTRA2_SGPR = GenericOperandDesc(
1042 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1043 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1044 OD_EXTRA2_VGPR = GenericOperandDesc(
1045 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1046 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1047 OD_CA = GenericOperandDesc(
1048 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1049 sub_kinds=[LocSubKind.CA])
1050 OD_VL = GenericOperandDesc(
1051 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1052 sub_kinds=[LocSubKind.VL_MAXVL])
1053
1054
1055 @plain_data(frozen=True, unsafe_hash=True)
1056 @final
1057 class GenericOpProperties(metaclass=InternedMeta):
1058 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
1059 "is_copy", "is_load_immediate", "has_side_effects")
1060
1061 def __init__(
1062 self, demo_asm, # type: str
1063 inputs, # type: Iterable[GenericOperandDesc]
1064 outputs, # type: Iterable[GenericOperandDesc]
1065 immediates=(), # type: Iterable[range]
1066 is_copy=False, # type: bool
1067 is_load_immediate=False, # type: bool
1068 has_side_effects=False, # type: bool
1069 ):
1070 # type: (...) -> None
1071 self.demo_asm = demo_asm # type: str
1072 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
1073 for inp in self.inputs:
1074 if inp.tied_input_index is not None:
1075 raise ValueError(
1076 f"tied_input_index is not allowed on inputs: {inp}")
1077 if inp.write_stage is not OpStage.Early:
1078 raise ValueError(
1079 f"write_stage is not allowed on inputs: {inp}")
1080 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
1081 fixed_locs = [] # type: list[tuple[Loc, int]]
1082 for idx, out in enumerate(self.outputs):
1083 if out.tied_input_index is not None:
1084 if out.tied_input_index >= len(self.inputs):
1085 raise ValueError(f"tied_input_index out of range: {out}")
1086 tied_inp = self.inputs[out.tied_input_index]
1087 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1088 .with_write_stage(out.write_stage)
1089 if expected_out != out:
1090 raise ValueError(f"output can't be tied to non-equivalent "
1091 f"input: {out} tied to {tied_inp}")
1092 if out.fixed_loc is not None:
1093 for other_fixed_loc, other_idx in fixed_locs:
1094 if not other_fixed_loc.conflicts(out.fixed_loc):
1095 continue
1096 raise ValueError(
1097 f"conflicting fixed_locs: outputs[{idx}] and "
1098 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1099 f"with {other_fixed_loc}")
1100 fixed_locs.append((out.fixed_loc, idx))
1101 self.immediates = tuple(immediates) # type: tuple[range, ...]
1102 self.is_copy = is_copy # type: bool
1103 self.is_load_immediate = is_load_immediate # type: bool
1104 self.has_side_effects = has_side_effects # type: bool
1105
1106
1107 @plain_data(frozen=True, unsafe_hash=True)
1108 @final
1109 class OpProperties(metaclass=InternedMeta):
1110 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1111
1112 def __init__(self, kind, maxvl):
1113 # type: (OpKind, int) -> None
1114 self.kind = kind # type: OpKind
1115 inputs = [] # type: list[OperandDesc]
1116 for inp in self.generic.inputs:
1117 inputs.extend(inp.instantiate(maxvl=maxvl))
1118 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1119 outputs = [] # type: list[OperandDesc]
1120 for out in self.generic.outputs:
1121 outputs.extend(out.instantiate(maxvl=maxvl))
1122 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1123 self.maxvl = maxvl # type: int
1124 copy_input_reg_len = 0
1125 for inp in self.inputs[:self.copy_inputs_len]:
1126 copy_input_reg_len += inp.ty.reg_len
1127 copy_output_reg_len = 0
1128 for out in self.outputs[:self.copy_outputs_len]:
1129 copy_output_reg_len += out.ty.reg_len
1130 if copy_input_reg_len != copy_output_reg_len:
1131 raise ValueError(f"invalid copy: copy's input reg len must "
1132 f"match its output reg len: "
1133 f"{copy_input_reg_len} != {copy_output_reg_len}")
1134 self.copy_reg_len = copy_input_reg_len
1135
1136 @property
1137 def generic(self):
1138 # type: () -> GenericOpProperties
1139 return self.kind.properties
1140
1141 @property
1142 def immediates(self):
1143 # type: () -> tuple[range, ...]
1144 return self.generic.immediates
1145
1146 @property
1147 def demo_asm(self):
1148 # type: () -> str
1149 return self.generic.demo_asm
1150
1151 @property
1152 def is_copy(self):
1153 # type: () -> bool
1154 return self.generic.is_copy
1155
1156 @property
1157 def is_load_immediate(self):
1158 # type: () -> bool
1159 return self.generic.is_load_immediate
1160
1161 @property
1162 def has_side_effects(self):
1163 # type: () -> bool
1164 return self.generic.has_side_effects
1165
1166 @cached_property
1167 def copy_inputs_len(self):
1168 # type: () -> int
1169 if not self.is_copy:
1170 return 0
1171 if self.inputs[0].spread_index is None:
1172 return 1
1173 retval = 0
1174 for i, inp in enumerate(self.inputs):
1175 if inp.spread_index != i:
1176 break
1177 retval += 1
1178 return retval
1179
1180 @cached_property
1181 def copy_outputs_len(self):
1182 # type: () -> int
1183 if not self.is_copy:
1184 return 0
1185 if self.outputs[0].spread_index is None:
1186 return 1
1187 retval = 0
1188 for i, out in enumerate(self.outputs):
1189 if out.spread_index != i:
1190 break
1191 retval += 1
1192 return retval
1193
1194
1195 IMM_S16 = range(-1 << 15, 1 << 15)
1196
1197 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1198 _SIM_FN2 = Callable[[], _SIM_FN]
1199 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1200 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1201 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1202 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1203
1204
1205 @unique
1206 @final
1207 class OpKind(Enum):
1208 def __init__(self, properties):
1209 # type: (GenericOpProperties) -> None
1210 super().__init__()
1211 self.__properties = properties
1212
1213 @property
1214 def properties(self):
1215 # type: () -> GenericOpProperties
1216 return self.__properties
1217
1218 def instantiate(self, maxvl):
1219 # type: (int) -> OpProperties
1220 return OpProperties(self, maxvl=maxvl)
1221
1222 def __repr__(self):
1223 # type: () -> str
1224 return "OpKind." + self._name_
1225
1226 @cached_property
1227 def sim(self):
1228 # type: () -> _SIM_FN
1229 return _SIM_FNS[self.properties]()
1230
1231 @cached_property
1232 def gen_asm(self):
1233 # type: () -> _GEN_ASM_FN
1234 return _GEN_ASMS[self.properties]()
1235
1236 @staticmethod
1237 def __clearca_sim(op, state):
1238 # type: (Op, BaseSimState) -> None
1239 state[op.outputs[0]] = False,
1240
1241 @staticmethod
1242 def __clearca_gen_asm(op, state):
1243 # type: (Op, GenAsmState) -> None
1244 state.writeln("addic 0, 0, 0")
1245 ClearCA = GenericOpProperties(
1246 demo_asm="addic 0, 0, 0",
1247 inputs=[],
1248 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1249 )
1250 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1251 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1252
1253 @staticmethod
1254 def __setca_sim(op, state):
1255 # type: (Op, BaseSimState) -> None
1256 state[op.outputs[0]] = True,
1257
1258 @staticmethod
1259 def __setca_gen_asm(op, state):
1260 # type: (Op, GenAsmState) -> None
1261 state.writeln("subfc 0, 0, 0")
1262 SetCA = GenericOpProperties(
1263 demo_asm="subfc 0, 0, 0",
1264 inputs=[],
1265 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1266 )
1267 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1268 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1269
1270 @staticmethod
1271 def __svadde_sim(op, state):
1272 # type: (Op, BaseSimState) -> None
1273 RA = state[op.input_vals[0]]
1274 RB = state[op.input_vals[1]]
1275 carry, = state[op.input_vals[2]]
1276 VL, = state[op.input_vals[3]]
1277 RT = [] # type: list[int]
1278 for i in range(VL):
1279 v = RA[i] + RB[i] + carry
1280 RT.append(v & GPR_VALUE_MASK)
1281 carry = (v >> GPR_SIZE_IN_BITS) != 0
1282 state[op.outputs[0]] = tuple(RT)
1283 state[op.outputs[1]] = carry,
1284
1285 @staticmethod
1286 def __svadde_gen_asm(op, state):
1287 # type: (Op, GenAsmState) -> None
1288 RT = state.vgpr(op.outputs[0])
1289 RA = state.vgpr(op.input_vals[0])
1290 RB = state.vgpr(op.input_vals[1])
1291 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1292 SvAddE = GenericOpProperties(
1293 demo_asm="sv.adde *RT, *RA, *RB",
1294 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1295 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1296 )
1297 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1298 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1299
1300 @staticmethod
1301 def __addze_sim(op, state):
1302 # type: (Op, BaseSimState) -> None
1303 RA, = state[op.input_vals[0]]
1304 carry, = state[op.input_vals[1]]
1305 v = RA + carry
1306 RT = v & GPR_VALUE_MASK
1307 carry = (v >> GPR_SIZE_IN_BITS) != 0
1308 state[op.outputs[0]] = RT,
1309 state[op.outputs[1]] = carry,
1310
1311 @staticmethod
1312 def __addze_gen_asm(op, state):
1313 # type: (Op, GenAsmState) -> None
1314 RT = state.vgpr(op.outputs[0])
1315 RA = state.vgpr(op.input_vals[0])
1316 state.writeln(f"addze {RT}, {RA}")
1317 AddZE = GenericOpProperties(
1318 demo_asm="addze RT, RA",
1319 inputs=[OD_BASE_SGPR, OD_CA],
1320 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1321 )
1322 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1323 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1324
1325 @staticmethod
1326 def __svsubfe_sim(op, state):
1327 # type: (Op, BaseSimState) -> None
1328 RA = state[op.input_vals[0]]
1329 RB = state[op.input_vals[1]]
1330 carry, = state[op.input_vals[2]]
1331 VL, = state[op.input_vals[3]]
1332 RT = [] # type: list[int]
1333 for i in range(VL):
1334 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1335 RT.append(v & GPR_VALUE_MASK)
1336 carry = (v >> GPR_SIZE_IN_BITS) != 0
1337 state[op.outputs[0]] = tuple(RT)
1338 state[op.outputs[1]] = carry,
1339
1340 @staticmethod
1341 def __svsubfe_gen_asm(op, state):
1342 # type: (Op, GenAsmState) -> None
1343 RT = state.vgpr(op.outputs[0])
1344 RA = state.vgpr(op.input_vals[0])
1345 RB = state.vgpr(op.input_vals[1])
1346 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1347 SvSubFE = GenericOpProperties(
1348 demo_asm="sv.subfe *RT, *RA, *RB",
1349 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1350 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1351 )
1352 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1353 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1354
1355 @staticmethod
1356 def __svandvs_sim(op, state):
1357 # type: (Op, BaseSimState) -> None
1358 RA = state[op.input_vals[0]]
1359 RB, = state[op.input_vals[1]]
1360 VL, = state[op.input_vals[2]]
1361 RT = [] # type: list[int]
1362 for i in range(VL):
1363 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1364 state[op.outputs[0]] = tuple(RT)
1365
1366 @staticmethod
1367 def __svandvs_gen_asm(op, state):
1368 # type: (Op, GenAsmState) -> None
1369 RT = state.vgpr(op.outputs[0])
1370 RA = state.vgpr(op.input_vals[0])
1371 RB = state.sgpr(op.input_vals[1])
1372 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1373 SvAndVS = GenericOpProperties(
1374 demo_asm="sv.and *RT, *RA, RB",
1375 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1376 outputs=[OD_EXTRA3_VGPR],
1377 )
1378 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1379 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1380
1381 @staticmethod
1382 def __svmaddedu_sim(op, state):
1383 # type: (Op, BaseSimState) -> None
1384 RA = state[op.input_vals[0]]
1385 RB, = state[op.input_vals[1]]
1386 carry, = state[op.input_vals[2]]
1387 VL, = state[op.input_vals[3]]
1388 RT = [] # type: list[int]
1389 for i in range(VL):
1390 v = RA[i] * RB + carry
1391 RT.append(v & GPR_VALUE_MASK)
1392 carry = v >> GPR_SIZE_IN_BITS
1393 state[op.outputs[0]] = tuple(RT)
1394 state[op.outputs[1]] = carry,
1395
1396 @staticmethod
1397 def __svmaddedu_gen_asm(op, state):
1398 # type: (Op, GenAsmState) -> None
1399 RT = state.vgpr(op.outputs[0])
1400 RA = state.vgpr(op.input_vals[0])
1401 RB = state.sgpr(op.input_vals[1])
1402 RC = state.sgpr(op.input_vals[2])
1403 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1404 SvMAddEDU = GenericOpProperties(
1405 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1406 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1407 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1408 )
1409 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1410 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1411
1412 @staticmethod
1413 def __sradi_sim(op, state):
1414 # type: (Op, BaseSimState) -> None
1415 rs, = state[op.input_vals[0]]
1416 imm = op.immediates[0]
1417 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1418 rs -= 1 << GPR_SIZE_IN_BITS
1419 v = rs >> imm
1420 RA = v & GPR_VALUE_MASK
1421 CA = (RA << imm) != rs
1422 state[op.outputs[0]] = RA,
1423 state[op.outputs[1]] = CA,
1424
1425 @staticmethod
1426 def __sradi_gen_asm(op, state):
1427 # type: (Op, GenAsmState) -> None
1428 RA = state.sgpr(op.outputs[0])
1429 RS = state.sgpr(op.input_vals[0])
1430 imm = op.immediates[0]
1431 state.writeln(f"sradi {RA}, {RS}, {imm}")
1432 SRADI = GenericOpProperties(
1433 demo_asm="sradi RA, RS, imm",
1434 inputs=[OD_BASE_SGPR],
1435 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1436 OD_CA.with_write_stage(OpStage.Late)],
1437 immediates=[range(GPR_SIZE_IN_BITS)],
1438 )
1439 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1440 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1441
1442 @staticmethod
1443 def __setvli_sim(op, state):
1444 # type: (Op, BaseSimState) -> None
1445 state[op.outputs[0]] = op.immediates[0],
1446
1447 @staticmethod
1448 def __setvli_gen_asm(op, state):
1449 # type: (Op, GenAsmState) -> None
1450 imm = op.immediates[0]
1451 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1452 SetVLI = GenericOpProperties(
1453 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1454 inputs=(),
1455 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1456 immediates=[range(1, 65)],
1457 is_load_immediate=True,
1458 )
1459 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1460 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1461
1462 @staticmethod
1463 def __svli_sim(op, state):
1464 # type: (Op, BaseSimState) -> None
1465 VL, = state[op.input_vals[0]]
1466 imm = op.immediates[0] & GPR_VALUE_MASK
1467 state[op.outputs[0]] = (imm,) * VL
1468
1469 @staticmethod
1470 def __svli_gen_asm(op, state):
1471 # type: (Op, GenAsmState) -> None
1472 RT = state.vgpr(op.outputs[0])
1473 imm = op.immediates[0]
1474 state.writeln(f"sv.addi {RT}, 0, {imm}")
1475 SvLI = GenericOpProperties(
1476 demo_asm="sv.addi *RT, 0, imm",
1477 inputs=[OD_VL],
1478 outputs=[OD_EXTRA3_VGPR],
1479 immediates=[IMM_S16],
1480 is_load_immediate=True,
1481 )
1482 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1483 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1484
1485 @staticmethod
1486 def __li_sim(op, state):
1487 # type: (Op, BaseSimState) -> None
1488 imm = op.immediates[0] & GPR_VALUE_MASK
1489 state[op.outputs[0]] = imm,
1490
1491 @staticmethod
1492 def __li_gen_asm(op, state):
1493 # type: (Op, GenAsmState) -> None
1494 RT = state.sgpr(op.outputs[0])
1495 imm = op.immediates[0]
1496 state.writeln(f"addi {RT}, 0, {imm}")
1497 LI = GenericOpProperties(
1498 demo_asm="addi RT, 0, imm",
1499 inputs=(),
1500 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1501 immediates=[IMM_S16],
1502 is_load_immediate=True,
1503 )
1504 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1505 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1506
1507 @staticmethod
1508 def __veccopytoreg_sim(op, state):
1509 # type: (Op, BaseSimState) -> None
1510 state[op.outputs[0]] = state[op.input_vals[0]]
1511
1512 @staticmethod
1513 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1514 # type: (Loc, Loc, bool, GenAsmState) -> None
1515 sv = "sv." if is_vec else ""
1516 rev = ""
1517 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1518 rev = "/mrr"
1519 if src_loc == dest_loc:
1520 return # no-op
1521 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1522 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1523 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1524 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1525 if src_loc.kind is LocKind.StackI64:
1526 if dest_loc.kind is LocKind.StackI64:
1527 raise ValueError(
1528 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1529 elif dest_loc.kind is not LocKind.GPR:
1530 assert_never(dest_loc.kind)
1531 src = state.stack(src_loc)
1532 dest = state.gpr(dest_loc, is_vec=is_vec)
1533 state.writeln(f"{sv}ld {dest}, {src}")
1534 elif dest_loc.kind is LocKind.StackI64:
1535 if src_loc.kind is not LocKind.GPR:
1536 assert_never(src_loc.kind)
1537 src = state.gpr(src_loc, is_vec=is_vec)
1538 dest = state.stack(dest_loc)
1539 state.writeln(f"{sv}std {src}, {dest}")
1540 elif src_loc.kind is LocKind.GPR:
1541 if dest_loc.kind is not LocKind.GPR:
1542 assert_never(dest_loc.kind)
1543 src = state.gpr(src_loc, is_vec=is_vec)
1544 dest = state.gpr(dest_loc, is_vec=is_vec)
1545 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1546 else:
1547 assert_never(src_loc.kind)
1548
1549 @staticmethod
1550 def __veccopytoreg_gen_asm(op, state):
1551 # type: (Op, GenAsmState) -> None
1552 OpKind.__copy_to_from_reg_gen_asm(
1553 src_loc=state.loc(
1554 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1555 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1556 is_vec=True, state=state)
1557
1558 VecCopyToReg = GenericOpProperties(
1559 demo_asm="sv.mv dest, src",
1560 inputs=[GenericOperandDesc(
1561 ty=GenericTy(BaseTy.I64, is_vec=True),
1562 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1563 ), OD_VL],
1564 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1565 is_copy=True,
1566 )
1567 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1568 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1569
1570 @staticmethod
1571 def __veccopyfromreg_sim(op, state):
1572 # type: (Op, BaseSimState) -> None
1573 state[op.outputs[0]] = state[op.input_vals[0]]
1574
1575 @staticmethod
1576 def __veccopyfromreg_gen_asm(op, state):
1577 # type: (Op, GenAsmState) -> None
1578 OpKind.__copy_to_from_reg_gen_asm(
1579 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1580 dest_loc=state.loc(
1581 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1582 is_vec=True, state=state)
1583 VecCopyFromReg = GenericOpProperties(
1584 demo_asm="sv.mv dest, src",
1585 inputs=[OD_EXTRA3_VGPR, OD_VL],
1586 outputs=[GenericOperandDesc(
1587 ty=GenericTy(BaseTy.I64, is_vec=True),
1588 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1589 write_stage=OpStage.Late,
1590 )],
1591 is_copy=True,
1592 )
1593 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1594 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1595
1596 @staticmethod
1597 def __copytoreg_sim(op, state):
1598 # type: (Op, BaseSimState) -> None
1599 state[op.outputs[0]] = state[op.input_vals[0]]
1600
1601 @staticmethod
1602 def __copytoreg_gen_asm(op, state):
1603 # type: (Op, GenAsmState) -> None
1604 OpKind.__copy_to_from_reg_gen_asm(
1605 src_loc=state.loc(
1606 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1607 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1608 is_vec=False, state=state)
1609 CopyToReg = GenericOpProperties(
1610 demo_asm="mv dest, src",
1611 inputs=[GenericOperandDesc(
1612 ty=GenericTy(BaseTy.I64, is_vec=False),
1613 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1614 LocSubKind.StackI64],
1615 )],
1616 outputs=[GenericOperandDesc(
1617 ty=GenericTy(BaseTy.I64, is_vec=False),
1618 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1619 write_stage=OpStage.Late,
1620 )],
1621 is_copy=True,
1622 )
1623 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1624 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1625
1626 @staticmethod
1627 def __copyfromreg_sim(op, state):
1628 # type: (Op, BaseSimState) -> None
1629 state[op.outputs[0]] = state[op.input_vals[0]]
1630
1631 @staticmethod
1632 def __copyfromreg_gen_asm(op, state):
1633 # type: (Op, GenAsmState) -> None
1634 OpKind.__copy_to_from_reg_gen_asm(
1635 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1636 dest_loc=state.loc(
1637 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1638 is_vec=False, state=state)
1639 CopyFromReg = GenericOpProperties(
1640 demo_asm="mv dest, src",
1641 inputs=[GenericOperandDesc(
1642 ty=GenericTy(BaseTy.I64, is_vec=False),
1643 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1644 )],
1645 outputs=[GenericOperandDesc(
1646 ty=GenericTy(BaseTy.I64, is_vec=False),
1647 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1648 LocSubKind.StackI64],
1649 write_stage=OpStage.Late,
1650 )],
1651 is_copy=True,
1652 )
1653 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1654 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1655
1656 @staticmethod
1657 def __concat_sim(op, state):
1658 # type: (Op, BaseSimState) -> None
1659 state[op.outputs[0]] = tuple(
1660 state[i][0] for i in op.input_vals[:-1])
1661
1662 @staticmethod
1663 def __concat_gen_asm(op, state):
1664 # type: (Op, GenAsmState) -> None
1665 OpKind.__copy_to_from_reg_gen_asm(
1666 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1667 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1668 is_vec=True, state=state)
1669 Concat = GenericOpProperties(
1670 demo_asm="sv.mv dest, src",
1671 inputs=[GenericOperandDesc(
1672 ty=GenericTy(BaseTy.I64, is_vec=False),
1673 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1674 spread=True,
1675 ), OD_VL],
1676 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1677 is_copy=True,
1678 )
1679 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1680 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1681
1682 @staticmethod
1683 def __spread_sim(op, state):
1684 # type: (Op, BaseSimState) -> None
1685 for idx, inp in enumerate(state[op.input_vals[0]]):
1686 state[op.outputs[idx]] = inp,
1687
1688 @staticmethod
1689 def __spread_gen_asm(op, state):
1690 # type: (Op, GenAsmState) -> None
1691 OpKind.__copy_to_from_reg_gen_asm(
1692 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1693 dest_loc=state.loc(op.outputs, LocKind.GPR),
1694 is_vec=True, state=state)
1695 Spread = GenericOpProperties(
1696 demo_asm="sv.mv dest, src",
1697 inputs=[OD_EXTRA3_VGPR, OD_VL],
1698 outputs=[GenericOperandDesc(
1699 ty=GenericTy(BaseTy.I64, is_vec=False),
1700 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1701 spread=True,
1702 write_stage=OpStage.Late,
1703 )],
1704 is_copy=True,
1705 )
1706 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1707 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1708
1709 @staticmethod
1710 def __svld_sim(op, state):
1711 # type: (Op, BaseSimState) -> None
1712 RA, = state[op.input_vals[0]]
1713 VL, = state[op.input_vals[1]]
1714 addr = RA + op.immediates[0]
1715 RT = [] # type: list[int]
1716 for i in range(VL):
1717 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1718 RT.append(v & GPR_VALUE_MASK)
1719 state[op.outputs[0]] = tuple(RT)
1720
1721 @staticmethod
1722 def __svld_gen_asm(op, state):
1723 # type: (Op, GenAsmState) -> None
1724 RA = state.sgpr(op.input_vals[0])
1725 RT = state.vgpr(op.outputs[0])
1726 imm = op.immediates[0]
1727 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1728 SvLd = GenericOpProperties(
1729 demo_asm="sv.ld *RT, imm(RA)",
1730 inputs=[OD_EXTRA3_SGPR, OD_VL],
1731 outputs=[OD_EXTRA3_VGPR],
1732 immediates=[IMM_S16],
1733 )
1734 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1735 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1736
1737 @staticmethod
1738 def __ld_sim(op, state):
1739 # type: (Op, BaseSimState) -> None
1740 RA, = state[op.input_vals[0]]
1741 addr = RA + op.immediates[0]
1742 v = state.load(addr)
1743 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1744
1745 @staticmethod
1746 def __ld_gen_asm(op, state):
1747 # type: (Op, GenAsmState) -> None
1748 RA = state.sgpr(op.input_vals[0])
1749 RT = state.sgpr(op.outputs[0])
1750 imm = op.immediates[0]
1751 state.writeln(f"ld {RT}, {imm}({RA})")
1752 Ld = GenericOpProperties(
1753 demo_asm="ld RT, imm(RA)",
1754 inputs=[OD_BASE_SGPR],
1755 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1756 immediates=[IMM_S16],
1757 )
1758 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1759 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1760
1761 @staticmethod
1762 def __svstd_sim(op, state):
1763 # type: (Op, BaseSimState) -> None
1764 RS = state[op.input_vals[0]]
1765 RA, = state[op.input_vals[1]]
1766 VL, = state[op.input_vals[2]]
1767 addr = RA + op.immediates[0]
1768 for i in range(VL):
1769 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1770
1771 @staticmethod
1772 def __svstd_gen_asm(op, state):
1773 # type: (Op, GenAsmState) -> None
1774 RS = state.vgpr(op.input_vals[0])
1775 RA = state.sgpr(op.input_vals[1])
1776 imm = op.immediates[0]
1777 state.writeln(f"sv.std {RS}, {imm}({RA})")
1778 SvStd = GenericOpProperties(
1779 demo_asm="sv.std *RS, imm(RA)",
1780 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1781 outputs=[],
1782 immediates=[IMM_S16],
1783 has_side_effects=True,
1784 )
1785 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1786 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1787
1788 @staticmethod
1789 def __std_sim(op, state):
1790 # type: (Op, BaseSimState) -> None
1791 RS, = state[op.input_vals[0]]
1792 RA, = state[op.input_vals[1]]
1793 addr = RA + op.immediates[0]
1794 state.store(addr, value=RS)
1795
1796 @staticmethod
1797 def __std_gen_asm(op, state):
1798 # type: (Op, GenAsmState) -> None
1799 RS = state.sgpr(op.input_vals[0])
1800 RA = state.sgpr(op.input_vals[1])
1801 imm = op.immediates[0]
1802 state.writeln(f"std {RS}, {imm}({RA})")
1803 Std = GenericOpProperties(
1804 demo_asm="std RS, imm(RA)",
1805 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1806 outputs=[],
1807 immediates=[IMM_S16],
1808 has_side_effects=True,
1809 )
1810 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1811 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1812
1813 @staticmethod
1814 def __funcargr3_sim(op, state):
1815 # type: (Op, BaseSimState) -> None
1816 pass # return value set before simulation
1817
1818 @staticmethod
1819 def __funcargr3_gen_asm(op, state):
1820 # type: (Op, GenAsmState) -> None
1821 pass # no instructions needed
1822 FuncArgR3 = GenericOpProperties(
1823 demo_asm="",
1824 inputs=[],
1825 outputs=[OD_BASE_SGPR.with_fixed_loc(
1826 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1827 )
1828 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1829 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1830
1831
1832 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1833 class SSAValOrUse(metaclass=InternedMeta):
1834 __slots__ = "op", "operand_idx"
1835
1836 def __init__(self, op, operand_idx):
1837 # type: (Op, int) -> None
1838 super().__init__()
1839 self.op = op
1840 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1841 raise ValueError("invalid operand_idx")
1842 self.operand_idx = operand_idx
1843
1844 @abstractmethod
1845 def __repr__(self):
1846 # type: () -> str
1847 ...
1848
1849 @property
1850 @abstractmethod
1851 def descriptor_array(self):
1852 # type: () -> tuple[OperandDesc, ...]
1853 ...
1854
1855 @cached_property
1856 def defining_descriptor(self):
1857 # type: () -> OperandDesc
1858 return self.descriptor_array[self.operand_idx]
1859
1860 @cached_property
1861 def ty(self):
1862 # type: () -> Ty
1863 return self.defining_descriptor.ty
1864
1865 @cached_property
1866 def ty_before_spread(self):
1867 # type: () -> Ty
1868 return self.defining_descriptor.ty_before_spread
1869
1870 @property
1871 def base_ty(self):
1872 # type: () -> BaseTy
1873 return self.ty_before_spread.base_ty
1874
1875 @property
1876 def reg_offset_in_unspread(self):
1877 """ the number of reg-sized slots in the unspread Loc before self's Loc
1878
1879 e.g. if the unspread Loc containing self is:
1880 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1881 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1882 then reg_offset_into_unspread == 2 == 10 - 8
1883 """
1884 return self.defining_descriptor.reg_offset_in_unspread
1885
1886 @property
1887 def unspread_start_idx(self):
1888 # type: () -> int
1889 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1890
1891 @property
1892 def unspread_start(self):
1893 # type: () -> Self
1894 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1895
1896
1897 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1898 @final
1899 class SSAVal(SSAValOrUse):
1900 __slots__ = ()
1901
1902 def __repr__(self):
1903 # type: () -> str
1904 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1905
1906 @cached_property
1907 def def_loc_set_before_spread(self):
1908 # type: () -> LocSet
1909 return self.defining_descriptor.loc_set_before_spread
1910
1911 @cached_property
1912 def descriptor_array(self):
1913 # type: () -> tuple[OperandDesc, ...]
1914 return self.op.properties.outputs
1915
1916 @cached_property
1917 def tied_input(self):
1918 # type: () -> None | SSAUse
1919 if self.defining_descriptor.tied_input_index is None:
1920 return None
1921 return SSAUse(op=self.op,
1922 operand_idx=self.defining_descriptor.tied_input_index)
1923
1924 @property
1925 def write_stage(self):
1926 # type: () -> OpStage
1927 return self.defining_descriptor.write_stage
1928
1929 @property
1930 def current_debugging_value(self):
1931 # type: () -> tuple[int, ...]
1932 """ get the current value for debugging in pdb or similar.
1933
1934 This is intended for use with
1935 `PreRASimState.set_current_debugging_state`.
1936
1937 This is only intended for debugging, do not use in unit tests or
1938 production code.
1939 """
1940 return PreRASimState.get_current_debugging_state()[self]
1941
1942 @cached_property
1943 def ssa_val_sub_regs(self):
1944 # type: () -> tuple[SSAValSubReg, ...]
1945 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
1946
1947
1948 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1949 @final
1950 class SSAUse(SSAValOrUse):
1951 __slots__ = ()
1952
1953 @cached_property
1954 def use_loc_set_before_spread(self):
1955 # type: () -> LocSet
1956 return self.defining_descriptor.loc_set_before_spread
1957
1958 @cached_property
1959 def descriptor_array(self):
1960 # type: () -> tuple[OperandDesc, ...]
1961 return self.op.properties.inputs
1962
1963 def __repr__(self):
1964 # type: () -> str
1965 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1966
1967 @property
1968 def ssa_val(self):
1969 # type: () -> SSAVal
1970 return self.op.input_vals[self.operand_idx]
1971
1972 @ssa_val.setter
1973 def ssa_val(self, ssa_val):
1974 # type: (SSAVal) -> None
1975 self.op.input_vals[self.operand_idx] = ssa_val
1976
1977
1978 _T = TypeVar("_T")
1979 _Desc = TypeVar("_Desc")
1980
1981
1982 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1983 @abstractmethod
1984 def _verify_write_with_desc(self, idx, item, desc):
1985 # type: (int, _T | Any, _Desc) -> None
1986 raise NotImplementedError
1987
1988 @final
1989 def _verify_write(self, idx, item):
1990 # type: (int | Any, _T | Any) -> int
1991 if not isinstance(idx, int):
1992 if isinstance(idx, slice):
1993 raise TypeError(
1994 f"can't write to slice of {self.__class__.__name__}")
1995 raise TypeError(f"can't write with index {idx!r}")
1996 # normalize idx, raising IndexError if it is out of range
1997 idx = range(len(self.descriptors))[idx]
1998 desc = self.descriptors[idx]
1999 self._verify_write_with_desc(idx, item, desc)
2000 return idx
2001
2002 def _on_set(self, idx, new_item, old_item):
2003 # type: (int, _T, _T | None) -> None
2004 pass
2005
2006 @abstractmethod
2007 def _get_descriptors(self):
2008 # type: () -> tuple[_Desc, ...]
2009 raise NotImplementedError
2010
2011 @cached_property
2012 @final
2013 def descriptors(self):
2014 # type: () -> tuple[_Desc, ...]
2015 return self._get_descriptors()
2016
2017 @property
2018 @final
2019 def op(self):
2020 return self.__op
2021
2022 def __init__(self, items, op):
2023 # type: (Iterable[_T], Op) -> None
2024 super().__init__()
2025 self.__op = op
2026 self.__items = [] # type: list[_T]
2027 for idx, item in enumerate(items):
2028 if idx >= len(self.descriptors):
2029 raise ValueError("too many items")
2030 _ = self._verify_write(idx, item)
2031 self.__items.append(item)
2032 if len(self.__items) < len(self.descriptors):
2033 raise ValueError("not enough items")
2034
2035 @final
2036 def __iter__(self):
2037 # type: () -> Iterator[_T]
2038 yield from self.__items
2039
2040 @overload
2041 def __getitem__(self, idx):
2042 # type: (int) -> _T
2043 ...
2044
2045 @overload
2046 def __getitem__(self, idx):
2047 # type: (slice) -> list[_T]
2048 ...
2049
2050 @final
2051 def __getitem__(self, idx):
2052 # type: (int | slice) -> _T | list[_T]
2053 return self.__items[idx]
2054
2055 @final
2056 def __setitem__(self, idx, item):
2057 # type: (int, _T) -> None
2058 idx = self._verify_write(idx, item)
2059 self.__items[idx] = item
2060
2061 @final
2062 def __len__(self):
2063 # type: () -> int
2064 return len(self.__items)
2065
2066 def __repr__(self):
2067 # type: () -> str
2068 return f"{self.__class__.__name__}({self.__items}, op=...)"
2069
2070
2071 @final
2072 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2073 def _get_descriptors(self):
2074 # type: () -> tuple[OperandDesc, ...]
2075 return self.op.properties.inputs
2076
2077 def _verify_write_with_desc(self, idx, item, desc):
2078 # type: (int, SSAVal | Any, OperandDesc) -> None
2079 if not isinstance(item, SSAVal):
2080 raise TypeError("expected value of type SSAVal")
2081 if item.ty != desc.ty:
2082 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2083 f"corresponding input's type {desc.ty!r}")
2084
2085 def _on_set(self, idx, new_item, old_item):
2086 # type: (int, SSAVal, SSAVal | None) -> None
2087 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2088
2089 def __init__(self, items, op):
2090 # type: (Iterable[SSAVal], Op) -> None
2091 if hasattr(op, "inputs"):
2092 raise ValueError("Op.inputs already set")
2093 super().__init__(items, op)
2094
2095
2096 @final
2097 class OpImmediates(OpInputSeq[int, range]):
2098 def _get_descriptors(self):
2099 # type: () -> tuple[range, ...]
2100 return self.op.properties.immediates
2101
2102 def _verify_write_with_desc(self, idx, item, desc):
2103 # type: (int, int | Any, range) -> None
2104 if not isinstance(item, int):
2105 raise TypeError("expected value of type int")
2106 if item not in desc:
2107 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2108
2109 def __init__(self, items, op):
2110 # type: (Iterable[int], Op) -> None
2111 if hasattr(op, "immediates"):
2112 raise ValueError("Op.immediates already set")
2113 super().__init__(items, op)
2114
2115
2116 @plain_data(frozen=True, eq=False, repr=False)
2117 @final
2118 class Op:
2119 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2120 "outputs", "name")
2121
2122 def __init__(self, fn, properties, input_vals, immediates, name=""):
2123 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2124 self.fn = fn
2125 self.properties = properties
2126 self.input_vals = OpInputVals(input_vals, op=self)
2127 inputs_len = len(self.properties.inputs)
2128 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2129 self.immediates = OpImmediates(immediates, op=self)
2130 outputs_len = len(self.properties.outputs)
2131 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2132 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2133
2134 @property
2135 def kind(self):
2136 # type: () -> OpKind
2137 return self.properties.kind
2138
2139 def __eq__(self, other):
2140 # type: (Op | Any) -> bool
2141 if isinstance(other, Op):
2142 return self is other
2143 return NotImplemented
2144
2145 def __hash__(self):
2146 # type: () -> int
2147 return object.__hash__(self)
2148
2149 def __repr__(self, wrap_width=63, indent=" "):
2150 # type: (int, str) -> str
2151 WRAP_POINT = "\u200B" # zero-width space
2152 items = [f"{self.name}:\n"]
2153 for i, out in enumerate(self.outputs):
2154 item = f"<...outputs[{i}]: {out.ty}>"
2155 if i == 0:
2156 item = "(" + WRAP_POINT + item
2157 if i != len(self.outputs) - 1:
2158 item += ", " + WRAP_POINT
2159 else:
2160 item += WRAP_POINT + ") <= "
2161 items.append(item)
2162 items.append(self.kind._name_)
2163 if len(self.input_vals) + len(self.immediates) != 0:
2164 items[-1] += "("
2165 items[-1] += WRAP_POINT
2166 for i, inp in enumerate(self.input_vals):
2167 item = repr(inp)
2168 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2169 item += ", " + WRAP_POINT
2170 else:
2171 item += ") " + WRAP_POINT
2172 items.append(item)
2173 for i, imm in enumerate(self.immediates):
2174 item = hex(imm)
2175 if i != len(self.immediates) - 1:
2176 item += ", " + WRAP_POINT
2177 else:
2178 item += ") " + WRAP_POINT
2179 items.append(item)
2180 lines = [] # type: list[str]
2181 for i, line_in in enumerate("".join(items).splitlines()):
2182 if i != 0:
2183 line_in = indent + line_in
2184 line_out = ""
2185 for part in line_in.split(WRAP_POINT):
2186 if line_out == "":
2187 line_out = part
2188 continue
2189 trial_line_out = line_out + part
2190 if len(trial_line_out.rstrip()) > wrap_width:
2191 lines.append(line_out.rstrip())
2192 line_out = indent + part
2193 else:
2194 line_out = trial_line_out
2195 lines.append(line_out.rstrip())
2196 return "\n".join(lines)
2197
2198 def sim(self, state):
2199 # type: (BaseSimState) -> None
2200 for inp in self.input_vals:
2201 try:
2202 val = state[inp]
2203 except KeyError:
2204 raise ValueError(f"SSAVal {inp} not yet assigned when "
2205 f"running {self}")
2206 except SimSkipOp:
2207 continue
2208 if len(val) != inp.ty.reg_len:
2209 raise ValueError(
2210 f"value of SSAVal {inp} has wrong number of elements: "
2211 f"expected {inp.ty.reg_len} found "
2212 f"{len(val)}: {val!r}")
2213 if isinstance(state, PreRASimState):
2214 for out in self.outputs:
2215 if out in state.ssa_vals:
2216 if self.kind is OpKind.FuncArgR3:
2217 continue
2218 raise ValueError(f"SSAVal {out} already assigned before "
2219 f"running {self}")
2220 try:
2221 self.kind.sim(self, state)
2222 except SimSkipOp:
2223 state.on_skip(self)
2224 for out in self.outputs:
2225 try:
2226 val = state[out]
2227 except KeyError:
2228 raise ValueError(f"running {self} failed to assign to {out}")
2229 except SimSkipOp:
2230 continue
2231 if len(val) != out.ty.reg_len:
2232 raise ValueError(
2233 f"value of SSAVal {out} has wrong number of elements: "
2234 f"expected {out.ty.reg_len} found "
2235 f"{len(val)}: {val!r}")
2236
2237 def gen_asm(self, state):
2238 # type: (GenAsmState) -> None
2239 all_loc_kinds = tuple(LocKind)
2240 for inp in self.input_vals:
2241 state.loc(inp, expected_kinds=all_loc_kinds)
2242 for out in self.outputs:
2243 state.loc(out, expected_kinds=all_loc_kinds)
2244 self.kind.gen_asm(self, state)
2245
2246
2247 @plain_data(frozen=True, repr=False)
2248 class BaseSimState(metaclass=ABCMeta):
2249 __slots__ = "memory",
2250
2251 def __init__(self, memory):
2252 # type: (dict[int, int]) -> None
2253 super().__init__()
2254 self.memory = memory # type: dict[int, int]
2255
2256 def _default_memory_value(self):
2257 # type: () -> int
2258 return 0
2259
2260 def on_skip(self, op):
2261 # type: (Op) -> None
2262 raise ValueError("skipping instructions not supported")
2263
2264 def load_byte(self, addr):
2265 # type: (int) -> int
2266 addr &= GPR_VALUE_MASK
2267 try:
2268 return self.memory[addr] & 0xFF
2269 except KeyError:
2270 return self._default_memory_value()
2271
2272 def store_byte(self, addr, value):
2273 # type: (int, int) -> None
2274 addr &= GPR_VALUE_MASK
2275 value &= 0xFF
2276 self.memory[addr] = value
2277
2278 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2279 # type: (int, int, bool) -> int
2280 if addr % size_in_bytes != 0:
2281 raise ValueError(f"address not aligned: {hex(addr)} "
2282 f"required alignment: {size_in_bytes}")
2283 retval = 0
2284 for i in range(size_in_bytes):
2285 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2286 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2287 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2288 return retval
2289
2290 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2291 # type: (int, int, int) -> None
2292 if addr % size_in_bytes != 0:
2293 raise ValueError(f"address not aligned: {hex(addr)} "
2294 f"required alignment: {size_in_bytes}")
2295 for i in range(size_in_bytes):
2296 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2297
2298 def _memory__repr(self):
2299 # type: () -> str
2300 if len(self.memory) == 0:
2301 return "{}"
2302 keys = sorted(self.memory.keys(), reverse=True)
2303 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2304 items = [] # type: list[str]
2305 while len(keys) != 0:
2306 addr = keys[-1]
2307 if (len(keys) >= CHUNK_SIZE
2308 and addr % CHUNK_SIZE == 0
2309 and keys[-CHUNK_SIZE:]
2310 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2311 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2312 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2313 keys[-CHUNK_SIZE:] = ()
2314 else:
2315 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2316 if len(items) == 1:
2317 return f"{{{items[0]}}}"
2318 items_str = ",\n".join(items)
2319 return f"{{\n{items_str}}}"
2320
2321 def __repr__(self):
2322 # type: () -> str
2323 field_vals = [] # type: list[str]
2324 for name in fields(self):
2325 try:
2326 value = getattr(self, name)
2327 except AttributeError:
2328 field_vals.append(f"{name}=<not set>")
2329 continue
2330 repr_fn = getattr(self, f"_{name}__repr", None)
2331 if callable(repr_fn):
2332 field_vals.append(f"{name}={repr_fn()}")
2333 else:
2334 field_vals.append(f"{name}={value!r}")
2335 field_vals_str = ", ".join(field_vals)
2336 return f"{self.__class__.__name__}({field_vals_str})"
2337
2338 @abstractmethod
2339 def __getitem__(self, ssa_val):
2340 # type: (SSAVal) -> tuple[int, ...]
2341 ...
2342
2343 @abstractmethod
2344 def __setitem__(self, ssa_val, value):
2345 # type: (SSAVal, Iterable[int]) -> None
2346 ...
2347
2348
2349 @plain_data(frozen=True, repr=False)
2350 class PreRABaseSimState(BaseSimState):
2351 __slots__ = "ssa_vals",
2352
2353 def __init__(self, ssa_vals, memory):
2354 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2355 super().__init__(memory)
2356 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2357
2358 def _ssa_vals__repr(self):
2359 # type: () -> str
2360 if len(self.ssa_vals) == 0:
2361 return "{}"
2362 items = [] # type: list[str]
2363 CHUNK_SIZE = 4
2364 for k, v in self.ssa_vals.items():
2365 element_strs = [] # type: list[str]
2366 for i, el in enumerate(v):
2367 if i % CHUNK_SIZE != 0:
2368 element_strs.append(" " + hex(el))
2369 else:
2370 element_strs.append("\n " + hex(el))
2371 if len(element_strs) <= CHUNK_SIZE:
2372 element_strs[0] = element_strs[0].lstrip()
2373 if len(element_strs) == 1:
2374 element_strs.append("")
2375 v_str = ",".join(element_strs)
2376 items.append(f"{k!r}: ({v_str})")
2377 if len(items) == 1 and "\n" not in items[0]:
2378 return f"{{{items[0]}}}"
2379 items_str = ",\n".join(items)
2380 return f"{{\n{items_str},\n}}"
2381
2382 def __getitem__(self, ssa_val):
2383 # type: (SSAVal) -> tuple[int, ...]
2384 try:
2385 return self.ssa_vals[ssa_val]
2386 except KeyError:
2387 return self._handle_undefined_ssa_val(ssa_val)
2388
2389 def _handle_undefined_ssa_val(self, ssa_val):
2390 # type: (SSAVal) -> tuple[int, ...]
2391 raise KeyError("SSAVal has no value set", ssa_val)
2392
2393 def __setitem__(self, ssa_val, value):
2394 # type: (SSAVal, Iterable[int]) -> None
2395 value = tuple(map(int, value))
2396 if len(value) != ssa_val.ty.reg_len:
2397 raise ValueError("value has wrong len")
2398 self.ssa_vals[ssa_val] = value
2399
2400
2401 class SimSkipOp(Exception):
2402 pass
2403
2404
2405 @plain_data(frozen=True, repr=False)
2406 @final
2407 class ConstPropagationState(PreRABaseSimState):
2408 __slots__ = "skipped_ops",
2409
2410 def __init__(self, ssa_vals, memory, skipped_ops):
2411 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2412 super().__init__(ssa_vals, memory)
2413 self.skipped_ops = skipped_ops
2414
2415 def _default_memory_value(self):
2416 # type: () -> int
2417 raise SimSkipOp
2418
2419 def _handle_undefined_ssa_val(self, ssa_val):
2420 # type: (SSAVal) -> tuple[int, ...]
2421 raise SimSkipOp
2422
2423 def on_skip(self, op):
2424 # type: (Op) -> None
2425 self.skipped_ops.add(op)
2426
2427
2428 @plain_data(frozen=True, repr=False)
2429 class PreRASimState(PreRABaseSimState):
2430 __slots__ = ()
2431
2432 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2433
2434 @contextmanager
2435 def set_as_current_debugging_state(self):
2436 """ return a context manager that sets self as the current state for
2437 debugging in pdb or similar. This is intended only for use with
2438 `get_current_debugging_state` which should not be used in unit tests
2439 or production code.
2440 """
2441 try:
2442 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2443 yield
2444 finally:
2445 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2446 "inconsistent __CURRENT_DEBUGGING_STATE"
2447
2448 @staticmethod
2449 def get_current_debugging_state():
2450 # type: () -> PreRASimState
2451 """ get the current state for debugging in pdb or similar.
2452
2453 This is intended for use with `set_current_debugging_state`.
2454
2455 This is only intended for debugging, do not use in unit tests or
2456 production code.
2457 """
2458 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2459 raise ValueError("no current debugging state")
2460 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2461
2462
2463 @plain_data(frozen=True, repr=False)
2464 @final
2465 class PostRASimState(BaseSimState):
2466 __slots__ = "ssa_val_to_loc_map", "loc_values"
2467
2468 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2469 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2470 super().__init__(memory)
2471 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2472 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2473 if ssa_val.ty != loc.ty:
2474 raise ValueError(
2475 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2476 self.loc_values = loc_values
2477 for loc in self.loc_values.keys():
2478 if loc.reg_len != 1:
2479 raise ValueError(
2480 "loc_values must only contain Locs with reg_len=1, all "
2481 "larger Locs will be split into reg_len=1 sub-Locs")
2482
2483 def _loc_values__repr(self):
2484 # type: () -> str
2485 locs = sorted(self.loc_values.keys(),
2486 key=lambda v: (v.kind.name, v.start))
2487 items = [] # type: list[str]
2488 for loc in locs:
2489 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2490 items_str = ",\n".join(items)
2491 return f"{{\n{items_str},\n}}"
2492
2493 def __getitem__(self, ssa_val):
2494 # type: (SSAVal) -> tuple[int, ...]
2495 loc = self.ssa_val_to_loc_map[ssa_val]
2496 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2497 retval = [] # type: list[int]
2498 for i in range(loc.reg_len):
2499 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2500 retval.append(self.loc_values.get(subloc, 0))
2501 return tuple(retval)
2502
2503 def __setitem__(self, ssa_val, value):
2504 # type: (SSAVal, Iterable[int]) -> None
2505 value = tuple(map(int, value))
2506 if len(value) != ssa_val.ty.reg_len:
2507 raise ValueError("value has wrong len")
2508 loc = self.ssa_val_to_loc_map[ssa_val]
2509 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2510 for i in range(loc.reg_len):
2511 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2512 self.loc_values[subloc] = value[i]
2513
2514
2515 @plain_data(frozen=True)
2516 class GenAsmState:
2517 __slots__ = "allocated_locs", "output"
2518
2519 def __init__(self, allocated_locs, output=None):
2520 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2521 super().__init__()
2522 self.allocated_locs = FMap(allocated_locs)
2523 for ssa_val, loc in self.allocated_locs.items():
2524 if ssa_val.ty != loc.ty:
2525 raise ValueError(
2526 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2527 if output is None:
2528 output = []
2529 self.output = output
2530
2531 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2532
2533 def loc(self, ssa_val_or_locs, expected_kinds):
2534 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2535 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2536 ssa_val_or_locs = [ssa_val_or_locs]
2537 locs = [] # type: list[Loc]
2538 for i in ssa_val_or_locs:
2539 if isinstance(i, SSAVal):
2540 locs.append(self.allocated_locs[i])
2541 else:
2542 locs.append(i)
2543 if len(locs) == 0:
2544 raise ValueError("invalid Loc sequence: must not be empty")
2545 retval = locs[0].try_concat(*locs[1:])
2546 if retval is None:
2547 raise ValueError("invalid Loc sequence: try_concat failed")
2548 if isinstance(expected_kinds, LocKind):
2549 expected_kinds = expected_kinds,
2550 if retval.kind not in expected_kinds:
2551 if len(expected_kinds) == 1:
2552 expected_kinds = expected_kinds[0]
2553 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2554 f"{retval.kind} expected {expected_kinds}")
2555 return retval
2556
2557 def gpr(self, ssa_val_or_locs, is_vec):
2558 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2559 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2560 vec_str = "*" if is_vec else ""
2561 return vec_str + str(loc.start)
2562
2563 def sgpr(self, ssa_val_or_locs):
2564 # type: (__SSA_VAL_OR_LOCS) -> str
2565 return self.gpr(ssa_val_or_locs, is_vec=False)
2566
2567 def vgpr(self, ssa_val_or_locs):
2568 # type: (__SSA_VAL_OR_LOCS) -> str
2569 return self.gpr(ssa_val_or_locs, is_vec=True)
2570
2571 def stack(self, ssa_val_or_locs):
2572 # type: (__SSA_VAL_OR_LOCS) -> str
2573 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2574 return f"{loc.start}(1)"
2575
2576 def writeln(self, *line_segments):
2577 # type: (*str) -> None
2578 line = " ".join(line_segments)
2579 if isinstance(self.output, list):
2580 self.output.append(line)
2581 else:
2582 self.output.write(line + "\n")