working on adding MergedSSAVal.__mergable_check that ensures copy
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 from abc import ABCMeta, abstractmethod
4 from enum import Enum, unique
5 from functools import lru_cache, total_ordering
6 from io import StringIO
7 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
8 Mapping, Sequence, TypeVar, Union, overload)
9 from weakref import WeakValueDictionary as _WeakVDict
10
11 from cached_property import cached_property
12 from nmutil.plain_data import fields, plain_data
13
14 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
15 final)
16 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
17 OFSet, OSet)
18
19 GPR_SIZE_IN_BYTES = 8
20 BITS_IN_BYTE = 8
21 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
22 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
23
24
25 @final
26 class Fn:
27 def __init__(self):
28 self.ops = [] # type: list[Op]
29 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
30 self.__next_name_suffix = 2
31
32 def _add_op_with_unused_name(self, op, name=""):
33 # type: (Op, str) -> str
34 if op.fn is not self:
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op, "name"):
37 raise ValueError("Op already named")
38 orig_name = name
39 while True:
40 if name != "" and name not in self.__op_names:
41 self.__op_names[name] = op
42 return name
43 name = orig_name + str(self.__next_name_suffix)
44 self.__next_name_suffix += 1
45
46 def __repr__(self):
47 # type: () -> str
48 return "<Fn>"
49
50 def ops_to_str(self, as_python_literal=False, wrap_width=63,
51 python_indent=" ", indent=" "):
52 # type: (bool, int, str, str) -> str
53 l = [] # type: list[str]
54 for op in self.ops:
55 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
56 retval = "\n".join(l)
57 if as_python_literal:
58 l = [python_indent + "\""]
59 for ch in retval:
60 if ch == "\n":
61 l.append(f"\\n\"\n{python_indent}\"")
62 elif ch in "\"\\":
63 l.append("\\" + ch)
64 elif ch.isascii() and ch.isprintable():
65 l.append(ch)
66 else:
67 l.append(repr(ch).strip("\"'"))
68 l.append("\"")
69 retval = "".join(l)
70 empty_end = f"\"\n{python_indent}\"\""
71 if retval.endswith(empty_end):
72 retval = retval[:-len(empty_end)]
73 return retval
74
75 def append_op(self, op):
76 # type: (Op) -> None
77 if op.fn is not self:
78 raise ValueError("can't add Op to wrong Fn")
79 self.ops.append(op)
80
81 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
82 maxvl=1):
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
85 input_vals=input_vals, immediates=immediates, name=name)
86 self.append_op(retval)
87 return retval
88
89 def sim(self, state):
90 # type: (BaseSimState) -> None
91 for op in self.ops:
92 op.sim(state)
93
94 def gen_asm(self, state):
95 # type: (GenAsmState) -> None
96 for op in self.ops:
97 op.gen_asm(state)
98
99 def pre_ra_insert_copies(self):
100 # type: () -> None
101 orig_ops = list(self.ops)
102 copied_outputs = {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs = {} # type: dict[SSAVal, Op]
104 self.ops.clear()
105 for op in orig_ops:
106 for i in range(len(op.input_vals)):
107 inp = copied_outputs[op.input_vals[i]]
108 if inp.ty.base_ty is BaseTy.I64:
109 maxvl = inp.ty.reg_len
110 if inp.ty.reg_len != 1:
111 setvl = self.append_new_op(
112 OpKind.SetVLI, immediates=[maxvl],
113 name=f"{op.name}.inp{i}.setvl")
114 vl = setvl.outputs[0]
115 mv = self.append_new_op(
116 OpKind.VecCopyToReg, input_vals=[inp, vl],
117 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
118 else:
119 mv = self.append_new_op(
120 OpKind.CopyToReg, input_vals=[inp],
121 name=f"{op.name}.inp{i}.copy")
122 op.input_vals[i] = mv.outputs[0]
123 elif inp.ty.base_ty is BaseTy.CA \
124 or inp.ty.base_ty is BaseTy.VL_MAXVL:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
127 # before the ops VL
128 if inp in setvli_outputs:
129 setvl = self.append_new_op(
130 OpKind.SetVLI,
131 immediates=setvli_outputs[inp].immediates,
132 name=f"{op.name}.inp{i}.setvl")
133 inp = setvl.outputs[0]
134 op.input_vals[i] = inp
135 else:
136 assert_never(inp.ty.base_ty)
137 self.ops.append(op)
138 for i, out in enumerate(op.outputs):
139 if op.kind is OpKind.SetVLI:
140 setvli_outputs[out] = op
141 if out.ty.base_ty is BaseTy.I64:
142 maxvl = out.ty.reg_len
143 if out.ty.reg_len != 1:
144 setvl = self.append_new_op(
145 OpKind.SetVLI, immediates=[maxvl],
146 name=f"{op.name}.out{i}.setvl")
147 vl = setvl.outputs[0]
148 mv = self.append_new_op(
149 OpKind.VecCopyFromReg, input_vals=[out, vl],
150 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
151 else:
152 mv = self.append_new_op(
153 OpKind.CopyFromReg, input_vals=[out],
154 name=f"{op.name}.out{i}.copy")
155 copied_outputs[out] = mv.outputs[0]
156 elif out.ty.base_ty is BaseTy.CA \
157 or out.ty.base_ty is BaseTy.VL_MAXVL:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs[out] = out
160 else:
161 assert_never(out.ty.base_ty)
162
163
164 @final
165 @unique
166 @total_ordering
167 class OpStage(Enum):
168 value: Literal[0, 1] # type: ignore
169
170 def __new__(cls, value):
171 # type: (int) -> OpStage
172 value = int(value)
173 if value not in (0, 1):
174 raise ValueError("invalid value")
175 retval = object.__new__(cls)
176 retval._value_ = value
177 return retval
178
179 Early = 0
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
184
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
187 """
188 Late = 1
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
192 those outputs.
193
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
196 """
197
198 def __repr__(self):
199 # type: () -> str
200 return f"OpStage.{self._name_}"
201
202 def __lt__(self, other):
203 # type: (OpStage | object) -> bool
204 if isinstance(other, OpStage):
205 return self.value < other.value
206 return NotImplemented
207
208
209 assert OpStage.Early < OpStage.Late, "early must be less than late"
210
211
212 @plain_data(frozen=True, unsafe_hash=True, repr=False)
213 @final
214 @total_ordering
215 class ProgramPoint(metaclass=InternedMeta):
216 __slots__ = "op_index", "stage"
217
218 def __init__(self, op_index, stage):
219 # type: (int, OpStage) -> None
220 self.op_index = op_index
221 self.stage = stage
222
223 @property
224 def int_value(self):
225 # type: () -> int
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
228 """
229 return self.op_index * 2 + self.stage.value
230
231 @staticmethod
232 def from_int_value(int_value):
233 # type: (int) -> ProgramPoint
234 op_index, stage = divmod(int_value, 2)
235 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
236
237 def next(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint.from_int_value(self.int_value + steps)
240
241 def prev(self, steps=1):
242 # type: (int) -> ProgramPoint
243 return self.next(steps=-steps)
244
245 def __lt__(self, other):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other, ProgramPoint):
248 return NotImplemented
249 if self.op_index != other.op_index:
250 return self.op_index < other.op_index
251 return self.stage < other.stage
252
253 def __repr__(self):
254 # type: () -> str
255 return f"<ops[{self.op_index}]:{self.stage._name_}>"
256
257
258 @plain_data(frozen=True, unsafe_hash=True, repr=False)
259 @final
260 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
261 __slots__ = "start", "stop"
262
263 def __init__(self, start, stop):
264 # type: (ProgramPoint, ProgramPoint) -> None
265 self.start = start
266 self.stop = stop
267
268 @cached_property
269 def int_value_range(self):
270 # type: () -> range
271 return range(self.start.int_value, self.stop.int_value)
272
273 @staticmethod
274 def from_int_value_range(int_value_range):
275 # type: (range) -> ProgramRange
276 if int_value_range.step != 1:
277 raise ValueError("int_value_range must have step == 1")
278 return ProgramRange(
279 start=ProgramPoint.from_int_value(int_value_range.start),
280 stop=ProgramPoint.from_int_value(int_value_range.stop))
281
282 @overload
283 def __getitem__(self, __idx):
284 # type: (int) -> ProgramPoint
285 ...
286
287 @overload
288 def __getitem__(self, __idx):
289 # type: (slice) -> ProgramRange
290 ...
291
292 def __getitem__(self, __idx):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v = range(self.start.int_value, self.stop.int_value)[__idx]
295 if isinstance(v, int):
296 return ProgramPoint.from_int_value(v)
297 return ProgramRange.from_int_value_range(v)
298
299 def __len__(self):
300 # type: () -> int
301 return len(self.int_value_range)
302
303 def __iter__(self):
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint.from_int_value, self.int_value_range)
306
307 def __repr__(self):
308 # type: () -> str
309 start = repr(self.start).lstrip("<").rstrip(">")
310 stop = repr(self.stop).lstrip("<").rstrip(">")
311 return f"<range:{start}..{stop}>"
312
313
314 @plain_data(frozen=True, unsafe_hash=True)
315 @final
316 class SSAValSubReg(metaclass=InternedMeta):
317 __slots__ = "ssa_val", "reg_idx"
318
319 def __init__(self, ssa_val, reg_idx):
320 # type: (SSAVal, int) -> None
321 if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
322 raise ValueError("reg_idx out of range")
323 self.ssa_val = ssa_val
324 self.reg_idx = reg_idx
325
326
327 @plain_data(frozen=True, eq=False, repr=False)
328 @final
329 class FnAnalysis:
330 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
331 "def_program_ranges", "use_program_points",
332 "all_program_points")
333
334 def __init__(self, fn):
335 # type: (Fn) -> None
336 self.fn = fn
337 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
338 self.all_program_points = ProgramRange(
339 start=ProgramPoint(op_index=0, stage=OpStage.Early),
340 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
341 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
342 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
343 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
344 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
345 for op in fn.ops:
346 for use in op.input_uses:
347 uses[use.ssa_val].add(use)
348 use_program_point = self.__get_use_program_point(use)
349 use_program_points[use] = use_program_point
350 live_range_stops[use.ssa_val] = max(
351 live_range_stops[use.ssa_val], use_program_point.next())
352 for out in op.outputs:
353 uses[out] = OSet()
354 def_program_range = self.__get_def_program_range(out)
355 def_program_ranges[out] = def_program_range
356 live_range_stops[out] = def_program_range.stop
357 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
358 self.def_program_ranges = FMap(def_program_ranges)
359 self.use_program_points = FMap(use_program_points)
360 live_ranges = {} # type: dict[SSAVal, ProgramRange]
361 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
362 for ssa_val in uses.keys():
363 live_ranges[ssa_val] = live_range = ProgramRange(
364 start=self.def_program_ranges[ssa_val].start,
365 stop=live_range_stops[ssa_val])
366 for program_point in live_range:
367 live_at[program_point].add(ssa_val)
368 self.live_ranges = FMap(live_ranges)
369 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
370 self.copies # initialize
371 self.const_ssa_vals # initialize
372 self.const_ssa_val_sub_regs # initialize
373
374 def __get_def_program_range(self, ssa_val):
375 # type: (SSAVal) -> ProgramRange
376 write_stage = ssa_val.defining_descriptor.write_stage
377 start = ProgramPoint(
378 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
379 # always include late stage of ssa_val.op, to ensure outputs always
380 # overlap all other outputs.
381 # stop is exclusive, so we need the next program point.
382 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
383 return ProgramRange(start=start, stop=stop)
384
385 def __get_use_program_point(self, ssa_use):
386 # type: (SSAUse) -> ProgramPoint
387 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
388 "assumed here, ensured by GenericOpProperties.__init__"
389 return ProgramPoint(
390 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
391
392 def __eq__(self, other):
393 # type: (FnAnalysis | Any) -> bool
394 if isinstance(other, FnAnalysis):
395 return self.fn == other.fn
396 return NotImplemented
397
398 def __hash__(self):
399 # type: () -> int
400 return hash(self.fn)
401
402 def __repr__(self):
403 # type: () -> str
404 return "<FnAnalysis>"
405
406 @cached_property
407 def copies(self):
408 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
409 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
410 a copy of, looking through all layers of copies. The map excludes all
411 SSAValSubRegs that aren't copies of other SSAValSubRegs.
412 """
413 retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
414 for op in self.op_indexes.keys():
415 if not op.properties.is_copy:
416 continue
417 copy_reg_len = op.properties.copy_reg_len
418 copy_inputs = [] # type: list[SSAValSubReg]
419 for inp in op.input_vals[:op.properties.copy_inputs_len]:
420 for inp_sub_reg in inp.ssa_val_sub_regs:
421 # propagate copies of copies
422 inp_sub_reg = retval.get(inp_sub_reg, inp_sub_reg)
423 copy_inputs.append(inp_sub_reg)
424 assert len(copy_inputs) == copy_reg_len, "logic error"
425 copy_outputs = [] # type: list[SSAValSubReg]
426 for out in op.outputs[:op.properties.copy_outputs_len]:
427 copy_outputs.extend(out.ssa_val_sub_regs)
428 assert len(copy_outputs) == copy_reg_len, "logic error"
429 for inp, out in zip(copy_inputs, copy_outputs):
430 retval[out] = inp
431 return FMap(retval)
432
433 @cached_property
434 def const_ssa_vals(self):
435 # type: () -> FMap[SSAVal, tuple[int, ...]]
436 state = ConstPropagationState(
437 ssa_vals={}, memory={}, skipped_ops=OSet())
438 self.fn.sim(state)
439 return FMap(state.ssa_vals)
440
441 @cached_property
442 def const_ssa_val_sub_regs(self):
443 # type: () -> FMap[SSAValSubReg, int]
444 retval = {} # type: dict[SSAValSubReg, int]
445 for ssa_val, const_val in self.const_ssa_vals.items():
446 assert ssa_val.ty.reg_len == len(const_val), "logic error"
447 for reg_idx, v in enumerate(const_val):
448 retval[SSAValSubReg(ssa_val, reg_idx)] = v
449 return FMap(retval)
450
451 def are_always_equal(self, a, b):
452 # type: (SSAValSubReg, SSAValSubReg) -> bool
453 """check if a and b are known to be always equal to each other.
454 This means they can be allocated to the same location if other
455 constraints don't prevent that.
456
457 this can happen for a number of reasons, such as:
458 * a and b are copies of the same thing
459 * a and b are known to be constants and they have the same value
460 """
461 if a.ssa_val.base_ty != b.ssa_val.base_ty:
462 return False # can't be equal, they have different types
463 # look through copies
464 a = self.copies.get(a, a)
465 b = self.copies.get(b, b)
466 if a == b:
467 return True
468 # check if they have the same constant value
469 try:
470 a_const_val = self.const_ssa_val_sub_regs[a]
471 b_const_val = self.const_ssa_val_sub_regs[b]
472 if a_const_val == b_const_val:
473 return True
474 except KeyError:
475 pass
476 return False
477
478
479 @unique
480 @final
481 class BaseTy(Enum):
482 I64 = enum.auto()
483 CA = enum.auto()
484 VL_MAXVL = enum.auto()
485
486 @cached_property
487 def only_scalar(self):
488 # type: () -> bool
489 if self is BaseTy.I64:
490 return False
491 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
492 return True
493 else:
494 assert_never(self)
495
496 @cached_property
497 def max_reg_len(self):
498 # type: () -> int
499 if self is BaseTy.I64:
500 return 128
501 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
502 return 1
503 else:
504 assert_never(self)
505
506 def __repr__(self):
507 return "BaseTy." + self._name_
508
509
510 @plain_data(frozen=True, unsafe_hash=True, repr=False)
511 @final
512 class Ty(metaclass=InternedMeta):
513 __slots__ = "base_ty", "reg_len"
514
515 @staticmethod
516 def validate(base_ty, reg_len):
517 # type: (BaseTy, int) -> str | None
518 """ return a string with the error if the combination is invalid,
519 otherwise return None
520 """
521 if base_ty.only_scalar and reg_len != 1:
522 return f"can't create a vector of an only-scalar type: {base_ty}"
523 if reg_len < 1 or reg_len > base_ty.max_reg_len:
524 return "reg_len out of range"
525 return None
526
527 def __init__(self, base_ty, reg_len):
528 # type: (BaseTy, int) -> None
529 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
530 if msg is not None:
531 raise ValueError(msg)
532 self.base_ty = base_ty
533 self.reg_len = reg_len
534
535 def __repr__(self):
536 # type: () -> str
537 if self.reg_len != 1:
538 reg_len = f"*{self.reg_len}"
539 else:
540 reg_len = ""
541 return f"<{self.base_ty._name_}{reg_len}>"
542
543
544 @unique
545 @final
546 class LocKind(Enum):
547 GPR = enum.auto()
548 StackI64 = enum.auto()
549 CA = enum.auto()
550 VL_MAXVL = enum.auto()
551
552 @cached_property
553 def base_ty(self):
554 # type: () -> BaseTy
555 if self is LocKind.GPR or self is LocKind.StackI64:
556 return BaseTy.I64
557 if self is LocKind.CA:
558 return BaseTy.CA
559 if self is LocKind.VL_MAXVL:
560 return BaseTy.VL_MAXVL
561 else:
562 assert_never(self)
563
564 @cached_property
565 def loc_count(self):
566 # type: () -> int
567 if self is LocKind.StackI64:
568 return 512
569 if self is LocKind.GPR or self is LocKind.CA \
570 or self is LocKind.VL_MAXVL:
571 return self.base_ty.max_reg_len
572 else:
573 assert_never(self)
574
575 def __repr__(self):
576 return "LocKind." + self._name_
577
578
579 @final
580 @unique
581 class LocSubKind(Enum):
582 BASE_GPR = enum.auto()
583 SV_EXTRA2_VGPR = enum.auto()
584 SV_EXTRA2_SGPR = enum.auto()
585 SV_EXTRA3_VGPR = enum.auto()
586 SV_EXTRA3_SGPR = enum.auto()
587 StackI64 = enum.auto()
588 CA = enum.auto()
589 VL_MAXVL = enum.auto()
590
591 @cached_property
592 def kind(self):
593 # type: () -> LocKind
594 # pyright fails typechecking when using `in` here:
595 # reported: https://github.com/microsoft/pyright/issues/4102
596 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
597 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
598 LocSubKind.SV_EXTRA3_SGPR):
599 return LocKind.GPR
600 if self is LocSubKind.StackI64:
601 return LocKind.StackI64
602 if self is LocSubKind.CA:
603 return LocKind.CA
604 if self is LocSubKind.VL_MAXVL:
605 return LocKind.VL_MAXVL
606 assert_never(self)
607
608 @property
609 def base_ty(self):
610 return self.kind.base_ty
611
612 @lru_cache()
613 def allocatable_locs(self, ty):
614 # type: (Ty) -> LocSet
615 if ty.base_ty != self.base_ty:
616 raise ValueError("type mismatch")
617 if self is LocSubKind.BASE_GPR:
618 starts = range(32)
619 elif self is LocSubKind.SV_EXTRA2_VGPR:
620 starts = range(0, 128, 2)
621 elif self is LocSubKind.SV_EXTRA2_SGPR:
622 starts = range(64)
623 elif self is LocSubKind.SV_EXTRA3_VGPR \
624 or self is LocSubKind.SV_EXTRA3_SGPR:
625 starts = range(128)
626 elif self is LocSubKind.StackI64:
627 starts = range(LocKind.StackI64.loc_count)
628 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
629 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
630 else:
631 assert_never(self)
632 retval = [] # type: list[Loc]
633 for start in starts:
634 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
635 if loc is None:
636 continue
637 conflicts = False
638 for special_loc in SPECIAL_GPRS:
639 if loc.conflicts(special_loc):
640 conflicts = True
641 break
642 if not conflicts:
643 retval.append(loc)
644 return LocSet(retval)
645
646 def __repr__(self):
647 return "LocSubKind." + self._name_
648
649
650 @plain_data(frozen=True, unsafe_hash=True)
651 @final
652 class GenericTy(metaclass=InternedMeta):
653 __slots__ = "base_ty", "is_vec"
654
655 def __init__(self, base_ty, is_vec):
656 # type: (BaseTy, bool) -> None
657 self.base_ty = base_ty
658 if base_ty.only_scalar and is_vec:
659 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
660 self.is_vec = is_vec
661
662 def instantiate(self, maxvl):
663 # type: (int) -> Ty
664 # here's where subvl and elwid would be accounted for
665 if self.is_vec:
666 return Ty(self.base_ty, maxvl)
667 return Ty(self.base_ty, 1)
668
669 def can_instantiate_to(self, ty):
670 # type: (Ty) -> bool
671 if self.base_ty != ty.base_ty:
672 return False
673 if self.is_vec:
674 return True
675 return ty.reg_len == 1
676
677
678 @plain_data(frozen=True, unsafe_hash=True)
679 @final
680 class Loc(metaclass=InternedMeta):
681 __slots__ = "kind", "start", "reg_len"
682
683 @staticmethod
684 def validate(kind, start, reg_len):
685 # type: (LocKind, int, int) -> str | None
686 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
687 if msg is not None:
688 return msg
689 if reg_len > kind.loc_count:
690 return "invalid reg_len"
691 if start < 0 or start + reg_len > kind.loc_count:
692 return "start not in valid range"
693 return None
694
695 @staticmethod
696 def try_make(kind, start, reg_len):
697 # type: (LocKind, int, int) -> Loc | None
698 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
699 if msg is not None:
700 return None
701 return Loc(kind=kind, start=start, reg_len=reg_len)
702
703 def __init__(self, kind, start, reg_len):
704 # type: (LocKind, int, int) -> None
705 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
706 if msg is not None:
707 raise ValueError(msg)
708 self.kind = kind
709 self.reg_len = reg_len
710 self.start = start
711
712 def conflicts(self, other):
713 # type: (Loc) -> bool
714 return (self.kind == other.kind
715 and self.start < other.stop and other.start < self.stop)
716
717 @staticmethod
718 def make_ty(kind, reg_len):
719 # type: (LocKind, int) -> Ty
720 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
721
722 @cached_property
723 def ty(self):
724 # type: () -> Ty
725 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
726
727 @property
728 def stop(self):
729 # type: () -> int
730 return self.start + self.reg_len
731
732 def try_concat(self, *others):
733 # type: (*Loc | None) -> Loc | None
734 reg_len = self.reg_len
735 stop = self.stop
736 for other in others:
737 if other is None or other.kind != self.kind:
738 return None
739 if stop != other.start:
740 return None
741 stop = other.stop
742 reg_len += other.reg_len
743 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
744
745 def get_subloc_at_offset(self, subloc_ty, offset):
746 # type: (Ty, int) -> Loc
747 if subloc_ty.base_ty != self.kind.base_ty:
748 raise ValueError("BaseTy mismatch")
749 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
750 raise ValueError("invalid sub-Loc: offset and/or "
751 "subloc_ty.reg_len out of range")
752 return Loc(kind=self.kind,
753 start=self.start + offset, reg_len=subloc_ty.reg_len)
754
755
756 SPECIAL_GPRS = (
757 Loc(kind=LocKind.GPR, start=0, reg_len=1),
758 Loc(kind=LocKind.GPR, start=1, reg_len=1),
759 Loc(kind=LocKind.GPR, start=2, reg_len=1),
760 Loc(kind=LocKind.GPR, start=13, reg_len=1),
761 )
762
763
764 @final
765 class LocSet(OFSet[Loc], metaclass=InternedMeta):
766 def __init__(self, __locs=()):
767 # type: (Iterable[Loc]) -> None
768 super().__init__(__locs)
769 if isinstance(__locs, LocSet):
770 self.__starts = __locs.starts
771 self.__ty = __locs.ty
772 return
773 starts = {i: BitSet() for i in LocKind}
774 ty = None # type: None | Ty
775 for loc in self:
776 if ty is None:
777 ty = loc.ty
778 if ty != loc.ty:
779 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
780 starts[loc.kind].add(loc.start)
781 self.__starts = FMap(
782 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
783 self.__ty = ty
784
785 @property
786 def starts(self):
787 # type: () -> FMap[LocKind, FBitSet]
788 return self.__starts
789
790 @property
791 def ty(self):
792 # type: () -> Ty | None
793 return self.__ty
794
795 @cached_property
796 def stops(self):
797 # type: () -> FMap[LocKind, FBitSet]
798 if self.ty is None:
799 return FMap()
800 sh = self.ty.reg_len
801 return FMap(
802 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
803
804 @property
805 def kinds(self):
806 # type: () -> AbstractSet[LocKind]
807 return self.starts.keys()
808
809 @property
810 def reg_len(self):
811 # type: () -> int | None
812 if self.ty is None:
813 return None
814 return self.ty.reg_len
815
816 @property
817 def base_ty(self):
818 # type: () -> BaseTy | None
819 if self.ty is None:
820 return None
821 return self.ty.base_ty
822
823 def concat(self, *others):
824 # type: (*LocSet) -> LocSet
825 if self.ty is None:
826 return LocSet()
827 base_ty = self.ty.base_ty
828 reg_len = self.ty.reg_len
829 starts = {k: BitSet(v) for k, v in self.starts.items()}
830 for other in others:
831 if other.ty is None:
832 return LocSet()
833 if other.ty.base_ty != base_ty:
834 return LocSet()
835 for kind, other_starts in other.starts.items():
836 if kind not in starts:
837 continue
838 starts[kind].bits &= other_starts.bits >> reg_len
839 if starts[kind] == 0:
840 del starts[kind]
841 if len(starts) == 0:
842 return LocSet()
843 reg_len += other.ty.reg_len
844
845 def locs():
846 # type: () -> Iterable[Loc]
847 for kind, v in starts.items():
848 for start in v:
849 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
850 if loc is not None:
851 yield loc
852 return LocSet(locs())
853
854 @lru_cache(maxsize=None, typed=True)
855 def max_conflicts_with(self, other):
856 # type: (LocSet | Loc) -> int
857 """the largest number of Locs in `self` that a single Loc
858 from `other` can conflict with
859 """
860 if isinstance(other, LocSet):
861 return max(self.max_conflicts_with(i) for i in other)
862 else:
863 return sum(other.conflicts(i) for i in self)
864
865 def __repr__(self):
866 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
867
868
869 @plain_data(frozen=True, unsafe_hash=True)
870 @final
871 class GenericOperandDesc(metaclass=InternedMeta):
872 """generic Op operand descriptor"""
873 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
874 "write_stage")
875
876 def __init__(
877 self, ty, # type: GenericTy
878 sub_kinds, # type: Iterable[LocSubKind]
879 *,
880 fixed_loc=None, # type: Loc | None
881 tied_input_index=None, # type: int | None
882 spread=False, # type: bool
883 write_stage=OpStage.Early, # type: OpStage
884 ):
885 # type: (...) -> None
886 self.ty = ty
887 self.sub_kinds = OFSet(sub_kinds)
888 if len(self.sub_kinds) == 0:
889 raise ValueError("sub_kinds can't be empty")
890 self.fixed_loc = fixed_loc
891 if fixed_loc is not None:
892 if tied_input_index is not None:
893 raise ValueError("operand can't be both tied and fixed")
894 if not ty.can_instantiate_to(fixed_loc.ty):
895 raise ValueError(
896 f"fixed_loc has incompatible type for given generic "
897 f"type: fixed_loc={fixed_loc} generic ty={ty}")
898 if len(self.sub_kinds) != 1:
899 raise ValueError(
900 "multiple sub_kinds not allowed for fixed operand")
901 for sub_kind in self.sub_kinds:
902 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
903 raise ValueError(
904 f"fixed_loc not in given sub_kind: "
905 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
906 for sub_kind in self.sub_kinds:
907 if sub_kind.base_ty != ty.base_ty:
908 raise ValueError(f"sub_kind is incompatible with type: "
909 f"sub_kind={sub_kind} ty={ty}")
910 if tied_input_index is not None and tied_input_index < 0:
911 raise ValueError("invalid tied_input_index")
912 self.tied_input_index = tied_input_index
913 self.spread = spread
914 if spread:
915 if self.tied_input_index is not None:
916 raise ValueError("operand can't be both spread and tied")
917 if self.fixed_loc is not None:
918 raise ValueError("operand can't be both spread and fixed")
919 if self.ty.is_vec:
920 raise ValueError("operand can't be both spread and vector")
921 self.write_stage = write_stage
922
923 @cached_property
924 def ty_before_spread(self):
925 # type: () -> GenericTy
926 if self.spread:
927 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
928 return self.ty
929
930 def tied_to_input(self, tied_input_index):
931 # type: (int) -> Self
932 return GenericOperandDesc(self.ty, self.sub_kinds,
933 tied_input_index=tied_input_index,
934 write_stage=self.write_stage)
935
936 def with_fixed_loc(self, fixed_loc):
937 # type: (Loc) -> Self
938 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
939 write_stage=self.write_stage)
940
941 def with_write_stage(self, write_stage):
942 # type: (OpStage) -> Self
943 return GenericOperandDesc(self.ty, self.sub_kinds,
944 fixed_loc=self.fixed_loc,
945 tied_input_index=self.tied_input_index,
946 spread=self.spread,
947 write_stage=write_stage)
948
949 def instantiate(self, maxvl):
950 # type: (int) -> Iterable[OperandDesc]
951 # assumes all spread operands have ty.reg_len = 1
952 rep_count = 1
953 if self.spread:
954 rep_count = maxvl
955 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
956
957 def locs_before_spread():
958 # type: () -> Iterable[Loc]
959 if self.fixed_loc is not None:
960 if ty_before_spread != self.fixed_loc.ty:
961 raise ValueError(
962 f"instantiation failed: type mismatch with fixed_loc: "
963 f"instantiated type: {ty_before_spread} "
964 f"fixed_loc: {self.fixed_loc}")
965 yield self.fixed_loc
966 return
967 for sub_kind in self.sub_kinds:
968 yield from sub_kind.allocatable_locs(ty_before_spread)
969 loc_set_before_spread = LocSet(locs_before_spread())
970 for idx in range(rep_count):
971 if not self.spread:
972 idx = None
973 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
974 tied_input_index=self.tied_input_index,
975 spread_index=idx, write_stage=self.write_stage)
976
977
978 @plain_data(frozen=True, unsafe_hash=True)
979 @final
980 class OperandDesc(metaclass=InternedMeta):
981 """Op operand descriptor"""
982 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
983 "write_stage")
984
985 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
986 write_stage):
987 # type: (LocSet, int | None, int | None, OpStage) -> None
988 if len(loc_set_before_spread) == 0:
989 raise ValueError("loc_set_before_spread must not be empty")
990 self.loc_set_before_spread = loc_set_before_spread
991 self.tied_input_index = tied_input_index
992 if self.tied_input_index is not None and spread_index is not None:
993 raise ValueError("operand can't be both spread and tied")
994 self.spread_index = spread_index
995 self.write_stage = write_stage
996
997 @cached_property
998 def ty_before_spread(self):
999 # type: () -> Ty
1000 ty = self.loc_set_before_spread.ty
1001 assert ty is not None, (
1002 "__init__ checked that the LocSet isn't empty, "
1003 "non-empty LocSets should always have ty set")
1004 return ty
1005
1006 @cached_property
1007 def ty(self):
1008 """ Ty after any spread is applied """
1009 if self.spread_index is not None:
1010 # assumes all spread operands have ty.reg_len = 1
1011 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
1012 return self.ty_before_spread
1013
1014 @property
1015 def reg_offset_in_unspread(self):
1016 """ the number of reg-sized slots in the unspread Loc before self's Loc
1017
1018 e.g. if the unspread Loc containing self is:
1019 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1020 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1021 then reg_offset_into_unspread == 2 == 10 - 8
1022 """
1023 if self.spread_index is None:
1024 return 0
1025 return self.spread_index * self.ty.reg_len
1026
1027
1028 OD_BASE_SGPR = GenericOperandDesc(
1029 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1030 sub_kinds=[LocSubKind.BASE_GPR])
1031 OD_EXTRA3_SGPR = GenericOperandDesc(
1032 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1033 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
1034 OD_EXTRA3_VGPR = GenericOperandDesc(
1035 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1036 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
1037 OD_EXTRA2_SGPR = GenericOperandDesc(
1038 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
1039 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
1040 OD_EXTRA2_VGPR = GenericOperandDesc(
1041 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
1042 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
1043 OD_CA = GenericOperandDesc(
1044 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
1045 sub_kinds=[LocSubKind.CA])
1046 OD_VL = GenericOperandDesc(
1047 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
1048 sub_kinds=[LocSubKind.VL_MAXVL])
1049
1050
1051 @plain_data(frozen=True, unsafe_hash=True)
1052 @final
1053 class GenericOpProperties(metaclass=InternedMeta):
1054 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
1055 "is_copy", "is_load_immediate", "has_side_effects")
1056
1057 def __init__(
1058 self, demo_asm, # type: str
1059 inputs, # type: Iterable[GenericOperandDesc]
1060 outputs, # type: Iterable[GenericOperandDesc]
1061 immediates=(), # type: Iterable[range]
1062 is_copy=False, # type: bool
1063 is_load_immediate=False, # type: bool
1064 has_side_effects=False, # type: bool
1065 ):
1066 # type: (...) -> None
1067 self.demo_asm = demo_asm # type: str
1068 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
1069 for inp in self.inputs:
1070 if inp.tied_input_index is not None:
1071 raise ValueError(
1072 f"tied_input_index is not allowed on inputs: {inp}")
1073 if inp.write_stage is not OpStage.Early:
1074 raise ValueError(
1075 f"write_stage is not allowed on inputs: {inp}")
1076 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
1077 fixed_locs = [] # type: list[tuple[Loc, int]]
1078 for idx, out in enumerate(self.outputs):
1079 if out.tied_input_index is not None:
1080 if out.tied_input_index >= len(self.inputs):
1081 raise ValueError(f"tied_input_index out of range: {out}")
1082 tied_inp = self.inputs[out.tied_input_index]
1083 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
1084 .with_write_stage(out.write_stage)
1085 if expected_out != out:
1086 raise ValueError(f"output can't be tied to non-equivalent "
1087 f"input: {out} tied to {tied_inp}")
1088 if out.fixed_loc is not None:
1089 for other_fixed_loc, other_idx in fixed_locs:
1090 if not other_fixed_loc.conflicts(out.fixed_loc):
1091 continue
1092 raise ValueError(
1093 f"conflicting fixed_locs: outputs[{idx}] and "
1094 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1095 f"with {other_fixed_loc}")
1096 fixed_locs.append((out.fixed_loc, idx))
1097 self.immediates = tuple(immediates) # type: tuple[range, ...]
1098 self.is_copy = is_copy # type: bool
1099 self.is_load_immediate = is_load_immediate # type: bool
1100 self.has_side_effects = has_side_effects # type: bool
1101
1102
1103 @plain_data(frozen=True, unsafe_hash=True)
1104 @final
1105 class OpProperties(metaclass=InternedMeta):
1106 __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1107
1108 def __init__(self, kind, maxvl):
1109 # type: (OpKind, int) -> None
1110 self.kind = kind # type: OpKind
1111 inputs = [] # type: list[OperandDesc]
1112 for inp in self.generic.inputs:
1113 inputs.extend(inp.instantiate(maxvl=maxvl))
1114 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1115 outputs = [] # type: list[OperandDesc]
1116 for out in self.generic.outputs:
1117 outputs.extend(out.instantiate(maxvl=maxvl))
1118 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1119 self.maxvl = maxvl # type: int
1120 copy_input_reg_len = 0
1121 for inp in self.inputs[:self.copy_inputs_len]:
1122 copy_input_reg_len += inp.ty.reg_len
1123 copy_output_reg_len = 0
1124 for out in self.outputs[:self.copy_outputs_len]:
1125 copy_output_reg_len += out.ty.reg_len
1126 if copy_input_reg_len != copy_output_reg_len:
1127 raise ValueError(f"invalid copy: copy's input reg len must "
1128 f"match its output reg len: "
1129 f"{copy_input_reg_len} != {copy_output_reg_len}")
1130 self.copy_reg_len = copy_input_reg_len
1131
1132 @property
1133 def generic(self):
1134 # type: () -> GenericOpProperties
1135 return self.kind.properties
1136
1137 @property
1138 def immediates(self):
1139 # type: () -> tuple[range, ...]
1140 return self.generic.immediates
1141
1142 @property
1143 def demo_asm(self):
1144 # type: () -> str
1145 return self.generic.demo_asm
1146
1147 @property
1148 def is_copy(self):
1149 # type: () -> bool
1150 return self.generic.is_copy
1151
1152 @property
1153 def is_load_immediate(self):
1154 # type: () -> bool
1155 return self.generic.is_load_immediate
1156
1157 @property
1158 def has_side_effects(self):
1159 # type: () -> bool
1160 return self.generic.has_side_effects
1161
1162 @cached_property
1163 def copy_inputs_len(self):
1164 # type: () -> int
1165 if not self.is_copy:
1166 return 0
1167 if self.inputs[0].spread_index is None:
1168 return 1
1169 retval = 0
1170 for i, inp in enumerate(self.inputs):
1171 if inp.spread_index != i:
1172 break
1173 retval += 1
1174 return retval
1175
1176 @cached_property
1177 def copy_outputs_len(self):
1178 # type: () -> int
1179 if not self.is_copy:
1180 return 0
1181 if self.outputs[0].spread_index is None:
1182 return 1
1183 retval = 0
1184 for i, out in enumerate(self.outputs):
1185 if out.spread_index != i:
1186 break
1187 retval += 1
1188 return retval
1189
1190
1191 IMM_S16 = range(-1 << 15, 1 << 15)
1192
1193 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1194 _SIM_FN2 = Callable[[], _SIM_FN]
1195 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1196 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1197 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1198 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1199
1200
1201 @unique
1202 @final
1203 class OpKind(Enum):
1204 def __init__(self, properties):
1205 # type: (GenericOpProperties) -> None
1206 super().__init__()
1207 self.__properties = properties
1208
1209 @property
1210 def properties(self):
1211 # type: () -> GenericOpProperties
1212 return self.__properties
1213
1214 def instantiate(self, maxvl):
1215 # type: (int) -> OpProperties
1216 return OpProperties(self, maxvl=maxvl)
1217
1218 def __repr__(self):
1219 # type: () -> str
1220 return "OpKind." + self._name_
1221
1222 @cached_property
1223 def sim(self):
1224 # type: () -> _SIM_FN
1225 return _SIM_FNS[self.properties]()
1226
1227 @cached_property
1228 def gen_asm(self):
1229 # type: () -> _GEN_ASM_FN
1230 return _GEN_ASMS[self.properties]()
1231
1232 @staticmethod
1233 def __clearca_sim(op, state):
1234 # type: (Op, BaseSimState) -> None
1235 state[op.outputs[0]] = False,
1236
1237 @staticmethod
1238 def __clearca_gen_asm(op, state):
1239 # type: (Op, GenAsmState) -> None
1240 state.writeln("addic 0, 0, 0")
1241 ClearCA = GenericOpProperties(
1242 demo_asm="addic 0, 0, 0",
1243 inputs=[],
1244 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1245 )
1246 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1247 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1248
1249 @staticmethod
1250 def __setca_sim(op, state):
1251 # type: (Op, BaseSimState) -> None
1252 state[op.outputs[0]] = True,
1253
1254 @staticmethod
1255 def __setca_gen_asm(op, state):
1256 # type: (Op, GenAsmState) -> None
1257 state.writeln("subfc 0, 0, 0")
1258 SetCA = GenericOpProperties(
1259 demo_asm="subfc 0, 0, 0",
1260 inputs=[],
1261 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1262 )
1263 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1264 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1265
1266 @staticmethod
1267 def __svadde_sim(op, state):
1268 # type: (Op, BaseSimState) -> None
1269 RA = state[op.input_vals[0]]
1270 RB = state[op.input_vals[1]]
1271 carry, = state[op.input_vals[2]]
1272 VL, = state[op.input_vals[3]]
1273 RT = [] # type: list[int]
1274 for i in range(VL):
1275 v = RA[i] + RB[i] + carry
1276 RT.append(v & GPR_VALUE_MASK)
1277 carry = (v >> GPR_SIZE_IN_BITS) != 0
1278 state[op.outputs[0]] = tuple(RT)
1279 state[op.outputs[1]] = carry,
1280
1281 @staticmethod
1282 def __svadde_gen_asm(op, state):
1283 # type: (Op, GenAsmState) -> None
1284 RT = state.vgpr(op.outputs[0])
1285 RA = state.vgpr(op.input_vals[0])
1286 RB = state.vgpr(op.input_vals[1])
1287 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1288 SvAddE = GenericOpProperties(
1289 demo_asm="sv.adde *RT, *RA, *RB",
1290 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1291 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1292 )
1293 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1294 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1295
1296 @staticmethod
1297 def __addze_sim(op, state):
1298 # type: (Op, BaseSimState) -> None
1299 RA, = state[op.input_vals[0]]
1300 carry, = state[op.input_vals[1]]
1301 v = RA + carry
1302 RT = v & GPR_VALUE_MASK
1303 carry = (v >> GPR_SIZE_IN_BITS) != 0
1304 state[op.outputs[0]] = RT,
1305 state[op.outputs[1]] = carry,
1306
1307 @staticmethod
1308 def __addze_gen_asm(op, state):
1309 # type: (Op, GenAsmState) -> None
1310 RT = state.vgpr(op.outputs[0])
1311 RA = state.vgpr(op.input_vals[0])
1312 state.writeln(f"addze {RT}, {RA}")
1313 AddZE = GenericOpProperties(
1314 demo_asm="addze RT, RA",
1315 inputs=[OD_BASE_SGPR, OD_CA],
1316 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1317 )
1318 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1319 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1320
1321 @staticmethod
1322 def __svsubfe_sim(op, state):
1323 # type: (Op, BaseSimState) -> None
1324 RA = state[op.input_vals[0]]
1325 RB = state[op.input_vals[1]]
1326 carry, = state[op.input_vals[2]]
1327 VL, = state[op.input_vals[3]]
1328 RT = [] # type: list[int]
1329 for i in range(VL):
1330 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1331 RT.append(v & GPR_VALUE_MASK)
1332 carry = (v >> GPR_SIZE_IN_BITS) != 0
1333 state[op.outputs[0]] = tuple(RT)
1334 state[op.outputs[1]] = carry,
1335
1336 @staticmethod
1337 def __svsubfe_gen_asm(op, state):
1338 # type: (Op, GenAsmState) -> None
1339 RT = state.vgpr(op.outputs[0])
1340 RA = state.vgpr(op.input_vals[0])
1341 RB = state.vgpr(op.input_vals[1])
1342 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1343 SvSubFE = GenericOpProperties(
1344 demo_asm="sv.subfe *RT, *RA, *RB",
1345 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1346 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1347 )
1348 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1349 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1350
1351 @staticmethod
1352 def __svandvs_sim(op, state):
1353 # type: (Op, BaseSimState) -> None
1354 RA = state[op.input_vals[0]]
1355 RB, = state[op.input_vals[1]]
1356 VL, = state[op.input_vals[2]]
1357 RT = [] # type: list[int]
1358 for i in range(VL):
1359 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1360 state[op.outputs[0]] = tuple(RT)
1361
1362 @staticmethod
1363 def __svandvs_gen_asm(op, state):
1364 # type: (Op, GenAsmState) -> None
1365 RT = state.vgpr(op.outputs[0])
1366 RA = state.vgpr(op.input_vals[0])
1367 RB = state.sgpr(op.input_vals[1])
1368 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1369 SvAndVS = GenericOpProperties(
1370 demo_asm="sv.and *RT, *RA, RB",
1371 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1372 outputs=[OD_EXTRA3_VGPR],
1373 )
1374 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1375 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1376
1377 @staticmethod
1378 def __svmaddedu_sim(op, state):
1379 # type: (Op, BaseSimState) -> None
1380 RA = state[op.input_vals[0]]
1381 RB, = state[op.input_vals[1]]
1382 carry, = state[op.input_vals[2]]
1383 VL, = state[op.input_vals[3]]
1384 RT = [] # type: list[int]
1385 for i in range(VL):
1386 v = RA[i] * RB + carry
1387 RT.append(v & GPR_VALUE_MASK)
1388 carry = v >> GPR_SIZE_IN_BITS
1389 state[op.outputs[0]] = tuple(RT)
1390 state[op.outputs[1]] = carry,
1391
1392 @staticmethod
1393 def __svmaddedu_gen_asm(op, state):
1394 # type: (Op, GenAsmState) -> None
1395 RT = state.vgpr(op.outputs[0])
1396 RA = state.vgpr(op.input_vals[0])
1397 RB = state.sgpr(op.input_vals[1])
1398 RC = state.sgpr(op.input_vals[2])
1399 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1400 SvMAddEDU = GenericOpProperties(
1401 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1402 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1403 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1404 )
1405 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1406 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1407
1408 @staticmethod
1409 def __sradi_sim(op, state):
1410 # type: (Op, BaseSimState) -> None
1411 rs, = state[op.input_vals[0]]
1412 imm = op.immediates[0]
1413 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1414 rs -= 1 << GPR_SIZE_IN_BITS
1415 v = rs >> imm
1416 RA = v & GPR_VALUE_MASK
1417 CA = (RA << imm) != rs
1418 state[op.outputs[0]] = RA,
1419 state[op.outputs[1]] = CA,
1420
1421 @staticmethod
1422 def __sradi_gen_asm(op, state):
1423 # type: (Op, GenAsmState) -> None
1424 RA = state.sgpr(op.outputs[0])
1425 RS = state.sgpr(op.input_vals[0])
1426 imm = op.immediates[0]
1427 state.writeln(f"sradi {RA}, {RS}, {imm}")
1428 SRADI = GenericOpProperties(
1429 demo_asm="sradi RA, RS, imm",
1430 inputs=[OD_BASE_SGPR],
1431 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1432 OD_CA.with_write_stage(OpStage.Late)],
1433 immediates=[range(GPR_SIZE_IN_BITS)],
1434 )
1435 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1436 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1437
1438 @staticmethod
1439 def __setvli_sim(op, state):
1440 # type: (Op, BaseSimState) -> None
1441 state[op.outputs[0]] = op.immediates[0],
1442
1443 @staticmethod
1444 def __setvli_gen_asm(op, state):
1445 # type: (Op, GenAsmState) -> None
1446 imm = op.immediates[0]
1447 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1448 SetVLI = GenericOpProperties(
1449 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1450 inputs=(),
1451 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1452 immediates=[range(1, 65)],
1453 is_load_immediate=True,
1454 )
1455 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1456 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1457
1458 @staticmethod
1459 def __svli_sim(op, state):
1460 # type: (Op, BaseSimState) -> None
1461 VL, = state[op.input_vals[0]]
1462 imm = op.immediates[0] & GPR_VALUE_MASK
1463 state[op.outputs[0]] = (imm,) * VL
1464
1465 @staticmethod
1466 def __svli_gen_asm(op, state):
1467 # type: (Op, GenAsmState) -> None
1468 RT = state.vgpr(op.outputs[0])
1469 imm = op.immediates[0]
1470 state.writeln(f"sv.addi {RT}, 0, {imm}")
1471 SvLI = GenericOpProperties(
1472 demo_asm="sv.addi *RT, 0, imm",
1473 inputs=[OD_VL],
1474 outputs=[OD_EXTRA3_VGPR],
1475 immediates=[IMM_S16],
1476 is_load_immediate=True,
1477 )
1478 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1479 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1480
1481 @staticmethod
1482 def __li_sim(op, state):
1483 # type: (Op, BaseSimState) -> None
1484 imm = op.immediates[0] & GPR_VALUE_MASK
1485 state[op.outputs[0]] = imm,
1486
1487 @staticmethod
1488 def __li_gen_asm(op, state):
1489 # type: (Op, GenAsmState) -> None
1490 RT = state.sgpr(op.outputs[0])
1491 imm = op.immediates[0]
1492 state.writeln(f"addi {RT}, 0, {imm}")
1493 LI = GenericOpProperties(
1494 demo_asm="addi RT, 0, imm",
1495 inputs=(),
1496 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1497 immediates=[IMM_S16],
1498 is_load_immediate=True,
1499 )
1500 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1501 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1502
1503 @staticmethod
1504 def __veccopytoreg_sim(op, state):
1505 # type: (Op, BaseSimState) -> None
1506 state[op.outputs[0]] = state[op.input_vals[0]]
1507
1508 @staticmethod
1509 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1510 # type: (Loc, Loc, bool, GenAsmState) -> None
1511 sv = "sv." if is_vec else ""
1512 rev = ""
1513 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1514 rev = "/mrr"
1515 if src_loc == dest_loc:
1516 return # no-op
1517 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1518 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1519 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1520 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1521 if src_loc.kind is LocKind.StackI64:
1522 if dest_loc.kind is LocKind.StackI64:
1523 raise ValueError(
1524 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1525 elif dest_loc.kind is not LocKind.GPR:
1526 assert_never(dest_loc.kind)
1527 src = state.stack(src_loc)
1528 dest = state.gpr(dest_loc, is_vec=is_vec)
1529 state.writeln(f"{sv}ld {dest}, {src}")
1530 elif dest_loc.kind is LocKind.StackI64:
1531 if src_loc.kind is not LocKind.GPR:
1532 assert_never(src_loc.kind)
1533 src = state.gpr(src_loc, is_vec=is_vec)
1534 dest = state.stack(dest_loc)
1535 state.writeln(f"{sv}std {src}, {dest}")
1536 elif src_loc.kind is LocKind.GPR:
1537 if dest_loc.kind is not LocKind.GPR:
1538 assert_never(dest_loc.kind)
1539 src = state.gpr(src_loc, is_vec=is_vec)
1540 dest = state.gpr(dest_loc, is_vec=is_vec)
1541 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1542 else:
1543 assert_never(src_loc.kind)
1544
1545 @staticmethod
1546 def __veccopytoreg_gen_asm(op, state):
1547 # type: (Op, GenAsmState) -> None
1548 OpKind.__copy_to_from_reg_gen_asm(
1549 src_loc=state.loc(
1550 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1551 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1552 is_vec=True, state=state)
1553
1554 VecCopyToReg = GenericOpProperties(
1555 demo_asm="sv.mv dest, src",
1556 inputs=[GenericOperandDesc(
1557 ty=GenericTy(BaseTy.I64, is_vec=True),
1558 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1559 ), OD_VL],
1560 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1561 is_copy=True,
1562 )
1563 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1564 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1565
1566 @staticmethod
1567 def __veccopyfromreg_sim(op, state):
1568 # type: (Op, BaseSimState) -> None
1569 state[op.outputs[0]] = state[op.input_vals[0]]
1570
1571 @staticmethod
1572 def __veccopyfromreg_gen_asm(op, state):
1573 # type: (Op, GenAsmState) -> None
1574 OpKind.__copy_to_from_reg_gen_asm(
1575 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1576 dest_loc=state.loc(
1577 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1578 is_vec=True, state=state)
1579 VecCopyFromReg = GenericOpProperties(
1580 demo_asm="sv.mv dest, src",
1581 inputs=[OD_EXTRA3_VGPR, OD_VL],
1582 outputs=[GenericOperandDesc(
1583 ty=GenericTy(BaseTy.I64, is_vec=True),
1584 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1585 write_stage=OpStage.Late,
1586 )],
1587 is_copy=True,
1588 )
1589 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1590 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1591
1592 @staticmethod
1593 def __copytoreg_sim(op, state):
1594 # type: (Op, BaseSimState) -> None
1595 state[op.outputs[0]] = state[op.input_vals[0]]
1596
1597 @staticmethod
1598 def __copytoreg_gen_asm(op, state):
1599 # type: (Op, GenAsmState) -> None
1600 OpKind.__copy_to_from_reg_gen_asm(
1601 src_loc=state.loc(
1602 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1603 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1604 is_vec=False, state=state)
1605 CopyToReg = GenericOpProperties(
1606 demo_asm="mv dest, src",
1607 inputs=[GenericOperandDesc(
1608 ty=GenericTy(BaseTy.I64, is_vec=False),
1609 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1610 LocSubKind.StackI64],
1611 )],
1612 outputs=[GenericOperandDesc(
1613 ty=GenericTy(BaseTy.I64, is_vec=False),
1614 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1615 write_stage=OpStage.Late,
1616 )],
1617 is_copy=True,
1618 )
1619 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1620 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1621
1622 @staticmethod
1623 def __copyfromreg_sim(op, state):
1624 # type: (Op, BaseSimState) -> None
1625 state[op.outputs[0]] = state[op.input_vals[0]]
1626
1627 @staticmethod
1628 def __copyfromreg_gen_asm(op, state):
1629 # type: (Op, GenAsmState) -> None
1630 OpKind.__copy_to_from_reg_gen_asm(
1631 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1632 dest_loc=state.loc(
1633 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1634 is_vec=False, state=state)
1635 CopyFromReg = GenericOpProperties(
1636 demo_asm="mv dest, src",
1637 inputs=[GenericOperandDesc(
1638 ty=GenericTy(BaseTy.I64, is_vec=False),
1639 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1640 )],
1641 outputs=[GenericOperandDesc(
1642 ty=GenericTy(BaseTy.I64, is_vec=False),
1643 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1644 LocSubKind.StackI64],
1645 write_stage=OpStage.Late,
1646 )],
1647 is_copy=True,
1648 )
1649 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1650 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1651
1652 @staticmethod
1653 def __concat_sim(op, state):
1654 # type: (Op, BaseSimState) -> None
1655 state[op.outputs[0]] = tuple(
1656 state[i][0] for i in op.input_vals[:-1])
1657
1658 @staticmethod
1659 def __concat_gen_asm(op, state):
1660 # type: (Op, GenAsmState) -> None
1661 OpKind.__copy_to_from_reg_gen_asm(
1662 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1663 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1664 is_vec=True, state=state)
1665 Concat = GenericOpProperties(
1666 demo_asm="sv.mv dest, src",
1667 inputs=[GenericOperandDesc(
1668 ty=GenericTy(BaseTy.I64, is_vec=False),
1669 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1670 spread=True,
1671 ), OD_VL],
1672 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1673 is_copy=True,
1674 )
1675 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1676 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1677
1678 @staticmethod
1679 def __spread_sim(op, state):
1680 # type: (Op, BaseSimState) -> None
1681 for idx, inp in enumerate(state[op.input_vals[0]]):
1682 state[op.outputs[idx]] = inp,
1683
1684 @staticmethod
1685 def __spread_gen_asm(op, state):
1686 # type: (Op, GenAsmState) -> None
1687 OpKind.__copy_to_from_reg_gen_asm(
1688 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1689 dest_loc=state.loc(op.outputs, LocKind.GPR),
1690 is_vec=True, state=state)
1691 Spread = GenericOpProperties(
1692 demo_asm="sv.mv dest, src",
1693 inputs=[OD_EXTRA3_VGPR, OD_VL],
1694 outputs=[GenericOperandDesc(
1695 ty=GenericTy(BaseTy.I64, is_vec=False),
1696 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1697 spread=True,
1698 write_stage=OpStage.Late,
1699 )],
1700 is_copy=True,
1701 )
1702 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1703 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1704
1705 @staticmethod
1706 def __svld_sim(op, state):
1707 # type: (Op, BaseSimState) -> None
1708 RA, = state[op.input_vals[0]]
1709 VL, = state[op.input_vals[1]]
1710 addr = RA + op.immediates[0]
1711 RT = [] # type: list[int]
1712 for i in range(VL):
1713 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1714 RT.append(v & GPR_VALUE_MASK)
1715 state[op.outputs[0]] = tuple(RT)
1716
1717 @staticmethod
1718 def __svld_gen_asm(op, state):
1719 # type: (Op, GenAsmState) -> None
1720 RA = state.sgpr(op.input_vals[0])
1721 RT = state.vgpr(op.outputs[0])
1722 imm = op.immediates[0]
1723 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1724 SvLd = GenericOpProperties(
1725 demo_asm="sv.ld *RT, imm(RA)",
1726 inputs=[OD_EXTRA3_SGPR, OD_VL],
1727 outputs=[OD_EXTRA3_VGPR],
1728 immediates=[IMM_S16],
1729 )
1730 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1731 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1732
1733 @staticmethod
1734 def __ld_sim(op, state):
1735 # type: (Op, BaseSimState) -> None
1736 RA, = state[op.input_vals[0]]
1737 addr = RA + op.immediates[0]
1738 v = state.load(addr)
1739 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1740
1741 @staticmethod
1742 def __ld_gen_asm(op, state):
1743 # type: (Op, GenAsmState) -> None
1744 RA = state.sgpr(op.input_vals[0])
1745 RT = state.sgpr(op.outputs[0])
1746 imm = op.immediates[0]
1747 state.writeln(f"ld {RT}, {imm}({RA})")
1748 Ld = GenericOpProperties(
1749 demo_asm="ld RT, imm(RA)",
1750 inputs=[OD_BASE_SGPR],
1751 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1752 immediates=[IMM_S16],
1753 )
1754 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1755 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1756
1757 @staticmethod
1758 def __svstd_sim(op, state):
1759 # type: (Op, BaseSimState) -> None
1760 RS = state[op.input_vals[0]]
1761 RA, = state[op.input_vals[1]]
1762 VL, = state[op.input_vals[2]]
1763 addr = RA + op.immediates[0]
1764 for i in range(VL):
1765 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1766
1767 @staticmethod
1768 def __svstd_gen_asm(op, state):
1769 # type: (Op, GenAsmState) -> None
1770 RS = state.vgpr(op.input_vals[0])
1771 RA = state.sgpr(op.input_vals[1])
1772 imm = op.immediates[0]
1773 state.writeln(f"sv.std {RS}, {imm}({RA})")
1774 SvStd = GenericOpProperties(
1775 demo_asm="sv.std *RS, imm(RA)",
1776 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1777 outputs=[],
1778 immediates=[IMM_S16],
1779 has_side_effects=True,
1780 )
1781 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1782 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1783
1784 @staticmethod
1785 def __std_sim(op, state):
1786 # type: (Op, BaseSimState) -> None
1787 RS, = state[op.input_vals[0]]
1788 RA, = state[op.input_vals[1]]
1789 addr = RA + op.immediates[0]
1790 state.store(addr, value=RS)
1791
1792 @staticmethod
1793 def __std_gen_asm(op, state):
1794 # type: (Op, GenAsmState) -> None
1795 RS = state.sgpr(op.input_vals[0])
1796 RA = state.sgpr(op.input_vals[1])
1797 imm = op.immediates[0]
1798 state.writeln(f"std {RS}, {imm}({RA})")
1799 Std = GenericOpProperties(
1800 demo_asm="std RS, imm(RA)",
1801 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1802 outputs=[],
1803 immediates=[IMM_S16],
1804 has_side_effects=True,
1805 )
1806 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1807 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1808
1809 @staticmethod
1810 def __funcargr3_sim(op, state):
1811 # type: (Op, BaseSimState) -> None
1812 pass # return value set before simulation
1813
1814 @staticmethod
1815 def __funcargr3_gen_asm(op, state):
1816 # type: (Op, GenAsmState) -> None
1817 pass # no instructions needed
1818 FuncArgR3 = GenericOpProperties(
1819 demo_asm="",
1820 inputs=[],
1821 outputs=[OD_BASE_SGPR.with_fixed_loc(
1822 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1823 )
1824 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1825 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1826
1827
1828 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1829 class SSAValOrUse(metaclass=InternedMeta):
1830 __slots__ = "op", "operand_idx"
1831
1832 def __init__(self, op, operand_idx):
1833 # type: (Op, int) -> None
1834 super().__init__()
1835 self.op = op
1836 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1837 raise ValueError("invalid operand_idx")
1838 self.operand_idx = operand_idx
1839
1840 @abstractmethod
1841 def __repr__(self):
1842 # type: () -> str
1843 ...
1844
1845 @property
1846 @abstractmethod
1847 def descriptor_array(self):
1848 # type: () -> tuple[OperandDesc, ...]
1849 ...
1850
1851 @cached_property
1852 def defining_descriptor(self):
1853 # type: () -> OperandDesc
1854 return self.descriptor_array[self.operand_idx]
1855
1856 @cached_property
1857 def ty(self):
1858 # type: () -> Ty
1859 return self.defining_descriptor.ty
1860
1861 @cached_property
1862 def ty_before_spread(self):
1863 # type: () -> Ty
1864 return self.defining_descriptor.ty_before_spread
1865
1866 @property
1867 def base_ty(self):
1868 # type: () -> BaseTy
1869 return self.ty_before_spread.base_ty
1870
1871 @property
1872 def reg_offset_in_unspread(self):
1873 """ the number of reg-sized slots in the unspread Loc before self's Loc
1874
1875 e.g. if the unspread Loc containing self is:
1876 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1877 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1878 then reg_offset_into_unspread == 2 == 10 - 8
1879 """
1880 return self.defining_descriptor.reg_offset_in_unspread
1881
1882 @property
1883 def unspread_start_idx(self):
1884 # type: () -> int
1885 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1886
1887 @property
1888 def unspread_start(self):
1889 # type: () -> Self
1890 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1891
1892
1893 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1894 @final
1895 class SSAVal(SSAValOrUse):
1896 __slots__ = ()
1897
1898 def __repr__(self):
1899 # type: () -> str
1900 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1901
1902 @cached_property
1903 def def_loc_set_before_spread(self):
1904 # type: () -> LocSet
1905 return self.defining_descriptor.loc_set_before_spread
1906
1907 @cached_property
1908 def descriptor_array(self):
1909 # type: () -> tuple[OperandDesc, ...]
1910 return self.op.properties.outputs
1911
1912 @cached_property
1913 def tied_input(self):
1914 # type: () -> None | SSAUse
1915 if self.defining_descriptor.tied_input_index is None:
1916 return None
1917 return SSAUse(op=self.op,
1918 operand_idx=self.defining_descriptor.tied_input_index)
1919
1920 @property
1921 def write_stage(self):
1922 # type: () -> OpStage
1923 return self.defining_descriptor.write_stage
1924
1925 @property
1926 def current_debugging_value(self):
1927 # type: () -> tuple[int, ...]
1928 """ get the current value for debugging in pdb or similar.
1929
1930 This is intended for use with
1931 `PreRASimState.set_current_debugging_state`.
1932
1933 This is only intended for debugging, do not use in unit tests or
1934 production code.
1935 """
1936 return PreRASimState.get_current_debugging_state()[self]
1937
1938 @cached_property
1939 def ssa_val_sub_regs(self):
1940 # type: () -> tuple[SSAValSubReg, ...]
1941 return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
1942
1943
1944 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1945 @final
1946 class SSAUse(SSAValOrUse):
1947 __slots__ = ()
1948
1949 @cached_property
1950 def use_loc_set_before_spread(self):
1951 # type: () -> LocSet
1952 return self.defining_descriptor.loc_set_before_spread
1953
1954 @cached_property
1955 def descriptor_array(self):
1956 # type: () -> tuple[OperandDesc, ...]
1957 return self.op.properties.inputs
1958
1959 def __repr__(self):
1960 # type: () -> str
1961 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1962
1963 @property
1964 def ssa_val(self):
1965 # type: () -> SSAVal
1966 return self.op.input_vals[self.operand_idx]
1967
1968 @ssa_val.setter
1969 def ssa_val(self, ssa_val):
1970 # type: (SSAVal) -> None
1971 self.op.input_vals[self.operand_idx] = ssa_val
1972
1973
1974 _T = TypeVar("_T")
1975 _Desc = TypeVar("_Desc")
1976
1977
1978 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1979 @abstractmethod
1980 def _verify_write_with_desc(self, idx, item, desc):
1981 # type: (int, _T | Any, _Desc) -> None
1982 raise NotImplementedError
1983
1984 @final
1985 def _verify_write(self, idx, item):
1986 # type: (int | Any, _T | Any) -> int
1987 if not isinstance(idx, int):
1988 if isinstance(idx, slice):
1989 raise TypeError(
1990 f"can't write to slice of {self.__class__.__name__}")
1991 raise TypeError(f"can't write with index {idx!r}")
1992 # normalize idx, raising IndexError if it is out of range
1993 idx = range(len(self.descriptors))[idx]
1994 desc = self.descriptors[idx]
1995 self._verify_write_with_desc(idx, item, desc)
1996 return idx
1997
1998 def _on_set(self, idx, new_item, old_item):
1999 # type: (int, _T, _T | None) -> None
2000 pass
2001
2002 @abstractmethod
2003 def _get_descriptors(self):
2004 # type: () -> tuple[_Desc, ...]
2005 raise NotImplementedError
2006
2007 @cached_property
2008 @final
2009 def descriptors(self):
2010 # type: () -> tuple[_Desc, ...]
2011 return self._get_descriptors()
2012
2013 @property
2014 @final
2015 def op(self):
2016 return self.__op
2017
2018 def __init__(self, items, op):
2019 # type: (Iterable[_T], Op) -> None
2020 super().__init__()
2021 self.__op = op
2022 self.__items = [] # type: list[_T]
2023 for idx, item in enumerate(items):
2024 if idx >= len(self.descriptors):
2025 raise ValueError("too many items")
2026 _ = self._verify_write(idx, item)
2027 self.__items.append(item)
2028 if len(self.__items) < len(self.descriptors):
2029 raise ValueError("not enough items")
2030
2031 @final
2032 def __iter__(self):
2033 # type: () -> Iterator[_T]
2034 yield from self.__items
2035
2036 @overload
2037 def __getitem__(self, idx):
2038 # type: (int) -> _T
2039 ...
2040
2041 @overload
2042 def __getitem__(self, idx):
2043 # type: (slice) -> list[_T]
2044 ...
2045
2046 @final
2047 def __getitem__(self, idx):
2048 # type: (int | slice) -> _T | list[_T]
2049 return self.__items[idx]
2050
2051 @final
2052 def __setitem__(self, idx, item):
2053 # type: (int, _T) -> None
2054 idx = self._verify_write(idx, item)
2055 self.__items[idx] = item
2056
2057 @final
2058 def __len__(self):
2059 # type: () -> int
2060 return len(self.__items)
2061
2062 def __repr__(self):
2063 # type: () -> str
2064 return f"{self.__class__.__name__}({self.__items}, op=...)"
2065
2066
2067 @final
2068 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
2069 def _get_descriptors(self):
2070 # type: () -> tuple[OperandDesc, ...]
2071 return self.op.properties.inputs
2072
2073 def _verify_write_with_desc(self, idx, item, desc):
2074 # type: (int, SSAVal | Any, OperandDesc) -> None
2075 if not isinstance(item, SSAVal):
2076 raise TypeError("expected value of type SSAVal")
2077 if item.ty != desc.ty:
2078 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
2079 f"corresponding input's type {desc.ty!r}")
2080
2081 def _on_set(self, idx, new_item, old_item):
2082 # type: (int, SSAVal, SSAVal | None) -> None
2083 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
2084
2085 def __init__(self, items, op):
2086 # type: (Iterable[SSAVal], Op) -> None
2087 if hasattr(op, "inputs"):
2088 raise ValueError("Op.inputs already set")
2089 super().__init__(items, op)
2090
2091
2092 @final
2093 class OpImmediates(OpInputSeq[int, range]):
2094 def _get_descriptors(self):
2095 # type: () -> tuple[range, ...]
2096 return self.op.properties.immediates
2097
2098 def _verify_write_with_desc(self, idx, item, desc):
2099 # type: (int, int | Any, range) -> None
2100 if not isinstance(item, int):
2101 raise TypeError("expected value of type int")
2102 if item not in desc:
2103 raise ValueError(f"immediate value {item!r} not in {desc!r}")
2104
2105 def __init__(self, items, op):
2106 # type: (Iterable[int], Op) -> None
2107 if hasattr(op, "immediates"):
2108 raise ValueError("Op.immediates already set")
2109 super().__init__(items, op)
2110
2111
2112 @plain_data(frozen=True, eq=False, repr=False)
2113 @final
2114 class Op:
2115 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
2116 "outputs", "name")
2117
2118 def __init__(self, fn, properties, input_vals, immediates, name=""):
2119 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2120 self.fn = fn
2121 self.properties = properties
2122 self.input_vals = OpInputVals(input_vals, op=self)
2123 inputs_len = len(self.properties.inputs)
2124 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
2125 self.immediates = OpImmediates(immediates, op=self)
2126 outputs_len = len(self.properties.outputs)
2127 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
2128 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
2129
2130 @property
2131 def kind(self):
2132 # type: () -> OpKind
2133 return self.properties.kind
2134
2135 def __eq__(self, other):
2136 # type: (Op | Any) -> bool
2137 if isinstance(other, Op):
2138 return self is other
2139 return NotImplemented
2140
2141 def __hash__(self):
2142 # type: () -> int
2143 return object.__hash__(self)
2144
2145 def __repr__(self, wrap_width=63, indent=" "):
2146 # type: (int, str) -> str
2147 WRAP_POINT = "\u200B" # zero-width space
2148 items = [f"{self.name}:\n"]
2149 for i, out in enumerate(self.outputs):
2150 item = f"<...outputs[{i}]: {out.ty}>"
2151 if i == 0:
2152 item = "(" + WRAP_POINT + item
2153 if i != len(self.outputs) - 1:
2154 item += ", " + WRAP_POINT
2155 else:
2156 item += WRAP_POINT + ") <= "
2157 items.append(item)
2158 items.append(self.kind._name_)
2159 if len(self.input_vals) + len(self.immediates) != 0:
2160 items[-1] += "("
2161 items[-1] += WRAP_POINT
2162 for i, inp in enumerate(self.input_vals):
2163 item = repr(inp)
2164 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2165 item += ", " + WRAP_POINT
2166 else:
2167 item += ") " + WRAP_POINT
2168 items.append(item)
2169 for i, imm in enumerate(self.immediates):
2170 item = hex(imm)
2171 if i != len(self.immediates) - 1:
2172 item += ", " + WRAP_POINT
2173 else:
2174 item += ") " + WRAP_POINT
2175 items.append(item)
2176 lines = [] # type: list[str]
2177 for i, line_in in enumerate("".join(items).splitlines()):
2178 if i != 0:
2179 line_in = indent + line_in
2180 line_out = ""
2181 for part in line_in.split(WRAP_POINT):
2182 if line_out == "":
2183 line_out = part
2184 continue
2185 trial_line_out = line_out + part
2186 if len(trial_line_out.rstrip()) > wrap_width:
2187 lines.append(line_out.rstrip())
2188 line_out = indent + part
2189 else:
2190 line_out = trial_line_out
2191 lines.append(line_out.rstrip())
2192 return "\n".join(lines)
2193
2194 def sim(self, state):
2195 # type: (BaseSimState) -> None
2196 for inp in self.input_vals:
2197 try:
2198 val = state[inp]
2199 except KeyError:
2200 raise ValueError(f"SSAVal {inp} not yet assigned when "
2201 f"running {self}")
2202 except SimSkipOp:
2203 continue
2204 if len(val) != inp.ty.reg_len:
2205 raise ValueError(
2206 f"value of SSAVal {inp} has wrong number of elements: "
2207 f"expected {inp.ty.reg_len} found "
2208 f"{len(val)}: {val!r}")
2209 if isinstance(state, PreRASimState):
2210 for out in self.outputs:
2211 if out in state.ssa_vals:
2212 if self.kind is OpKind.FuncArgR3:
2213 continue
2214 raise ValueError(f"SSAVal {out} already assigned before "
2215 f"running {self}")
2216 try:
2217 self.kind.sim(self, state)
2218 except SimSkipOp:
2219 state.on_skip(self)
2220 for out in self.outputs:
2221 try:
2222 val = state[out]
2223 except KeyError:
2224 raise ValueError(f"running {self} failed to assign to {out}")
2225 except SimSkipOp:
2226 continue
2227 if len(val) != out.ty.reg_len:
2228 raise ValueError(
2229 f"value of SSAVal {out} has wrong number of elements: "
2230 f"expected {out.ty.reg_len} found "
2231 f"{len(val)}: {val!r}")
2232
2233 def gen_asm(self, state):
2234 # type: (GenAsmState) -> None
2235 all_loc_kinds = tuple(LocKind)
2236 for inp in self.input_vals:
2237 state.loc(inp, expected_kinds=all_loc_kinds)
2238 for out in self.outputs:
2239 state.loc(out, expected_kinds=all_loc_kinds)
2240 self.kind.gen_asm(self, state)
2241
2242
2243 @plain_data(frozen=True, repr=False)
2244 class BaseSimState(metaclass=ABCMeta):
2245 __slots__ = "memory",
2246
2247 def __init__(self, memory):
2248 # type: (dict[int, int]) -> None
2249 super().__init__()
2250 self.memory = memory # type: dict[int, int]
2251
2252 def _default_memory_value(self):
2253 # type: () -> int
2254 return 0
2255
2256 def on_skip(self, op):
2257 # type: (Op) -> None
2258 raise ValueError("skipping instructions not supported")
2259
2260 def load_byte(self, addr):
2261 # type: (int) -> int
2262 addr &= GPR_VALUE_MASK
2263 try:
2264 return self.memory[addr] & 0xFF
2265 except KeyError:
2266 return self._default_memory_value()
2267
2268 def store_byte(self, addr, value):
2269 # type: (int, int) -> None
2270 addr &= GPR_VALUE_MASK
2271 value &= 0xFF
2272 self.memory[addr] = value
2273
2274 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2275 # type: (int, int, bool) -> int
2276 if addr % size_in_bytes != 0:
2277 raise ValueError(f"address not aligned: {hex(addr)} "
2278 f"required alignment: {size_in_bytes}")
2279 retval = 0
2280 for i in range(size_in_bytes):
2281 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2282 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2283 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2284 return retval
2285
2286 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2287 # type: (int, int, int) -> None
2288 if addr % size_in_bytes != 0:
2289 raise ValueError(f"address not aligned: {hex(addr)} "
2290 f"required alignment: {size_in_bytes}")
2291 for i in range(size_in_bytes):
2292 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2293
2294 def _memory__repr(self):
2295 # type: () -> str
2296 if len(self.memory) == 0:
2297 return "{}"
2298 keys = sorted(self.memory.keys(), reverse=True)
2299 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2300 items = [] # type: list[str]
2301 while len(keys) != 0:
2302 addr = keys[-1]
2303 if (len(keys) >= CHUNK_SIZE
2304 and addr % CHUNK_SIZE == 0
2305 and keys[-CHUNK_SIZE:]
2306 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2307 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2308 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2309 keys[-CHUNK_SIZE:] = ()
2310 else:
2311 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2312 if len(items) == 1:
2313 return f"{{{items[0]}}}"
2314 items_str = ",\n".join(items)
2315 return f"{{\n{items_str}}}"
2316
2317 def __repr__(self):
2318 # type: () -> str
2319 field_vals = [] # type: list[str]
2320 for name in fields(self):
2321 try:
2322 value = getattr(self, name)
2323 except AttributeError:
2324 field_vals.append(f"{name}=<not set>")
2325 continue
2326 repr_fn = getattr(self, f"_{name}__repr", None)
2327 if callable(repr_fn):
2328 field_vals.append(f"{name}={repr_fn()}")
2329 else:
2330 field_vals.append(f"{name}={value!r}")
2331 field_vals_str = ", ".join(field_vals)
2332 return f"{self.__class__.__name__}({field_vals_str})"
2333
2334 @abstractmethod
2335 def __getitem__(self, ssa_val):
2336 # type: (SSAVal) -> tuple[int, ...]
2337 ...
2338
2339 @abstractmethod
2340 def __setitem__(self, ssa_val, value):
2341 # type: (SSAVal, tuple[int, ...]) -> None
2342 ...
2343
2344
2345 @plain_data(frozen=True, repr=False)
2346 class PreRABaseSimState(BaseSimState):
2347 __slots__ = "ssa_vals",
2348
2349 def __init__(self, ssa_vals, memory):
2350 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2351 super().__init__(memory)
2352 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2353
2354 def _ssa_vals__repr(self):
2355 # type: () -> str
2356 if len(self.ssa_vals) == 0:
2357 return "{}"
2358 items = [] # type: list[str]
2359 CHUNK_SIZE = 4
2360 for k, v in self.ssa_vals.items():
2361 element_strs = [] # type: list[str]
2362 for i, el in enumerate(v):
2363 if i % CHUNK_SIZE != 0:
2364 element_strs.append(" " + hex(el))
2365 else:
2366 element_strs.append("\n " + hex(el))
2367 if len(element_strs) <= CHUNK_SIZE:
2368 element_strs[0] = element_strs[0].lstrip()
2369 if len(element_strs) == 1:
2370 element_strs.append("")
2371 v_str = ",".join(element_strs)
2372 items.append(f"{k!r}: ({v_str})")
2373 if len(items) == 1 and "\n" not in items[0]:
2374 return f"{{{items[0]}}}"
2375 items_str = ",\n".join(items)
2376 return f"{{\n{items_str},\n}}"
2377
2378 def __getitem__(self, ssa_val):
2379 # type: (SSAVal) -> tuple[int, ...]
2380 try:
2381 return self.ssa_vals[ssa_val]
2382 except KeyError:
2383 return self._handle_undefined_ssa_val(ssa_val)
2384
2385 def _handle_undefined_ssa_val(self, ssa_val):
2386 # type: (SSAVal) -> tuple[int, ...]
2387 raise KeyError("SSAVal has no value set", ssa_val)
2388
2389 def __setitem__(self, ssa_val, value):
2390 # type: (SSAVal, tuple[int, ...]) -> None
2391 if len(value) != ssa_val.ty.reg_len:
2392 raise ValueError("value has wrong len")
2393 self.ssa_vals[ssa_val] = value
2394
2395
2396 class SimSkipOp(Exception):
2397 pass
2398
2399
2400 @plain_data(frozen=True, repr=False)
2401 @final
2402 class ConstPropagationState(PreRABaseSimState):
2403 __slots__ = "skipped_ops",
2404
2405 def __init__(self, ssa_vals, memory, skipped_ops):
2406 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2407 super().__init__(ssa_vals, memory)
2408 self.skipped_ops = skipped_ops
2409
2410 def _default_memory_value(self):
2411 # type: () -> int
2412 raise SimSkipOp
2413
2414 def _handle_undefined_ssa_val(self, ssa_val):
2415 # type: (SSAVal) -> tuple[int, ...]
2416 raise SimSkipOp
2417
2418 def on_skip(self, op):
2419 # type: (Op) -> None
2420 self.skipped_ops.add(op)
2421
2422
2423 @plain_data(frozen=True, repr=False)
2424 class PreRASimState(PreRABaseSimState):
2425 __slots__ = ()
2426
2427 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2428
2429 @contextmanager
2430 def set_as_current_debugging_state(self):
2431 """ return a context manager that sets self as the current state for
2432 debugging in pdb or similar. This is intended only for use with
2433 `get_current_debugging_state` which should not be used in unit tests
2434 or production code.
2435 """
2436 try:
2437 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2438 yield
2439 finally:
2440 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2441 "inconsistent __CURRENT_DEBUGGING_STATE"
2442
2443 @staticmethod
2444 def get_current_debugging_state():
2445 # type: () -> PreRASimState
2446 """ get the current state for debugging in pdb or similar.
2447
2448 This is intended for use with `set_current_debugging_state`.
2449
2450 This is only intended for debugging, do not use in unit tests or
2451 production code.
2452 """
2453 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2454 raise ValueError("no current debugging state")
2455 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2456
2457
2458 @plain_data(frozen=True, repr=False)
2459 @final
2460 class PostRASimState(BaseSimState):
2461 __slots__ = "ssa_val_to_loc_map", "loc_values"
2462
2463 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2464 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2465 super().__init__(memory)
2466 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2467 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2468 if ssa_val.ty != loc.ty:
2469 raise ValueError(
2470 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2471 self.loc_values = loc_values
2472 for loc in self.loc_values.keys():
2473 if loc.reg_len != 1:
2474 raise ValueError(
2475 "loc_values must only contain Locs with reg_len=1, all "
2476 "larger Locs will be split into reg_len=1 sub-Locs")
2477
2478 def _loc_values__repr(self):
2479 # type: () -> str
2480 locs = sorted(self.loc_values.keys(),
2481 key=lambda v: (v.kind.name, v.start))
2482 items = [] # type: list[str]
2483 for loc in locs:
2484 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2485 items_str = ",\n".join(items)
2486 return f"{{\n{items_str},\n}}"
2487
2488 def __getitem__(self, ssa_val):
2489 # type: (SSAVal) -> tuple[int, ...]
2490 loc = self.ssa_val_to_loc_map[ssa_val]
2491 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2492 retval = [] # type: list[int]
2493 for i in range(loc.reg_len):
2494 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2495 retval.append(self.loc_values.get(subloc, 0))
2496 return tuple(retval)
2497
2498 def __setitem__(self, ssa_val, value):
2499 # type: (SSAVal, tuple[int, ...]) -> None
2500 if len(value) != ssa_val.ty.reg_len:
2501 raise ValueError("value has wrong len")
2502 loc = self.ssa_val_to_loc_map[ssa_val]
2503 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2504 for i in range(loc.reg_len):
2505 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2506 self.loc_values[subloc] = value[i]
2507
2508
2509 @plain_data(frozen=True)
2510 class GenAsmState:
2511 __slots__ = "allocated_locs", "output"
2512
2513 def __init__(self, allocated_locs, output=None):
2514 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2515 super().__init__()
2516 self.allocated_locs = FMap(allocated_locs)
2517 for ssa_val, loc in self.allocated_locs.items():
2518 if ssa_val.ty != loc.ty:
2519 raise ValueError(
2520 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2521 if output is None:
2522 output = []
2523 self.output = output
2524
2525 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2526
2527 def loc(self, ssa_val_or_locs, expected_kinds):
2528 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2529 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2530 ssa_val_or_locs = [ssa_val_or_locs]
2531 locs = [] # type: list[Loc]
2532 for i in ssa_val_or_locs:
2533 if isinstance(i, SSAVal):
2534 locs.append(self.allocated_locs[i])
2535 else:
2536 locs.append(i)
2537 if len(locs) == 0:
2538 raise ValueError("invalid Loc sequence: must not be empty")
2539 retval = locs[0].try_concat(*locs[1:])
2540 if retval is None:
2541 raise ValueError("invalid Loc sequence: try_concat failed")
2542 if isinstance(expected_kinds, LocKind):
2543 expected_kinds = expected_kinds,
2544 if retval.kind not in expected_kinds:
2545 if len(expected_kinds) == 1:
2546 expected_kinds = expected_kinds[0]
2547 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2548 f"{retval.kind} expected {expected_kinds}")
2549 return retval
2550
2551 def gpr(self, ssa_val_or_locs, is_vec):
2552 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2553 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2554 vec_str = "*" if is_vec else ""
2555 return vec_str + str(loc.start)
2556
2557 def sgpr(self, ssa_val_or_locs):
2558 # type: (__SSA_VAL_OR_LOCS) -> str
2559 return self.gpr(ssa_val_or_locs, is_vec=False)
2560
2561 def vgpr(self, ssa_val_or_locs):
2562 # type: (__SSA_VAL_OR_LOCS) -> str
2563 return self.gpr(ssa_val_or_locs, is_vec=True)
2564
2565 def stack(self, ssa_val_or_locs):
2566 # type: (__SSA_VAL_OR_LOCS) -> str
2567 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2568 return f"{loc.start}(1)"
2569
2570 def writeln(self, *line_segments):
2571 # type: (*str) -> None
2572 line = " ".join(line_segments)
2573 if isinstance(self.output, list):
2574 self.output.append(line)
2575 else:
2576 self.output.write(line + "\n")