working on code some more
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
index 4f8fb79bbbf15788663ca0875fd6150db7679e51..2d1949219f04f4567c95d436d00b766e2511e1f7 100644 (file)
@@ -9,11 +9,12 @@ from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
 from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
-from nmutil.plain_data import fields, plain_data
+from dataclasses import dataclass
+from nmutil import plain_data  # type: ignore
 
 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
                                                 final)
-from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
+from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned,
                                            OFSet, OSet)
 
 GPR_SIZE_IN_BYTES = 8
@@ -209,16 +210,11 @@ class OpStage(Enum):
 assert OpStage.Early < OpStage.Late, "early must be less than late"
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
-@total_ordering
-class ProgramPoint(metaclass=InternedMeta):
-    __slots__ = "op_index", "stage"
-
-    def __init__(self, op_index, stage):
-        # type: (int, OpStage) -> None
-        self.op_index = op_index
-        self.stage = stage
+class ProgramPoint(Interned):
+    op_index: int
+    stage: OpStage
 
     @property
     def int_value(self):
@@ -250,20 +246,34 @@ class ProgramPoint(metaclass=InternedMeta):
             return self.op_index < other.op_index
         return self.stage < other.stage
 
+    def __gt__(self, other):
+        # type: (ProgramPoint | Any) -> bool
+        if not isinstance(other, ProgramPoint):
+            return NotImplemented
+        return other.__lt__(self)
+
+    def __le__(self, other):
+        # type: (ProgramPoint | Any) -> bool
+        if not isinstance(other, ProgramPoint):
+            return NotImplemented
+        return not self.__gt__(other)
+
+    def __ge__(self, other):
+        # type: (ProgramPoint | Any) -> bool
+        if not isinstance(other, ProgramPoint):
+            return NotImplemented
+        return not self.__lt__(other)
+
     def __repr__(self):
         # type: () -> str
         return f"<ops[{self.op_index}]:{self.stage._name_}>"
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
-class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
-    __slots__ = "start", "stop"
-
-    def __init__(self, start, stop):
-        # type: (ProgramPoint, ProgramPoint) -> None
-        self.start = start
-        self.stop = stop
+class ProgramRange(Sequence[ProgramPoint], Interned):
+    start: ProgramPoint
+    stop: ProgramPoint
 
     @cached_property
     def int_value_range(self):
@@ -311,24 +321,22 @@ class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
         return f"<range:{start}..{stop}>"
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
-class SSAValSubReg(metaclass=InternedMeta):
-    __slots__ = "ssa_val", "reg_idx"
+class SSAValSubReg(Interned):
+    ssa_val: "SSAVal"
+    reg_idx: int
 
-    def __init__(self, ssa_val, reg_idx):
-        # type: (SSAVal, int) -> None
-        if reg_idx < 0 or reg_idx >= ssa_val.ty.reg_len:
+    def __post_init__(self):
+        if self.reg_idx < 0 or self.reg_idx >= self.ssa_val.ty.reg_len:
             raise ValueError("reg_idx out of range")
-        self.ssa_val = ssa_val
-        self.reg_idx = reg_idx
 
     def __repr__(self):
         # type: () -> str
         return f"{self.ssa_val}[{self.reg_idx}]"
 
 
-@plain_data(frozen=True, eq=False, repr=False)
+@plain_data.plain_data(frozen=True, eq=False, repr=False)
 @final
 class FnAnalysis:
     __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
@@ -542,10 +550,11 @@ class BaseTy(Enum):
         return "BaseTy." + self._name_
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
-class Ty(metaclass=InternedMeta):
-    __slots__ = "base_ty", "reg_len"
+class Ty(Interned):
+    base_ty: BaseTy
+    reg_len: int
 
     @staticmethod
     def validate(base_ty, reg_len):
@@ -559,13 +568,10 @@ class Ty(metaclass=InternedMeta):
             return "reg_len out of range"
         return None
 
-    def __init__(self, base_ty, reg_len):
-        # type: (BaseTy, int) -> None
-        msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+    def __post_init__(self):
+        msg = self.validate(base_ty=self.base_ty, reg_len=self.reg_len)
         if msg is not None:
             raise ValueError(msg)
