1 from contextlib
import contextmanager
3 from abc
import ABCMeta
, abstractmethod
4 from enum
import Enum
, unique
5 from functools
import lru_cache
, total_ordering
6 from io
import StringIO
7 from typing
import (AbstractSet
, Any
, Callable
, Generic
, Iterable
, Iterator
,
8 Mapping
, Sequence
, TypeVar
, Union
, overload
)
9 from weakref
import WeakValueDictionary
as _WeakVDict
11 from cached_property
import cached_property
12 from nmutil
.plain_data
import fields
, plain_data
14 from bigint_presentation_code
.type_util
import (Literal
, Self
, assert_never
,
16 from bigint_presentation_code
.util
import (BitSet
, FBitSet
, FMap
, InternedMeta
,
21 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* BITS_IN_BYTE
22 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
28 self
.ops
= [] # type: list[Op]
29 self
.__op
_names
= _WeakVDict() # type: _WeakVDict[str, Op]
30 self
.__next
_name
_suffix
= 2
32 def _add_op_with_unused_name(self
, op
, name
=""):
33 # type: (Op, str) -> str
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op
, "name"):
37 raise ValueError("Op already named")
40 if name
!= "" and name
not in self
.__op
_names
:
41 self
.__op
_names
[name
] = op
43 name
= orig_name
+ str(self
.__next
_name
_suffix
)
44 self
.__next
_name
_suffix
+= 1
50 def ops_to_str(self
, as_python_literal
=False, wrap_width
=63,
51 python_indent
=" ", indent
=" "):
52 # type: (bool, int, str, str) -> str
53 l
= [] # type: list[str]
55 l
.append(op
.__repr
__(wrap_width
=wrap_width
, indent
=indent
))
58 l
= [python_indent
+ "\""]
61 l
.append(f
"\\n\"\n{python_indent}\"")
64 elif ch
.isascii() and ch
.isprintable():
67 l
.append(repr(ch
).strip("\"'"))
70 empty_end
= f
"\"\n{python_indent}\"\""
71 if retval
.endswith(empty_end
):
72 retval
= retval
[:-len(empty_end
)]
75 def append_op(self
, op
):
78 raise ValueError("can't add Op to wrong Fn")
81 def append_new_op(self
, kind
, input_vals
=(), immediates
=(), name
="",
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval
= Op(fn
=self
, properties
=kind
.instantiate(maxvl
=maxvl
),
85 input_vals
=input_vals
, immediates
=immediates
, name
=name
)
86 self
.append_op(retval
)
90 # type: (BaseSimState) -> None
94 def gen_asm(self
, state
):
95 # type: (GenAsmState) -> None
99 def pre_ra_insert_copies(self
):
101 orig_ops
= list(self
.ops
)
102 copied_outputs
= {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs
= {} # type: dict[SSAVal, Op]
106 for i
in range(len(op
.input_vals
)):
107 inp
= copied_outputs
[op
.input_vals
[i
]]
108 if inp
.ty
.base_ty
is BaseTy
.I64
:
109 maxvl
= inp
.ty
.reg_len
110 if inp
.ty
.reg_len
!= 1:
111 setvl
= self
.append_new_op(
112 OpKind
.SetVLI
, immediates
=[maxvl
],
113 name
=f
"{op.name}.inp{i}.setvl")
114 vl
= setvl
.outputs
[0]
115 mv
= self
.append_new_op(
116 OpKind
.VecCopyToReg
, input_vals
=[inp
, vl
],
117 maxvl
=maxvl
, name
=f
"{op.name}.inp{i}.copy")
119 mv
= self
.append_new_op(
120 OpKind
.CopyToReg
, input_vals
=[inp
],
121 name
=f
"{op.name}.inp{i}.copy")
122 op
.input_vals
[i
] = mv
.outputs
[0]
123 elif inp
.ty
.base_ty
is BaseTy
.CA \
124 or inp
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
128 if inp
in setvli_outputs
:
129 setvl
= self
.append_new_op(
131 immediates
=setvli_outputs
[inp
].immediates
,
132 name
=f
"{op.name}.inp{i}.setvl")
133 inp
= setvl
.outputs
[0]
134 op
.input_vals
[i
] = inp
136 assert_never(inp
.ty
.base_ty
)
138 for i
, out
in enumerate(op
.outputs
):
139 if op
.kind
is OpKind
.SetVLI
:
140 setvli_outputs
[out
] = op
141 if out
.ty
.base_ty
is BaseTy
.I64
:
142 maxvl
= out
.ty
.reg_len
143 if out
.ty
.reg_len
!= 1:
144 setvl
= self
.append_new_op(
145 OpKind
.SetVLI
, immediates
=[maxvl
],
146 name
=f
"{op.name}.out{i}.setvl")
147 vl
= setvl
.outputs
[0]
148 mv
= self
.append_new_op(
149 OpKind
.VecCopyFromReg
, input_vals
=[out
, vl
],
150 maxvl
=maxvl
, name
=f
"{op.name}.out{i}.copy")
152 mv
= self
.append_new_op(
153 OpKind
.CopyFromReg
, input_vals
=[out
],
154 name
=f
"{op.name}.out{i}.copy")
155 copied_outputs
[out
] = mv
.outputs
[0]
156 elif out
.ty
.base_ty
is BaseTy
.CA \
157 or out
.ty
.base_ty
is BaseTy
.VL_MAXVL
:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs
[out
] = out
161 assert_never(out
.ty
.base_ty
)
168 value
: Literal
[0, 1] # type: ignore
170 def __new__(cls
, value
):
171 # type: (int) -> OpStage
173 if value
not in (0, 1):
174 raise ValueError("invalid value")
175 retval
= object.__new
__(cls
)
176 retval
._value
_ = value
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
200 return f
"OpStage.{self._name_}"
202 def __lt__(self
, other
):
203 # type: (OpStage | object) -> bool
204 if isinstance(other
, OpStage
):
205 return self
.value
< other
.value
206 return NotImplemented
209 assert OpStage
.Early
< OpStage
.Late
, "early must be less than late"
212 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
215 class ProgramPoint(metaclass
=InternedMeta
):
216 __slots__
= "op_index", "stage"
218 def __init__(self
, op_index
, stage
):
219 # type: (int, OpStage) -> None
220 self
.op_index
= op_index
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
229 return self
.op_index
* 2 + self
.stage
.value
232 def from_int_value(int_value
):
233 # type: (int) -> ProgramPoint
234 op_index
, stage
= divmod(int_value
, 2)
235 return ProgramPoint(op_index
=op_index
, stage
=OpStage(stage
))
237 def next(self
, steps
=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint
.from_int_value(self
.int_value
+ steps
)
241 def prev(self
, steps
=1):
242 # type: (int) -> ProgramPoint
243 return self
.next(steps
=-steps
)
245 def __lt__(self
, other
):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other
, ProgramPoint
):
248 return NotImplemented
249 if self
.op_index
!= other
.op_index
:
250 return self
.op_index
< other
.op_index
251 return self
.stage
< other
.stage
255 return f
"<ops[{self.op_index}]:{self.stage._name_}>"
258 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
260 class ProgramRange(Sequence
[ProgramPoint
], metaclass
=InternedMeta
):
261 __slots__
= "start", "stop"
263 def __init__(self
, start
, stop
):
264 # type: (ProgramPoint, ProgramPoint) -> None
269 def int_value_range(self
):
271 return range(self
.start
.int_value
, self
.stop
.int_value
)
274 def from_int_value_range(int_value_range
):
275 # type: (range) -> ProgramRange
276 if int_value_range
.step
!= 1:
277 raise ValueError("int_value_range must have step == 1")
279 start
=ProgramPoint
.from_int_value(int_value_range
.start
),
280 stop
=ProgramPoint
.from_int_value(int_value_range
.stop
))
283 def __getitem__(self
, __idx
):
284 # type: (int) -> ProgramPoint
288 def __getitem__(self
, __idx
):
289 # type: (slice) -> ProgramRange
292 def __getitem__(self
, __idx
):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v
= range(self
.start
.int_value
, self
.stop
.int_value
)[__idx
]
295 if isinstance(v
, int):
296 return ProgramPoint
.from_int_value(v
)
297 return ProgramRange
.from_int_value_range(v
)
301 return len(self
.int_value_range
)
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint
.from_int_value
, self
.int_value_range
)
309 start
= repr(self
.start
).lstrip("<").rstrip(">")
310 stop
= repr(self
.stop
).lstrip("<").rstrip(">")
311 return f
"<range:{start}..{stop}>"
314 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
316 class SSAValSubReg(metaclass
=InternedMeta
):
317 __slots__
= "ssa_val", "reg_idx"
319 def __init__(self
, ssa_val
, reg_idx
):
320 # type: (SSAVal, int) -> None
321 if reg_idx
< 0 or reg_idx
>= ssa_val
.ty
.reg_len
:
322 raise ValueError("reg_idx out of range")
323 self
.ssa_val
= ssa_val
324 self
.reg_idx
= reg_idx
328 return f
"{self.ssa_val}[{self.reg_idx}]"
331 @plain_data(frozen
=True, eq
=False, repr=False)
334 __slots__
= ("fn", "uses", "op_indexes", "live_ranges", "live_at",
335 "def_program_ranges", "use_program_points",
336 "all_program_points")
338 def __init__(self
, fn
):
341 self
.op_indexes
= FMap((op
, idx
) for idx
, op
in enumerate(fn
.ops
))
342 self
.all_program_points
= ProgramRange(
343 start
=ProgramPoint(op_index
=0, stage
=OpStage
.Early
),
344 stop
=ProgramPoint(op_index
=len(fn
.ops
), stage
=OpStage
.Early
))
345 def_program_ranges
= {} # type: dict[SSAVal, ProgramRange]
346 use_program_points
= {} # type: dict[SSAUse, ProgramPoint]
347 uses
= {} # type: dict[SSAVal, OSet[SSAUse]]
348 live_range_stops
= {} # type: dict[SSAVal, ProgramPoint]
350 for use
in op
.input_uses
:
351 uses
[use
.ssa_val
].add(use
)
352 use_program_point
= self
.__get
_use
_program
_point
(use
)
353 use_program_points
[use
] = use_program_point
354 live_range_stops
[use
.ssa_val
] = max(
355 live_range_stops
[use
.ssa_val
], use_program_point
.next())
356 for out
in op
.outputs
:
358 def_program_range
= self
.__get
_def
_program
_range
(out
)
359 def_program_ranges
[out
] = def_program_range
360 live_range_stops
[out
] = def_program_range
.stop
361 self
.uses
= FMap((k
, OFSet(v
)) for k
, v
in uses
.items())
362 self
.def_program_ranges
= FMap(def_program_ranges
)
363 self
.use_program_points
= FMap(use_program_points
)
364 live_ranges
= {} # type: dict[SSAVal, ProgramRange]
365 live_at
= {i
: OSet
[SSAVal
]() for i
in self
.all_program_points
}
366 for ssa_val
in uses
.keys():
367 live_ranges
[ssa_val
] = live_range
= ProgramRange(
368 start
=self
.def_program_ranges
[ssa_val
].start
,
369 stop
=live_range_stops
[ssa_val
])
370 for program_point
in live_range
:
371 live_at
[program_point
].add(ssa_val
)
372 self
.live_ranges
= FMap(live_ranges
)
373 self
.live_at
= FMap((k
, OFSet(v
)) for k
, v
in live_at
.items())
374 self
.copies
# initialize
375 self
.const_ssa_vals
# initialize
376 self
.const_ssa_val_sub_regs
# initialize
378 def __get_def_program_range(self
, ssa_val
):
379 # type: (SSAVal) -> ProgramRange
380 write_stage
= ssa_val
.defining_descriptor
.write_stage
381 start
= ProgramPoint(
382 op_index
=self
.op_indexes
[ssa_val
.op
], stage
=write_stage
)
383 # always include late stage of ssa_val.op, to ensure outputs always
384 # overlap all other outputs.
385 # stop is exclusive, so we need the next program point.
386 stop
= ProgramPoint(op_index
=start
.op_index
, stage
=OpStage
.Late
).next()
387 return ProgramRange(start
=start
, stop
=stop
)
389 def __get_use_program_point(self
, ssa_use
):
390 # type: (SSAUse) -> ProgramPoint
391 assert ssa_use
.defining_descriptor
.write_stage
is OpStage
.Early
, \
392 "assumed here, ensured by GenericOpProperties.__init__"
394 op_index
=self
.op_indexes
[ssa_use
.op
], stage
=OpStage
.Early
)
396 def __eq__(self
, other
):
397 # type: (FnAnalysis | Any) -> bool
398 if isinstance(other
, FnAnalysis
):
399 return self
.fn
== other
.fn
400 return NotImplemented
408 return "<FnAnalysis>"
412 # type: () -> FMap[SSAValSubReg, SSAValSubReg]
413 """ map from SSAValSubRegs to the original SSAValSubRegs that they are
414 a copy of, looking through all layers of copies. The map excludes all
415 SSAValSubRegs that aren't copies of other SSAValSubRegs.
417 retval
= {} # type: dict[SSAValSubReg, SSAValSubReg]
418 for op
in self
.op_indexes
.keys():
419 if not op
.properties
.is_copy
:
421 copy_reg_len
= op
.properties
.copy_reg_len
422 copy_inputs
= [] # type: list[SSAValSubReg]
423 for inp
in op
.input_vals
[:op
.properties
.copy_inputs_len
]:
424 for inp_sub_reg
in inp
.ssa_val_sub_regs
:
425 # propagate copies of copies
426 inp_sub_reg
= retval
.get(inp_sub_reg
, inp_sub_reg
)
427 copy_inputs
.append(inp_sub_reg
)
428 assert len(copy_inputs
) == copy_reg_len
, "logic error"
429 copy_outputs
= [] # type: list[SSAValSubReg]
430 for out
in op
.outputs
[:op
.properties
.copy_outputs_len
]:
431 copy_outputs
.extend(out
.ssa_val_sub_regs
)
432 assert len(copy_outputs
) == copy_reg_len
, "logic error"
433 for inp
, out
in zip(copy_inputs
, copy_outputs
):
438 def const_ssa_vals(self
):
439 # type: () -> FMap[SSAVal, tuple[int, ...]]
440 state
= ConstPropagationState(
441 ssa_vals
={}, memory
={}, skipped_ops
=OSet())
443 return FMap(state
.ssa_vals
)
446 def const_ssa_val_sub_regs(self
):
447 # type: () -> FMap[SSAValSubReg, int]
448 retval
= {} # type: dict[SSAValSubReg, int]
449 for ssa_val
, const_val
in self
.const_ssa_vals
.items():
450 assert ssa_val
.ty
.reg_len
== len(const_val
), "logic error"
451 for reg_idx
, v
in enumerate(const_val
):
452 retval
[SSAValSubReg(ssa_val
, reg_idx
)] = v
455 def is_always_equal(self
, a
, b
):
456 # type: (SSAValSubReg, SSAValSubReg) -> bool
457 """check if a and b are known to be always equal to each other.
458 This means they can be allocated to the same location if other
459 constraints don't prevent that.
461 this can happen for a number of reasons, such as:
462 * a and b are copies of the same thing
463 * a and b are known to be constants and they have the same value
465 if a
.ssa_val
.base_ty
!= b
.ssa_val
.base_ty
:
466 return False # can't be equal, they have different types
467 # look through copies
468 a
= self
.copies
.get(a
, a
)
469 b
= self
.copies
.get(b
, b
)
472 # check if they have the same constant value
474 a_const_val
= self
.const_ssa_val_sub_regs
[a
]
475 b_const_val
= self
.const_ssa_val_sub_regs
[b
]
476 if a_const_val
== b_const_val
:
488 VL_MAXVL
= enum
.auto()
491 def only_scalar(self
):
493 if self
is BaseTy
.I64
:
495 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
501 def max_reg_len(self
):
503 if self
is BaseTy
.I64
:
505 elif self
is BaseTy
.CA
or self
is BaseTy
.VL_MAXVL
:
511 return "BaseTy." + self
._name
_
514 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
516 class Ty(metaclass
=InternedMeta
):
517 __slots__
= "base_ty", "reg_len"
520 def validate(base_ty
, reg_len
):
521 # type: (BaseTy, int) -> str | None
522 """ return a string with the error if the combination is invalid,
523 otherwise return None
525 if base_ty
.only_scalar
and reg_len
!= 1:
526 return f
"can't create a vector of an only-scalar type: {base_ty}"
527 if reg_len
< 1 or reg_len
> base_ty
.max_reg_len
:
528 return "reg_len out of range"
531 def __init__(self
, base_ty
, reg_len
):
532 # type: (BaseTy, int) -> None
533 msg
= self
.validate(base_ty
=base_ty
, reg_len
=reg_len
)
535 raise ValueError(msg
)
536 self
.base_ty
= base_ty
537 self
.reg_len
= reg_len
541 if self
.reg_len
!= 1:
542 reg_len
= f
"*{self.reg_len}"
545 return f
"<{self.base_ty._name_}{reg_len}>"
552 StackI64
= enum
.auto()
554 VL_MAXVL
= enum
.auto()
559 if self
is LocKind
.GPR
or self
is LocKind
.StackI64
:
561 if self
is LocKind
.CA
:
563 if self
is LocKind
.VL_MAXVL
:
564 return BaseTy
.VL_MAXVL
571 if self
is LocKind
.StackI64
:
573 if self
is LocKind
.GPR
or self
is LocKind
.CA \
574 or self
is LocKind
.VL_MAXVL
:
575 return self
.base_ty
.max_reg_len
580 return "LocKind." + self
._name
_
585 class LocSubKind(Enum
):
586 BASE_GPR
= enum
.auto()
587 SV_EXTRA2_VGPR
= enum
.auto()
588 SV_EXTRA2_SGPR
= enum
.auto()
589 SV_EXTRA3_VGPR
= enum
.auto()
590 SV_EXTRA3_SGPR
= enum
.auto()
591 StackI64
= enum
.auto()
593 VL_MAXVL
= enum
.auto()
597 # type: () -> LocKind
598 # pyright fails typechecking when using `in` here:
599 # reported: https://github.com/microsoft/pyright/issues/4102
600 if self
in (LocSubKind
.BASE_GPR
, LocSubKind
.SV_EXTRA2_VGPR
,
601 LocSubKind
.SV_EXTRA2_SGPR
, LocSubKind
.SV_EXTRA3_VGPR
,
602 LocSubKind
.SV_EXTRA3_SGPR
):
604 if self
is LocSubKind
.StackI64
:
605 return LocKind
.StackI64
606 if self
is LocSubKind
.CA
:
608 if self
is LocSubKind
.VL_MAXVL
:
609 return LocKind
.VL_MAXVL
614 return self
.kind
.base_ty
617 def allocatable_locs(self
, ty
):
618 # type: (Ty) -> LocSet
619 if ty
.base_ty
!= self
.base_ty
:
620 raise ValueError("type mismatch")
621 if self
is LocSubKind
.BASE_GPR
:
623 elif self
is LocSubKind
.SV_EXTRA2_VGPR
:
624 starts
= range(0, 128, 2)
625 elif self
is LocSubKind
.SV_EXTRA2_SGPR
:
627 elif self
is LocSubKind
.SV_EXTRA3_VGPR \
628 or self
is LocSubKind
.SV_EXTRA3_SGPR
:
630 elif self
is LocSubKind
.StackI64
:
631 starts
= range(LocKind
.StackI64
.loc_count
)
632 elif self
is LocSubKind
.CA
or self
is LocSubKind
.VL_MAXVL
:
633 return LocSet([Loc(kind
=self
.kind
, start
=0, reg_len
=1)])
636 retval
= [] # type: list[Loc]
638 loc
= Loc
.try_make(kind
=self
.kind
, start
=start
, reg_len
=ty
.reg_len
)
642 for special_loc
in SPECIAL_GPRS
:
643 if loc
.conflicts(special_loc
):
648 return LocSet(retval
)
651 return "LocSubKind." + self
._name
_
654 @plain_data(frozen
=True, unsafe_hash
=True)
656 class GenericTy(metaclass
=InternedMeta
):
657 __slots__
= "base_ty", "is_vec"
659 def __init__(self
, base_ty
, is_vec
):
660 # type: (BaseTy, bool) -> None
661 self
.base_ty
= base_ty
662 if base_ty
.only_scalar
and is_vec
:
663 raise ValueError(f
"base_ty={base_ty} requires is_vec=False")
666 def instantiate(self
, maxvl
):
668 # here's where subvl and elwid would be accounted for
670 return Ty(self
.base_ty
, maxvl
)
671 return Ty(self
.base_ty
, 1)
673 def can_instantiate_to(self
, ty
):
675 if self
.base_ty
!= ty
.base_ty
:
679 return ty
.reg_len
== 1
682 @plain_data(frozen
=True, unsafe_hash
=True)
684 class Loc(metaclass
=InternedMeta
):
685 __slots__
= "kind", "start", "reg_len"
688 def validate(kind
, start
, reg_len
):
689 # type: (LocKind, int, int) -> str | None
690 msg
= Ty
.validate(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
693 if reg_len
> kind
.loc_count
:
694 return "invalid reg_len"
695 if start
< 0 or start
+ reg_len
> kind
.loc_count
:
696 return "start not in valid range"
700 def try_make(kind
, start
, reg_len
):
701 # type: (LocKind, int, int) -> Loc | None
702 msg
= Loc
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
705 return Loc(kind
=kind
, start
=start
, reg_len
=reg_len
)
707 def __init__(self
, kind
, start
, reg_len
):
708 # type: (LocKind, int, int) -> None
709 msg
= self
.validate(kind
=kind
, start
=start
, reg_len
=reg_len
)
711 raise ValueError(msg
)
713 self
.reg_len
= reg_len
716 def conflicts(self
, other
):
717 # type: (Loc) -> bool
718 return (self
.kind
== other
.kind
719 and self
.start
< other
.stop
and other
.start
< self
.stop
)
722 def make_ty(kind
, reg_len
):
723 # type: (LocKind, int) -> Ty
724 return Ty(base_ty
=kind
.base_ty
, reg_len
=reg_len
)
729 return self
.make_ty(kind
=self
.kind
, reg_len
=self
.reg_len
)
734 return self
.start
+ self
.reg_len
736 def try_concat(self
, *others
):
737 # type: (*Loc | None) -> Loc | None
738 reg_len
= self
.reg_len
741 if other
is None or other
.kind
!= self
.kind
:
743 if stop
!= other
.start
:
746 reg_len
+= other
.reg_len
747 return Loc(kind
=self
.kind
, start
=self
.start
, reg_len
=reg_len
)
749 def get_subloc_at_offset(self
, subloc_ty
, offset
):
750 # type: (Ty, int) -> Loc
751 if subloc_ty
.base_ty
!= self
.kind
.base_ty
:
752 raise ValueError("BaseTy mismatch")
753 if offset
< 0 or offset
+ subloc_ty
.reg_len
> self
.reg_len
:
754 raise ValueError("invalid sub-Loc: offset and/or "
755 "subloc_ty.reg_len out of range")
756 return Loc(kind
=self
.kind
,
757 start
=self
.start
+ offset
, reg_len
=subloc_ty
.reg_len
)
761 Loc(kind
=LocKind
.GPR
, start
=0, reg_len
=1),
762 Loc(kind
=LocKind
.GPR
, start
=1, reg_len
=1),
763 Loc(kind
=LocKind
.GPR
, start
=2, reg_len
=1),
764 Loc(kind
=LocKind
.GPR
, start
=13, reg_len
=1),
769 class LocSet(OFSet
[Loc
], metaclass
=InternedMeta
):
770 def __init__(self
, __locs
=()):
771 # type: (Iterable[Loc]) -> None
772 super().__init
__(__locs
)
773 if isinstance(__locs
, LocSet
):
774 self
.__starts
= __locs
.starts
775 self
.__ty
= __locs
.ty
777 starts
= {i
: BitSet() for i
in LocKind
}
778 ty
= None # type: None | Ty
783 raise ValueError(f
"conflicting types: {ty} != {loc.ty}")
784 starts
[loc
.kind
].add(loc
.start
)
785 self
.__starts
= FMap(
786 (k
, FBitSet(v
)) for k
, v
in starts
.items() if len(v
) != 0)
791 # type: () -> FMap[LocKind, FBitSet]
796 # type: () -> Ty | None
801 # type: () -> FMap[LocKind, FBitSet]
806 (k
, FBitSet(bits
=v
.bits
<< sh
)) for k
, v
in self
.starts
.items())
810 # type: () -> AbstractSet[LocKind]
811 return self
.starts
.keys()
815 # type: () -> int | None
818 return self
.ty
.reg_len
822 # type: () -> BaseTy | None
825 return self
.ty
.base_ty
827 def concat(self
, *others
):
828 # type: (*LocSet) -> LocSet
831 base_ty
= self
.ty
.base_ty
832 reg_len
= self
.ty
.reg_len
833 starts
= {k
: BitSet(v
) for k
, v
in self
.starts
.items()}
837 if other
.ty
.base_ty
!= base_ty
:
839 for kind
, other_starts
in other
.starts
.items():
840 if kind
not in starts
:
842 starts
[kind
].bits
&= other_starts
.bits
>> reg_len
843 if starts
[kind
] == 0:
847 reg_len
+= other
.ty
.reg_len
850 # type: () -> Iterable[Loc]
851 for kind
, v
in starts
.items():
853 loc
= Loc
.try_make(kind
=kind
, start
=start
, reg_len
=reg_len
)
856 return LocSet(locs())
858 @lru_cache(maxsize
=None, typed
=True)
859 def max_conflicts_with(self
, other
):
860 # type: (LocSet | Loc) -> int
861 """the largest number of Locs in `self` that a single Loc
862 from `other` can conflict with
864 if isinstance(other
, LocSet
):
865 return max(self
.max_conflicts_with(i
) for i
in other
)
867 return sum(other
.conflicts(i
) for i
in self
)
870 return f
"LocSet(starts={self.starts!r}, ty={self.ty!r})"
873 @plain_data(frozen
=True, unsafe_hash
=True)
875 class GenericOperandDesc(metaclass
=InternedMeta
):
876 """generic Op operand descriptor"""
877 __slots__
= ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
881 self
, ty
, # type: GenericTy
882 sub_kinds
, # type: Iterable[LocSubKind]
884 fixed_loc
=None, # type: Loc | None
885 tied_input_index
=None, # type: int | None
886 spread
=False, # type: bool
887 write_stage
=OpStage
.Early
, # type: OpStage
889 # type: (...) -> None
891 self
.sub_kinds
= OFSet(sub_kinds
)
892 if len(self
.sub_kinds
) == 0:
893 raise ValueError("sub_kinds can't be empty")
894 self
.fixed_loc
= fixed_loc
895 if fixed_loc
is not None:
896 if tied_input_index
is not None:
897 raise ValueError("operand can't be both tied and fixed")
898 if not ty
.can_instantiate_to(fixed_loc
.ty
):
900 f
"fixed_loc has incompatible type for given generic "
901 f
"type: fixed_loc={fixed_loc} generic ty={ty}")
902 if len(self
.sub_kinds
) != 1:
904 "multiple sub_kinds not allowed for fixed operand")
905 for sub_kind
in self
.sub_kinds
:
906 if fixed_loc
not in sub_kind
.allocatable_locs(fixed_loc
.ty
):
908 f
"fixed_loc not in given sub_kind: "
909 f
"fixed_loc={fixed_loc} sub_kind={sub_kind}")
910 for sub_kind
in self
.sub_kinds
:
911 if sub_kind
.base_ty
!= ty
.base_ty
:
912 raise ValueError(f
"sub_kind is incompatible with type: "
913 f
"sub_kind={sub_kind} ty={ty}")
914 if tied_input_index
is not None and tied_input_index
< 0:
915 raise ValueError("invalid tied_input_index")
916 self
.tied_input_index
= tied_input_index
919 if self
.tied_input_index
is not None:
920 raise ValueError("operand can't be both spread and tied")
921 if self
.fixed_loc
is not None:
922 raise ValueError("operand can't be both spread and fixed")
924 raise ValueError("operand can't be both spread and vector")
925 self
.write_stage
= write_stage
928 def ty_before_spread(self
):
929 # type: () -> GenericTy
931 return GenericTy(base_ty
=self
.ty
.base_ty
, is_vec
=True)
934 def tied_to_input(self
, tied_input_index
):
935 # type: (int) -> Self
936 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
937 tied_input_index
=tied_input_index
,
938 write_stage
=self
.write_stage
)
940 def with_fixed_loc(self
, fixed_loc
):
941 # type: (Loc) -> Self
942 return GenericOperandDesc(self
.ty
, self
.sub_kinds
, fixed_loc
=fixed_loc
,
943 write_stage
=self
.write_stage
)
945 def with_write_stage(self
, write_stage
):
946 # type: (OpStage) -> Self
947 return GenericOperandDesc(self
.ty
, self
.sub_kinds
,
948 fixed_loc
=self
.fixed_loc
,
949 tied_input_index
=self
.tied_input_index
,
951 write_stage
=write_stage
)
953 def instantiate(self
, maxvl
):
954 # type: (int) -> Iterable[OperandDesc]
955 # assumes all spread operands have ty.reg_len = 1
959 ty_before_spread
= self
.ty_before_spread
.instantiate(maxvl
=maxvl
)
961 def locs_before_spread():
962 # type: () -> Iterable[Loc]
963 if self
.fixed_loc
is not None:
964 if ty_before_spread
!= self
.fixed_loc
.ty
:
966 f
"instantiation failed: type mismatch with fixed_loc: "
967 f
"instantiated type: {ty_before_spread} "
968 f
"fixed_loc: {self.fixed_loc}")
971 for sub_kind
in self
.sub_kinds
:
972 yield from sub_kind
.allocatable_locs(ty_before_spread
)
973 loc_set_before_spread
= LocSet(locs_before_spread())
974 for idx
in range(rep_count
):
977 yield OperandDesc(loc_set_before_spread
=loc_set_before_spread
,
978 tied_input_index
=self
.tied_input_index
,
979 spread_index
=idx
, write_stage
=self
.write_stage
)
982 @plain_data(frozen
=True, unsafe_hash
=True)
984 class OperandDesc(metaclass
=InternedMeta
):
985 """Op operand descriptor"""
986 __slots__
= ("loc_set_before_spread", "tied_input_index", "spread_index",
989 def __init__(self
, loc_set_before_spread
, tied_input_index
, spread_index
,
991 # type: (LocSet, int | None, int | None, OpStage) -> None
992 if len(loc_set_before_spread
) == 0:
993 raise ValueError("loc_set_before_spread must not be empty")
994 self
.loc_set_before_spread
= loc_set_before_spread
995 self
.tied_input_index
= tied_input_index
996 if self
.tied_input_index
is not None and spread_index
is not None:
997 raise ValueError("operand can't be both spread and tied")
998 self
.spread_index
= spread_index
999 self
.write_stage
= write_stage
1002 def ty_before_spread(self
):
1004 ty
= self
.loc_set_before_spread
.ty
1005 assert ty
is not None, (
1006 "__init__ checked that the LocSet isn't empty, "
1007 "non-empty LocSets should always have ty set")
1012 """ Ty after any spread is applied """
1013 if self
.spread_index
is not None:
1014 # assumes all spread operands have ty.reg_len = 1
1015 return Ty(base_ty
=self
.ty_before_spread
.base_ty
, reg_len
=1)
1016 return self
.ty_before_spread
1019 def reg_offset_in_unspread(self
):
1020 """ the number of reg-sized slots in the unspread Loc before self's Loc
1022 e.g. if the unspread Loc containing self is:
1023 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1024 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1025 then reg_offset_into_unspread == 2 == 10 - 8
1027 if self
.spread_index
is None:
1029 return self
.spread_index
* self
.ty
.reg_len
1032 OD_BASE_SGPR
= GenericOperandDesc(
1033 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1034 sub_kinds
=[LocSubKind
.BASE_GPR
])
1035 OD_EXTRA3_SGPR
= GenericOperandDesc(
1036 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1037 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
])
1038 OD_EXTRA3_VGPR
= GenericOperandDesc(
1039 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1040 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
])
1041 OD_EXTRA2_SGPR
= GenericOperandDesc(
1042 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=False),
1043 sub_kinds
=[LocSubKind
.SV_EXTRA2_SGPR
])
1044 OD_EXTRA2_VGPR
= GenericOperandDesc(
1045 ty
=GenericTy(base_ty
=BaseTy
.I64
, is_vec
=True),
1046 sub_kinds
=[LocSubKind
.SV_EXTRA2_VGPR
])
1047 OD_CA
= GenericOperandDesc(
1048 ty
=GenericTy(base_ty
=BaseTy
.CA
, is_vec
=False),
1049 sub_kinds
=[LocSubKind
.CA
])
1050 OD_VL
= GenericOperandDesc(
1051 ty
=GenericTy(base_ty
=BaseTy
.VL_MAXVL
, is_vec
=False),
1052 sub_kinds
=[LocSubKind
.VL_MAXVL
])
1055 @plain_data(frozen
=True, unsafe_hash
=True)
1057 class GenericOpProperties(metaclass
=InternedMeta
):
1058 __slots__
= ("demo_asm", "inputs", "outputs", "immediates",
1059 "is_copy", "is_load_immediate", "has_side_effects")
1062 self
, demo_asm
, # type: str
1063 inputs
, # type: Iterable[GenericOperandDesc]
1064 outputs
, # type: Iterable[GenericOperandDesc]
1065 immediates
=(), # type: Iterable[range]
1066 is_copy
=False, # type: bool
1067 is_load_immediate
=False, # type: bool
1068 has_side_effects
=False, # type: bool
1070 # type: (...) -> None
1071 self
.demo_asm
= demo_asm
# type: str
1072 self
.inputs
= tuple(inputs
) # type: tuple[GenericOperandDesc, ...]
1073 for inp
in self
.inputs
:
1074 if inp
.tied_input_index
is not None:
1076 f
"tied_input_index is not allowed on inputs: {inp}")
1077 if inp
.write_stage
is not OpStage
.Early
:
1079 f
"write_stage is not allowed on inputs: {inp}")
1080 self
.outputs
= tuple(outputs
) # type: tuple[GenericOperandDesc, ...]
1081 fixed_locs
= [] # type: list[tuple[Loc, int]]
1082 for idx
, out
in enumerate(self
.outputs
):
1083 if out
.tied_input_index
is not None:
1084 if out
.tied_input_index
>= len(self
.inputs
):
1085 raise ValueError(f
"tied_input_index out of range: {out}")
1086 tied_inp
= self
.inputs
[out
.tied_input_index
]
1087 expected_out
= tied_inp
.tied_to_input(out
.tied_input_index
) \
1088 .with_write_stage(out
.write_stage
)
1089 if expected_out
!= out
:
1090 raise ValueError(f
"output can't be tied to non-equivalent "
1091 f
"input: {out} tied to {tied_inp}")
1092 if out
.fixed_loc
is not None:
1093 for other_fixed_loc
, other_idx
in fixed_locs
:
1094 if not other_fixed_loc
.conflicts(out
.fixed_loc
):
1097 f
"conflicting fixed_locs: outputs[{idx}] and "
1098 f
"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1099 f
"with {other_fixed_loc}")
1100 fixed_locs
.append((out
.fixed_loc
, idx
))
1101 self
.immediates
= tuple(immediates
) # type: tuple[range, ...]
1102 self
.is_copy
= is_copy
# type: bool
1103 self
.is_load_immediate
= is_load_immediate
# type: bool
1104 self
.has_side_effects
= has_side_effects
# type: bool
1107 @plain_data(frozen
=True, unsafe_hash
=True)
1109 class OpProperties(metaclass
=InternedMeta
):
1110 __slots__
= "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
1112 def __init__(self
, kind
, maxvl
):
1113 # type: (OpKind, int) -> None
1114 self
.kind
= kind
# type: OpKind
1115 inputs
= [] # type: list[OperandDesc]
1116 for inp
in self
.generic
.inputs
:
1117 inputs
.extend(inp
.instantiate(maxvl
=maxvl
))
1118 self
.inputs
= tuple(inputs
) # type: tuple[OperandDesc, ...]
1119 outputs
= [] # type: list[OperandDesc]
1120 for out
in self
.generic
.outputs
:
1121 outputs
.extend(out
.instantiate(maxvl
=maxvl
))
1122 self
.outputs
= tuple(outputs
) # type: tuple[OperandDesc, ...]
1123 self
.maxvl
= maxvl
# type: int
1124 copy_input_reg_len
= 0
1125 for inp
in self
.inputs
[:self
.copy_inputs_len
]:
1126 copy_input_reg_len
+= inp
.ty
.reg_len
1127 copy_output_reg_len
= 0
1128 for out
in self
.outputs
[:self
.copy_outputs_len
]:
1129 copy_output_reg_len
+= out
.ty
.reg_len
1130 if copy_input_reg_len
!= copy_output_reg_len
:
1131 raise ValueError(f
"invalid copy: copy's input reg len must "
1132 f
"match its output reg len: "
1133 f
"{copy_input_reg_len} != {copy_output_reg_len}")
1134 self
.copy_reg_len
= copy_input_reg_len
1138 # type: () -> GenericOpProperties
1139 return self
.kind
.properties
1142 def immediates(self
):
1143 # type: () -> tuple[range, ...]
1144 return self
.generic
.immediates
1149 return self
.generic
.demo_asm
1154 return self
.generic
.is_copy
1157 def is_load_immediate(self
):
1159 return self
.generic
.is_load_immediate
1162 def has_side_effects(self
):
1164 return self
.generic
.has_side_effects
1167 def copy_inputs_len(self
):
1169 if not self
.is_copy
:
1171 if self
.inputs
[0].spread_index
is None:
1174 for i
, inp
in enumerate(self
.inputs
):
1175 if inp
.spread_index
!= i
:
1181 def copy_outputs_len(self
):
1183 if not self
.is_copy
:
1185 if self
.outputs
[0].spread_index
is None:
1188 for i
, out
in enumerate(self
.outputs
):
1189 if out
.spread_index
!= i
:
1195 IMM_S16
= range(-1 << 15, 1 << 15)
1197 _SIM_FN
= Callable
[["Op", "BaseSimState"], None]
1198 _SIM_FN2
= Callable
[[], _SIM_FN
]
1199 _SIM_FNS
= {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1200 _GEN_ASM_FN
= Callable
[["Op", "GenAsmState"], None]
1201 _GEN_ASM_FN2
= Callable
[[], _GEN_ASM_FN
]
1202 _GEN_ASMS
= {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1208 def __init__(self
, properties
):
1209 # type: (GenericOpProperties) -> None
1211 self
.__properties
= properties
1214 def properties(self
):
1215 # type: () -> GenericOpProperties
1216 return self
.__properties
1218 def instantiate(self
, maxvl
):
1219 # type: (int) -> OpProperties
1220 return OpProperties(self
, maxvl
=maxvl
)
1224 return "OpKind." + self
._name
_
1228 # type: () -> _SIM_FN
1229 return _SIM_FNS
[self
.properties
]()
1233 # type: () -> _GEN_ASM_FN
1234 return _GEN_ASMS
[self
.properties
]()
1237 def __clearca_sim(op
, state
):
1238 # type: (Op, BaseSimState) -> None
1239 state
[op
.outputs
[0]] = False,
1242 def __clearca_gen_asm(op
, state
):
1243 # type: (Op, GenAsmState) -> None
1244 state
.writeln("addic 0, 0, 0")
1245 ClearCA
= GenericOpProperties(
1246 demo_asm
="addic 0, 0, 0",
1248 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1250 _SIM_FNS
[ClearCA
] = lambda: OpKind
.__clearca
_sim
1251 _GEN_ASMS
[ClearCA
] = lambda: OpKind
.__clearca
_gen
_asm
1254 def __setca_sim(op
, state
):
1255 # type: (Op, BaseSimState) -> None
1256 state
[op
.outputs
[0]] = True,
1259 def __setca_gen_asm(op
, state
):
1260 # type: (Op, GenAsmState) -> None
1261 state
.writeln("subfc 0, 0, 0")
1262 SetCA
= GenericOpProperties(
1263 demo_asm
="subfc 0, 0, 0",
1265 outputs
=[OD_CA
.with_write_stage(OpStage
.Late
)],
1267 _SIM_FNS
[SetCA
] = lambda: OpKind
.__setca
_sim
1268 _GEN_ASMS
[SetCA
] = lambda: OpKind
.__setca
_gen
_asm
1271 def __svadde_sim(op
, state
):
1272 # type: (Op, BaseSimState) -> None
1273 RA
= state
[op
.input_vals
[0]]
1274 RB
= state
[op
.input_vals
[1]]
1275 carry
, = state
[op
.input_vals
[2]]
1276 VL
, = state
[op
.input_vals
[3]]
1277 RT
= [] # type: list[int]
1279 v
= RA
[i
] + RB
[i
] + carry
1280 RT
.append(v
& GPR_VALUE_MASK
)
1281 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1282 state
[op
.outputs
[0]] = tuple(RT
)
1283 state
[op
.outputs
[1]] = carry
,
1286 def __svadde_gen_asm(op
, state
):
1287 # type: (Op, GenAsmState) -> None
1288 RT
= state
.vgpr(op
.outputs
[0])
1289 RA
= state
.vgpr(op
.input_vals
[0])
1290 RB
= state
.vgpr(op
.input_vals
[1])
1291 state
.writeln(f
"sv.adde {RT}, {RA}, {RB}")
1292 SvAddE
= GenericOpProperties(
1293 demo_asm
="sv.adde *RT, *RA, *RB",
1294 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1295 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1297 _SIM_FNS
[SvAddE
] = lambda: OpKind
.__svadde
_sim
1298 _GEN_ASMS
[SvAddE
] = lambda: OpKind
.__svadde
_gen
_asm
1301 def __addze_sim(op
, state
):
1302 # type: (Op, BaseSimState) -> None
1303 RA
, = state
[op
.input_vals
[0]]
1304 carry
, = state
[op
.input_vals
[1]]
1306 RT
= v
& GPR_VALUE_MASK
1307 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1308 state
[op
.outputs
[0]] = RT
,
1309 state
[op
.outputs
[1]] = carry
,
1312 def __addze_gen_asm(op
, state
):
1313 # type: (Op, GenAsmState) -> None
1314 RT
= state
.vgpr(op
.outputs
[0])
1315 RA
= state
.vgpr(op
.input_vals
[0])
1316 state
.writeln(f
"addze {RT}, {RA}")
1317 AddZE
= GenericOpProperties(
1318 demo_asm
="addze RT, RA",
1319 inputs
=[OD_BASE_SGPR
, OD_CA
],
1320 outputs
=[OD_BASE_SGPR
, OD_CA
.tied_to_input(1)],
1322 _SIM_FNS
[AddZE
] = lambda: OpKind
.__addze
_sim
1323 _GEN_ASMS
[AddZE
] = lambda: OpKind
.__addze
_gen
_asm
1326 def __svsubfe_sim(op
, state
):
1327 # type: (Op, BaseSimState) -> None
1328 RA
= state
[op
.input_vals
[0]]
1329 RB
= state
[op
.input_vals
[1]]
1330 carry
, = state
[op
.input_vals
[2]]
1331 VL
, = state
[op
.input_vals
[3]]
1332 RT
= [] # type: list[int]
1334 v
= (~RA
[i
] & GPR_VALUE_MASK
) + RB
[i
] + carry
1335 RT
.append(v
& GPR_VALUE_MASK
)
1336 carry
= (v
>> GPR_SIZE_IN_BITS
) != 0
1337 state
[op
.outputs
[0]] = tuple(RT
)
1338 state
[op
.outputs
[1]] = carry
,
1341 def __svsubfe_gen_asm(op
, state
):
1342 # type: (Op, GenAsmState) -> None
1343 RT
= state
.vgpr(op
.outputs
[0])
1344 RA
= state
.vgpr(op
.input_vals
[0])
1345 RB
= state
.vgpr(op
.input_vals
[1])
1346 state
.writeln(f
"sv.subfe {RT}, {RA}, {RB}")
1347 SvSubFE
= GenericOpProperties(
1348 demo_asm
="sv.subfe *RT, *RA, *RB",
1349 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_VGPR
, OD_CA
, OD_VL
],
1350 outputs
=[OD_EXTRA3_VGPR
, OD_CA
.tied_to_input(2)],
1352 _SIM_FNS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_sim
1353 _GEN_ASMS
[SvSubFE
] = lambda: OpKind
.__svsubfe
_gen
_asm
1356 def __svandvs_sim(op
, state
):
1357 # type: (Op, BaseSimState) -> None
1358 RA
= state
[op
.input_vals
[0]]
1359 RB
, = state
[op
.input_vals
[1]]
1360 VL
, = state
[op
.input_vals
[2]]
1361 RT
= [] # type: list[int]
1363 RT
.append(RA
[i
] & RB
& GPR_VALUE_MASK
)
1364 state
[op
.outputs
[0]] = tuple(RT
)
1367 def __svandvs_gen_asm(op
, state
):
1368 # type: (Op, GenAsmState) -> None
1369 RT
= state
.vgpr(op
.outputs
[0])
1370 RA
= state
.vgpr(op
.input_vals
[0])
1371 RB
= state
.sgpr(op
.input_vals
[1])
1372 state
.writeln(f
"sv.and {RT}, {RA}, {RB}")
1373 SvAndVS
= GenericOpProperties(
1374 demo_asm
="sv.and *RT, *RA, RB",
1375 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1376 outputs
=[OD_EXTRA3_VGPR
],
1378 _SIM_FNS
[SvAndVS
] = lambda: OpKind
.__svandvs
_sim
1379 _GEN_ASMS
[SvAndVS
] = lambda: OpKind
.__svandvs
_gen
_asm
1382 def __svmaddedu_sim(op
, state
):
1383 # type: (Op, BaseSimState) -> None
1384 RA
= state
[op
.input_vals
[0]]
1385 RB
, = state
[op
.input_vals
[1]]
1386 carry
, = state
[op
.input_vals
[2]]
1387 VL
, = state
[op
.input_vals
[3]]
1388 RT
= [] # type: list[int]
1390 v
= RA
[i
] * RB
+ carry
1391 RT
.append(v
& GPR_VALUE_MASK
)
1392 carry
= v
>> GPR_SIZE_IN_BITS
1393 state
[op
.outputs
[0]] = tuple(RT
)
1394 state
[op
.outputs
[1]] = carry
,
1397 def __svmaddedu_gen_asm(op
, state
):
1398 # type: (Op, GenAsmState) -> None
1399 RT
= state
.vgpr(op
.outputs
[0])
1400 RA
= state
.vgpr(op
.input_vals
[0])
1401 RB
= state
.sgpr(op
.input_vals
[1])
1402 RC
= state
.sgpr(op
.input_vals
[2])
1403 state
.writeln(f
"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1404 SvMAddEDU
= GenericOpProperties(
1405 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
1406 inputs
=[OD_EXTRA2_VGPR
, OD_EXTRA2_SGPR
, OD_EXTRA2_SGPR
, OD_VL
],
1407 outputs
=[OD_EXTRA3_VGPR
, OD_EXTRA2_SGPR
.tied_to_input(2)],
1409 _SIM_FNS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_sim
1410 _GEN_ASMS
[SvMAddEDU
] = lambda: OpKind
.__svmaddedu
_gen
_asm
1413 def __sradi_sim(op
, state
):
1414 # type: (Op, BaseSimState) -> None
1415 rs
, = state
[op
.input_vals
[0]]
1416 imm
= op
.immediates
[0]
1417 if rs
>= 1 << (GPR_SIZE_IN_BITS
- 1):
1418 rs
-= 1 << GPR_SIZE_IN_BITS
1420 RA
= v
& GPR_VALUE_MASK
1421 CA
= (RA
<< imm
) != rs
1422 state
[op
.outputs
[0]] = RA
,
1423 state
[op
.outputs
[1]] = CA
,
1426 def __sradi_gen_asm(op
, state
):
1427 # type: (Op, GenAsmState) -> None
1428 RA
= state
.sgpr(op
.outputs
[0])
1429 RS
= state
.sgpr(op
.input_vals
[0])
1430 imm
= op
.immediates
[0]
1431 state
.writeln(f
"sradi {RA}, {RS}, {imm}")
1432 SRADI
= GenericOpProperties(
1433 demo_asm
="sradi RA, RS, imm",
1434 inputs
=[OD_BASE_SGPR
],
1435 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
),
1436 OD_CA
.with_write_stage(OpStage
.Late
)],
1437 immediates
=[range(GPR_SIZE_IN_BITS
)],
1439 _SIM_FNS
[SRADI
] = lambda: OpKind
.__sradi
_sim
1440 _GEN_ASMS
[SRADI
] = lambda: OpKind
.__sradi
_gen
_asm
1443 def __setvli_sim(op
, state
):
1444 # type: (Op, BaseSimState) -> None
1445 state
[op
.outputs
[0]] = op
.immediates
[0],
1448 def __setvli_gen_asm(op
, state
):
1449 # type: (Op, GenAsmState) -> None
1450 imm
= op
.immediates
[0]
1451 state
.writeln(f
"setvl 0, 0, {imm}, 0, 1, 1")
1452 SetVLI
= GenericOpProperties(
1453 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
1455 outputs
=[OD_VL
.with_write_stage(OpStage
.Late
)],
1456 immediates
=[range(1, 65)],
1457 is_load_immediate
=True,
1459 _SIM_FNS
[SetVLI
] = lambda: OpKind
.__setvli
_sim
1460 _GEN_ASMS
[SetVLI
] = lambda: OpKind
.__setvli
_gen
_asm
1463 def __svli_sim(op
, state
):
1464 # type: (Op, BaseSimState) -> None
1465 VL
, = state
[op
.input_vals
[0]]
1466 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1467 state
[op
.outputs
[0]] = (imm
,) * VL
1470 def __svli_gen_asm(op
, state
):
1471 # type: (Op, GenAsmState) -> None
1472 RT
= state
.vgpr(op
.outputs
[0])
1473 imm
= op
.immediates
[0]
1474 state
.writeln(f
"sv.addi {RT}, 0, {imm}")
1475 SvLI
= GenericOpProperties(
1476 demo_asm
="sv.addi *RT, 0, imm",
1478 outputs
=[OD_EXTRA3_VGPR
],
1479 immediates
=[IMM_S16
],
1480 is_load_immediate
=True,
1482 _SIM_FNS
[SvLI
] = lambda: OpKind
.__svli
_sim
1483 _GEN_ASMS
[SvLI
] = lambda: OpKind
.__svli
_gen
_asm
1486 def __li_sim(op
, state
):
1487 # type: (Op, BaseSimState) -> None
1488 imm
= op
.immediates
[0] & GPR_VALUE_MASK
1489 state
[op
.outputs
[0]] = imm
,
1492 def __li_gen_asm(op
, state
):
1493 # type: (Op, GenAsmState) -> None
1494 RT
= state
.sgpr(op
.outputs
[0])
1495 imm
= op
.immediates
[0]
1496 state
.writeln(f
"addi {RT}, 0, {imm}")
1497 LI
= GenericOpProperties(
1498 demo_asm
="addi RT, 0, imm",
1500 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1501 immediates
=[IMM_S16
],
1502 is_load_immediate
=True,
1504 _SIM_FNS
[LI
] = lambda: OpKind
.__li
_sim
1505 _GEN_ASMS
[LI
] = lambda: OpKind
.__li
_gen
_asm
1508 def __veccopytoreg_sim(op
, state
):
1509 # type: (Op, BaseSimState) -> None
1510 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1513 def __copy_to_from_reg_gen_asm(src_loc
, dest_loc
, is_vec
, state
):
1514 # type: (Loc, Loc, bool, GenAsmState) -> None
1515 sv
= "sv." if is_vec
else ""
1517 if src_loc
.conflicts(dest_loc
) and src_loc
.start
< dest_loc
.start
:
1519 if src_loc
== dest_loc
:
1521 if src_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1522 raise ValueError(f
"invalid src_loc.kind: {src_loc.kind}")
1523 if dest_loc
.kind
not in (LocKind
.GPR
, LocKind
.StackI64
):
1524 raise ValueError(f
"invalid dest_loc.kind: {dest_loc.kind}")
1525 if src_loc
.kind
is LocKind
.StackI64
:
1526 if dest_loc
.kind
is LocKind
.StackI64
:
1528 f
"can't copy from stack to stack: {src_loc} {dest_loc}")
1529 elif dest_loc
.kind
is not LocKind
.GPR
:
1530 assert_never(dest_loc
.kind
)
1531 src
= state
.stack(src_loc
)
1532 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1533 state
.writeln(f
"{sv}ld {dest}, {src}")
1534 elif dest_loc
.kind
is LocKind
.StackI64
:
1535 if src_loc
.kind
is not LocKind
.GPR
:
1536 assert_never(src_loc
.kind
)
1537 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1538 dest
= state
.stack(dest_loc
)
1539 state
.writeln(f
"{sv}std {src}, {dest}")
1540 elif src_loc
.kind
is LocKind
.GPR
:
1541 if dest_loc
.kind
is not LocKind
.GPR
:
1542 assert_never(dest_loc
.kind
)
1543 src
= state
.gpr(src_loc
, is_vec
=is_vec
)
1544 dest
= state
.gpr(dest_loc
, is_vec
=is_vec
)
1545 state
.writeln(f
"{sv}or{rev} {dest}, {src}, {src}")
1547 assert_never(src_loc
.kind
)
1550 def __veccopytoreg_gen_asm(op
, state
):
1551 # type: (Op, GenAsmState) -> None
1552 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1554 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1555 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1556 is_vec
=True, state
=state
)
1558 VecCopyToReg
= GenericOpProperties(
1559 demo_asm
="sv.mv dest, src",
1560 inputs
=[GenericOperandDesc(
1561 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1562 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1564 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1567 _SIM_FNS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_sim
1568 _GEN_ASMS
[VecCopyToReg
] = lambda: OpKind
.__veccopytoreg
_gen
_asm
1571 def __veccopyfromreg_sim(op
, state
):
1572 # type: (Op, BaseSimState) -> None
1573 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1576 def __veccopyfromreg_gen_asm(op
, state
):
1577 # type: (Op, GenAsmState) -> None
1578 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1579 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1581 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1582 is_vec
=True, state
=state
)
1583 VecCopyFromReg
= GenericOpProperties(
1584 demo_asm
="sv.mv dest, src",
1585 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1586 outputs
=[GenericOperandDesc(
1587 ty
=GenericTy(BaseTy
.I64
, is_vec
=True),
1588 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
, LocSubKind
.StackI64
],
1589 write_stage
=OpStage
.Late
,
1593 _SIM_FNS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_sim
1594 _GEN_ASMS
[VecCopyFromReg
] = lambda: OpKind
.__veccopyfromreg
_gen
_asm
1597 def __copytoreg_sim(op
, state
):
1598 # type: (Op, BaseSimState) -> None
1599 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1602 def __copytoreg_gen_asm(op
, state
):
1603 # type: (Op, GenAsmState) -> None
1604 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1606 op
.input_vals
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1607 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1608 is_vec
=False, state
=state
)
1609 CopyToReg
= GenericOpProperties(
1610 demo_asm
="mv dest, src",
1611 inputs
=[GenericOperandDesc(
1612 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1613 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1614 LocSubKind
.StackI64
],
1616 outputs
=[GenericOperandDesc(
1617 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1618 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1619 write_stage
=OpStage
.Late
,
1623 _SIM_FNS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_sim
1624 _GEN_ASMS
[CopyToReg
] = lambda: OpKind
.__copytoreg
_gen
_asm
1627 def __copyfromreg_sim(op
, state
):
1628 # type: (Op, BaseSimState) -> None
1629 state
[op
.outputs
[0]] = state
[op
.input_vals
[0]]
1632 def __copyfromreg_gen_asm(op
, state
):
1633 # type: (Op, GenAsmState) -> None
1634 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1635 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1637 op
.outputs
[0], (LocKind
.GPR
, LocKind
.StackI64
)),
1638 is_vec
=False, state
=state
)
1639 CopyFromReg
= GenericOpProperties(
1640 demo_asm
="mv dest, src",
1641 inputs
=[GenericOperandDesc(
1642 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1643 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
],
1645 outputs
=[GenericOperandDesc(
1646 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1647 sub_kinds
=[LocSubKind
.SV_EXTRA3_SGPR
, LocSubKind
.BASE_GPR
,
1648 LocSubKind
.StackI64
],
1649 write_stage
=OpStage
.Late
,
1653 _SIM_FNS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_sim
1654 _GEN_ASMS
[CopyFromReg
] = lambda: OpKind
.__copyfromreg
_gen
_asm
1657 def __concat_sim(op
, state
):
1658 # type: (Op, BaseSimState) -> None
1659 state
[op
.outputs
[0]] = tuple(
1660 state
[i
][0] for i
in op
.input_vals
[:-1])
1663 def __concat_gen_asm(op
, state
):
1664 # type: (Op, GenAsmState) -> None
1665 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1666 src_loc
=state
.loc(op
.input_vals
[0:-1], LocKind
.GPR
),
1667 dest_loc
=state
.loc(op
.outputs
[0], LocKind
.GPR
),
1668 is_vec
=True, state
=state
)
1669 Concat
= GenericOpProperties(
1670 demo_asm
="sv.mv dest, src",
1671 inputs
=[GenericOperandDesc(
1672 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1673 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1676 outputs
=[OD_EXTRA3_VGPR
.with_write_stage(OpStage
.Late
)],
1679 _SIM_FNS
[Concat
] = lambda: OpKind
.__concat
_sim
1680 _GEN_ASMS
[Concat
] = lambda: OpKind
.__concat
_gen
_asm
1683 def __spread_sim(op
, state
):
1684 # type: (Op, BaseSimState) -> None
1685 for idx
, inp
in enumerate(state
[op
.input_vals
[0]]):
1686 state
[op
.outputs
[idx
]] = inp
,
1689 def __spread_gen_asm(op
, state
):
1690 # type: (Op, GenAsmState) -> None
1691 OpKind
.__copy
_to
_from
_reg
_gen
_asm
(
1692 src_loc
=state
.loc(op
.input_vals
[0], LocKind
.GPR
),
1693 dest_loc
=state
.loc(op
.outputs
, LocKind
.GPR
),
1694 is_vec
=True, state
=state
)
1695 Spread
= GenericOpProperties(
1696 demo_asm
="sv.mv dest, src",
1697 inputs
=[OD_EXTRA3_VGPR
, OD_VL
],
1698 outputs
=[GenericOperandDesc(
1699 ty
=GenericTy(BaseTy
.I64
, is_vec
=False),
1700 sub_kinds
=[LocSubKind
.SV_EXTRA3_VGPR
],
1702 write_stage
=OpStage
.Late
,
1706 _SIM_FNS
[Spread
] = lambda: OpKind
.__spread
_sim
1707 _GEN_ASMS
[Spread
] = lambda: OpKind
.__spread
_gen
_asm
1710 def __svld_sim(op
, state
):
1711 # type: (Op, BaseSimState) -> None
1712 RA
, = state
[op
.input_vals
[0]]
1713 VL
, = state
[op
.input_vals
[1]]
1714 addr
= RA
+ op
.immediates
[0]
1715 RT
= [] # type: list[int]
1717 v
= state
.load(addr
+ GPR_SIZE_IN_BYTES
* i
)
1718 RT
.append(v
& GPR_VALUE_MASK
)
1719 state
[op
.outputs
[0]] = tuple(RT
)
1722 def __svld_gen_asm(op
, state
):
1723 # type: (Op, GenAsmState) -> None
1724 RA
= state
.sgpr(op
.input_vals
[0])
1725 RT
= state
.vgpr(op
.outputs
[0])
1726 imm
= op
.immediates
[0]
1727 state
.writeln(f
"sv.ld {RT}, {imm}({RA})")
1728 SvLd
= GenericOpProperties(
1729 demo_asm
="sv.ld *RT, imm(RA)",
1730 inputs
=[OD_EXTRA3_SGPR
, OD_VL
],
1731 outputs
=[OD_EXTRA3_VGPR
],
1732 immediates
=[IMM_S16
],
1734 _SIM_FNS
[SvLd
] = lambda: OpKind
.__svld
_sim
1735 _GEN_ASMS
[SvLd
] = lambda: OpKind
.__svld
_gen
_asm
1738 def __ld_sim(op
, state
):
1739 # type: (Op, BaseSimState) -> None
1740 RA
, = state
[op
.input_vals
[0]]
1741 addr
= RA
+ op
.immediates
[0]
1742 v
= state
.load(addr
)
1743 state
[op
.outputs
[0]] = v
& GPR_VALUE_MASK
,
1746 def __ld_gen_asm(op
, state
):
1747 # type: (Op, GenAsmState) -> None
1748 RA
= state
.sgpr(op
.input_vals
[0])
1749 RT
= state
.sgpr(op
.outputs
[0])
1750 imm
= op
.immediates
[0]
1751 state
.writeln(f
"ld {RT}, {imm}({RA})")
1752 Ld
= GenericOpProperties(
1753 demo_asm
="ld RT, imm(RA)",
1754 inputs
=[OD_BASE_SGPR
],
1755 outputs
=[OD_BASE_SGPR
.with_write_stage(OpStage
.Late
)],
1756 immediates
=[IMM_S16
],
1758 _SIM_FNS
[Ld
] = lambda: OpKind
.__ld
_sim
1759 _GEN_ASMS
[Ld
] = lambda: OpKind
.__ld
_gen
_asm
1762 def __svstd_sim(op
, state
):
1763 # type: (Op, BaseSimState) -> None
1764 RS
= state
[op
.input_vals
[0]]
1765 RA
, = state
[op
.input_vals
[1]]
1766 VL
, = state
[op
.input_vals
[2]]
1767 addr
= RA
+ op
.immediates
[0]
1769 state
.store(addr
+ GPR_SIZE_IN_BYTES
* i
, value
=RS
[i
])
1772 def __svstd_gen_asm(op
, state
):
1773 # type: (Op, GenAsmState) -> None
1774 RS
= state
.vgpr(op
.input_vals
[0])
1775 RA
= state
.sgpr(op
.input_vals
[1])
1776 imm
= op
.immediates
[0]
1777 state
.writeln(f
"sv.std {RS}, {imm}({RA})")
1778 SvStd
= GenericOpProperties(
1779 demo_asm
="sv.std *RS, imm(RA)",
1780 inputs
=[OD_EXTRA3_VGPR
, OD_EXTRA3_SGPR
, OD_VL
],
1782 immediates
=[IMM_S16
],
1783 has_side_effects
=True,
1785 _SIM_FNS
[SvStd
] = lambda: OpKind
.__svstd
_sim
1786 _GEN_ASMS
[SvStd
] = lambda: OpKind
.__svstd
_gen
_asm
1789 def __std_sim(op
, state
):
1790 # type: (Op, BaseSimState) -> None
1791 RS
, = state
[op
.input_vals
[0]]
1792 RA
, = state
[op
.input_vals
[1]]
1793 addr
= RA
+ op
.immediates
[0]
1794 state
.store(addr
, value
=RS
)
1797 def __std_gen_asm(op
, state
):
1798 # type: (Op, GenAsmState) -> None
1799 RS
= state
.sgpr(op
.input_vals
[0])
1800 RA
= state
.sgpr(op
.input_vals
[1])
1801 imm
= op
.immediates
[0]
1802 state
.writeln(f
"std {RS}, {imm}({RA})")
1803 Std
= GenericOpProperties(
1804 demo_asm
="std RS, imm(RA)",
1805 inputs
=[OD_BASE_SGPR
, OD_BASE_SGPR
],
1807 immediates
=[IMM_S16
],
1808 has_side_effects
=True,
1810 _SIM_FNS
[Std
] = lambda: OpKind
.__std
_sim
1811 _GEN_ASMS
[Std
] = lambda: OpKind
.__std
_gen
_asm
1814 def __funcargr3_sim(op
, state
):
1815 # type: (Op, BaseSimState) -> None
1816 pass # return value set before simulation
1819 def __funcargr3_gen_asm(op
, state
):
1820 # type: (Op, GenAsmState) -> None
1821 pass # no instructions needed
1822 FuncArgR3
= GenericOpProperties(
1825 outputs
=[OD_BASE_SGPR
.with_fixed_loc(
1826 Loc(kind
=LocKind
.GPR
, start
=3, reg_len
=1))],
1828 _SIM_FNS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_sim
1829 _GEN_ASMS
[FuncArgR3
] = lambda: OpKind
.__funcargr
3_gen
_asm
1832 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1833 class SSAValOrUse(metaclass
=InternedMeta
):
1834 __slots__
= "op", "operand_idx"
1836 def __init__(self
, op
, operand_idx
):
1837 # type: (Op, int) -> None
1840 if operand_idx
< 0 or operand_idx
>= len(self
.descriptor_array
):
1841 raise ValueError("invalid operand_idx")
1842 self
.operand_idx
= operand_idx
1851 def descriptor_array(self
):
1852 # type: () -> tuple[OperandDesc, ...]
1856 def defining_descriptor(self
):
1857 # type: () -> OperandDesc
1858 return self
.descriptor_array
[self
.operand_idx
]
1863 return self
.defining_descriptor
.ty
1866 def ty_before_spread(self
):
1868 return self
.defining_descriptor
.ty_before_spread
1872 # type: () -> BaseTy
1873 return self
.ty_before_spread
.base_ty
1876 def reg_offset_in_unspread(self
):
1877 """ the number of reg-sized slots in the unspread Loc before self's Loc
1879 e.g. if the unspread Loc containing self is:
1880 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1881 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1882 then reg_offset_into_unspread == 2 == 10 - 8
1884 return self
.defining_descriptor
.reg_offset_in_unspread
1887 def unspread_start_idx(self
):
1889 return self
.operand_idx
- (self
.defining_descriptor
.spread_index
or 0)
1892 def unspread_start(self
):
1894 return self
.__class
__(op
=self
.op
, operand_idx
=self
.unspread_start_idx
)
1897 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1899 class SSAVal(SSAValOrUse
):
1904 return f
"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1907 def def_loc_set_before_spread(self
):
1908 # type: () -> LocSet
1909 return self
.defining_descriptor
.loc_set_before_spread
1912 def descriptor_array(self
):
1913 # type: () -> tuple[OperandDesc, ...]
1914 return self
.op
.properties
.outputs
1917 def tied_input(self
):
1918 # type: () -> None | SSAUse
1919 if self
.defining_descriptor
.tied_input_index
is None:
1921 return SSAUse(op
=self
.op
,
1922 operand_idx
=self
.defining_descriptor
.tied_input_index
)
1925 def write_stage(self
):
1926 # type: () -> OpStage
1927 return self
.defining_descriptor
.write_stage
1930 def current_debugging_value(self
):
1931 # type: () -> tuple[int, ...]
1932 """ get the current value for debugging in pdb or similar.
1934 This is intended for use with
1935 `PreRASimState.set_current_debugging_state`.
1937 This is only intended for debugging, do not use in unit tests or
1940 return PreRASimState
.get_current_debugging_state()[self
]
1943 def ssa_val_sub_regs(self
):
1944 # type: () -> tuple[SSAValSubReg, ...]
1945 return tuple(SSAValSubReg(self
, i
) for i
in range(self
.ty
.reg_len
))
1948 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
1950 class SSAUse(SSAValOrUse
):
1954 def use_loc_set_before_spread(self
):
1955 # type: () -> LocSet
1956 return self
.defining_descriptor
.loc_set_before_spread
1959 def descriptor_array(self
):
1960 # type: () -> tuple[OperandDesc, ...]
1961 return self
.op
.properties
.inputs
1965 return f
"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1969 # type: () -> SSAVal
1970 return self
.op
.input_vals
[self
.operand_idx
]
1973 def ssa_val(self
, ssa_val
):
1974 # type: (SSAVal) -> None
1975 self
.op
.input_vals
[self
.operand_idx
] = ssa_val
1979 _Desc
= TypeVar("_Desc")
1982 class OpInputSeq(Sequence
[_T
], Generic
[_T
, _Desc
]):
1984 def _verify_write_with_desc(self
, idx
, item
, desc
):
1985 # type: (int, _T | Any, _Desc) -> None
1986 raise NotImplementedError
1989 def _verify_write(self
, idx
, item
):
1990 # type: (int | Any, _T | Any) -> int
1991 if not isinstance(idx
, int):
1992 if isinstance(idx
, slice):
1994 f
"can't write to slice of {self.__class__.__name__}")
1995 raise TypeError(f
"can't write with index {idx!r}")
1996 # normalize idx, raising IndexError if it is out of range
1997 idx
= range(len(self
.descriptors
))[idx
]
1998 desc
= self
.descriptors
[idx
]
1999 self
._verify
_write
_with
_desc
(idx
, item
, desc
)
2002 def _on_set(self
, idx
, new_item
, old_item
):
2003 # type: (int, _T, _T | None) -> None
2007 def _get_descriptors(self
):
2008 # type: () -> tuple[_Desc, ...]
2009 raise NotImplementedError
2013 def descriptors(self
):
2014 # type: () -> tuple[_Desc, ...]
2015 return self
._get
_descriptors
()
2022 def __init__(self
, items
, op
):
2023 # type: (Iterable[_T], Op) -> None
2026 self
.__items
= [] # type: list[_T]
2027 for idx
, item
in enumerate(items
):
2028 if idx
>= len(self
.descriptors
):
2029 raise ValueError("too many items")
2030 _
= self
._verify
_write
(idx
, item
)
2031 self
.__items
.append(item
)
2032 if len(self
.__items
) < len(self
.descriptors
):
2033 raise ValueError("not enough items")
2037 # type: () -> Iterator[_T]
2038 yield from self
.__items
2041 def __getitem__(self
, idx
):
2046 def __getitem__(self
, idx
):
2047 # type: (slice) -> list[_T]
2051 def __getitem__(self
, idx
):
2052 # type: (int | slice) -> _T | list[_T]
2053 return self
.__items
[idx
]
2056 def __setitem__(self
, idx
, item
):
2057 # type: (int, _T) -> None
2058 idx
= self
._verify
_write
(idx
, item
)
2059 self
.__items
[idx
] = item
2064 return len(self
.__items
)
2068 return f
"{self.__class__.__name__}({self.__items}, op=...)"
2072 class OpInputVals(OpInputSeq
[SSAVal
, OperandDesc
]):
2073 def _get_descriptors(self
):
2074 # type: () -> tuple[OperandDesc, ...]
2075 return self
.op
.properties
.inputs
2077 def _verify_write_with_desc(self
, idx
, item
, desc
):
2078 # type: (int, SSAVal | Any, OperandDesc) -> None
2079 if not isinstance(item
, SSAVal
):
2080 raise TypeError("expected value of type SSAVal")
2081 if item
.ty
!= desc
.ty
:
2082 raise ValueError(f
"assigned item's type {item.ty!r} doesn't match "
2083 f
"corresponding input's type {desc.ty!r}")
2085 def _on_set(self
, idx
, new_item
, old_item
):
2086 # type: (int, SSAVal, SSAVal | None) -> None
2087 SSAUses
._on
_op
_input
_set
(self
, idx
, new_item
, old_item
) # type: ignore
2089 def __init__(self
, items
, op
):
2090 # type: (Iterable[SSAVal], Op) -> None
2091 if hasattr(op
, "inputs"):
2092 raise ValueError("Op.inputs already set")
2093 super().__init
__(items
, op
)
2097 class OpImmediates(OpInputSeq
[int, range]):
2098 def _get_descriptors(self
):
2099 # type: () -> tuple[range, ...]
2100 return self
.op
.properties
.immediates
2102 def _verify_write_with_desc(self
, idx
, item
, desc
):
2103 # type: (int, int | Any, range) -> None
2104 if not isinstance(item
, int):
2105 raise TypeError("expected value of type int")
2106 if item
not in desc
:
2107 raise ValueError(f
"immediate value {item!r} not in {desc!r}")
2109 def __init__(self
, items
, op
):
2110 # type: (Iterable[int], Op) -> None
2111 if hasattr(op
, "immediates"):
2112 raise ValueError("Op.immediates already set")
2113 super().__init
__(items
, op
)
2116 @plain_data(frozen
=True, eq
=False, repr=False)
2119 __slots__
= ("fn", "properties", "input_vals", "input_uses", "immediates",
2122 def __init__(self
, fn
, properties
, input_vals
, immediates
, name
=""):
2123 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
2125 self
.properties
= properties
2126 self
.input_vals
= OpInputVals(input_vals
, op
=self
)
2127 inputs_len
= len(self
.properties
.inputs
)
2128 self
.input_uses
= tuple(SSAUse(self
, i
) for i
in range(inputs_len
))
2129 self
.immediates
= OpImmediates(immediates
, op
=self
)
2130 outputs_len
= len(self
.properties
.outputs
)
2131 self
.outputs
= tuple(SSAVal(self
, i
) for i
in range(outputs_len
))
2132 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
) # type: ignore
2136 # type: () -> OpKind
2137 return self
.properties
.kind
2139 def __eq__(self
, other
):
2140 # type: (Op | Any) -> bool
2141 if isinstance(other
, Op
):
2142 return self
is other
2143 return NotImplemented
2147 return object.__hash
__(self
)
2149 def __repr__(self
, wrap_width
=63, indent
=" "):
2150 # type: (int, str) -> str
2151 WRAP_POINT
= "\u200B" # zero-width space
2152 items
= [f
"{self.name}:\n"]
2153 for i
, out
in enumerate(self
.outputs
):
2154 item
= f
"<...outputs[{i}]: {out.ty}>"
2156 item
= "(" + WRAP_POINT
+ item
2157 if i
!= len(self
.outputs
) - 1:
2158 item
+= ", " + WRAP_POINT
2160 item
+= WRAP_POINT
+ ") <= "
2162 items
.append(self
.kind
._name
_)
2163 if len(self
.input_vals
) + len(self
.immediates
) != 0:
2165 items
[-1] += WRAP_POINT
2166 for i
, inp
in enumerate(self
.input_vals
):
2168 if i
!= len(self
.input_vals
) - 1 or len(self
.immediates
) != 0:
2169 item
+= ", " + WRAP_POINT
2171 item
+= ") " + WRAP_POINT
2173 for i
, imm
in enumerate(self
.immediates
):
2175 if i
!= len(self
.immediates
) - 1:
2176 item
+= ", " + WRAP_POINT
2178 item
+= ") " + WRAP_POINT
2180 lines
= [] # type: list[str]
2181 for i
, line_in
in enumerate("".join(items
).splitlines()):
2183 line_in
= indent
+ line_in
2185 for part
in line_in
.split(WRAP_POINT
):
2189 trial_line_out
= line_out
+ part
2190 if len(trial_line_out
.rstrip()) > wrap_width
:
2191 lines
.append(line_out
.rstrip())
2192 line_out
= indent
+ part
2194 line_out
= trial_line_out
2195 lines
.append(line_out
.rstrip())
2196 return "\n".join(lines
)
2198 def sim(self
, state
):
2199 # type: (BaseSimState) -> None
2200 for inp
in self
.input_vals
:
2204 raise ValueError(f
"SSAVal {inp} not yet assigned when "
2208 if len(val
) != inp
.ty
.reg_len
:
2210 f
"value of SSAVal {inp} has wrong number of elements: "
2211 f
"expected {inp.ty.reg_len} found "
2212 f
"{len(val)}: {val!r}")
2213 if isinstance(state
, PreRASimState
):
2214 for out
in self
.outputs
:
2215 if out
in state
.ssa_vals
:
2216 if self
.kind
is OpKind
.FuncArgR3
:
2218 raise ValueError(f
"SSAVal {out} already assigned before "
2221 self
.kind
.sim(self
, state
)
2224 for out
in self
.outputs
:
2228 raise ValueError(f
"running {self} failed to assign to {out}")
2231 if len(val
) != out
.ty
.reg_len
:
2233 f
"value of SSAVal {out} has wrong number of elements: "
2234 f
"expected {out.ty.reg_len} found "
2235 f
"{len(val)}: {val!r}")
2237 def gen_asm(self
, state
):
2238 # type: (GenAsmState) -> None
2239 all_loc_kinds
= tuple(LocKind
)
2240 for inp
in self
.input_vals
:
2241 state
.loc(inp
, expected_kinds
=all_loc_kinds
)
2242 for out
in self
.outputs
:
2243 state
.loc(out
, expected_kinds
=all_loc_kinds
)
2244 self
.kind
.gen_asm(self
, state
)
2247 @plain_data(frozen
=True, repr=False)
2248 class BaseSimState(metaclass
=ABCMeta
):
2249 __slots__
= "memory",
2251 def __init__(self
, memory
):
2252 # type: (dict[int, int]) -> None
2254 self
.memory
= memory
# type: dict[int, int]
2256 def _default_memory_value(self
):
2260 def on_skip(self
, op
):
2261 # type: (Op) -> None
2262 raise ValueError("skipping instructions not supported")
2264 def load_byte(self
, addr
):
2265 # type: (int) -> int
2266 addr
&= GPR_VALUE_MASK
2268 return self
.memory
[addr
] & 0xFF
2270 return self
._default
_memory
_value
()
2272 def store_byte(self
, addr
, value
):
2273 # type: (int, int) -> None
2274 addr
&= GPR_VALUE_MASK
2276 self
.memory
[addr
] = value
2278 def load(self
, addr
, size_in_bytes
=GPR_SIZE_IN_BYTES
, signed
=False):
2279 # type: (int, int, bool) -> int
2280 if addr
% size_in_bytes
!= 0:
2281 raise ValueError(f
"address not aligned: {hex(addr)} "
2282 f
"required alignment: {size_in_bytes}")
2284 for i
in range(size_in_bytes
):
2285 retval |
= self
.load_byte(addr
+ i
) << i
* BITS_IN_BYTE
2286 if signed
and retval
>> (size_in_bytes
* BITS_IN_BYTE
- 1) != 0:
2287 retval
-= 1 << size_in_bytes
* BITS_IN_BYTE
2290 def store(self
, addr
, value
, size_in_bytes
=GPR_SIZE_IN_BYTES
):
2291 # type: (int, int, int) -> None
2292 if addr
% size_in_bytes
!= 0:
2293 raise ValueError(f
"address not aligned: {hex(addr)} "
2294 f
"required alignment: {size_in_bytes}")
2295 for i
in range(size_in_bytes
):
2296 self
.store_byte(addr
+ i
, (value
>> i
* BITS_IN_BYTE
) & 0xFF)
2298 def _memory__repr(self
):
2300 if len(self
.memory
) == 0:
2302 keys
= sorted(self
.memory
.keys(), reverse
=True)
2303 CHUNK_SIZE
= GPR_SIZE_IN_BYTES
2304 items
= [] # type: list[str]
2305 while len(keys
) != 0:
2307 if (len(keys
) >= CHUNK_SIZE
2308 and addr
% CHUNK_SIZE
== 0
2309 and keys
[-CHUNK_SIZE
:]
2310 == list(reversed(range(addr
, addr
+ CHUNK_SIZE
)))):
2311 value
= self
.load(addr
, size_in_bytes
=CHUNK_SIZE
)
2312 items
.append(f
"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2313 keys
[-CHUNK_SIZE
:] = ()
2315 items
.append(f
"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2317 return f
"{{{items[0]}}}"
2318 items_str
= ",\n".join(items
)
2319 return f
"{{\n{items_str}}}"
2323 field_vals
= [] # type: list[str]
2324 for name
in fields(self
):
2326 value
= getattr(self
, name
)
2327 except AttributeError:
2328 field_vals
.append(f
"{name}=<not set>")
2330 repr_fn
= getattr(self
, f
"_{name}__repr", None)
2331 if callable(repr_fn
):
2332 field_vals
.append(f
"{name}={repr_fn()}")
2334 field_vals
.append(f
"{name}={value!r}")
2335 field_vals_str
= ", ".join(field_vals
)
2336 return f
"{self.__class__.__name__}({field_vals_str})"
2339 def __getitem__(self
, ssa_val
):
2340 # type: (SSAVal) -> tuple[int, ...]
2344 def __setitem__(self
, ssa_val
, value
):
2345 # type: (SSAVal, Iterable[int]) -> None
2349 @plain_data(frozen
=True, repr=False)
2350 class PreRABaseSimState(BaseSimState
):
2351 __slots__
= "ssa_vals",
2353 def __init__(self
, ssa_vals
, memory
):
2354 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2355 super().__init
__(memory
)
2356 self
.ssa_vals
= ssa_vals
# type: dict[SSAVal, tuple[int, ...]]
2358 def _ssa_vals__repr(self
):
2360 if len(self
.ssa_vals
) == 0:
2362 items
= [] # type: list[str]
2364 for k
, v
in self
.ssa_vals
.items():
2365 element_strs
= [] # type: list[str]
2366 for i
, el
in enumerate(v
):
2367 if i
% CHUNK_SIZE
!= 0:
2368 element_strs
.append(" " + hex(el
))
2370 element_strs
.append("\n " + hex(el
))
2371 if len(element_strs
) <= CHUNK_SIZE
:
2372 element_strs
[0] = element_strs
[0].lstrip()
2373 if len(element_strs
) == 1:
2374 element_strs
.append("")
2375 v_str
= ",".join(element_strs
)
2376 items
.append(f
"{k!r}: ({v_str})")
2377 if len(items
) == 1 and "\n" not in items
[0]:
2378 return f
"{{{items[0]}}}"
2379 items_str
= ",\n".join(items
)
2380 return f
"{{\n{items_str},\n}}"
2382 def __getitem__(self
, ssa_val
):
2383 # type: (SSAVal) -> tuple[int, ...]
2385 return self
.ssa_vals
[ssa_val
]
2387 return self
._handle
_undefined
_ssa
_val
(ssa_val
)
2389 def _handle_undefined_ssa_val(self
, ssa_val
):
2390 # type: (SSAVal) -> tuple[int, ...]
2391 raise KeyError("SSAVal has no value set", ssa_val
)
2393 def __setitem__(self
, ssa_val
, value
):
2394 # type: (SSAVal, Iterable[int]) -> None
2395 value
= tuple(map(int, value
))
2396 if len(value
) != ssa_val
.ty
.reg_len
:
2397 raise ValueError("value has wrong len")
2398 self
.ssa_vals
[ssa_val
] = value
2401 class SimSkipOp(Exception):
2405 @plain_data(frozen
=True, repr=False)
2407 class ConstPropagationState(PreRABaseSimState
):
2408 __slots__
= "skipped_ops",
2410 def __init__(self
, ssa_vals
, memory
, skipped_ops
):
2411 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int], OSet[Op]) -> None
2412 super().__init
__(ssa_vals
, memory
)
2413 self
.skipped_ops
= skipped_ops
2415 def _default_memory_value(self
):
2419 def _handle_undefined_ssa_val(self
, ssa_val
):
2420 # type: (SSAVal) -> tuple[int, ...]
2423 def on_skip(self
, op
):
2424 # type: (Op) -> None
2425 self
.skipped_ops
.add(op
)
2428 @plain_data(frozen
=True, repr=False)
2429 class PreRASimState(PreRABaseSimState
):
2432 __CURRENT_DEBUGGING_STATE
= [] # type: list[PreRASimState]
2435 def set_as_current_debugging_state(self
):
2436 """ return a context manager that sets self as the current state for
2437 debugging in pdb or similar. This is intended only for use with
2438 `get_current_debugging_state` which should not be used in unit tests
2442 PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.append(self
)
2445 assert self
is PreRASimState
.__CURRENT
_DEBUGGING
_STATE
.pop(), \
2446 "inconsistent __CURRENT_DEBUGGING_STATE"
2449 def get_current_debugging_state():
2450 # type: () -> PreRASimState
2451 """ get the current state for debugging in pdb or similar.
2453 This is intended for use with `set_current_debugging_state`.
2455 This is only intended for debugging, do not use in unit tests or
2458 if len(PreRASimState
.__CURRENT
_DEBUGGING
_STATE
) == 0:
2459 raise ValueError("no current debugging state")
2460 return PreRASimState
.__CURRENT
_DEBUGGING
_STATE
[-1]
2463 @plain_data(frozen
=True, repr=False)
2465 class PostRASimState(BaseSimState
):
2466 __slots__
= "ssa_val_to_loc_map", "loc_values"
2468 def __init__(self
, ssa_val_to_loc_map
, memory
, loc_values
):
2469 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2470 super().__init
__(memory
)
2471 self
.ssa_val_to_loc_map
= FMap(ssa_val_to_loc_map
)
2472 for ssa_val
, loc
in self
.ssa_val_to_loc_map
.items():
2473 if ssa_val
.ty
!= loc
.ty
:
2475 f
"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2476 self
.loc_values
= loc_values
2477 for loc
in self
.loc_values
.keys():
2478 if loc
.reg_len
!= 1:
2480 "loc_values must only contain Locs with reg_len=1, all "
2481 "larger Locs will be split into reg_len=1 sub-Locs")
2483 def _loc_values__repr(self
):
2485 locs
= sorted(self
.loc_values
.keys(),
2486 key
=lambda v
: (v
.kind
.name
, v
.start
))
2487 items
= [] # type: list[str]
2489 items
.append(f
"{loc}: 0x{self.loc_values[loc]:x}")
2490 items_str
= ",\n".join(items
)
2491 return f
"{{\n{items_str},\n}}"
2493 def __getitem__(self
, ssa_val
):
2494 # type: (SSAVal) -> tuple[int, ...]
2495 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2496 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2497 retval
= [] # type: list[int]
2498 for i
in range(loc
.reg_len
):
2499 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2500 retval
.append(self
.loc_values
.get(subloc
, 0))
2501 return tuple(retval
)
2503 def __setitem__(self
, ssa_val
, value
):
2504 # type: (SSAVal, Iterable[int]) -> None
2505 value
= tuple(map(int, value
))
2506 if len(value
) != ssa_val
.ty
.reg_len
:
2507 raise ValueError("value has wrong len")
2508 loc
= self
.ssa_val_to_loc_map
[ssa_val
]
2509 subloc_ty
= Ty(base_ty
=loc
.ty
.base_ty
, reg_len
=1)
2510 for i
in range(loc
.reg_len
):
2511 subloc
= loc
.get_subloc_at_offset(subloc_ty
=subloc_ty
, offset
=i
)
2512 self
.loc_values
[subloc
] = value
[i
]
2515 @plain_data(frozen
=True)
2517 __slots__
= "allocated_locs", "output"
2519 def __init__(self
, allocated_locs
, output
=None):
2520 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2522 self
.allocated_locs
= FMap(allocated_locs
)
2523 for ssa_val
, loc
in self
.allocated_locs
.items():
2524 if ssa_val
.ty
!= loc
.ty
:
2526 f
"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2529 self
.output
= output
2531 __SSA_VAL_OR_LOCS
= Union
[SSAVal
, Loc
, Sequence
["SSAVal | Loc"]]
2533 def loc(self
, ssa_val_or_locs
, expected_kinds
):
2534 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2535 if isinstance(ssa_val_or_locs
, (SSAVal
, Loc
)):
2536 ssa_val_or_locs
= [ssa_val_or_locs
]
2537 locs
= [] # type: list[Loc]
2538 for i
in ssa_val_or_locs
:
2539 if isinstance(i
, SSAVal
):
2540 locs
.append(self
.allocated_locs
[i
])
2544 raise ValueError("invalid Loc sequence: must not be empty")
2545 retval
= locs
[0].try_concat(*locs
[1:])
2547 raise ValueError("invalid Loc sequence: try_concat failed")
2548 if isinstance(expected_kinds
, LocKind
):
2549 expected_kinds
= expected_kinds
,
2550 if retval
.kind
not in expected_kinds
:
2551 if len(expected_kinds
) == 1:
2552 expected_kinds
= expected_kinds
[0]
2553 raise ValueError(f
"LocKind mismatch: {ssa_val_or_locs}: found "
2554 f
"{retval.kind} expected {expected_kinds}")
2557 def gpr(self
, ssa_val_or_locs
, is_vec
):
2558 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2559 loc
= self
.loc(ssa_val_or_locs
, LocKind
.GPR
)
2560 vec_str
= "*" if is_vec
else ""
2561 return vec_str
+ str(loc
.start
)
2563 def sgpr(self
, ssa_val_or_locs
):
2564 # type: (__SSA_VAL_OR_LOCS) -> str
2565 return self
.gpr(ssa_val_or_locs
, is_vec
=False)
2567 def vgpr(self
, ssa_val_or_locs
):
2568 # type: (__SSA_VAL_OR_LOCS) -> str
2569 return self
.gpr(ssa_val_or_locs
, is_vec
=True)
2571 def stack(self
, ssa_val_or_locs
):
2572 # type: (__SSA_VAL_OR_LOCS) -> str
2573 loc
= self
.loc(ssa_val_or_locs
, LocKind
.StackI64
)
2574 return f
"{loc.start}(1)"
2576 def writeln(self
, *line_segments
):
2577 # type: (*str) -> None
2578 line
= " ".join(line_segments
)
2579 if isinstance(self
.output
, list):
2580 self
.output
.append(line
)
2582 self
.output
.write(line
+ "\n")