from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
-from nmutil.plain_data import fields, plain_data
+from dataclasses import dataclass
+from nmutil import plain_data # type: ignore
from bigint_presentation_code.type_util import (Literal, Self, assert_never,
final)
-from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
+from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
OFSet, OSet)
GPR_SIZE_IN_BYTES = 8
assert OpStage.Early < OpStage.Late, "early must be less than late"
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
-@total_ordering
-class ProgramPoint(metaclass=InternedMeta):
- __slots__ = "op_index", "stage"
-
- def __init__(self, op_index, stage):
- # type: (int, OpStage) -> None
- self.op_index = op_index
- self.stage = stage
+class ProgramPoint(Interned):
+ op_index: int
+ stage: OpStage
@property
def int_value(self):
return self.op_index < other.op_index
return self.stage < other.stage
+ def __gt__(self, other):
+ # type: (ProgramPoint | Any) -> bool
+ if not isinstance(other, ProgramPoint):
+ return NotImplemented
+ return other.__lt__(self)
+
+ def __le__(self, other):
+ # type: (ProgramPoint | Any) -> bool
+ if not isinstance(other, ProgramPoint):
+ return NotImplemented
+ return not self.__gt__(other)
+
+ def __ge__(self, other):
+ # type: (ProgramPoint | Any) -> bool
+ if not isinstance(other, ProgramPoint):
+ return NotImplemented
+ return not self.__lt__(other)
+
def __repr__(self):
# type: () -> str
return f"<ops[{self.op_index}]:{self.stage._name_}>"
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
-class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
- __slots__ = "start", "stop"
-
- def __init__(self, start, stop):
- # type: (ProgramPoint, ProgramPoint) -> None
- self.start = start
- self.stop = stop
+class ProgramRange(Sequence[ProgramPoint], Interned):
+ start: ProgramPoint
+ stop: ProgramPoint
@cached_property
def int_value_range(self):
return f"<range:{start}..{stop}>"
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
-class SSAValSubReg(metaclass=InternedMeta):
- __slots__ = "ssa_val", "reg_idx"
+class SSAValSubReg(Interned):
+ ssa_val: "SSAVal"
+ reg_idx: int
- def __init__(self, ssa_val, reg_idx):
- # type: (SSAVal, int) -> None
- if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
+ def __post_init__(self):
+ if self.reg_idx < 0 or self.reg_idx >= self.ssa_val.ty.reg_len:
raise ValueError("reg_idx out of range")
- self.ssa_val = ssa_val
- self.reg_idx = reg_idx
def __repr__(self):
# type: () -> str
return f"{self.ssa_val}[{self.reg_idx}]"
-@plain_data(frozen=True, eq=False, repr=False)
+@plain_data.plain_data(frozen=True, eq=False, repr=False)
@final
class FnAnalysis:
__slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
return "BaseTy." + self._name_
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
-class Ty(metaclass=InternedMeta):
- __slots__ = "base_ty", "reg_len"
+class Ty(Interned):
+ base_ty: BaseTy
+ reg_len: int
@staticmethod
def validate(base_ty, reg_len):
return "reg_len out of range"
return None
- def __init__(self, base_ty, reg_len):
- # type: (BaseTy, int) -> None
- msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+ def __post_init__(self):
+ msg = self.validate(base_ty=self.base_ty, reg_len=self.reg_len)
if msg is not None:
raise ValueError(msg)
- self.base_ty = base_ty
- self.reg_len = reg_len
def __repr__(self):
# type: () -> str
return "LocSubKind." + self._name_
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
@final
-class GenericTy(metaclass=InternedMeta):
- __slots__ = "base_ty", "is_vec"
+class GenericTy(Interned):
+ base_ty: BaseTy
+ is_vec: bool
- def __init__(self, base_ty, is_vec):
- # type: (BaseTy, bool) -> None
- self.base_ty = base_ty
- if base_ty.only_scalar and is_vec:
- raise ValueError(f"base_ty={base_ty} requires is_vec=False")
- self.is_vec = is_vec
+ def __post_init__(self):
+ if self.base_ty.only_scalar and self.is_vec:
+ raise ValueError(f"base_ty={self.base_ty} requires is_vec=False")
def instantiate(self, maxvl):
# type: (int) -> Ty
return ty.reg_len == 1
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
@final
-class Loc(metaclass=InternedMeta):
- __slots__ = "kind", "start", "reg_len"
+class Loc(Interned):
+ kind: LocKind
+ start: int
+ reg_len: int
@staticmethod
def validate(kind, start, reg_len):
return None
return Loc(kind=kind, start=start, reg_len=reg_len)
- def __init__(self, kind, start, reg_len):
- # type: (LocKind, int, int) -> None
- msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+ def __post_init__(self):
+ msg = self.validate(kind=self.kind, start=self.start,
+ reg_len=self.reg_len)
if msg is not None:
raise ValueError(msg)
- self.kind = kind
- self.reg_len = reg_len
- self.start = start
def conflicts(self, other):
# type: (Loc) -> bool
@final
-class LocSet(OFSet[Loc], metaclass=InternedMeta):
+class LocSet(OFSet[Loc], Interned):
def __init__(self, __locs=()):
# type: (Iterable[Loc]) -> None
super().__init__(__locs)
return only_loc
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
@final
-class GenericOperandDesc(metaclass=InternedMeta):
+class GenericOperandDesc(Interned):
"""generic Op operand descriptor"""
- __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
- "write_stage")
+ ty: GenericTy
+ sub_kinds: OFSet[LocSubKind]
+ fixed_loc: "Loc | None" = None
+ tied_input_index: "int | None" = None
+ spread: bool = False
+ write_stage: OpStage = OpStage.Early
def __init__(
self, ty, # type: GenericTy
write_stage=OpStage.Early, # type: OpStage
):
# type: (...) -> None
- self.ty = ty
- self.sub_kinds = OFSet(sub_kinds)
+ object.__setattr__(self, "ty", ty)
+ object.__setattr__(self, "sub_kinds", OFSet(sub_kinds))
if len(self.sub_kinds) == 0:
raise ValueError("sub_kinds can't be empty")
- self.fixed_loc = fixed_loc
+ object.__setattr__(self, "fixed_loc", fixed_loc)
if fixed_loc is not None:
if tied_input_index is not None:
raise ValueError("operand can't be both tied and fixed")
f"sub_kind={sub_kind} ty={ty}")
if tied_input_index is not None and tied_input_index < 0:
raise ValueError("invalid tied_input_index")
- self.tied_input_index = tied_input_index
- self.spread = spread
+ object.__setattr__(self, "tied_input_index", tied_input_index)
+ object.__setattr__(self, "spread", spread)
if spread:
if self.tied_input_index is not None:
raise ValueError("operand can't be both spread and tied")
raise ValueError("operand can't be both spread and fixed")
if self.ty.is_vec:
raise ValueError("operand can't be both spread and vector")
- self.write_stage = write_stage
+ object.__setattr__(self, "write_stage", write_stage)
@cached_property
def ty_before_spread(self):
spread_index=idx, write_stage=self.write_stage)
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
@final
-class OperandDesc(metaclass=InternedMeta):
+class OperandDesc(Interned):
"""Op operand descriptor"""
- __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
- "write_stage")
+ loc_set_before_spread: LocSet
+ tied_input_index: "int | None"
+ spread_index: "int | None"
+ write_stage: "OpStage"
- def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
- write_stage):
- # type: (LocSet, int | None, int | None, OpStage) -> None
- if len(loc_set_before_spread) == 0:
+ def __post_init__(self):
+ if len(self.loc_set_before_spread) == 0:
raise ValueError("loc_set_before_spread must not be empty")
- self.loc_set_before_spread = loc_set_before_spread
- self.tied_input_index = tied_input_index
- if self.tied_input_index is not None and spread_index is not None:
+ if self.tied_input_index is not None and self.spread_index is not None:
raise ValueError("operand can't be both spread and tied")
- self.spread_index = spread_index
- self.write_stage = write_stage
@cached_property
def ty_before_spread(self):
sub_kinds=[LocSubKind.VL_MAXVL])
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
@final
-class GenericOpProperties(metaclass=InternedMeta):
- __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
- "is_copy", "is_load_immediate", "has_side_effects")
+class GenericOpProperties(Interned):
+ demo_asm: str
+ inputs: "tuple[GenericOperandDesc, ...]"
+ outputs: "tuple[GenericOperandDesc, ...]"
+ immediates: "tuple[range, ...]"
+ is_copy: bool
+ is_load_immediate: bool
+ has_side_effects: bool
def __init__(
self, demo_asm, # type: str
has_side_effects=False, # type: bool
):
# type: (...) -> None
- self.demo_asm = demo_asm # type: str
- self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
+ object.__setattr__(self, "demo_asm", demo_asm)
+ object.__setattr__(self, "inputs", tuple(inputs))
for inp in self.inputs:
if inp.tied_input_index is not None:
raise ValueError(
if inp.write_stage is not OpStage.Early:
raise ValueError(
f"write_stage is not allowed on inputs: {inp}")
- self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
+ object.__setattr__(self, "outputs", tuple(outputs))
fixed_locs = [] # type: list[tuple[Loc, int]]
for idx, out in enumerate(self.outputs):
if out.tied_input_index is not None:
f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
f"with {other_fixed_loc}")
fixed_locs.append((out.fixed_loc, idx))
- self.immediates = tuple(immediates) # type: tuple[range, ...]
- self.is_copy = is_copy # type: bool
- self.is_load_immediate = is_load_immediate # type: bool
- self.has_side_effects = has_side_effects # type: bool
+ object.__setattr__(self, "immediates", tuple(immediates))
+ object.__setattr__(self, "is_copy", is_copy)
+ object.__setattr__(self, "is_load_immediate", is_load_immediate)
+ object.__setattr__(self, "has_side_effects", has_side_effects)
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data.plain_data(frozen=True, unsafe_hash=True)
@final
-class OpProperties(metaclass=InternedMeta):
+class OpProperties:
__slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
def __init__(self, kind, maxvl):
_GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-class SSAValOrUse(metaclass=InternedMeta):
- __slots__ = "op", "operand_idx"
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
+class SSAValOrUse(Interned):
+ op: "Op"
+ operand_idx: int
- def __init__(self, op, operand_idx):
- # type: (Op, int) -> None
- super().__init__()
- self.op = op
- if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
+ def __post_init__(self):
+ if self.operand_idx < 0 or \
+ self.operand_idx >= len(self.descriptor_array):
raise ValueError("invalid operand_idx")
- self.operand_idx = operand_idx
@abstractmethod
def __repr__(self):
return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAVal(SSAValOrUse):
__slots__ = ()
return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAUse(SSAValOrUse):
__slots__ = ()
super().__init__(items, op)
-@plain_data(frozen=True, eq=False, repr=False)
+@plain_data.plain_data(frozen=True, eq=False, repr=False)
@final
class Op:
__slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
self.kind.gen_asm(self, state)
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
class BaseSimState(metaclass=ABCMeta):
__slots__ = "memory",
def __repr__(self):
# type: () -> str
field_vals = [] # type: list[str]
- for name in fields(self):
+ for name in plain_data.fields(self):
try:
value = getattr(self, name)
except AttributeError:
...
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
class PreRABaseSimState(BaseSimState):
__slots__ = "ssa_vals",
pass
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
@final
class ConstPropagationState(PreRABaseSimState):
__slots__ = "skipped_ops",
self.skipped_ops.add(op)
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
class PreRASimState(PreRABaseSimState):
__slots__ = ()
return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
@final
class PostRASimState(BaseSimState):
__slots__ = "ssa_val_to_loc_map", "loc_values"
self.loc_values[subloc] = value[i]
-@plain_data(frozen=True)
+@plain_data.plain_data(frozen=True)
class GenAsmState:
__slots__ = "allocated_locs", "output"