2 from abc
import ABCMeta
, abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
, total_ordering
5 from io
import StringIO
6 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
7 Mapping
, Sequence
, TypeVar
, Union
, overload
)
8 from weakref
import WeakValueDictionary
as _WeakVDict
10 from cached_property
import cached_property
11 from nmutil
.plain_data
import fields
, plain_data
13 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
15 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
21 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
22 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
28 self
.ops
= [] # type: list[Op]
29 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
30 self
.__next
_name
_suffix
= 2
32 def _add_op_with_unused_name(self
, op
, name
=""):
33 # type: (Op, str) -> str
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op
, "name"):
37 raise ValueError("Op already named")
40 if name
!= "" and name
not in self
.__op
_names
:
41 self
.__op
_names
[name
] = op
43 name
= orig_name
+ str(self
.__next
_name
_suffix
)
44 self
.__next
_name
_suffix
+= 1
50 def append_op(self
, op
):
53 raise ValueError("can't add Op to wrong Fn")
56 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
58 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
59 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
60 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
61 self
.append_op(retval
)
65 # type: (BaseSimState) -> None
69 def gen_asm(self
, state
):
70 # type: (GenAsmState) -> None
74 def pre_ra_insert_copies(self
):
76 orig_ops
= list(self
.ops
)
77 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
78 setvli_outputs
= {} # type: dict[SSAVal, Op]
81 for i
in range(len(op
.input_vals
)):
82 inp
= copied_outputs
[op
.input_vals
[i
]]
83 if inp
.ty
.base_ty
is BaseTy
.I64
:
84 maxvl
= inp
.ty
.reg_len
85 if inp
.ty
.reg_len
!= 1:
86 setvl
= self
.append_new_op(
87 OpKind
.SetVLI
, immediates
=[maxvl
],
88 name
=f
"{op.name}.inp{i}.setvl")
90 mv
= self
.append_new_op(
91 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
92 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
94 mv
= self
.append_new_op(
95 OpKind
.CopyToReg
, input_vals
=[inp
],
96 name
=f
"{op.name}.inp{i}.copy")
97 op
.input_vals
[i
] = mv
.outputs
[0]
98 elif inp
.ty
.base_ty
is BaseTy
.CA \
99 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
100 # all copies would be no-ops, so we don't need to copy,
101 # though we do need to rematerialize SetVLI ops right
103 if inp
in setvli_outputs
:
104 setvl
= self
.append_new_op(
106 immediates
=setvli_outputs
[inp
].immediates
,
107 name
=f
"{op.name}.inp{i}.setvl")
108 inp
= setvl
.outputs
[0]
109 op
.input_vals
[i
] = inp
111 assert_never(inp
.ty
.base_ty
)
113 for i
, out
in enumerate(op
.outputs
):
114 if op
.kind
is OpKind
.SetVLI
:
115 setvli_outputs
[out
] = op
116 if out
.ty
.base_ty
is BaseTy
.I64
:
117 maxvl
= out
.ty
.reg_len
118 if out
.ty
.reg_len
!= 1:
119 setvl
= self
.append_new_op(
120 OpKind
.SetVLI
, immediates
=[maxvl
],
121 name
=f
"{op.name}.out{i}.setvl")
122 vl
= setvl
.outputs
[0]
123 mv
= self
.append_new_op(
124 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
125 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
127 mv
= self
.append_new_op(
128 OpKind
.CopyFromReg
, input_vals
=[out
],
129 name
=f
"{op.name}.out{i}.copy")
130 copied_outputs
[out
] = mv
.outputs
[0]
131 elif out
.ty
.base_ty
is BaseTy
.CA \
132 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
133 # all copies would be no-ops, so we don't need to copy
134 copied_outputs
[out
] = out
136 assert_never(out
.ty
.base_ty
)
143 value
: Literal
[0, 1] # type: ignore
145 def __new__(cls
, value
):
146 # type: (int) -> OpStage
148 if value
not in (0, 1):
149 raise ValueError("invalid value")
150 retval
= object.__new
__(cls
)
151 retval
._value
_ = value
155 """ early stage of Op execution, where all input reads occur.
156 all output writes with `write_stage == Early` occur here too, and therefore
157 conflict with input reads, telling the compiler that it that can't share
158 that output's register with any inputs that the output isn't tied to.
160 All outputs, even unused outputs, can't share registers with any other
161 outputs, independent of `write_stage` settings.
164 """ late stage of Op execution, where all output writes with
165 `write_stage == Late` occur, and therefore don't conflict with input reads,
166 telling the compiler that any inputs can safely use the same register as
169 All outputs, even unused outputs, can't share registers with any other
170 outputs, independent of `write_stage` settings.
175 return f
"OpStage.{self._name_}"
177 def __lt__(self
, other
):
178 # type: (OpStage | object) -> bool
179 if isinstance(other
, OpStage
):
180 return self
.value
< other
.value
181 return NotImplemented
184 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
187 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
190 class ProgramPoint(metaclass
=InternedMeta
):
191 __slots__
= "op_index", "stage"
193 def __init__(self
, op_index
, stage
):
194 # type: (int, OpStage) -> None
195 self
.op_index
= op_index
201 """ an integer representation of `self` such that it keeps ordering and
202 successor/predecessor relations.
204 return self
.op_index
* 2 + self
.stage
.value
207 def from_int_value(int_value
):
208 # type: (int) -> ProgramPoint
209 op_index
, stage
= divmod(int_value
, 2)
210 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
212 def next(self
, steps
=1):
213 # type: (int) -> ProgramPoint
214 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
216 def prev(self
, steps
=1):
217 # type: (int) -> ProgramPoint
218 return self
.next(steps
=-steps
)
220 def __lt__(self
, other
):
221 # type: (ProgramPoint | Any) -> bool
222 if not isinstance(other
, ProgramPoint
):
223 return NotImplemented
224 if self
.op_index
!= other
.op_index
:
225 return self
.op_index
< other
.op_index
226 return self
.stage
< other
.stage
230 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
233 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
235 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
236 __slots__
= "start", "stop"
238 def __init__(self
, start
, stop
):
239 # type: (ProgramPoint, ProgramPoint) -> None
244 def int_value_range(self
):
246 return range(self
.start
.int_value
, self
.stop
.int_value
)
249 def from_int_value_range(int_value_range
):
250 # type: (range) -> ProgramRange
251 if int_value_range
.step
!= 1:
252 raise ValueError("int_value_range must have step == 1")
254 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
255 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
258 def __getitem__(self
, __idx
):
259 # type: (int) -> ProgramPoint
263 def __getitem__(self
, __idx
):
264 # type: (slice) -> ProgramRange
267 def __getitem__(self
, __idx
):
268 # type: (int | slice) -> ProgramPoint | ProgramRange
269 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
270 if isinstance(v
, int):
271 return ProgramPoint
.from_int_value(v
)
272 return ProgramRange
.from_int_value_range(v
)
276 return len(self
.int_value_range
)
279 # type: () -> Iterator[ProgramPoint]
280 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
284 start
= repr(self
.start
).lstrip("<").rstrip(">")
285 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
286 return f
"<range:{start}..{stop}>"
289 @plain_data(frozen
=True, eq
=False, repr=False)
292 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
293 "def_program_ranges", "use_program_points",
294 "all_program_points")
296 def __init__(self
, fn
):
299 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
300 self
.all_program_points
= ProgramRange(
301 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
302 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
303 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
304 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
305 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
306 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
308 for use
in op
.input_uses
:
309 uses
[use
.ssa_val
].add(use
)
310 use_program_point
= self
.__get
_use
_program
_point
(use
)
311 use_program_points
[use
] = use_program_point
312 live_range_stops
[use
.ssa_val
] = max(
313 live_range_stops
[use
.ssa_val
], use_program_point
.next())
314 for out
in op
.outputs
:
316 def_program_range
= self
.__get
_def
_program
_range
(out
)
317 def_program_ranges
[out
] = def_program_range
318 live_range_stops
[out
] = def_program_range
.stop
319 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
320 self
.def_program_ranges
= FMap(def_program_ranges
)
321 self
.use_program_points
= FMap(use_program_points
)
322 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
323 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
324 for ssa_val
in uses
.keys():
325 live_ranges
[ssa_val
] = live_range
= ProgramRange(
326 start
=self
.def_program_ranges
[ssa_val
].start
,
327 stop
=live_range_stops
[ssa_val
])
328 for program_point
in live_range
:
329 live_at
[program_point
].add(ssa_val
)
330 self
.live_ranges
= FMap(live_ranges
)
331 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
333 def __get_def_program_range(self
, ssa_val
):
334 # type: (SSAVal) -> ProgramRange
335 write_stage
= ssa_val
.defining_descriptor
.write_stage
336 start
= ProgramPoint(
337 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
338 # always include late stage of ssa_val.op, to ensure outputs always
339 # overlap all other outputs.
340 # stop is exclusive, so we need the next program point.
341 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
342 return ProgramRange(start
=start
, stop
=stop
)
344 def __get_use_program_point(self
, ssa_use
):
345 # type: (SSAUse) -> ProgramPoint
346 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
347 "assumed here, ensured by GenericOpProperties.__init__"
349 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
351 def __eq__(self
, other
):
352 # type: (FnAnalysis | Any) -> bool
353 if isinstance(other
, FnAnalysis
):
354 return self
.fn
== other
.fn
355 return NotImplemented
363 return "<FnAnalysis>"
371 VL_MAXVL
= enum
.auto()
374 def only_scalar(self
):
376 if self
is BaseTy
.I64
:
378 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
384 def max_reg_len(self
):
386 if self
is BaseTy
.I64
:
388 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
394 return "BaseTy." + self
._name
_
397 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
399 class Ty(metaclass
=InternedMeta
):
400 __slots__
= "base_ty", "reg_len"
403 def validate(base_ty
, reg_len
):
404 # type: (BaseTy, int) -> str | None
405 """ return a string with the error if the combination is invalid,
406 otherwise return None
408 if base_ty
.only_scalar
and reg_len
!= 1:
409 return f
"can't create a vector of an only-scalar type: {base_ty}"
410 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
411 return "reg_len out of range"
414 def __init__(self
, base_ty
, reg_len
):
415 # type: (BaseTy, int) -> None
416 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
418 raise ValueError(msg
)
419 self
.base_ty
= base_ty
420 self
.reg_len
= reg_len
424 if self
.reg_len
!= 1:
425 reg_len
= f
"*{self.reg_len}"
428 return f
"<{self.base_ty._name_}{reg_len}>"
435 StackI64
= enum
.auto()
437 VL_MAXVL
= enum
.auto()
442 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
444 if self
is LocKind
.CA
:
446 if self
is LocKind
.VL_MAXVL
:
447 return BaseTy
.VL_MAXVL
454 if self
is LocKind
.StackI64
:
456 if self
is LocKind
.GPR
or self
is LocKind
.CA \
457 or self
is LocKind
.VL_MAXVL
:
458 return self
.base_ty
.max_reg_len
463 return "LocKind." + self
._name
_
468 class LocSubKind(Enum
):
469 BASE_GPR
= enum
.auto()
470 SV_EXTRA2_VGPR
= enum
.auto()
471 SV_EXTRA2_SGPR
= enum
.auto()
472 SV_EXTRA3_VGPR
= enum
.auto()
473 SV_EXTRA3_SGPR
= enum
.auto()
474 StackI64
= enum
.auto()
476 VL_MAXVL
= enum
.auto()
480 # type: () -> LocKind
481 # pyright fails typechecking when using `in` here:
482 # reported: https://github.com/microsoft/pyright/issues/4102
483 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
484 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
485 LocSubKind
.SV_EXTRA3_SGPR
):
487 if self
is LocSubKind
.StackI64
:
488 return LocKind
.StackI64
489 if self
is LocSubKind
.CA
:
491 if self
is LocSubKind
.VL_MAXVL
:
492 return LocKind
.VL_MAXVL
497 return self
.kind
.base_ty
500 def allocatable_locs(self
, ty
):
501 # type: (Ty) -> LocSet
502 if ty
.base_ty
!= self
.base_ty
:
503 raise ValueError("type mismatch")
504 if self
is LocSubKind
.BASE_GPR
:
506 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
507 starts
= range(0, 128, 2)
508 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
510 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
511 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
513 elif self
is LocSubKind
.StackI64
:
514 starts
= range(LocKind
.StackI64
.loc_count
)
515 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
516 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
519 retval
= [] # type: list[Loc]
521 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
525 for special_loc
in SPECIAL_GPRS
:
526 if loc
.conflicts(special_loc
):
531 return LocSet(retval
)
534 return "LocSubKind." + self
._name
_
537 @plain_data(frozen
=True, unsafe_hash
=True)
539 class GenericTy(metaclass
=InternedMeta
):
540 __slots__
= "base_ty", "is_vec"
542 def __init__(self
, base_ty
, is_vec
):
543 # type: (BaseTy, bool) -> None
544 self
.base_ty
= base_ty
545 if base_ty
.only_scalar
and is_vec
:
546 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
549 def instantiate(self
, maxvl
):
551 # here's where subvl and elwid would be accounted for
553 return Ty(self
.base_ty
, maxvl
)
554 return Ty(self
.base_ty
, 1)
556 def can_instantiate_to(self
, ty
):
558 if self
.base_ty
!= ty
.base_ty
:
562 return ty
.reg_len
== 1
565 @plain_data(frozen
=True, unsafe_hash
=True)
567 class Loc(metaclass
=InternedMeta
):
568 __slots__
= "kind", "start", "reg_len"
571 def validate(kind
, start
, reg_len
):
572 # type: (LocKind, int, int) -> str | None
573 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
576 if reg_len
> kind
.loc_count
:
577 return "invalid reg_len"
578 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
579 return "start not in valid range"
583 def try_make(kind
, start
, reg_len
):
584 # type: (LocKind, int, int) -> Loc | None
585 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
588 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
590 def __init__(self
, kind
, start
, reg_len
):
591 # type: (LocKind, int, int) -> None
592 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
594 raise ValueError(msg
)
596 self
.reg_len
= reg_len
599 def conflicts(self
, other
):
600 # type: (Loc) -> bool
601 return (self
.kind
== other
.kind
602 and self
.start
< other
.stop
and other
.start
< self
.stop
)
605 def make_ty(kind
, reg_len
):
606 # type: (LocKind, int) -> Ty
607 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
612 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
617 return self
.start
+ self
.reg_len
619 def try_concat(self
, *others
):
620 # type: (*Loc | None) -> Loc | None
621 reg_len
= self
.reg_len
624 if other
is None or other
.kind
!= self
.kind
:
626 if stop
!= other
.start
:
629 reg_len
+= other
.reg_len
630 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
632 def get_subloc_at_offset(self
, subloc_ty
, offset
):
633 # type: (Ty, int) -> Loc
634 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
635 raise ValueError("BaseTy mismatch")
636 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
637 raise ValueError("invalid sub-Loc: offset and/or "
638 "subloc_ty.reg_len out of range")
639 return Loc(kind
=self
.kind
,
640 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
644 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
645 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
646 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
647 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
652 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
653 def __init__(self
, __locs
=()):
654 # type: (Iterable[Loc]) -> None
655 super().__init
__(__locs
)
656 if isinstance(__locs
, LocSet
):
657 self
.__starts
= __locs
.starts
658 self
.__ty
= __locs
.ty
660 starts
= {i
: BitSet() for i
in LocKind
}
661 ty
= None # type: None | Ty
666 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
667 starts
[loc
.kind
].add(loc
.start
)
668 self
.__starts
= FMap(
669 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
674 # type: () -> FMap[LocKind, FBitSet]
679 # type: () -> Ty | None
684 # type: () -> FMap[LocKind, FBitSet]
689 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
693 # type: () -> AbstractSet[LocKind]
694 return self
.starts
.keys()
698 # type: () -> int | None
701 return self
.ty
.reg_len
705 # type: () -> BaseTy | None
708 return self
.ty
.base_ty
710 def concat(self
, *others
):
711 # type: (*LocSet) -> LocSet
714 base_ty
= self
.ty
.base_ty
715 reg_len
= self
.ty
.reg_len
716 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
720 if other
.ty
.base_ty
!= base_ty
:
722 for kind
, other_starts
in other
.starts
.items():
723 if kind
not in starts
:
725 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
726 if starts
[kind
] == 0:
730 reg_len
+= other
.ty
.reg_len
733 # type: () -> Iterable[Loc]
734 for kind
, v
in starts
.items():
736 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
739 return LocSet(locs())
741 @lru_cache(maxsize
=None, typed
=True)
742 def max_conflicts_with(self
, other
):
743 # type: (LocSet | Loc) -> int
744 """the largest number of Locs in `self` that a single Loc
745 from `other` can conflict with
747 if isinstance(other
, LocSet
):
748 return max(self
.max_conflicts_with(i
) for i
in other
)
750 return sum(other
.conflicts(i
) for i
in self
)
753 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
756 @plain_data(frozen
=True, unsafe_hash
=True)
758 class GenericOperandDesc(metaclass
=InternedMeta
):
759 """generic Op operand descriptor"""
760 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
764 self
, ty
, # type: GenericTy
765 sub_kinds
, # type: Iterable[LocSubKind]
767 fixed_loc
=None, # type: Loc | None
768 tied_input_index
=None, # type: int | None
769 spread
=False, # type: bool
770 write_stage
=OpStage
.Early
, # type: OpStage
772 # type: (...) -> None
774 self
.sub_kinds
= OFSet(sub_kinds
)
775 if len(self
.sub_kinds
) == 0:
776 raise ValueError("sub_kinds can't be empty")
777 self
.fixed_loc
= fixed_loc
778 if fixed_loc
is not None:
779 if tied_input_index
is not None:
780 raise ValueError("operand can't be both tied and fixed")
781 if not ty
.can_instantiate_to(fixed_loc
.ty
):
783 f
"fixed_loc has incompatible type for given generic "
784 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
785 if len(self
.sub_kinds
) != 1:
787 "multiple sub_kinds not allowed for fixed operand")
788 for sub_kind
in self
.sub_kinds
:
789 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
791 f
"fixed_loc not in given sub_kind: "
792 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
793 for sub_kind
in self
.sub_kinds
:
794 if sub_kind
.base_ty
!= ty
.base_ty
:
795 raise ValueError(f
"sub_kind is incompatible with type: "
796 f
"sub_kind={sub_kind} ty={ty}")
797 if tied_input_index
is not None and tied_input_index
< 0:
798 raise ValueError("invalid tied_input_index")
799 self
.tied_input_index
= tied_input_index
802 if self
.tied_input_index
is not None:
803 raise ValueError("operand can't be both spread and tied")
804 if self
.fixed_loc
is not None:
805 raise ValueError("operand can't be both spread and fixed")
807 raise ValueError("operand can't be both spread and vector")
808 self
.write_stage
= write_stage
811 def ty_before_spread(self
):
812 # type: () -> GenericTy
814 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
817 def tied_to_input(self
, tied_input_index
):
818 # type: (int) -> Self
819 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
820 tied_input_index
=tied_input_index
,
821 write_stage
=self
.write_stage
)
823 def with_fixed_loc(self
, fixed_loc
):
824 # type: (Loc) -> Self
825 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
826 write_stage
=self
.write_stage
)
828 def with_write_stage(self
, write_stage
):
829 # type: (OpStage) -> Self
830 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
831 fixed_loc
=self
.fixed_loc
,
832 tied_input_index
=self
.tied_input_index
,
834 write_stage
=write_stage
)
836 def instantiate(self
, maxvl
):
837 # type: (int) -> Iterable[OperandDesc]
838 # assumes all spread operands have ty.reg_len = 1
842 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
844 def locs_before_spread():
845 # type: () -> Iterable[Loc]
846 if self
.fixed_loc
is not None:
847 if ty_before_spread
!= self
.fixed_loc
.ty
:
849 f
"instantiation failed: type mismatch with fixed_loc: "
850 f
"instantiated type: {ty_before_spread} "
851 f
"fixed_loc: {self.fixed_loc}")
854 for sub_kind
in self
.sub_kinds
:
855 yield from sub_kind
.allocatable_locs(ty_before_spread
)
856 loc_set_before_spread
= LocSet(locs_before_spread())
857 for idx
in range(rep_count
):
860 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
861 tied_input_index
=self
.tied_input_index
,
862 spread_index
=idx
, write_stage
=self
.write_stage
)
865 @plain_data(frozen
=True, unsafe_hash
=True)
867 class OperandDesc(metaclass
=InternedMeta
):
868 """Op operand descriptor"""
869 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
872 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
874 # type: (LocSet, int | None, int | None, OpStage) -> None
875 if len(loc_set_before_spread
) == 0:
876 raise ValueError("loc_set_before_spread must not be empty")
877 self
.loc_set_before_spread
= loc_set_before_spread
878 self
.tied_input_index
= tied_input_index
879 if self
.tied_input_index
is not None and spread_index
is not None:
880 raise ValueError("operand can't be both spread and tied")
881 self
.spread_index
= spread_index
882 self
.write_stage
= write_stage
885 def ty_before_spread(self
):
887 ty
= self
.loc_set_before_spread
.ty
888 assert ty
is not None, (
889 "__init__ checked that the LocSet isn't empty, "
890 "non-empty LocSets should always have ty set")
895 """ Ty after any spread is applied """
896 if self
.spread_index
is not None:
897 # assumes all spread operands have ty.reg_len = 1
898 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
899 return self
.ty_before_spread
902 def reg_offset_in_unspread(self
):
903 """ the number of reg-sized slots in the unspread Loc before self's Loc
905 e.g. if the unspread Loc containing self is:
906 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
907 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
908 then reg_offset_into_unspread == 2 == 10 - 8
910 if self
.spread_index
is None:
912 return self
.spread_index
* self
.ty
.reg_len
915 OD_BASE_SGPR
= GenericOperandDesc(
916 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
917 sub_kinds
=[LocSubKind
.BASE_GPR
])
918 OD_EXTRA3_SGPR
= GenericOperandDesc(
919 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
920 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
921 OD_EXTRA3_VGPR
= GenericOperandDesc(
922 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
923 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
924 OD_EXTRA2_SGPR
= GenericOperandDesc(
925 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
926 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
927 OD_EXTRA2_VGPR
= GenericOperandDesc(
928 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
929 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
930 OD_CA
= GenericOperandDesc(
931 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
932 sub_kinds
=[LocSubKind
.CA
])
933 OD_VL
= GenericOperandDesc(
934 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
935 sub_kinds
=[LocSubKind
.VL_MAXVL
])
938 @plain_data(frozen
=True, unsafe_hash
=True)
940 class GenericOpProperties(metaclass
=InternedMeta
):
941 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
942 "is_copy", "is_load_immediate", "has_side_effects")
945 self
, demo_asm
, # type: str
946 inputs
, # type: Iterable[GenericOperandDesc]
947 outputs
, # type: Iterable[GenericOperandDesc]
948 immediates
=(), # type: Iterable[range]
949 is_copy
=False, # type: bool
950 is_load_immediate
=False, # type: bool
951 has_side_effects
=False, # type: bool
953 # type: (...) -> None
954 self
.demo_asm
= demo_asm
# type: str
955 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
956 for inp
in self
.inputs
:
957 if inp
.tied_input_index
is not None:
959 f
"tied_input_index is not allowed on inputs: {inp}")
960 if inp
.write_stage
is not OpStage
.Early
:
962 f
"write_stage is not allowed on inputs: {inp}")
963 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
964 fixed_locs
= [] # type: list[tuple[Loc, int]]
965 for idx
, out
in enumerate(self
.outputs
):
966 if out
.tied_input_index
is not None:
967 if out
.tied_input_index
>= len(self
.inputs
):
968 raise ValueError(f
"tied_input_index out of range: {out}")
969 tied_inp
= self
.inputs
[out
.tied_input_index
]
970 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
971 .with_write_stage(out
.write_stage
)
972 if expected_out
!= out
:
973 raise ValueError(f
"output can't be tied to non-equivalent "
974 f
"input: {out} tied to {tied_inp}")
975 if out
.fixed_loc
is not None:
976 for other_fixed_loc
, other_idx
in fixed_locs
:
977 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
980 f
"conflicting fixed_locs: outputs[{idx}] and "
981 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
982 f
"with {other_fixed_loc}")
983 fixed_locs
.append((out
.fixed_loc
, idx
))
984 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
985 self
.is_copy
= is_copy
# type: bool
986 self
.is_load_immediate
= is_load_immediate
# type: bool
987 self
.has_side_effects
= has_side_effects
# type: bool
990 @plain_data(frozen
=True, unsafe_hash
=True)
992 class OpProperties(metaclass
=InternedMeta
):
993 __slots__
= "kind", "inputs", "outputs", "maxvl"
995 def __init__(self
, kind
, maxvl
):
996 # type: (OpKind, int) -> None
997 self
.kind
= kind
# type: OpKind
998 inputs
= [] # type: list[OperandDesc]
999 for inp
in self
.generic
.inputs
:
1000 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1001 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1002 outputs
= [] # type: list[OperandDesc]
1003 for out
in self
.generic
.outputs
:
1004 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1005 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1006 self
.maxvl
= maxvl
# type: int
1010 # type: () -> GenericOpProperties
1011 return self
.kind
.properties
1014 def immediates(self
):
1015 # type: () -> tuple[range, ...]
1016 return self
.generic
.immediates
1021 return self
.generic
.demo_asm
1026 return self
.generic
.is_copy
1029 def is_load_immediate(self
):
1031 return self
.generic
.is_load_immediate
1034 def has_side_effects(self
):
1036 return self
.generic
.has_side_effects
1039 IMM_S16
= range(-1 << 15, 1 << 15)
1041 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1042 _SIM_FN2
= Callable
[[], _SIM_FN
]
1043 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1044 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1045 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1046 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1052 def __init__(self
, properties
):
1053 # type: (GenericOpProperties) -> None
1055 self
.__properties
= properties
1058 def properties(self
):
1059 # type: () -> GenericOpProperties
1060 return self
.__properties
1062 def instantiate(self
, maxvl
):
1063 # type: (int) -> OpProperties
1064 return OpProperties(self
, maxvl
=maxvl
)
1068 return "OpKind." + self
._name
_
1072 # type: () -> _SIM_FN
1073 return _SIM_FNS
[self
.properties
]()
1077 # type: () -> _GEN_ASM_FN
1078 return _GEN_ASMS
[self
.properties
]()
1081 def __clearca_sim(op
, state
):
1082 # type: (Op, BaseSimState) -> None
1083 state
[op
.outputs
[0]] = False,
1086 def __clearca_gen_asm(op
, state
):
1087 # type: (Op, GenAsmState) -> None
1088 state
.writeln("addic 0, 0, 0")
1089 ClearCA
= GenericOpProperties(
1090 demo_asm
="addic 0, 0, 0",
1092 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1094 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1095 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1098 def __setca_sim(op
, state
):
1099 # type: (Op, BaseSimState) -> None
1100 state
[op
.outputs
[0]] = True,
1103 def __setca_gen_asm(op
, state
):
1104 # type: (Op, GenAsmState) -> None
1105 state
.writeln("subfc 0, 0, 0")
1106 SetCA
= GenericOpProperties(
1107 demo_asm
="subfc 0, 0, 0",
1109 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1111 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1112 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1115 def __svadde_sim(op
, state
):
1116 # type: (Op, BaseSimState) -> None
1117 RA
= state
[op
.input_vals
[0]]
1118 RB
= state
[op
.input_vals
[1]]
1119 carry
, = state
[op
.input_vals
[2]]
1120 VL
, = state
[op
.input_vals
[3]]
1121 RT
= [] # type: list[int]
1123 v
= RA
[i
] + RB
[i
] + carry
1124 RT
.append(v
& GPR_VALUE_MASK
)
1125 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1126 state
[op
.outputs
[0]] = tuple(RT
)
1127 state
[op
.outputs
[1]] = carry
,
1130 def __svadde_gen_asm(op
, state
):
1131 # type: (Op, GenAsmState) -> None
1132 RT
= state
.vgpr(op
.outputs
[0])
1133 RA
= state
.vgpr(op
.input_vals
[0])
1134 RB
= state
.vgpr(op
.input_vals
[1])
1135 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1136 SvAddE
= GenericOpProperties(
1137 demo_asm
="sv.adde *RT, *RA, *RB",
1138 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1139 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1141 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1142 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1145 def __addze_sim(op
, state
):
1146 # type: (Op, BaseSimState) -> None
1147 RA
, = state
[op
.input_vals
[0]]
1148 carry
, = state
[op
.input_vals
[1]]
1150 RT
= v
& GPR_VALUE_MASK
1151 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1152 state
[op
.outputs
[0]] = RT
,
1153 state
[op
.outputs
[1]] = carry
,
1156 def __addze_gen_asm(op
, state
):
1157 # type: (Op, GenAsmState) -> None
1158 RT
= state
.vgpr(op
.outputs
[0])
1159 RA
= state
.vgpr(op
.input_vals
[0])
1160 state
.writeln(f
"addze {RT}, {RA}")
1161 AddZE
= GenericOpProperties(
1162 demo_asm
="addze RT, RA",
1163 inputs
=[OD_BASE_SGPR
, OD_CA
],
1164 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1166 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1167 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1170 def __svsubfe_sim(op
, state
):
1171 # type: (Op, BaseSimState) -> None
1172 RA
= state
[op
.input_vals
[0]]
1173 RB
= state
[op
.input_vals
[1]]
1174 carry
, = state
[op
.input_vals
[2]]
1175 VL
, = state
[op
.input_vals
[3]]
1176 RT
= [] # type: list[int]
1178 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1179 RT
.append(v
& GPR_VALUE_MASK
)
1180 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1181 state
[op
.outputs
[0]] = tuple(RT
)
1182 state
[op
.outputs
[1]] = carry
,
1185 def __svsubfe_gen_asm(op
, state
):
1186 # type: (Op, GenAsmState) -> None
1187 RT
= state
.vgpr(op
.outputs
[0])
1188 RA
= state
.vgpr(op
.input_vals
[0])
1189 RB
= state
.vgpr(op
.input_vals
[1])
1190 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1191 SvSubFE
= GenericOpProperties(
1192 demo_asm
="sv.subfe *RT, *RA, *RB",
1193 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1194 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1196 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1197 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1200 def __svmaddedu_sim(op
, state
):
1201 # type: (Op, BaseSimState) -> None
1202 RA
= state
[op
.input_vals
[0]]
1203 RB
, = state
[op
.input_vals
[1]]
1204 carry
, = state
[op
.input_vals
[2]]
1205 VL
, = state
[op
.input_vals
[3]]
1206 RT
= [] # type: list[int]
1208 v
= RA
[i
] * RB
+ carry
1209 RT
.append(v
& GPR_VALUE_MASK
)
1210 carry
= v
>> GPR_SIZE_IN_BITS
1211 state
[op
.outputs
[0]] = tuple(RT
)
1212 state
[op
.outputs
[1]] = carry
,
1215 def __svmaddedu_gen_asm(op
, state
):
1216 # type: (Op, GenAsmState) -> None
1217 RT
= state
.vgpr(op
.outputs
[0])
1218 RA
= state
.vgpr(op
.input_vals
[0])
1219 RB
= state
.sgpr(op
.input_vals
[1])
1220 RC
= state
.sgpr(op
.input_vals
[2])
1221 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1222 SvMAddEDU
= GenericOpProperties(
1223 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1224 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1225 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1227 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1228 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1231 def __sradi_sim(op
, state
):
1232 # type: (Op, BaseSimState) -> None
1233 rs
, = state
[op
.input_vals
[0]]
1234 imm
= op
.immediates
[0]
1235 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1236 rs
-= 1 << GPR_SIZE_IN_BITS
1238 RA
= v
& GPR_VALUE_MASK
1239 CA
= (RA
<< imm
) != rs
1240 state
[op
.outputs
[0]] = RA
,
1241 state
[op
.outputs
[1]] = CA
,
1244 def __sradi_gen_asm(op
, state
):
1245 # type: (Op, GenAsmState) -> None
1246 RA
= state
.sgpr(op
.outputs
[0])
1247 RS
= state
.sgpr(op
.input_vals
[1])
1248 imm
= op
.immediates
[0]
1249 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1250 SRADI
= GenericOpProperties(
1251 demo_asm
="sradi RA, RS, imm",
1252 inputs
=[OD_BASE_SGPR
],
1253 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1254 OD_CA
.with_write_stage(OpStage
.Late
)],
1255 immediates
=[range(GPR_SIZE_IN_BITS
)],
1257 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1258 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1261 def __setvli_sim(op
, state
):
1262 # type: (Op, BaseSimState) -> None
1263 state
[op
.outputs
[0]] = op
.immediates
[0],
1266 def __setvli_gen_asm(op
, state
):
1267 # type: (Op, GenAsmState) -> None
1268 imm
= op
.immediates
[0]
1269 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1270 SetVLI
= GenericOpProperties(
1271 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1273 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1274 immediates
=[range(1, 65)],
1275 is_load_immediate
=True,
1277 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1278 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1281 def __svli_sim(op
, state
):
1282 # type: (Op, BaseSimState) -> None
1283 VL
, = state
[op
.input_vals
[0]]
1284 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1285 state
[op
.outputs
[0]] = (imm
,) * VL
1288 def __svli_gen_asm(op
, state
):
1289 # type: (Op, GenAsmState) -> None
1290 RT
= state
.vgpr(op
.outputs
[0])
1291 imm
= op
.immediates
[0]
1292 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1293 SvLI
= GenericOpProperties(
1294 demo_asm
="sv.addi *RT, 0, imm",
1296 outputs
=[OD_EXTRA3_VGPR
],
1297 immediates
=[IMM_S16
],
1298 is_load_immediate
=True,
1300 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1301 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1304 def __li_sim(op
, state
):
1305 # type: (Op, BaseSimState) -> None
1306 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1307 state
[op
.outputs
[0]] = imm
,
1310 def __li_gen_asm(op
, state
):
1311 # type: (Op, GenAsmState) -> None
1312 RT
= state
.sgpr(op
.outputs
[0])
1313 imm
= op
.immediates
[0]
1314 state
.writeln(f
"addi {RT}, 0, {imm}")
1315 LI
= GenericOpProperties(
1316 demo_asm
="addi RT, 0, imm",
1318 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1319 immediates
=[IMM_S16
],
1320 is_load_immediate
=True,
1322 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1323 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1326 def __veccopytoreg_sim(op
, state
):
1327 # type: (Op, BaseSimState) -> None
1328 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1331 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1332 # type: (Loc, Loc, bool, GenAsmState) -> None
1333 sv
= "sv." if is_vec
else ""
1335 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1337 if src_loc
== dest_loc
:
1339 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1340 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1341 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1342 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1343 if src_loc
.kind
is LocKind
.StackI64
:
1344 if dest_loc
.kind
is LocKind
.StackI64
:
1346 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1347 elif dest_loc
.kind
is not LocKind
.GPR
:
1348 assert_never(dest_loc
.kind
)
1349 src
= state
.stack(src_loc
)
1350 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1351 state
.writeln(f
"{sv}ld {dest}, {src}")
1352 elif dest_loc
.kind
is LocKind
.StackI64
:
1353 if src_loc
.kind
is not LocKind
.GPR
:
1354 assert_never(src_loc
.kind
)
1355 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1356 dest
= state
.stack(dest_loc
)
1357 state
.writeln(f
"{sv}std {src}, {dest}")
1358 elif src_loc
.kind
is LocKind
.GPR
:
1359 if dest_loc
.kind
is not LocKind
.GPR
:
1360 assert_never(dest_loc
.kind
)
1361 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1362 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1363 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1365 assert_never(src_loc
.kind
)
1368 def __veccopytoreg_gen_asm(op
, state
):
1369 # type: (Op, GenAsmState) -> None
1370 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1372 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1373 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1374 is_vec
=True, state
=state
)
1376 VecCopyToReg
= GenericOpProperties(
1377 demo_asm
="sv.mv dest, src",
1378 inputs
=[GenericOperandDesc(
1379 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1380 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1382 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1385 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1386 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1389 def __veccopyfromreg_sim(op
, state
):
1390 # type: (Op, BaseSimState) -> None
1391 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1394 def __veccopyfromreg_gen_asm(op
, state
):
1395 # type: (Op, GenAsmState) -> None
1396 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1397 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1399 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1400 is_vec
=True, state
=state
)
1401 VecCopyFromReg
= GenericOpProperties(
1402 demo_asm
="sv.mv dest, src",
1403 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1404 outputs
=[GenericOperandDesc(
1405 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1406 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1407 write_stage
=OpStage
.Late
,
1411 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1412 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1415 def __copytoreg_sim(op
, state
):
1416 # type: (Op, BaseSimState) -> None
1417 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1420 def __copytoreg_gen_asm(op
, state
):
1421 # type: (Op, GenAsmState) -> None
1422 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1424 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1425 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1426 is_vec
=False, state
=state
)
1427 CopyToReg
= GenericOpProperties(
1428 demo_asm
="mv dest, src",
1429 inputs
=[GenericOperandDesc(
1430 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1431 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1432 LocSubKind
.StackI64
],
1434 outputs
=[GenericOperandDesc(
1435 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1436 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1437 write_stage
=OpStage
.Late
,
1441 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1442 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1445 def __copyfromreg_sim(op
, state
):
1446 # type: (Op, BaseSimState) -> None
1447 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1450 def __copyfromreg_gen_asm(op
, state
):
1451 # type: (Op, GenAsmState) -> None
1452 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1453 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1455 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1456 is_vec
=False, state
=state
)
1457 CopyFromReg
= GenericOpProperties(
1458 demo_asm
="mv dest, src",
1459 inputs
=[GenericOperandDesc(
1460 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1461 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1463 outputs
=[GenericOperandDesc(
1464 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1465 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1466 LocSubKind
.StackI64
],
1467 write_stage
=OpStage
.Late
,
1471 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1472 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1475 def __concat_sim(op
, state
):
1476 # type: (Op, BaseSimState) -> None
1477 state
[op
.outputs
[0]] = tuple(
1478 state
[i
][0] for i
in op
.input_vals
[:-1])
1481 def __concat_gen_asm(op
, state
):
1482 # type: (Op, GenAsmState) -> None
1483 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1484 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1485 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1486 is_vec
=True, state
=state
)
1487 Concat
= GenericOpProperties(
1488 demo_asm
="sv.mv dest, src",
1489 inputs
=[GenericOperandDesc(
1490 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1491 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1494 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1497 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1498 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1501 def __spread_sim(op
, state
):
1502 # type: (Op, BaseSimState) -> None
1503 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1504 state
[op
.outputs
[idx
]] = inp
,
1507 def __spread_gen_asm(op
, state
):
1508 # type: (Op, GenAsmState) -> None
1509 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1510 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1511 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1512 is_vec
=True, state
=state
)
1513 Spread
= GenericOpProperties(
1514 demo_asm
="sv.mv dest, src",
1515 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1516 outputs
=[GenericOperandDesc(
1517 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1518 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1520 write_stage
=OpStage
.Late
,
1524 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1525 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1528 def __svld_sim(op
, state
):
1529 # type: (Op, BaseSimState) -> None
1530 RA
, = state
[op
.input_vals
[0]]
1531 VL
, = state
[op
.input_vals
[1]]
1532 addr
= RA
+ op
.immediates
[0]
1533 RT
= [] # type: list[int]
1535 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1536 RT
.append(v
& GPR_VALUE_MASK
)
1537 state
[op
.outputs
[0]] = tuple(RT
)
1540 def __svld_gen_asm(op
, state
):
1541 # type: (Op, GenAsmState) -> None
1542 RA
= state
.sgpr(op
.input_vals
[0])
1543 RT
= state
.vgpr(op
.outputs
[0])
1544 imm
= op
.immediates
[0]
1545 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1546 SvLd
= GenericOpProperties(
1547 demo_asm
="sv.ld *RT, imm(RA)",
1548 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1549 outputs
=[OD_EXTRA3_VGPR
],
1550 immediates
=[IMM_S16
],
1552 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1553 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1556 def __ld_sim(op
, state
):
1557 # type: (Op, BaseSimState) -> None
1558 RA
, = state
[op
.input_vals
[0]]
1559 addr
= RA
+ op
.immediates
[0]
1560 v
= state
.load(addr
)
1561 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1564 def __ld_gen_asm(op
, state
):
1565 # type: (Op, GenAsmState) -> None
1566 RA
= state
.sgpr(op
.input_vals
[0])
1567 RT
= state
.sgpr(op
.outputs
[0])
1568 imm
= op
.immediates
[0]
1569 state
.writeln(f
"ld {RT}, {imm}({RA})")
1570 Ld
= GenericOpProperties(
1571 demo_asm
="ld RT, imm(RA)",
1572 inputs
=[OD_BASE_SGPR
],
1573 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1574 immediates
=[IMM_S16
],
1576 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1577 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1580 def __svstd_sim(op
, state
):
1581 # type: (Op, BaseSimState) -> None
1582 RS
= state
[op
.input_vals
[0]]
1583 RA
, = state
[op
.input_vals
[1]]
1584 VL
, = state
[op
.input_vals
[2]]
1585 addr
= RA
+ op
.immediates
[0]
1587 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1590 def __svstd_gen_asm(op
, state
):
1591 # type: (Op, GenAsmState) -> None
1592 RS
= state
.vgpr(op
.input_vals
[0])
1593 RA
= state
.sgpr(op
.input_vals
[1])
1594 imm
= op
.immediates
[0]
1595 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1596 SvStd
= GenericOpProperties(
1597 demo_asm
="sv.std *RS, imm(RA)",
1598 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1600 immediates
=[IMM_S16
],
1601 has_side_effects
=True,
1603 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1604 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1607 def __std_sim(op
, state
):
1608 # type: (Op, BaseSimState) -> None
1609 RS
, = state
[op
.input_vals
[0]]
1610 RA
, = state
[op
.input_vals
[1]]
1611 addr
= RA
+ op
.immediates
[0]
1612 state
.store(addr
, value
=RS
)
1615 def __std_gen_asm(op
, state
):
1616 # type: (Op, GenAsmState) -> None
1617 RS
= state
.sgpr(op
.input_vals
[0])
1618 RA
= state
.sgpr(op
.input_vals
[1])
1619 imm
= op
.immediates
[0]
1620 state
.writeln(f
"std {RS}, {imm}({RA})")
1621 Std
= GenericOpProperties(
1622 demo_asm
="std RS, imm(RA)",
1623 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1625 immediates
=[IMM_S16
],
1626 has_side_effects
=True,
1628 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1629 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1632 def __funcargr3_sim(op
, state
):
1633 # type: (Op, BaseSimState) -> None
1634 pass # return value set before simulation
1637 def __funcargr3_gen_asm(op
, state
):
1638 # type: (Op, GenAsmState) -> None
1639 pass # no instructions needed
1640 FuncArgR3
= GenericOpProperties(
1643 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1644 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1646 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1647 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1650 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1651 class SSAValOrUse(metaclass
=InternedMeta
):
1652 __slots__
= "op", "operand_idx"
1654 def __init__(self
, op
, operand_idx
):
1655 # type: (Op, int) -> None
1658 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1659 raise ValueError("invalid operand_idx")
1660 self
.operand_idx
= operand_idx
1669 def descriptor_array(self
):
1670 # type: () -> tuple[OperandDesc, ...]
1674 def defining_descriptor(self
):
1675 # type: () -> OperandDesc
1676 return self
.descriptor_array
[self
.operand_idx
]
1681 return self
.defining_descriptor
.ty
1684 def ty_before_spread(self
):
1686 return self
.defining_descriptor
.ty_before_spread
1690 # type: () -> BaseTy
1691 return self
.ty_before_spread
.base_ty
1694 def reg_offset_in_unspread(self
):
1695 """ the number of reg-sized slots in the unspread Loc before self's Loc
1697 e.g. if the unspread Loc containing self is:
1698 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1699 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1700 then reg_offset_into_unspread == 2 == 10 - 8
1702 return self
.defining_descriptor
.reg_offset_in_unspread
1705 def unspread_start_idx(self
):
1707 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1710 def unspread_start(self
):
1712 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1715 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1717 class SSAVal(SSAValOrUse
):
1722 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1725 def def_loc_set_before_spread(self
):
1726 # type: () -> LocSet
1727 return self
.defining_descriptor
.loc_set_before_spread
1730 def descriptor_array(self
):
1731 # type: () -> tuple[OperandDesc, ...]
1732 return self
.op
.properties
.outputs
1735 def tied_input(self
):
1736 # type: () -> None | SSAUse
1737 if self
.defining_descriptor
.tied_input_index
is None:
1739 return SSAUse(op
=self
.op
,
1740 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1743 def write_stage(self
):
1744 # type: () -> OpStage
1745 return self
.defining_descriptor
.write_stage
1748 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1750 class SSAUse(SSAValOrUse
):
1754 def use_loc_set_before_spread(self
):
1755 # type: () -> LocSet
1756 return self
.defining_descriptor
.loc_set_before_spread
1759 def descriptor_array(self
):
1760 # type: () -> tuple[OperandDesc, ...]
1761 return self
.op
.properties
.inputs
1765 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1769 # type: () -> SSAVal
1770 return self
.op
.input_vals
[self
.operand_idx
]
1773 def ssa_val(self
, ssa_val
):
1774 # type: (SSAVal) -> None
1775 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1779 _Desc
= TypeVar("_Desc")
1782 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1784 def _verify_write_with_desc(self
, idx
, item
, desc
):
1785 # type: (int, _T | Any, _Desc) -> None
1786 raise NotImplementedError
1789 def _verify_write(self
, idx
, item
):
1790 # type: (int | Any, _T | Any) -> int
1791 if not isinstance(idx
, int):
1792 if isinstance(idx
, slice):
1794 f
"can't write to slice of {self.__class__.__name__}")
1795 raise TypeError(f
"can't write with index {idx!r}")
1796 # normalize idx, raising IndexError if it is out of range
1797 idx
= range(len(self
.descriptors
))[idx
]
1798 desc
= self
.descriptors
[idx
]
1799 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1802 def _on_set(self
, idx
, new_item
, old_item
):
1803 # type: (int, _T, _T | None) -> None
1807 def _get_descriptors(self
):
1808 # type: () -> tuple[_Desc, ...]
1809 raise NotImplementedError
1813 def descriptors(self
):
1814 # type: () -> tuple[_Desc, ...]
1815 return self
._get
_descriptors
()
1822 def __init__(self
, items
, op
):
1823 # type: (Iterable[_T], Op) -> None
1826 self
.__items
= [] # type: list[_T]
1827 for idx
, item
in enumerate(items
):
1828 if idx
>= len(self
.descriptors
):
1829 raise ValueError("too many items")
1830 _
= self
._verify
_write
(idx
, item
)
1831 self
.__items
.append(item
)
1832 if len(self
.__items
) < len(self
.descriptors
):
1833 raise ValueError("not enough items")
1837 # type: () -> Iterator[_T]
1838 yield from self
.__items
1841 def __getitem__(self
, idx
):
1846 def __getitem__(self
, idx
):
1847 # type: (slice) -> list[_T]
1851 def __getitem__(self
, idx
):
1852 # type: (int | slice) -> _T | list[_T]
1853 return self
.__items
[idx
]
1856 def __setitem__(self
, idx
, item
):
1857 # type: (int, _T) -> None
1858 idx
= self
._verify
_write
(idx
, item
)
1859 self
.__items
[idx
] = item
1864 return len(self
.__items
)
1868 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1872 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
1873 def _get_descriptors(self
):
1874 # type: () -> tuple[OperandDesc, ...]
1875 return self
.op
.properties
.inputs
1877 def _verify_write_with_desc(self
, idx
, item
, desc
):
1878 # type: (int, SSAVal | Any, OperandDesc) -> None
1879 if not isinstance(item
, SSAVal
):
1880 raise TypeError("expected value of type SSAVal")
1881 if item
.ty
!= desc
.ty
:
1882 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1883 f
"corresponding input's type {desc.ty!r}")
1885 def _on_set(self
, idx
, new_item
, old_item
):
1886 # type: (int, SSAVal, SSAVal | None) -> None
1887 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
1889 def __init__(self
, items
, op
):
1890 # type: (Iterable[SSAVal], Op) -> None
1891 if hasattr(op
, "inputs"):
1892 raise ValueError("Op.inputs already set")
1893 super().__init
__(items
, op
)
1897 class OpImmediates(OpInputSeq
[int, range]):
1898 def _get_descriptors(self
):
1899 # type: () -> tuple[range, ...]
1900 return self
.op
.properties
.immediates
1902 def _verify_write_with_desc(self
, idx
, item
, desc
):
1903 # type: (int, int | Any, range) -> None
1904 if not isinstance(item
, int):
1905 raise TypeError("expected value of type int")
1906 if item
not in desc
:
1907 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1909 def __init__(self
, items
, op
):
1910 # type: (Iterable[int], Op) -> None
1911 if hasattr(op
, "immediates"):
1912 raise ValueError("Op.immediates already set")
1913 super().__init
__(items
, op
)
1916 @plain_data(frozen
=True, eq
=False, repr=False)
1919 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
1922 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
1923 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1925 self
.properties
= properties
1926 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
1927 inputs_len
= len(self
.properties
.inputs
)
1928 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
1929 self
.immediates
= OpImmediates(immediates
, op
=self
)
1930 outputs_len
= len(self
.properties
.outputs
)
1931 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1932 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1936 # type: () -> OpKind
1937 return self
.properties
.kind
1939 def __eq__(self
, other
):
1940 # type: (Op | Any) -> bool
1941 if isinstance(other
, Op
):
1942 return self
is other
1943 return NotImplemented
1947 return object.__hash
__(self
)
1951 field_vals
= [] # type: list[str]
1952 for name
in fields(self
):
1953 if name
== "properties":
1958 value
= getattr(self
, name
)
1959 except AttributeError:
1960 field_vals
.append(f
"{name}=<not set>")
1962 if isinstance(value
, OpInputSeq
):
1963 value
= list(value
) # type: ignore
1964 field_vals
.append(f
"{name}={value!r}")
1965 field_vals_str
= ", ".join(field_vals
)
1966 return f
"Op({field_vals_str})"
1968 def sim(self
, state
):
1969 # type: (BaseSimState) -> None
1970 for inp
in self
.input_vals
:
1974 raise ValueError(f
"SSAVal {inp} not yet assigned when "
1976 if len(val
) != inp
.ty
.reg_len
:
1978 f
"value of SSAVal {inp} has wrong number of elements: "
1979 f
"expected {inp.ty.reg_len} found "
1980 f
"{len(val)}: {val!r}")
1981 if isinstance(state
, PreRASimState
):
1982 for out
in self
.outputs
:
1983 if out
in state
.ssa_vals
:
1984 if self
.kind
is OpKind
.FuncArgR3
:
1986 raise ValueError(f
"SSAVal {out} already assigned before "
1988 self
.kind
.sim(self
, state
)
1989 for out
in self
.outputs
:
1993 raise ValueError(f
"running {self} failed to assign to {out}")
1994 if len(val
) != out
.ty
.reg_len
:
1996 f
"value of SSAVal {out} has wrong number of elements: "
1997 f
"expected {out.ty.reg_len} found "
1998 f
"{len(val)}: {val!r}")
2000 def gen_asm(self
, state
):
2001 # type: (GenAsmState) -> None
2002 all_loc_kinds
= tuple(LocKind
)
2003 for inp
in self
.input_vals
:
2004 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2005 for out
in self
.outputs
:
2006 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2007 self
.kind
.gen_asm(self
, state
)
2010 @plain_data(frozen
=True, repr=False)
2011 class BaseSimState(metaclass
=ABCMeta
):
2012 __slots__
= "memory",
2014 def __init__(self
, memory
):
2015 # type: (dict[int, int]) -> None
2017 self
.memory
= memory
# type: dict[int, int]
2019 def load_byte(self
, addr
):
2020 # type: (int) -> int
2021 addr
&= GPR_VALUE_MASK
2022 return self
.memory
.get(addr
, 0) & 0xFF
2024 def store_byte(self
, addr
, value
):
2025 # type: (int, int) -> None
2026 addr
&= GPR_VALUE_MASK
2028 self
.memory
[addr
] = value
2030 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2031 # type: (int, int, bool) -> int
2032 if addr
% size_in_bytes
!= 0:
2033 raise ValueError(f
"address not aligned: {hex(addr)} "
2034 f
"required alignment: {size_in_bytes}")
2036 for i
in range(size_in_bytes
):
2037 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2038 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2039 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2042 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2043 # type: (int, int, int) -> None
2044 if addr
% size_in_bytes
!= 0:
2045 raise ValueError(f
"address not aligned: {hex(addr)} "
2046 f
"required alignment: {size_in_bytes}")
2047 for i
in range(size_in_bytes
):
2048 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2050 def _memory__repr(self
):
2052 if len(self
.memory
) == 0:
2054 keys
= sorted(self
.memory
.keys(), reverse
=True)
2055 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2056 items
= [] # type: list[str]
2057 while len(keys
) != 0:
2059 if (len(keys
) >= CHUNK_SIZE
2060 and addr
% CHUNK_SIZE
== 0
2061 and keys
[-CHUNK_SIZE
:]
2062 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2063 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2064 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2065 keys
[-CHUNK_SIZE
:] = ()
2067 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2069 return f
"{{{items[0]}}}"
2070 items_str
= ",\n".join(items
)
2071 return f
"{{\n{items_str}}}"
2075 field_vals
= [] # type: list[str]
2076 for name
in fields(self
):
2078 value
= getattr(self
, name
)
2079 except AttributeError:
2080 field_vals
.append(f
"{name}=<not set>")
2082 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2083 if callable(repr_fn
):
2084 field_vals
.append(f
"{name}={repr_fn()}")
2086 field_vals
.append(f
"{name}={value!r}")
2087 field_vals_str
= ", ".join(field_vals
)
2088 return f
"{self.__class__.__name__}({field_vals_str})"
2091 def __getitem__(self
, ssa_val
):
2092 # type: (SSAVal) -> tuple[int, ...]
2096 def __setitem__(self
, ssa_val
, value
):
2097 # type: (SSAVal, tuple[int, ...]) -> None
2101 @plain_data(frozen
=True, repr=False)
2103 class PreRASimState(BaseSimState
):
2104 __slots__
= "ssa_vals",
2106 def __init__(self
, ssa_vals
, memory
):
2107 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2108 super().__init
__(memory
)
2109 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2111 def _ssa_vals__repr(self
):
2113 if len(self
.ssa_vals
) == 0:
2115 items
= [] # type: list[str]
2117 for k
, v
in self
.ssa_vals
.items():
2118 element_strs
= [] # type: list[str]
2119 for i
, el
in enumerate(v
):
2120 if i
% CHUNK_SIZE
!= 0:
2121 element_strs
.append(" " + hex(el
))
2123 element_strs
.append("\n " + hex(el
))
2124 if len(element_strs
) <= CHUNK_SIZE
:
2125 element_strs
[0] = element_strs
[0].lstrip()
2126 if len(element_strs
) == 1:
2127 element_strs
.append("")
2128 v_str
= ",".join(element_strs
)
2129 items
.append(f
"{k!r}: ({v_str})")
2130 if len(items
) == 1 and "\n" not in items
[0]:
2131 return f
"{{{items[0]}}}"
2132 items_str
= ",\n".join(items
)
2133 return f
"{{\n{items_str},\n}}"
2135 def __getitem__(self
, ssa_val
):
2136 # type: (SSAVal) -> tuple[int, ...]
2137 return self
.ssa_vals
[ssa_val
]
2139 def __setitem__(self
, ssa_val
, value
):
2140 # type: (SSAVal, tuple[int, ...]) -> None
2141 if len(value
) != ssa_val
.ty
.reg_len
:
2142 raise ValueError("value has wrong len")
2143 self
.ssa_vals
[ssa_val
] = value
2146 @plain_data(frozen
=True, repr=False)
2148 class PostRASimState(BaseSimState
):
2149 __slots__
= "ssa_val_to_loc_map", "loc_values"
2151 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2152 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2153 super().__init
__(memory
)
2154 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2155 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2156 if ssa_val
.ty
!= loc
.ty
:
2158 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2159 self
.loc_values
= loc_values
2160 for loc
in self
.loc_values
.keys():
2161 if loc
.reg_len
!= 1:
2163 "loc_values must only contain Locs with reg_len=1, all "
2164 "larger Locs will be split into reg_len=1 sub-Locs")
2166 def _loc_values__repr(self
):
2168 locs
= sorted(self
.loc_values
.keys(), key
=lambda v
: (v
.kind
, v
.start
))
2169 items
= [] # type: list[str]
2171 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2172 items_str
= ",\n".join(items
)
2173 return f
"{{\n{items_str},\n}}"
2175 def __getitem__(self
, ssa_val
):
2176 # type: (SSAVal) -> tuple[int, ...]
2177 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2178 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2179 retval
= [] # type: list[int]
2180 for i
in range(loc
.reg_len
):
2181 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2182 retval
.append(self
.loc_values
.get(subloc
, 0))
2183 return tuple(retval
)
2185 def __setitem__(self
, ssa_val
, value
):
2186 # type: (SSAVal, tuple[int, ...]) -> None
2187 if len(value
) != ssa_val
.ty
.reg_len
:
2188 raise ValueError("value has wrong len")
2189 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2190 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2191 for i
in range(loc
.reg_len
):
2192 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2193 self
.loc_values
[subloc
] = value
[i
]
2196 @plain_data(frozen
=True)
2198 __slots__
= "allocated_locs", "output"
2200 def __init__(self
, allocated_locs
, output
=None):
2201 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2203 self
.allocated_locs
= FMap(allocated_locs
)
2204 for ssa_val
, loc
in self
.allocated_locs
.items():
2205 if ssa_val
.ty
!= loc
.ty
:
2207 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2210 self
.output
= output
2212 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2214 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2215 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2216 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2217 ssa_val_or_locs
= [ssa_val_or_locs
]
2218 locs
= [] # type: list[Loc]
2219 for i
in ssa_val_or_locs
:
2220 if isinstance(i
, SSAVal
):
2221 locs
.append(self
.allocated_locs
[i
])
2225 raise ValueError("invalid Loc sequence: must not be empty")
2226 retval
= locs
[0].try_concat(*locs
[1:])
2228 raise ValueError("invalid Loc sequence: try_concat failed")
2229 if isinstance(expected_kinds
, LocKind
):
2230 expected_kinds
= expected_kinds
,
2231 if retval
.kind
not in expected_kinds
:
2232 if len(expected_kinds
) == 1:
2233 expected_kinds
= expected_kinds
[0]
2234 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2235 f
"{retval.kind} expected {expected_kinds}")
2238 def gpr(self
, ssa_val_or_locs
, is_vec
):
2239 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2240 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2241 vec_str
= "*" if is_vec
else ""
2242 return vec_str
+ str(loc
.start
)
2244 def sgpr(self
, ssa_val_or_locs
):
2245 # type: (__SSA_VAL_OR_LOCS) -> str
2246 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2248 def vgpr(self
, ssa_val_or_locs
):
2249 # type: (__SSA_VAL_OR_LOCS) -> str
2250 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2252 def stack(self
, ssa_val_or_locs
):
2253 # type: (__SSA_VAL_OR_LOCS) -> str
2254 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2255 return f
"{loc.start}(1)"
2257 def writeln(self
, *line_segments
):
2258 # type: (*str) -> None
2259 line
= " ".join(line_segments
)
2260 if isinstance(self
.output
, list):
2261 self
.output
.append(line
)
2263 self
.output
.write(line
+ "\n")