1 from collections
import defaultdict
3 from abc
import ABCMeta
, abstractmethod
4 from enum
import Enum
, unique
5 from functools
import lru_cache
, total_ordering
6 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
7 Sequence
, TypeVar
, overload
)
8 from weakref
import WeakValueDictionary
as _WeakVDict
10 from cached_property
import cached_property
11 from nmutil
.plain_data
import fields
, plain_data
13 from bigint_presentation_code
.type_util
import Self
, assert_never
, final
, Literal
14 from bigint_presentation_code
.util
import BitSet
, FBitSet
, FMap
, OFSet
, OSet
20 self
.ops
= [] # type: list[Op]
21 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
22 self
.__next
_name
_suffix
= 2
24 def _add_op_with_unused_name(self
, op
, name
=""):
25 # type: (Op, str) -> str
27 raise ValueError("can't add Op to wrong Fn")
28 if hasattr(op
, "name"):
29 raise ValueError("Op already named")
32 if name
!= "" and name
not in self
.__op
_names
:
33 self
.__op
_names
[name
] = op
35 name
= orig_name
+ str(self
.__next
_name
_suffix
)
36 self
.__next
_name
_suffix
+= 1
42 def append_op(self
, op
):
45 raise ValueError("can't add Op to wrong Fn")
48 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
50 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
51 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
52 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
53 self
.append_op(retval
)
56 def pre_ra_sim(self
, state
):
57 # type: (PreRASimState) -> None
61 def pre_ra_insert_copies(self
):
63 orig_ops
= list(self
.ops
)
64 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
67 for i
in range(len(op
.input_vals
)):
68 inp
= copied_outputs
[op
.input_vals
[i
]]
69 if inp
.ty
.base_ty
is BaseTy
.I64
:
70 maxvl
= inp
.ty
.reg_len
71 if inp
.ty
.reg_len
!= 1:
72 setvl
= self
.append_new_op(
73 OpKind
.SetVLI
, immediates
=[maxvl
],
74 name
=f
"{op.name}.inp{i}.setvl")
76 mv
= self
.append_new_op(
77 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
78 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
80 mv
= self
.append_new_op(
81 OpKind
.CopyToReg
, input_vals
=[inp
],
82 name
=f
"{op.name}.inp{i}.copy")
83 op
.input_vals
[i
] = mv
.outputs
[0]
84 elif inp
.ty
.base_ty
is BaseTy
.CA \
85 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
86 # all copies would be no-ops, so we don't need to copy
87 op
.input_vals
[i
] = inp
89 assert_never(inp
.ty
.base_ty
)
91 for i
, out
in enumerate(op
.outputs
):
92 if out
.ty
.base_ty
is BaseTy
.I64
:
93 maxvl
= out
.ty
.reg_len
94 if out
.ty
.reg_len
!= 1:
95 setvl
= self
.append_new_op(
96 OpKind
.SetVLI
, immediates
=[maxvl
],
97 name
=f
"{op.name}.out{i}.setvl")
99 mv
= self
.append_new_op(
100 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
101 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
103 mv
= self
.append_new_op(
104 OpKind
.CopyFromReg
, input_vals
=[out
],
105 name
=f
"{op.name}.out{i}.copy")
106 copied_outputs
[out
] = mv
.outputs
[0]
107 elif out
.ty
.base_ty
is BaseTy
.CA \
108 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
109 # all copies would be no-ops, so we don't need to copy
110 copied_outputs
[out
] = out
112 assert_never(out
.ty
.base_ty
)
119 value
: Literal
[0, 1] # type: ignore
121 def __new__(cls
, value
):
122 # type: (int) -> OpStage
124 if value
not in (0, 1):
125 raise ValueError("invalid value")
126 retval
= object.__new
__(cls
)
127 retval
._value
_ = value
131 """ early stage of Op execution, where all input reads occur.
132 all output writes with `write_stage == Early` occur here too, and therefore
133 conflict with input reads, telling the compiler that it that can't share
134 that output's register with any inputs that the output isn't tied to.
136 All outputs, even unused outputs, can't share registers with any other
137 outputs, independent of `write_stage` settings.
140 """ late stage of Op execution, where all output writes with
141 `write_stage == Late` occur, and therefore don't conflict with input reads,
142 telling the compiler that any inputs can safely use the same register as
145 All outputs, even unused outputs, can't share registers with any other
146 outputs, independent of `write_stage` settings.
151 return f
"OpStage.{self._name_}"
153 def __lt__(self
, other
):
154 # type: (OpStage | object) -> bool
155 if isinstance(other
, OpStage
):
156 return self
.value
< other
.value
157 return NotImplemented
160 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
163 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
167 __slots__
= "op_index", "stage"
169 def __init__(self
, op_index
, stage
):
170 # type: (int, OpStage) -> None
171 self
.op_index
= op_index
177 """ an integer representation of `self` such that it keeps ordering and
178 successor/predecessor relations.
180 return self
.op_index
* 2 + self
.stage
.value
183 def from_int_value(int_value
):
184 # type: (int) -> ProgramPoint
185 op_index
, stage
= divmod(int_value
, 2)
186 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
188 def next(self
, steps
=1):
189 # type: (int) -> ProgramPoint
190 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
192 def prev(self
, steps
=1):
193 # type: (int) -> ProgramPoint
194 return self
.next(steps
=-steps
)
196 def __lt__(self
, other
):
197 # type: (ProgramPoint | Any) -> bool
198 if not isinstance(other
, ProgramPoint
):
199 return NotImplemented
200 if self
.op_index
!= other
.op_index
:
201 return self
.op_index
< other
.op_index
202 return self
.stage
< other
.stage
206 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
209 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
211 class ProgramRange(Sequence
[ProgramPoint
]):
212 __slots__
= "start", "stop"
214 def __init__(self
, start
, stop
):
215 # type: (ProgramPoint, ProgramPoint) -> None
220 def int_value_range(self
):
222 return range(self
.start
.int_value
, self
.stop
.int_value
)
225 def from_int_value_range(int_value_range
):
226 # type: (range) -> ProgramRange
227 if int_value_range
.step
!= 1:
228 raise ValueError("int_value_range must have step == 1")
230 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
231 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
234 def __getitem__(self
, __idx
):
235 # type: (int) -> ProgramPoint
239 def __getitem__(self
, __idx
):
240 # type: (slice) -> ProgramRange
243 def __getitem__(self
, __idx
):
244 # type: (int | slice) -> ProgramPoint | ProgramRange
245 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
246 if isinstance(v
, int):
247 return ProgramPoint
.from_int_value(v
)
248 return ProgramRange
.from_int_value_range(v
)
252 return len(self
.int_value_range
)
255 # type: () -> Iterator[ProgramPoint]
256 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
260 start
= repr(self
.start
).lstrip("<").rstrip(">")
261 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
262 return f
"<range:{start}..{stop}>"
265 @plain_data(frozen
=True, eq
=False)
268 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
269 "def_program_ranges", "use_program_points",
270 "all_program_points")
272 def __init__(self
, fn
):
275 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
276 self
.all_program_points
= ProgramRange(
277 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
278 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
279 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
280 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
281 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
282 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
284 for use
in op
.input_uses
:
285 uses
[use
.ssa_val
].add(use
)
286 use_program_point
= self
.__get
_use
_program
_point
(use
)
287 use_program_points
[use
] = use_program_point
288 live_range_stops
[use
.ssa_val
] = max(
289 live_range_stops
[use
.ssa_val
], use_program_point
.next())
290 for out
in op
.outputs
:
292 def_program_range
= self
.__get
_def
_program
_range
(out
)
293 def_program_ranges
[out
] = def_program_range
294 live_range_stops
[out
] = def_program_range
.stop
295 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
296 self
.def_program_ranges
= FMap(def_program_ranges
)
297 self
.use_program_points
= FMap(use_program_points
)
298 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
299 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
300 for ssa_val
in uses
.keys():
301 live_ranges
[ssa_val
] = live_range
= ProgramRange(
302 start
=self
.def_program_ranges
[ssa_val
].start
,
303 stop
=live_range_stops
[ssa_val
])
304 for program_point
in live_range
:
305 live_at
[program_point
].add(ssa_val
)
306 self
.live_ranges
= FMap(live_ranges
)
307 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
309 def __get_def_program_range(self
, ssa_val
):
310 # type: (SSAVal) -> ProgramRange
311 write_stage
= ssa_val
.defining_descriptor
.write_stage
312 start
= ProgramPoint(
313 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
314 # always include late stage of ssa_val.op, to ensure outputs always
315 # overlap all other outputs.
316 # stop is exclusive, so we need the next program point.
317 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
318 return ProgramRange(start
=start
, stop
=stop
)
320 def __get_use_program_point(self
, ssa_use
):
321 # type: (SSAUse) -> ProgramPoint
322 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
323 "assumed here, ensured by GenericOpProperties.__init__"
325 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
327 def __eq__(self
, other
):
328 # type: (FnAnalysis | Any) -> bool
329 if isinstance(other
, FnAnalysis
):
330 return self
.fn
== other
.fn
331 return NotImplemented
343 VL_MAXVL
= enum
.auto()
346 def only_scalar(self
):
348 if self
is BaseTy
.I64
:
350 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
356 def max_reg_len(self
):
358 if self
is BaseTy
.I64
:
360 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
366 return "BaseTy." + self
._name
_
369 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
372 __slots__
= "base_ty", "reg_len"
375 def validate(base_ty
, reg_len
):
376 # type: (BaseTy, int) -> str | None
377 """ return a string with the error if the combination is invalid,
378 otherwise return None
380 if base_ty
.only_scalar
and reg_len
!= 1:
381 return f
"can't create a vector of an only-scalar type: {base_ty}"
382 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
383 return "reg_len out of range"
386 def __init__(self
, base_ty
, reg_len
):
387 # type: (BaseTy, int) -> None
388 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
390 raise ValueError(msg
)
391 self
.base_ty
= base_ty
392 self
.reg_len
= reg_len
396 if self
.reg_len
!= 1:
397 reg_len
= f
"*{self.reg_len}"
400 return f
"<{self.base_ty._name_}{reg_len}>"
407 StackI64
= enum
.auto()
409 VL_MAXVL
= enum
.auto()
414 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
416 if self
is LocKind
.CA
:
418 if self
is LocKind
.VL_MAXVL
:
419 return BaseTy
.VL_MAXVL
426 if self
is LocKind
.StackI64
:
428 if self
is LocKind
.GPR
or self
is LocKind
.CA \
429 or self
is LocKind
.VL_MAXVL
:
430 return self
.base_ty
.max_reg_len
435 return "LocKind." + self
._name
_
440 class LocSubKind(Enum
):
441 BASE_GPR
= enum
.auto()
442 SV_EXTRA2_VGPR
= enum
.auto()
443 SV_EXTRA2_SGPR
= enum
.auto()
444 SV_EXTRA3_VGPR
= enum
.auto()
445 SV_EXTRA3_SGPR
= enum
.auto()
446 StackI64
= enum
.auto()
448 VL_MAXVL
= enum
.auto()
452 # type: () -> LocKind
453 # pyright fails typechecking when using `in` here:
454 # reported: https://github.com/microsoft/pyright/issues/4102
455 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
456 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
457 LocSubKind
.SV_EXTRA3_SGPR
):
459 if self
is LocSubKind
.StackI64
:
460 return LocKind
.StackI64
461 if self
is LocSubKind
.CA
:
463 if self
is LocSubKind
.VL_MAXVL
:
464 return LocKind
.VL_MAXVL
469 return self
.kind
.base_ty
472 def allocatable_locs(self
, ty
):
473 # type: (Ty) -> LocSet
474 if ty
.base_ty
!= self
.base_ty
:
475 raise ValueError("type mismatch")
476 if self
is LocSubKind
.BASE_GPR
:
478 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
479 starts
= range(0, 128, 2)
480 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
482 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
483 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
485 elif self
is LocSubKind
.StackI64
:
486 starts
= range(LocKind
.StackI64
.loc_count
)
487 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
488 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
491 retval
= [] # type: list[Loc]
493 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
497 for special_loc
in SPECIAL_GPRS
:
498 if loc
.conflicts(special_loc
):
503 return LocSet(retval
)
506 return "LocSubKind." + self
._name
_
509 @plain_data(frozen
=True, unsafe_hash
=True)
512 __slots__
= "base_ty", "is_vec"
514 def __init__(self
, base_ty
, is_vec
):
515 # type: (BaseTy, bool) -> None
516 self
.base_ty
= base_ty
517 if base_ty
.only_scalar
and is_vec
:
518 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
521 def instantiate(self
, maxvl
):
523 # here's where subvl and elwid would be accounted for
525 return Ty(self
.base_ty
, maxvl
)
526 return Ty(self
.base_ty
, 1)
528 def can_instantiate_to(self
, ty
):
530 if self
.base_ty
!= ty
.base_ty
:
534 return ty
.reg_len
== 1
537 @plain_data(frozen
=True, unsafe_hash
=True)
540 __slots__
= "kind", "start", "reg_len"
543 def validate(kind
, start
, reg_len
):
544 # type: (LocKind, int, int) -> str | None
545 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
548 if reg_len
> kind
.loc_count
:
549 return "invalid reg_len"
550 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
551 return "start not in valid range"
555 def try_make(kind
, start
, reg_len
):
556 # type: (LocKind, int, int) -> Loc | None
557 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
560 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
562 def __init__(self
, kind
, start
, reg_len
):
563 # type: (LocKind, int, int) -> None
564 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
566 raise ValueError(msg
)
568 self
.reg_len
= reg_len
571 def conflicts(self
, other
):
572 # type: (Loc) -> bool
573 return (self
.kind
== other
.kind
574 and self
.start
< other
.stop
and other
.start
< self
.stop
)
577 def make_ty(kind
, reg_len
):
578 # type: (LocKind, int) -> Ty
579 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
584 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
589 return self
.start
+ self
.reg_len
591 def try_concat(self
, *others
):
592 # type: (*Loc | None) -> Loc | None
593 reg_len
= self
.reg_len
596 if other
is None or other
.kind
!= self
.kind
:
598 if stop
!= other
.start
:
601 reg_len
+= other
.reg_len
602 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
606 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
607 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
608 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
609 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
613 @plain_data(frozen
=True, eq
=False)
615 class LocSet(AbstractSet
[Loc
]):
616 __slots__
= "starts", "ty"
618 def __init__(self
, __locs
=()):
619 # type: (Iterable[Loc]) -> None
620 if isinstance(__locs
, LocSet
):
621 self
.starts
= __locs
.starts
# type: FMap[LocKind, FBitSet]
622 self
.ty
= __locs
.ty
# type: Ty | None
624 starts
= {i
: BitSet() for i
in LocKind
}
630 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
631 starts
[loc
.kind
].add(loc
.start
)
633 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
638 # type: () -> FMap[LocKind, FBitSet]
643 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
647 # type: () -> AbstractSet[LocKind]
648 return self
.starts
.keys()
652 # type: () -> int | None
655 return self
.ty
.reg_len
659 # type: () -> BaseTy | None
662 return self
.ty
.base_ty
664 def concat(self
, *others
):
665 # type: (*LocSet) -> LocSet
668 base_ty
= self
.ty
.base_ty
669 reg_len
= self
.ty
.reg_len
670 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
674 if other
.ty
.base_ty
!= base_ty
:
676 for kind
, other_starts
in other
.starts
.items():
677 if kind
not in starts
:
679 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
680 if starts
[kind
] == 0:
684 reg_len
+= other
.ty
.reg_len
687 # type: () -> Iterable[Loc]
688 for kind
, v
in starts
.items():
690 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
693 return LocSet(locs())
695 def __contains__(self
, loc
):
696 # type: (Loc | Any) -> bool
697 if not isinstance(loc
, Loc
) or loc
.ty
!= self
.ty
:
699 if loc
.kind
not in self
.starts
:
701 return loc
.start
in self
.starts
[loc
.kind
]
704 # type: () -> Iterator[Loc]
707 for kind
, starts
in self
.starts
.items():
709 yield Loc(kind
=kind
, start
=start
, reg_len
=self
.ty
.reg_len
)
713 return sum((len(v
) for v
in self
.starts
.values()), 0)
720 return super()._hash
()
725 @lru_cache(maxsize
=None, typed
=True)
726 def max_conflicts_with(self
, other
):
727 # type: (LocSet | Loc) -> int
728 """the largest number of Locs in `self` that a single Loc
729 from `other` can conflict with
731 if isinstance(other
, LocSet
):
732 return max(self
.max_conflicts_with(i
) for i
in other
)
734 return sum(other
.conflicts(i
) for i
in self
)
737 @plain_data(frozen
=True, unsafe_hash
=True)
739 class GenericOperandDesc
:
740 """generic Op operand descriptor"""
741 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
745 self
, ty
, # type: GenericTy
746 sub_kinds
, # type: Iterable[LocSubKind]
748 fixed_loc
=None, # type: Loc | None
749 tied_input_index
=None, # type: int | None
750 spread
=False, # type: bool
751 write_stage
=OpStage
.Early
, # type: OpStage
753 # type: (...) -> None
755 self
.sub_kinds
= OFSet(sub_kinds
)
756 if len(self
.sub_kinds
) == 0:
757 raise ValueError("sub_kinds can't be empty")
758 self
.fixed_loc
= fixed_loc
759 if fixed_loc
is not None:
760 if tied_input_index
is not None:
761 raise ValueError("operand can't be both tied and fixed")
762 if not ty
.can_instantiate_to(fixed_loc
.ty
):
764 f
"fixed_loc has incompatible type for given generic "
765 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
766 if len(self
.sub_kinds
) != 1:
768 "multiple sub_kinds not allowed for fixed operand")
769 for sub_kind
in self
.sub_kinds
:
770 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
772 f
"fixed_loc not in given sub_kind: "
773 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
774 for sub_kind
in self
.sub_kinds
:
775 if sub_kind
.base_ty
!= ty
.base_ty
:
776 raise ValueError(f
"sub_kind is incompatible with type: "
777 f
"sub_kind={sub_kind} ty={ty}")
778 if tied_input_index
is not None and tied_input_index
< 0:
779 raise ValueError("invalid tied_input_index")
780 self
.tied_input_index
= tied_input_index
783 if self
.tied_input_index
is not None:
784 raise ValueError("operand can't be both spread and tied")
785 if self
.fixed_loc
is not None:
786 raise ValueError("operand can't be both spread and fixed")
788 raise ValueError("operand can't be both spread and vector")
789 self
.write_stage
= write_stage
791 def tied_to_input(self
, tied_input_index
):
792 # type: (int) -> Self
793 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
794 tied_input_index
=tied_input_index
,
795 write_stage
=self
.write_stage
)
797 def with_fixed_loc(self
, fixed_loc
):
798 # type: (Loc) -> Self
799 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
800 write_stage
=self
.write_stage
)
802 def with_write_stage(self
, write_stage
):
803 # type: (OpStage) -> Self
804 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
805 fixed_loc
=self
.fixed_loc
,
806 tied_input_index
=self
.tied_input_index
,
808 write_stage
=write_stage
)
810 def instantiate(self
, maxvl
):
811 # type: (int) -> Iterable[OperandDesc]
812 # assumes all spread operands have ty.reg_len = 1
817 ty
= self
.ty
.instantiate(maxvl
=maxvl
)
820 # type: () -> Iterable[Loc]
821 if self
.fixed_loc
is not None:
822 if ty
!= self
.fixed_loc
.ty
:
824 f
"instantiation failed: type mismatch with fixed_loc: "
825 f
"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
828 for sub_kind
in self
.sub_kinds
:
829 yield from sub_kind
.allocatable_locs(ty
)
830 loc_set_before_spread
= LocSet(locs())
831 for idx
in range(rep_count
):
834 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
835 tied_input_index
=self
.tied_input_index
,
836 spread_index
=idx
, write_stage
=self
.write_stage
)
839 @plain_data(frozen
=True, unsafe_hash
=True)
842 """Op operand descriptor"""
843 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
846 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
848 # type: (LocSet, int | None, int | None, OpStage) -> None
849 if len(loc_set_before_spread
) == 0:
850 raise ValueError("loc_set_before_spread must not be empty")
851 self
.loc_set_before_spread
= loc_set_before_spread
852 self
.tied_input_index
= tied_input_index
853 if self
.tied_input_index
is not None and self
.spread_index
is not None:
854 raise ValueError("operand can't be both spread and tied")
855 self
.spread_index
= spread_index
856 self
.write_stage
= write_stage
859 def ty_before_spread(self
):
861 ty
= self
.loc_set_before_spread
.ty
862 assert ty
is not None, (
863 "__init__ checked that the LocSet isn't empty, "
864 "non-empty LocSets should always have ty set")
869 """ Ty after any spread is applied """
870 if self
.spread_index
is not None:
871 # assumes all spread operands have ty.reg_len = 1
872 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
873 return self
.ty_before_spread
876 def reg_offset_in_unspread(self
):
877 """ the number of reg-sized slots in the unspread Loc before self's Loc
879 e.g. if the unspread Loc containing self is:
880 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
881 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
882 then reg_offset_into_unspread == 2 == 10 - 8
884 if self
.spread_index
is None:
886 return self
.spread_index
* self
.ty
.reg_len
889 OD_BASE_SGPR
= GenericOperandDesc(
890 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
891 sub_kinds
=[LocSubKind
.BASE_GPR
])
892 OD_EXTRA3_SGPR
= GenericOperandDesc(
893 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
894 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
895 OD_EXTRA3_VGPR
= GenericOperandDesc(
896 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
897 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
898 OD_EXTRA2_SGPR
= GenericOperandDesc(
899 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
900 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
901 OD_EXTRA2_VGPR
= GenericOperandDesc(
902 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
903 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
904 OD_CA
= GenericOperandDesc(
905 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
906 sub_kinds
=[LocSubKind
.CA
])
907 OD_VL
= GenericOperandDesc(
908 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
909 sub_kinds
=[LocSubKind
.VL_MAXVL
])
912 @plain_data(frozen
=True, unsafe_hash
=True)
914 class GenericOpProperties
:
915 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
916 "is_copy", "is_load_immediate", "has_side_effects")
919 self
, demo_asm
, # type: str
920 inputs
, # type: Iterable[GenericOperandDesc]
921 outputs
, # type: Iterable[GenericOperandDesc]
922 immediates
=(), # type: Iterable[range]
923 is_copy
=False, # type: bool
924 is_load_immediate
=False, # type: bool
925 has_side_effects
=False, # type: bool
927 # type: (...) -> None
928 self
.demo_asm
= demo_asm
# type: str
929 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
930 for inp
in self
.inputs
:
931 if inp
.tied_input_index
is not None:
933 f
"tied_input_index is not allowed on inputs: {inp}")
934 if inp
.write_stage
is not OpStage
.Early
:
936 f
"write_stage is not allowed on inputs: {inp}")
937 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
938 fixed_locs
= [] # type: list[tuple[Loc, int]]
939 for idx
, out
in enumerate(self
.outputs
):
940 if out
.tied_input_index
is not None:
941 if out
.tied_input_index
>= len(self
.inputs
):
942 raise ValueError(f
"tied_input_index out of range: {out}")
943 tied_inp
= self
.inputs
[out
.tied_input_index
]
944 if tied_inp
.tied_to_input(out
.tied_input_index
) != out
:
945 raise ValueError(f
"output can't be tied to non-equivalent "
946 f
"input: {out} tied to {tied_inp}")
947 if out
.fixed_loc
is not None:
948 for other_fixed_loc
, other_idx
in fixed_locs
:
949 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
952 f
"conflicting fixed_locs: outputs[{idx}] and "
953 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
954 f
"with {other_fixed_loc}")
955 fixed_locs
.append((out
.fixed_loc
, idx
))
956 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
957 self
.is_copy
= is_copy
# type: bool
958 self
.is_load_immediate
= is_load_immediate
# type: bool
959 self
.has_side_effects
= has_side_effects
# type: bool
962 @plain_data(frozen
=True, unsafe_hash
=True)
965 __slots__
= "kind", "inputs", "outputs", "maxvl"
967 def __init__(self
, kind
, maxvl
):
968 # type: (OpKind, int) -> None
969 self
.kind
= kind
# type: OpKind
970 inputs
= [] # type: list[OperandDesc]
971 for inp
in self
.generic
.inputs
:
972 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
973 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
974 outputs
= [] # type: list[OperandDesc]
975 for out
in self
.generic
.outputs
:
976 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
977 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
978 self
.maxvl
= maxvl
# type: int
982 # type: () -> GenericOpProperties
983 return self
.kind
.properties
986 def immediates(self
):
987 # type: () -> tuple[range, ...]
988 return self
.generic
.immediates
993 return self
.generic
.demo_asm
998 return self
.generic
.is_copy
1001 def is_load_immediate(self
):
1003 return self
.generic
.is_load_immediate
1006 def has_side_effects(self
):
1008 return self
.generic
.has_side_effects
1011 IMM_S16
= range(-1 << 15, 1 << 15)
1013 _PRE_RA_SIM_FN
= Callable
[["Op", "PreRASimState"], None]
1014 _PRE_RA_SIM_FN2
= Callable
[[], _PRE_RA_SIM_FN
]
1015 _PRE_RA_SIMS
= {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
1021 def __init__(self
, properties
):
1022 # type: (GenericOpProperties) -> None
1024 self
.__properties
= properties
1027 def properties(self
):
1028 # type: () -> GenericOpProperties
1029 return self
.__properties
1031 def instantiate(self
, maxvl
):
1032 # type: (int) -> OpProperties
1033 return OpProperties(self
, maxvl
=maxvl
)
1037 return "OpKind." + self
._name
_
1040 def pre_ra_sim(self
):
1041 # type: () -> _PRE_RA_SIM_FN
1042 return _PRE_RA_SIMS
[self
.properties
]()
1045 def __clearca_pre_ra_sim(op
, state
):
1046 # type: (Op, PreRASimState) -> None
1047 state
.ssa_vals
[op
.outputs
[0]] = False,
1048 ClearCA
= GenericOpProperties(
1049 demo_asm
="addic 0, 0, 0",
1051 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1053 _PRE_RA_SIMS
[ClearCA
] = lambda: OpKind
.__clearca
_pre
_ra
_sim
1056 def __setca_pre_ra_sim(op
, state
):
1057 # type: (Op, PreRASimState) -> None
1058 state
.ssa_vals
[op
.outputs
[0]] = True,
1059 SetCA
= GenericOpProperties(
1060 demo_asm
="subfc 0, 0, 0",
1062 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1064 _PRE_RA_SIMS
[SetCA
] = lambda: OpKind
.__setca
_pre
_ra
_sim
1067 def __svadde_pre_ra_sim(op
, state
):
1068 # type: (Op, PreRASimState) -> None
1069 RA
= state
.ssa_vals
[op
.input_vals
[0]]
1070 RB
= state
.ssa_vals
[op
.input_vals
[1]]
1071 carry
, = state
.ssa_vals
[op
.input_vals
[2]]
1072 VL
, = state
.ssa_vals
[op
.input_vals
[3]]
1073 RT
= [] # type: list[int]
1075 v
= RA
[i
] + RB
[i
] + carry
1076 RT
.append(v
& GPR_VALUE_MASK
)
1077 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1078 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
1079 state
.ssa_vals
[op
.outputs
[1]] = carry
,
1080 SvAddE
= GenericOpProperties(
1081 demo_asm
="sv.adde *RT, *RA, *RB",
1082 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1083 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
1085 _PRE_RA_SIMS
[SvAddE
] = lambda: OpKind
.__svadde
_pre
_ra
_sim
1088 def __svsubfe_pre_ra_sim(op
, state
):
1089 # type: (Op, PreRASimState) -> None
1090 RA
= state
.ssa_vals
[op
.input_vals
[0]]
1091 RB
= state
.ssa_vals
[op
.input_vals
[1]]
1092 carry
, = state
.ssa_vals
[op
.input_vals
[2]]
1093 VL
, = state
.ssa_vals
[op
.input_vals
[3]]
1094 RT
= [] # type: list[int]
1096 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1097 RT
.append(v
& GPR_VALUE_MASK
)
1098 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1099 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
1100 state
.ssa_vals
[op
.outputs
[1]] = carry
,
1101 SvSubFE
= GenericOpProperties(
1102 demo_asm
="sv.subfe *RT, *RA, *RB",
1103 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1104 outputs
=[OD_EXTRA3_VGPR
, OD_CA
],
1106 _PRE_RA_SIMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_pre
_ra
_sim
1109 def __svmaddedu_pre_ra_sim(op
, state
):
1110 # type: (Op, PreRASimState) -> None
1111 RA
= state
.ssa_vals
[op
.input_vals
[0]]
1112 RB
, = state
.ssa_vals
[op
.input_vals
[1]]
1113 carry
, = state
.ssa_vals
[op
.input_vals
[2]]
1114 VL
, = state
.ssa_vals
[op
.input_vals
[3]]
1115 RT
= [] # type: list[int]
1117 v
= RA
[i
] * RB
+ carry
1118 RT
.append(v
& GPR_VALUE_MASK
)
1119 carry
= v
>> GPR_SIZE_IN_BITS
1120 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
1121 state
.ssa_vals
[op
.outputs
[1]] = carry
,
1122 SvMAddEDU
= GenericOpProperties(
1123 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1124 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1125 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1127 _PRE_RA_SIMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_pre
_ra
_sim
1130 def __setvli_pre_ra_sim(op
, state
):
1131 # type: (Op, PreRASimState) -> None
1132 state
.ssa_vals
[op
.outputs
[0]] = op
.immediates
[0],
1133 SetVLI
= GenericOpProperties(
1134 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1136 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1137 immediates
=[range(1, 65)],
1138 is_load_immediate
=True,
1140 _PRE_RA_SIMS
[SetVLI
] = lambda: OpKind
.__setvli
_pre
_ra
_sim
1143 def __svli_pre_ra_sim(op
, state
):
1144 # type: (Op, PreRASimState) -> None
1145 VL
, = state
.ssa_vals
[op
.input_vals
[0]]
1146 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1147 state
.ssa_vals
[op
.outputs
[0]] = (imm
,) * VL
1148 SvLI
= GenericOpProperties(
1149 demo_asm
="sv.addi *RT, 0, imm",
1151 outputs
=[OD_EXTRA3_VGPR
],
1152 immediates
=[IMM_S16
],
1153 is_load_immediate
=True,
1155 _PRE_RA_SIMS
[SvLI
] = lambda: OpKind
.__svli
_pre
_ra
_sim
1158 def __li_pre_ra_sim(op
, state
):
1159 # type: (Op, PreRASimState) -> None
1160 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1161 state
.ssa_vals
[op
.outputs
[0]] = imm
,
1162 LI
= GenericOpProperties(
1163 demo_asm
="addi RT, 0, imm",
1165 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1166 immediates
=[IMM_S16
],
1167 is_load_immediate
=True,
1169 _PRE_RA_SIMS
[LI
] = lambda: OpKind
.__li
_pre
_ra
_sim
1172 def __veccopytoreg_pre_ra_sim(op
, state
):
1173 # type: (Op, PreRASimState) -> None
1174 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.input_vals
[0]]
1175 VecCopyToReg
= GenericOpProperties(
1176 demo_asm
="sv.mv dest, src",
1177 inputs
=[GenericOperandDesc(
1178 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1179 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1181 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1184 _PRE_RA_SIMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_pre
_ra
_sim
1187 def __veccopyfromreg_pre_ra_sim(op
, state
):
1188 # type: (Op, PreRASimState) -> None
1189 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.input_vals
[0]]
1190 VecCopyFromReg
= GenericOpProperties(
1191 demo_asm
="sv.mv dest, src",
1192 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1193 outputs
=[GenericOperandDesc(
1194 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1195 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1196 write_stage
=OpStage
.Late
,
1200 _PRE_RA_SIMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_pre
_ra
_sim
1203 def __copytoreg_pre_ra_sim(op
, state
):
1204 # type: (Op, PreRASimState) -> None
1205 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.input_vals
[0]]
1206 CopyToReg
= GenericOpProperties(
1207 demo_asm
="mv dest, src",
1208 inputs
=[GenericOperandDesc(
1209 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1210 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1211 LocSubKind
.StackI64
],
1213 outputs
=[GenericOperandDesc(
1214 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1215 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1216 write_stage
=OpStage
.Late
,
1220 _PRE_RA_SIMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_pre
_ra
_sim
1223 def __copyfromreg_pre_ra_sim(op
, state
):
1224 # type: (Op, PreRASimState) -> None
1225 state
.ssa_vals
[op
.outputs
[0]] = state
.ssa_vals
[op
.input_vals
[0]]
1226 CopyFromReg
= GenericOpProperties(
1227 demo_asm
="mv dest, src",
1228 inputs
=[GenericOperandDesc(
1229 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1230 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1232 outputs
=[GenericOperandDesc(
1233 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1234 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1235 LocSubKind
.StackI64
],
1236 write_stage
=OpStage
.Late
,
1240 _PRE_RA_SIMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_pre
_ra
_sim
1243 def __concat_pre_ra_sim(op
, state
):
1244 # type: (Op, PreRASimState) -> None
1245 state
.ssa_vals
[op
.outputs
[0]] = tuple(
1246 state
.ssa_vals
[i
][0] for i
in op
.input_vals
[:-1])
1247 Concat
= GenericOpProperties(
1248 demo_asm
="sv.mv dest, src",
1249 inputs
=[GenericOperandDesc(
1250 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1251 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1254 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1257 _PRE_RA_SIMS
[Concat
] = lambda: OpKind
.__concat
_pre
_ra
_sim
1260 def __spread_pre_ra_sim(op
, state
):
1261 # type: (Op, PreRASimState) -> None
1262 for idx
, inp
in enumerate(state
.ssa_vals
[op
.input_vals
[0]]):
1263 state
.ssa_vals
[op
.outputs
[idx
]] = inp
,
1264 Spread
= GenericOpProperties(
1265 demo_asm
="sv.mv dest, src",
1266 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1267 outputs
=[GenericOperandDesc(
1268 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1269 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1271 write_stage
=OpStage
.Late
,
1275 _PRE_RA_SIMS
[Spread
] = lambda: OpKind
.__spread
_pre
_ra
_sim
1278 def __svld_pre_ra_sim(op
, state
):
1279 # type: (Op, PreRASimState) -> None
1280 RA
, = state
.ssa_vals
[op
.input_vals
[0]]
1281 VL
, = state
.ssa_vals
[op
.input_vals
[1]]
1282 addr
= RA
+ op
.immediates
[0]
1283 RT
= [] # type: list[int]
1285 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1286 RT
.append(v
& GPR_VALUE_MASK
)
1287 state
.ssa_vals
[op
.outputs
[0]] = tuple(RT
)
1288 SvLd
= GenericOpProperties(
1289 demo_asm
="sv.ld *RT, imm(RA)",
1290 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1291 outputs
=[OD_EXTRA3_VGPR
],
1292 immediates
=[IMM_S16
],
1294 _PRE_RA_SIMS
[SvLd
] = lambda: OpKind
.__svld
_pre
_ra
_sim
1297 def __ld_pre_ra_sim(op
, state
):
1298 # type: (Op, PreRASimState) -> None
1299 RA
, = state
.ssa_vals
[op
.input_vals
[0]]
1300 addr
= RA
+ op
.immediates
[0]
1301 v
= state
.load(addr
)
1302 state
.ssa_vals
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1303 Ld
= GenericOpProperties(
1304 demo_asm
="ld RT, imm(RA)",
1305 inputs
=[OD_BASE_SGPR
],
1306 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1307 immediates
=[IMM_S16
],
1309 _PRE_RA_SIMS
[Ld
] = lambda: OpKind
.__ld
_pre
_ra
_sim
1312 def __svstd_pre_ra_sim(op
, state
):
1313 # type: (Op, PreRASimState) -> None
1314 RS
= state
.ssa_vals
[op
.input_vals
[0]]
1315 RA
, = state
.ssa_vals
[op
.input_vals
[1]]
1316 VL
, = state
.ssa_vals
[op
.input_vals
[2]]
1317 addr
= RA
+ op
.immediates
[0]
1319 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1320 SvStd
= GenericOpProperties(
1321 demo_asm
="sv.std *RS, imm(RA)",
1322 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1324 immediates
=[IMM_S16
],
1325 has_side_effects
=True,
1327 _PRE_RA_SIMS
[SvStd
] = lambda: OpKind
.__svstd
_pre
_ra
_sim
1330 def __std_pre_ra_sim(op
, state
):
1331 # type: (Op, PreRASimState) -> None
1332 RS
, = state
.ssa_vals
[op
.input_vals
[0]]
1333 RA
, = state
.ssa_vals
[op
.input_vals
[1]]
1334 addr
= RA
+ op
.immediates
[0]
1335 state
.store(addr
, value
=RS
)
1336 Std
= GenericOpProperties(
1337 demo_asm
="std RT, imm(RA)",
1338 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1340 immediates
=[IMM_S16
],
1341 has_side_effects
=True,
1343 _PRE_RA_SIMS
[Std
] = lambda: OpKind
.__std
_pre
_ra
_sim
1346 def __funcargr3_pre_ra_sim(op
, state
):
1347 # type: (Op, PreRASimState) -> None
1348 pass # return value set before simulation
1349 FuncArgR3
= GenericOpProperties(
1352 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1353 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1355 _PRE_RA_SIMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_pre
_ra
_sim
1358 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1359 class SSAValOrUse(metaclass
=ABCMeta
):
1360 __slots__
= "op", "operand_idx"
1362 def __init__(self
, op
, operand_idx
):
1363 # type: (Op, int) -> None
1366 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1367 raise ValueError("invalid operand_idx")
1368 self
.operand_idx
= operand_idx
1377 def descriptor_array(self
):
1378 # type: () -> tuple[OperandDesc, ...]
1382 def defining_descriptor(self
):
1383 # type: () -> OperandDesc
1384 return self
.descriptor_array
[self
.operand_idx
]
1389 return self
.defining_descriptor
.ty
1392 def ty_before_spread(self
):
1394 return self
.defining_descriptor
.ty_before_spread
1398 # type: () -> BaseTy
1399 return self
.ty_before_spread
.base_ty
1402 def reg_offset_in_unspread(self
):
1403 """ the number of reg-sized slots in the unspread Loc before self's Loc
1405 e.g. if the unspread Loc containing self is:
1406 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1407 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1408 then reg_offset_into_unspread == 2 == 10 - 8
1410 return self
.defining_descriptor
.reg_offset_in_unspread
1413 def unspread_start_idx(self
):
1415 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1418 def unspread_start(self
):
1420 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1423 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1425 class SSAVal(SSAValOrUse
):
1430 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1433 def def_loc_set_before_spread(self
):
1434 # type: () -> LocSet
1435 return self
.defining_descriptor
.loc_set_before_spread
1438 def descriptor_array(self
):
1439 # type: () -> tuple[OperandDesc, ...]
1440 return self
.op
.properties
.outputs
1443 def tied_input(self
):
1444 # type: () -> None | SSAUse
1445 if self
.defining_descriptor
.tied_input_index
is None:
1447 return SSAUse(op
=self
.op
,
1448 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1451 def write_stage(self
):
1452 # type: () -> OpStage
1453 return self
.defining_descriptor
.write_stage
1456 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1458 class SSAUse(SSAValOrUse
):
1462 def use_loc_set_before_spread(self
):
1463 # type: () -> LocSet
1464 return self
.defining_descriptor
.loc_set_before_spread
1467 def descriptor_array(self
):
1468 # type: () -> tuple[OperandDesc, ...]
1469 return self
.op
.properties
.inputs
1473 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1477 # type: () -> SSAVal
1478 return self
.op
.input_vals
[self
.operand_idx
]
1481 def ssa_val(self
, ssa_val
):
1482 # type: (SSAVal) -> None
1483 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1487 _Desc
= TypeVar("_Desc")
1490 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1492 def _verify_write_with_desc(self
, idx
, item
, desc
):
1493 # type: (int, _T | Any, _Desc) -> None
1494 raise NotImplementedError
1497 def _verify_write(self
, idx
, item
):
1498 # type: (int | Any, _T | Any) -> int
1499 if not isinstance(idx
, int):
1500 if isinstance(idx
, slice):
1502 f
"can't write to slice of {self.__class__.__name__}")
1503 raise TypeError(f
"can't write with index {idx!r}")
1504 # normalize idx, raising IndexError if it is out of range
1505 idx
= range(len(self
.descriptors
))[idx
]
1506 desc
= self
.descriptors
[idx
]
1507 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1510 def _on_set(self
, idx
, new_item
, old_item
):
1511 # type: (int, _T, _T | None) -> None
1515 def _get_descriptors(self
):
1516 # type: () -> tuple[_Desc, ...]
1517 raise NotImplementedError
1521 def descriptors(self
):
1522 # type: () -> tuple[_Desc, ...]
1523 return self
._get
_descriptors
()
1530 def __init__(self
, items
, op
):
1531 # type: (Iterable[_T], Op) -> None
1534 self
.__items
= [] # type: list[_T]
1535 for idx
, item
in enumerate(items
):
1536 if idx
>= len(self
.descriptors
):
1537 raise ValueError("too many items")
1538 _
= self
._verify
_write
(idx
, item
)
1539 self
.__items
.append(item
)
1540 if len(self
.__items
) < len(self
.descriptors
):
1541 raise ValueError("not enough items")
1545 # type: () -> Iterator[_T]
1546 yield from self
.__items
1549 def __getitem__(self
, idx
):
1554 def __getitem__(self
, idx
):
1555 # type: (slice) -> list[_T]
1559 def __getitem__(self
, idx
):
1560 # type: (int | slice) -> _T | list[_T]
1561 return self
.__items
[idx
]
1564 def __setitem__(self
, idx
, item
):
1565 # type: (int, _T) -> None
1566 idx
= self
._verify
_write
(idx
, item
)
1567 self
.__items
[idx
] = item
1572 return len(self
.__items
)
1576 return f
"{self.__class__.__name__}({self.__items}, op=...)"
1580 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
1581 def _get_descriptors(self
):
1582 # type: () -> tuple[OperandDesc, ...]
1583 return self
.op
.properties
.inputs
1585 def _verify_write_with_desc(self
, idx
, item
, desc
):
1586 # type: (int, SSAVal | Any, OperandDesc) -> None
1587 if not isinstance(item
, SSAVal
):
1588 raise TypeError("expected value of type SSAVal")
1589 if item
.ty
!= desc
.ty
:
1590 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
1591 f
"corresponding input's type {desc.ty!r}")
1593 def _on_set(self
, idx
, new_item
, old_item
):
1594 # type: (int, SSAVal, SSAVal | None) -> None
1595 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
1597 def __init__(self
, items
, op
):
1598 # type: (Iterable[SSAVal], Op) -> None
1599 if hasattr(op
, "inputs"):
1600 raise ValueError("Op.inputs already set")
1601 super().__init
__(items
, op
)
1605 class OpImmediates(OpInputSeq
[int, range]):
1606 def _get_descriptors(self
):
1607 # type: () -> tuple[range, ...]
1608 return self
.op
.properties
.immediates
1610 def _verify_write_with_desc(self
, idx
, item
, desc
):
1611 # type: (int, int | Any, range) -> None
1612 if not isinstance(item
, int):
1613 raise TypeError("expected value of type int")
1614 if item
not in desc
:
1615 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
1617 def __init__(self
, items
, op
):
1618 # type: (Iterable[int], Op) -> None
1619 if hasattr(op
, "immediates"):
1620 raise ValueError("Op.immediates already set")
1621 super().__init
__(items
, op
)
1624 @plain_data(frozen
=True, eq
=False, repr=False)
1627 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
1630 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
1631 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1633 self
.properties
= properties
1634 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
1635 inputs_len
= len(self
.properties
.inputs
)
1636 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
1637 self
.immediates
= OpImmediates(immediates
, op
=self
)
1638 outputs_len
= len(self
.properties
.outputs
)
1639 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
1640 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
1644 # type: () -> OpKind
1645 return self
.properties
.kind
1647 def __eq__(self
, other
):
1648 # type: (Op | Any) -> bool
1649 if isinstance(other
, Op
):
1650 return self
is other
1651 return NotImplemented
1655 return object.__hash
__(self
)
1659 field_vals
= [] # type: list[str]
1660 for name
in fields(self
):
1661 if name
== "properties":
1666 value
= getattr(self
, name
)
1667 except AttributeError:
1668 field_vals
.append(f
"{name}=<not set>")
1670 if isinstance(value
, OpInputSeq
):
1671 value
= list(value
) # type: ignore
1672 field_vals
.append(f
"{name}={value!r}")
1673 field_vals_str
= ", ".join(field_vals
)
1674 return f
"Op({field_vals_str})"
1676 def pre_ra_sim(self
, state
):
1677 # type: (PreRASimState) -> None
1678 for inp
in self
.input_vals
:
1679 if inp
not in state
.ssa_vals
:
1680 raise ValueError(f
"SSAVal {inp} not yet assigned when "
1682 if len(state
.ssa_vals
[inp
]) != inp
.ty
.reg_len
:
1684 f
"value of SSAVal {inp} has wrong number of elements: "
1685 f
"expected {inp.ty.reg_len} found "
1686 f
"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1687 for out
in self
.outputs
:
1688 if out
in state
.ssa_vals
:
1689 if self
.kind
is OpKind
.FuncArgR3
:
1691 raise ValueError(f
"SSAVal {out} already assigned before "
1693 self
.kind
.pre_ra_sim(self
, state
)
1694 for out
in self
.outputs
:
1695 if out
not in state
.ssa_vals
:
1696 raise ValueError(f
"running {self} failed to assign to {out}")
1697 if len(state
.ssa_vals
[out
]) != out
.ty
.reg_len
:
1699 f
"value of SSAVal {out} has wrong number of elements: "
1700 f
"expected {out.ty.reg_len} found "
1701 f
"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1704 GPR_SIZE_IN_BYTES
= 8
1706 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
1707 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
1710 @plain_data(frozen
=True, repr=False)
1712 class PreRASimState
:
1713 __slots__
= "ssa_vals", "memory"
1715 def __init__(self
, ssa_vals
, memory
):
1716 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1717 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
1718 self
.memory
= memory
# type: dict[int, int]
1720 def load_byte(self
, addr
):
1721 # type: (int) -> int
1722 addr
&= GPR_VALUE_MASK
1723 return self
.memory
.get(addr
, 0) & 0xFF
1725 def store_byte(self
, addr
, value
):
1726 # type: (int, int) -> None
1727 addr
&= GPR_VALUE_MASK
1729 self
.memory
[addr
] = value
1731 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
1732 # type: (int, int, bool) -> int
1733 if addr
% size_in_bytes
!= 0:
1734 raise ValueError(f
"address not aligned: {hex(addr)} "
1735 f
"required alignment: {size_in_bytes}")
1737 for i
in range(size_in_bytes
):
1738 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
1739 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
1740 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
1743 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
1744 # type: (int, int, int) -> None
1745 if addr
% size_in_bytes
!= 0:
1746 raise ValueError(f
"address not aligned: {hex(addr)} "
1747 f
"required alignment: {size_in_bytes}")
1748 for i
in range(size_in_bytes
):
1749 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
1751 def _memory__repr(self
):
1753 if len(self
.memory
) == 0:
1755 keys
= sorted(self
.memory
.keys(), reverse
=True)
1756 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
1757 items
= [] # type: list[str]
1758 while len(keys
) != 0:
1760 if (len(keys
) >= CHUNK_SIZE
1761 and addr
% CHUNK_SIZE
== 0
1762 and keys
[-CHUNK_SIZE
:]
1763 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
1764 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
1765 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1766 keys
[-CHUNK_SIZE
:] = ()
1768 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1770 return f
"{{{items[0]}}}"
1771 items_str
= ",\n".join(items
)
1772 return f
"{{\n{items_str}}}"
1774 def _ssa_vals__repr(self
):
1776 if len(self
.ssa_vals
) == 0:
1778 items
= [] # type: list[str]
1780 for k
, v
in self
.ssa_vals
.items():
1781 element_strs
= [] # type: list[str]
1782 for i
, el
in enumerate(v
):
1783 if i
% CHUNK_SIZE
!= 0:
1784 element_strs
.append(" " + hex(el
))
1786 element_strs
.append("\n " + hex(el
))
1787 if len(element_strs
) <= CHUNK_SIZE
:
1788 element_strs
[0] = element_strs
[0].lstrip()
1789 if len(element_strs
) == 1:
1790 element_strs
.append("")
1791 v_str
= ",".join(element_strs
)
1792 items
.append(f
"{k!r}: ({v_str})")
1793 if len(items
) == 1 and "\n" not in items
[0]:
1794 return f
"{{{items[0]}}}"
1795 items_str
= ",\n".join(items
)
1796 return f
"{{\n{items_str},\n}}"
1800 field_vals
= [] # type: list[str]
1801 for name
in fields(self
):
1803 value
= getattr(self
, name
)
1804 except AttributeError:
1805 field_vals
.append(f
"{name}=<not set>")
1807 repr_fn
= getattr(self
, f
"_{name}__repr", None)
1808 if callable(repr_fn
):
1809 field_vals
.append(f
"{name}={repr_fn()}")
1811 field_vals
.append(f
"{name}={value!r}")
1812 field_vals_str
= ", ".join(field_vals
)
1813 return f
"PreRASimState({field_vals_str})"