2 from abc
import ABCMeta
, abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
, total_ordering
5 from io
import StringIO
6 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
7 Mapping
, Sequence
, TypeVar
, Union
, overload
)
8 from weakref
import WeakValueDictionary
as _WeakVDict
10 from cached_property
import cached_property
11 from nmutil
.plain_data
import fields
, plain_data
13 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
15 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
22 self
.ops
= [] # type: list[Op]
23 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
24 self
.__next
_name
_suffix
= 2
26 def _add_op_with_unused_name(self
, op
, name
=""):
27 # type: (Op, str) -> str
29 raise ValueError("can't add Op to wrong Fn")
30 if hasattr(op
, "name"):
31 raise ValueError("Op already named")
34 if name
!= "" and name
not in self
.__op
_names
:
35 self
.__op
_names
[name
] = op
37 name
= orig_name
+ str(self
.__next
_name
_suffix
)
38 self
.__next
_name
_suffix
+= 1
44 def append_op(self
, op
):
47 raise ValueError("can't add Op to wrong Fn")
50 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
52 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
53 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
54 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
55 self
.append_op(retval
)
59 # type: (BaseSimState) -> None
63 def gen_asm(self
, state
):
64 # type: (GenAsmState) -> None
68 def pre_ra_insert_copies(self
):
70 orig_ops
= list(self
.ops
)
71 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
72 setvli_outputs
= {} # type: dict[SSAVal, Op]
75 for i
in range(len(op
.input_vals
)):
76 inp
= copied_outputs
[op
.input_vals
[i
]]
77 if inp
.ty
.base_ty
is BaseTy
.I64
:
78 maxvl
= inp
.ty
.reg_len
79 if inp
.ty
.reg_len
!= 1:
80 setvl
= self
.append_new_op(
81 OpKind
.SetVLI
, immediates
=[maxvl
],
82 name
=f
"{op.name}.inp{i}.setvl")
84 mv
= self
.append_new_op(
85 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
86 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
88 mv
= self
.append_new_op(
89 OpKind
.CopyToReg
, input_vals
=[inp
],
90 name
=f
"{op.name}.inp{i}.copy")
91 op
.input_vals
[i
] = mv
.outputs
[0]
92 elif inp
.ty
.base_ty
is BaseTy
.CA \
93 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
94 # all copies would be no-ops, so we don't need to copy,
95 # though we do need to rematerialize SetVLI ops right
97 if inp
in setvli_outputs
:
98 setvl
= self
.append_new_op(
100 immediates
=setvli_outputs
[inp
].immediates
,
101 name
=f
"{op.name}.inp{i}.setvl")
102 inp
= setvl
.outputs
[0]
103 op
.input_vals
[i
] = inp
105 assert_never(inp
.ty
.base_ty
)
107 for i
, out
in enumerate(op
.outputs
):
108 if op
.kind
is OpKind
.SetVLI
:
109 setvli_outputs
[out
] = op
110 if out
.ty
.base_ty
is BaseTy
.I64
:
111 maxvl
= out
.ty
.reg_len
112 if out
.ty
.reg_len
!= 1:
113 setvl
= self
.append_new_op(
114 OpKind
.SetVLI
, immediates
=[maxvl
],
115 name
=f
"{op.name}.out{i}.setvl")
116 vl
= setvl
.outputs
[0]
117 mv
= self
.append_new_op(
118 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
119 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
121 mv
= self
.append_new_op(
122 OpKind
.CopyFromReg
, input_vals
=[out
],
123 name
=f
"{op.name}.out{i}.copy")
124 copied_outputs
[out
] = mv
.outputs
[0]
125 elif out
.ty
.base_ty
is BaseTy
.CA \
126 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
127 # all copies would be no-ops, so we don't need to copy
128 copied_outputs
[out
] = out
130 assert_never(out
.ty
.base_ty
)
137 value
: Literal
[0, 1] # type: ignore
139 def __new__(cls
, value
):
140 # type: (int) -> OpStage
142 if value
not in (0, 1):
143 raise ValueError("invalid value")
144 retval
= object.__new
__(cls
)
145 retval
._value
_ = value
149 """ early stage of Op execution, where all input reads occur.
150 all output writes with `write_stage == Early` occur here too, and therefore
151 conflict with input reads, telling the compiler that it that can't share
152 that output's register with any inputs that the output isn't tied to.
154 All outputs, even unused outputs, can't share registers with any other
155 outputs, independent of `write_stage` settings.
158 """ late stage of Op execution, where all output writes with
159 `write_stage == Late` occur, and therefore don't conflict with input reads,
160 telling the compiler that any inputs can safely use the same register as
163 All outputs, even unused outputs, can't share registers with any other
164 outputs, independent of `write_stage` settings.
169 return f
"OpStage.{self._name_}"
171 def __lt__(self
, other
):
172 # type: (OpStage | object) -> bool
173 if isinstance(other
, OpStage
):
174 return self
.value
< other
.value
175 return NotImplemented
178 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
181 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
184 class ProgramPoint(metaclass
=InternedMeta
):
185 __slots__
= "op_index", "stage"
187 def __init__(self
, op_index
, stage
):
188 # type: (int, OpStage) -> None
189 self
.op_index
= op_index
195 """ an integer representation of `self` such that it keeps ordering and
196 successor/predecessor relations.
198 return self
.op_index
* 2 + self
.stage
.value
201 def from_int_value(int_value
):
202 # type: (int) -> ProgramPoint
203 op_index
, stage
= divmod(int_value
, 2)
204 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
206 def next(self
, steps
=1):
207 # type: (int) -> ProgramPoint
208 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
210 def prev(self
, steps
=1):
211 # type: (int) -> ProgramPoint
212 return self
.next(steps
=-steps
)
214 def __lt__(self
, other
):
215 # type: (ProgramPoint | Any) -> bool
216 if not isinstance(other
, ProgramPoint
):
217 return NotImplemented
218 if self
.op_index
!= other
.op_index
:
219 return self
.op_index
< other
.op_index
220 return self
.stage
< other
.stage
224 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
227 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
229 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
230 __slots__
= "start", "stop"
232 def __init__(self
, start
, stop
):
233 # type: (ProgramPoint, ProgramPoint) -> None
238 def int_value_range(self
):
240 return range(self
.start
.int_value
, self
.stop
.int_value
)
243 def from_int_value_range(int_value_range
):
244 # type: (range) -> ProgramRange
245 if int_value_range
.step
!= 1:
246 raise ValueError("int_value_range must have step == 1")
248 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
249 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
252 def __getitem__(self
, __idx
):
253 # type: (int) -> ProgramPoint
257 def __getitem__(self
, __idx
):
258 # type: (slice) -> ProgramRange
261 def __getitem__(self
, __idx
):
262 # type: (int | slice) -> ProgramPoint | ProgramRange
263 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
264 if isinstance(v
, int):
265 return ProgramPoint
.from_int_value(v
)
266 return ProgramRange
.from_int_value_range(v
)
270 return len(self
.int_value_range
)
273 # type: () -> Iterator[ProgramPoint]
274 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
278 start
= repr(self
.start
).lstrip("<").rstrip(">")
279 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
280 return f
"<range:{start}..{stop}>"
283 @plain_data(frozen
=True, eq
=False, repr=False)
286 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
287 "def_program_ranges", "use_program_points",
288 "all_program_points")
290 def __init__(self
, fn
):
293 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
294 self
.all_program_points
= ProgramRange(
295 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
296 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
297 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
298 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
299 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
300 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
302 for use
in op
.input_uses
:
303 uses
[use
.ssa_val
].add(use
)
304 use_program_point
= self
.__get
_use
_program
_point
(use
)
305 use_program_points
[use
] = use_program_point
306 live_range_stops
[use
.ssa_val
] = max(
307 live_range_stops
[use
.ssa_val
], use_program_point
.next())
308 for out
in op
.outputs
:
310 def_program_range
= self
.__get
_def
_program
_range
(out
)
311 def_program_ranges
[out
] = def_program_range
312 live_range_stops
[out
] = def_program_range
.stop
313 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
314 self
.def_program_ranges
= FMap(def_program_ranges
)
315 self
.use_program_points
= FMap(use_program_points
)
316 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
317 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
318 for ssa_val
in uses
.keys():
319 live_ranges
[ssa_val
] = live_range
= ProgramRange(
320 start
=self
.def_program_ranges
[ssa_val
].start
,
321 stop
=live_range_stops
[ssa_val
])
322 for program_point
in live_range
:
323 live_at
[program_point
].add(ssa_val
)
324 self
.live_ranges
= FMap(live_ranges
)
325 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
327 def __get_def_program_range(self
, ssa_val
):
328 # type: (SSAVal) -> ProgramRange
329 write_stage
= ssa_val
.defining_descriptor
.write_stage
330 start
= ProgramPoint(
331 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
332 # always include late stage of ssa_val.op, to ensure outputs always
333 # overlap all other outputs.
334 # stop is exclusive, so we need the next program point.
335 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
336 return ProgramRange(start
=start
, stop
=stop
)
338 def __get_use_program_point(self
, ssa_use
):
339 # type: (SSAUse) -> ProgramPoint
340 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
341 "assumed here, ensured by GenericOpProperties.__init__"
343 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
345 def __eq__(self
, other
):
346 # type: (FnAnalysis | Any) -> bool
347 if isinstance(other
, FnAnalysis
):
348 return self
.fn
== other
.fn
349 return NotImplemented
357 return "<FnAnalysis>"
365 VL_MAXVL
= enum
.auto()
368 def only_scalar(self
):
370 if self
is BaseTy
.I64
:
372 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
378 def max_reg_len(self
):
380 if self
is BaseTy
.I64
:
382 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
388 return "BaseTy." + self
._name
_
391 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
393 class Ty(metaclass
=InternedMeta
):
394 __slots__
= "base_ty", "reg_len"
397 def validate(base_ty
, reg_len
):
398 # type: (BaseTy, int) -> str | None
399 """ return a string with the error if the combination is invalid,
400 otherwise return None
402 if base_ty
.only_scalar
and reg_len
!= 1:
403 return f
"can't create a vector of an only-scalar type: {base_ty}"
404 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
405 return "reg_len out of range"
408 def __init__(self
, base_ty
, reg_len
):
409 # type: (BaseTy, int) -> None
410 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
412 raise ValueError(msg
)
413 self
.base_ty
= base_ty
414 self
.reg_len
= reg_len
418 if self
.reg_len
!= 1:
419 reg_len
= f
"*{self.reg_len}"
422 return f
"<{self.base_ty._name_}{reg_len}>"
429 StackI64
= enum
.auto()
431 VL_MAXVL
= enum
.auto()
436 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
438 if self
is LocKind
.CA
:
440 if self
is LocKind
.VL_MAXVL
:
441 return BaseTy
.VL_MAXVL
448 if self
is LocKind
.StackI64
:
450 if self
is LocKind
.GPR
or self
is LocKind
.CA \
451 or self
is LocKind
.VL_MAXVL
:
452 return self
.base_ty
.max_reg_len
457 return "LocKind." + self
._name
_
462 class LocSubKind(Enum
):
463 BASE_GPR
= enum
.auto()
464 SV_EXTRA2_VGPR
= enum
.auto()
465 SV_EXTRA2_SGPR
= enum
.auto()
466 SV_EXTRA3_VGPR
= enum
.auto()
467 SV_EXTRA3_SGPR
= enum
.auto()
468 StackI64
= enum
.auto()
470 VL_MAXVL
= enum
.auto()
474 # type: () -> LocKind
475 # pyright fails typechecking when using `in` here:
476 # reported: https://github.com/microsoft/pyright/issues/4102
477 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
478 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
479 LocSubKind
.SV_EXTRA3_SGPR
):
481 if self
is LocSubKind
.StackI64
:
482 return LocKind
.StackI64
483 if self
is LocSubKind
.CA
:
485 if self
is LocSubKind
.VL_MAXVL
:
486 return LocKind
.VL_MAXVL
491 return self
.kind
.base_ty
494 def allocatable_locs(self
, ty
):
495 # type: (Ty) -> LocSet
496 if ty
.base_ty
!= self
.base_ty
:
497 raise ValueError("type mismatch")
498 if self
is LocSubKind
.BASE_GPR
:
500 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
501 starts
= range(0, 128, 2)
502 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
504 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
505 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
507 elif self
is LocSubKind
.StackI64
:
508 starts
= range(LocKind
.StackI64
.loc_count
)
509 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
510 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
513 retval
= [] # type: list[Loc]
515 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
519 for special_loc
in SPECIAL_GPRS
:
520 if loc
.conflicts(special_loc
):
525 return LocSet(retval
)
528 return "LocSubKind." + self
._name
_
531 @plain_data(frozen
=True, unsafe_hash
=True)
533 class GenericTy(metaclass
=InternedMeta
):
534 __slots__
= "base_ty", "is_vec"
536 def __init__(self
, base_ty
, is_vec
):
537 # type: (BaseTy, bool) -> None
538 self
.base_ty
= base_ty
539 if base_ty
.only_scalar
and is_vec
:
540 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
543 def instantiate(self
, maxvl
):
545 # here's where subvl and elwid would be accounted for
547 return Ty(self
.base_ty
, maxvl
)
548 return Ty(self
.base_ty
, 1)
550 def can_instantiate_to(self
, ty
):
552 if self
.base_ty
!= ty
.base_ty
:
556 return ty
.reg_len
== 1
559 @plain_data(frozen
=True, unsafe_hash
=True)
561 class Loc(metaclass
=InternedMeta
):
562 __slots__
= "kind", "start", "reg_len"
565 def validate(kind
, start
, reg_len
):
566 # type: (LocKind, int, int) -> str | None
567 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
570 if reg_len
> kind
.loc_count
:
571 return "invalid reg_len"
572 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
573 return "start not in valid range"
577 def try_make(kind
, start
, reg_len
):
578 # type: (LocKind, int, int) -> Loc | None
579 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
582 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
584 def __init__(self
, kind
, start
, reg_len
):
585 # type: (LocKind, int, int) -> None
586 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
588 raise ValueError(msg
)
590 self
.reg_len
= reg_len
593 def conflicts(self
, other
):
594 # type: (Loc) -> bool
595 return (self
.kind
== other
.kind
596 and self
.start
< other
.stop
and other
.start
< self
.stop
)
599 def make_ty(kind
, reg_len
):
600 # type: (LocKind, int) -> Ty
601 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
606 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
611 return self
.start
+ self
.reg_len
613 def try_concat(self
, *others
):
614 # type: (*Loc | None) -> Loc | None
615 reg_len
= self
.reg_len
618 if other
is None or other
.kind
!= self
.kind
:
620 if stop
!= other
.start
:
623 reg_len
+= other
.reg_len
624 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
626 def get_subloc_at_offset(self
, subloc_ty
, offset
):
627 # type: (Ty, int) -> Loc
628 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
629 raise ValueError("BaseTy mismatch")
630 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
631 raise ValueError("invalid sub-Loc: offset and/or "
632 "subloc_ty.reg_len out of range")
633 return Loc(kind
=self
.kind
,
634 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
638 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
639 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
640 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
641 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
646 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
647 def __init__(self
, __locs
=()):
648 # type: (Iterable[Loc]) -> None
649 super().__init
__(__locs
)
650 if isinstance(__locs
, LocSet
):
651 self
.__starts
= __locs
.starts
652 self
.__ty
= __locs
.ty
654 starts
= {i
: BitSet() for i
in LocKind
}
655 ty
= None # type: None | Ty
660 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
661 starts
[loc
.kind
].add(loc
.start
)
662 self
.__starts
= FMap(
663 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
668 # type: () -> FMap[LocKind, FBitSet]
673 # type: () -> Ty | None
678 # type: () -> FMap[LocKind, FBitSet]
683 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
687 # type: () -> AbstractSet[LocKind]
688 return self
.starts
.keys()
692 # type: () -> int | None
695 return self
.ty
.reg_len
699 # type: () -> BaseTy | None
702 return self
.ty
.base_ty
704 def concat(self
, *others
):
705 # type: (*LocSet) -> LocSet
708 base_ty
= self
.ty
.base_ty
709 reg_len
= self
.ty
.reg_len
710 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
714 if other
.ty
.base_ty
!= base_ty
:
716 for kind
, other_starts
in other
.starts
.items():
717 if kind
not in starts
:
719 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
720 if starts
[kind
] == 0:
724 reg_len
+= other
.ty
.reg_len
727 # type: () -> Iterable[Loc]
728 for kind
, v
in starts
.items():
730 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
733 return LocSet(locs())
735 @lru_cache(maxsize
=None, typed
=True)
736 def max_conflicts_with(self
, other
):
737 # type: (LocSet | Loc) -> int
738 """the largest number of Locs in `self` that a single Loc
739 from `other` can conflict with
741 if isinstance(other
, LocSet
):
742 return max(self
.max_conflicts_with(i
) for i
in other
)
744 return sum(other
.conflicts(i
) for i
in self
)
747 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
750 @plain_data(frozen
=True, unsafe_hash
=True)
752 class GenericOperandDesc(metaclass
=InternedMeta
):
753 """generic Op operand descriptor"""
754 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
758 self
, ty
, # type: GenericTy
759 sub_kinds
, # type: Iterable[LocSubKind]
761 fixed_loc
=None, # type: Loc | None
762 tied_input_index
=None, # type: int | None
763 spread
=False, # type: bool
764 write_stage
=OpStage
.Early
, # type: OpStage
766 # type: (...) -> None
768 self
.sub_kinds
= OFSet(sub_kinds
)
769 if len(self
.sub_kinds
) == 0:
770 raise ValueError("sub_kinds can't be empty")
771 self
.fixed_loc
= fixed_loc
772 if fixed_loc
is not None:
773 if tied_input_index
is not None:
774 raise ValueError("operand can't be both tied and fixed")
775 if not ty
.can_instantiate_to(fixed_loc
.ty
):
777 f
"fixed_loc has incompatible type for given generic "
778 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
779 if len(self
.sub_kinds
) != 1:
781 "multiple sub_kinds not allowed for fixed operand")
782 for sub_kind
in self
.sub_kinds
:
783 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
785 f
"fixed_loc not in given sub_kind: "
786 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
787 for sub_kind
in self
.sub_kinds
:
788 if sub_kind
.base_ty
!= ty
.base_ty
:
789 raise ValueError(f
"sub_kind is incompatible with type: "
790 f
"sub_kind={sub_kind} ty={ty}")
791 if tied_input_index
is not None and tied_input_index
< 0:
792 raise ValueError("invalid tied_input_index")
793 self
.tied_input_index
= tied_input_index
796 if self
.tied_input_index
is not None:
797 raise ValueError("operand can't be both spread and tied")
798 if self
.fixed_loc
is not None:
799 raise ValueError("operand can't be both spread and fixed")
801 raise ValueError("operand can't be both spread and vector")
802 self
.write_stage
= write_stage
805 def ty_before_spread(self
):
806 # type: () -> GenericTy
808 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
811 def tied_to_input(self
, tied_input_index
):
812 # type: (int) -> Self
813 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
814 tied_input_index
=tied_input_index
,
815 write_stage
=self
.write_stage
)
817 def with_fixed_loc(self
, fixed_loc
):
818 # type: (Loc) -> Self
819 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
820 write_stage
=self
.write_stage
)
822 def with_write_stage(self
, write_stage
):
823 # type: (OpStage) -> Self
824 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
825 fixed_loc
=self
.fixed_loc
,
826 tied_input_index
=self
.tied_input_index
,
828 write_stage
=write_stage
)
830 def instantiate(self
, maxvl
):
831 # type: (int) -> Iterable[OperandDesc]
832 # assumes all spread operands have ty.reg_len = 1
836 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
838 def locs_before_spread():
839 # type: () -> Iterable[Loc]
840 if self
.fixed_loc
is not None:
841 if ty_before_spread
!= self
.fixed_loc
.ty
:
843 f
"instantiation failed: type mismatch with fixed_loc: "
844 f
"instantiated type: {ty_before_spread} "
845 f
"fixed_loc: {self.fixed_loc}")
848 for sub_kind
in self
.sub_kinds
:
849 yield from sub_kind
.allocatable_locs(ty_before_spread
)
850 loc_set_before_spread
= LocSet(locs_before_spread())
851 for idx
in range(rep_count
):
854 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
855 tied_input_index
=self
.tied_input_index
,
856 spread_index
=idx
, write_stage
=self
.write_stage
)
859 @plain_data(frozen
=True, unsafe_hash
=True)
861 class OperandDesc(metaclass
=InternedMeta
):
862 """Op operand descriptor"""
863 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
866 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
868 # type: (LocSet, int | None, int | None, OpStage) -> None
869 if len(loc_set_before_spread
) == 0:
870 raise ValueError("loc_set_before_spread must not be empty")
871 self
.loc_set_before_spread
= loc_set_before_spread
872 self
.tied_input_index
= tied_input_index
873 if self
.tied_input_index
is not None and spread_index
is not None:
874 raise ValueError("operand can't be both spread and tied")
875 self
.spread_index
= spread_index
876 self
.write_stage
= write_stage
879 def ty_before_spread(self
):
881 ty
= self
.loc_set_before_spread
.ty
882 assert ty
is not None, (
883 "__init__ checked that the LocSet isn't empty, "
884 "non-empty LocSets should always have ty set")
889 """ Ty after any spread is applied """
890 if self
.spread_index
is not None:
891 # assumes all spread operands have ty.reg_len = 1
892 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
893 return self
.ty_before_spread
896 def reg_offset_in_unspread(self
):
897 """ the number of reg-sized slots in the unspread Loc before self's Loc
899 e.g. if the unspread Loc containing self is:
900 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
901 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
902 then reg_offset_into_unspread == 2 == 10 - 8
904 if self
.spread_index
is None:
906 return self
.spread_index
* self
.ty
.reg_len
909 OD_BASE_SGPR
= GenericOperandDesc(
910 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
911 sub_kinds
=[LocSubKind
.BASE_GPR
])
912 OD_EXTRA3_SGPR
= GenericOperandDesc(
913 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
914 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
915 OD_EXTRA3_VGPR
= GenericOperandDesc(
916 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
917 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
918 OD_EXTRA2_SGPR
= GenericOperandDesc(
919 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
920 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
921 OD_EXTRA2_VGPR
= GenericOperandDesc(
922 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
923 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
924 OD_CA
= GenericOperandDesc(
925 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
926 sub_kinds
=[LocSubKind
.CA
])
927 OD_VL
= GenericOperandDesc(
928 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
929 sub_kinds
=[LocSubKind
.VL_MAXVL
])
932 @plain_data(frozen
=True, unsafe_hash
=True)
934 class GenericOpProperties(metaclass
=InternedMeta
):
935 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
936 "is_copy", "is_load_immediate", "has_side_effects")
939 self
, demo_asm
, # type: str
940 inputs
, # type: Iterable[GenericOperandDesc]
941 outputs
, # type: Iterable[GenericOperandDesc]
942 immediates
=(), # type: Iterable[range]
943 is_copy
=False, # type: bool
944 is_load_immediate
=False, # type: bool
945 has_side_effects
=False, # type: bool
947 # type: (...) -> None
948 self
.demo_asm
= demo_asm
# type: str
949 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
950 for inp
in self
.inputs
:
951 if inp
.tied_input_index
is not None:
953 f
"tied_input_index is not allowed on inputs: {inp}")
954 if inp
.write_stage
is not OpStage
.Early
:
956 f
"write_stage is not allowed on inputs: {inp}")
957 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
958 fixed_locs
= [] # type: list[tuple[Loc, int]]
959 for idx
, out
in enumerate(self
.outputs
):
960 if out
.tied_input_index
is not None:
961 if out
.tied_input_index
>= len(self
.inputs
):
962 raise ValueError(f
"tied_input_index out of range: {out}")
963 tied_inp
= self
.inputs
[out
.tied_input_index
]
964 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
965 .with_write_stage(out
.write_stage
)
966 if expected_out
!= out
:
967 raise ValueError(f
"output can't be tied to non-equivalent "
968 f
"input: {out} tied to {tied_inp}")
969 if out
.fixed_loc
is not None:
970 for other_fixed_loc
, other_idx
in fixed_locs
:
971 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
974 f
"conflicting fixed_locs: outputs[{idx}] and "
975 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
976 f
"with {other_fixed_loc}")
977 fixed_locs
.append((out
.fixed_loc
, idx
))
978 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
979 self
.is_copy
= is_copy
# type: bool
980 self
.is_load_immediate
= is_load_immediate
# type: bool
981 self
.has_side_effects
= has_side_effects
# type: bool
984 @plain_data(frozen
=True, unsafe_hash
=True)
986 class OpProperties(metaclass
=InternedMeta
):
987 __slots__
= "kind", "inputs", "outputs", "maxvl"
989 def __init__(self
, kind
, maxvl
):
990 # type: (OpKind, int) -> None
991 self
.kind
= kind
# type: OpKind
992 inputs
= [] # type: list[OperandDesc]
993 for inp
in self
.generic
.inputs
:
994 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
995 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
996 outputs
= [] # type: list[OperandDesc]
997 for out
in self
.generic
.outputs
:
998 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
999 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1000 self
.maxvl
= maxvl
# type: int
1004 # type: () -> GenericOpProperties
1005 return self
.kind
.properties
1008 def immediates(self
):
1009 # type: () -> tuple[range, ...]
1010 return self
.generic
.immediates
1015 return self
.generic
.demo_asm
1020 return self
.generic
.is_copy
1023 def is_load_immediate(self
):
1025 return self
.generic
.is_load_immediate
1028 def has_side_effects(self
):
1030 return self
.generic
.has_side_effects
1033 IMM_S16
= range(-1 << 15, 1 << 15)
1035 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1036 _SIM_FN2
= Callable
[[], _SIM_FN
]
1037 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1038 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1039 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1040 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1046 def __init__(self
, properties
):
1047 # type: (GenericOpProperties) -> None
1049 self
.__properties
= properties
1052 def properties(self
):
1053 # type: () -> GenericOpProperties
1054 return self
.__properties
1056 def instantiate(self
, maxvl
):
1057 # type: (int) -> OpProperties
1058 return OpProperties(self
, maxvl
=maxvl
)
1062 return "OpKind." + self
._name
_
1066 # type: () -> _SIM_FN
1067 return _SIM_FNS
[self
.properties
]()
1071 # type: () -> _GEN_ASM_FN
1072 return _GEN_ASMS
[self
.properties
]()
1075 def __clearca_sim(op
, state
):
1076 # type: (Op, BaseSimState) -> None
1077 state
[op
.outputs
[0]] = False,
1080 def __clearca_gen_asm(op
, state
):
1081 # type: (Op, GenAsmState) -> None
1082 state
.writeln("addic 0, 0, 0")
1083 ClearCA
= GenericOpProperties(
1084 demo_asm
="addic 0, 0, 0",
1086 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1088 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1089 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1092 def __setca_sim(op
, state
):
1093 # type: (Op, BaseSimState) -> None
1094 state
[op
.outputs
[0]] = True,
1097 def __setca_gen_asm(op
, state
):
1098 # type: (Op, GenAsmState) -> None
1099 state
.writeln("subfc 0, 0, 0")
1100 SetCA
= GenericOpProperties(
1101 demo_asm
="subfc 0, 0, 0",
1103 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1105 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1106 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1109 def __svadde_sim(op
, state
):
1110 # type: (Op, BaseSimState) -> None
1111 RA
= state
[op
.input_vals
[0]]
1112 RB
= state
[op
.input_vals
[1]]
1113 carry
, = state
[op
.input_vals
[2]]
1114 VL
, = state
[op
.input_vals
[3]]
1115 RT
= [] # type: list[int]
1117 v
= RA
[i
] + RB
[i
] + carry
1118 RT
.append(v
& GPR_VALUE_MASK
)
1119 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1120 state
[op
.outputs
[0]] = tuple(RT
)
1121 state
[op
.outputs
[1]] = carry
,
1124 def __svadde_gen_asm(op
, state
):
1125 # type: (Op, GenAsmState) -> None
1126 RT
= state
.vgpr(op
.outputs
[0])
1127 RA
= state
.vgpr(op
.input_vals
[0])
1128 RB
= state
.vgpr(op
.input_vals
[1])
1129 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1130 SvAddE
= GenericOpProperties(
1131 demo_asm
="sv.adde *RT, *RA, *RB",
1132 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1133 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1135 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1136 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1139 def __addze_sim(op
, state
):
1140 # type: (Op, BaseSimState) -> None
1141 RA
, = state
[op
.input_vals
[0]]
1142 carry
, = state
[op
.input_vals
[1]]
1144 RT
= v
& GPR_VALUE_MASK
1145 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1146 state
[op
.outputs
[0]] = RT
,
1147 state
[op
.outputs
[1]] = carry
,
1150 def __addze_gen_asm(op
, state
):
1151 # type: (Op, GenAsmState) -> None
1152 RT
= state
.vgpr(op
.outputs
[0])
1153 RA
= state
.vgpr(op
.input_vals
[0])
1154 state
.writeln(f
"addze {RT}, {RA}")
1155 AddZE
= GenericOpProperties(
1156 demo_asm
="addze RT, RA",
1157 inputs
=[OD_BASE_SGPR
, OD_CA
],
1158 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1160 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1161 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1164 def __svsubfe_sim(op
, state
):
1165 # type: (Op, BaseSimState) -> None
1166 RA
= state
[op
.input_vals
[0]]
1167 RB
= state
[op
.input_vals
[1]]
1168 carry
, = state
[op
.input_vals
[2]]
1169 VL
, = state
[op
.input_vals
[3]]
1170 RT
= [] # type: list[int]
1172 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1173 RT
.append(v
& GPR_VALUE_MASK
)
1174 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1175 state
[op
.outputs
[0]] = tuple(RT
)
1176 state
[op
.outputs
[1]] = carry
,
1179 def __svsubfe_gen_asm(op
, state
):
1180 # type: (Op, GenAsmState) -> None
1181 RT
= state
.vgpr(op
.outputs
[0])
1182 RA
= state
.vgpr(op
.input_vals
[0])
1183 RB
= state
.vgpr(op
.input_vals
[1])
1184 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1185 SvSubFE
= GenericOpProperties(
1186 demo_asm
="sv.subfe *RT, *RA, *RB",
1187 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1188 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1190 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1191 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1194 def __svmaddedu_sim(op
, state
):
1195 # type: (Op, BaseSimState) -> None
1196 RA
= state
[op
.input_vals
[0]]
1197 RB
, = state
[op
.input_vals
[1]]
1198 carry
, = state
[op
.input_vals
[2]]
1199 VL
, = state
[op
.input_vals
[3]]
1200 RT
= [] # type: list[int]
1202 v
= RA
[i
] * RB
+ carry
1203 RT
.append(v
& GPR_VALUE_MASK
)
1204 carry
= v
>> GPR_SIZE_IN_BITS
1205 state
[op
.outputs
[0]] = tuple(RT
)
1206 state
[op
.outputs
[1]] = carry
,
1209 def __svmaddedu_gen_asm(op
, state
):
1210 # type: (Op, GenAsmState) -> None
1211 RT
= state
.vgpr(op
.outputs
[0])
1212 RA
= state
.vgpr(op
.input_vals
[0])
1213 RB
= state
.sgpr(op
.input_vals
[1])
1214 RC
= state
.sgpr(op
.input_vals
[2])
1215 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1216 SvMAddEDU
= GenericOpProperties(
1217 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1218 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1219 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1221 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1222 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1225 def __setvli_sim(op
, state
):
1226 # type: (Op, BaseSimState) -> None
1227 state
[op
.outputs
[0]] = op
.immediates
[0],
1230 def __setvli_gen_asm(op
, state
):
1231 # type: (Op, GenAsmState) -> None
1232 imm
= op
.immediates
[0]
1233 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1234 SetVLI
= GenericOpProperties(
1235 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1237 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1238 immediates
=[range(1, 65)],
1239 is_load_immediate
=True,
1241 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1242 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1245 def __svli_sim(op
, state
):
1246 # type: (Op, BaseSimState) -> None
1247 VL
, = state
[op
.input_vals
[0]]
1248 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1249 state
[op
.outputs
[0]] = (imm
,) * VL
1252 def __svli_gen_asm(op
, state
):
1253 # type: (Op, GenAsmState) -> None
1254 RT
= state
.vgpr(op
.outputs
[0])
1255 imm
= op
.immediates
[0]
1256 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1257 SvLI
= GenericOpProperties(
1258 demo_asm
="sv.addi *RT, 0, imm",
1260 outputs
=[OD_EXTRA3_VGPR
],
1261 immediates
=[IMM_S16
],
1262 is_load_immediate
=True,
1264 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1265 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1268 def __li_sim(op
, state
):
1269 # type: (Op, BaseSimState) -> None
1270 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1271 state
[op
.outputs
[0]] = imm
,
1274 def __li_gen_asm(op
, state
):
1275 # type: (Op, GenAsmState) -> None
1276 RT
= state
.sgpr(op
.outputs
[0])
1277 imm
= op
.immediates
[0]
1278 state
.writeln(f
"addi {RT}, 0, {imm}")
1279 LI
= GenericOpProperties(
1280 demo_asm
="addi RT, 0, imm",
1282 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1283 immediates
=[IMM_S16
],
1284 is_load_immediate
=True,
1286 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1287 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1290 def __veccopytoreg_sim(op
, state
):
1291 # type: (Op, BaseSimState) -> None
1292 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1295 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1296 # type: (Loc, Loc, bool, GenAsmState) -> None
1297 sv
= "sv." if is_vec
else ""
1299 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1301 if src_loc
== dest_loc
:
1303 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1304 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1305 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1306 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1307 if src_loc
.kind
is LocKind
.StackI64
:
1308 if dest_loc
.kind
is LocKind
.StackI64
:
1310 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1311 elif dest_loc
.kind
is not LocKind
.GPR
:
1312 assert_never(dest_loc
.kind
)
1313 src
= state
.stack(src_loc
)
1314 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1315 state
.writeln(f
"{sv}ld {dest}, {src}")
1316 elif dest_loc
.kind
is LocKind
.StackI64
:
1317 if src_loc
.kind
is not LocKind
.GPR
:
1318 assert_never(src_loc
.kind
)
1319 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1320 dest
= state
.stack(dest_loc
)
1321 state
.writeln(f
"{sv}std {src}, {dest}")
1322 elif src_loc
.kind
is LocKind
.GPR
:
1323 if dest_loc
.kind
is not LocKind
.GPR
:
1324 assert_never(dest_loc
.kind
)
1325 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1326 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1327 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1329 assert_never(src_loc
.kind
)
1332 def __veccopytoreg_gen_asm(op
, state
):
1333 # type: (Op, GenAsmState) -> None
1334 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1336 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1337 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1338 is_vec
=True, state
=state
)
1340 VecCopyToReg
= GenericOpProperties(
1341 demo_asm
="sv.mv dest, src",
1342 inputs
=[GenericOperandDesc(
1343 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1344 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1346 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1349 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1350 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1353 def __veccopyfromreg_sim(op
, state
):
1354 # type: (Op, BaseSimState) -> None
1355 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1358 def __veccopyfromreg_gen_asm(op
, state
):
1359 # type: (Op, GenAsmState) -> None
1360 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1361 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1363 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1364 is_vec
=True, state
=state
)
1365 VecCopyFromReg
= GenericOpProperties(
1366 demo_asm
="sv.mv dest, src",
1367 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1368 outputs
=[GenericOperandDesc(
1369 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1370 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1371 write_stage
=OpStage
.Late
,
1375 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1376 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1379 def __copytoreg_sim(op
, state
):
1380 # type: (Op, BaseSimState) -> None
1381 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1384 def __copytoreg_gen_asm(op
, state
):
1385 # type: (Op, GenAsmState) -> None
1386 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1388 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1389 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1390 is_vec
=False, state
=state
)
1391 CopyToReg
= GenericOpProperties(
1392 demo_asm
="mv dest, src",
1393 inputs
=[GenericOperandDesc(
1394 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1395 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1396 LocSubKind
.StackI64
],
1398 outputs
=[GenericOperandDesc(
1399 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1400 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1401 write_stage
=OpStage
.Late
,
1405 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1406 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1409 def __copyfromreg_sim(op
, state
):
1410 # type: (Op, BaseSimState) -> None
1411 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1414 def __copyfromreg_gen_asm(op
, state
):
1415 # type: (Op, GenAsmState) -> None
1416 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1417 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1419 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1420 is_vec
=False, state
=state
)
1421 CopyFromReg
= GenericOpProperties(
1422 demo_asm
="mv dest, src",
1423 inputs
=[GenericOperandDesc(
1424 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1425 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1427 outputs
=[GenericOperandDesc(
1428 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1429 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1430 LocSubKind
.StackI64
],
1431 write_stage
=OpStage
.Late
,
1435 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1436 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1439 def __concat_sim(op
, state
):
1440 # type: (Op, BaseSimState) -> None
1441 state
[op
.outputs
[0]] = tuple(
1442 state
[i
][0] for i
in op
.input_vals
[:-1])
1445 def __concat_gen_asm(op
, state
):
1446 # type: (Op, GenAsmState) -> None
1447 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1448 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1449 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1450 is_vec
=True, state
=state
)
1451 Concat
= GenericOpProperties(
1452 demo_asm
="sv.mv dest, src",
1453 inputs
=[GenericOperandDesc(
1454 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1455 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1458 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1461 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1462 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1465 def __spread_sim(op
, state
):
1466 # type: (Op, BaseSimState) -> None
1467 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1468 state
[op
.outputs
[idx
]] = inp
,
1471 def __spread_gen_asm(op
, state
):
1472 # type: (Op, GenAsmState) -> None
1473 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1474 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1475 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1476 is_vec
=True, state
=state
)
1477 Spread
= GenericOpProperties(
1478 demo_asm
="sv.mv dest, src",
1479 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1480 outputs
=[GenericOperandDesc(
1481 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1482 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1484 write_stage
=OpStage
.Late
,
1488 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1489 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1492 def __svld_sim(op
, state
):
1493 # type: (Op, BaseSimState) -> None
1494 RA
, = state
[op
.input_vals
[0]]
1495 VL
, = state
[op
.input_vals
[1]]
1496 addr
= RA
+ op
.immediates
[0]
1497 RT
= [] # type: list[int]
1499 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1500 RT
.append(v
& GPR_VALUE_MASK
)
1501 state
[op
.outputs
[0]] = tuple(RT
)
1504 def __svld_gen_asm(op
, state
):
1505 # type: (Op, GenAsmState) -> None
1506 RA
= state
.sgpr(op
.input_vals
[0])
1507 RT
= state
.vgpr(op
.outputs
[0])
1508 imm
= op
.immediates
[0]
1509 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1510 SvLd
= GenericOpProperties(
1511 demo_asm
="sv.ld *RT, imm(RA)",
1512 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1513 outputs
=[OD_EXTRA3_VGPR
],
1514 immediates
=[IMM_S16
],
1516 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1517 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1520 def __ld_sim(op
, state
):
1521 # type: (Op, BaseSimState) -> None
1522 RA
, = state
[op
.input_vals
[0]]
1523 addr
= RA
+ op
.immediates
[0]
1524 v
= state
.load(addr
)
1525 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1528 def __ld_gen_asm(op
, state
):
1529 # type: (Op, GenAsmState) -> None
1530 RA
= state
.sgpr(op
.input_vals
[0])
1531 RT
= state
.sgpr(op
.outputs
[0])
1532 imm
= op
.immediates
[0]
1533 state
.writeln(f
"ld {RT}, {imm}({RA})")
1534 Ld
= GenericOpProperties(
1535 demo_asm
="ld RT, imm(RA)",
1536 inputs
=[OD_BASE_SGPR
],
1537 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1538 immediates
=[IMM_S16
],
1540 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1541 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1544 def __svstd_sim(op
, state
):
1545 # type: (Op, BaseSimState) -> None
1546 RS
= state
[op
.input_vals
[0]]
1547 RA
, = state
[op
.input_vals
[1]]
1548 VL
, = state
[op
.input_vals
[2]]
1549 addr
= RA
+ op
.immediates
[0]
1551 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1554 def __svstd_gen_asm(op
, state
):
1555 # type: (Op, GenAsmState) -> None
1556 RS
= state
.vgpr(op
.input_vals
[0])
1557 RA
= state
.sgpr(op
.input_vals
[1])
1558 imm
= op
.immediates
[0]
1559 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1560 SvStd
= GenericOpProperties(
1561 demo_asm
="sv.std *RS, imm(RA)",
1562 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1564 immediates
=[IMM_S16
],
1565 has_side_effects
=True,
1567 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1568 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1571 def __std_sim(op
, state
):
1572 # type: (Op, BaseSimState) -> None
1573 RS
, = state
[op
.input_vals
[0]]
1574 RA
, = state
[op
.input_vals
[1]]
1575 addr
= RA
+ op
.immediates
[0]
1576 state
.store(addr
, value
=RS
)
1579 def __std_gen_asm(op
, state
):
1580 # type: (Op, GenAsmState) -> None
1581 RS
= state
.sgpr(op
.input_vals
[0])
1582 RA
= state
.sgpr(op
.input_vals
[1])
1583 imm
= op
.immediates
[0]
1584 state
.writeln(f
"std {RS}, {imm}({RA})")
1585 Std
= GenericOpProperties(
1586 demo_asm
="std RS, imm(RA)",
1587 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1589 immediates
=[IMM_S16
],
1590 has_side_effects
=True,
1592 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1593 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1596 def __funcargr3_sim(op
, state
):
1597 # type: (Op, BaseSimState) -> None
1598 pass # return value set before simulation
1601 def __funcargr3_gen_asm(op
, state
):
1602 # type: (Op, GenAsmState) -> None
1603 pass # no instructions needed
1604 FuncArgR3
= GenericOpProperties(
1607 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1608 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1610 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1611 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1614 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1615 class SSAValOrUse(metaclass
=InternedMeta
):
1616 __slots__
= "op", "operand_idx"
1618 def __init__(self
, op
, operand_idx
):
1619 # type: (Op, int) -> None
1622 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1623 raise ValueError("invalid operand_idx")
1624 self
.operand_idx
= operand_idx
1633 def descriptor_array(self
):
1634 # type: () -> tuple[OperandDesc, ...]
1638 def defining_descriptor(self
):
1639 # type: () -> OperandDesc
1640 return self
.descriptor_array
[self
.operand_idx
]
1645 return self
.defining_descriptor
.ty
1648 def ty_before_spread(self
):
1650 return self
.defining_descriptor
.ty_before_spread
1654 # type: () -> BaseTy
1655 return self
.ty_before_spread
.base_ty
1658 def reg_offset_in_unspread(self
):
1659 """ the number of reg-sized slots in the unspread Loc before self's Loc
1661 e.g. if the unspread Loc containing self is:
1662 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1663 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1664 then reg_offset_into_unspread == 2 == 10 - 8
1666 return self
.defining_descriptor
.reg_offset_in_unspread
1669 def unspread_start_idx(self
):
1671 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1674 def unspread_start(self
):
1676 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1679 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1681 class SSAVal(SSAValOrUse
):
1686 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1689 def def_loc_set_before_spread(self
):
1690 # type: () -> LocSet
1691 return self
.defining_descriptor
.loc_set_before_spread
1694 def descriptor_array(self
):
1695 # type: () -> tuple[OperandDesc, ...]
1696 return self
.op
.properties
.outputs
1699 def tied_input(self
):
1700 # type: () -> None | SSAUse
1701 if self
.defining_descriptor
.tied_input_index
is None:
1703 return SSAUse(op
=self
.op
,
1704 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1707 def write_stage(self
):
1708 # type: () -> OpStage
1709 return self
.defining_descriptor
.write_stage
1712 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1714 class SSAUse(SSAValOrUse
):
1718 def use_loc_set_before_spread(self
):
1719 # type: () -> LocSet
1720 return self
.defining_descriptor
.loc_set_before_spread
1723 def descriptor_array(self
):
1724 # type: () -> tuple[OperandDesc, ...]
1725 return self
.op
.properties
.inputs
1729 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1733 # type: () -> SSAVal
1734 return self
.op
.input_vals
[self
.operand_idx
]
1737 def ssa_val(self
, ssa_val
):
1738 # type: (SSAVal) -> None
1739 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1743 _Desc
= TypeVar("_Desc")
1746 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1748 def _verify_write_with_desc(self
, idx
, item
, desc
):
1749 # type: (int, _T | Any, _Desc) -> None
1750 raise NotImplementedError
1753 def _verify_write(self
, idx
, item
):
1754 # type: (int | Any, _T | Any) -> int
1755 if not isinstance(idx
, int):
1756 if isinstance(idx
, slice):
1758 f
"can't write to slice of {self.__class__.__name__}")
1759 raise TypeError(f
"can't write with index {idx!r}")
1760 # normalize idx, raising IndexError if it is out of range
1761 idx
= range(len(self
.descriptors
))[idx
]
1762 desc
= self
.descriptors
[idx
]
1763 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1766 def _on_set(self
, idx
, new_item
, old_item
):
1767 # type: (int, _T, _T | None) -> None
1771 def _get_descriptors(self
):
1772 # type: () -> tuple[_Desc, ...]
1773 raise NotImplementedError
1777 def descriptors(self
):
1778 # type: () -> tuple[_Desc, ...]
1779 return self
._get
_descriptors
()
1786 def __init__(self
, items
, op
):
1787 # type: (Iterable[_T], Op) -> None
1790 self
.__items
= [] # type: list[_T]
1791 for idx
, item
in enumerate(items
):
1792 if idx
>= len(self
.descriptors
):
1793 raise ValueError("too many items")
1794 _
= self
._verify
_write
(idx
, item
)
1795 self
.__items
.append(item
)
1796 if len(self
.__items
) < len(self
.descriptors
):
1797 raise ValueError("not enough items")
1801 # type: () -> Iterator[_T]
1802 yield from self
.__items
1805 def __getitem__(self
, idx
):
1810 def __getitem__(self
, idx
):
1811 # type: (slice) -> list[_T]
1815 def __getitem__(self
, idx
):
1816 # type: (int | slice) -> _T | list[_T]
1817 return self
.__items
[idx
]
1820 def __setitem__(self
, idx
, item
):
1821 # type: (int, _T) -> None
1822 idx
= self
._verify
_write
(idx
, item
)
1823 self
.__items
[idx
] = item
1828 return len(self
.__items
)
1832 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1836 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
1837 def _get_descriptors(self
):
1838 # type: () -> tuple[OperandDesc, ...]
1839 return self
.op
.properties
.inputs
1841 def _verify_write_with_desc(self
, idx
, item
, desc
):
1842 # type: (int, SSAVal | Any, OperandDesc) -> None
1843 if not isinstance(item
, SSAVal
):
1844 raise TypeError("expected value of type SSAVal")
1845 if item
.ty
!= desc
.ty
:
1846 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1847 f
"corresponding input's type {desc.ty!r}")
1849 def _on_set(self
, idx
, new_item
, old_item
):
1850 # type: (int, SSAVal, SSAVal | None) -> None
1851 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
1853 def __init__(self
, items
, op
):
1854 # type: (Iterable[SSAVal], Op) -> None
1855 if hasattr(op
, "inputs"):
1856 raise ValueError("Op.inputs already set")
1857 super().__init
__(items
, op
)
1861 class OpImmediates(OpInputSeq
[int, range]):
1862 def _get_descriptors(self
):
1863 # type: () -> tuple[range, ...]
1864 return self
.op
.properties
.immediates
1866 def _verify_write_with_desc(self
, idx
, item
, desc
):
1867 # type: (int, int | Any, range) -> None
1868 if not isinstance(item
, int):
1869 raise TypeError("expected value of type int")
1870 if item
not in desc
:
1871 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1873 def __init__(self
, items
, op
):
1874 # type: (Iterable[int], Op) -> None
1875 if hasattr(op
, "immediates"):
1876 raise ValueError("Op.immediates already set")
1877 super().__init
__(items
, op
)
1880 @plain_data(frozen
=True, eq
=False, repr=False)
1883 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
1886 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
1887 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1889 self
.properties
= properties
1890 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
1891 inputs_len
= len(self
.properties
.inputs
)
1892 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
1893 self
.immediates
= OpImmediates(immediates
, op
=self
)
1894 outputs_len
= len(self
.properties
.outputs
)
1895 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1896 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1900 # type: () -> OpKind
1901 return self
.properties
.kind
1903 def __eq__(self
, other
):
1904 # type: (Op | Any) -> bool
1905 if isinstance(other
, Op
):
1906 return self
is other
1907 return NotImplemented
1911 return object.__hash
__(self
)
1915 field_vals
= [] # type: list[str]
1916 for name
in fields(self
):
1917 if name
== "properties":
1922 value
= getattr(self
, name
)
1923 except AttributeError:
1924 field_vals
.append(f
"{name}=<not set>")
1926 if isinstance(value
, OpInputSeq
):
1927 value
= list(value
) # type: ignore
1928 field_vals
.append(f
"{name}={value!r}")
1929 field_vals_str
= ", ".join(field_vals
)
1930 return f
"Op({field_vals_str})"
1932 def sim(self
, state
):
1933 # type: (BaseSimState) -> None
1934 for inp
in self
.input_vals
:
1938 raise ValueError(f
"SSAVal {inp} not yet assigned when "
1940 if len(val
) != inp
.ty
.reg_len
:
1942 f
"value of SSAVal {inp} has wrong number of elements: "
1943 f
"expected {inp.ty.reg_len} found "
1944 f
"{len(val)}: {val!r}")
1945 if isinstance(state
, PreRASimState
):
1946 for out
in self
.outputs
:
1947 if out
in state
.ssa_vals
:
1948 if self
.kind
is OpKind
.FuncArgR3
:
1950 raise ValueError(f
"SSAVal {out} already assigned before "
1952 self
.kind
.sim(self
, state
)
1953 for out
in self
.outputs
:
1957 raise ValueError(f
"running {self} failed to assign to {out}")
1958 if len(val
) != out
.ty
.reg_len
:
1960 f
"value of SSAVal {out} has wrong number of elements: "
1961 f
"expected {out.ty.reg_len} found "
1962 f
"{len(val)}: {val!r}")
1964 def gen_asm(self
, state
):
1965 # type: (GenAsmState) -> None
1966 all_loc_kinds
= tuple(LocKind
)
1967 for inp
in self
.input_vals
:
1968 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
1969 for out
in self
.outputs
:
1970 state
.loc(out
, expected_kinds
=all_loc_kinds
)
1971 self
.kind
.gen_asm(self
, state
)
1974 GPR_SIZE_IN_BYTES
= 8
1976 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
1977 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
1980 @plain_data(frozen
=True, repr=False)
1981 class BaseSimState(metaclass
=ABCMeta
):
1982 __slots__
= "memory",
1984 def __init__(self
, memory
):
1985 # type: (dict[int, int]) -> None
1987 self
.memory
= memory
# type: dict[int, int]
1989 def load_byte(self
, addr
):
1990 # type: (int) -> int
1991 addr
&= GPR_VALUE_MASK
1992 return self
.memory
.get(addr
, 0) & 0xFF
1994 def store_byte(self
, addr
, value
):
1995 # type: (int, int) -> None
1996 addr
&= GPR_VALUE_MASK
1998 self
.memory
[addr
] = value
2000 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2001 # type: (int, int, bool) -> int
2002 if addr
% size_in_bytes
!= 0:
2003 raise ValueError(f
"address not aligned: {hex(addr)} "
2004 f
"required alignment: {size_in_bytes}")
2006 for i
in range(size_in_bytes
):
2007 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2008 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2009 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2012 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2013 # type: (int, int, int) -> None
2014 if addr
% size_in_bytes
!= 0:
2015 raise ValueError(f
"address not aligned: {hex(addr)} "
2016 f
"required alignment: {size_in_bytes}")
2017 for i
in range(size_in_bytes
):
2018 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2020 def _memory__repr(self
):
2022 if len(self
.memory
) == 0:
2024 keys
= sorted(self
.memory
.keys(), reverse
=True)
2025 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2026 items
= [] # type: list[str]
2027 while len(keys
) != 0:
2029 if (len(keys
) >= CHUNK_SIZE
2030 and addr
% CHUNK_SIZE
== 0
2031 and keys
[-CHUNK_SIZE
:]
2032 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2033 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2034 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2035 keys
[-CHUNK_SIZE
:] = ()
2037 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2039 return f
"{{{items[0]}}}"
2040 items_str
= ",\n".join(items
)
2041 return f
"{{\n{items_str}}}"
2045 field_vals
= [] # type: list[str]
2046 for name
in fields(self
):
2048 value
= getattr(self
, name
)
2049 except AttributeError:
2050 field_vals
.append(f
"{name}=<not set>")
2052 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2053 if callable(repr_fn
):
2054 field_vals
.append(f
"{name}={repr_fn()}")
2056 field_vals
.append(f
"{name}={value!r}")
2057 field_vals_str
= ", ".join(field_vals
)
2058 return f
"{self.__class__.__name__}({field_vals_str})"
2061 def __getitem__(self
, ssa_val
):
2062 # type: (SSAVal) -> tuple[int, ...]
2066 def __setitem__(self
, ssa_val
, value
):
2067 # type: (SSAVal, tuple[int, ...]) -> None
2071 @plain_data(frozen
=True, repr=False)
2073 class PreRASimState(BaseSimState
):
2074 __slots__
= "ssa_vals",
2076 def __init__(self
, ssa_vals
, memory
):
2077 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2078 super().__init
__(memory
)
2079 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2081 def _ssa_vals__repr(self
):
2083 if len(self
.ssa_vals
) == 0:
2085 items
= [] # type: list[str]
2087 for k
, v
in self
.ssa_vals
.items():
2088 element_strs
= [] # type: list[str]
2089 for i
, el
in enumerate(v
):
2090 if i
% CHUNK_SIZE
!= 0:
2091 element_strs
.append(" " + hex(el
))
2093 element_strs
.append("\n " + hex(el
))
2094 if len(element_strs
) <= CHUNK_SIZE
:
2095 element_strs
[0] = element_strs
[0].lstrip()
2096 if len(element_strs
) == 1:
2097 element_strs
.append("")
2098 v_str
= ",".join(element_strs
)
2099 items
.append(f
"{k!r}: ({v_str})")
2100 if len(items
) == 1 and "\n" not in items
[0]:
2101 return f
"{{{items[0]}}}"
2102 items_str
= ",\n".join(items
)
2103 return f
"{{\n{items_str},\n}}"
2105 def __getitem__(self
, ssa_val
):
2106 # type: (SSAVal) -> tuple[int, ...]
2107 return self
.ssa_vals
[ssa_val
]
2109 def __setitem__(self
, ssa_val
, value
):
2110 # type: (SSAVal, tuple[int, ...]) -> None
2111 if len(value
) != ssa_val
.ty
.reg_len
:
2112 raise ValueError("value has wrong len")
2113 self
.ssa_vals
[ssa_val
] = value
2116 @plain_data(frozen
=True, repr=False)
2118 class PostRASimState(BaseSimState
):
2119 __slots__
= "ssa_val_to_loc_map", "loc_values"
2121 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2122 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2123 super().__init
__(memory
)
2124 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2125 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2126 if ssa_val
.ty
!= loc
.ty
:
2128 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2129 self
.loc_values
= loc_values
2130 for loc
in self
.loc_values
.keys():
2131 if loc
.reg_len
!= 1:
2133 "loc_values must only contain Locs with reg_len=1, all "
2134 "larger Locs will be split into reg_len=1 sub-Locs")
2136 def _loc_values__repr(self
):
2138 locs
= sorted(self
.loc_values
.keys(), key
=lambda v
: (v
.kind
, v
.start
))
2139 items
= [] # type: list[str]
2141 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2142 items_str
= ",\n".join(items
)
2143 return f
"{{\n{items_str},\n}}"
2145 def __getitem__(self
, ssa_val
):
2146 # type: (SSAVal) -> tuple[int, ...]
2147 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2148 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2149 retval
= [] # type: list[int]
2150 for i
in range(loc
.reg_len
):
2151 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2152 retval
.append(self
.loc_values
.get(subloc
, 0))
2153 return tuple(retval
)
2155 def __setitem__(self
, ssa_val
, value
):
2156 # type: (SSAVal, tuple[int, ...]) -> None
2157 if len(value
) != ssa_val
.ty
.reg_len
:
2158 raise ValueError("value has wrong len")
2159 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2160 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2161 for i
in range(loc
.reg_len
):
2162 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2163 self
.loc_values
[subloc
] = value
[i
]
2166 @plain_data(frozen
=True)
2168 __slots__
= "allocated_locs", "output"
2170 def __init__(self
, allocated_locs
, output
=None):
2171 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2173 self
.allocated_locs
= FMap(allocated_locs
)
2174 for ssa_val
, loc
in self
.allocated_locs
.items():
2175 if ssa_val
.ty
!= loc
.ty
:
2177 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2180 self
.output
= output
2182 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2184 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2185 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2186 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2187 ssa_val_or_locs
= [ssa_val_or_locs
]
2188 locs
= [] # type: list[Loc]
2189 for i
in ssa_val_or_locs
:
2190 if isinstance(i
, SSAVal
):
2191 locs
.append(self
.allocated_locs
[i
])
2195 raise ValueError("invalid Loc sequence: must not be empty")
2196 retval
= locs
[0].try_concat(*locs
[1:])
2198 raise ValueError("invalid Loc sequence: try_concat failed")
2199 if isinstance(expected_kinds
, LocKind
):
2200 expected_kinds
= expected_kinds
,
2201 if retval
.kind
not in expected_kinds
:
2202 if len(expected_kinds
) == 1:
2203 expected_kinds
= expected_kinds
[0]
2204 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2205 f
"{retval.kind} expected {expected_kinds}")
2208 def gpr(self
, ssa_val_or_locs
, is_vec
):
2209 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2210 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2211 vec_str
= "*" if is_vec
else ""
2212 return vec_str
+ str(loc
.start
)
2214 def sgpr(self
, ssa_val_or_locs
):
2215 # type: (__SSA_VAL_OR_LOCS) -> str
2216 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2218 def vgpr(self
, ssa_val_or_locs
):
2219 # type: (__SSA_VAL_OR_LOCS) -> str
2220 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2222 def stack(self
, ssa_val_or_locs
):
2223 # type: (__SSA_VAL_OR_LOCS) -> str
2224 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2225 return f
"{loc.start}(1)"
2227 def writeln(self
, *line_segments
):
2228 # type: (*str) -> None
2229 line
= " ".join(line_segments
)
2230 if isinstance(self
.output
, list):
2231 self
.output
.append(line
)
2233 self
.output
.write(line
+ "\n")