-        self.base_ty = base_ty
-        self.reg_len = reg_len
 
     def __repr__(self):
         # type: () -> str
@@ -682,17 +688,15 @@ class LocSubKind(Enum):
         return "LocSubKind." + self._name_
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
 @final
-class GenericTy(metaclass=InternedMeta):
-    __slots__ = "base_ty", "is_vec"
+class GenericTy(Interned):
+    base_ty: BaseTy
+    is_vec: bool
 
-    def __init__(self, base_ty, is_vec):
-        # type: (BaseTy, bool) -> None
-        self.base_ty = base_ty
-        if base_ty.only_scalar and is_vec:
-            raise ValueError(f"base_ty={base_ty} requires is_vec=False")
-        self.is_vec = is_vec
+    def __post_init__(self):
+        if self.base_ty.only_scalar and self.is_vec:
+            raise ValueError(f"base_ty={self.base_ty} requires is_vec=False")
 
     def instantiate(self, maxvl):
         # type: (int) -> Ty
@@ -710,10 +714,12 @@ class GenericTy(metaclass=InternedMeta):
         return ty.reg_len == 1
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
 @final
-class Loc(metaclass=InternedMeta):
-    __slots__ = "kind", "start", "reg_len"
+class Loc(Interned):
+    kind: LocKind
+    start: int
+    reg_len: int
 
     @staticmethod
     def validate(kind, start, reg_len):
@@ -735,14 +741,11 @@ class Loc(metaclass=InternedMeta):
             return None
         return Loc(kind=kind, start=start, reg_len=reg_len)
 
-    def __init__(self, kind, start, reg_len):
-        # type: (LocKind, int, int) -> None
-        msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+    def __post_init__(self):
+        msg = self.validate(kind=self.kind, start=self.start,
+                            reg_len=self.reg_len)
         if msg is not None:
             raise ValueError(msg)
-        self.kind = kind
-        self.reg_len = reg_len
-        self.start = start
 
     def conflicts(self, other):
         # type: (Loc) -> bool
@@ -811,7 +814,7 @@ SPECIAL_GPRS = (
 
 
 @final
-class LocSet(OFSet[Loc], metaclass=InternedMeta):
+class LocSet(OFSet[Loc], Interned):
     def __init__(self, __locs=()):
         # type: (Iterable[Loc]) -> None
         super().__init__(__locs)
@@ -927,12 +930,16 @@ class LocSet(OFSet[Loc], metaclass=InternedMeta):
         return only_loc
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
 @final
-class GenericOperandDesc(metaclass=InternedMeta):
+class GenericOperandDesc(Interned):
     """generic Op operand descriptor"""
-    __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
-                 "write_stage")
+    ty: GenericTy
+    sub_kinds: OFSet[LocSubKind]
+    fixed_loc: "Loc | None" = None
+    tied_input_index: "int | None" = None
+    spread: bool = False
+    write_stage: OpStage = OpStage.Early
 
     def __init__(
         self, ty,  # type: GenericTy
@@ -944,11 +951,11 @@ class GenericOperandDesc(metaclass=InternedMeta):
         write_stage=OpStage.Early,  # type: OpStage
     ):
         # type: (...) -> None
-        self.ty = ty
-        self.sub_kinds = OFSet(sub_kinds)
+        object.__setattr__(self, "ty", ty)
+        object.__setattr__(self, "sub_kinds", OFSet(sub_kinds))
         if len(self.sub_kinds) == 0:
             raise ValueError("sub_kinds can't be empty")
-        self.fixed_loc = fixed_loc
+        object.__setattr__(self, "fixed_loc", fixed_loc)
         if fixed_loc is not None:
             if tied_input_index is not None:
                 raise ValueError("operand can't be both tied and fixed")
@@ -970,8 +977,8 @@ class GenericOperandDesc(metaclass=InternedMeta):
                                  f"sub_kind={sub_kind} ty={ty}")
         if tied_input_index is not None and tied_input_index < 0:
             raise ValueError("invalid tied_input_index")
