1 from contextlib
import contextmanager
3 from abc
import ABCMeta
, abstractmethod
4 from enum
import Enum
, unique
5 from functools
import lru_cache
, total_ordering
6 from io
import StringIO
7 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
8 Mapping
, Sequence
, TypeVar
, Union
, overload
)
9 from weakref
import WeakValueDictionary
as _WeakVDict
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import fields
, plain_data
14 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
16 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
21 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
22 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
28 self
.ops
= [] # type: list[Op]
29 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
30 self
.__next
_name
_suffix
= 2
32 def _add_op_with_unused_name(self
, op
, name
=""):
33 # type: (Op, str) -> str
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op
, "name"):
37 raise ValueError("Op already named")
40 if name
!= "" and name
not in self
.__op
_names
:
41 self
.__op
_names
[name
] = op
43 name
= orig_name
+ str(self
.__next
_name
_suffix
)
44 self
.__next
_name
_suffix
+= 1
50 def ops_to_str(self
, as_python_literal
=False, wrap_width
=63,
51 python_indent
=" ", indent
=" "):
52 # type: (bool, int, str, str) -> str
53 l
= [] # type: list[str]
55 l
.append(op
.__repr
__(wrap_width
=wrap_width
, indent
=indent
))
58 l
= [python_indent
+ "\""]
61 l
.append(f
"\\n\"\n{python_indent}\"")
64 elif ch
.isascii() and ch
.isprintable():
67 l
.append(repr(ch
).strip("\"'"))
70 empty_end
= f
"\"\n{python_indent}\"\""
71 if retval
.endswith(empty_end
):
72 retval
= retval
[:-len(empty_end
)]
75 def append_op(self
, op
):
78 raise ValueError("can't add Op to wrong Fn")
81 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
85 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
86 self
.append_op(retval
)
90 # type: (BaseSimState) -> None
94 def gen_asm(self
, state
):
95 # type: (GenAsmState) -> None
99 def pre_ra_insert_copies(self
):
101 orig_ops
= list(self
.ops
)
102 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs
= {} # type: dict[SSAVal, Op]
106 for i
in range(len(op
.input_vals
)):
107 inp
= copied_outputs
[op
.input_vals
[i
]]
108 if inp
.ty
.base_ty
is BaseTy
.I64
:
109 maxvl
= inp
.ty
.reg_len
110 if inp
.ty
.reg_len
!= 1:
111 setvl
= self
.append_new_op(
112 OpKind
.SetVLI
, immediates
=[maxvl
],
113 name
=f
"{op.name}.inp{i}.setvl")
114 vl
= setvl
.outputs
[0]
115 mv
= self
.append_new_op(
116 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
117 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
119 mv
= self
.append_new_op(
120 OpKind
.CopyToReg
, input_vals
=[inp
],
121 name
=f
"{op.name}.inp{i}.copy")
122 op
.input_vals
[i
] = mv
.outputs
[0]
123 elif inp
.ty
.base_ty
is BaseTy
.CA \
124 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
128 if inp
in setvli_outputs
:
129 setvl
= self
.append_new_op(
131 immediates
=setvli_outputs
[inp
].immediates
,
132 name
=f
"{op.name}.inp{i}.setvl")
133 inp
= setvl
.outputs
[0]
134 op
.input_vals
[i
] = inp
136 assert_never(inp
.ty
.base_ty
)
138 for i
, out
in enumerate(op
.outputs
):
139 if op
.kind
is OpKind
.SetVLI
:
140 setvli_outputs
[out
] = op
141 if out
.ty
.base_ty
is BaseTy
.I64
:
142 maxvl
= out
.ty
.reg_len
143 if out
.ty
.reg_len
!= 1:
144 setvl
= self
.append_new_op(
145 OpKind
.SetVLI
, immediates
=[maxvl
],
146 name
=f
"{op.name}.out{i}.setvl")
147 vl
= setvl
.outputs
[0]
148 mv
= self
.append_new_op(
149 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
150 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
152 mv
= self
.append_new_op(
153 OpKind
.CopyFromReg
, input_vals
=[out
],
154 name
=f
"{op.name}.out{i}.copy")
155 copied_outputs
[out
] = mv
.outputs
[0]
156 elif out
.ty
.base_ty
is BaseTy
.CA \
157 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs
[out
] = out
161 assert_never(out
.ty
.base_ty
)
168 value
: Literal
[0, 1] # type: ignore
170 def __new__(cls
, value
):
171 # type: (int) -> OpStage
173 if value
not in (0, 1):
174 raise ValueError("invalid value")
175 retval
= object.__new
__(cls
)
176 retval
._value
_ = value
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
200 return f
"OpStage.{self._name_}"
202 def __lt__(self
, other
):
203 # type: (OpStage | object) -> bool
204 if isinstance(other
, OpStage
):
205 return self
.value
< other
.value
206 return NotImplemented
209 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
212 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
215 class ProgramPoint(metaclass
=InternedMeta
):
216 __slots__
= "op_index", "stage"
218 def __init__(self
, op_index
, stage
):
219 # type: (int, OpStage) -> None
220 self
.op_index
= op_index
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
229 return self
.op_index
* 2 + self
.stage
.value
232 def from_int_value(int_value
):
233 # type: (int) -> ProgramPoint
234 op_index
, stage
= divmod(int_value
, 2)
235 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
237 def next(self
, steps
=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
241 def prev(self
, steps
=1):
242 # type: (int) -> ProgramPoint
243 return self
.next(steps
=-steps
)
245 def __lt__(self
, other
):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other
, ProgramPoint
):
248 return NotImplemented
249 if self
.op_index
!= other
.op_index
:
250 return self
.op_index
< other
.op_index
251 return self
.stage
< other
.stage
255 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
258 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
260 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
261 __slots__
= "start", "stop"
263 def __init__(self
, start
, stop
):
264 # type: (ProgramPoint, ProgramPoint) -> None
269 def int_value_range(self
):
271 return range(self
.start
.int_value
, self
.stop
.int_value
)
274 def from_int_value_range(int_value_range
):
275 # type: (range) -> ProgramRange
276 if int_value_range
.step
!= 1:
277 raise ValueError("int_value_range must have step == 1")
279 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
280 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
283 def __getitem__(self
, __idx
):
284 # type: (int) -> ProgramPoint
288 def __getitem__(self
, __idx
):
289 # type: (slice) -> ProgramRange
292 def __getitem__(self
, __idx
):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
295 if isinstance(v
, int):
296 return ProgramPoint
.from_int_value(v
)
297 return ProgramRange
.from_int_value_range(v
)
301 return len(self
.int_value_range
)
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
309 start
= repr(self
.start
).lstrip("<").rstrip(">")
310 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
311 return f
"<range:{start}..{stop}>"
314 @plain_data(frozen
=True, unsafe_hash
=True)
316 class SSAValSubReg(metaclass
=InternedMeta
):
317 __slots__
= "ssa_val", "reg_idx"
319 def __init__(self
, ssa_val
, reg_idx
):
320 # type: (SSAVal, int) -> None
321 if reg_idx
< 0 or reg_idx
>= ssa_val
.ty
.reg_len
:
322 raise ValueError("reg_idx out of range")
323 self
.ssa_val
= ssa_val
324 self
.reg_idx
= reg_idx
327 @plain_data(frozen
=True, eq
=False, repr=False)
330 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
331 "def_program_ranges", "use_program_points",
332 "all_program_points")
334 def __init__(self
, fn
):
337 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
338 self
.all_program_points
= ProgramRange(
339 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
340 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
341 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
342 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
343 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
344 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
346 for use
in op
.input_uses
:
347 uses
[use
.ssa_val
].add(use
)
348 use_program_point
= self
.__get
_use
_program
_point
(use
)
349 use_program_points
[use
] = use_program_point
350 live_range_stops
[use
.ssa_val
] = max(
351 live_range_stops
[use
.ssa_val
], use_program_point
.next())
352 for out
in op
.outputs
:
354 def_program_range
= self
.__get
_def
_program
_range
(out
)
355 def_program_ranges
[out
] = def_program_range
356 live_range_stops
[out
] = def_program_range
.stop
357 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
358 self
.def_program_ranges
= FMap(def_program_ranges
)
359 self
.use_program_points
= FMap(use_program_points
)
360 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
361 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
362 for ssa_val
in uses
.keys():
363 live_ranges
[ssa_val
] = live_range
= ProgramRange(
364 start
=self
.def_program_ranges
[ssa_val
].start
,
365 stop
=live_range_stops
[ssa_val
])
366 for program_point
in live_range
:
367 live_at
[program_point
].add(ssa_val
)
368 self
.live_ranges
= FMap(live_ranges
)
369 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
370 self
.copies
# initialize
371 self
.const_ssa_vals
# initialize
372 self
.const_ssa_val_sub_regs
# initialize
374 def __get_def_program_range(self
, ssa_val
):
375 # type: (SSAVal) -> ProgramRange
376 write_stage
= ssa_val
.defining_descriptor
.write_stage
377 start
= ProgramPoint(
378 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
379 # always include late stage of ssa_val.op, to ensure outputs always
380 # overlap all other outputs.
381 # stop is exclusive, so we need the next program point.
382 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
383 return ProgramRange(start
=start
, stop
=stop
)
385 def __get_use_program_point(self
, ssa_use
):
386 # type: (SSAUse) -> ProgramPoint
387 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
388 "assumed here, ensured by GenericOpProperties.__init__"
390 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
392 def __eq__(self
, other
):
393 # type: (FnAnalysis | Any) -> bool
394 if isinstance(other
, FnAnalysis
):
395 return self
.fn
== other
.fn
396 return NotImplemented
404 return "<FnAnalysis>"
408 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
409 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
410 a copy of, looking through all layers of copies. The map excludes all
411 SSAValSubRegs that aren't copies of other SSAValSubRegs.
413 retval
= {} # type: dict[SSAValSubReg, SSAValSubReg]
414 for op
in self
.op_indexes
.keys():
415 if not op
.properties
.is_copy
:
417 copy_reg_len
= op
.properties
.copy_reg_len
418 copy_inputs
= [] # type: list[SSAValSubReg]
419 for inp
in op
.input_vals
[:op
.properties
.copy_inputs_len
]:
420 for inp_sub_reg
in inp
.ssa_val_sub_regs
:
421 # propagate copies of copies
422 inp_sub_reg
= retval
.get(inp_sub_reg
, inp_sub_reg
)
423 copy_inputs
.append(inp_sub_reg
)
424 assert len(copy_inputs
) == copy_reg_len
, "logic error"
425 copy_outputs
= [] # type: list[SSAValSubReg]
426 for out
in op
.outputs
[:op
.properties
.copy_outputs_len
]:
427 copy_outputs
.extend(out
.ssa_val_sub_regs
)
428 assert len(copy_outputs
) == copy_reg_len
, "logic error"
429 for inp
, out
in zip(copy_inputs
, copy_outputs
):
434 def const_ssa_vals(self
):
435 # type: () -> FMap[SSAVal, tuple[int, ...]]
436 state
= ConstPropagationState(
437 ssa_vals
={}, memory
={}, skipped_ops
=OSet())
439 return FMap(state
.ssa_vals
)
442 def const_ssa_val_sub_regs(self
):
443 # type: () -> FMap[SSAValSubReg, int]
444 retval
= {} # type: dict[SSAValSubReg, int]
445 for ssa_val
, const_val
in self
.const_ssa_vals
.items():
446 assert ssa_val
.ty
.reg_len
== len(const_val
), "logic error"
447 for reg_idx
, v
in enumerate(const_val
):
448 retval
[SSAValSubReg(ssa_val
, reg_idx
)] = v
451 def are_always_equal(self
, a
, b
):
452 # type: (SSAValSubReg, SSAValSubReg) -> bool
453 """check if a and b are known to be always equal to each other.
454 This means they can be allocated to the same location if other
455 constraints don't prevent that.
457 this can happen for a number of reasons, such as:
458 * a and b are copies of the same thing
459 * a and b are known to be constants and they have the same value
461 if a
.ssa_val
.base_ty
!= b
.ssa_val
.base_ty
:
462 return False # can't be equal, they have different types
463 # look through copies
464 a
= self
.copies
.get(a
, a
)
465 b
= self
.copies
.get(b
, b
)
468 # check if they have the same constant value
470 a_const_val
= self
.const_ssa_val_sub_regs
[a
]
471 b_const_val
= self
.const_ssa_val_sub_regs
[b
]
472 if a_const_val
== b_const_val
:
484 VL_MAXVL
= enum
.auto()
487 def only_scalar(self
):
489 if self
is BaseTy
.I64
:
491 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
497 def max_reg_len(self
):
499 if self
is BaseTy
.I64
:
501 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
507 return "BaseTy." + self
._name
_
510 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
512 class Ty(metaclass
=InternedMeta
):
513 __slots__
= "base_ty", "reg_len"
516 def validate(base_ty
, reg_len
):
517 # type: (BaseTy, int) -> str | None
518 """ return a string with the error if the combination is invalid,
519 otherwise return None
521 if base_ty
.only_scalar
and reg_len
!= 1:
522 return f
"can't create a vector of an only-scalar type: {base_ty}"
523 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
524 return "reg_len out of range"
527 def __init__(self
, base_ty
, reg_len
):
528 # type: (BaseTy, int) -> None
529 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
531 raise ValueError(msg
)
532 self
.base_ty
= base_ty
533 self
.reg_len
= reg_len
537 if self
.reg_len
!= 1:
538 reg_len
= f
"*{self.reg_len}"
541 return f
"<{self.base_ty._name_}{reg_len}>"
548 StackI64
= enum
.auto()
550 VL_MAXVL
= enum
.auto()
555 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
557 if self
is LocKind
.CA
:
559 if self
is LocKind
.VL_MAXVL
:
560 return BaseTy
.VL_MAXVL
567 if self
is LocKind
.StackI64
:
569 if self
is LocKind
.GPR
or self
is LocKind
.CA \
570 or self
is LocKind
.VL_MAXVL
:
571 return self
.base_ty
.max_reg_len
576 return "LocKind." + self
._name
_
581 class LocSubKind(Enum
):
582 BASE_GPR
= enum
.auto()
583 SV_EXTRA2_VGPR
= enum
.auto()
584 SV_EXTRA2_SGPR
= enum
.auto()
585 SV_EXTRA3_VGPR
= enum
.auto()
586 SV_EXTRA3_SGPR
= enum
.auto()
587 StackI64
= enum
.auto()
589 VL_MAXVL
= enum
.auto()
593 # type: () -> LocKind
594 # pyright fails typechecking when using `in` here:
595 # reported: https://github.com/microsoft/pyright/issues/4102
596 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
597 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
598 LocSubKind
.SV_EXTRA3_SGPR
):
600 if self
is LocSubKind
.StackI64
:
601 return LocKind
.StackI64
602 if self
is LocSubKind
.CA
:
604 if self
is LocSubKind
.VL_MAXVL
:
605 return LocKind
.VL_MAXVL
610 return self
.kind
.base_ty
613 def allocatable_locs(self
, ty
):
614 # type: (Ty) -> LocSet
615 if ty
.base_ty
!= self
.base_ty
:
616 raise ValueError("type mismatch")
617 if self
is LocSubKind
.BASE_GPR
:
619 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
620 starts
= range(0, 128, 2)
621 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
623 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
624 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
626 elif self
is LocSubKind
.StackI64
:
627 starts
= range(LocKind
.StackI64
.loc_count
)
628 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
629 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
632 retval
= [] # type: list[Loc]
634 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
638 for special_loc
in SPECIAL_GPRS
:
639 if loc
.conflicts(special_loc
):
644 return LocSet(retval
)
647 return "LocSubKind." + self
._name
_
650 @plain_data(frozen
=True, unsafe_hash
=True)
652 class GenericTy(metaclass
=InternedMeta
):
653 __slots__
= "base_ty", "is_vec"
655 def __init__(self
, base_ty
, is_vec
):
656 # type: (BaseTy, bool) -> None
657 self
.base_ty
= base_ty
658 if base_ty
.only_scalar
and is_vec
:
659 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
662 def instantiate(self
, maxvl
):
664 # here's where subvl and elwid would be accounted for
666 return Ty(self
.base_ty
, maxvl
)
667 return Ty(self
.base_ty
, 1)
669 def can_instantiate_to(self
, ty
):
671 if self
.base_ty
!= ty
.base_ty
:
675 return ty
.reg_len
== 1
678 @plain_data(frozen
=True, unsafe_hash
=True)
680 class Loc(metaclass
=InternedMeta
):
681 __slots__
= "kind", "start", "reg_len"
684 def validate(kind
, start
, reg_len
):
685 # type: (LocKind, int, int) -> str | None
686 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
689 if reg_len
> kind
.loc_count
:
690 return "invalid reg_len"
691 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
692 return "start not in valid range"
696 def try_make(kind
, start
, reg_len
):
697 # type: (LocKind, int, int) -> Loc | None
698 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
701 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
703 def __init__(self
, kind
, start
, reg_len
):
704 # type: (LocKind, int, int) -> None
705 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
707 raise ValueError(msg
)
709 self
.reg_len
= reg_len
712 def conflicts(self
, other
):
713 # type: (Loc) -> bool
714 return (self
.kind
== other
.kind
715 and self
.start
< other
.stop
and other
.start
< self
.stop
)
718 def make_ty(kind
, reg_len
):
719 # type: (LocKind, int) -> Ty
720 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
725 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
730 return self
.start
+ self
.reg_len
732 def try_concat(self
, *others
):
733 # type: (*Loc | None) -> Loc | None
734 reg_len
= self
.reg_len
737 if other
is None or other
.kind
!= self
.kind
:
739 if stop
!= other
.start
:
742 reg_len
+= other
.reg_len
743 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
745 def get_subloc_at_offset(self
, subloc_ty
, offset
):
746 # type: (Ty, int) -> Loc
747 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
748 raise ValueError("BaseTy mismatch")
749 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
750 raise ValueError("invalid sub-Loc: offset and/or "
751 "subloc_ty.reg_len out of range")
752 return Loc(kind
=self
.kind
,
753 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
757 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
758 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
759 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
760 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
765 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
766 def __init__(self
, __locs
=()):
767 # type: (Iterable[Loc]) -> None
768 super().__init
__(__locs
)
769 if isinstance(__locs
, LocSet
):
770 self
.__starts
= __locs
.starts
771 self
.__ty
= __locs
.ty
773 starts
= {i
: BitSet() for i
in LocKind
}
774 ty
= None # type: None | Ty
779 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
780 starts
[loc
.kind
].add(loc
.start
)
781 self
.__starts
= FMap(
782 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
787 # type: () -> FMap[LocKind, FBitSet]
792 # type: () -> Ty | None
797 # type: () -> FMap[LocKind, FBitSet]
802 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
806 # type: () -> AbstractSet[LocKind]
807 return self
.starts
.keys()
811 # type: () -> int | None
814 return self
.ty
.reg_len
818 # type: () -> BaseTy | None
821 return self
.ty
.base_ty
823 def concat(self
, *others
):
824 # type: (*LocSet) -> LocSet
827 base_ty
= self
.ty
.base_ty
828 reg_len
= self
.ty
.reg_len
829 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
833 if other
.ty
.base_ty
!= base_ty
:
835 for kind
, other_starts
in other
.starts
.items():
836 if kind
not in starts
:
838 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
839 if starts
[kind
] == 0:
843 reg_len
+= other
.ty
.reg_len
846 # type: () -> Iterable[Loc]
847 for kind
, v
in starts
.items():
849 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
852 return LocSet(locs())
854 @lru_cache(maxsize
=None, typed
=True)
855 def max_conflicts_with(self
, other
):
856 # type: (LocSet | Loc) -> int
857 """the largest number of Locs in `self` that a single Loc
858 from `other` can conflict with
860 if isinstance(other
, LocSet
):
861 return max(self
.max_conflicts_with(i
) for i
in other
)
863 return sum(other
.conflicts(i
) for i
in self
)
866 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
869 @plain_data(frozen
=True, unsafe_hash
=True)
871 class GenericOperandDesc(metaclass
=InternedMeta
):
872 """generic Op operand descriptor"""
873 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
877 self
, ty
, # type: GenericTy
878 sub_kinds
, # type: Iterable[LocSubKind]
880 fixed_loc
=None, # type: Loc | None
881 tied_input_index
=None, # type: int | None
882 spread
=False, # type: bool
883 write_stage
=OpStage
.Early
, # type: OpStage
885 # type: (...) -> None
887 self
.sub_kinds
= OFSet(sub_kinds
)
888 if len(self
.sub_kinds
) == 0:
889 raise ValueError("sub_kinds can't be empty")
890 self
.fixed_loc
= fixed_loc
891 if fixed_loc
is not None:
892 if tied_input_index
is not None:
893 raise ValueError("operand can't be both tied and fixed")
894 if not ty
.can_instantiate_to(fixed_loc
.ty
):
896 f
"fixed_loc has incompatible type for given generic "
897 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
898 if len(self
.sub_kinds
) != 1:
900 "multiple sub_kinds not allowed for fixed operand")
901 for sub_kind
in self
.sub_kinds
:
902 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
904 f
"fixed_loc not in given sub_kind: "
905 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
906 for sub_kind
in self
.sub_kinds
:
907 if sub_kind
.base_ty
!= ty
.base_ty
:
908 raise ValueError(f
"sub_kind is incompatible with type: "
909 f
"sub_kind={sub_kind} ty={ty}")
910 if tied_input_index
is not None and tied_input_index
< 0:
911 raise ValueError("invalid tied_input_index")
912 self
.tied_input_index
= tied_input_index
915 if self
.tied_input_index
is not None:
916 raise ValueError("operand can't be both spread and tied")
917 if self
.fixed_loc
is not None:
918 raise ValueError("operand can't be both spread and fixed")
920 raise ValueError("operand can't be both spread and vector")
921 self
.write_stage
= write_stage
924 def ty_before_spread(self
):
925 # type: () -> GenericTy
927 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
930 def tied_to_input(self
, tied_input_index
):
931 # type: (int) -> Self
932 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
933 tied_input_index
=tied_input_index
,
934 write_stage
=self
.write_stage
)
936 def with_fixed_loc(self
, fixed_loc
):
937 # type: (Loc) -> Self
938 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
939 write_stage
=self
.write_stage
)
941 def with_write_stage(self
, write_stage
):
942 # type: (OpStage) -> Self
943 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
944 fixed_loc
=self
.fixed_loc
,
945 tied_input_index
=self
.tied_input_index
,
947 write_stage
=write_stage
)
949 def instantiate(self
, maxvl
):
950 # type: (int) -> Iterable[OperandDesc]
951 # assumes all spread operands have ty.reg_len = 1
955 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
957 def locs_before_spread():
958 # type: () -> Iterable[Loc]
959 if self
.fixed_loc
is not None:
960 if ty_before_spread
!= self
.fixed_loc
.ty
:
962 f
"instantiation failed: type mismatch with fixed_loc: "
963 f
"instantiated type: {ty_before_spread} "
964 f
"fixed_loc: {self.fixed_loc}")
967 for sub_kind
in self
.sub_kinds
:
968 yield from sub_kind
.allocatable_locs(ty_before_spread
)
969 loc_set_before_spread
= LocSet(locs_before_spread())
970 for idx
in range(rep_count
):
973 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
974 tied_input_index
=self
.tied_input_index
,
975 spread_index
=idx
, write_stage
=self
.write_stage
)
978 @plain_data(frozen
=True, unsafe_hash
=True)
980 class OperandDesc(metaclass
=InternedMeta
):
981 """Op operand descriptor"""
982 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
985 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
987 # type: (LocSet, int | None, int | None, OpStage) -> None
988 if len(loc_set_before_spread
) == 0:
989 raise ValueError("loc_set_before_spread must not be empty")
990 self
.loc_set_before_spread
= loc_set_before_spread
991 self
.tied_input_index
= tied_input_index
992 if self
.tied_input_index
is not None and spread_index
is not None:
993 raise ValueError("operand can't be both spread and tied")
994 self
.spread_index
= spread_index
995 self
.write_stage
= write_stage
998 def ty_before_spread(self
):
1000 ty
= self
.loc_set_before_spread
.ty
1001 assert ty
is not None, (
1002 "__init__ checked that the LocSet isn't empty, "
1003 "non-empty LocSets should always have ty set")
1008 """ Ty after any spread is applied """
1009 if self
.spread_index
is not None:
1010 # assumes all spread operands have ty.reg_len = 1
1011 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
1012 return self
.ty_before_spread
1015 def reg_offset_in_unspread(self
):
1016 """ the number of reg-sized slots in the unspread Loc before self's Loc
1018 e.g. if the unspread Loc containing self is:
1019 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1020 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1021 then reg_offset_into_unspread == 2 == 10 - 8
1023 if self
.spread_index
is None:
1025 return self
.spread_index
* self
.ty
.reg_len
1028 OD_BASE_SGPR
= GenericOperandDesc(
1029 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1030 sub_kinds
=[LocSubKind
.BASE_GPR
])
1031 OD_EXTRA3_SGPR
= GenericOperandDesc(
1032 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1033 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
1034 OD_EXTRA3_VGPR
= GenericOperandDesc(
1035 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1036 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
1037 OD_EXTRA2_SGPR
= GenericOperandDesc(
1038 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1039 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
1040 OD_EXTRA2_VGPR
= GenericOperandDesc(
1041 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1042 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
1043 OD_CA
= GenericOperandDesc(
1044 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
1045 sub_kinds
=[LocSubKind
.CA
])
1046 OD_VL
= GenericOperandDesc(
1047 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
1048 sub_kinds
=[LocSubKind
.VL_MAXVL
])
1051 @plain_data(frozen
=True, unsafe_hash
=True)
1053 class GenericOpProperties(metaclass
=InternedMeta
):
1054 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
1055 "is_copy", "is_load_immediate", "has_side_effects")
1058 self
, demo_asm
, # type: str
1059 inputs
, # type: Iterable[GenericOperandDesc]
1060 outputs
, # type: Iterable[GenericOperandDesc]
1061 immediates
=(), # type: Iterable[range]
1062 is_copy
=False, # type: bool
1063 is_load_immediate
=False, # type: bool
1064 has_side_effects
=False, # type: bool
1066 # type: (...) -> None
1067 self
.demo_asm
= demo_asm
# type: str
1068 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
1069 for inp
in self
.inputs
:
1070 if inp
.tied_input_index
is not None:
1072 f
"tied_input_index is not allowed on inputs: {inp}")
1073 if inp
.write_stage
is not OpStage
.Early
:
1075 f
"write_stage is not allowed on inputs: {inp}")
1076 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
1077 fixed_locs
= [] # type: list[tuple[Loc, int]]
1078 for idx
, out
in enumerate(self
.outputs
):
1079 if out
.tied_input_index
is not None:
1080 if out
.tied_input_index
>= len(self
.inputs
):
1081 raise ValueError(f
"tied_input_index out of range: {out}")
1082 tied_inp
= self
.inputs
[out
.tied_input_index
]
1083 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
1084 .with_write_stage(out
.write_stage
)
1085 if expected_out
!= out
:
1086 raise ValueError(f
"output can't be tied to non-equivalent "
1087 f
"input: {out} tied to {tied_inp}")
1088 if out
.fixed_loc
is not None:
1089 for other_fixed_loc
, other_idx
in fixed_locs
:
1090 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
1093 f
"conflicting fixed_locs: outputs[{idx}] and "
1094 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1095 f
"with {other_fixed_loc}")
1096 fixed_locs
.append((out
.fixed_loc
, idx
))
1097 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
1098 self
.is_copy
= is_copy
# type: bool
1099 self
.is_load_immediate
= is_load_immediate
# type: bool
1100 self
.has_side_effects
= has_side_effects
# type: bool
1103 @plain_data(frozen
=True, unsafe_hash
=True)
1105 class OpProperties(metaclass
=InternedMeta
):
1106 __slots__
= "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1108 def __init__(self
, kind
, maxvl
):
1109 # type: (OpKind, int) -> None
1110 self
.kind
= kind
# type: OpKind
1111 inputs
= [] # type: list[OperandDesc]
1112 for inp
in self
.generic
.inputs
:
1113 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1114 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1115 outputs
= [] # type: list[OperandDesc]
1116 for out
in self
.generic
.outputs
:
1117 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1118 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1119 self
.maxvl
= maxvl
# type: int
1120 copy_input_reg_len
= 0
1121 for inp
in self
.inputs
[:self
.copy_inputs_len
]:
1122 copy_input_reg_len
+= inp
.ty
.reg_len
1123 copy_output_reg_len
= 0
1124 for out
in self
.outputs
[:self
.copy_outputs_len
]:
1125 copy_output_reg_len
+= out
.ty
.reg_len
1126 if copy_input_reg_len
!= copy_output_reg_len
:
1127 raise ValueError(f
"invalid copy: copy's input reg len must "
1128 f
"match its output reg len: "
1129 f
"{copy_input_reg_len} != {copy_output_reg_len}")
1130 self
.copy_reg_len
= copy_input_reg_len
1134 # type: () -> GenericOpProperties
1135 return self
.kind
.properties
1138 def immediates(self
):
1139 # type: () -> tuple[range, ...]
1140 return self
.generic
.immediates
1145 return self
.generic
.demo_asm
1150 return self
.generic
.is_copy
1153 def is_load_immediate(self
):
1155 return self
.generic
.is_load_immediate
1158 def has_side_effects(self
):
1160 return self
.generic
.has_side_effects
1163 def copy_inputs_len(self
):
1165 if not self
.is_copy
:
1167 if self
.inputs
[0].spread_index
is None:
1170 for i
, inp
in enumerate(self
.inputs
):
1171 if inp
.spread_index
!= i
:
1177 def copy_outputs_len(self
):
1179 if not self
.is_copy
:
1181 if self
.outputs
[0].spread_index
is None:
1184 for i
, out
in enumerate(self
.outputs
):
1185 if out
.spread_index
!= i
:
1191 IMM_S16
= range(-1 << 15, 1 << 15)
1193 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1194 _SIM_FN2
= Callable
[[], _SIM_FN
]
1195 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1196 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1197 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1198 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1204 def __init__(self
, properties
):
1205 # type: (GenericOpProperties) -> None
1207 self
.__properties
= properties
1210 def properties(self
):
1211 # type: () -> GenericOpProperties
1212 return self
.__properties
1214 def instantiate(self
, maxvl
):
1215 # type: (int) -> OpProperties
1216 return OpProperties(self
, maxvl
=maxvl
)
1220 return "OpKind." + self
._name
_
1224 # type: () -> _SIM_FN
1225 return _SIM_FNS
[self
.properties
]()
1229 # type: () -> _GEN_ASM_FN
1230 return _GEN_ASMS
[self
.properties
]()
1233 def __clearca_sim(op
, state
):
1234 # type: (Op, BaseSimState) -> None
1235 state
[op
.outputs
[0]] = False,
1238 def __clearca_gen_asm(op
, state
):
1239 # type: (Op, GenAsmState) -> None
1240 state
.writeln("addic 0, 0, 0")
1241 ClearCA
= GenericOpProperties(
1242 demo_asm
="addic 0, 0, 0",
1244 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1246 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1247 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1250 def __setca_sim(op
, state
):
1251 # type: (Op, BaseSimState) -> None
1252 state
[op
.outputs
[0]] = True,
1255 def __setca_gen_asm(op
, state
):
1256 # type: (Op, GenAsmState) -> None
1257 state
.writeln("subfc 0, 0, 0")
1258 SetCA
= GenericOpProperties(
1259 demo_asm
="subfc 0, 0, 0",
1261 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1263 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1264 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1267 def __svadde_sim(op
, state
):
1268 # type: (Op, BaseSimState) -> None
1269 RA
= state
[op
.input_vals
[0]]
1270 RB
= state
[op
.input_vals
[1]]
1271 carry
, = state
[op
.input_vals
[2]]
1272 VL
, = state
[op
.input_vals
[3]]
1273 RT
= [] # type: list[int]
1275 v
= RA
[i
] + RB
[i
] + carry
1276 RT
.append(v
& GPR_VALUE_MASK
)
1277 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1278 state
[op
.outputs
[0]] = tuple(RT
)
1279 state
[op
.outputs
[1]] = carry
,
1282 def __svadde_gen_asm(op
, state
):
1283 # type: (Op, GenAsmState) -> None
1284 RT
= state
.vgpr(op
.outputs
[0])
1285 RA
= state
.vgpr(op
.input_vals
[0])
1286 RB
= state
.vgpr(op
.input_vals
[1])
1287 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1288 SvAddE
= GenericOpProperties(
1289 demo_asm
="sv.adde *RT, *RA, *RB",
1290 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1291 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1293 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1294 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1297 def __addze_sim(op
, state
):
1298 # type: (Op, BaseSimState) -> None
1299 RA
, = state
[op
.input_vals
[0]]
1300 carry
, = state
[op
.input_vals
[1]]
1302 RT
= v
& GPR_VALUE_MASK
1303 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1304 state
[op
.outputs
[0]] = RT
,
1305 state
[op
.outputs
[1]] = carry
,
1308 def __addze_gen_asm(op
, state
):
1309 # type: (Op, GenAsmState) -> None
1310 RT
= state
.vgpr(op
.outputs
[0])
1311 RA
= state
.vgpr(op
.input_vals
[0])
1312 state
.writeln(f
"addze {RT}, {RA}")
1313 AddZE
= GenericOpProperties(
1314 demo_asm
="addze RT, RA",
1315 inputs
=[OD_BASE_SGPR
, OD_CA
],
1316 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1318 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1319 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1322 def __svsubfe_sim(op
, state
):
1323 # type: (Op, BaseSimState) -> None
1324 RA
= state
[op
.input_vals
[0]]
1325 RB
= state
[op
.input_vals
[1]]
1326 carry
, = state
[op
.input_vals
[2]]
1327 VL
, = state
[op
.input_vals
[3]]
1328 RT
= [] # type: list[int]
1330 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1331 RT
.append(v
& GPR_VALUE_MASK
)
1332 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1333 state
[op
.outputs
[0]] = tuple(RT
)
1334 state
[op
.outputs
[1]] = carry
,
1337 def __svsubfe_gen_asm(op
, state
):
1338 # type: (Op, GenAsmState) -> None
1339 RT
= state
.vgpr(op
.outputs
[0])
1340 RA
= state
.vgpr(op
.input_vals
[0])
1341 RB
= state
.vgpr(op
.input_vals
[1])
1342 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1343 SvSubFE
= GenericOpProperties(
1344 demo_asm
="sv.subfe *RT, *RA, *RB",
1345 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1346 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1348 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1349 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1352 def __svandvs_sim(op
, state
):
1353 # type: (Op, BaseSimState) -> None
1354 RA
= state
[op
.input_vals
[0]]
1355 RB
, = state
[op
.input_vals
[1]]
1356 VL
, = state
[op
.input_vals
[2]]
1357 RT
= [] # type: list[int]
1359 RT
.append(RA
[i
] & RB
& GPR_VALUE_MASK
)
1360 state
[op
.outputs
[0]] = tuple(RT
)
1363 def __svandvs_gen_asm(op
, state
):
1364 # type: (Op, GenAsmState) -> None
1365 RT
= state
.vgpr(op
.outputs
[0])
1366 RA
= state
.vgpr(op
.input_vals
[0])
1367 RB
= state
.sgpr(op
.input_vals
[1])
1368 state
.writeln(f
"sv.and {RT}, {RA}, {RB}")
1369 SvAndVS
= GenericOpProperties(
1370 demo_asm
="sv.and *RT, *RA, RB",
1371 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1372 outputs
=[OD_EXTRA3_VGPR
],
1374 _SIM_FNS
[SvAndVS
] = lambda: OpKind
.__svandvs
_sim
1375 _GEN_ASMS
[SvAndVS
] = lambda: OpKind
.__svandvs
_gen
_asm
1378 def __svmaddedu_sim(op
, state
):
1379 # type: (Op, BaseSimState) -> None
1380 RA
= state
[op
.input_vals
[0]]
1381 RB
, = state
[op
.input_vals
[1]]
1382 carry
, = state
[op
.input_vals
[2]]
1383 VL
, = state
[op
.input_vals
[3]]
1384 RT
= [] # type: list[int]
1386 v
= RA
[i
] * RB
+ carry
1387 RT
.append(v
& GPR_VALUE_MASK
)
1388 carry
= v
>> GPR_SIZE_IN_BITS
1389 state
[op
.outputs
[0]] = tuple(RT
)
1390 state
[op
.outputs
[1]] = carry
,
1393 def __svmaddedu_gen_asm(op
, state
):
1394 # type: (Op, GenAsmState) -> None
1395 RT
= state
.vgpr(op
.outputs
[0])
1396 RA
= state
.vgpr(op
.input_vals
[0])
1397 RB
= state
.sgpr(op
.input_vals
[1])
1398 RC
= state
.sgpr(op
.input_vals
[2])
1399 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1400 SvMAddEDU
= GenericOpProperties(
1401 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1402 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1403 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1405 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1406 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1409 def __sradi_sim(op
, state
):
1410 # type: (Op, BaseSimState) -> None
1411 rs
, = state
[op
.input_vals
[0]]
1412 imm
= op
.immediates
[0]
1413 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1414 rs
-= 1 << GPR_SIZE_IN_BITS
1416 RA
= v
& GPR_VALUE_MASK
1417 CA
= (RA
<< imm
) != rs
1418 state
[op
.outputs
[0]] = RA
,
1419 state
[op
.outputs
[1]] = CA
,
1422 def __sradi_gen_asm(op
, state
):
1423 # type: (Op, GenAsmState) -> None
1424 RA
= state
.sgpr(op
.outputs
[0])
1425 RS
= state
.sgpr(op
.input_vals
[0])
1426 imm
= op
.immediates
[0]
1427 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1428 SRADI
= GenericOpProperties(
1429 demo_asm
="sradi RA, RS, imm",
1430 inputs
=[OD_BASE_SGPR
],
1431 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1432 OD_CA
.with_write_stage(OpStage
.Late
)],
1433 immediates
=[range(GPR_SIZE_IN_BITS
)],
1435 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1436 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1439 def __setvli_sim(op
, state
):
1440 # type: (Op, BaseSimState) -> None
1441 state
[op
.outputs
[0]] = op
.immediates
[0],
1444 def __setvli_gen_asm(op
, state
):
1445 # type: (Op, GenAsmState) -> None
1446 imm
= op
.immediates
[0]
1447 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1448 SetVLI
= GenericOpProperties(
1449 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1451 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1452 immediates
=[range(1, 65)],
1453 is_load_immediate
=True,
1455 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1456 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1459 def __svli_sim(op
, state
):
1460 # type: (Op, BaseSimState) -> None
1461 VL
, = state
[op
.input_vals
[0]]
1462 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1463 state
[op
.outputs
[0]] = (imm
,) * VL
1466 def __svli_gen_asm(op
, state
):
1467 # type: (Op, GenAsmState) -> None
1468 RT
= state
.vgpr(op
.outputs
[0])
1469 imm
= op
.immediates
[0]
1470 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1471 SvLI
= GenericOpProperties(
1472 demo_asm
="sv.addi *RT, 0, imm",
1474 outputs
=[OD_EXTRA3_VGPR
],
1475 immediates
=[IMM_S16
],
1476 is_load_immediate
=True,
1478 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1479 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1482 def __li_sim(op
, state
):
1483 # type: (Op, BaseSimState) -> None
1484 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1485 state
[op
.outputs
[0]] = imm
,
1488 def __li_gen_asm(op
, state
):
1489 # type: (Op, GenAsmState) -> None
1490 RT
= state
.sgpr(op
.outputs
[0])
1491 imm
= op
.immediates
[0]
1492 state
.writeln(f
"addi {RT}, 0, {imm}")
1493 LI
= GenericOpProperties(
1494 demo_asm
="addi RT, 0, imm",
1496 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1497 immediates
=[IMM_S16
],
1498 is_load_immediate
=True,
1500 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1501 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1504 def __veccopytoreg_sim(op
, state
):
1505 # type: (Op, BaseSimState) -> None
1506 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1509 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1510 # type: (Loc, Loc, bool, GenAsmState) -> None
1511 sv
= "sv." if is_vec
else ""
1513 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1515 if src_loc
== dest_loc
:
1517 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1518 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1519 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1520 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1521 if src_loc
.kind
is LocKind
.StackI64
:
1522 if dest_loc
.kind
is LocKind
.StackI64
:
1524 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1525 elif dest_loc
.kind
is not LocKind
.GPR
:
1526 assert_never(dest_loc
.kind
)
1527 src
= state
.stack(src_loc
)
1528 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1529 state
.writeln(f
"{sv}ld {dest}, {src}")
1530 elif dest_loc
.kind
is LocKind
.StackI64
:
1531 if src_loc
.kind
is not LocKind
.GPR
:
1532 assert_never(src_loc
.kind
)
1533 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1534 dest
= state
.stack(dest_loc
)
1535 state
.writeln(f
"{sv}std {src}, {dest}")
1536 elif src_loc
.kind
is LocKind
.GPR
:
1537 if dest_loc
.kind
is not LocKind
.GPR
:
1538 assert_never(dest_loc
.kind
)
1539 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1540 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1541 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1543 assert_never(src_loc
.kind
)
1546 def __veccopytoreg_gen_asm(op
, state
):
1547 # type: (Op, GenAsmState) -> None
1548 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1550 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1551 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1552 is_vec
=True, state
=state
)
1554 VecCopyToReg
= GenericOpProperties(
1555 demo_asm
="sv.mv dest, src",
1556 inputs
=[GenericOperandDesc(
1557 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1558 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1560 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1563 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1564 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1567 def __veccopyfromreg_sim(op
, state
):
1568 # type: (Op, BaseSimState) -> None
1569 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1572 def __veccopyfromreg_gen_asm(op
, state
):
1573 # type: (Op, GenAsmState) -> None
1574 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1575 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1577 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1578 is_vec
=True, state
=state
)
1579 VecCopyFromReg
= GenericOpProperties(
1580 demo_asm
="sv.mv dest, src",
1581 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1582 outputs
=[GenericOperandDesc(
1583 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1584 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1585 write_stage
=OpStage
.Late
,
1589 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1590 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1593 def __copytoreg_sim(op
, state
):
1594 # type: (Op, BaseSimState) -> None
1595 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1598 def __copytoreg_gen_asm(op
, state
):
1599 # type: (Op, GenAsmState) -> None
1600 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1602 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1603 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1604 is_vec
=False, state
=state
)
1605 CopyToReg
= GenericOpProperties(
1606 demo_asm
="mv dest, src",
1607 inputs
=[GenericOperandDesc(
1608 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1609 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1610 LocSubKind
.StackI64
],
1612 outputs
=[GenericOperandDesc(
1613 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1614 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1615 write_stage
=OpStage
.Late
,
1619 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1620 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1623 def __copyfromreg_sim(op
, state
):
1624 # type: (Op, BaseSimState) -> None
1625 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1628 def __copyfromreg_gen_asm(op
, state
):
1629 # type: (Op, GenAsmState) -> None
1630 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1631 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1633 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1634 is_vec
=False, state
=state
)
1635 CopyFromReg
= GenericOpProperties(
1636 demo_asm
="mv dest, src",
1637 inputs
=[GenericOperandDesc(
1638 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1639 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1641 outputs
=[GenericOperandDesc(
1642 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1643 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1644 LocSubKind
.StackI64
],
1645 write_stage
=OpStage
.Late
,
1649 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1650 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1653 def __concat_sim(op
, state
):
1654 # type: (Op, BaseSimState) -> None
1655 state
[op
.outputs
[0]] = tuple(
1656 state
[i
][0] for i
in op
.input_vals
[:-1])
1659 def __concat_gen_asm(op
, state
):
1660 # type: (Op, GenAsmState) -> None
1661 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1662 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1663 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1664 is_vec
=True, state
=state
)
1665 Concat
= GenericOpProperties(
1666 demo_asm
="sv.mv dest, src",
1667 inputs
=[GenericOperandDesc(
1668 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1669 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1672 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1675 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1676 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1679 def __spread_sim(op
, state
):
1680 # type: (Op, BaseSimState) -> None
1681 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1682 state
[op
.outputs
[idx
]] = inp
,
1685 def __spread_gen_asm(op
, state
):
1686 # type: (Op, GenAsmState) -> None
1687 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1688 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1689 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1690 is_vec
=True, state
=state
)
1691 Spread
= GenericOpProperties(
1692 demo_asm
="sv.mv dest, src",
1693 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1694 outputs
=[GenericOperandDesc(
1695 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1696 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1698 write_stage
=OpStage
.Late
,
1702 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1703 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1706 def __svld_sim(op
, state
):
1707 # type: (Op, BaseSimState) -> None
1708 RA
, = state
[op
.input_vals
[0]]
1709 VL
, = state
[op
.input_vals
[1]]
1710 addr
= RA
+ op
.immediates
[0]
1711 RT
= [] # type: list[int]
1713 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1714 RT
.append(v
& GPR_VALUE_MASK
)
1715 state
[op
.outputs
[0]] = tuple(RT
)
1718 def __svld_gen_asm(op
, state
):
1719 # type: (Op, GenAsmState) -> None
1720 RA
= state
.sgpr(op
.input_vals
[0])
1721 RT
= state
.vgpr(op
.outputs
[0])
1722 imm
= op
.immediates
[0]
1723 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1724 SvLd
= GenericOpProperties(
1725 demo_asm
="sv.ld *RT, imm(RA)",
1726 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1727 outputs
=[OD_EXTRA3_VGPR
],
1728 immediates
=[IMM_S16
],
1730 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1731 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1734 def __ld_sim(op
, state
):
1735 # type: (Op, BaseSimState) -> None
1736 RA
, = state
[op
.input_vals
[0]]
1737 addr
= RA
+ op
.immediates
[0]
1738 v
= state
.load(addr
)
1739 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1742 def __ld_gen_asm(op
, state
):
1743 # type: (Op, GenAsmState) -> None
1744 RA
= state
.sgpr(op
.input_vals
[0])
1745 RT
= state
.sgpr(op
.outputs
[0])
1746 imm
= op
.immediates
[0]
1747 state
.writeln(f
"ld {RT}, {imm}({RA})")
1748 Ld
= GenericOpProperties(
1749 demo_asm
="ld RT, imm(RA)",
1750 inputs
=[OD_BASE_SGPR
],
1751 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1752 immediates
=[IMM_S16
],
1754 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1755 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1758 def __svstd_sim(op
, state
):
1759 # type: (Op, BaseSimState) -> None
1760 RS
= state
[op
.input_vals
[0]]
1761 RA
, = state
[op
.input_vals
[1]]
1762 VL
, = state
[op
.input_vals
[2]]
1763 addr
= RA
+ op
.immediates
[0]
1765 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1768 def __svstd_gen_asm(op
, state
):
1769 # type: (Op, GenAsmState) -> None
1770 RS
= state
.vgpr(op
.input_vals
[0])
1771 RA
= state
.sgpr(op
.input_vals
[1])
1772 imm
= op
.immediates
[0]
1773 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1774 SvStd
= GenericOpProperties(
1775 demo_asm
="sv.std *RS, imm(RA)",
1776 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1778 immediates
=[IMM_S16
],
1779 has_side_effects
=True,
1781 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1782 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1785 def __std_sim(op
, state
):
1786 # type: (Op, BaseSimState) -> None
1787 RS
, = state
[op
.input_vals
[0]]
1788 RA
, = state
[op
.input_vals
[1]]
1789 addr
= RA
+ op
.immediates
[0]
1790 state
.store(addr
, value
=RS
)
1793 def __std_gen_asm(op
, state
):
1794 # type: (Op, GenAsmState) -> None
1795 RS
= state
.sgpr(op
.input_vals
[0])
1796 RA
= state
.sgpr(op
.input_vals
[1])
1797 imm
= op
.immediates
[0]
1798 state
.writeln(f
"std {RS}, {imm}({RA})")
1799 Std
= GenericOpProperties(
1800 demo_asm
="std RS, imm(RA)",
1801 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1803 immediates
=[IMM_S16
],
1804 has_side_effects
=True,
1806 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1807 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1810 def __funcargr3_sim(op
, state
):
1811 # type: (Op, BaseSimState) -> None
1812 pass # return value set before simulation
1815 def __funcargr3_gen_asm(op
, state
):
1816 # type: (Op, GenAsmState) -> None
1817 pass # no instructions needed
1818 FuncArgR3
= GenericOpProperties(
1821 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1822 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1824 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1825 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1828 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1829 class SSAValOrUse(metaclass
=InternedMeta
):
1830 __slots__
= "op", "operand_idx"
1832 def __init__(self
, op
, operand_idx
):
1833 # type: (Op, int) -> None
1836 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1837 raise ValueError("invalid operand_idx")
1838 self
.operand_idx
= operand_idx
1847 def descriptor_array(self
):
1848 # type: () -> tuple[OperandDesc, ...]
1852 def defining_descriptor(self
):
1853 # type: () -> OperandDesc
1854 return self
.descriptor_array
[self
.operand_idx
]
1859 return self
.defining_descriptor
.ty
1862 def ty_before_spread(self
):
1864 return self
.defining_descriptor
.ty_before_spread
1868 # type: () -> BaseTy
1869 return self
.ty_before_spread
.base_ty
1872 def reg_offset_in_unspread(self
):
1873 """ the number of reg-sized slots in the unspread Loc before self's Loc
1875 e.g. if the unspread Loc containing self is:
1876 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1877 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1878 then reg_offset_into_unspread == 2 == 10 - 8
1880 return self
.defining_descriptor
.reg_offset_in_unspread
1883 def unspread_start_idx(self
):
1885 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1888 def unspread_start(self
):
1890 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1893 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1895 class SSAVal(SSAValOrUse
):
1900 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1903 def def_loc_set_before_spread(self
):
1904 # type: () -> LocSet
1905 return self
.defining_descriptor
.loc_set_before_spread
1908 def descriptor_array(self
):
1909 # type: () -> tuple[OperandDesc, ...]
1910 return self
.op
.properties
.outputs
1913 def tied_input(self
):
1914 # type: () -> None | SSAUse
1915 if self
.defining_descriptor
.tied_input_index
is None:
1917 return SSAUse(op
=self
.op
,
1918 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1921 def write_stage(self
):
1922 # type: () -> OpStage
1923 return self
.defining_descriptor
.write_stage
1926 def current_debugging_value(self
):
1927 # type: () -> tuple[int, ...]
1928 """ get the current value for debugging in pdb or similar.
1930 This is intended for use with
1931 `PreRASimState.set_current_debugging_state`.
1933 This is only intended for debugging, do not use in unit tests or
1936 return PreRASimState
.get_current_debugging_state()[self
]
1939 def ssa_val_sub_regs(self
):
1940 # type: () -> tuple[SSAValSubReg, ...]
1941 return tuple(SSAValSubReg(self
, i
) for i
in range(self
.ty
.reg_len
))
1944 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1946 class SSAUse(SSAValOrUse
):
1950 def use_loc_set_before_spread(self
):
1951 # type: () -> LocSet
1952 return self
.defining_descriptor
.loc_set_before_spread
1955 def descriptor_array(self
):
1956 # type: () -> tuple[OperandDesc, ...]
1957 return self
.op
.properties
.inputs
1961 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1965 # type: () -> SSAVal
1966 return self
.op
.input_vals
[self
.operand_idx
]
1969 def ssa_val(self
, ssa_val
):
1970 # type: (SSAVal) -> None
1971 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1975 _Desc
= TypeVar("_Desc")
1978 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1980 def _verify_write_with_desc(self
, idx
, item
, desc
):
1981 # type: (int, _T | Any, _Desc) -> None
1982 raise NotImplementedError
1985 def _verify_write(self
, idx
, item
):
1986 # type: (int | Any, _T | Any) -> int
1987 if not isinstance(idx
, int):
1988 if isinstance(idx
, slice):
1990 f
"can't write to slice of {self.__class__.__name__}")
1991 raise TypeError(f
"can't write with index {idx!r}")
1992 # normalize idx, raising IndexError if it is out of range
1993 idx
= range(len(self
.descriptors
))[idx
]
1994 desc
= self
.descriptors
[idx
]
1995 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
1998 def _on_set(self
, idx
, new_item
, old_item
):
1999 # type: (int, _T, _T | None) -> None
2003 def _get_descriptors(self
):
2004 # type: () -> tuple[_Desc, ...]
2005 raise NotImplementedError
2009 def descriptors(self
):
2010 # type: () -> tuple[_Desc, ...]
2011 return self
._get
_descriptors
()
2018 def __init__(self
, items
, op
):
2019 # type: (Iterable[_T], Op) -> None
2022 self
.__items
= [] # type: list[_T]
2023 for idx
, item
in enumerate(items
):
2024 if idx
>= len(self
.descriptors
):
2025 raise ValueError("too many items")
2026 _
= self
._verify
_write
(idx
, item
)
2027 self
.__items
.append(item
)
2028 if len(self
.__items
) < len(self
.descriptors
):
2029 raise ValueError("not enough items")
2033 # type: () -> Iterator[_T]
2034 yield from self
.__items
2037 def __getitem__(self
, idx
):
2042 def __getitem__(self
, idx
):
2043 # type: (slice) -> list[_T]
2047 def __getitem__(self
, idx
):
2048 # type: (int | slice) -> _T | list[_T]
2049 return self
.__items
[idx
]
2052 def __setitem__(self
, idx
, item
):
2053 # type: (int, _T) -> None
2054 idx
= self
._verify
_write
(idx
, item
)
2055 self
.__items
[idx
] = item
2060 return len(self
.__items
)
2064 return f
"{self.__class__.__name__}({self.__items}, op=...)"
2068 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
2069 def _get_descriptors(self
):
2070 # type: () -> tuple[OperandDesc, ...]
2071 return self
.op
.properties
.inputs
2073 def _verify_write_with_desc(self
, idx
, item
, desc
):
2074 # type: (int, SSAVal | Any, OperandDesc) -> None
2075 if not isinstance(item
, SSAVal
):
2076 raise TypeError("expected value of type SSAVal")
2077 if item
.ty
!= desc
.ty
:
2078 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
2079 f
"corresponding input's type {desc.ty!r}")
2081 def _on_set(self
, idx
, new_item
, old_item
):
2082 # type: (int, SSAVal, SSAVal | None) -> None
2083 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
2085 def __init__(self
, items
, op
):
2086 # type: (Iterable[SSAVal], Op) -> None
2087 if hasattr(op
, "inputs"):
2088 raise ValueError("Op.inputs already set")
2089 super().__init
__(items
, op
)
2093 class OpImmediates(OpInputSeq
[int, range]):
2094 def _get_descriptors(self
):
2095 # type: () -> tuple[range, ...]
2096 return self
.op
.properties
.immediates
2098 def _verify_write_with_desc(self
, idx
, item
, desc
):
2099 # type: (int, int | Any, range) -> None
2100 if not isinstance(item
, int):
2101 raise TypeError("expected value of type int")
2102 if item
not in desc
:
2103 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
2105 def __init__(self
, items
, op
):
2106 # type: (Iterable[int], Op) -> None
2107 if hasattr(op
, "immediates"):
2108 raise ValueError("Op.immediates already set")
2109 super().__init
__(items
, op
)
2112 @plain_data(frozen
=True, eq
=False, repr=False)
2115 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
2118 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
2119 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2121 self
.properties
= properties
2122 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
2123 inputs_len
= len(self
.properties
.inputs
)
2124 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
2125 self
.immediates
= OpImmediates(immediates
, op
=self
)
2126 outputs_len
= len(self
.properties
.outputs
)
2127 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
2128 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
2132 # type: () -> OpKind
2133 return self
.properties
.kind
2135 def __eq__(self
, other
):
2136 # type: (Op | Any) -> bool
2137 if isinstance(other
, Op
):
2138 return self
is other
2139 return NotImplemented
2143 return object.__hash
__(self
)
2145 def __repr__(self
, wrap_width
=63, indent
=" "):
2146 # type: (int, str) -> str
2147 WRAP_POINT
= "\u200B" # zero-width space
2148 items
= [f
"{self.name}:\n"]
2149 for i
, out
in enumerate(self
.outputs
):
2150 item
= f
"<...outputs[{i}]: {out.ty}>"
2152 item
= "(" + WRAP_POINT
+ item
2153 if i
!= len(self
.outputs
) - 1:
2154 item
+= ", " + WRAP_POINT
2156 item
+= WRAP_POINT
+ ") <= "
2158 items
.append(self
.kind
._name
_)
2159 if len(self
.input_vals
) + len(self
.immediates
) != 0:
2161 items
[-1] += WRAP_POINT
2162 for i
, inp
in enumerate(self
.input_vals
):
2164 if i
!= len(self
.input_vals
) - 1 or len(self
.immediates
) != 0:
2165 item
+= ", " + WRAP_POINT
2167 item
+= ") " + WRAP_POINT
2169 for i
, imm
in enumerate(self
.immediates
):
2171 if i
!= len(self
.immediates
) - 1:
2172 item
+= ", " + WRAP_POINT
2174 item
+= ") " + WRAP_POINT
2176 lines
= [] # type: list[str]
2177 for i
, line_in
in enumerate("".join(items
).splitlines()):
2179 line_in
= indent
+ line_in
2181 for part
in line_in
.split(WRAP_POINT
):
2185 trial_line_out
= line_out
+ part
2186 if len(trial_line_out
.rstrip()) > wrap_width
:
2187 lines
.append(line_out
.rstrip())
2188 line_out
= indent
+ part
2190 line_out
= trial_line_out
2191 lines
.append(line_out
.rstrip())
2192 return "\n".join(lines
)
2194 def sim(self
, state
):
2195 # type: (BaseSimState) -> None
2196 for inp
in self
.input_vals
:
2200 raise ValueError(f
"SSAVal {inp} not yet assigned when "
2204 if len(val
) != inp
.ty
.reg_len
:
2206 f
"value of SSAVal {inp} has wrong number of elements: "
2207 f
"expected {inp.ty.reg_len} found "
2208 f
"{len(val)}: {val!r}")
2209 if isinstance(state
, PreRASimState
):
2210 for out
in self
.outputs
:
2211 if out
in state
.ssa_vals
:
2212 if self
.kind
is OpKind
.FuncArgR3
:
2214 raise ValueError(f
"SSAVal {out} already assigned before "
2217 self
.kind
.sim(self
, state
)
2220 for out
in self
.outputs
:
2224 raise ValueError(f
"running {self} failed to assign to {out}")
2227 if len(val
) != out
.ty
.reg_len
:
2229 f
"value of SSAVal {out} has wrong number of elements: "
2230 f
"expected {out.ty.reg_len} found "
2231 f
"{len(val)}: {val!r}")
2233 def gen_asm(self
, state
):
2234 # type: (GenAsmState) -> None
2235 all_loc_kinds
= tuple(LocKind
)
2236 for inp
in self
.input_vals
:
2237 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2238 for out
in self
.outputs
:
2239 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2240 self
.kind
.gen_asm(self
, state
)
2243 @plain_data(frozen
=True, repr=False)
2244 class BaseSimState(metaclass
=ABCMeta
):
2245 __slots__
= "memory",
2247 def __init__(self
, memory
):
2248 # type: (dict[int, int]) -> None
2250 self
.memory
= memory
# type: dict[int, int]
2252 def _default_memory_value(self
):
2256 def on_skip(self
, op
):
2257 # type: (Op) -> None
2258 raise ValueError("skipping instructions not supported")
2260 def load_byte(self
, addr
):
2261 # type: (int) -> int
2262 addr
&= GPR_VALUE_MASK
2264 return self
.memory
[addr
] & 0xFF
2266 return self
._default
_memory
_value
()
2268 def store_byte(self
, addr
, value
):
2269 # type: (int, int) -> None
2270 addr
&= GPR_VALUE_MASK
2272 self
.memory
[addr
] = value
2274 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2275 # type: (int, int, bool) -> int
2276 if addr
% size_in_bytes
!= 0:
2277 raise ValueError(f
"address not aligned: {hex(addr)} "
2278 f
"required alignment: {size_in_bytes}")
2280 for i
in range(size_in_bytes
):
2281 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2282 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2283 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2286 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2287 # type: (int, int, int) -> None
2288 if addr
% size_in_bytes
!= 0:
2289 raise ValueError(f
"address not aligned: {hex(addr)} "
2290 f
"required alignment: {size_in_bytes}")
2291 for i
in range(size_in_bytes
):
2292 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2294 def _memory__repr(self
):
2296 if len(self
.memory
) == 0:
2298 keys
= sorted(self
.memory
.keys(), reverse
=True)
2299 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2300 items
= [] # type: list[str]
2301 while len(keys
) != 0:
2303 if (len(keys
) >= CHUNK_SIZE
2304 and addr
% CHUNK_SIZE
== 0
2305 and keys
[-CHUNK_SIZE
:]
2306 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2307 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2308 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2309 keys
[-CHUNK_SIZE
:] = ()
2311 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2313 return f
"{{{items[0]}}}"
2314 items_str
= ",\n".join(items
)
2315 return f
"{{\n{items_str}}}"
2319 field_vals
= [] # type: list[str]
2320 for name
in fields(self
):
2322 value
= getattr(self
, name
)
2323 except AttributeError:
2324 field_vals
.append(f
"{name}=<not set>")
2326 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2327 if callable(repr_fn
):
2328 field_vals
.append(f
"{name}={repr_fn()}")
2330 field_vals
.append(f
"{name}={value!r}")
2331 field_vals_str
= ", ".join(field_vals
)
2332 return f
"{self.__class__.__name__}({field_vals_str})"
2335 def __getitem__(self
, ssa_val
):
2336 # type: (SSAVal) -> tuple[int, ...]
2340 def __setitem__(self
, ssa_val
, value
):
2341 # type: (SSAVal, tuple[int, ...]) -> None
2345 @plain_data(frozen
=True, repr=False)
2346 class PreRABaseSimState(BaseSimState
):
2347 __slots__
= "ssa_vals",
2349 def __init__(self
, ssa_vals
, memory
):
2350 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2351 super().__init
__(memory
)
2352 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2354 def _ssa_vals__repr(self
):
2356 if len(self
.ssa_vals
) == 0:
2358 items
= [] # type: list[str]
2360 for k
, v
in self
.ssa_vals
.items():
2361 element_strs
= [] # type: list[str]
2362 for i
, el
in enumerate(v
):
2363 if i
% CHUNK_SIZE
!= 0:
2364 element_strs
.append(" " + hex(el
))
2366 element_strs
.append("\n " + hex(el
))
2367 if len(element_strs
) <= CHUNK_SIZE
:
2368 element_strs
[0] = element_strs
[0].lstrip()
2369 if len(element_strs
) == 1:
2370 element_strs
.append("")
2371 v_str
= ",".join(element_strs
)
2372 items
.append(f
"{k!r}: ({v_str})")
2373 if len(items
) == 1 and "\n" not in items
[0]:
2374 return f
"{{{items[0]}}}"
2375 items_str
= ",\n".join(items
)
2376 return f
"{{\n{items_str},\n}}"
2378 def __getitem__(self
, ssa_val
):
2379 # type: (SSAVal) -> tuple[int, ...]
2381 return self
.ssa_vals
[ssa_val
]
2383 return self
._handle
_undefined
_ssa
_val
(ssa_val
)
2385 def _handle_undefined_ssa_val(self
, ssa_val
):
2386 # type: (SSAVal) -> tuple[int, ...]
2387 raise KeyError("SSAVal has no value set", ssa_val
)
2389 def __setitem__(self
, ssa_val
, value
):
2390 # type: (SSAVal, tuple[int, ...]) -> None
2391 if len(value
) != ssa_val
.ty
.reg_len
:
2392 raise ValueError("value has wrong len")
2393 self
.ssa_vals
[ssa_val
] = value
2396 class SimSkipOp(Exception):
2400 @plain_data(frozen
=True, repr=False)
2402 class ConstPropagationState(PreRABaseSimState
):
2403 __slots__
= "skipped_ops",
2405 def __init__(self
, ssa_vals
, memory
, skipped_ops
):
2406 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2407 super().__init
__(ssa_vals
, memory
)
2408 self
.skipped_ops
= skipped_ops
2410 def _default_memory_value(self
):
2414 def _handle_undefined_ssa_val(self
, ssa_val
):
2415 # type: (SSAVal) -> tuple[int, ...]
2418 def on_skip(self
, op
):
2419 # type: (Op) -> None
2420 self
.skipped_ops
.add(op
)
2423 @plain_data(frozen
=True, repr=False)
2424 class PreRASimState(PreRABaseSimState
):
2427 __CURRENT_DEBUGGING_STATE
= [] # type: list[PreRASimState]
2430 def set_as_current_debugging_state(self
):
2431 """ return a context manager that sets self as the current state for
2432 debugging in pdb or similar. This is intended only for use with
2433 `get_current_debugging_state` which should not be used in unit tests
2437 PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.append(self
)
2440 assert self
is PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.pop(), \
2441 "inconsistent __CURRENT_DEBUGGING_STATE"
2444 def get_current_debugging_state():
2445 # type: () -> PreRASimState
2446 """ get the current state for debugging in pdb or similar.
2448 This is intended for use with `set_current_debugging_state`.
2450 This is only intended for debugging, do not use in unit tests or
2453 if len(PreRASimState
.__CURRENT
_DEBUGGING
_STATE
) == 0:
2454 raise ValueError("no current debugging state")
2455 return PreRASimState
.__CURRENT
_DEBUGGING
_STATE
[-1]
2458 @plain_data(frozen
=True, repr=False)
2460 class PostRASimState(BaseSimState
):
2461 __slots__
= "ssa_val_to_loc_map", "loc_values"
2463 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2464 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2465 super().__init
__(memory
)
2466 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2467 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2468 if ssa_val
.ty
!= loc
.ty
:
2470 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2471 self
.loc_values
= loc_values
2472 for loc
in self
.loc_values
.keys():
2473 if loc
.reg_len
!= 1:
2475 "loc_values must only contain Locs with reg_len=1, all "
2476 "larger Locs will be split into reg_len=1 sub-Locs")
2478 def _loc_values__repr(self
):
2480 locs
= sorted(self
.loc_values
.keys(),
2481 key
=lambda v
: (v
.kind
.name
, v
.start
))
2482 items
= [] # type: list[str]
2484 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2485 items_str
= ",\n".join(items
)
2486 return f
"{{\n{items_str},\n}}"
2488 def __getitem__(self
, ssa_val
):
2489 # type: (SSAVal) -> tuple[int, ...]
2490 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2491 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2492 retval
= [] # type: list[int]
2493 for i
in range(loc
.reg_len
):
2494 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2495 retval
.append(self
.loc_values
.get(subloc
, 0))
2496 return tuple(retval
)
2498 def __setitem__(self
, ssa_val
, value
):
2499 # type: (SSAVal, tuple[int, ...]) -> None
2500 if len(value
) != ssa_val
.ty
.reg_len
:
2501 raise ValueError("value has wrong len")
2502 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2503 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2504 for i
in range(loc
.reg_len
):
2505 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2506 self
.loc_values
[subloc
] = value
[i
]
2509 @plain_data(frozen
=True)
2511 __slots__
= "allocated_locs", "output"
2513 def __init__(self
, allocated_locs
, output
=None):
2514 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2516 self
.allocated_locs
= FMap(allocated_locs
)
2517 for ssa_val
, loc
in self
.allocated_locs
.items():
2518 if ssa_val
.ty
!= loc
.ty
:
2520 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2523 self
.output
= output
2525 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2527 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2528 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2529 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2530 ssa_val_or_locs
= [ssa_val_or_locs
]
2531 locs
= [] # type: list[Loc]
2532 for i
in ssa_val_or_locs
:
2533 if isinstance(i
, SSAVal
):
2534 locs
.append(self
.allocated_locs
[i
])
2538 raise ValueError("invalid Loc sequence: must not be empty")
2539 retval
= locs
[0].try_concat(*locs
[1:])
2541 raise ValueError("invalid Loc sequence: try_concat failed")
2542 if isinstance(expected_kinds
, LocKind
):
2543 expected_kinds
= expected_kinds
,
2544 if retval
.kind
not in expected_kinds
:
2545 if len(expected_kinds
) == 1:
2546 expected_kinds
= expected_kinds
[0]
2547 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2548 f
"{retval.kind} expected {expected_kinds}")
2551 def gpr(self
, ssa_val_or_locs
, is_vec
):
2552 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2553 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2554 vec_str
= "*" if is_vec
else ""
2555 return vec_str
+ str(loc
.start
)
2557 def sgpr(self
, ssa_val_or_locs
):
2558 # type: (__SSA_VAL_OR_LOCS) -> str
2559 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2561 def vgpr(self
, ssa_val_or_locs
):
2562 # type: (__SSA_VAL_OR_LOCS) -> str
2563 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2565 def stack(self
, ssa_val_or_locs
):
2566 # type: (__SSA_VAL_OR_LOCS) -> str
2567 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2568 return f
"{loc.start}(1)"
2570 def writeln(self
, *line_segments
):
2571 # type: (*str) -> None
2572 line
= " ".join(line_segments
)
2573 if isinstance(self
.output
, list):
2574 self
.output
.append(line
)
2576 self
.output
.write(line
+ "\n")