2 from abc
import ABCMeta
, abstractmethod
3 from enum
import Enum
, unique
4 from functools
import lru_cache
, total_ordering
5 from io
import StringIO
6 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
7 Mapping
, Sequence
, TypeVar
, Union
, overload
)
8 from weakref
import WeakValueDictionary
as _WeakVDict
10 from cached_property
import cached_property
11 from nmutil
.plain_data
import fields
, plain_data
13 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
15 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
20 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
21 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
27 self
.ops
= [] # type: list[Op]
28 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
29 self
.__next
_name
_suffix
= 2
31 def _add_op_with_unused_name(self
, op
, name
=""):
32 # type: (Op, str) -> str
34 raise ValueError("can't add Op to wrong Fn")
35 if hasattr(op
, "name"):
36 raise ValueError("Op already named")
39 if name
!= "" and name
not in self
.__op
_names
:
40 self
.__op
_names
[name
] = op
42 name
= orig_name
+ str(self
.__next
_name
_suffix
)
43 self
.__next
_name
_suffix
+= 1
49 def ops_to_str(self
, as_python_literal
=False, wrap_width
=63,
50 python_indent
=" ", indent
=" "):
51 # type: (bool, int, str, str) -> str
52 l
= [] # type: list[str]
54 l
.append(op
.__repr
__(wrap_width
=wrap_width
, indent
=indent
))
57 l
= [python_indent
+ "\""]
60 l
.append(f
"\\n\"\n{python_indent}\"")
63 elif ch
.isascii() and ch
.isprintable():
66 l
.append(repr(ch
).strip("\"'"))
69 empty_end
= f
"\"\n{python_indent}\"\""
70 if retval
.endswith(empty_end
):
71 retval
= retval
[:-len(empty_end
)]
74 def append_op(self
, op
):
77 raise ValueError("can't add Op to wrong Fn")
80 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
82 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
83 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
84 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
85 self
.append_op(retval
)
89 # type: (BaseSimState) -> None
93 def gen_asm(self
, state
):
94 # type: (GenAsmState) -> None
98 def pre_ra_insert_copies(self
):
100 orig_ops
= list(self
.ops
)
101 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
102 setvli_outputs
= {} # type: dict[SSAVal, Op]
105 for i
in range(len(op
.input_vals
)):
106 inp
= copied_outputs
[op
.input_vals
[i
]]
107 if inp
.ty
.base_ty
is BaseTy
.I64
:
108 maxvl
= inp
.ty
.reg_len
109 if inp
.ty
.reg_len
!= 1:
110 setvl
= self
.append_new_op(
111 OpKind
.SetVLI
, immediates
=[maxvl
],
112 name
=f
"{op.name}.inp{i}.setvl")
113 vl
= setvl
.outputs
[0]
114 mv
= self
.append_new_op(
115 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
116 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
118 mv
= self
.append_new_op(
119 OpKind
.CopyToReg
, input_vals
=[inp
],
120 name
=f
"{op.name}.inp{i}.copy")
121 op
.input_vals
[i
] = mv
.outputs
[0]
122 elif inp
.ty
.base_ty
is BaseTy
.CA \
123 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
124 # all copies would be no-ops, so we don't need to copy,
125 # though we do need to rematerialize SetVLI ops right
127 if inp
in setvli_outputs
:
128 setvl
= self
.append_new_op(
130 immediates
=setvli_outputs
[inp
].immediates
,
131 name
=f
"{op.name}.inp{i}.setvl")
132 inp
= setvl
.outputs
[0]
133 op
.input_vals
[i
] = inp
135 assert_never(inp
.ty
.base_ty
)
137 for i
, out
in enumerate(op
.outputs
):
138 if op
.kind
is OpKind
.SetVLI
:
139 setvli_outputs
[out
] = op
140 if out
.ty
.base_ty
is BaseTy
.I64
:
141 maxvl
= out
.ty
.reg_len
142 if out
.ty
.reg_len
!= 1:
143 setvl
= self
.append_new_op(
144 OpKind
.SetVLI
, immediates
=[maxvl
],
145 name
=f
"{op.name}.out{i}.setvl")
146 vl
= setvl
.outputs
[0]
147 mv
= self
.append_new_op(
148 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
149 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
151 mv
= self
.append_new_op(
152 OpKind
.CopyFromReg
, input_vals
=[out
],
153 name
=f
"{op.name}.out{i}.copy")
154 copied_outputs
[out
] = mv
.outputs
[0]
155 elif out
.ty
.base_ty
is BaseTy
.CA \
156 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
157 # all copies would be no-ops, so we don't need to copy
158 copied_outputs
[out
] = out
160 assert_never(out
.ty
.base_ty
)
167 value
: Literal
[0, 1] # type: ignore
169 def __new__(cls
, value
):
170 # type: (int) -> OpStage
172 if value
not in (0, 1):
173 raise ValueError("invalid value")
174 retval
= object.__new
__(cls
)
175 retval
._value
_ = value
179 """ early stage of Op execution, where all input reads occur.
180 all output writes with `write_stage == Early` occur here too, and therefore
181 conflict with input reads, telling the compiler that it that can't share
182 that output's register with any inputs that the output isn't tied to.
184 All outputs, even unused outputs, can't share registers with any other
185 outputs, independent of `write_stage` settings.
188 """ late stage of Op execution, where all output writes with
189 `write_stage == Late` occur, and therefore don't conflict with input reads,
190 telling the compiler that any inputs can safely use the same register as
193 All outputs, even unused outputs, can't share registers with any other
194 outputs, independent of `write_stage` settings.
199 return f
"OpStage.{self._name_}"
201 def __lt__(self
, other
):
202 # type: (OpStage | object) -> bool
203 if isinstance(other
, OpStage
):
204 return self
.value
< other
.value
205 return NotImplemented
208 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
211 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
214 class ProgramPoint(metaclass
=InternedMeta
):
215 __slots__
= "op_index", "stage"
217 def __init__(self
, op_index
, stage
):
218 # type: (int, OpStage) -> None
219 self
.op_index
= op_index
225 """ an integer representation of `self` such that it keeps ordering and
226 successor/predecessor relations.
228 return self
.op_index
* 2 + self
.stage
.value
231 def from_int_value(int_value
):
232 # type: (int) -> ProgramPoint
233 op_index
, stage
= divmod(int_value
, 2)
234 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
236 def next(self
, steps
=1):
237 # type: (int) -> ProgramPoint
238 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
240 def prev(self
, steps
=1):
241 # type: (int) -> ProgramPoint
242 return self
.next(steps
=-steps
)
244 def __lt__(self
, other
):
245 # type: (ProgramPoint | Any) -> bool
246 if not isinstance(other
, ProgramPoint
):
247 return NotImplemented
248 if self
.op_index
!= other
.op_index
:
249 return self
.op_index
< other
.op_index
250 return self
.stage
< other
.stage
254 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
257 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
259 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
260 __slots__
= "start", "stop"
262 def __init__(self
, start
, stop
):
263 # type: (ProgramPoint, ProgramPoint) -> None
268 def int_value_range(self
):
270 return range(self
.start
.int_value
, self
.stop
.int_value
)
273 def from_int_value_range(int_value_range
):
274 # type: (range) -> ProgramRange
275 if int_value_range
.step
!= 1:
276 raise ValueError("int_value_range must have step == 1")
278 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
279 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
282 def __getitem__(self
, __idx
):
283 # type: (int) -> ProgramPoint
287 def __getitem__(self
, __idx
):
288 # type: (slice) -> ProgramRange
291 def __getitem__(self
, __idx
):
292 # type: (int | slice) -> ProgramPoint | ProgramRange
293 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
294 if isinstance(v
, int):
295 return ProgramPoint
.from_int_value(v
)
296 return ProgramRange
.from_int_value_range(v
)
300 return len(self
.int_value_range
)
303 # type: () -> Iterator[ProgramPoint]
304 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
308 start
= repr(self
.start
).lstrip("<").rstrip(">")
309 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
310 return f
"<range:{start}..{stop}>"
313 @plain_data(frozen
=True, eq
=False, repr=False)
316 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
317 "def_program_ranges", "use_program_points",
318 "all_program_points")
320 def __init__(self
, fn
):
323 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
324 self
.all_program_points
= ProgramRange(
325 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
326 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
327 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
328 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
329 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
330 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
332 for use
in op
.input_uses
:
333 uses
[use
.ssa_val
].add(use
)
334 use_program_point
= self
.__get
_use
_program
_point
(use
)
335 use_program_points
[use
] = use_program_point
336 live_range_stops
[use
.ssa_val
] = max(
337 live_range_stops
[use
.ssa_val
], use_program_point
.next())
338 for out
in op
.outputs
:
340 def_program_range
= self
.__get
_def
_program
_range
(out
)
341 def_program_ranges
[out
] = def_program_range
342 live_range_stops
[out
] = def_program_range
.stop
343 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
344 self
.def_program_ranges
= FMap(def_program_ranges
)
345 self
.use_program_points
= FMap(use_program_points
)
346 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
347 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
348 for ssa_val
in uses
.keys():
349 live_ranges
[ssa_val
] = live_range
= ProgramRange(
350 start
=self
.def_program_ranges
[ssa_val
].start
,
351 stop
=live_range_stops
[ssa_val
])
352 for program_point
in live_range
:
353 live_at
[program_point
].add(ssa_val
)
354 self
.live_ranges
= FMap(live_ranges
)
355 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
357 def __get_def_program_range(self
, ssa_val
):
358 # type: (SSAVal) -> ProgramRange
359 write_stage
= ssa_val
.defining_descriptor
.write_stage
360 start
= ProgramPoint(
361 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
362 # always include late stage of ssa_val.op, to ensure outputs always
363 # overlap all other outputs.
364 # stop is exclusive, so we need the next program point.
365 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
366 return ProgramRange(start
=start
, stop
=stop
)
368 def __get_use_program_point(self
, ssa_use
):
369 # type: (SSAUse) -> ProgramPoint
370 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
371 "assumed here, ensured by GenericOpProperties.__init__"
373 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
375 def __eq__(self
, other
):
376 # type: (FnAnalysis | Any) -> bool
377 if isinstance(other
, FnAnalysis
):
378 return self
.fn
== other
.fn
379 return NotImplemented
387 return "<FnAnalysis>"
395 VL_MAXVL
= enum
.auto()
398 def only_scalar(self
):
400 if self
is BaseTy
.I64
:
402 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
408 def max_reg_len(self
):
410 if self
is BaseTy
.I64
:
412 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
418 return "BaseTy." + self
._name
_
421 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
423 class Ty(metaclass
=InternedMeta
):
424 __slots__
= "base_ty", "reg_len"
427 def validate(base_ty
, reg_len
):
428 # type: (BaseTy, int) -> str | None
429 """ return a string with the error if the combination is invalid,
430 otherwise return None
432 if base_ty
.only_scalar
and reg_len
!= 1:
433 return f
"can't create a vector of an only-scalar type: {base_ty}"
434 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
435 return "reg_len out of range"
438 def __init__(self
, base_ty
, reg_len
):
439 # type: (BaseTy, int) -> None
440 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
442 raise ValueError(msg
)
443 self
.base_ty
= base_ty
444 self
.reg_len
= reg_len
448 if self
.reg_len
!= 1:
449 reg_len
= f
"*{self.reg_len}"
452 return f
"<{self.base_ty._name_}{reg_len}>"
459 StackI64
= enum
.auto()
461 VL_MAXVL
= enum
.auto()
466 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
468 if self
is LocKind
.CA
:
470 if self
is LocKind
.VL_MAXVL
:
471 return BaseTy
.VL_MAXVL
478 if self
is LocKind
.StackI64
:
480 if self
is LocKind
.GPR
or self
is LocKind
.CA \
481 or self
is LocKind
.VL_MAXVL
:
482 return self
.base_ty
.max_reg_len
487 return "LocKind." + self
._name
_
492 class LocSubKind(Enum
):
493 BASE_GPR
= enum
.auto()
494 SV_EXTRA2_VGPR
= enum
.auto()
495 SV_EXTRA2_SGPR
= enum
.auto()
496 SV_EXTRA3_VGPR
= enum
.auto()
497 SV_EXTRA3_SGPR
= enum
.auto()
498 StackI64
= enum
.auto()
500 VL_MAXVL
= enum
.auto()
504 # type: () -> LocKind
505 # pyright fails typechecking when using `in` here:
506 # reported: https://github.com/microsoft/pyright/issues/4102
507 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
508 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
509 LocSubKind
.SV_EXTRA3_SGPR
):
511 if self
is LocSubKind
.StackI64
:
512 return LocKind
.StackI64
513 if self
is LocSubKind
.CA
:
515 if self
is LocSubKind
.VL_MAXVL
:
516 return LocKind
.VL_MAXVL
521 return self
.kind
.base_ty
524 def allocatable_locs(self
, ty
):
525 # type: (Ty) -> LocSet
526 if ty
.base_ty
!= self
.base_ty
:
527 raise ValueError("type mismatch")
528 if self
is LocSubKind
.BASE_GPR
:
530 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
531 starts
= range(0, 128, 2)
532 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
534 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
535 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
537 elif self
is LocSubKind
.StackI64
:
538 starts
= range(LocKind
.StackI64
.loc_count
)
539 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
540 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
543 retval
= [] # type: list[Loc]
545 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
549 for special_loc
in SPECIAL_GPRS
:
550 if loc
.conflicts(special_loc
):
555 return LocSet(retval
)
558 return "LocSubKind." + self
._name
_
561 @plain_data(frozen
=True, unsafe_hash
=True)
563 class GenericTy(metaclass
=InternedMeta
):
564 __slots__
= "base_ty", "is_vec"
566 def __init__(self
, base_ty
, is_vec
):
567 # type: (BaseTy, bool) -> None
568 self
.base_ty
= base_ty
569 if base_ty
.only_scalar
and is_vec
:
570 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
573 def instantiate(self
, maxvl
):
575 # here's where subvl and elwid would be accounted for
577 return Ty(self
.base_ty
, maxvl
)
578 return Ty(self
.base_ty
, 1)
580 def can_instantiate_to(self
, ty
):
582 if self
.base_ty
!= ty
.base_ty
:
586 return ty
.reg_len
== 1
589 @plain_data(frozen
=True, unsafe_hash
=True)
591 class Loc(metaclass
=InternedMeta
):
592 __slots__
= "kind", "start", "reg_len"
595 def validate(kind
, start
, reg_len
):
596 # type: (LocKind, int, int) -> str | None
597 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
600 if reg_len
> kind
.loc_count
:
601 return "invalid reg_len"
602 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
603 return "start not in valid range"
607 def try_make(kind
, start
, reg_len
):
608 # type: (LocKind, int, int) -> Loc | None
609 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
612 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
614 def __init__(self
, kind
, start
, reg_len
):
615 # type: (LocKind, int, int) -> None
616 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
618 raise ValueError(msg
)
620 self
.reg_len
= reg_len
623 def conflicts(self
, other
):
624 # type: (Loc) -> bool
625 return (self
.kind
== other
.kind
626 and self
.start
< other
.stop
and other
.start
< self
.stop
)
629 def make_ty(kind
, reg_len
):
630 # type: (LocKind, int) -> Ty
631 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
636 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
641 return self
.start
+ self
.reg_len
643 def try_concat(self
, *others
):
644 # type: (*Loc | None) -> Loc | None
645 reg_len
= self
.reg_len
648 if other
is None or other
.kind
!= self
.kind
:
650 if stop
!= other
.start
:
653 reg_len
+= other
.reg_len
654 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
656 def get_subloc_at_offset(self
, subloc_ty
, offset
):
657 # type: (Ty, int) -> Loc
658 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
659 raise ValueError("BaseTy mismatch")
660 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
661 raise ValueError("invalid sub-Loc: offset and/or "
662 "subloc_ty.reg_len out of range")
663 return Loc(kind
=self
.kind
,
664 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
668 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
669 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
670 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
671 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
676 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
677 def __init__(self
, __locs
=()):
678 # type: (Iterable[Loc]) -> None
679 super().__init
__(__locs
)
680 if isinstance(__locs
, LocSet
):
681 self
.__starts
= __locs
.starts
682 self
.__ty
= __locs
.ty
684 starts
= {i
: BitSet() for i
in LocKind
}
685 ty
= None # type: None | Ty
690 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
691 starts
[loc
.kind
].add(loc
.start
)
692 self
.__starts
= FMap(
693 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
698 # type: () -> FMap[LocKind, FBitSet]
703 # type: () -> Ty | None
708 # type: () -> FMap[LocKind, FBitSet]
713 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
717 # type: () -> AbstractSet[LocKind]
718 return self
.starts
.keys()
722 # type: () -> int | None
725 return self
.ty
.reg_len
729 # type: () -> BaseTy | None
732 return self
.ty
.base_ty
734 def concat(self
, *others
):
735 # type: (*LocSet) -> LocSet
738 base_ty
= self
.ty
.base_ty
739 reg_len
= self
.ty
.reg_len
740 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
744 if other
.ty
.base_ty
!= base_ty
:
746 for kind
, other_starts
in other
.starts
.items():
747 if kind
not in starts
:
749 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
750 if starts
[kind
] == 0:
754 reg_len
+= other
.ty
.reg_len
757 # type: () -> Iterable[Loc]
758 for kind
, v
in starts
.items():
760 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
763 return LocSet(locs())
765 @lru_cache(maxsize
=None, typed
=True)
766 def max_conflicts_with(self
, other
):
767 # type: (LocSet | Loc) -> int
768 """the largest number of Locs in `self` that a single Loc
769 from `other` can conflict with
771 if isinstance(other
, LocSet
):
772 return max(self
.max_conflicts_with(i
) for i
in other
)
774 return sum(other
.conflicts(i
) for i
in self
)
777 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
780 @plain_data(frozen
=True, unsafe_hash
=True)
782 class GenericOperandDesc(metaclass
=InternedMeta
):
783 """generic Op operand descriptor"""
784 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
788 self
, ty
, # type: GenericTy
789 sub_kinds
, # type: Iterable[LocSubKind]
791 fixed_loc
=None, # type: Loc | None
792 tied_input_index
=None, # type: int | None
793 spread
=False, # type: bool
794 write_stage
=OpStage
.Early
, # type: OpStage
796 # type: (...) -> None
798 self
.sub_kinds
= OFSet(sub_kinds
)
799 if len(self
.sub_kinds
) == 0:
800 raise ValueError("sub_kinds can't be empty")
801 self
.fixed_loc
= fixed_loc
802 if fixed_loc
is not None:
803 if tied_input_index
is not None:
804 raise ValueError("operand can't be both tied and fixed")
805 if not ty
.can_instantiate_to(fixed_loc
.ty
):
807 f
"fixed_loc has incompatible type for given generic "
808 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
809 if len(self
.sub_kinds
) != 1:
811 "multiple sub_kinds not allowed for fixed operand")
812 for sub_kind
in self
.sub_kinds
:
813 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
815 f
"fixed_loc not in given sub_kind: "
816 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
817 for sub_kind
in self
.sub_kinds
:
818 if sub_kind
.base_ty
!= ty
.base_ty
:
819 raise ValueError(f
"sub_kind is incompatible with type: "
820 f
"sub_kind={sub_kind} ty={ty}")
821 if tied_input_index
is not None and tied_input_index
< 0:
822 raise ValueError("invalid tied_input_index")
823 self
.tied_input_index
= tied_input_index
826 if self
.tied_input_index
is not None:
827 raise ValueError("operand can't be both spread and tied")
828 if self
.fixed_loc
is not None:
829 raise ValueError("operand can't be both spread and fixed")
831 raise ValueError("operand can't be both spread and vector")
832 self
.write_stage
= write_stage
835 def ty_before_spread(self
):
836 # type: () -> GenericTy
838 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
841 def tied_to_input(self
, tied_input_index
):
842 # type: (int) -> Self
843 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
844 tied_input_index
=tied_input_index
,
845 write_stage
=self
.write_stage
)
847 def with_fixed_loc(self
, fixed_loc
):
848 # type: (Loc) -> Self
849 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
850 write_stage
=self
.write_stage
)
852 def with_write_stage(self
, write_stage
):
853 # type: (OpStage) -> Self
854 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
855 fixed_loc
=self
.fixed_loc
,
856 tied_input_index
=self
.tied_input_index
,
858 write_stage
=write_stage
)
860 def instantiate(self
, maxvl
):
861 # type: (int) -> Iterable[OperandDesc]
862 # assumes all spread operands have ty.reg_len = 1
866 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
868 def locs_before_spread():
869 # type: () -> Iterable[Loc]
870 if self
.fixed_loc
is not None:
871 if ty_before_spread
!= self
.fixed_loc
.ty
:
873 f
"instantiation failed: type mismatch with fixed_loc: "
874 f
"instantiated type: {ty_before_spread} "
875 f
"fixed_loc: {self.fixed_loc}")
878 for sub_kind
in self
.sub_kinds
:
879 yield from sub_kind
.allocatable_locs(ty_before_spread
)
880 loc_set_before_spread
= LocSet(locs_before_spread())
881 for idx
in range(rep_count
):
884 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
885 tied_input_index
=self
.tied_input_index
,
886 spread_index
=idx
, write_stage
=self
.write_stage
)
889 @plain_data(frozen
=True, unsafe_hash
=True)
891 class OperandDesc(metaclass
=InternedMeta
):
892 """Op operand descriptor"""
893 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
896 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
898 # type: (LocSet, int | None, int | None, OpStage) -> None
899 if len(loc_set_before_spread
) == 0:
900 raise ValueError("loc_set_before_spread must not be empty")
901 self
.loc_set_before_spread
= loc_set_before_spread
902 self
.tied_input_index
= tied_input_index
903 if self
.tied_input_index
is not None and spread_index
is not None:
904 raise ValueError("operand can't be both spread and tied")
905 self
.spread_index
= spread_index
906 self
.write_stage
= write_stage
909 def ty_before_spread(self
):
911 ty
= self
.loc_set_before_spread
.ty
912 assert ty
is not None, (
913 "__init__ checked that the LocSet isn't empty, "
914 "non-empty LocSets should always have ty set")
919 """ Ty after any spread is applied """
920 if self
.spread_index
is not None:
921 # assumes all spread operands have ty.reg_len = 1
922 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
923 return self
.ty_before_spread
926 def reg_offset_in_unspread(self
):
927 """ the number of reg-sized slots in the unspread Loc before self's Loc
929 e.g. if the unspread Loc containing self is:
930 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
931 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
932 then reg_offset_into_unspread == 2 == 10 - 8
934 if self
.spread_index
is None:
936 return self
.spread_index
* self
.ty
.reg_len
939 OD_BASE_SGPR
= GenericOperandDesc(
940 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
941 sub_kinds
=[LocSubKind
.BASE_GPR
])
942 OD_EXTRA3_SGPR
= GenericOperandDesc(
943 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
944 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
945 OD_EXTRA3_VGPR
= GenericOperandDesc(
946 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
947 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
948 OD_EXTRA2_SGPR
= GenericOperandDesc(
949 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
950 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
951 OD_EXTRA2_VGPR
= GenericOperandDesc(
952 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
953 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
954 OD_CA
= GenericOperandDesc(
955 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
956 sub_kinds
=[LocSubKind
.CA
])
957 OD_VL
= GenericOperandDesc(
958 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
959 sub_kinds
=[LocSubKind
.VL_MAXVL
])
962 @plain_data(frozen
=True, unsafe_hash
=True)
964 class GenericOpProperties(metaclass
=InternedMeta
):
965 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
966 "is_copy", "is_load_immediate", "has_side_effects")
969 self
, demo_asm
, # type: str
970 inputs
, # type: Iterable[GenericOperandDesc]
971 outputs
, # type: Iterable[GenericOperandDesc]
972 immediates
=(), # type: Iterable[range]
973 is_copy
=False, # type: bool
974 is_load_immediate
=False, # type: bool
975 has_side_effects
=False, # type: bool
977 # type: (...) -> None
978 self
.demo_asm
= demo_asm
# type: str
979 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
980 for inp
in self
.inputs
:
981 if inp
.tied_input_index
is not None:
983 f
"tied_input_index is not allowed on inputs: {inp}")
984 if inp
.write_stage
is not OpStage
.Early
:
986 f
"write_stage is not allowed on inputs: {inp}")
987 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
988 fixed_locs
= [] # type: list[tuple[Loc, int]]
989 for idx
, out
in enumerate(self
.outputs
):
990 if out
.tied_input_index
is not None:
991 if out
.tied_input_index
>= len(self
.inputs
):
992 raise ValueError(f
"tied_input_index out of range: {out}")
993 tied_inp
= self
.inputs
[out
.tied_input_index
]
994 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
995 .with_write_stage(out
.write_stage
)
996 if expected_out
!= out
:
997 raise ValueError(f
"output can't be tied to non-equivalent "
998 f
"input: {out} tied to {tied_inp}")
999 if out
.fixed_loc
is not None:
1000 for other_fixed_loc
, other_idx
in fixed_locs
:
1001 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
1004 f
"conflicting fixed_locs: outputs[{idx}] and "
1005 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1006 f
"with {other_fixed_loc}")
1007 fixed_locs
.append((out
.fixed_loc
, idx
))
1008 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
1009 self
.is_copy
= is_copy
# type: bool
1010 self
.is_load_immediate
= is_load_immediate
# type: bool
1011 self
.has_side_effects
= has_side_effects
# type: bool
1014 @plain_data(frozen
=True, unsafe_hash
=True)
1016 class OpProperties(metaclass
=InternedMeta
):
1017 __slots__
= "kind", "inputs", "outputs", "maxvl"
1019 def __init__(self
, kind
, maxvl
):
1020 # type: (OpKind, int) -> None
1021 self
.kind
= kind
# type: OpKind
1022 inputs
= [] # type: list[OperandDesc]
1023 for inp
in self
.generic
.inputs
:
1024 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1025 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1026 outputs
= [] # type: list[OperandDesc]
1027 for out
in self
.generic
.outputs
:
1028 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1029 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1030 self
.maxvl
= maxvl
# type: int
1034 # type: () -> GenericOpProperties
1035 return self
.kind
.properties
1038 def immediates(self
):
1039 # type: () -> tuple[range, ...]
1040 return self
.generic
.immediates
1045 return self
.generic
.demo_asm
1050 return self
.generic
.is_copy
1053 def is_load_immediate(self
):
1055 return self
.generic
.is_load_immediate
1058 def has_side_effects(self
):
1060 return self
.generic
.has_side_effects
1063 IMM_S16
= range(-1 << 15, 1 << 15)
1065 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1066 _SIM_FN2
= Callable
[[], _SIM_FN
]
1067 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1068 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1069 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1070 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1076 def __init__(self
, properties
):
1077 # type: (GenericOpProperties) -> None
1079 self
.__properties
= properties
1082 def properties(self
):
1083 # type: () -> GenericOpProperties
1084 return self
.__properties
1086 def instantiate(self
, maxvl
):
1087 # type: (int) -> OpProperties
1088 return OpProperties(self
, maxvl
=maxvl
)
1092 return "OpKind." + self
._name
_
1096 # type: () -> _SIM_FN
1097 return _SIM_FNS
[self
.properties
]()
1101 # type: () -> _GEN_ASM_FN
1102 return _GEN_ASMS
[self
.properties
]()
1105 def __clearca_sim(op
, state
):
1106 # type: (Op, BaseSimState) -> None
1107 state
[op
.outputs
[0]] = False,
1110 def __clearca_gen_asm(op
, state
):
1111 # type: (Op, GenAsmState) -> None
1112 state
.writeln("addic 0, 0, 0")
1113 ClearCA
= GenericOpProperties(
1114 demo_asm
="addic 0, 0, 0",
1116 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1118 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1119 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1122 def __setca_sim(op
, state
):
1123 # type: (Op, BaseSimState) -> None
1124 state
[op
.outputs
[0]] = True,
1127 def __setca_gen_asm(op
, state
):
1128 # type: (Op, GenAsmState) -> None
1129 state
.writeln("subfc 0, 0, 0")
1130 SetCA
= GenericOpProperties(
1131 demo_asm
="subfc 0, 0, 0",
1133 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1135 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1136 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1139 def __svadde_sim(op
, state
):
1140 # type: (Op, BaseSimState) -> None
1141 RA
= state
[op
.input_vals
[0]]
1142 RB
= state
[op
.input_vals
[1]]
1143 carry
, = state
[op
.input_vals
[2]]
1144 VL
, = state
[op
.input_vals
[3]]
1145 RT
= [] # type: list[int]
1147 v
= RA
[i
] + RB
[i
] + carry
1148 RT
.append(v
& GPR_VALUE_MASK
)
1149 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1150 state
[op
.outputs
[0]] = tuple(RT
)
1151 state
[op
.outputs
[1]] = carry
,
1154 def __svadde_gen_asm(op
, state
):
1155 # type: (Op, GenAsmState) -> None
1156 RT
= state
.vgpr(op
.outputs
[0])
1157 RA
= state
.vgpr(op
.input_vals
[0])
1158 RB
= state
.vgpr(op
.input_vals
[1])
1159 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1160 SvAddE
= GenericOpProperties(
1161 demo_asm
="sv.adde *RT, *RA, *RB",
1162 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1163 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1165 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1166 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1169 def __addze_sim(op
, state
):
1170 # type: (Op, BaseSimState) -> None
1171 RA
, = state
[op
.input_vals
[0]]
1172 carry
, = state
[op
.input_vals
[1]]
1174 RT
= v
& GPR_VALUE_MASK
1175 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1176 state
[op
.outputs
[0]] = RT
,
1177 state
[op
.outputs
[1]] = carry
,
1180 def __addze_gen_asm(op
, state
):
1181 # type: (Op, GenAsmState) -> None
1182 RT
= state
.vgpr(op
.outputs
[0])
1183 RA
= state
.vgpr(op
.input_vals
[0])
1184 state
.writeln(f
"addze {RT}, {RA}")
1185 AddZE
= GenericOpProperties(
1186 demo_asm
="addze RT, RA",
1187 inputs
=[OD_BASE_SGPR
, OD_CA
],
1188 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1190 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1191 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1194 def __svsubfe_sim(op
, state
):
1195 # type: (Op, BaseSimState) -> None
1196 RA
= state
[op
.input_vals
[0]]
1197 RB
= state
[op
.input_vals
[1]]
1198 carry
, = state
[op
.input_vals
[2]]
1199 VL
, = state
[op
.input_vals
[3]]
1200 RT
= [] # type: list[int]
1202 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1203 RT
.append(v
& GPR_VALUE_MASK
)
1204 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1205 state
[op
.outputs
[0]] = tuple(RT
)
1206 state
[op
.outputs
[1]] = carry
,
1209 def __svsubfe_gen_asm(op
, state
):
1210 # type: (Op, GenAsmState) -> None
1211 RT
= state
.vgpr(op
.outputs
[0])
1212 RA
= state
.vgpr(op
.input_vals
[0])
1213 RB
= state
.vgpr(op
.input_vals
[1])
1214 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1215 SvSubFE
= GenericOpProperties(
1216 demo_asm
="sv.subfe *RT, *RA, *RB",
1217 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1218 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1220 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1221 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1224 def __svandvs_sim(op
, state
):
1225 # type: (Op, BaseSimState) -> None
1226 RA
= state
[op
.input_vals
[0]]
1227 RB
, = state
[op
.input_vals
[1]]
1228 VL
, = state
[op
.input_vals
[2]]
1229 RT
= [] # type: list[int]
1231 RT
.append(RA
[i
] & RB
& GPR_VALUE_MASK
)
1232 state
[op
.outputs
[0]] = tuple(RT
)
1235 def __svandvs_gen_asm(op
, state
):
1236 # type: (Op, GenAsmState) -> None
1237 RT
= state
.vgpr(op
.outputs
[0])
1238 RA
= state
.vgpr(op
.input_vals
[0])
1239 RB
= state
.sgpr(op
.input_vals
[1])
1240 state
.writeln(f
"sv.and {RT}, {RA}, {RB}")
1241 SvAndVS
= GenericOpProperties(
1242 demo_asm
="sv.and *RT, *RA, RB",
1243 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1244 outputs
=[OD_EXTRA3_VGPR
],
1246 _SIM_FNS
[SvAndVS
] = lambda: OpKind
.__svandvs
_sim
1247 _GEN_ASMS
[SvAndVS
] = lambda: OpKind
.__svandvs
_gen
_asm
1250 def __svmaddedu_sim(op
, state
):
1251 # type: (Op, BaseSimState) -> None
1252 RA
= state
[op
.input_vals
[0]]
1253 RB
, = state
[op
.input_vals
[1]]
1254 carry
, = state
[op
.input_vals
[2]]
1255 VL
, = state
[op
.input_vals
[3]]
1256 RT
= [] # type: list[int]
1258 v
= RA
[i
] * RB
+ carry
1259 RT
.append(v
& GPR_VALUE_MASK
)
1260 carry
= v
>> GPR_SIZE_IN_BITS
1261 state
[op
.outputs
[0]] = tuple(RT
)
1262 state
[op
.outputs
[1]] = carry
,
1265 def __svmaddedu_gen_asm(op
, state
):
1266 # type: (Op, GenAsmState) -> None
1267 RT
= state
.vgpr(op
.outputs
[0])
1268 RA
= state
.vgpr(op
.input_vals
[0])
1269 RB
= state
.sgpr(op
.input_vals
[1])
1270 RC
= state
.sgpr(op
.input_vals
[2])
1271 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1272 SvMAddEDU
= GenericOpProperties(
1273 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1274 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1275 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1277 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1278 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1281 def __sradi_sim(op
, state
):
1282 # type: (Op, BaseSimState) -> None
1283 rs
, = state
[op
.input_vals
[0]]
1284 imm
= op
.immediates
[0]
1285 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1286 rs
-= 1 << GPR_SIZE_IN_BITS
1288 RA
= v
& GPR_VALUE_MASK
1289 CA
= (RA
<< imm
) != rs
1290 state
[op
.outputs
[0]] = RA
,
1291 state
[op
.outputs
[1]] = CA
,
1294 def __sradi_gen_asm(op
, state
):
1295 # type: (Op, GenAsmState) -> None
1296 RA
= state
.sgpr(op
.outputs
[0])
1297 RS
= state
.sgpr(op
.input_vals
[0])
1298 imm
= op
.immediates
[0]
1299 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1300 SRADI
= GenericOpProperties(
1301 demo_asm
="sradi RA, RS, imm",
1302 inputs
=[OD_BASE_SGPR
],
1303 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1304 OD_CA
.with_write_stage(OpStage
.Late
)],
1305 immediates
=[range(GPR_SIZE_IN_BITS
)],
1307 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1308 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1311 def __setvli_sim(op
, state
):
1312 # type: (Op, BaseSimState) -> None
1313 state
[op
.outputs
[0]] = op
.immediates
[0],
1316 def __setvli_gen_asm(op
, state
):
1317 # type: (Op, GenAsmState) -> None
1318 imm
= op
.immediates
[0]
1319 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1320 SetVLI
= GenericOpProperties(
1321 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1323 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1324 immediates
=[range(1, 65)],
1325 is_load_immediate
=True,
1327 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1328 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1331 def __svli_sim(op
, state
):
1332 # type: (Op, BaseSimState) -> None
1333 VL
, = state
[op
.input_vals
[0]]
1334 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1335 state
[op
.outputs
[0]] = (imm
,) * VL
1338 def __svli_gen_asm(op
, state
):
1339 # type: (Op, GenAsmState) -> None
1340 RT
= state
.vgpr(op
.outputs
[0])
1341 imm
= op
.immediates
[0]
1342 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1343 SvLI
= GenericOpProperties(
1344 demo_asm
="sv.addi *RT, 0, imm",
1346 outputs
=[OD_EXTRA3_VGPR
],
1347 immediates
=[IMM_S16
],
1348 is_load_immediate
=True,
1350 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1351 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1354 def __li_sim(op
, state
):
1355 # type: (Op, BaseSimState) -> None
1356 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1357 state
[op
.outputs
[0]] = imm
,
1360 def __li_gen_asm(op
, state
):
1361 # type: (Op, GenAsmState) -> None
1362 RT
= state
.sgpr(op
.outputs
[0])
1363 imm
= op
.immediates
[0]
1364 state
.writeln(f
"addi {RT}, 0, {imm}")
1365 LI
= GenericOpProperties(
1366 demo_asm
="addi RT, 0, imm",
1368 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1369 immediates
=[IMM_S16
],
1370 is_load_immediate
=True,
1372 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1373 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1376 def __veccopytoreg_sim(op
, state
):
1377 # type: (Op, BaseSimState) -> None
1378 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1381 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1382 # type: (Loc, Loc, bool, GenAsmState) -> None
1383 sv
= "sv." if is_vec
else ""
1385 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1387 if src_loc
== dest_loc
:
1389 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1390 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1391 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1392 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1393 if src_loc
.kind
is LocKind
.StackI64
:
1394 if dest_loc
.kind
is LocKind
.StackI64
:
1396 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1397 elif dest_loc
.kind
is not LocKind
.GPR
:
1398 assert_never(dest_loc
.kind
)
1399 src
= state
.stack(src_loc
)
1400 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1401 state
.writeln(f
"{sv}ld {dest}, {src}")
1402 elif dest_loc
.kind
is LocKind
.StackI64
:
1403 if src_loc
.kind
is not LocKind
.GPR
:
1404 assert_never(src_loc
.kind
)
1405 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1406 dest
= state
.stack(dest_loc
)
1407 state
.writeln(f
"{sv}std {src}, {dest}")
1408 elif src_loc
.kind
is LocKind
.GPR
:
1409 if dest_loc
.kind
is not LocKind
.GPR
:
1410 assert_never(dest_loc
.kind
)
1411 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1412 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1413 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1415 assert_never(src_loc
.kind
)
1418 def __veccopytoreg_gen_asm(op
, state
):
1419 # type: (Op, GenAsmState) -> None
1420 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1422 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1423 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1424 is_vec
=True, state
=state
)
1426 VecCopyToReg
= GenericOpProperties(
1427 demo_asm
="sv.mv dest, src",
1428 inputs
=[GenericOperandDesc(
1429 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1430 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1432 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1435 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1436 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1439 def __veccopyfromreg_sim(op
, state
):
1440 # type: (Op, BaseSimState) -> None
1441 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1444 def __veccopyfromreg_gen_asm(op
, state
):
1445 # type: (Op, GenAsmState) -> None
1446 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1447 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1449 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1450 is_vec
=True, state
=state
)
1451 VecCopyFromReg
= GenericOpProperties(
1452 demo_asm
="sv.mv dest, src",
1453 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1454 outputs
=[GenericOperandDesc(
1455 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1456 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1457 write_stage
=OpStage
.Late
,
1461 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1462 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1465 def __copytoreg_sim(op
, state
):
1466 # type: (Op, BaseSimState) -> None
1467 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1470 def __copytoreg_gen_asm(op
, state
):
1471 # type: (Op, GenAsmState) -> None
1472 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1474 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1475 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1476 is_vec
=False, state
=state
)
1477 CopyToReg
= GenericOpProperties(
1478 demo_asm
="mv dest, src",
1479 inputs
=[GenericOperandDesc(
1480 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1481 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1482 LocSubKind
.StackI64
],
1484 outputs
=[GenericOperandDesc(
1485 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1486 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1487 write_stage
=OpStage
.Late
,
1491 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1492 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1495 def __copyfromreg_sim(op
, state
):
1496 # type: (Op, BaseSimState) -> None
1497 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1500 def __copyfromreg_gen_asm(op
, state
):
1501 # type: (Op, GenAsmState) -> None
1502 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1503 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1505 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1506 is_vec
=False, state
=state
)
1507 CopyFromReg
= GenericOpProperties(
1508 demo_asm
="mv dest, src",
1509 inputs
=[GenericOperandDesc(
1510 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1511 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1513 outputs
=[GenericOperandDesc(
1514 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1515 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1516 LocSubKind
.StackI64
],
1517 write_stage
=OpStage
.Late
,
1521 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1522 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1525 def __concat_sim(op
, state
):
1526 # type: (Op, BaseSimState) -> None
1527 state
[op
.outputs
[0]] = tuple(
1528 state
[i
][0] for i
in op
.input_vals
[:-1])
1531 def __concat_gen_asm(op
, state
):
1532 # type: (Op, GenAsmState) -> None
1533 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1534 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1535 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1536 is_vec
=True, state
=state
)
1537 Concat
= GenericOpProperties(
1538 demo_asm
="sv.mv dest, src",
1539 inputs
=[GenericOperandDesc(
1540 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1541 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1544 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1547 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1548 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1551 def __spread_sim(op
, state
):
1552 # type: (Op, BaseSimState) -> None
1553 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1554 state
[op
.outputs
[idx
]] = inp
,
1557 def __spread_gen_asm(op
, state
):
1558 # type: (Op, GenAsmState) -> None
1559 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1560 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1561 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1562 is_vec
=True, state
=state
)
1563 Spread
= GenericOpProperties(
1564 demo_asm
="sv.mv dest, src",
1565 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1566 outputs
=[GenericOperandDesc(
1567 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1568 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1570 write_stage
=OpStage
.Late
,
1574 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1575 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1578 def __svld_sim(op
, state
):
1579 # type: (Op, BaseSimState) -> None
1580 RA
, = state
[op
.input_vals
[0]]
1581 VL
, = state
[op
.input_vals
[1]]
1582 addr
= RA
+ op
.immediates
[0]
1583 RT
= [] # type: list[int]
1585 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1586 RT
.append(v
& GPR_VALUE_MASK
)
1587 state
[op
.outputs
[0]] = tuple(RT
)
1590 def __svld_gen_asm(op
, state
):
1591 # type: (Op, GenAsmState) -> None
1592 RA
= state
.sgpr(op
.input_vals
[0])
1593 RT
= state
.vgpr(op
.outputs
[0])
1594 imm
= op
.immediates
[0]
1595 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1596 SvLd
= GenericOpProperties(
1597 demo_asm
="sv.ld *RT, imm(RA)",
1598 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1599 outputs
=[OD_EXTRA3_VGPR
],
1600 immediates
=[IMM_S16
],
1602 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1603 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1606 def __ld_sim(op
, state
):
1607 # type: (Op, BaseSimState) -> None
1608 RA
, = state
[op
.input_vals
[0]]
1609 addr
= RA
+ op
.immediates
[0]
1610 v
= state
.load(addr
)
1611 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1614 def __ld_gen_asm(op
, state
):
1615 # type: (Op, GenAsmState) -> None
1616 RA
= state
.sgpr(op
.input_vals
[0])
1617 RT
= state
.sgpr(op
.outputs
[0])
1618 imm
= op
.immediates
[0]
1619 state
.writeln(f
"ld {RT}, {imm}({RA})")
1620 Ld
= GenericOpProperties(
1621 demo_asm
="ld RT, imm(RA)",
1622 inputs
=[OD_BASE_SGPR
],
1623 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1624 immediates
=[IMM_S16
],
1626 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1627 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1630 def __svstd_sim(op
, state
):
1631 # type: (Op, BaseSimState) -> None
1632 RS
= state
[op
.input_vals
[0]]
1633 RA
, = state
[op
.input_vals
[1]]
1634 VL
, = state
[op
.input_vals
[2]]
1635 addr
= RA
+ op
.immediates
[0]
1637 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1640 def __svstd_gen_asm(op
, state
):
1641 # type: (Op, GenAsmState) -> None
1642 RS
= state
.vgpr(op
.input_vals
[0])
1643 RA
= state
.sgpr(op
.input_vals
[1])
1644 imm
= op
.immediates
[0]
1645 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1646 SvStd
= GenericOpProperties(
1647 demo_asm
="sv.std *RS, imm(RA)",
1648 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1650 immediates
=[IMM_S16
],
1651 has_side_effects
=True,
1653 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1654 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1657 def __std_sim(op
, state
):
1658 # type: (Op, BaseSimState) -> None
1659 RS
, = state
[op
.input_vals
[0]]
1660 RA
, = state
[op
.input_vals
[1]]
1661 addr
= RA
+ op
.immediates
[0]
1662 state
.store(addr
, value
=RS
)
1665 def __std_gen_asm(op
, state
):
1666 # type: (Op, GenAsmState) -> None
1667 RS
= state
.sgpr(op
.input_vals
[0])
1668 RA
= state
.sgpr(op
.input_vals
[1])
1669 imm
= op
.immediates
[0]
1670 state
.writeln(f
"std {RS}, {imm}({RA})")
1671 Std
= GenericOpProperties(
1672 demo_asm
="std RS, imm(RA)",
1673 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1675 immediates
=[IMM_S16
],
1676 has_side_effects
=True,
1678 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1679 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1682 def __funcargr3_sim(op
, state
):
1683 # type: (Op, BaseSimState) -> None
1684 pass # return value set before simulation
1687 def __funcargr3_gen_asm(op
, state
):
1688 # type: (Op, GenAsmState) -> None
1689 pass # no instructions needed
1690 FuncArgR3
= GenericOpProperties(
1693 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1694 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1696 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1697 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1700 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1701 class SSAValOrUse(metaclass
=InternedMeta
):
1702 __slots__
= "op", "operand_idx"
1704 def __init__(self
, op
, operand_idx
):
1705 # type: (Op, int) -> None
1708 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1709 raise ValueError("invalid operand_idx")
1710 self
.operand_idx
= operand_idx
1719 def descriptor_array(self
):
1720 # type: () -> tuple[OperandDesc, ...]
1724 def defining_descriptor(self
):
1725 # type: () -> OperandDesc
1726 return self
.descriptor_array
[self
.operand_idx
]
1731 return self
.defining_descriptor
.ty
1734 def ty_before_spread(self
):
1736 return self
.defining_descriptor
.ty_before_spread
1740 # type: () -> BaseTy
1741 return self
.ty_before_spread
.base_ty
1744 def reg_offset_in_unspread(self
):
1745 """ the number of reg-sized slots in the unspread Loc before self's Loc
1747 e.g. if the unspread Loc containing self is:
1748 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1749 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1750 then reg_offset_into_unspread == 2 == 10 - 8
1752 return self
.defining_descriptor
.reg_offset_in_unspread
1755 def unspread_start_idx(self
):
1757 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1760 def unspread_start(self
):
1762 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1765 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1767 class SSAVal(SSAValOrUse
):
1772 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1775 def def_loc_set_before_spread(self
):
1776 # type: () -> LocSet
1777 return self
.defining_descriptor
.loc_set_before_spread
1780 def descriptor_array(self
):
1781 # type: () -> tuple[OperandDesc, ...]
1782 return self
.op
.properties
.outputs
1785 def tied_input(self
):
1786 # type: () -> None | SSAUse
1787 if self
.defining_descriptor
.tied_input_index
is None:
1789 return SSAUse(op
=self
.op
,
1790 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1793 def write_stage(self
):
1794 # type: () -> OpStage
1795 return self
.defining_descriptor
.write_stage
1798 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1800 class SSAUse(SSAValOrUse
):
1804 def use_loc_set_before_spread(self
):
1805 # type: () -> LocSet
1806 return self
.defining_descriptor
.loc_set_before_spread
1809 def descriptor_array(self
):
1810 # type: () -> tuple[OperandDesc, ...]
1811 return self
.op
.properties
.inputs
1815 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1819 # type: () -> SSAVal
1820 return self
.op
.input_vals
[self
.operand_idx
]
1823 def ssa_val(self
, ssa_val
):
1824 # type: (SSAVal) -> None
1825 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1829 _Desc
= TypeVar("_Desc")
1832 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1834 def _verify_write_with_desc(self
, idx
, item
, desc
):
1835 # type: (int, _T | Any, _Desc) -> None
1836 raise NotImplementedError
1839 def _verify_write(self
, idx
, item
):
1840 # type: (int | Any, _T | Any) -> int
1841 if not isinstance(idx
, int):
1842 if isinstance(idx
, slice):
1844 f
"can't write to slice of {self.__class__.__name__}")
1845 raise TypeError(f
"can't write with index {idx!r}")
1846 # normalize idx, raising IndexError if it is out of range
1847 idx
= range(len(self
.descriptors
))[idx
]
1848 desc
= self
.descriptors
[idx
]
1849 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1852 def _on_set(self
, idx
, new_item
, old_item
):
1853 # type: (int, _T, _T | None) -> None
1857 def _get_descriptors(self
):
1858 # type: () -> tuple[_Desc, ...]
1859 raise NotImplementedError
1863 def descriptors(self
):
1864 # type: () -> tuple[_Desc, ...]
1865 return self
._get
_descriptors
()
1872 def __init__(self
, items
, op
):
1873 # type: (Iterable[_T], Op) -> None
1876 self
.__items
= [] # type: list[_T]
1877 for idx
, item
in enumerate(items
):
1878 if idx
>= len(self
.descriptors
):
1879 raise ValueError("too many items")
1880 _
= self
._verify
_write
(idx
, item
)
1881 self
.__items
.append(item
)
1882 if len(self
.__items
) < len(self
.descriptors
):
1883 raise ValueError("not enough items")
1887 # type: () -> Iterator[_T]
1888 yield from self
.__items
1891 def __getitem__(self
, idx
):
1896 def __getitem__(self
, idx
):
1897 # type: (slice) -> list[_T]
1901 def __getitem__(self
, idx
):
1902 # type: (int | slice) -> _T | list[_T]
1903 return self
.__items
[idx
]
1906 def __setitem__(self
, idx
, item
):
1907 # type: (int, _T) -> None
1908 idx
= self
._verify
_write
(idx
, item
)
1909 self
.__items
[idx
] = item
1914 return len(self
.__items
)
1918 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1922 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
1923 def _get_descriptors(self
):
1924 # type: () -> tuple[OperandDesc, ...]
1925 return self
.op
.properties
.inputs
1927 def _verify_write_with_desc(self
, idx
, item
, desc
):
1928 # type: (int, SSAVal | Any, OperandDesc) -> None
1929 if not isinstance(item
, SSAVal
):
1930 raise TypeError("expected value of type SSAVal")
1931 if item
.ty
!= desc
.ty
:
1932 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1933 f
"corresponding input's type {desc.ty!r}")
1935 def _on_set(self
, idx
, new_item
, old_item
):
1936 # type: (int, SSAVal, SSAVal | None) -> None
1937 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
1939 def __init__(self
, items
, op
):
1940 # type: (Iterable[SSAVal], Op) -> None
1941 if hasattr(op
, "inputs"):
1942 raise ValueError("Op.inputs already set")
1943 super().__init
__(items
, op
)
1947 class OpImmediates(OpInputSeq
[int, range]):
1948 def _get_descriptors(self
):
1949 # type: () -> tuple[range, ...]
1950 return self
.op
.properties
.immediates
1952 def _verify_write_with_desc(self
, idx
, item
, desc
):
1953 # type: (int, int | Any, range) -> None
1954 if not isinstance(item
, int):
1955 raise TypeError("expected value of type int")
1956 if item
not in desc
:
1957 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1959 def __init__(self
, items
, op
):
1960 # type: (Iterable[int], Op) -> None
1961 if hasattr(op
, "immediates"):
1962 raise ValueError("Op.immediates already set")
1963 super().__init
__(items
, op
)
1966 @plain_data(frozen
=True, eq
=False, repr=False)
1969 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
1972 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
1973 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1975 self
.properties
= properties
1976 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
1977 inputs_len
= len(self
.properties
.inputs
)
1978 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
1979 self
.immediates
= OpImmediates(immediates
, op
=self
)
1980 outputs_len
= len(self
.properties
.outputs
)
1981 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1982 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1986 # type: () -> OpKind
1987 return self
.properties
.kind
1989 def __eq__(self
, other
):
1990 # type: (Op | Any) -> bool
1991 if isinstance(other
, Op
):
1992 return self
is other
1993 return NotImplemented
1997 return object.__hash
__(self
)
1999 def __repr__(self
, wrap_width
=63, indent
=" "):
2000 # type: (int, str) -> str
2001 WRAP_POINT
= "\u200B" # zero-width space
2002 items
= [f
"{self.name}:\n"]
2003 for i
, out
in enumerate(self
.outputs
):
2004 item
= f
"<...outputs[{i}]: {out.ty}>"
2006 item
= "(" + WRAP_POINT
+ item
2007 if i
!= len(self
.outputs
) - 1:
2008 item
+= ", " + WRAP_POINT
2010 item
+= WRAP_POINT
+ ") <= "
2012 items
.append(self
.kind
._name
_)
2013 if len(self
.input_vals
) + len(self
.immediates
) != 0:
2015 items
[-1] += WRAP_POINT
2016 for i
, inp
in enumerate(self
.input_vals
):
2018 if i
!= len(self
.input_vals
) - 1 or len(self
.immediates
) != 0:
2019 item
+= ", " + WRAP_POINT
2021 item
+= ") " + WRAP_POINT
2023 for i
, imm
in enumerate(self
.immediates
):
2025 if i
!= len(self
.immediates
) - 1:
2026 item
+= ", " + WRAP_POINT
2028 item
+= ") " + WRAP_POINT
2030 lines
= [] # type: list[str]
2031 for i
, line_in
in enumerate("".join(items
).splitlines()):
2033 line_in
= indent
+ line_in
2035 for part
in line_in
.split(WRAP_POINT
):
2039 trial_line_out
= line_out
+ part
2040 if len(trial_line_out
.rstrip()) > wrap_width
:
2041 lines
.append(line_out
.rstrip())
2042 line_out
= indent
+ part
2044 line_out
= trial_line_out
2045 lines
.append(line_out
.rstrip())
2046 return "\n".join(lines
)
2048 def sim(self
, state
):
2049 # type: (BaseSimState) -> None
2050 for inp
in self
.input_vals
:
2054 raise ValueError(f
"SSAVal {inp} not yet assigned when "
2056 if len(val
) != inp
.ty
.reg_len
:
2058 f
"value of SSAVal {inp} has wrong number of elements: "
2059 f
"expected {inp.ty.reg_len} found "
2060 f
"{len(val)}: {val!r}")
2061 if isinstance(state
, PreRASimState
):
2062 for out
in self
.outputs
:
2063 if out
in state
.ssa_vals
:
2064 if self
.kind
is OpKind
.FuncArgR3
:
2066 raise ValueError(f
"SSAVal {out} already assigned before "
2068 self
.kind
.sim(self
, state
)
2069 for out
in self
.outputs
:
2073 raise ValueError(f
"running {self} failed to assign to {out}")
2074 if len(val
) != out
.ty
.reg_len
:
2076 f
"value of SSAVal {out} has wrong number of elements: "
2077 f
"expected {out.ty.reg_len} found "
2078 f
"{len(val)}: {val!r}")
2080 def gen_asm(self
, state
):
2081 # type: (GenAsmState) -> None
2082 all_loc_kinds
= tuple(LocKind
)
2083 for inp
in self
.input_vals
:
2084 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2085 for out
in self
.outputs
:
2086 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2087 self
.kind
.gen_asm(self
, state
)
2090 @plain_data(frozen
=True, repr=False)
2091 class BaseSimState(metaclass
=ABCMeta
):
2092 __slots__
= "memory",
2094 def __init__(self
, memory
):
2095 # type: (dict[int, int]) -> None
2097 self
.memory
= memory
# type: dict[int, int]
2099 def load_byte(self
, addr
):
2100 # type: (int) -> int
2101 addr
&= GPR_VALUE_MASK
2102 return self
.memory
.get(addr
, 0) & 0xFF
2104 def store_byte(self
, addr
, value
):
2105 # type: (int, int) -> None
2106 addr
&= GPR_VALUE_MASK
2108 self
.memory
[addr
] = value
2110 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2111 # type: (int, int, bool) -> int
2112 if addr
% size_in_bytes
!= 0:
2113 raise ValueError(f
"address not aligned: {hex(addr)} "
2114 f
"required alignment: {size_in_bytes}")
2116 for i
in range(size_in_bytes
):
2117 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2118 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2119 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2122 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2123 # type: (int, int, int) -> None
2124 if addr
% size_in_bytes
!= 0:
2125 raise ValueError(f
"address not aligned: {hex(addr)} "
2126 f
"required alignment: {size_in_bytes}")
2127 for i
in range(size_in_bytes
):
2128 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2130 def _memory__repr(self
):
2132 if len(self
.memory
) == 0:
2134 keys
= sorted(self
.memory
.keys(), reverse
=True)
2135 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2136 items
= [] # type: list[str]
2137 while len(keys
) != 0:
2139 if (len(keys
) >= CHUNK_SIZE
2140 and addr
% CHUNK_SIZE
== 0
2141 and keys
[-CHUNK_SIZE
:]
2142 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2143 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2144 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2145 keys
[-CHUNK_SIZE
:] = ()
2147 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2149 return f
"{{{items[0]}}}"
2150 items_str
= ",\n".join(items
)
2151 return f
"{{\n{items_str}}}"
2155 field_vals
= [] # type: list[str]
2156 for name
in fields(self
):
2158 value
= getattr(self
, name
)
2159 except AttributeError:
2160 field_vals
.append(f
"{name}=<not set>")
2162 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2163 if callable(repr_fn
):
2164 field_vals
.append(f
"{name}={repr_fn()}")
2166 field_vals
.append(f
"{name}={value!r}")
2167 field_vals_str
= ", ".join(field_vals
)
2168 return f
"{self.__class__.__name__}({field_vals_str})"
2171 def __getitem__(self
, ssa_val
):
2172 # type: (SSAVal) -> tuple[int, ...]
2176 def __setitem__(self
, ssa_val
, value
):
2177 # type: (SSAVal, tuple[int, ...]) -> None
2181 @plain_data(frozen
=True, repr=False)
2183 class PreRASimState(BaseSimState
):
2184 __slots__
= "ssa_vals",
2186 def __init__(self
, ssa_vals
, memory
):
2187 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2188 super().__init
__(memory
)
2189 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2191 def _ssa_vals__repr(self
):
2193 if len(self
.ssa_vals
) == 0:
2195 items
= [] # type: list[str]
2197 for k
, v
in self
.ssa_vals
.items():
2198 element_strs
= [] # type: list[str]
2199 for i
, el
in enumerate(v
):
2200 if i
% CHUNK_SIZE
!= 0:
2201 element_strs
.append(" " + hex(el
))
2203 element_strs
.append("\n " + hex(el
))
2204 if len(element_strs
) <= CHUNK_SIZE
:
2205 element_strs
[0] = element_strs
[0].lstrip()
2206 if len(element_strs
) == 1:
2207 element_strs
.append("")
2208 v_str
= ",".join(element_strs
)
2209 items
.append(f
"{k!r}: ({v_str})")
2210 if len(items
) == 1 and "\n" not in items
[0]:
2211 return f
"{{{items[0]}}}"
2212 items_str
= ",\n".join(items
)
2213 return f
"{{\n{items_str},\n}}"
2215 def __getitem__(self
, ssa_val
):
2216 # type: (SSAVal) -> tuple[int, ...]
2217 return self
.ssa_vals
[ssa_val
]
2219 def __setitem__(self
, ssa_val
, value
):
2220 # type: (SSAVal, tuple[int, ...]) -> None
2221 if len(value
) != ssa_val
.ty
.reg_len
:
2222 raise ValueError("value has wrong len")
2223 self
.ssa_vals
[ssa_val
] = value
2226 @plain_data(frozen
=True, repr=False)
2228 class PostRASimState(BaseSimState
):
2229 __slots__
= "ssa_val_to_loc_map", "loc_values"
2231 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2232 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2233 super().__init
__(memory
)
2234 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2235 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2236 if ssa_val
.ty
!= loc
.ty
:
2238 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2239 self
.loc_values
= loc_values
2240 for loc
in self
.loc_values
.keys():
2241 if loc
.reg_len
!= 1:
2243 "loc_values must only contain Locs with reg_len=1, all "
2244 "larger Locs will be split into reg_len=1 sub-Locs")
2246 def _loc_values__repr(self
):
2248 locs
= sorted(self
.loc_values
.keys(), key
=lambda v
: (v
.kind
, v
.start
))
2249 items
= [] # type: list[str]
2251 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2252 items_str
= ",\n".join(items
)
2253 return f
"{{\n{items_str},\n}}"
2255 def __getitem__(self
, ssa_val
):
2256 # type: (SSAVal) -> tuple[int, ...]
2257 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2258 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2259 retval
= [] # type: list[int]
2260 for i
in range(loc
.reg_len
):
2261 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2262 retval
.append(self
.loc_values
.get(subloc
, 0))
2263 return tuple(retval
)
2265 def __setitem__(self
, ssa_val
, value
):
2266 # type: (SSAVal, tuple[int, ...]) -> None
2267 if len(value
) != ssa_val
.ty
.reg_len
:
2268 raise ValueError("value has wrong len")
2269 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2270 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2271 for i
in range(loc
.reg_len
):
2272 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2273 self
.loc_values
[subloc
] = value
[i
]
2276 @plain_data(frozen
=True)
2278 __slots__
= "allocated_locs", "output"
2280 def __init__(self
, allocated_locs
, output
=None):
2281 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2283 self
.allocated_locs
= FMap(allocated_locs
)
2284 for ssa_val
, loc
in self
.allocated_locs
.items():
2285 if ssa_val
.ty
!= loc
.ty
:
2287 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2290 self
.output
= output
2292 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2294 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2295 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2296 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2297 ssa_val_or_locs
= [ssa_val_or_locs
]
2298 locs
= [] # type: list[Loc]
2299 for i
in ssa_val_or_locs
:
2300 if isinstance(i
, SSAVal
):
2301 locs
.append(self
.allocated_locs
[i
])
2305 raise ValueError("invalid Loc sequence: must not be empty")
2306 retval
= locs
[0].try_concat(*locs
[1:])
2308 raise ValueError("invalid Loc sequence: try_concat failed")
2309 if isinstance(expected_kinds
, LocKind
):
2310 expected_kinds
= expected_kinds
,
2311 if retval
.kind
not in expected_kinds
:
2312 if len(expected_kinds
) == 1:
2313 expected_kinds
= expected_kinds
[0]
2314 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2315 f
"{retval.kind} expected {expected_kinds}")
2318 def gpr(self
, ssa_val_or_locs
, is_vec
):
2319 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2320 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2321 vec_str
= "*" if is_vec
else ""
2322 return vec_str
+ str(loc
.start
)
2324 def sgpr(self
, ssa_val_or_locs
):
2325 # type: (__SSA_VAL_OR_LOCS) -> str
2326 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2328 def vgpr(self
, ssa_val_or_locs
):
2329 # type: (__SSA_VAL_OR_LOCS) -> str
2330 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2332 def stack(self
, ssa_val_or_locs
):
2333 # type: (__SSA_VAL_OR_LOCS) -> str
2334 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2335 return f
"{loc.start}(1)"
2337 def writeln(self
, *line_segments
):
2338 # type: (*str) -> None
2339 line
= " ".join(line_segments
)
2340 if isinstance(self
.output
, list):
2341 self
.output
.append(line
)
2343 self
.output
.write(line
+ "\n")