-        self.tied_input_index = tied_input_index
-        self.spread = spread
+        object.__setattr__(self, "tied_input_index", tied_input_index)
+        object.__setattr__(self, "spread", spread)
         if spread:
             if self.tied_input_index is not None:
                 raise ValueError("operand can't be both spread and tied")
@@ -979,7 +986,7 @@ class GenericOperandDesc(metaclass=InternedMeta):
                 raise ValueError("operand can't be both spread and fixed")
             if self.ty.is_vec:
                 raise ValueError("operand can't be both spread and vector")
-        self.write_stage = write_stage
+        object.__setattr__(self, "write_stage", write_stage)
 
     @cached_property
     def ty_before_spread(self):
@@ -1036,24 +1043,20 @@ class GenericOperandDesc(metaclass=InternedMeta):
                               spread_index=idx, write_stage=self.write_stage)
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
 @final
-class OperandDesc(metaclass=InternedMeta):
+class OperandDesc(Interned):
     """Op operand descriptor"""
-    __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
-                 "write_stage")
+    loc_set_before_spread: LocSet
+    tied_input_index: "int | None"
+    spread_index: "int | None"
+    write_stage: "OpStage"
 
-    def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
-                 write_stage):
-        # type: (LocSet, int | None, int | None, OpStage) -> None
-        if len(loc_set_before_spread) == 0:
+    def __post_init__(self):
+        if len(self.loc_set_before_spread) == 0:
             raise ValueError("loc_set_before_spread must not be empty")
-        self.loc_set_before_spread = loc_set_before_spread
-        self.tied_input_index = tied_input_index
-        if self.tied_input_index is not None and spread_index is not None:
+        if self.tied_input_index is not None and self.spread_index is not None:
             raise ValueError("operand can't be both spread and tied")
-        self.spread_index = spread_index
-        self.write_stage = write_stage
 
     @cached_property
     def ty_before_spread(self):
@@ -1109,11 +1112,16 @@ OD_VL = GenericOperandDesc(
     sub_kinds=[LocSubKind.VL_MAXVL])
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@dataclass(frozen=True, unsafe_hash=True)
 @final
-class GenericOpProperties(metaclass=InternedMeta):
-    __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
-                 "is_copy", "is_load_immediate", "has_side_effects")
+class GenericOpProperties(Interned):
+    demo_asm: str
+    inputs: "tuple[GenericOperandDesc, ...]"
+    outputs: "tuple[GenericOperandDesc, ...]"
+    immediates: "tuple[range, ...]"
+    is_copy: bool
+    is_load_immediate: bool
+    has_side_effects: bool
 
     def __init__(
         self, demo_asm,  # type: str
@@ -1125,8 +1133,8 @@ class GenericOpProperties(metaclass=InternedMeta):
         has_side_effects=False,  # type: bool
     ):
         # type: (...) -> None
-        self.demo_asm = demo_asm  # type: str
-        self.inputs = tuple(inputs)  # type: tuple[GenericOperandDesc, ...]
+        object.__setattr__(self, "demo_asm", demo_asm)
+        object.__setattr__(self, "inputs", tuple(inputs))
         for inp in self.inputs:
             if inp.tied_input_index is not None:
                 raise ValueError(
@@ -1134,7 +1142,7 @@ class GenericOpProperties(metaclass=InternedMeta):
             if inp.write_stage is not OpStage.Early:
                 raise ValueError(
                     f"write_stage is not allowed on inputs: {inp}")
-        self.outputs = tuple(outputs)  # type: tuple[GenericOperandDesc, ...]
+        object.__setattr__(self, "outputs", tuple(outputs))
         fixed_locs = []  # type: list[tuple[Loc, int]]
         for idx, out in enumerate(self.outputs):
             if out.tied_input_index is not None:
@@ -1155,15 +1163,15 @@ class GenericOpProperties(metaclass=InternedMeta):
                         f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
                         f"with {other_fixed_loc}")
                 fixed_locs.append((out.fixed_loc, idx))
