1 from contextlib
import contextmanager
4 from abc
import ABCMeta
, abstractmethod
5 from enum
import Enum
, unique
6 from functools
import lru_cache
, total_ordering
7 from io
import StringIO
8 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
9 Mapping
, Sequence
, TypeVar
, Union
, overload
)
10 from weakref
import WeakValueDictionary
as _WeakVDict
12 from cached_property
import cached_property
13 from nmutil
import plain_data
# type: ignore
15 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
17 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, Interned
,
18 OFSet
, OSet
, bit_count
)
22 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
23 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
29 self
.ops
= [] # type: list[Op]
30 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
31 self
.__next
_name
_suffix
= 2
33 def _add_op_with_unused_name(self
, op
, name
=""):
34 # type: (Op, str) -> str
36 raise ValueError("can't add Op to wrong Fn")
37 if hasattr(op
, "name"):
38 raise ValueError("Op already named")
41 if name
!= "" and name
not in self
.__op
_names
:
42 self
.__op
_names
[name
] = op
44 name
= orig_name
+ str(self
.__next
_name
_suffix
)
45 self
.__next
_name
_suffix
+= 1
51 def ops_to_str(self
, as_python_literal
=False, wrap_width
=63,
52 python_indent
=" ", indent
=" "):
53 # type: (bool, int, str, str) -> str
54 l
= [] # type: list[str]
56 l
.append(op
.__repr
__(wrap_width
=wrap_width
, indent
=indent
))
59 l
= [python_indent
+ "\""]
62 l
.append(f
"\\n\"\n{python_indent}\"")
65 elif ch
.isascii() and ch
.isprintable():
68 l
.append(repr(ch
).strip("\"'"))
71 empty_end
= f
"\"\n{python_indent}\"\""
72 if retval
.endswith(empty_end
):
73 retval
= retval
[:-len(empty_end
)]
76 def append_op(self
, op
):
79 raise ValueError("can't add Op to wrong Fn")
82 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
84 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
85 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
86 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
87 self
.append_op(retval
)
91 # type: (BaseSimState) -> None
95 def gen_asm(self
, state
):
96 # type: (GenAsmState) -> None
100 def pre_ra_insert_copies(self
):
102 orig_ops
= list(self
.ops
)
103 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
104 setvli_outputs
= {} # type: dict[SSAVal, Op]
107 for i
in range(len(op
.input_vals
)):
108 inp
= copied_outputs
[op
.input_vals
[i
]]
109 if inp
.ty
.base_ty
is BaseTy
.I64
:
110 maxvl
= inp
.ty
.reg_len
111 if inp
.ty
.reg_len
!= 1:
112 setvl
= self
.append_new_op(
113 OpKind
.SetVLI
, immediates
=[maxvl
],
114 name
=f
"{op.name}.inp{i}.setvl")
115 vl
= setvl
.outputs
[0]
116 mv
= self
.append_new_op(
117 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
118 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
120 mv
= self
.append_new_op(
121 OpKind
.CopyToReg
, input_vals
=[inp
],
122 name
=f
"{op.name}.inp{i}.copy")
123 op
.input_vals
[i
] = mv
.outputs
[0]
124 elif inp
.ty
.base_ty
is BaseTy
.CA \
125 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
126 # all copies would be no-ops, so we don't need to copy,
127 # though we do need to rematerialize SetVLI ops right
129 if inp
in setvli_outputs
:
130 setvl
= self
.append_new_op(
132 immediates
=setvli_outputs
[inp
].immediates
,
133 name
=f
"{op.name}.inp{i}.setvl")
134 inp
= setvl
.outputs
[0]
135 op
.input_vals
[i
] = inp
137 assert_never(inp
.ty
.base_ty
)
139 for i
, out
in enumerate(op
.outputs
):
140 if op
.kind
is OpKind
.SetVLI
:
141 setvli_outputs
[out
] = op
142 if out
.ty
.base_ty
is BaseTy
.I64
:
143 maxvl
= out
.ty
.reg_len
144 if out
.ty
.reg_len
!= 1:
145 setvl
= self
.append_new_op(
146 OpKind
.SetVLI
, immediates
=[maxvl
],
147 name
=f
"{op.name}.out{i}.setvl")
148 vl
= setvl
.outputs
[0]
149 mv
= self
.append_new_op(
150 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
151 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
153 mv
= self
.append_new_op(
154 OpKind
.CopyFromReg
, input_vals
=[out
],
155 name
=f
"{op.name}.out{i}.copy")
156 copied_outputs
[out
] = mv
.outputs
[0]
157 elif out
.ty
.base_ty
is BaseTy
.CA \
158 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
159 # all copies would be no-ops, so we don't need to copy
160 copied_outputs
[out
] = out
162 assert_never(out
.ty
.base_ty
)
169 value
: Literal
[0, 1] # type: ignore
171 def __new__(cls
, value
):
172 # type: (int) -> OpStage
174 if value
not in (0, 1):
175 raise ValueError("invalid value")
176 retval
= object.__new
__(cls
)
177 retval
._value
_ = value
181 """ early stage of Op execution, where all input reads occur.
182 all output writes with `write_stage == Early` occur here too, and therefore
183 conflict with input reads, telling the compiler that it that can't share
184 that output's register with any inputs that the output isn't tied to.
186 All outputs, even unused outputs, can't share registers with any other
187 outputs, independent of `write_stage` settings.
190 """ late stage of Op execution, where all output writes with
191 `write_stage == Late` occur, and therefore don't conflict with input reads,
192 telling the compiler that any inputs can safely use the same register as
195 All outputs, even unused outputs, can't share registers with any other
196 outputs, independent of `write_stage` settings.
201 return f
"OpStage.{self._name_}"
203 def __lt__(self
, other
):
204 # type: (OpStage | object) -> bool
205 if isinstance(other
, OpStage
):
206 return self
.value
< other
.value
207 return NotImplemented
210 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
213 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
215 class ProgramPoint(Interned
):
222 """ an integer representation of `self` such that it keeps ordering and
223 successor/predecessor relations.
225 return self
.op_index
* 2 + self
.stage
.value
228 def from_int_value(int_value
):
229 # type: (int) -> ProgramPoint
230 op_index
, stage
= divmod(int_value
, 2)
231 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
233 def next(self
, steps
=1):
234 # type: (int) -> ProgramPoint
235 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
237 def prev(self
, steps
=1):
238 # type: (int) -> ProgramPoint
239 return self
.next(steps
=-steps
)
241 def __lt__(self
, other
):
242 # type: (ProgramPoint | Any) -> bool
243 if not isinstance(other
, ProgramPoint
):
244 return NotImplemented
245 if self
.op_index
!= other
.op_index
:
246 return self
.op_index
< other
.op_index
247 return self
.stage
< other
.stage
249 def __gt__(self
, other
):
250 # type: (ProgramPoint | Any) -> bool
251 if not isinstance(other
, ProgramPoint
):
252 return NotImplemented
253 return other
.__lt
__(self
)
255 def __le__(self
, other
):
256 # type: (ProgramPoint | Any) -> bool
257 if not isinstance(other
, ProgramPoint
):
258 return NotImplemented
259 return not self
.__gt
__(other
)
261 def __ge__(self
, other
):
262 # type: (ProgramPoint | Any) -> bool
263 if not isinstance(other
, ProgramPoint
):
264 return NotImplemented
265 return not self
.__lt
__(other
)
269 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
272 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
274 class ProgramRange(Sequence
[ProgramPoint
], Interned
):
279 def int_value_range(self
):
281 return range(self
.start
.int_value
, self
.stop
.int_value
)
284 def from_int_value_range(int_value_range
):
285 # type: (range) -> ProgramRange
286 if int_value_range
.step
!= 1:
287 raise ValueError("int_value_range must have step == 1")
289 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
290 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
293 def __getitem__(self
, __idx
):
294 # type: (int) -> ProgramPoint
298 def __getitem__(self
, __idx
):
299 # type: (slice) -> ProgramRange
302 def __getitem__(self
, __idx
):
303 # type: (int | slice) -> ProgramPoint | ProgramRange
304 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
305 if isinstance(v
, int):
306 return ProgramPoint
.from_int_value(v
)
307 return ProgramRange
.from_int_value_range(v
)
311 return len(self
.int_value_range
)
314 # type: () -> Iterator[ProgramPoint]
315 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
319 start
= repr(self
.start
).lstrip("<").rstrip(">")
320 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
321 return f
"<range:{start}..{stop}>"
324 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
326 class SSAValSubReg(Interned
):
330 def __post_init__(self
):
331 if self
.reg_idx
< 0 or self
.reg_idx
>= self
.ssa_val
.ty
.reg_len
:
332 raise ValueError("reg_idx out of range")
336 return f
"{self.ssa_val}[{self.reg_idx}]"
339 @plain_data.plain_data(frozen
=True, eq
=False, repr=False)
342 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
343 "def_program_ranges", "use_program_points",
344 "all_program_points")
346 def __init__(self
, fn
):
349 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
350 self
.all_program_points
= ProgramRange(
351 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
352 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
353 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
354 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
355 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
356 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
358 for use
in op
.input_uses
:
359 uses
[use
.ssa_val
].add(use
)
360 use_program_point
= self
.__get
_use
_program
_point
(use
)
361 use_program_points
[use
] = use_program_point
362 live_range_stops
[use
.ssa_val
] = max(
363 live_range_stops
[use
.ssa_val
], use_program_point
.next())
364 for out
in op
.outputs
:
366 def_program_range
= self
.__get
_def
_program
_range
(out
)
367 def_program_ranges
[out
] = def_program_range
368 live_range_stops
[out
] = def_program_range
.stop
369 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
370 self
.def_program_ranges
= FMap(def_program_ranges
)
371 self
.use_program_points
= FMap(use_program_points
)
372 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
373 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
374 for ssa_val
in uses
.keys():
375 live_ranges
[ssa_val
] = live_range
= ProgramRange(
376 start
=self
.def_program_ranges
[ssa_val
].start
,
377 stop
=live_range_stops
[ssa_val
])
378 for program_point
in live_range
:
379 live_at
[program_point
].add(ssa_val
)
380 self
.live_ranges
= FMap(live_ranges
)
381 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
382 self
.copies
# initialize
383 self
.const_ssa_vals
# initialize
384 self
.const_ssa_val_sub_regs
# initialize
386 def __get_def_program_range(self
, ssa_val
):
387 # type: (SSAVal) -> ProgramRange
388 write_stage
= ssa_val
.defining_descriptor
.write_stage
389 start
= ProgramPoint(
390 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
391 # always include late stage of ssa_val.op, to ensure outputs always
392 # overlap all other outputs.
393 # stop is exclusive, so we need the next program point.
394 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
395 return ProgramRange(start
=start
, stop
=stop
)
397 def __get_use_program_point(self
, ssa_use
):
398 # type: (SSAUse) -> ProgramPoint
399 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
400 "assumed here, ensured by GenericOpProperties.__init__"
402 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
404 def __eq__(self
, other
):
405 # type: (FnAnalysis | Any) -> bool
406 if isinstance(other
, FnAnalysis
):
407 return self
.fn
== other
.fn
408 return NotImplemented
416 return "<FnAnalysis>"
420 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
421 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
422 a copy of, looking through all layers of copies. The map excludes all
423 SSAValSubRegs that aren't copies of other SSAValSubRegs.
424 This ignores inputs of copy Ops that aren't actually being copied
425 (e.g. the VL input of VecCopyToReg).
427 retval
= {} # type: dict[SSAValSubReg, SSAValSubReg]
428 for op
in self
.op_indexes
.keys():
429 if not op
.properties
.is_copy
:
431 copy_reg_len
= op
.properties
.copy_reg_len
432 copy_inputs
= [] # type: list[SSAValSubReg]
433 for inp
in op
.input_vals
[:op
.properties
.copy_inputs_len
]:
434 for inp_sub_reg
in inp
.ssa_val_sub_regs
:
435 # propagate copies of copies
436 inp_sub_reg
= retval
.get(inp_sub_reg
, inp_sub_reg
)
437 copy_inputs
.append(inp_sub_reg
)
438 assert len(copy_inputs
) == copy_reg_len
, "logic error"
439 copy_outputs
= [] # type: list[SSAValSubReg]
440 for out
in op
.outputs
[:op
.properties
.copy_outputs_len
]:
441 copy_outputs
.extend(out
.ssa_val_sub_regs
)
442 assert len(copy_outputs
) == copy_reg_len
, "logic error"
443 for inp
, out
in zip(copy_inputs
, copy_outputs
):
448 def copy_related_ssa_vals(self
):
449 # type: () -> FMap[SSAVal, OFSet[SSAVal]]
450 """ map from SSAVals to the full set of SSAVals that are related by
451 being sources/destinations of copies, transitively looking through all
453 This ignores inputs of copy Ops that aren't actually being copied
454 (e.g. the VL input of VecCopyToReg).
456 sets_map
= {i
: OSet([i
]) for i
in self
.uses
.keys()}
457 for k
, v
in self
.copies
.items():
458 k_set
= sets_map
[k
.ssa_val
]
459 v_set
= sets_map
[v
.ssa_val
]
460 # merge k_set and v_set
466 # this way we construct each OFSet only once rather than
468 sets_set
= {id(i
): i
for i
in sets_map
.values()}
469 retval
= {} # type: dict[SSAVal, OFSet[SSAVal]]
470 for v
in sets_set
.values():
477 def const_ssa_vals(self
):
478 # type: () -> FMap[SSAVal, tuple[int, ...]]
479 state
= ConstPropagationState(
480 ssa_vals
={}, memory
={}, skipped_ops
=OSet())
482 return FMap(state
.ssa_vals
)
485 def const_ssa_val_sub_regs(self
):
486 # type: () -> FMap[SSAValSubReg, int]
487 retval
= {} # type: dict[SSAValSubReg, int]
488 for ssa_val
, const_val
in self
.const_ssa_vals
.items():
489 assert ssa_val
.ty
.reg_len
== len(const_val
), "logic error"
490 for reg_idx
, v
in enumerate(const_val
):
491 retval
[SSAValSubReg(ssa_val
, reg_idx
)] = v
494 def is_always_equal(self
, a
, b
):
495 # type: (SSAValSubReg, SSAValSubReg) -> bool
496 """check if a and b are known to be always equal to each other.
497 This means they can be allocated to the same location if other
498 constraints don't prevent that.
500 this can happen for a number of reasons, such as:
501 * a and b are copies of the same thing
502 * a and b are known to be constants and they have the same value
504 if a
.ssa_val
.base_ty
!= b
.ssa_val
.base_ty
:
505 return False # can't be equal, they have different types
506 # look through copies
507 a
= self
.copies
.get(a
, a
)
508 b
= self
.copies
.get(b
, b
)
511 # check if they have the same constant value
513 a_const_val
= self
.const_ssa_val_sub_regs
[a
]
514 b_const_val
= self
.const_ssa_val_sub_regs
[b
]
515 if a_const_val
== b_const_val
:
527 VL_MAXVL
= enum
.auto()
530 def only_scalar(self
):
532 if self
is BaseTy
.I64
:
534 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
540 def max_reg_len(self
):
542 if self
is BaseTy
.I64
:
544 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
550 return "BaseTy." + self
._name
_
553 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
560 def validate(base_ty
, reg_len
):
561 # type: (BaseTy, int) -> str | None
562 """ return a string with the error if the combination is invalid,
563 otherwise return None
565 if base_ty
.only_scalar
and reg_len
!= 1:
566 return f
"can't create a vector of an only-scalar type: {base_ty}"
567 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
568 return "reg_len out of range"
571 def __post_init__(self
):
572 msg
= self
.validate(base_ty
=self
.base_ty
, reg_len
=self
.reg_len
)
574 raise ValueError(msg
)
578 if self
.reg_len
!= 1:
579 reg_len
= f
"*{self.reg_len}"
582 return f
"<{self.base_ty._name_}{reg_len}>"
589 StackI64
= enum
.auto()
591 VL_MAXVL
= enum
.auto()
596 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
598 if self
is LocKind
.CA
:
600 if self
is LocKind
.VL_MAXVL
:
601 return BaseTy
.VL_MAXVL
608 if self
is LocKind
.StackI64
:
610 if self
is LocKind
.GPR
or self
is LocKind
.CA \
611 or self
is LocKind
.VL_MAXVL
:
612 return self
.base_ty
.max_reg_len
617 return "LocKind." + self
._name
_
622 class LocSubKind(Enum
):
623 BASE_GPR
= enum
.auto()
624 SV_EXTRA2_VGPR
= enum
.auto()
625 SV_EXTRA2_SGPR
= enum
.auto()
626 SV_EXTRA3_VGPR
= enum
.auto()
627 SV_EXTRA3_SGPR
= enum
.auto()
628 StackI64
= enum
.auto()
630 VL_MAXVL
= enum
.auto()
634 # type: () -> LocKind
635 # pyright fails typechecking when using `in` here:
636 # reported: https://github.com/microsoft/pyright/issues/4102
637 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
638 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
639 LocSubKind
.SV_EXTRA3_SGPR
):
641 if self
is LocSubKind
.StackI64
:
642 return LocKind
.StackI64
643 if self
is LocSubKind
.CA
:
645 if self
is LocSubKind
.VL_MAXVL
:
646 return LocKind
.VL_MAXVL
651 return self
.kind
.base_ty
654 def allocatable_locs(self
, ty
):
655 # type: (Ty) -> LocSet
656 if ty
.base_ty
!= self
.base_ty
:
657 raise ValueError("type mismatch")
658 if self
is LocSubKind
.BASE_GPR
:
660 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
661 starts
= range(0, 128, 2)
662 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
664 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
665 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
667 elif self
is LocSubKind
.StackI64
:
668 starts
= range(LocKind
.StackI64
.loc_count
)
669 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
670 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
673 retval
= [] # type: list[Loc]
675 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
679 for special_loc
in SPECIAL_GPRS
:
680 if loc
.conflicts(special_loc
):
685 return LocSet(retval
)
688 return "LocSubKind." + self
._name
_
691 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True)
693 class GenericTy(Interned
):
697 def __post_init__(self
):
698 if self
.base_ty
.only_scalar
and self
.is_vec
:
699 raise ValueError(f
"base_ty={self.base_ty} requires is_vec=False")
701 def instantiate(self
, maxvl
):
703 # here's where subvl and elwid would be accounted for
705 return Ty(self
.base_ty
, maxvl
)
706 return Ty(self
.base_ty
, 1)
708 def can_instantiate_to(self
, ty
):
710 if self
.base_ty
!= ty
.base_ty
:
714 return ty
.reg_len
== 1
717 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True)
725 def validate(kind
, start
, reg_len
):
726 # type: (LocKind, int, int) -> str | None
727 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
730 if reg_len
> kind
.loc_count
:
731 return "invalid reg_len"
732 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
733 return "start not in valid range"
737 def try_make(kind
, start
, reg_len
):
738 # type: (LocKind, int, int) -> Loc | None
739 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
742 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
744 def __post_init__(self
):
745 msg
= self
.validate(kind
=self
.kind
, start
=self
.start
,
746 reg_len
=self
.reg_len
)
748 raise ValueError(msg
)
750 def conflicts(self
, other
):
751 # type: (Loc) -> bool
752 return (self
.kind
== other
.kind
753 and self
.start
< other
.stop
and other
.start
< self
.stop
)
756 def make_ty(kind
, reg_len
):
757 # type: (LocKind, int) -> Ty
758 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
763 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
768 return self
.start
+ self
.reg_len
770 def try_concat(self
, *others
):
771 # type: (*Loc | None) -> Loc | None
772 reg_len
= self
.reg_len
775 if other
is None or other
.kind
!= self
.kind
:
777 if stop
!= other
.start
:
780 reg_len
+= other
.reg_len
781 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
783 def get_subloc_at_offset(self
, subloc_ty
, offset
):
784 # type: (Ty, int) -> Loc
785 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
786 raise ValueError("BaseTy mismatch")
787 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
788 raise ValueError("invalid sub-Loc: offset and/or "
789 "subloc_ty.reg_len out of range")
790 return Loc(kind
=self
.kind
,
791 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
793 def get_superloc_with_self_at_offset(self
, superloc_ty
, offset
):
794 # type: (Ty, int) -> Loc
795 """get the Loc containing `self` such that:
796 `retval.get_subloc_at_offset(self.ty, offset) == self`
797 and `retval.ty == superloc_ty`
799 if superloc_ty
.base_ty
!= self
.kind
.base_ty
:
800 raise ValueError("BaseTy mismatch")
801 if offset
< 0 or offset
+ self
.reg_len
> superloc_ty
.reg_len
:
802 raise ValueError("invalid sub-Loc: offset and/or "
803 "superloc_ty.reg_len out of range")
804 return Loc(kind
=self
.kind
,
805 start
=self
.start
- offset
, reg_len
=superloc_ty
.reg_len
)
809 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
810 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
811 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
812 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
817 class LocSet(OFSet
[Loc
], Interned
):
818 def __init__(self
, __locs
=()):
819 # type: (Iterable[Loc]) -> None
820 super().__init
__(__locs
)
821 if isinstance(__locs
, LocSet
):
822 self
.__starts
= __locs
.starts
823 self
.__ty
= __locs
.ty
825 starts
= {i
: BitSet() for i
in LocKind
}
826 ty
= None # type: None | Ty
831 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
832 starts
[loc
.kind
].add(loc
.start
)
833 self
.__starts
= FMap(
834 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
839 # type: () -> FMap[LocKind, FBitSet]
844 # type: () -> Ty | None
849 # type: () -> FMap[LocKind, FBitSet]
854 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
858 # type: () -> AbstractSet[LocKind]
859 return self
.starts
.keys()
863 # type: () -> int | None
866 return self
.ty
.reg_len
870 # type: () -> BaseTy | None
873 return self
.ty
.base_ty
875 def concat(self
, *others
):
876 # type: (*LocSet) -> LocSet
879 base_ty
= self
.ty
.base_ty
880 reg_len
= self
.ty
.reg_len
881 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
885 if other
.ty
.base_ty
!= base_ty
:
887 for kind
, other_starts
in other
.starts
.items():
888 if kind
not in starts
:
890 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
891 if starts
[kind
] == 0:
895 reg_len
+= other
.ty
.reg_len
898 # type: () -> Iterable[Loc]
899 for kind
, v
in starts
.items():
901 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
904 return LocSet(locs())
906 @lru_cache(maxsize
=None, typed
=True)
907 def max_conflicts_with(self
, other
):
908 # type: (LocSet | Loc) -> int
909 """the largest number of Locs in `self` that a single Loc
910 from `other` can conflict with
912 if isinstance(other
, LocSet
):
913 return max(self
.max_conflicts_with(i
) for i
in other
)
915 # now we do the equivalent of:
916 # return sum(other.conflicts(i) for i in self)
917 reg_len
= self
.reg_len
920 starts
= self
.starts
.get(other
.kind
)
923 # now we do the equivalent of:
924 # return sum(other.start < start + reg_len
925 # and start < other.start + other.reg_len
926 # for start in starts)
927 stops
= starts
.bits
<< reg_len
929 # find all the bit indexes `i` where `i < other.start + 1`
930 lt_other_start_plus_1
= ~
(~
0 << (other
.start
+ 1))
932 # find all the bit indexes `i` where
933 # `i < other.start + other.reg_len + reg_len`
934 lt_other_start_plus_other_reg_len_plus_reg_len
= (
935 ~
(~
0 << (other
.start
+ other
.reg_len
+ reg_len
)))
936 included
= ~
(stops
& lt_other_start_plus_1
)
938 included
&= lt_other_start_plus_other_reg_len_plus_reg_len
939 return bit_count(included
)
942 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
946 # type: () -> Loc | None
947 """if len(self) == 1 then return the Loc in self, otherwise None"""
953 return None # len(self) > 1
957 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True)
959 class GenericOperandDesc(Interned
):
960 """generic Op operand descriptor"""
962 sub_kinds
: OFSet
[LocSubKind
]
963 fixed_loc
: "Loc | None" = None
964 tied_input_index
: "int | None" = None
966 write_stage
: OpStage
= OpStage
.Early
969 self
, ty
, # type: GenericTy
970 sub_kinds
, # type: Iterable[LocSubKind]
972 fixed_loc
=None, # type: Loc | None
973 tied_input_index
=None, # type: int | None
974 spread
=False, # type: bool
975 write_stage
=OpStage
.Early
, # type: OpStage
977 # type: (...) -> None
978 object.__setattr
__(self
, "ty", ty
)
979 object.__setattr
__(self
, "sub_kinds", OFSet(sub_kinds
))
980 if len(self
.sub_kinds
) == 0:
981 raise ValueError("sub_kinds can't be empty")
982 object.__setattr
__(self
, "fixed_loc", fixed_loc
)
983 if fixed_loc
is not None:
984 if tied_input_index
is not None:
985 raise ValueError("operand can't be both tied and fixed")
986 if not ty
.can_instantiate_to(fixed_loc
.ty
):
988 f
"fixed_loc has incompatible type for given generic "
989 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
990 if len(self
.sub_kinds
) != 1:
992 "multiple sub_kinds not allowed for fixed operand")
993 for sub_kind
in self
.sub_kinds
:
994 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
996 f
"fixed_loc not in given sub_kind: "
997 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
998 for sub_kind
in self
.sub_kinds
:
999 if sub_kind
.base_ty
!= ty
.base_ty
:
1000 raise ValueError(f
"sub_kind is incompatible with type: "
1001 f
"sub_kind={sub_kind} ty={ty}")
1002 if tied_input_index
is not None and tied_input_index
< 0:
1003 raise ValueError("invalid tied_input_index")
1004 object.__setattr
__(self
, "tied_input_index", tied_input_index
)
1005 object.__setattr
__(self
, "spread", spread
)
1007 if self
.tied_input_index
is not None:
1008 raise ValueError("operand can't be both spread and tied")
1009 if self
.fixed_loc
is not None:
1010 raise ValueError("operand can't be both spread and fixed")
1012 raise ValueError("operand can't be both spread and vector")
1013 object.__setattr
__(self
, "write_stage", write_stage
)
1016 def ty_before_spread(self
):
1017 # type: () -> GenericTy
1019 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
1022 def tied_to_input(self
, tied_input_index
):
1023 # type: (int) -> Self
1024 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
1025 tied_input_index
=tied_input_index
,
1026 write_stage
=self
.write_stage
)
1028 def with_fixed_loc(self
, fixed_loc
):
1029 # type: (Loc) -> Self
1030 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
1031 write_stage
=self
.write_stage
)
1033 def with_write_stage(self
, write_stage
):
1034 # type: (OpStage) -> Self
1035 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
1036 fixed_loc
=self
.fixed_loc
,
1037 tied_input_index
=self
.tied_input_index
,
1039 write_stage
=write_stage
)
1041 def instantiate(self
, maxvl
):
1042 # type: (int) -> Iterable[OperandDesc]
1043 # assumes all spread operands have ty.reg_len = 1
1047 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
1049 def locs_before_spread():
1050 # type: () -> Iterable[Loc]
1051 if self
.fixed_loc
is not None:
1052 if ty_before_spread
!= self
.fixed_loc
.ty
:
1054 f
"instantiation failed: type mismatch with fixed_loc: "
1055 f
"instantiated type: {ty_before_spread} "
1056 f
"fixed_loc: {self.fixed_loc}")
1057 yield self
.fixed_loc
1059 for sub_kind
in self
.sub_kinds
:
1060 yield from sub_kind
.allocatable_locs(ty_before_spread
)
1061 loc_set_before_spread
= LocSet(locs_before_spread())
1062 for idx
in range(rep_count
):
1065 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
1066 tied_input_index
=self
.tied_input_index
,
1067 spread_index
=idx
, write_stage
=self
.write_stage
)
1070 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True)
1072 class OperandDesc(Interned
):
1073 """Op operand descriptor"""
1074 loc_set_before_spread
: LocSet
1075 tied_input_index
: "int | None"
1076 spread_index
: "int | None"
1077 write_stage
: "OpStage"
1079 def __post_init__(self
):
1080 if len(self
.loc_set_before_spread
) == 0:
1081 raise ValueError("loc_set_before_spread must not be empty")
1082 if self
.tied_input_index
is not None and self
.spread_index
is not None:
1083 raise ValueError("operand can't be both spread and tied")
1086 def ty_before_spread(self
):
1088 ty
= self
.loc_set_before_spread
.ty
1089 assert ty
is not None, (
1090 "__init__ checked that the LocSet isn't empty, "
1091 "non-empty LocSets should always have ty set")
1096 """ Ty after any spread is applied """
1097 if self
.spread_index
is not None:
1098 # assumes all spread operands have ty.reg_len = 1
1099 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
1100 return self
.ty_before_spread
1103 def reg_offset_in_unspread(self
):
1104 """ the number of reg-sized slots in the unspread Loc before self's Loc
1106 e.g. if the unspread Loc containing self is:
1107 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1108 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1109 then reg_offset_into_unspread == 2 == 10 - 8
1111 if self
.spread_index
is None:
1113 return self
.spread_index
* self
.ty
.reg_len
1116 OD_BASE_SGPR
= GenericOperandDesc(
1117 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1118 sub_kinds
=[LocSubKind
.BASE_GPR
])
1119 OD_EXTRA3_SGPR
= GenericOperandDesc(
1120 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1121 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
1122 OD_EXTRA3_VGPR
= GenericOperandDesc(
1123 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1124 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
1125 OD_EXTRA2_SGPR
= GenericOperandDesc(
1126 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1127 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
1128 OD_EXTRA2_VGPR
= GenericOperandDesc(
1129 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1130 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
1131 OD_CA
= GenericOperandDesc(
1132 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
1133 sub_kinds
=[LocSubKind
.CA
])
1134 OD_VL
= GenericOperandDesc(
1135 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
1136 sub_kinds
=[LocSubKind
.VL_MAXVL
])
1139 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True)
1141 class GenericOpProperties(Interned
):
1143 inputs
: "tuple[GenericOperandDesc, ...]"
1144 outputs
: "tuple[GenericOperandDesc, ...]"
1145 immediates
: "tuple[range, ...]"
1147 is_load_immediate
: bool
1148 has_side_effects
: bool
1151 self
, demo_asm
, # type: str
1152 inputs
, # type: Iterable[GenericOperandDesc]
1153 outputs
, # type: Iterable[GenericOperandDesc]
1154 immediates
=(), # type: Iterable[range]
1155 is_copy
=False, # type: bool
1156 is_load_immediate
=False, # type: bool
1157 has_side_effects
=False, # type: bool
1159 # type: (...) -> None
1160 object.__setattr
__(self
, "demo_asm", demo_asm
)
1161 object.__setattr
__(self
, "inputs", tuple(inputs
))
1162 for inp
in self
.inputs
:
1163 if inp
.tied_input_index
is not None:
1165 f
"tied_input_index is not allowed on inputs: {inp}")
1166 if inp
.write_stage
is not OpStage
.Early
:
1168 f
"write_stage is not allowed on inputs: {inp}")
1169 object.__setattr
__(self
, "outputs", tuple(outputs
))
1170 fixed_locs
= [] # type: list[tuple[Loc, int]]
1171 for idx
, out
in enumerate(self
.outputs
):
1172 if out
.tied_input_index
is not None:
1173 if out
.tied_input_index
>= len(self
.inputs
):
1174 raise ValueError(f
"tied_input_index out of range: {out}")
1175 tied_inp
= self
.inputs
[out
.tied_input_index
]
1176 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
1177 .with_write_stage(out
.write_stage
)
1178 if expected_out
!= out
:
1179 raise ValueError(f
"output can't be tied to non-equivalent "
1180 f
"input: {out} tied to {tied_inp}")
1181 if out
.fixed_loc
is not None:
1182 for other_fixed_loc
, other_idx
in fixed_locs
:
1183 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
1186 f
"conflicting fixed_locs: outputs[{idx}] and "
1187 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1188 f
"with {other_fixed_loc}")
1189 fixed_locs
.append((out
.fixed_loc
, idx
))
1190 object.__setattr
__(self
, "immediates", tuple(immediates
))
1191 object.__setattr
__(self
, "is_copy", is_copy
)
1192 object.__setattr
__(self
, "is_load_immediate", is_load_immediate
)
1193 object.__setattr
__(self
, "has_side_effects", has_side_effects
)
1196 @plain_data.plain_data(frozen
=True, unsafe_hash
=True)
1199 __slots__
= "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1201 def __init__(self
, kind
, maxvl
):
1202 # type: (OpKind, int) -> None
1203 self
.kind
= kind
# type: OpKind
1204 inputs
= [] # type: list[OperandDesc]
1205 for inp
in self
.generic
.inputs
:
1206 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1207 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1208 outputs
= [] # type: list[OperandDesc]
1209 for out
in self
.generic
.outputs
:
1210 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1211 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1212 self
.maxvl
= maxvl
# type: int
1213 copy_input_reg_len
= 0
1214 for inp
in self
.inputs
[:self
.copy_inputs_len
]:
1215 copy_input_reg_len
+= inp
.ty
.reg_len
1216 copy_output_reg_len
= 0
1217 for out
in self
.outputs
[:self
.copy_outputs_len
]:
1218 copy_output_reg_len
+= out
.ty
.reg_len
1219 if copy_input_reg_len
!= copy_output_reg_len
:
1220 raise ValueError(f
"invalid copy: copy's input reg len must "
1221 f
"match its output reg len: "
1222 f
"{copy_input_reg_len} != {copy_output_reg_len}")
1223 self
.copy_reg_len
= copy_input_reg_len
1227 # type: () -> GenericOpProperties
1228 return self
.kind
.properties
1231 def immediates(self
):
1232 # type: () -> tuple[range, ...]
1233 return self
.generic
.immediates
1238 return self
.generic
.demo_asm
1243 return self
.generic
.is_copy
1246 def is_load_immediate(self
):
1248 return self
.generic
.is_load_immediate
1251 def has_side_effects(self
):
1253 return self
.generic
.has_side_effects
1256 def copy_inputs_len(self
):
1258 if not self
.is_copy
:
1260 if self
.inputs
[0].spread_index
is None:
1263 for i
, inp
in enumerate(self
.inputs
):
1264 if inp
.spread_index
!= i
:
1270 def copy_outputs_len(self
):
1272 if not self
.is_copy
:
1274 if self
.outputs
[0].spread_index
is None:
1277 for i
, out
in enumerate(self
.outputs
):
1278 if out
.spread_index
!= i
:
1284 IMM_S16
= range(-1 << 15, 1 << 15)
1286 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1287 _SIM_FN2
= Callable
[[], _SIM_FN
]
1288 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1289 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1290 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1291 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1297 def __init__(self
, properties
):
1298 # type: (GenericOpProperties) -> None
1300 self
.__properties
= properties
1303 def properties(self
):
1304 # type: () -> GenericOpProperties
1305 return self
.__properties
1307 def instantiate(self
, maxvl
):
1308 # type: (int) -> OpProperties
1309 return OpProperties(self
, maxvl
=maxvl
)
1313 return "OpKind." + self
._name
_
1317 # type: () -> _SIM_FN
1318 return _SIM_FNS
[self
.properties
]()
1322 # type: () -> _GEN_ASM_FN
1323 return _GEN_ASMS
[self
.properties
]()
1326 def __clearca_sim(op
, state
):
1327 # type: (Op, BaseSimState) -> None
1328 state
[op
.outputs
[0]] = False,
1331 def __clearca_gen_asm(op
, state
):
1332 # type: (Op, GenAsmState) -> None
1333 state
.writeln("addic 0, 0, 0")
1334 ClearCA
= GenericOpProperties(
1335 demo_asm
="addic 0, 0, 0",
1337 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1339 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1340 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1343 def __setca_sim(op
, state
):
1344 # type: (Op, BaseSimState) -> None
1345 state
[op
.outputs
[0]] = True,
1348 def __setca_gen_asm(op
, state
):
1349 # type: (Op, GenAsmState) -> None
1350 state
.writeln("subfc 0, 0, 0")
1351 SetCA
= GenericOpProperties(
1352 demo_asm
="subfc 0, 0, 0",
1354 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1356 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1357 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1360 def __svadde_sim(op
, state
):
1361 # type: (Op, BaseSimState) -> None
1362 RA
= state
[op
.input_vals
[0]]
1363 RB
= state
[op
.input_vals
[1]]
1364 carry
, = state
[op
.input_vals
[2]]
1365 VL
, = state
[op
.input_vals
[3]]
1366 RT
= [] # type: list[int]
1368 v
= RA
[i
] + RB
[i
] + carry
1369 RT
.append(v
& GPR_VALUE_MASK
)
1370 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1371 state
[op
.outputs
[0]] = tuple(RT
)
1372 state
[op
.outputs
[1]] = carry
,
1375 def __svadde_gen_asm(op
, state
):
1376 # type: (Op, GenAsmState) -> None
1377 RT
= state
.vgpr(op
.outputs
[0])
1378 RA
= state
.vgpr(op
.input_vals
[0])
1379 RB
= state
.vgpr(op
.input_vals
[1])
1380 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1381 SvAddE
= GenericOpProperties(
1382 demo_asm
="sv.adde *RT, *RA, *RB",
1383 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1384 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1386 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1387 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1390 def __addze_sim(op
, state
):
1391 # type: (Op, BaseSimState) -> None
1392 RA
, = state
[op
.input_vals
[0]]
1393 carry
, = state
[op
.input_vals
[1]]
1395 RT
= v
& GPR_VALUE_MASK
1396 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1397 state
[op
.outputs
[0]] = RT
,
1398 state
[op
.outputs
[1]] = carry
,
1401 def __addze_gen_asm(op
, state
):
1402 # type: (Op, GenAsmState) -> None
1403 RT
= state
.vgpr(op
.outputs
[0])
1404 RA
= state
.vgpr(op
.input_vals
[0])
1405 state
.writeln(f
"addze {RT}, {RA}")
1406 AddZE
= GenericOpProperties(
1407 demo_asm
="addze RT, RA",
1408 inputs
=[OD_BASE_SGPR
, OD_CA
],
1409 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1411 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1412 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1415 def __svsubfe_sim(op
, state
):
1416 # type: (Op, BaseSimState) -> None
1417 RA
= state
[op
.input_vals
[0]]
1418 RB
= state
[op
.input_vals
[1]]
1419 carry
, = state
[op
.input_vals
[2]]
1420 VL
, = state
[op
.input_vals
[3]]
1421 RT
= [] # type: list[int]
1423 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1424 RT
.append(v
& GPR_VALUE_MASK
)
1425 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1426 state
[op
.outputs
[0]] = tuple(RT
)
1427 state
[op
.outputs
[1]] = carry
,
1430 def __svsubfe_gen_asm(op
, state
):
1431 # type: (Op, GenAsmState) -> None
1432 RT
= state
.vgpr(op
.outputs
[0])
1433 RA
= state
.vgpr(op
.input_vals
[0])
1434 RB
= state
.vgpr(op
.input_vals
[1])
1435 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1436 SvSubFE
= GenericOpProperties(
1437 demo_asm
="sv.subfe *RT, *RA, *RB",
1438 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1439 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1441 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1442 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1445 def __svandvs_sim(op
, state
):
1446 # type: (Op, BaseSimState) -> None
1447 RA
= state
[op
.input_vals
[0]]
1448 RB
, = state
[op
.input_vals
[1]]
1449 VL
, = state
[op
.input_vals
[2]]
1450 RT
= [] # type: list[int]
1452 RT
.append(RA
[i
] & RB
& GPR_VALUE_MASK
)
1453 state
[op
.outputs
[0]] = tuple(RT
)
1456 def __svandvs_gen_asm(op
, state
):
1457 # type: (Op, GenAsmState) -> None
1458 RT
= state
.vgpr(op
.outputs
[0])
1459 RA
= state
.vgpr(op
.input_vals
[0])
1460 RB
= state
.sgpr(op
.input_vals
[1])
1461 state
.writeln(f
"sv.and {RT}, {RA}, {RB}")
1462 SvAndVS
= GenericOpProperties(
1463 demo_asm
="sv.and *RT, *RA, RB",
1464 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1465 outputs
=[OD_EXTRA3_VGPR
],
1467 _SIM_FNS
[SvAndVS
] = lambda: OpKind
.__svandvs
_sim
1468 _GEN_ASMS
[SvAndVS
] = lambda: OpKind
.__svandvs
_gen
_asm
1471 def __svmaddedu_sim(op
, state
):
1472 # type: (Op, BaseSimState) -> None
1473 RA
= state
[op
.input_vals
[0]]
1474 RB
, = state
[op
.input_vals
[1]]
1475 carry
, = state
[op
.input_vals
[2]]
1476 VL
, = state
[op
.input_vals
[3]]
1477 RT
= [] # type: list[int]
1479 v
= RA
[i
] * RB
+ carry
1480 RT
.append(v
& GPR_VALUE_MASK
)
1481 carry
= v
>> GPR_SIZE_IN_BITS
1482 state
[op
.outputs
[0]] = tuple(RT
)
1483 state
[op
.outputs
[1]] = carry
,
1486 def __svmaddedu_gen_asm(op
, state
):
1487 # type: (Op, GenAsmState) -> None
1488 RT
= state
.vgpr(op
.outputs
[0])
1489 RA
= state
.vgpr(op
.input_vals
[0])
1490 RB
= state
.sgpr(op
.input_vals
[1])
1491 RC
= state
.sgpr(op
.input_vals
[2])
1492 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1493 SvMAddEDU
= GenericOpProperties(
1494 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1495 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1496 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1498 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1499 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1502 def __sradi_sim(op
, state
):
1503 # type: (Op, BaseSimState) -> None
1504 rs
, = state
[op
.input_vals
[0]]
1505 imm
= op
.immediates
[0]
1506 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1507 rs
-= 1 << GPR_SIZE_IN_BITS
1509 RA
= v
& GPR_VALUE_MASK
1510 CA
= (RA
<< imm
) != rs
1511 state
[op
.outputs
[0]] = RA
,
1512 state
[op
.outputs
[1]] = CA
,
1515 def __sradi_gen_asm(op
, state
):
1516 # type: (Op, GenAsmState) -> None
1517 RA
= state
.sgpr(op
.outputs
[0])
1518 RS
= state
.sgpr(op
.input_vals
[0])
1519 imm
= op
.immediates
[0]
1520 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1521 SRADI
= GenericOpProperties(
1522 demo_asm
="sradi RA, RS, imm",
1523 inputs
=[OD_BASE_SGPR
],
1524 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1525 OD_CA
.with_write_stage(OpStage
.Late
)],
1526 immediates
=[range(GPR_SIZE_IN_BITS
)],
1528 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1529 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1532 def __setvli_sim(op
, state
):
1533 # type: (Op, BaseSimState) -> None
1534 state
[op
.outputs
[0]] = op
.immediates
[0],
1537 def __setvli_gen_asm(op
, state
):
1538 # type: (Op, GenAsmState) -> None
1539 imm
= op
.immediates
[0]
1540 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1541 SetVLI
= GenericOpProperties(
1542 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1544 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1545 immediates
=[range(1, 65)],
1546 is_load_immediate
=True,
1548 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1549 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1552 def __svli_sim(op
, state
):
1553 # type: (Op, BaseSimState) -> None
1554 VL
, = state
[op
.input_vals
[0]]
1555 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1556 state
[op
.outputs
[0]] = (imm
,) * VL
1559 def __svli_gen_asm(op
, state
):
1560 # type: (Op, GenAsmState) -> None
1561 RT
= state
.vgpr(op
.outputs
[0])
1562 imm
= op
.immediates
[0]
1563 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1564 SvLI
= GenericOpProperties(
1565 demo_asm
="sv.addi *RT, 0, imm",
1567 outputs
=[OD_EXTRA3_VGPR
],
1568 immediates
=[IMM_S16
],
1569 is_load_immediate
=True,
1571 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1572 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1575 def __li_sim(op
, state
):
1576 # type: (Op, BaseSimState) -> None
1577 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1578 state
[op
.outputs
[0]] = imm
,
1581 def __li_gen_asm(op
, state
):
1582 # type: (Op, GenAsmState) -> None
1583 RT
= state
.sgpr(op
.outputs
[0])
1584 imm
= op
.immediates
[0]
1585 state
.writeln(f
"addi {RT}, 0, {imm}")
1586 LI
= GenericOpProperties(
1587 demo_asm
="addi RT, 0, imm",
1589 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1590 immediates
=[IMM_S16
],
1591 is_load_immediate
=True,
1593 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1594 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1597 def __veccopytoreg_sim(op
, state
):
1598 # type: (Op, BaseSimState) -> None
1599 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1602 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1603 # type: (Loc, Loc, bool, GenAsmState) -> None
1604 sv
= "sv." if is_vec
else ""
1606 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1608 if src_loc
== dest_loc
:
1610 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1611 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1612 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1613 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1614 if src_loc
.kind
is LocKind
.StackI64
:
1615 if dest_loc
.kind
is LocKind
.StackI64
:
1617 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1618 elif dest_loc
.kind
is not LocKind
.GPR
:
1619 assert_never(dest_loc
.kind
)
1620 src
= state
.stack(src_loc
)
1621 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1622 state
.writeln(f
"{sv}ld {dest}, {src}")
1623 elif dest_loc
.kind
is LocKind
.StackI64
:
1624 if src_loc
.kind
is not LocKind
.GPR
:
1625 assert_never(src_loc
.kind
)
1626 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1627 dest
= state
.stack(dest_loc
)
1628 state
.writeln(f
"{sv}std {src}, {dest}")
1629 elif src_loc
.kind
is LocKind
.GPR
:
1630 if dest_loc
.kind
is not LocKind
.GPR
:
1631 assert_never(dest_loc
.kind
)
1632 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1633 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1634 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1636 assert_never(src_loc
.kind
)
1639 def __veccopytoreg_gen_asm(op
, state
):
1640 # type: (Op, GenAsmState) -> None
1641 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1643 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1644 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1645 is_vec
=True, state
=state
)
1647 VecCopyToReg
= GenericOpProperties(
1648 demo_asm
="sv.mv dest, src",
1649 inputs
=[GenericOperandDesc(
1650 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1651 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1653 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1656 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1657 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1660 def __veccopyfromreg_sim(op
, state
):
1661 # type: (Op, BaseSimState) -> None
1662 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1665 def __veccopyfromreg_gen_asm(op
, state
):
1666 # type: (Op, GenAsmState) -> None
1667 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1668 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1670 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1671 is_vec
=True, state
=state
)
1672 VecCopyFromReg
= GenericOpProperties(
1673 demo_asm
="sv.mv dest, src",
1674 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1675 outputs
=[GenericOperandDesc(
1676 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1677 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1678 write_stage
=OpStage
.Late
,
1682 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1683 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1686 def __copytoreg_sim(op
, state
):
1687 # type: (Op, BaseSimState) -> None
1688 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1691 def __copytoreg_gen_asm(op
, state
):
1692 # type: (Op, GenAsmState) -> None
1693 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1695 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1696 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1697 is_vec
=False, state
=state
)
1698 CopyToReg
= GenericOpProperties(
1699 demo_asm
="mv dest, src",
1700 inputs
=[GenericOperandDesc(
1701 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1702 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1703 LocSubKind
.StackI64
],
1705 outputs
=[GenericOperandDesc(
1706 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1707 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1708 write_stage
=OpStage
.Late
,
1712 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1713 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1716 def __copyfromreg_sim(op
, state
):
1717 # type: (Op, BaseSimState) -> None
1718 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1721 def __copyfromreg_gen_asm(op
, state
):
1722 # type: (Op, GenAsmState) -> None
1723 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1724 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1726 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1727 is_vec
=False, state
=state
)
1728 CopyFromReg
= GenericOpProperties(
1729 demo_asm
="mv dest, src",
1730 inputs
=[GenericOperandDesc(
1731 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1732 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1734 outputs
=[GenericOperandDesc(
1735 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1736 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1737 LocSubKind
.StackI64
],
1738 write_stage
=OpStage
.Late
,
1742 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1743 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1746 def __concat_sim(op
, state
):
1747 # type: (Op, BaseSimState) -> None
1748 state
[op
.outputs
[0]] = tuple(
1749 state
[i
][0] for i
in op
.input_vals
[:-1])
1752 def __concat_gen_asm(op
, state
):
1753 # type: (Op, GenAsmState) -> None
1754 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1755 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1756 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1757 is_vec
=True, state
=state
)
1758 Concat
= GenericOpProperties(
1759 demo_asm
="sv.mv dest, src",
1760 inputs
=[GenericOperandDesc(
1761 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1762 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1765 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1768 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1769 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1772 def __spread_sim(op
, state
):
1773 # type: (Op, BaseSimState) -> None
1774 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1775 state
[op
.outputs
[idx
]] = inp
,
1778 def __spread_gen_asm(op
, state
):
1779 # type: (Op, GenAsmState) -> None
1780 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1781 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1782 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1783 is_vec
=True, state
=state
)
1784 Spread
= GenericOpProperties(
1785 demo_asm
="sv.mv dest, src",
1786 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1787 outputs
=[GenericOperandDesc(
1788 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1789 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1791 write_stage
=OpStage
.Late
,
1795 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1796 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1799 def __svld_sim(op
, state
):
1800 # type: (Op, BaseSimState) -> None
1801 RA
, = state
[op
.input_vals
[0]]
1802 VL
, = state
[op
.input_vals
[1]]
1803 addr
= RA
+ op
.immediates
[0]
1804 RT
= [] # type: list[int]
1806 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1807 RT
.append(v
& GPR_VALUE_MASK
)
1808 state
[op
.outputs
[0]] = tuple(RT
)
1811 def __svld_gen_asm(op
, state
):
1812 # type: (Op, GenAsmState) -> None
1813 RA
= state
.sgpr(op
.input_vals
[0])
1814 RT
= state
.vgpr(op
.outputs
[0])
1815 imm
= op
.immediates
[0]
1816 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1817 SvLd
= GenericOpProperties(
1818 demo_asm
="sv.ld *RT, imm(RA)",
1819 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1820 outputs
=[OD_EXTRA3_VGPR
],
1821 immediates
=[IMM_S16
],
1823 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1824 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1827 def __ld_sim(op
, state
):
1828 # type: (Op, BaseSimState) -> None
1829 RA
, = state
[op
.input_vals
[0]]
1830 addr
= RA
+ op
.immediates
[0]
1831 v
= state
.load(addr
)
1832 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1835 def __ld_gen_asm(op
, state
):
1836 # type: (Op, GenAsmState) -> None
1837 RA
= state
.sgpr(op
.input_vals
[0])
1838 RT
= state
.sgpr(op
.outputs
[0])
1839 imm
= op
.immediates
[0]
1840 state
.writeln(f
"ld {RT}, {imm}({RA})")
1841 Ld
= GenericOpProperties(
1842 demo_asm
="ld RT, imm(RA)",
1843 inputs
=[OD_BASE_SGPR
],
1844 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1845 immediates
=[IMM_S16
],
1847 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1848 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1851 def __svstd_sim(op
, state
):
1852 # type: (Op, BaseSimState) -> None
1853 RS
= state
[op
.input_vals
[0]]
1854 RA
, = state
[op
.input_vals
[1]]
1855 VL
, = state
[op
.input_vals
[2]]
1856 addr
= RA
+ op
.immediates
[0]
1858 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1861 def __svstd_gen_asm(op
, state
):
1862 # type: (Op, GenAsmState) -> None
1863 RS
= state
.vgpr(op
.input_vals
[0])
1864 RA
= state
.sgpr(op
.input_vals
[1])
1865 imm
= op
.immediates
[0]
1866 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1867 SvStd
= GenericOpProperties(
1868 demo_asm
="sv.std *RS, imm(RA)",
1869 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1871 immediates
=[IMM_S16
],
1872 has_side_effects
=True,
1874 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1875 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1878 def __std_sim(op
, state
):
1879 # type: (Op, BaseSimState) -> None
1880 RS
, = state
[op
.input_vals
[0]]
1881 RA
, = state
[op
.input_vals
[1]]
1882 addr
= RA
+ op
.immediates
[0]
1883 state
.store(addr
, value
=RS
)
1886 def __std_gen_asm(op
, state
):
1887 # type: (Op, GenAsmState) -> None
1888 RS
= state
.sgpr(op
.input_vals
[0])
1889 RA
= state
.sgpr(op
.input_vals
[1])
1890 imm
= op
.immediates
[0]
1891 state
.writeln(f
"std {RS}, {imm}({RA})")
1892 Std
= GenericOpProperties(
1893 demo_asm
="std RS, imm(RA)",
1894 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1896 immediates
=[IMM_S16
],
1897 has_side_effects
=True,
1899 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1900 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1903 def __funcargr3_sim(op
, state
):
1904 # type: (Op, BaseSimState) -> None
1905 pass # return value set before simulation
1908 def __funcargr3_gen_asm(op
, state
):
1909 # type: (Op, GenAsmState) -> None
1910 pass # no instructions needed
1911 FuncArgR3
= GenericOpProperties(
1914 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1915 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1917 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1918 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1921 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
1922 class SSAValOrUse(Interned
):
1926 def __post_init__(self
):
1927 if self
.operand_idx
< 0 or \
1928 self
.operand_idx
>= len(self
.descriptor_array
):
1929 raise ValueError("invalid operand_idx")
1938 def descriptor_array(self
):
1939 # type: () -> tuple[OperandDesc, ...]
1943 def defining_descriptor(self
):
1944 # type: () -> OperandDesc
1945 return self
.descriptor_array
[self
.operand_idx
]
1950 return self
.defining_descriptor
.ty
1953 def ty_before_spread(self
):
1955 return self
.defining_descriptor
.ty_before_spread
1959 # type: () -> BaseTy
1960 return self
.ty_before_spread
.base_ty
1963 def reg_offset_in_unspread(self
):
1964 """ the number of reg-sized slots in the unspread Loc before self's Loc
1966 e.g. if the unspread Loc containing self is:
1967 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1968 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1969 then reg_offset_into_unspread == 2 == 10 - 8
1971 return self
.defining_descriptor
.reg_offset_in_unspread
1974 def unspread_start_idx(self
):
1976 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1979 def unspread_start(self
):
1981 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1984 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
1986 class SSAVal(SSAValOrUse
):
1991 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1994 def def_loc_set_before_spread(self
):
1995 # type: () -> LocSet
1996 return self
.defining_descriptor
.loc_set_before_spread
1999 def descriptor_array(self
):
2000 # type: () -> tuple[OperandDesc, ...]
2001 return self
.op
.properties
.outputs
2004 def tied_input(self
):
2005 # type: () -> None | SSAUse
2006 if self
.defining_descriptor
.tied_input_index
is None:
2008 return SSAUse(op
=self
.op
,
2009 operand_idx
=self
.defining_descriptor
.tied_input_index
)
2012 def write_stage(self
):
2013 # type: () -> OpStage
2014 return self
.defining_descriptor
.write_stage
2017 def current_debugging_value(self
):
2018 # type: () -> tuple[int, ...]
2019 """ get the current value for debugging in pdb or similar.
2021 This is intended for use with
2022 `PreRASimState.set_current_debugging_state`.
2024 This is only intended for debugging, do not use in unit tests or
2027 return PreRASimState
.get_current_debugging_state()[self
]
2030 def ssa_val_sub_regs(self
):
2031 # type: () -> tuple[SSAValSubReg, ...]
2032 return tuple(SSAValSubReg(self
, i
) for i
in range(self
.ty
.reg_len
))
2035 @dataclasses.dataclass(frozen
=True, unsafe_hash
=True, repr=False)
2037 class SSAUse(SSAValOrUse
):
2041 def use_loc_set_before_spread(self
):
2042 # type: () -> LocSet
2043 return self
.defining_descriptor
.loc_set_before_spread
2046 def descriptor_array(self
):
2047 # type: () -> tuple[OperandDesc, ...]
2048 return self
.op
.properties
.inputs
2052 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
2056 # type: () -> SSAVal
2057 return self
.op
.input_vals
[self
.operand_idx
]
2060 def ssa_val(self
, ssa_val
):
2061 # type: (SSAVal) -> None
2062 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
2066 _Desc
= TypeVar("_Desc")
2069 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
2071 def _verify_write_with_desc(self
, idx
, item
, desc
):
2072 # type: (int, _T | Any, _Desc) -> None
2073 raise NotImplementedError
2076 def _verify_write(self
, idx
, item
):
2077 # type: (int | Any, _T | Any) -> int
2078 if not isinstance(idx
, int):
2079 if isinstance(idx
, slice):
2081 f
"can't write to slice of {self.__class__.__name__}")
2082 raise TypeError(f
"can't write with index {idx!r}")
2083 # normalize idx, raising IndexError if it is out of range
2084 idx
= range(len(self
.descriptors
))[idx
]
2085 desc
= self
.descriptors
[idx
]
2086 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
2089 def _on_set(self
, idx
, new_item
, old_item
):
2090 # type: (int, _T, _T | None) -> None
2094 def _get_descriptors(self
):
2095 # type: () -> tuple[_Desc, ...]
2096 raise NotImplementedError
2100 def descriptors(self
):
2101 # type: () -> tuple[_Desc, ...]
2102 return self
._get
_descriptors
()
2109 def __init__(self
, items
, op
):
2110 # type: (Iterable[_T], Op) -> None
2113 self
.__items
= [] # type: list[_T]
2114 for idx
, item
in enumerate(items
):
2115 if idx
>= len(self
.descriptors
):
2116 raise ValueError("too many items")
2117 _
= self
._verify
_write
(idx
, item
)
2118 self
.__items
.append(item
)
2119 if len(self
.__items
) < len(self
.descriptors
):
2120 raise ValueError("not enough items")
2124 # type: () -> Iterator[_T]
2125 yield from self
.__items
2128 def __getitem__(self
, idx
):
2133 def __getitem__(self
, idx
):
2134 # type: (slice) -> list[_T]
2138 def __getitem__(self
, idx
):
2139 # type: (int | slice) -> _T | list[_T]
2140 return self
.__items
[idx
]
2143 def __setitem__(self
, idx
, item
):
2144 # type: (int, _T) -> None
2145 idx
= self
._verify
_write
(idx
, item
)
2146 self
.__items
[idx
] = item
2151 return len(self
.__items
)
2155 return f
"{self.__class__.__name__}({self.__items}, op=...)"
2159 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
2160 def _get_descriptors(self
):
2161 # type: () -> tuple[OperandDesc, ...]
2162 return self
.op
.properties
.inputs
2164 def _verify_write_with_desc(self
, idx
, item
, desc
):
2165 # type: (int, SSAVal | Any, OperandDesc) -> None
2166 if not isinstance(item
, SSAVal
):
2167 raise TypeError("expected value of type SSAVal")
2168 if item
.ty
!= desc
.ty
:
2169 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
2170 f
"corresponding input's type {desc.ty!r}")
2172 def _on_set(self
, idx
, new_item
, old_item
):
2173 # type: (int, SSAVal, SSAVal | None) -> None
2174 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
2176 def __init__(self
, items
, op
):
2177 # type: (Iterable[SSAVal], Op) -> None
2178 if hasattr(op
, "inputs"):
2179 raise ValueError("Op.inputs already set")
2180 super().__init
__(items
, op
)
2184 class OpImmediates(OpInputSeq
[int, range]):
2185 def _get_descriptors(self
):
2186 # type: () -> tuple[range, ...]
2187 return self
.op
.properties
.immediates
2189 def _verify_write_with_desc(self
, idx
, item
, desc
):
2190 # type: (int, int | Any, range) -> None
2191 if not isinstance(item
, int):
2192 raise TypeError("expected value of type int")
2193 if item
not in desc
:
2194 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
2196 def __init__(self
, items
, op
):
2197 # type: (Iterable[int], Op) -> None
2198 if hasattr(op
, "immediates"):
2199 raise ValueError("Op.immediates already set")
2200 super().__init
__(items
, op
)
2203 @plain_data.plain_data(frozen
=True, eq
=False, repr=False)
2206 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
2209 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
2210 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2212 self
.properties
= properties
2213 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
2214 inputs_len
= len(self
.properties
.inputs
)
2215 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
2216 self
.immediates
= OpImmediates(immediates
, op
=self
)
2217 outputs_len
= len(self
.properties
.outputs
)
2218 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
2219 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
2223 # type: () -> OpKind
2224 return self
.properties
.kind
2226 def __eq__(self
, other
):
2227 # type: (Op | Any) -> bool
2228 if isinstance(other
, Op
):
2229 return self
is other
2230 return NotImplemented
2234 return object.__hash
__(self
)
2236 def __repr__(self
, wrap_width
=63, indent
=" "):
2237 # type: (int, str) -> str
2238 WRAP_POINT
= "\u200B" # zero-width space
2239 items
= [f
"{self.name}:\n"]
2240 for i
, out
in enumerate(self
.outputs
):
2241 item
= f
"<...outputs[{i}]: {out.ty}>"
2243 item
= "(" + WRAP_POINT
+ item
2244 if i
!= len(self
.outputs
) - 1:
2245 item
+= ", " + WRAP_POINT
2247 item
+= WRAP_POINT
+ ") <= "
2249 items
.append(self
.kind
._name
_)
2250 if len(self
.input_vals
) + len(self
.immediates
) != 0:
2252 items
[-1] += WRAP_POINT
2253 for i
, inp
in enumerate(self
.input_vals
):
2255 if i
!= len(self
.input_vals
) - 1 or len(self
.immediates
) != 0:
2256 item
+= ", " + WRAP_POINT
2258 item
+= ") " + WRAP_POINT
2260 for i
, imm
in enumerate(self
.immediates
):
2262 if i
!= len(self
.immediates
) - 1:
2263 item
+= ", " + WRAP_POINT
2265 item
+= ") " + WRAP_POINT
2267 lines
= [] # type: list[str]
2268 for i
, line_in
in enumerate("".join(items
).splitlines()):
2270 line_in
= indent
+ line_in
2272 for part
in line_in
.split(WRAP_POINT
):
2276 trial_line_out
= line_out
+ part
2277 if len(trial_line_out
.rstrip()) > wrap_width
:
2278 lines
.append(line_out
.rstrip())
2279 line_out
= indent
+ part
2281 line_out
= trial_line_out
2282 lines
.append(line_out
.rstrip())
2283 return "\n".join(lines
)
2285 def sim(self
, state
):
2286 # type: (BaseSimState) -> None
2287 for inp
in self
.input_vals
:
2291 raise ValueError(f
"SSAVal {inp} not yet assigned when "
2295 if len(val
) != inp
.ty
.reg_len
:
2297 f
"value of SSAVal {inp} has wrong number of elements: "
2298 f
"expected {inp.ty.reg_len} found "
2299 f
"{len(val)}: {val!r}")
2300 if isinstance(state
, PreRASimState
):
2301 for out
in self
.outputs
:
2302 if out
in state
.ssa_vals
:
2303 if self
.kind
is OpKind
.FuncArgR3
:
2305 raise ValueError(f
"SSAVal {out} already assigned before "
2308 self
.kind
.sim(self
, state
)
2311 for out
in self
.outputs
:
2315 raise ValueError(f
"running {self} failed to assign to {out}")
2318 if len(val
) != out
.ty
.reg_len
:
2320 f
"value of SSAVal {out} has wrong number of elements: "
2321 f
"expected {out.ty.reg_len} found "
2322 f
"{len(val)}: {val!r}")
2324 def gen_asm(self
, state
):
2325 # type: (GenAsmState) -> None
2326 all_loc_kinds
= tuple(LocKind
)
2327 for inp
in self
.input_vals
:
2328 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2329 for out
in self
.outputs
:
2330 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2331 self
.kind
.gen_asm(self
, state
)
2334 @plain_data.plain_data(frozen
=True, repr=False)
2335 class BaseSimState(metaclass
=ABCMeta
):
2336 __slots__
= "memory",
2338 def __init__(self
, memory
):
2339 # type: (dict[int, int]) -> None
2341 self
.memory
= memory
# type: dict[int, int]
2343 def _default_memory_value(self
):
2347 def on_skip(self
, op
):
2348 # type: (Op) -> None
2349 raise ValueError("skipping instructions not supported")
2351 def load_byte(self
, addr
):
2352 # type: (int) -> int
2353 addr
&= GPR_VALUE_MASK
2355 return self
.memory
[addr
] & 0xFF
2357 return self
._default
_memory
_value
()
2359 def store_byte(self
, addr
, value
):
2360 # type: (int, int) -> None
2361 addr
&= GPR_VALUE_MASK
2363 self
.memory
[addr
] = value
2365 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2366 # type: (int, int, bool) -> int
2367 if addr
% size_in_bytes
!= 0:
2368 raise ValueError(f
"address not aligned: {hex(addr)} "
2369 f
"required alignment: {size_in_bytes}")
2371 for i
in range(size_in_bytes
):
2372 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2373 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2374 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2377 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2378 # type: (int, int, int) -> None
2379 if addr
% size_in_bytes
!= 0:
2380 raise ValueError(f
"address not aligned: {hex(addr)} "
2381 f
"required alignment: {size_in_bytes}")
2382 for i
in range(size_in_bytes
):
2383 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2385 def _memory__repr(self
):
2387 if len(self
.memory
) == 0:
2389 keys
= sorted(self
.memory
.keys(), reverse
=True)
2390 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2391 items
= [] # type: list[str]
2392 while len(keys
) != 0:
2394 if (len(keys
) >= CHUNK_SIZE
2395 and addr
% CHUNK_SIZE
== 0
2396 and keys
[-CHUNK_SIZE
:]
2397 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2398 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2399 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2400 keys
[-CHUNK_SIZE
:] = ()
2402 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2404 return f
"{{{items[0]}}}"
2405 items_str
= ",\n".join(items
)
2406 return f
"{{\n{items_str}}}"
2410 field_vals
= [] # type: list[str]
2411 for name
in plain_data
.fields(self
):
2413 value
= getattr(self
, name
)
2414 except AttributeError:
2415 field_vals
.append(f
"{name}=<not set>")
2417 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2418 if callable(repr_fn
):
2419 field_vals
.append(f
"{name}={repr_fn()}")
2421 field_vals
.append(f
"{name}={value!r}")
2422 field_vals_str
= ", ".join(field_vals
)
2423 return f
"{self.__class__.__name__}({field_vals_str})"
2426 def __getitem__(self
, ssa_val
):
2427 # type: (SSAVal) -> tuple[int, ...]
2431 def __setitem__(self
, ssa_val
, value
):
2432 # type: (SSAVal, Iterable[int]) -> None
2436 @plain_data.plain_data(frozen
=True, repr=False)
2437 class PreRABaseSimState(BaseSimState
):
2438 __slots__
= "ssa_vals",
2440 def __init__(self
, ssa_vals
, memory
):
2441 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2442 super().__init
__(memory
)
2443 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2445 def _ssa_vals__repr(self
):
2447 if len(self
.ssa_vals
) == 0:
2449 items
= [] # type: list[str]
2451 for k
, v
in self
.ssa_vals
.items():
2452 element_strs
= [] # type: list[str]
2453 for i
, el
in enumerate(v
):
2454 if i
% CHUNK_SIZE
!= 0:
2455 element_strs
.append(" " + hex(el
))
2457 element_strs
.append("\n " + hex(el
))
2458 if len(element_strs
) <= CHUNK_SIZE
:
2459 element_strs
[0] = element_strs
[0].lstrip()
2460 if len(element_strs
) == 1:
2461 element_strs
.append("")
2462 v_str
= ",".join(element_strs
)
2463 items
.append(f
"{k!r}: ({v_str})")
2464 if len(items
) == 1 and "\n" not in items
[0]:
2465 return f
"{{{items[0]}}}"
2466 items_str
= ",\n".join(items
)
2467 return f
"{{\n{items_str},\n}}"
2469 def __getitem__(self
, ssa_val
):
2470 # type: (SSAVal) -> tuple[int, ...]
2472 return self
.ssa_vals
[ssa_val
]
2474 return self
._handle
_undefined
_ssa
_val
(ssa_val
)
2476 def _handle_undefined_ssa_val(self
, ssa_val
):
2477 # type: (SSAVal) -> tuple[int, ...]
2478 raise KeyError("SSAVal has no value set", ssa_val
)
2480 def __setitem__(self
, ssa_val
, value
):
2481 # type: (SSAVal, Iterable[int]) -> None
2482 value
= tuple(map(int, value
))
2483 if len(value
) != ssa_val
.ty
.reg_len
:
2484 raise ValueError("value has wrong len")
2485 self
.ssa_vals
[ssa_val
] = value
2488 class SimSkipOp(Exception):
2492 @plain_data.plain_data(frozen
=True, repr=False)
2494 class ConstPropagationState(PreRABaseSimState
):
2495 __slots__
= "skipped_ops",
2497 def __init__(self
, ssa_vals
, memory
, skipped_ops
):
2498 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2499 super().__init
__(ssa_vals
, memory
)
2500 self
.skipped_ops
= skipped_ops
2502 def _default_memory_value(self
):
2506 def _handle_undefined_ssa_val(self
, ssa_val
):
2507 # type: (SSAVal) -> tuple[int, ...]
2510 def on_skip(self
, op
):
2511 # type: (Op) -> None
2512 self
.skipped_ops
.add(op
)
2515 @plain_data.plain_data(frozen
=True, repr=False)
2516 class PreRASimState(PreRABaseSimState
):
2519 __CURRENT_DEBUGGING_STATE
= [] # type: list[PreRASimState]
2522 def set_as_current_debugging_state(self
):
2523 """ return a context manager that sets self as the current state for
2524 debugging in pdb or similar. This is intended only for use with
2525 `get_current_debugging_state` which should not be used in unit tests
2529 PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.append(self
)
2532 assert self
is PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.pop(), \
2533 "inconsistent __CURRENT_DEBUGGING_STATE"
2536 def get_current_debugging_state():
2537 # type: () -> PreRASimState
2538 """ get the current state for debugging in pdb or similar.
2540 This is intended for use with `set_current_debugging_state`.
2542 This is only intended for debugging, do not use in unit tests or
2545 if len(PreRASimState
.__CURRENT
_DEBUGGING
_STATE
) == 0:
2546 raise ValueError("no current debugging state")
2547 return PreRASimState
.__CURRENT
_DEBUGGING
_STATE
[-1]
2550 @plain_data.plain_data(frozen
=True, repr=False)
2552 class PostRASimState(BaseSimState
):
2553 __slots__
= "ssa_val_to_loc_map", "loc_values"
2555 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2556 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2557 super().__init
__(memory
)
2558 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2559 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2560 if ssa_val
.ty
!= loc
.ty
:
2562 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2563 self
.loc_values
= loc_values
2564 for loc
in self
.loc_values
.keys():
2565 if loc
.reg_len
!= 1:
2567 "loc_values must only contain Locs with reg_len=1, all "
2568 "larger Locs will be split into reg_len=1 sub-Locs")
2570 def _loc_values__repr(self
):
2572 locs
= sorted(self
.loc_values
.keys(),
2573 key
=lambda v
: (v
.kind
.name
, v
.start
))
2574 items
= [] # type: list[str]
2576 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2577 items_str
= ",\n".join(items
)
2578 return f
"{{\n{items_str},\n}}"
2580 def __getitem__(self
, ssa_val
):
2581 # type: (SSAVal) -> tuple[int, ...]
2582 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2583 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2584 retval
= [] # type: list[int]
2585 for i
in range(loc
.reg_len
):
2586 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2587 retval
.append(self
.loc_values
.get(subloc
, 0))
2588 return tuple(retval
)
2590 def __setitem__(self
, ssa_val
, value
):
2591 # type: (SSAVal, Iterable[int]) -> None
2592 value
= tuple(map(int, value
))
2593 if len(value
) != ssa_val
.ty
.reg_len
:
2594 raise ValueError("value has wrong len")
2595 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2596 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2597 for i
in range(loc
.reg_len
):
2598 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2599 self
.loc_values
[subloc
] = value
[i
]
2602 @plain_data.plain_data(frozen
=True)
2604 __slots__
= "allocated_locs", "output"
2606 def __init__(self
, allocated_locs
, output
=None):
2607 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2609 self
.allocated_locs
= FMap(allocated_locs
)
2610 for ssa_val
, loc
in self
.allocated_locs
.items():
2611 if ssa_val
.ty
!= loc
.ty
:
2613 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2616 self
.output
= output
2618 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2620 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2621 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2622 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2623 ssa_val_or_locs
= [ssa_val_or_locs
]
2624 locs
= [] # type: list[Loc]
2625 for i
in ssa_val_or_locs
:
2626 if isinstance(i
, SSAVal
):
2627 locs
.append(self
.allocated_locs
[i
])
2631 raise ValueError("invalid Loc sequence: must not be empty")
2632 retval
= locs
[0].try_concat(*locs
[1:])
2634 raise ValueError("invalid Loc sequence: try_concat failed")
2635 if isinstance(expected_kinds
, LocKind
):
2636 expected_kinds
= expected_kinds
,
2637 if retval
.kind
not in expected_kinds
:
2638 if len(expected_kinds
) == 1:
2639 expected_kinds
= expected_kinds
[0]
2640 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2641 f
"{retval.kind} expected {expected_kinds}")
2644 def gpr(self
, ssa_val_or_locs
, is_vec
):
2645 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2646 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2647 vec_str
= "*" if is_vec
else ""
2648 return vec_str
+ str(loc
.start
)
2650 def sgpr(self
, ssa_val_or_locs
):
2651 # type: (__SSA_VAL_OR_LOCS) -> str
2652 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2654 def vgpr(self
, ssa_val_or_locs
):
2655 # type: (__SSA_VAL_OR_LOCS) -> str
2656 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2658 def stack(self
, ssa_val_or_locs
):
2659 # type: (__SSA_VAL_OR_LOCS) -> str
2660 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2661 return f
"{loc.start}(1)"
2663 def writeln(self
, *line_segments
):
2664 # type: (*str) -> None
2665 line
= " ".join(line_segments
)
2666 if isinstance(self
.output
, list):
2667 self
.output
.append(line
)
2669 self
.output
.write(line
+ "\n")