1 from contextlib
import contextmanager
3 from abc
import ABCMeta
, abstractmethod
4 from enum
import Enum
, unique
5 from functools
import lru_cache
, total_ordering
6 from io
import StringIO
7 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
8 Mapping
, Sequence
, TypeVar
, Union
, overload
)
9 from weakref
import WeakValueDictionary
as _WeakVDict
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import fields
, plain_data
14 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
16 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
21 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
22 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
28 self
.ops
= [] # type: list[Op]
29 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
30 self
.__next
_name
_suffix
= 2
32 def _add_op_with_unused_name(self
, op
, name
=""):
33 # type: (Op, str) -> str
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op
, "name"):
37 raise ValueError("Op already named")
40 if name
!= "" and name
not in self
.__op
_names
:
41 self
.__op
_names
[name
] = op
43 name
= orig_name
+ str(self
.__next
_name
_suffix
)
44 self
.__next
_name
_suffix
+= 1
50 def ops_to_str(self
, as_python_literal
=False, wrap_width
=63,
51 python_indent
=" ", indent
=" "):
52 # type: (bool, int, str, str) -> str
53 l
= [] # type: list[str]
55 l
.append(op
.__repr
__(wrap_width
=wrap_width
, indent
=indent
))
58 l
= [python_indent
+ "\""]
61 l
.append(f
"\\n\"\n{python_indent}\"")
64 elif ch
.isascii() and ch
.isprintable():
67 l
.append(repr(ch
).strip("\"'"))
70 empty_end
= f
"\"\n{python_indent}\"\""
71 if retval
.endswith(empty_end
):
72 retval
= retval
[:-len(empty_end
)]
75 def append_op(self
, op
):
78 raise ValueError("can't add Op to wrong Fn")
81 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
85 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
86 self
.append_op(retval
)
90 # type: (BaseSimState) -> None
94 def gen_asm(self
, state
):
95 # type: (GenAsmState) -> None
99 def pre_ra_insert_copies(self
):
101 orig_ops
= list(self
.ops
)
102 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs
= {} # type: dict[SSAVal, Op]
106 for i
in range(len(op
.input_vals
)):
107 inp
= copied_outputs
[op
.input_vals
[i
]]
108 if inp
.ty
.base_ty
is BaseTy
.I64
:
109 maxvl
= inp
.ty
.reg_len
110 if inp
.ty
.reg_len
!= 1:
111 setvl
= self
.append_new_op(
112 OpKind
.SetVLI
, immediates
=[maxvl
],
113 name
=f
"{op.name}.inp{i}.setvl")
114 vl
= setvl
.outputs
[0]
115 mv
= self
.append_new_op(
116 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
117 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
119 mv
= self
.append_new_op(
120 OpKind
.CopyToReg
, input_vals
=[inp
],
121 name
=f
"{op.name}.inp{i}.copy")
122 op
.input_vals
[i
] = mv
.outputs
[0]
123 elif inp
.ty
.base_ty
is BaseTy
.CA \
124 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
128 if inp
in setvli_outputs
:
129 setvl
= self
.append_new_op(
131 immediates
=setvli_outputs
[inp
].immediates
,
132 name
=f
"{op.name}.inp{i}.setvl")
133 inp
= setvl
.outputs
[0]
134 op
.input_vals
[i
] = inp
136 assert_never(inp
.ty
.base_ty
)
138 for i
, out
in enumerate(op
.outputs
):
139 if op
.kind
is OpKind
.SetVLI
:
140 setvli_outputs
[out
] = op
141 if out
.ty
.base_ty
is BaseTy
.I64
:
142 maxvl
= out
.ty
.reg_len
143 if out
.ty
.reg_len
!= 1:
144 setvl
= self
.append_new_op(
145 OpKind
.SetVLI
, immediates
=[maxvl
],
146 name
=f
"{op.name}.out{i}.setvl")
147 vl
= setvl
.outputs
[0]
148 mv
= self
.append_new_op(
149 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
150 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
152 mv
= self
.append_new_op(
153 OpKind
.CopyFromReg
, input_vals
=[out
],
154 name
=f
"{op.name}.out{i}.copy")
155 copied_outputs
[out
] = mv
.outputs
[0]
156 elif out
.ty
.base_ty
is BaseTy
.CA \
157 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs
[out
] = out
161 assert_never(out
.ty
.base_ty
)
168 value
: Literal
[0, 1] # type: ignore
170 def __new__(cls
, value
):
171 # type: (int) -> OpStage
173 if value
not in (0, 1):
174 raise ValueError("invalid value")
175 retval
= object.__new
__(cls
)
176 retval
._value
_ = value
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
200 return f
"OpStage.{self._name_}"
202 def __lt__(self
, other
):
203 # type: (OpStage | object) -> bool
204 if isinstance(other
, OpStage
):
205 return self
.value
< other
.value
206 return NotImplemented
209 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
212 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
215 class ProgramPoint(metaclass
=InternedMeta
):
216 __slots__
= "op_index", "stage"
218 def __init__(self
, op_index
, stage
):
219 # type: (int, OpStage) -> None
220 self
.op_index
= op_index
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
229 return self
.op_index
* 2 + self
.stage
.value
232 def from_int_value(int_value
):
233 # type: (int) -> ProgramPoint
234 op_index
, stage
= divmod(int_value
, 2)
235 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
237 def next(self
, steps
=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
241 def prev(self
, steps
=1):
242 # type: (int) -> ProgramPoint
243 return self
.next(steps
=-steps
)
245 def __lt__(self
, other
):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other
, ProgramPoint
):
248 return NotImplemented
249 if self
.op_index
!= other
.op_index
:
250 return self
.op_index
< other
.op_index
251 return self
.stage
< other
.stage
255 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
258 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
260 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
261 __slots__
= "start", "stop"
263 def __init__(self
, start
, stop
):
264 # type: (ProgramPoint, ProgramPoint) -> None
269 def int_value_range(self
):
271 return range(self
.start
.int_value
, self
.stop
.int_value
)
274 def from_int_value_range(int_value_range
):
275 # type: (range) -> ProgramRange
276 if int_value_range
.step
!= 1:
277 raise ValueError("int_value_range must have step == 1")
279 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
280 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
283 def __getitem__(self
, __idx
):
284 # type: (int) -> ProgramPoint
288 def __getitem__(self
, __idx
):
289 # type: (slice) -> ProgramRange
292 def __getitem__(self
, __idx
):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
295 if isinstance(v
, int):
296 return ProgramPoint
.from_int_value(v
)
297 return ProgramRange
.from_int_value_range(v
)
301 return len(self
.int_value_range
)
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
309 start
= repr(self
.start
).lstrip("<").rstrip(">")
310 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
311 return f
"<range:{start}..{stop}>"
314 @plain_data(frozen
=True, eq
=False, repr=False)
317 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
318 "def_program_ranges", "use_program_points",
319 "all_program_points")
321 def __init__(self
, fn
):
324 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
325 self
.all_program_points
= ProgramRange(
326 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
327 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
328 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
329 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
330 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
331 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
333 for use
in op
.input_uses
:
334 uses
[use
.ssa_val
].add(use
)
335 use_program_point
= self
.__get
_use
_program
_point
(use
)
336 use_program_points
[use
] = use_program_point
337 live_range_stops
[use
.ssa_val
] = max(
338 live_range_stops
[use
.ssa_val
], use_program_point
.next())
339 for out
in op
.outputs
:
341 def_program_range
= self
.__get
_def
_program
_range
(out
)
342 def_program_ranges
[out
] = def_program_range
343 live_range_stops
[out
] = def_program_range
.stop
344 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
345 self
.def_program_ranges
= FMap(def_program_ranges
)
346 self
.use_program_points
= FMap(use_program_points
)
347 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
348 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
349 for ssa_val
in uses
.keys():
350 live_ranges
[ssa_val
] = live_range
= ProgramRange(
351 start
=self
.def_program_ranges
[ssa_val
].start
,
352 stop
=live_range_stops
[ssa_val
])
353 for program_point
in live_range
:
354 live_at
[program_point
].add(ssa_val
)
355 self
.live_ranges
= FMap(live_ranges
)
356 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
358 def __get_def_program_range(self
, ssa_val
):
359 # type: (SSAVal) -> ProgramRange
360 write_stage
= ssa_val
.defining_descriptor
.write_stage
361 start
= ProgramPoint(
362 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
363 # always include late stage of ssa_val.op, to ensure outputs always
364 # overlap all other outputs.
365 # stop is exclusive, so we need the next program point.
366 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
367 return ProgramRange(start
=start
, stop
=stop
)
369 def __get_use_program_point(self
, ssa_use
):
370 # type: (SSAUse) -> ProgramPoint
371 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
372 "assumed here, ensured by GenericOpProperties.__init__"
374 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
376 def __eq__(self
, other
):
377 # type: (FnAnalysis | Any) -> bool
378 if isinstance(other
, FnAnalysis
):
379 return self
.fn
== other
.fn
380 return NotImplemented
388 return "<FnAnalysis>"
396 VL_MAXVL
= enum
.auto()
399 def only_scalar(self
):
401 if self
is BaseTy
.I64
:
403 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
409 def max_reg_len(self
):
411 if self
is BaseTy
.I64
:
413 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
419 return "BaseTy." + self
._name
_
422 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
424 class Ty(metaclass
=InternedMeta
):
425 __slots__
= "base_ty", "reg_len"
428 def validate(base_ty
, reg_len
):
429 # type: (BaseTy, int) -> str | None
430 """ return a string with the error if the combination is invalid,
431 otherwise return None
433 if base_ty
.only_scalar
and reg_len
!= 1:
434 return f
"can't create a vector of an only-scalar type: {base_ty}"
435 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
436 return "reg_len out of range"
439 def __init__(self
, base_ty
, reg_len
):
440 # type: (BaseTy, int) -> None
441 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
443 raise ValueError(msg
)
444 self
.base_ty
= base_ty
445 self
.reg_len
= reg_len
449 if self
.reg_len
!= 1:
450 reg_len
= f
"*{self.reg_len}"
453 return f
"<{self.base_ty._name_}{reg_len}>"
460 StackI64
= enum
.auto()
462 VL_MAXVL
= enum
.auto()
467 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
469 if self
is LocKind
.CA
:
471 if self
is LocKind
.VL_MAXVL
:
472 return BaseTy
.VL_MAXVL
479 if self
is LocKind
.StackI64
:
481 if self
is LocKind
.GPR
or self
is LocKind
.CA \
482 or self
is LocKind
.VL_MAXVL
:
483 return self
.base_ty
.max_reg_len
488 return "LocKind." + self
._name
_
493 class LocSubKind(Enum
):
494 BASE_GPR
= enum
.auto()
495 SV_EXTRA2_VGPR
= enum
.auto()
496 SV_EXTRA2_SGPR
= enum
.auto()
497 SV_EXTRA3_VGPR
= enum
.auto()
498 SV_EXTRA3_SGPR
= enum
.auto()
499 StackI64
= enum
.auto()
501 VL_MAXVL
= enum
.auto()
505 # type: () -> LocKind
506 # pyright fails typechecking when using `in` here:
507 # reported: https://github.com/microsoft/pyright/issues/4102
508 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
509 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
510 LocSubKind
.SV_EXTRA3_SGPR
):
512 if self
is LocSubKind
.StackI64
:
513 return LocKind
.StackI64
514 if self
is LocSubKind
.CA
:
516 if self
is LocSubKind
.VL_MAXVL
:
517 return LocKind
.VL_MAXVL
522 return self
.kind
.base_ty
525 def allocatable_locs(self
, ty
):
526 # type: (Ty) -> LocSet
527 if ty
.base_ty
!= self
.base_ty
:
528 raise ValueError("type mismatch")
529 if self
is LocSubKind
.BASE_GPR
:
531 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
532 starts
= range(0, 128, 2)
533 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
535 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
536 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
538 elif self
is LocSubKind
.StackI64
:
539 starts
= range(LocKind
.StackI64
.loc_count
)
540 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
541 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
544 retval
= [] # type: list[Loc]
546 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
550 for special_loc
in SPECIAL_GPRS
:
551 if loc
.conflicts(special_loc
):
556 return LocSet(retval
)
559 return "LocSubKind." + self
._name
_
562 @plain_data(frozen
=True, unsafe_hash
=True)
564 class GenericTy(metaclass
=InternedMeta
):
565 __slots__
= "base_ty", "is_vec"
567 def __init__(self
, base_ty
, is_vec
):
568 # type: (BaseTy, bool) -> None
569 self
.base_ty
= base_ty
570 if base_ty
.only_scalar
and is_vec
:
571 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
574 def instantiate(self
, maxvl
):
576 # here's where subvl and elwid would be accounted for
578 return Ty(self
.base_ty
, maxvl
)
579 return Ty(self
.base_ty
, 1)
581 def can_instantiate_to(self
, ty
):
583 if self
.base_ty
!= ty
.base_ty
:
587 return ty
.reg_len
== 1
590 @plain_data(frozen
=True, unsafe_hash
=True)
592 class Loc(metaclass
=InternedMeta
):
593 __slots__
= "kind", "start", "reg_len"
596 def validate(kind
, start
, reg_len
):
597 # type: (LocKind, int, int) -> str | None
598 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
601 if reg_len
> kind
.loc_count
:
602 return "invalid reg_len"
603 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
604 return "start not in valid range"
608 def try_make(kind
, start
, reg_len
):
609 # type: (LocKind, int, int) -> Loc | None
610 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
613 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
615 def __init__(self
, kind
, start
, reg_len
):
616 # type: (LocKind, int, int) -> None
617 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
619 raise ValueError(msg
)
621 self
.reg_len
= reg_len
624 def conflicts(self
, other
):
625 # type: (Loc) -> bool
626 return (self
.kind
== other
.kind
627 and self
.start
< other
.stop
and other
.start
< self
.stop
)
630 def make_ty(kind
, reg_len
):
631 # type: (LocKind, int) -> Ty
632 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
637 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
642 return self
.start
+ self
.reg_len
644 def try_concat(self
, *others
):
645 # type: (*Loc | None) -> Loc | None
646 reg_len
= self
.reg_len
649 if other
is None or other
.kind
!= self
.kind
:
651 if stop
!= other
.start
:
654 reg_len
+= other
.reg_len
655 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
657 def get_subloc_at_offset(self
, subloc_ty
, offset
):
658 # type: (Ty, int) -> Loc
659 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
660 raise ValueError("BaseTy mismatch")
661 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
662 raise ValueError("invalid sub-Loc: offset and/or "
663 "subloc_ty.reg_len out of range")
664 return Loc(kind
=self
.kind
,
665 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
669 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
670 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
671 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
672 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
677 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
678 def __init__(self
, __locs
=()):
679 # type: (Iterable[Loc]) -> None
680 super().__init
__(__locs
)
681 if isinstance(__locs
, LocSet
):
682 self
.__starts
= __locs
.starts
683 self
.__ty
= __locs
.ty
685 starts
= {i
: BitSet() for i
in LocKind
}
686 ty
= None # type: None | Ty
691 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
692 starts
[loc
.kind
].add(loc
.start
)
693 self
.__starts
= FMap(
694 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
699 # type: () -> FMap[LocKind, FBitSet]
704 # type: () -> Ty | None
709 # type: () -> FMap[LocKind, FBitSet]
714 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
718 # type: () -> AbstractSet[LocKind]
719 return self
.starts
.keys()
723 # type: () -> int | None
726 return self
.ty
.reg_len
730 # type: () -> BaseTy | None
733 return self
.ty
.base_ty
735 def concat(self
, *others
):
736 # type: (*LocSet) -> LocSet
739 base_ty
= self
.ty
.base_ty
740 reg_len
= self
.ty
.reg_len
741 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
745 if other
.ty
.base_ty
!= base_ty
:
747 for kind
, other_starts
in other
.starts
.items():
748 if kind
not in starts
:
750 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
751 if starts
[kind
] == 0:
755 reg_len
+= other
.ty
.reg_len
758 # type: () -> Iterable[Loc]
759 for kind
, v
in starts
.items():
761 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
764 return LocSet(locs())
766 @lru_cache(maxsize
=None, typed
=True)
767 def max_conflicts_with(self
, other
):
768 # type: (LocSet | Loc) -> int
769 """the largest number of Locs in `self` that a single Loc
770 from `other` can conflict with
772 if isinstance(other
, LocSet
):
773 return max(self
.max_conflicts_with(i
) for i
in other
)
775 return sum(other
.conflicts(i
) for i
in self
)
778 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
781 @plain_data(frozen
=True, unsafe_hash
=True)
783 class GenericOperandDesc(metaclass
=InternedMeta
):
784 """generic Op operand descriptor"""
785 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
789 self
, ty
, # type: GenericTy
790 sub_kinds
, # type: Iterable[LocSubKind]
792 fixed_loc
=None, # type: Loc | None
793 tied_input_index
=None, # type: int | None
794 spread
=False, # type: bool
795 write_stage
=OpStage
.Early
, # type: OpStage
797 # type: (...) -> None
799 self
.sub_kinds
= OFSet(sub_kinds
)
800 if len(self
.sub_kinds
) == 0:
801 raise ValueError("sub_kinds can't be empty")
802 self
.fixed_loc
= fixed_loc
803 if fixed_loc
is not None:
804 if tied_input_index
is not None:
805 raise ValueError("operand can't be both tied and fixed")
806 if not ty
.can_instantiate_to(fixed_loc
.ty
):
808 f
"fixed_loc has incompatible type for given generic "
809 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
810 if len(self
.sub_kinds
) != 1:
812 "multiple sub_kinds not allowed for fixed operand")
813 for sub_kind
in self
.sub_kinds
:
814 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
816 f
"fixed_loc not in given sub_kind: "
817 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
818 for sub_kind
in self
.sub_kinds
:
819 if sub_kind
.base_ty
!= ty
.base_ty
:
820 raise ValueError(f
"sub_kind is incompatible with type: "
821 f
"sub_kind={sub_kind} ty={ty}")
822 if tied_input_index
is not None and tied_input_index
< 0:
823 raise ValueError("invalid tied_input_index")
824 self
.tied_input_index
= tied_input_index
827 if self
.tied_input_index
is not None:
828 raise ValueError("operand can't be both spread and tied")
829 if self
.fixed_loc
is not None:
830 raise ValueError("operand can't be both spread and fixed")
832 raise ValueError("operand can't be both spread and vector")
833 self
.write_stage
= write_stage
836 def ty_before_spread(self
):
837 # type: () -> GenericTy
839 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
842 def tied_to_input(self
, tied_input_index
):
843 # type: (int) -> Self
844 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
845 tied_input_index
=tied_input_index
,
846 write_stage
=self
.write_stage
)
848 def with_fixed_loc(self
, fixed_loc
):
849 # type: (Loc) -> Self
850 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
851 write_stage
=self
.write_stage
)
853 def with_write_stage(self
, write_stage
):
854 # type: (OpStage) -> Self
855 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
856 fixed_loc
=self
.fixed_loc
,
857 tied_input_index
=self
.tied_input_index
,
859 write_stage
=write_stage
)
861 def instantiate(self
, maxvl
):
862 # type: (int) -> Iterable[OperandDesc]
863 # assumes all spread operands have ty.reg_len = 1
867 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
869 def locs_before_spread():
870 # type: () -> Iterable[Loc]
871 if self
.fixed_loc
is not None:
872 if ty_before_spread
!= self
.fixed_loc
.ty
:
874 f
"instantiation failed: type mismatch with fixed_loc: "
875 f
"instantiated type: {ty_before_spread} "
876 f
"fixed_loc: {self.fixed_loc}")
879 for sub_kind
in self
.sub_kinds
:
880 yield from sub_kind
.allocatable_locs(ty_before_spread
)
881 loc_set_before_spread
= LocSet(locs_before_spread())
882 for idx
in range(rep_count
):
885 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
886 tied_input_index
=self
.tied_input_index
,
887 spread_index
=idx
, write_stage
=self
.write_stage
)
890 @plain_data(frozen
=True, unsafe_hash
=True)
892 class OperandDesc(metaclass
=InternedMeta
):
893 """Op operand descriptor"""
894 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
897 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
899 # type: (LocSet, int | None, int | None, OpStage) -> None
900 if len(loc_set_before_spread
) == 0:
901 raise ValueError("loc_set_before_spread must not be empty")
902 self
.loc_set_before_spread
= loc_set_before_spread
903 self
.tied_input_index
= tied_input_index
904 if self
.tied_input_index
is not None and spread_index
is not None:
905 raise ValueError("operand can't be both spread and tied")
906 self
.spread_index
= spread_index
907 self
.write_stage
= write_stage
910 def ty_before_spread(self
):
912 ty
= self
.loc_set_before_spread
.ty
913 assert ty
is not None, (
914 "__init__ checked that the LocSet isn't empty, "
915 "non-empty LocSets should always have ty set")
920 """ Ty after any spread is applied """
921 if self
.spread_index
is not None:
922 # assumes all spread operands have ty.reg_len = 1
923 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
924 return self
.ty_before_spread
927 def reg_offset_in_unspread(self
):
928 """ the number of reg-sized slots in the unspread Loc before self's Loc
930 e.g. if the unspread Loc containing self is:
931 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
932 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
933 then reg_offset_into_unspread == 2 == 10 - 8
935 if self
.spread_index
is None:
937 return self
.spread_index
* self
.ty
.reg_len
940 OD_BASE_SGPR
= GenericOperandDesc(
941 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
942 sub_kinds
=[LocSubKind
.BASE_GPR
])
943 OD_EXTRA3_SGPR
= GenericOperandDesc(
944 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
945 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
946 OD_EXTRA3_VGPR
= GenericOperandDesc(
947 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
948 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
949 OD_EXTRA2_SGPR
= GenericOperandDesc(
950 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
951 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
952 OD_EXTRA2_VGPR
= GenericOperandDesc(
953 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
954 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
955 OD_CA
= GenericOperandDesc(
956 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
957 sub_kinds
=[LocSubKind
.CA
])
958 OD_VL
= GenericOperandDesc(
959 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
960 sub_kinds
=[LocSubKind
.VL_MAXVL
])
963 @plain_data(frozen
=True, unsafe_hash
=True)
965 class GenericOpProperties(metaclass
=InternedMeta
):
966 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
967 "is_copy", "is_load_immediate", "has_side_effects")
970 self
, demo_asm
, # type: str
971 inputs
, # type: Iterable[GenericOperandDesc]
972 outputs
, # type: Iterable[GenericOperandDesc]
973 immediates
=(), # type: Iterable[range]
974 is_copy
=False, # type: bool
975 is_load_immediate
=False, # type: bool
976 has_side_effects
=False, # type: bool
978 # type: (...) -> None
979 self
.demo_asm
= demo_asm
# type: str
980 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
981 for inp
in self
.inputs
:
982 if inp
.tied_input_index
is not None:
984 f
"tied_input_index is not allowed on inputs: {inp}")
985 if inp
.write_stage
is not OpStage
.Early
:
987 f
"write_stage is not allowed on inputs: {inp}")
988 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
989 fixed_locs
= [] # type: list[tuple[Loc, int]]
990 for idx
, out
in enumerate(self
.outputs
):
991 if out
.tied_input_index
is not None:
992 if out
.tied_input_index
>= len(self
.inputs
):
993 raise ValueError(f
"tied_input_index out of range: {out}")
994 tied_inp
= self
.inputs
[out
.tied_input_index
]
995 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
996 .with_write_stage(out
.write_stage
)
997 if expected_out
!= out
:
998 raise ValueError(f
"output can't be tied to non-equivalent "
999 f
"input: {out} tied to {tied_inp}")
1000 if out
.fixed_loc
is not None:
1001 for other_fixed_loc
, other_idx
in fixed_locs
:
1002 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
1005 f
"conflicting fixed_locs: outputs[{idx}] and "
1006 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1007 f
"with {other_fixed_loc}")
1008 fixed_locs
.append((out
.fixed_loc
, idx
))
1009 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
1010 self
.is_copy
= is_copy
# type: bool
1011 self
.is_load_immediate
= is_load_immediate
# type: bool
1012 self
.has_side_effects
= has_side_effects
# type: bool
1015 @plain_data(frozen
=True, unsafe_hash
=True)
1017 class OpProperties(metaclass
=InternedMeta
):
1018 __slots__
= "kind", "inputs", "outputs", "maxvl"
1020 def __init__(self
, kind
, maxvl
):
1021 # type: (OpKind, int) -> None
1022 self
.kind
= kind
# type: OpKind
1023 inputs
= [] # type: list[OperandDesc]
1024 for inp
in self
.generic
.inputs
:
1025 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1026 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1027 outputs
= [] # type: list[OperandDesc]
1028 for out
in self
.generic
.outputs
:
1029 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1030 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1031 self
.maxvl
= maxvl
# type: int
1035 # type: () -> GenericOpProperties
1036 return self
.kind
.properties
1039 def immediates(self
):
1040 # type: () -> tuple[range, ...]
1041 return self
.generic
.immediates
1046 return self
.generic
.demo_asm
1051 return self
.generic
.is_copy
1054 def is_load_immediate(self
):
1056 return self
.generic
.is_load_immediate
1059 def has_side_effects(self
):
1061 return self
.generic
.has_side_effects
1064 IMM_S16
= range(-1 << 15, 1 << 15)
1066 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1067 _SIM_FN2
= Callable
[[], _SIM_FN
]
1068 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1069 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1070 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1071 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1077 def __init__(self
, properties
):
1078 # type: (GenericOpProperties) -> None
1080 self
.__properties
= properties
1083 def properties(self
):
1084 # type: () -> GenericOpProperties
1085 return self
.__properties
1087 def instantiate(self
, maxvl
):
1088 # type: (int) -> OpProperties
1089 return OpProperties(self
, maxvl
=maxvl
)
1093 return "OpKind." + self
._name
_
1097 # type: () -> _SIM_FN
1098 return _SIM_FNS
[self
.properties
]()
1102 # type: () -> _GEN_ASM_FN
1103 return _GEN_ASMS
[self
.properties
]()
1106 def __clearca_sim(op
, state
):
1107 # type: (Op, BaseSimState) -> None
1108 state
[op
.outputs
[0]] = False,
1111 def __clearca_gen_asm(op
, state
):
1112 # type: (Op, GenAsmState) -> None
1113 state
.writeln("addic 0, 0, 0")
1114 ClearCA
= GenericOpProperties(
1115 demo_asm
="addic 0, 0, 0",
1117 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1119 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1120 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1123 def __setca_sim(op
, state
):
1124 # type: (Op, BaseSimState) -> None
1125 state
[op
.outputs
[0]] = True,
1128 def __setca_gen_asm(op
, state
):
1129 # type: (Op, GenAsmState) -> None
1130 state
.writeln("subfc 0, 0, 0")
1131 SetCA
= GenericOpProperties(
1132 demo_asm
="subfc 0, 0, 0",
1134 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1136 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1137 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1140 def __svadde_sim(op
, state
):
1141 # type: (Op, BaseSimState) -> None
1142 RA
= state
[op
.input_vals
[0]]
1143 RB
= state
[op
.input_vals
[1]]
1144 carry
, = state
[op
.input_vals
[2]]
1145 VL
, = state
[op
.input_vals
[3]]
1146 RT
= [] # type: list[int]
1148 v
= RA
[i
] + RB
[i
] + carry
1149 RT
.append(v
& GPR_VALUE_MASK
)
1150 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1151 state
[op
.outputs
[0]] = tuple(RT
)
1152 state
[op
.outputs
[1]] = carry
,
1155 def __svadde_gen_asm(op
, state
):
1156 # type: (Op, GenAsmState) -> None
1157 RT
= state
.vgpr(op
.outputs
[0])
1158 RA
= state
.vgpr(op
.input_vals
[0])
1159 RB
= state
.vgpr(op
.input_vals
[1])
1160 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1161 SvAddE
= GenericOpProperties(
1162 demo_asm
="sv.adde *RT, *RA, *RB",
1163 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1164 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1166 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1167 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1170 def __addze_sim(op
, state
):
1171 # type: (Op, BaseSimState) -> None
1172 RA
, = state
[op
.input_vals
[0]]
1173 carry
, = state
[op
.input_vals
[1]]
1175 RT
= v
& GPR_VALUE_MASK
1176 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1177 state
[op
.outputs
[0]] = RT
,
1178 state
[op
.outputs
[1]] = carry
,
1181 def __addze_gen_asm(op
, state
):
1182 # type: (Op, GenAsmState) -> None
1183 RT
= state
.vgpr(op
.outputs
[0])
1184 RA
= state
.vgpr(op
.input_vals
[0])
1185 state
.writeln(f
"addze {RT}, {RA}")
1186 AddZE
= GenericOpProperties(
1187 demo_asm
="addze RT, RA",
1188 inputs
=[OD_BASE_SGPR
, OD_CA
],
1189 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1191 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1192 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1195 def __svsubfe_sim(op
, state
):
1196 # type: (Op, BaseSimState) -> None
1197 RA
= state
[op
.input_vals
[0]]
1198 RB
= state
[op
.input_vals
[1]]
1199 carry
, = state
[op
.input_vals
[2]]
1200 VL
, = state
[op
.input_vals
[3]]
1201 RT
= [] # type: list[int]
1203 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1204 RT
.append(v
& GPR_VALUE_MASK
)
1205 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1206 state
[op
.outputs
[0]] = tuple(RT
)
1207 state
[op
.outputs
[1]] = carry
,
1210 def __svsubfe_gen_asm(op
, state
):
1211 # type: (Op, GenAsmState) -> None
1212 RT
= state
.vgpr(op
.outputs
[0])
1213 RA
= state
.vgpr(op
.input_vals
[0])
1214 RB
= state
.vgpr(op
.input_vals
[1])
1215 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1216 SvSubFE
= GenericOpProperties(
1217 demo_asm
="sv.subfe *RT, *RA, *RB",
1218 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1219 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1221 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1222 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1225 def __svandvs_sim(op
, state
):
1226 # type: (Op, BaseSimState) -> None
1227 RA
= state
[op
.input_vals
[0]]
1228 RB
, = state
[op
.input_vals
[1]]
1229 VL
, = state
[op
.input_vals
[2]]
1230 RT
= [] # type: list[int]
1232 RT
.append(RA
[i
] & RB
& GPR_VALUE_MASK
)
1233 state
[op
.outputs
[0]] = tuple(RT
)
1236 def __svandvs_gen_asm(op
, state
):
1237 # type: (Op, GenAsmState) -> None
1238 RT
= state
.vgpr(op
.outputs
[0])
1239 RA
= state
.vgpr(op
.input_vals
[0])
1240 RB
= state
.sgpr(op
.input_vals
[1])
1241 state
.writeln(f
"sv.and {RT}, {RA}, {RB}")
1242 SvAndVS
= GenericOpProperties(
1243 demo_asm
="sv.and *RT, *RA, RB",
1244 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1245 outputs
=[OD_EXTRA3_VGPR
],
1247 _SIM_FNS
[SvAndVS
] = lambda: OpKind
.__svandvs
_sim
1248 _GEN_ASMS
[SvAndVS
] = lambda: OpKind
.__svandvs
_gen
_asm
1251 def __svmaddedu_sim(op
, state
):
1252 # type: (Op, BaseSimState) -> None
1253 RA
= state
[op
.input_vals
[0]]
1254 RB
, = state
[op
.input_vals
[1]]
1255 carry
, = state
[op
.input_vals
[2]]
1256 VL
, = state
[op
.input_vals
[3]]
1257 RT
= [] # type: list[int]
1259 v
= RA
[i
] * RB
+ carry
1260 RT
.append(v
& GPR_VALUE_MASK
)
1261 carry
= v
>> GPR_SIZE_IN_BITS
1262 state
[op
.outputs
[0]] = tuple(RT
)
1263 state
[op
.outputs
[1]] = carry
,
1266 def __svmaddedu_gen_asm(op
, state
):
1267 # type: (Op, GenAsmState) -> None
1268 RT
= state
.vgpr(op
.outputs
[0])
1269 RA
= state
.vgpr(op
.input_vals
[0])
1270 RB
= state
.sgpr(op
.input_vals
[1])
1271 RC
= state
.sgpr(op
.input_vals
[2])
1272 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1273 SvMAddEDU
= GenericOpProperties(
1274 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1275 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1276 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1278 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1279 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1282 def __sradi_sim(op
, state
):
1283 # type: (Op, BaseSimState) -> None
1284 rs
, = state
[op
.input_vals
[0]]
1285 imm
= op
.immediates
[0]
1286 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1287 rs
-= 1 << GPR_SIZE_IN_BITS
1289 RA
= v
& GPR_VALUE_MASK
1290 CA
= (RA
<< imm
) != rs
1291 state
[op
.outputs
[0]] = RA
,
1292 state
[op
.outputs
[1]] = CA
,
1295 def __sradi_gen_asm(op
, state
):
1296 # type: (Op, GenAsmState) -> None
1297 RA
= state
.sgpr(op
.outputs
[0])
1298 RS
= state
.sgpr(op
.input_vals
[0])
1299 imm
= op
.immediates
[0]
1300 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1301 SRADI
= GenericOpProperties(
1302 demo_asm
="sradi RA, RS, imm",
1303 inputs
=[OD_BASE_SGPR
],
1304 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1305 OD_CA
.with_write_stage(OpStage
.Late
)],
1306 immediates
=[range(GPR_SIZE_IN_BITS
)],
1308 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1309 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1312 def __setvli_sim(op
, state
):
1313 # type: (Op, BaseSimState) -> None
1314 state
[op
.outputs
[0]] = op
.immediates
[0],
1317 def __setvli_gen_asm(op
, state
):
1318 # type: (Op, GenAsmState) -> None
1319 imm
= op
.immediates
[0]
1320 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1321 SetVLI
= GenericOpProperties(
1322 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1324 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1325 immediates
=[range(1, 65)],
1326 is_load_immediate
=True,
1328 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1329 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1332 def __svli_sim(op
, state
):
1333 # type: (Op, BaseSimState) -> None
1334 VL
, = state
[op
.input_vals
[0]]
1335 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1336 state
[op
.outputs
[0]] = (imm
,) * VL
1339 def __svli_gen_asm(op
, state
):
1340 # type: (Op, GenAsmState) -> None
1341 RT
= state
.vgpr(op
.outputs
[0])
1342 imm
= op
.immediates
[0]
1343 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1344 SvLI
= GenericOpProperties(
1345 demo_asm
="sv.addi *RT, 0, imm",
1347 outputs
=[OD_EXTRA3_VGPR
],
1348 immediates
=[IMM_S16
],
1349 is_load_immediate
=True,
1351 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1352 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1355 def __li_sim(op
, state
):
1356 # type: (Op, BaseSimState) -> None
1357 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1358 state
[op
.outputs
[0]] = imm
,
1361 def __li_gen_asm(op
, state
):
1362 # type: (Op, GenAsmState) -> None
1363 RT
= state
.sgpr(op
.outputs
[0])
1364 imm
= op
.immediates
[0]
1365 state
.writeln(f
"addi {RT}, 0, {imm}")
1366 LI
= GenericOpProperties(
1367 demo_asm
="addi RT, 0, imm",
1369 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1370 immediates
=[IMM_S16
],
1371 is_load_immediate
=True,
1373 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1374 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1377 def __veccopytoreg_sim(op
, state
):
1378 # type: (Op, BaseSimState) -> None
1379 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1382 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1383 # type: (Loc, Loc, bool, GenAsmState) -> None
1384 sv
= "sv." if is_vec
else ""
1386 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1388 if src_loc
== dest_loc
:
1390 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1391 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1392 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1393 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1394 if src_loc
.kind
is LocKind
.StackI64
:
1395 if dest_loc
.kind
is LocKind
.StackI64
:
1397 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1398 elif dest_loc
.kind
is not LocKind
.GPR
:
1399 assert_never(dest_loc
.kind
)
1400 src
= state
.stack(src_loc
)
1401 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1402 state
.writeln(f
"{sv}ld {dest}, {src}")
1403 elif dest_loc
.kind
is LocKind
.StackI64
:
1404 if src_loc
.kind
is not LocKind
.GPR
:
1405 assert_never(src_loc
.kind
)
1406 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1407 dest
= state
.stack(dest_loc
)
1408 state
.writeln(f
"{sv}std {src}, {dest}")
1409 elif src_loc
.kind
is LocKind
.GPR
:
1410 if dest_loc
.kind
is not LocKind
.GPR
:
1411 assert_never(dest_loc
.kind
)
1412 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1413 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1414 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1416 assert_never(src_loc
.kind
)
1419 def __veccopytoreg_gen_asm(op
, state
):
1420 # type: (Op, GenAsmState) -> None
1421 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1423 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1424 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1425 is_vec
=True, state
=state
)
1427 VecCopyToReg
= GenericOpProperties(
1428 demo_asm
="sv.mv dest, src",
1429 inputs
=[GenericOperandDesc(
1430 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1431 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1433 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1436 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1437 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1440 def __veccopyfromreg_sim(op
, state
):
1441 # type: (Op, BaseSimState) -> None
1442 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1445 def __veccopyfromreg_gen_asm(op
, state
):
1446 # type: (Op, GenAsmState) -> None
1447 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1448 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1450 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1451 is_vec
=True, state
=state
)
1452 VecCopyFromReg
= GenericOpProperties(
1453 demo_asm
="sv.mv dest, src",
1454 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1455 outputs
=[GenericOperandDesc(
1456 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1457 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1458 write_stage
=OpStage
.Late
,
1462 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1463 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1466 def __copytoreg_sim(op
, state
):
1467 # type: (Op, BaseSimState) -> None
1468 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1471 def __copytoreg_gen_asm(op
, state
):
1472 # type: (Op, GenAsmState) -> None
1473 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1475 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1476 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1477 is_vec
=False, state
=state
)
1478 CopyToReg
= GenericOpProperties(
1479 demo_asm
="mv dest, src",
1480 inputs
=[GenericOperandDesc(
1481 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1482 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1483 LocSubKind
.StackI64
],
1485 outputs
=[GenericOperandDesc(
1486 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1487 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1488 write_stage
=OpStage
.Late
,
1492 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1493 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1496 def __copyfromreg_sim(op
, state
):
1497 # type: (Op, BaseSimState) -> None
1498 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1501 def __copyfromreg_gen_asm(op
, state
):
1502 # type: (Op, GenAsmState) -> None
1503 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1504 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1506 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1507 is_vec
=False, state
=state
)
1508 CopyFromReg
= GenericOpProperties(
1509 demo_asm
="mv dest, src",
1510 inputs
=[GenericOperandDesc(
1511 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1512 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1514 outputs
=[GenericOperandDesc(
1515 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1516 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1517 LocSubKind
.StackI64
],
1518 write_stage
=OpStage
.Late
,
1522 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1523 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1526 def __concat_sim(op
, state
):
1527 # type: (Op, BaseSimState) -> None
1528 state
[op
.outputs
[0]] = tuple(
1529 state
[i
][0] for i
in op
.input_vals
[:-1])
1532 def __concat_gen_asm(op
, state
):
1533 # type: (Op, GenAsmState) -> None
1534 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1535 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1536 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1537 is_vec
=True, state
=state
)
1538 Concat
= GenericOpProperties(
1539 demo_asm
="sv.mv dest, src",
1540 inputs
=[GenericOperandDesc(
1541 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1542 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1545 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1548 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1549 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1552 def __spread_sim(op
, state
):
1553 # type: (Op, BaseSimState) -> None
1554 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1555 state
[op
.outputs
[idx
]] = inp
,
1558 def __spread_gen_asm(op
, state
):
1559 # type: (Op, GenAsmState) -> None
1560 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1561 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1562 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1563 is_vec
=True, state
=state
)
1564 Spread
= GenericOpProperties(
1565 demo_asm
="sv.mv dest, src",
1566 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1567 outputs
=[GenericOperandDesc(
1568 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1569 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1571 write_stage
=OpStage
.Late
,
1575 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1576 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1579 def __svld_sim(op
, state
):
1580 # type: (Op, BaseSimState) -> None
1581 RA
, = state
[op
.input_vals
[0]]
1582 VL
, = state
[op
.input_vals
[1]]
1583 addr
= RA
+ op
.immediates
[0]
1584 RT
= [] # type: list[int]
1586 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1587 RT
.append(v
& GPR_VALUE_MASK
)
1588 state
[op
.outputs
[0]] = tuple(RT
)
1591 def __svld_gen_asm(op
, state
):
1592 # type: (Op, GenAsmState) -> None
1593 RA
= state
.sgpr(op
.input_vals
[0])
1594 RT
= state
.vgpr(op
.outputs
[0])
1595 imm
= op
.immediates
[0]
1596 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1597 SvLd
= GenericOpProperties(
1598 demo_asm
="sv.ld *RT, imm(RA)",
1599 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1600 outputs
=[OD_EXTRA3_VGPR
],
1601 immediates
=[IMM_S16
],
1603 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1604 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1607 def __ld_sim(op
, state
):
1608 # type: (Op, BaseSimState) -> None
1609 RA
, = state
[op
.input_vals
[0]]
1610 addr
= RA
+ op
.immediates
[0]
1611 v
= state
.load(addr
)
1612 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1615 def __ld_gen_asm(op
, state
):
1616 # type: (Op, GenAsmState) -> None
1617 RA
= state
.sgpr(op
.input_vals
[0])
1618 RT
= state
.sgpr(op
.outputs
[0])
1619 imm
= op
.immediates
[0]
1620 state
.writeln(f
"ld {RT}, {imm}({RA})")
1621 Ld
= GenericOpProperties(
1622 demo_asm
="ld RT, imm(RA)",
1623 inputs
=[OD_BASE_SGPR
],
1624 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1625 immediates
=[IMM_S16
],
1627 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1628 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1631 def __svstd_sim(op
, state
):
1632 # type: (Op, BaseSimState) -> None
1633 RS
= state
[op
.input_vals
[0]]
1634 RA
, = state
[op
.input_vals
[1]]
1635 VL
, = state
[op
.input_vals
[2]]
1636 addr
= RA
+ op
.immediates
[0]
1638 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1641 def __svstd_gen_asm(op
, state
):
1642 # type: (Op, GenAsmState) -> None
1643 RS
= state
.vgpr(op
.input_vals
[0])
1644 RA
= state
.sgpr(op
.input_vals
[1])
1645 imm
= op
.immediates
[0]
1646 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1647 SvStd
= GenericOpProperties(
1648 demo_asm
="sv.std *RS, imm(RA)",
1649 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1651 immediates
=[IMM_S16
],
1652 has_side_effects
=True,
1654 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1655 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1658 def __std_sim(op
, state
):
1659 # type: (Op, BaseSimState) -> None
1660 RS
, = state
[op
.input_vals
[0]]
1661 RA
, = state
[op
.input_vals
[1]]
1662 addr
= RA
+ op
.immediates
[0]
1663 state
.store(addr
, value
=RS
)
1666 def __std_gen_asm(op
, state
):
1667 # type: (Op, GenAsmState) -> None
1668 RS
= state
.sgpr(op
.input_vals
[0])
1669 RA
= state
.sgpr(op
.input_vals
[1])
1670 imm
= op
.immediates
[0]
1671 state
.writeln(f
"std {RS}, {imm}({RA})")
1672 Std
= GenericOpProperties(
1673 demo_asm
="std RS, imm(RA)",
1674 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1676 immediates
=[IMM_S16
],
1677 has_side_effects
=True,
1679 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1680 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1683 def __funcargr3_sim(op
, state
):
1684 # type: (Op, BaseSimState) -> None
1685 pass # return value set before simulation
1688 def __funcargr3_gen_asm(op
, state
):
1689 # type: (Op, GenAsmState) -> None
1690 pass # no instructions needed
1691 FuncArgR3
= GenericOpProperties(
1694 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1695 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1697 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1698 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1701 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1702 class SSAValOrUse(metaclass
=InternedMeta
):
1703 __slots__
= "op", "operand_idx"
1705 def __init__(self
, op
, operand_idx
):
1706 # type: (Op, int) -> None
1709 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1710 raise ValueError("invalid operand_idx")
1711 self
.operand_idx
= operand_idx
1720 def descriptor_array(self
):
1721 # type: () -> tuple[OperandDesc, ...]
1725 def defining_descriptor(self
):
1726 # type: () -> OperandDesc
1727 return self
.descriptor_array
[self
.operand_idx
]
1732 return self
.defining_descriptor
.ty
1735 def ty_before_spread(self
):
1737 return self
.defining_descriptor
.ty_before_spread
1741 # type: () -> BaseTy
1742 return self
.ty_before_spread
.base_ty
1745 def reg_offset_in_unspread(self
):
1746 """ the number of reg-sized slots in the unspread Loc before self's Loc
1748 e.g. if the unspread Loc containing self is:
1749 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1750 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1751 then reg_offset_into_unspread == 2 == 10 - 8
1753 return self
.defining_descriptor
.reg_offset_in_unspread
1756 def unspread_start_idx(self
):
1758 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1761 def unspread_start(self
):
1763 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1766 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1768 class SSAVal(SSAValOrUse
):
1773 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1776 def def_loc_set_before_spread(self
):
1777 # type: () -> LocSet
1778 return self
.defining_descriptor
.loc_set_before_spread
1781 def descriptor_array(self
):
1782 # type: () -> tuple[OperandDesc, ...]
1783 return self
.op
.properties
.outputs
1786 def tied_input(self
):
1787 # type: () -> None | SSAUse
1788 if self
.defining_descriptor
.tied_input_index
is None:
1790 return SSAUse(op
=self
.op
,
1791 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1794 def write_stage(self
):
1795 # type: () -> OpStage
1796 return self
.defining_descriptor
.write_stage
1799 def current_debugging_value(self
):
1800 # type: () -> tuple[int, ...]
1801 """ get the current value for debugging in pdb or similar.
1803 This is intended for use with
1804 `PreRASimState.set_current_debugging_state`.
1806 This is only intended for debugging, do not use in unit tests or
1809 return PreRASimState
.get_current_debugging_state()[self
]
1812 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1814 class SSAUse(SSAValOrUse
):
1818 def use_loc_set_before_spread(self
):
1819 # type: () -> LocSet
1820 return self
.defining_descriptor
.loc_set_before_spread
1823 def descriptor_array(self
):
1824 # type: () -> tuple[OperandDesc, ...]
1825 return self
.op
.properties
.inputs
1829 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1833 # type: () -> SSAVal
1834 return self
.op
.input_vals
[self
.operand_idx
]
1837 def ssa_val(self
, ssa_val
):
1838 # type: (SSAVal) -> None
1839 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1843 _Desc
= TypeVar("_Desc")
1846 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1848 def _verify_write_with_desc(self
, idx
, item
, desc
):
1849 # type: (int, _T | Any, _Desc) -> None
1850 raise NotImplementedError
1853 def _verify_write(self
, idx
, item
):
1854 # type: (int | Any, _T | Any) -> int
1855 if not isinstance(idx
, int):
1856 if isinstance(idx
, slice):
1858 f
"can't write to slice of {self.__class__.__name__}")
1859 raise TypeError(f
"can't write with index {idx!r}")
1860 # normalize idx, raising IndexError if it is out of range
1861 idx
= range(len(self
.descriptors
))[idx
]
1862 desc
= self
.descriptors
[idx
]
1863 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1866 def _on_set(self
, idx
, new_item
, old_item
):
1867 # type: (int, _T, _T | None) -> None
1871 def _get_descriptors(self
):
1872 # type: () -> tuple[_Desc, ...]
1873 raise NotImplementedError
1877 def descriptors(self
):
1878 # type: () -> tuple[_Desc, ...]
1879 return self
._get
_descriptors
()
1886 def __init__(self
, items
, op
):
1887 # type: (Iterable[_T], Op) -> None
1890 self
.__items
= [] # type: list[_T]
1891 for idx
, item
in enumerate(items
):
1892 if idx
>= len(self
.descriptors
):
1893 raise ValueError("too many items")
1894 _
= self
._verify
_write
(idx
, item
)
1895 self
.__items
.append(item
)
1896 if len(self
.__items
) < len(self
.descriptors
):
1897 raise ValueError("not enough items")
1901 # type: () -> Iterator[_T]
1902 yield from self
.__items
1905 def __getitem__(self
, idx
):
1910 def __getitem__(self
, idx
):
1911 # type: (slice) -> list[_T]
1915 def __getitem__(self
, idx
):
1916 # type: (int | slice) -> _T | list[_T]
1917 return self
.__items
[idx
]
1920 def __setitem__(self
, idx
, item
):
1921 # type: (int, _T) -> None
1922 idx
= self
._verify
_write
(idx
, item
)
1923 self
.__items
[idx
] = item
1928 return len(self
.__items
)
1932 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1936 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
1937 def _get_descriptors(self
):
1938 # type: () -> tuple[OperandDesc, ...]
1939 return self
.op
.properties
.inputs
1941 def _verify_write_with_desc(self
, idx
, item
, desc
):
1942 # type: (int, SSAVal | Any, OperandDesc) -> None
1943 if not isinstance(item
, SSAVal
):
1944 raise TypeError("expected value of type SSAVal")
1945 if item
.ty
!= desc
.ty
:
1946 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1947 f
"corresponding input's type {desc.ty!r}")
1949 def _on_set(self
, idx
, new_item
, old_item
):
1950 # type: (int, SSAVal, SSAVal | None) -> None
1951 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
1953 def __init__(self
, items
, op
):
1954 # type: (Iterable[SSAVal], Op) -> None
1955 if hasattr(op
, "inputs"):
1956 raise ValueError("Op.inputs already set")
1957 super().__init
__(items
, op
)
1961 class OpImmediates(OpInputSeq
[int, range]):
1962 def _get_descriptors(self
):
1963 # type: () -> tuple[range, ...]
1964 return self
.op
.properties
.immediates
1966 def _verify_write_with_desc(self
, idx
, item
, desc
):
1967 # type: (int, int | Any, range) -> None
1968 if not isinstance(item
, int):
1969 raise TypeError("expected value of type int")
1970 if item
not in desc
:
1971 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1973 def __init__(self
, items
, op
):
1974 # type: (Iterable[int], Op) -> None
1975 if hasattr(op
, "immediates"):
1976 raise ValueError("Op.immediates already set")
1977 super().__init
__(items
, op
)
1980 @plain_data(frozen
=True, eq
=False, repr=False)
1983 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
1986 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
1987 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1989 self
.properties
= properties
1990 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
1991 inputs_len
= len(self
.properties
.inputs
)
1992 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
1993 self
.immediates
= OpImmediates(immediates
, op
=self
)
1994 outputs_len
= len(self
.properties
.outputs
)
1995 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1996 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
2000 # type: () -> OpKind
2001 return self
.properties
.kind
2003 def __eq__(self
, other
):
2004 # type: (Op | Any) -> bool
2005 if isinstance(other
, Op
):
2006 return self
is other
2007 return NotImplemented
2011 return object.__hash
__(self
)
2013 def __repr__(self
, wrap_width
=63, indent
=" "):
2014 # type: (int, str) -> str
2015 WRAP_POINT
= "\u200B" # zero-width space
2016 items
= [f
"{self.name}:\n"]
2017 for i
, out
in enumerate(self
.outputs
):
2018 item
= f
"<...outputs[{i}]: {out.ty}>"
2020 item
= "(" + WRAP_POINT
+ item
2021 if i
!= len(self
.outputs
) - 1:
2022 item
+= ", " + WRAP_POINT
2024 item
+= WRAP_POINT
+ ") <= "
2026 items
.append(self
.kind
._name
_)
2027 if len(self
.input_vals
) + len(self
.immediates
) != 0:
2029 items
[-1] += WRAP_POINT
2030 for i
, inp
in enumerate(self
.input_vals
):
2032 if i
!= len(self
.input_vals
) - 1 or len(self
.immediates
) != 0:
2033 item
+= ", " + WRAP_POINT
2035 item
+= ") " + WRAP_POINT
2037 for i
, imm
in enumerate(self
.immediates
):
2039 if i
!= len(self
.immediates
) - 1:
2040 item
+= ", " + WRAP_POINT
2042 item
+= ") " + WRAP_POINT
2044 lines
= [] # type: list[str]
2045 for i
, line_in
in enumerate("".join(items
).splitlines()):
2047 line_in
= indent
+ line_in
2049 for part
in line_in
.split(WRAP_POINT
):
2053 trial_line_out
= line_out
+ part
2054 if len(trial_line_out
.rstrip()) > wrap_width
:
2055 lines
.append(line_out
.rstrip())
2056 line_out
= indent
+ part
2058 line_out
= trial_line_out
2059 lines
.append(line_out
.rstrip())
2060 return "\n".join(lines
)
2062 def sim(self
, state
):
2063 # type: (BaseSimState) -> None
2064 for inp
in self
.input_vals
:
2068 raise ValueError(f
"SSAVal {inp} not yet assigned when "
2070 if len(val
) != inp
.ty
.reg_len
:
2072 f
"value of SSAVal {inp} has wrong number of elements: "
2073 f
"expected {inp.ty.reg_len} found "
2074 f
"{len(val)}: {val!r}")
2075 if isinstance(state
, PreRASimState
):
2076 for out
in self
.outputs
:
2077 if out
in state
.ssa_vals
:
2078 if self
.kind
is OpKind
.FuncArgR3
:
2080 raise ValueError(f
"SSAVal {out} already assigned before "
2082 self
.kind
.sim(self
, state
)
2083 for out
in self
.outputs
:
2087 raise ValueError(f
"running {self} failed to assign to {out}")
2088 if len(val
) != out
.ty
.reg_len
:
2090 f
"value of SSAVal {out} has wrong number of elements: "
2091 f
"expected {out.ty.reg_len} found "
2092 f
"{len(val)}: {val!r}")
2094 def gen_asm(self
, state
):
2095 # type: (GenAsmState) -> None
2096 all_loc_kinds
= tuple(LocKind
)
2097 for inp
in self
.input_vals
:
2098 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2099 for out
in self
.outputs
:
2100 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2101 self
.kind
.gen_asm(self
, state
)
2104 @plain_data(frozen
=True, repr=False)
2105 class BaseSimState(metaclass
=ABCMeta
):
2106 __slots__
= "memory",
2108 def __init__(self
, memory
):
2109 # type: (dict[int, int]) -> None
2111 self
.memory
= memory
# type: dict[int, int]
2113 def load_byte(self
, addr
):
2114 # type: (int) -> int
2115 addr
&= GPR_VALUE_MASK
2116 return self
.memory
.get(addr
, 0) & 0xFF
2118 def store_byte(self
, addr
, value
):
2119 # type: (int, int) -> None
2120 addr
&= GPR_VALUE_MASK
2122 self
.memory
[addr
] = value
2124 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2125 # type: (int, int, bool) -> int
2126 if addr
% size_in_bytes
!= 0:
2127 raise ValueError(f
"address not aligned: {hex(addr)} "
2128 f
"required alignment: {size_in_bytes}")
2130 for i
in range(size_in_bytes
):
2131 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2132 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2133 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2136 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2137 # type: (int, int, int) -> None
2138 if addr
% size_in_bytes
!= 0:
2139 raise ValueError(f
"address not aligned: {hex(addr)} "
2140 f
"required alignment: {size_in_bytes}")
2141 for i
in range(size_in_bytes
):
2142 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2144 def _memory__repr(self
):
2146 if len(self
.memory
) == 0:
2148 keys
= sorted(self
.memory
.keys(), reverse
=True)
2149 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2150 items
= [] # type: list[str]
2151 while len(keys
) != 0:
2153 if (len(keys
) >= CHUNK_SIZE
2154 and addr
% CHUNK_SIZE
== 0
2155 and keys
[-CHUNK_SIZE
:]
2156 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2157 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2158 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2159 keys
[-CHUNK_SIZE
:] = ()
2161 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2163 return f
"{{{items[0]}}}"
2164 items_str
= ",\n".join(items
)
2165 return f
"{{\n{items_str}}}"
2169 field_vals
= [] # type: list[str]
2170 for name
in fields(self
):
2172 value
= getattr(self
, name
)
2173 except AttributeError:
2174 field_vals
.append(f
"{name}=<not set>")
2176 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2177 if callable(repr_fn
):
2178 field_vals
.append(f
"{name}={repr_fn()}")
2180 field_vals
.append(f
"{name}={value!r}")
2181 field_vals_str
= ", ".join(field_vals
)
2182 return f
"{self.__class__.__name__}({field_vals_str})"
2185 def __getitem__(self
, ssa_val
):
2186 # type: (SSAVal) -> tuple[int, ...]
2190 def __setitem__(self
, ssa_val
, value
):
2191 # type: (SSAVal, tuple[int, ...]) -> None
2195 @plain_data(frozen
=True, repr=False)
2197 class PreRASimState(BaseSimState
):
2198 __slots__
= "ssa_vals",
2200 def __init__(self
, ssa_vals
, memory
):
2201 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2202 super().__init
__(memory
)
2203 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2205 def _ssa_vals__repr(self
):
2207 if len(self
.ssa_vals
) == 0:
2209 items
= [] # type: list[str]
2211 for k
, v
in self
.ssa_vals
.items():
2212 element_strs
= [] # type: list[str]
2213 for i
, el
in enumerate(v
):
2214 if i
% CHUNK_SIZE
!= 0:
2215 element_strs
.append(" " + hex(el
))
2217 element_strs
.append("\n " + hex(el
))
2218 if len(element_strs
) <= CHUNK_SIZE
:
2219 element_strs
[0] = element_strs
[0].lstrip()
2220 if len(element_strs
) == 1:
2221 element_strs
.append("")
2222 v_str
= ",".join(element_strs
)
2223 items
.append(f
"{k!r}: ({v_str})")
2224 if len(items
) == 1 and "\n" not in items
[0]:
2225 return f
"{{{items[0]}}}"
2226 items_str
= ",\n".join(items
)
2227 return f
"{{\n{items_str},\n}}"
2229 def __getitem__(self
, ssa_val
):
2230 # type: (SSAVal) -> tuple[int, ...]
2231 return self
.ssa_vals
[ssa_val
]
2233 def __setitem__(self
, ssa_val
, value
):
2234 # type: (SSAVal, tuple[int, ...]) -> None
2235 if len(value
) != ssa_val
.ty
.reg_len
:
2236 raise ValueError("value has wrong len")
2237 self
.ssa_vals
[ssa_val
] = value
2239 __CURRENT_DEBUGGING_STATE
= [] # type: list[PreRASimState]
2242 def set_as_current_debugging_state(self
):
2243 """ return a context manager that sets self as the current state for
2244 debugging in pdb or similar. This is intended only for use with
2245 `get_current_debugging_state` which should not be used in unit tests
2249 PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.append(self
)
2252 assert self
is PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.pop(), \
2253 "inconsistent __CURRENT_DEBUGGING_STATE"
2256 def get_current_debugging_state():
2257 # type: () -> PreRASimState
2258 """ get the current state for debugging in pdb or similar.
2260 This is intended for use with `set_current_debugging_state`.
2262 This is only intended for debugging, do not use in unit tests or
2265 if len(PreRASimState
.__CURRENT
_DEBUGGING
_STATE
) == 0:
2266 raise ValueError("no current debugging state")
2267 return PreRASimState
.__CURRENT
_DEBUGGING
_STATE
[-1]
2270 @plain_data(frozen
=True, repr=False)
2272 class PostRASimState(BaseSimState
):
2273 __slots__
= "ssa_val_to_loc_map", "loc_values"
2275 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2276 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2277 super().__init
__(memory
)
2278 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2279 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2280 if ssa_val
.ty
!= loc
.ty
:
2282 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2283 self
.loc_values
= loc_values
2284 for loc
in self
.loc_values
.keys():
2285 if loc
.reg_len
!= 1:
2287 "loc_values must only contain Locs with reg_len=1, all "
2288 "larger Locs will be split into reg_len=1 sub-Locs")
2290 def _loc_values__repr(self
):
2292 locs
= sorted(self
.loc_values
.keys(),
2293 key
=lambda v
: (v
.kind
.name
, v
.start
))
2294 items
= [] # type: list[str]
2296 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2297 items_str
= ",\n".join(items
)
2298 return f
"{{\n{items_str},\n}}"
2300 def __getitem__(self
, ssa_val
):
2301 # type: (SSAVal) -> tuple[int, ...]
2302 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2303 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2304 retval
= [] # type: list[int]
2305 for i
in range(loc
.reg_len
):
2306 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2307 retval
.append(self
.loc_values
.get(subloc
, 0))
2308 return tuple(retval
)
2310 def __setitem__(self
, ssa_val
, value
):
2311 # type: (SSAVal, tuple[int, ...]) -> None
2312 if len(value
) != ssa_val
.ty
.reg_len
:
2313 raise ValueError("value has wrong len")
2314 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2315 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2316 for i
in range(loc
.reg_len
):
2317 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2318 self
.loc_values
[subloc
] = value
[i
]
2321 @plain_data(frozen
=True)
2323 __slots__
= "allocated_locs", "output"
2325 def __init__(self
, allocated_locs
, output
=None):
2326 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2328 self
.allocated_locs
= FMap(allocated_locs
)
2329 for ssa_val
, loc
in self
.allocated_locs
.items():
2330 if ssa_val
.ty
!= loc
.ty
:
2332 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2335 self
.output
= output
2337 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2339 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2340 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2341 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2342 ssa_val_or_locs
= [ssa_val_or_locs
]
2343 locs
= [] # type: list[Loc]
2344 for i
in ssa_val_or_locs
:
2345 if isinstance(i
, SSAVal
):
2346 locs
.append(self
.allocated_locs
[i
])
2350 raise ValueError("invalid Loc sequence: must not be empty")
2351 retval
= locs
[0].try_concat(*locs
[1:])
2353 raise ValueError("invalid Loc sequence: try_concat failed")
2354 if isinstance(expected_kinds
, LocKind
):
2355 expected_kinds
= expected_kinds
,
2356 if retval
.kind
not in expected_kinds
:
2357 if len(expected_kinds
) == 1:
2358 expected_kinds
= expected_kinds
[0]
2359 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2360 f
"{retval.kind} expected {expected_kinds}")
2363 def gpr(self
, ssa_val_or_locs
, is_vec
):
2364 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2365 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2366 vec_str
= "*" if is_vec
else ""
2367 return vec_str
+ str(loc
.start
)
2369 def sgpr(self
, ssa_val_or_locs
):
2370 # type: (__SSA_VAL_OR_LOCS) -> str
2371 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2373 def vgpr(self
, ssa_val_or_locs
):
2374 # type: (__SSA_VAL_OR_LOCS) -> str
2375 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2377 def stack(self
, ssa_val_or_locs
):
2378 # type: (__SSA_VAL_OR_LOCS) -> str
2379 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2380 return f
"{loc.start}(1)"
2382 def writeln(self
, *line_segments
):
2383 # type: (*str) -> None
2384 line
= " ".join(line_segments
)
2385 if isinstance(self
.output
, list):
2386 self
.output
.append(line
)
2388 self
.output
.write(line
+ "\n")