-        self.immediates = tuple(immediates)  # type: tuple[range, ...]
-        self.is_copy = is_copy  # type: bool
-        self.is_load_immediate = is_load_immediate  # type: bool
-        self.has_side_effects = has_side_effects  # type: bool
+        object.__setattr__(self, "immediates", tuple(immediates))
+        object.__setattr__(self, "is_copy", is_copy)
+        object.__setattr__(self, "is_load_immediate", is_load_immediate)
+        object.__setattr__(self, "has_side_effects", has_side_effects)
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data.plain_data(frozen=True, unsafe_hash=True)
 @final
-class OpProperties(metaclass=InternedMeta):
+class OpProperties:
     __slots__ = "kind", "inputs", "outputs", "maxvl", "copy_reg_len"
 
     def __init__(self, kind, maxvl):
@@ -1886,17 +1894,15 @@ class OpKind(Enum):
     _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-class SSAValOrUse(metaclass=InternedMeta):
-    __slots__ = "op", "operand_idx"
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
+class SSAValOrUse(Interned):
+    op: "Op"
+    operand_idx: int
 
-    def __init__(self, op, operand_idx):
-        # type: (Op, int) -> None
-        super().__init__()
-        self.op = op
-        if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
+    def __post_init__(self):
+        if self.operand_idx < 0 or \
+                self.operand_idx >= len(self.descriptor_array):
             raise ValueError("invalid operand_idx")
-        self.operand_idx = operand_idx
 
     @abstractmethod
     def __repr__(self):
@@ -1951,7 +1957,7 @@ class SSAValOrUse(metaclass=InternedMeta):
         return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAVal(SSAValOrUse):
     __slots__ = ()
@@ -2002,7 +2008,7 @@ class SSAVal(SSAValOrUse):
         return tuple(SSAValSubReg(self, i) for i in range(self.ty.reg_len))
 
 
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@dataclass(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAUse(SSAValOrUse):
     __slots__ = ()
@@ -2170,7 +2176,7 @@ class OpImmediates(OpInputSeq[int, range]):
         super().__init__(items, op)
 
 
-@plain_data(frozen=True, eq=False, repr=False)
+@plain_data.plain_data(frozen=True, eq=False, repr=False)
 @final
 class Op:
     __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
@@ -2301,7 +2307,7 @@ class Op:
         self.kind.gen_asm(self, state)
 
 
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
 class BaseSimState(metaclass=ABCMeta):
     __slots__ = "memory",
 
@@ -2378,7 +2384,7 @@ class BaseSimState(metaclass=ABCMeta):
     def __repr__(self):
         # type: () -> str
         field_vals = []  # type: list[str]
-        for name in fields(self):
+        for name in plain_data.fields(self):
             try:
                 value = getattr(self, name)
             except AttributeError:
@@ -2403,7 +2409,7 @@ class BaseSimState(metaclass=ABCMeta):
         ...
 
 
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
 class PreRABaseSimState(BaseSimState):
     __slots__ = "ssa_vals",
 
@@ -2459,7 +2465,7 @@ class SimSkipOp(Exception):
     pass
 
 
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
 @final
 class ConstPropagationState(PreRABaseSimState):
     __slots__ = "skipped_ops",
@@ -2482,7 +2488,7 @@ class ConstPropagationState(PreRABaseSimState):
         self.skipped_ops.add(op)
 
 
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
 class PreRASimState(PreRABaseSimState):
     __slots__ = ()
 
@@ -2517,7 +2523,7 @@ class PreRASimState(PreRABaseSimState):
         return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
 
 
-@plain_data(frozen=True, repr=False)
+@plain_data.plain_data(frozen=True, repr=False)
 @final
 class PostRASimState(BaseSimState):
     __slots__ = "ssa_val_to_loc_map", "loc_values"
@@ -2569,7 +2575,7 @@ class PostRASimState(BaseSimState):
             self.loc_values[subloc] = value[i]
 
 
-@plain_data(frozen=True)
+@plain_data.plain_data(frozen=True)
 class GenAsmState:
     __slots__ = "allocated_locs", "output"