+from collections import defaultdict
import enum
from enum import Enum, unique
-from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
from nmutil.plain_data import plain_data
-from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
-from weakref import WeakValueDictionary
+from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
+ OSet, FMap)
+from functools import lru_cache
@final
class Fn:
def __init__(self):
self.ops = [] # type: list[Op]
- op_names = WeakValueDictionary()
- self.__op_names = op_names # type: WeakValueDictionary[str, Op]
+ self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
self.__next_name_suffix = 2
def _add_op_with_unused_name(self, op, name=""):
self.__next_name_suffix += 1
def __repr__(self):
+ # type: () -> str
return "<Fn>"
@unique
@final
-class RegKind(Enum):
- GPR = enum.auto()
+class BaseTy(Enum):
+ I64 = enum.auto()
CA = enum.auto()
VL_MAXVL = enum.auto()
@cached_property
def only_scalar(self):
- if self is RegKind.GPR:
+ # type: () -> bool
+ if self is BaseTy.I64:
return False
- elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
return True
else:
assert_never(self)
@cached_property
- def reg_count(self):
- if self is RegKind.GPR:
+ def max_reg_len(self):
+ # type: () -> int
+ if self is BaseTy.I64:
return 128
- elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
return 1
else:
assert_never(self)
def __repr__(self):
- return "RegKind." + self._name_
+ return "BaseTy." + self._name_
@plain_data(frozen=True, unsafe_hash=True)
@final
-class OperandType:
- __slots__ = "kind", "vec"
+class Ty:
+ __slots__ = "base_ty", "reg_len"
- def __init__(self, kind, vec):
- # type: (RegKind, bool) -> None
- self.kind = kind
- if kind.only_scalar and vec:
- raise ValueError(f"kind={kind} must have vec=False")
- self.vec = vec
-
- def get_length(self, maxvl):
- # type: (int) -> int
- # here's where subvl and elwid would be accounted for
- if self.vec:
- return maxvl
- return 1
+ @staticmethod
+ def validate(base_ty, reg_len):
+ # type: (BaseTy, int) -> str | None
+ """ return a string with the error if the combination is invalid,
+ otherwise return None
+ """
+ if base_ty.only_scalar and reg_len != 1:
+ return f"can't create a vector of an only-scalar type: {base_ty}"
+ if reg_len < 1 or reg_len > base_ty.max_reg_len:
+ return "reg_len out of range"
+ return None
+
+ def __init__(self, base_ty, reg_len):
+ # type: (BaseTy, int) -> None
+ msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.base_ty = base_ty
+ self.reg_len = reg_len
-@plain_data(frozen=True, unsafe_hash=True)
+@unique
@final
-class RegShape:
- __slots__ = "kind", "length"
+class LocKind(Enum):
+ GPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
- def __init__(self, kind, length=1):
- # type: (RegKind, int) -> None
- self.kind = kind
- if length < 1 or length > kind.reg_count:
- raise ValueError("invalid length")
- self.length = length
+ @cached_property
+ def base_ty(self):
+ # type: () -> BaseTy
+ if self is LocKind.GPR or self is LocKind.StackI64:
+ return BaseTy.I64
+ if self is LocKind.CA:
+ return BaseTy.CA
+ if self is LocKind.VL_MAXVL:
+ return BaseTy.VL_MAXVL
+ else:
+ assert_never(self)
- def try_concat(self, *others):
- # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
- kind = self.kind
- length = self.length
- for other in others:
- if isinstance(other, (Reg, RegClass)):
- other = other.shape
- if other is None:
- return None
- if other.kind != self.kind:
- return None
- length += other.length
- if length > kind.reg_count:
- return None
- return RegShape(kind=kind, length=length)
+ @cached_property
+ def loc_count(self):
+ # type: () -> int
+ if self is LocKind.StackI64:
+ return 1024
+ if self is LocKind.GPR or self is LocKind.CA \
+ or self is LocKind.VL_MAXVL:
+ return self.base_ty.max_reg_len
+ else:
+ assert_never(self)
+
+ def __repr__(self):
+ return "LocKind." + self._name_
-@plain_data(frozen=True, unsafe_hash=True)
@final
-class Reg:
- __slots__ = "shape", "start"
-
- def __init__(self, shape, start):
- # type: (RegShape, int) -> None
- self.shape = shape
- if start < 0 or start + shape.length > shape.kind.reg_count:
- raise ValueError("start not in valid range")
- self.start = start
+@unique
+class LocSubKind(Enum):
+ BASE_GPR = enum.auto()
+ SV_EXTRA2_VGPR = enum.auto()
+ SV_EXTRA2_SGPR = enum.auto()
+ SV_EXTRA3_VGPR = enum.auto()
+ SV_EXTRA3_SGPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
- @property
+ @cached_property
def kind(self):
- return self.shape.kind
+ # type: () -> LocKind
+ # pyright fails typechecking when using `in` here:
+ # reported: https://github.com/microsoft/pyright/issues/4102
+ if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
+ or self is LocSubKind.SV_EXTRA2_SGPR \
+ or self is LocSubKind.SV_EXTRA3_VGPR \
+ or self is LocSubKind.SV_EXTRA3_SGPR:
+ return LocKind.GPR
+ if self is LocSubKind.StackI64:
+ return LocKind.StackI64
+ if self is LocSubKind.CA:
+ return LocKind.CA
+ if self is LocSubKind.VL_MAXVL:
+ return LocKind.VL_MAXVL
+ assert_never(self)
@property
- def length(self):
- return self.shape.length
+ def base_ty(self):
+ return self.kind.base_ty
+
+ @lru_cache()
+ def allocatable_locs(self, ty):
+ # type: (Ty) -> LocSet
+ if ty.base_ty != self.base_ty:
+ raise ValueError("type mismatch")
+ raise NotImplementedError # FIXME: finish
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericTy:
+ __slots__ = "base_ty", "is_vec"
+
+ def __init__(self, base_ty, is_vec):
+ # type: (BaseTy, bool) -> None
+ self.base_ty = base_ty
+ if base_ty.only_scalar and is_vec:
+ raise ValueError(f"base_ty={base_ty} requires is_vec=False")
+ self.is_vec = is_vec
+
+ def instantiate(self, maxvl):
+ # type: (int) -> Ty
+ # here's where subvl and elwid would be accounted for
+ if self.is_vec:
+ return Ty(self.base_ty, maxvl)
+ return Ty(self.base_ty, 1)
+
+ def can_instantiate_to(self, ty):
+ # type: (Ty) -> bool
+ if self.base_ty != ty.base_ty:
+ return False
+ if self.is_vec:
+ return True
+ return ty.reg_len == 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Loc:
+ __slots__ = "kind", "start", "reg_len"
+
+ @staticmethod
+ def validate(kind, start, reg_len):
+ # type: (LocKind, int, int) -> str | None
+ msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
+ if msg is not None:
+ return msg
+ if reg_len > kind.loc_count:
+ return "invalid reg_len"
+ if start < 0 or start + reg_len > kind.loc_count:
+ return "start not in valid range"
+ return None
+
+ @staticmethod
+ def try_make(kind, start, reg_len):
+ # type: (LocKind, int, int) -> Loc | None
+ msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is None:
+ return None
+ return Loc(kind=kind, start=start, reg_len=reg_len)
+
+ def __init__(self, kind, start, reg_len):
+ # type: (LocKind, int, int) -> None
+ msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.kind = kind
+ self.reg_len = reg_len
+ self.start = start
def conflicts(self, other):
- # type: (Reg) -> bool
- return (self.kind == other.kind
+ # type: (Loc) -> bool
+ return (self.kind != other.kind
and self.start < other.stop and other.start < self.stop)
+ @staticmethod
+ def make_ty(kind, reg_len):
+ # type: (LocKind, int) -> Ty
+ return Ty(base_ty=kind.base_ty, reg_len=reg_len)
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.make_ty(kind=self.kind, reg_len=self.reg_len)
+
@property
def stop(self):
- return self.start + self.length
+ # type: () -> int
+ return self.start + self.reg_len
def try_concat(self, *others):
- # type: (*Reg | None) -> Reg | None
- shape = self.shape.try_concat(*others)
- if shape is None:
- return None
+ # type: (*Loc | None) -> Loc | None
+ reg_len = self.reg_len
stop = self.stop
for other in others:
- assert other is not None, "already caught by RegShape.try_concat"
+ if other is None or other.kind != self.kind:
+ return None
if stop != other.start:
return None
stop = other.stop
- return Reg(shape, self.start)
+ reg_len += other.reg_len
+ return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
+@plain_data(frozen=True, eq=False, repr=False)
@final
-class RegClass(AbstractSet[Reg]):
- def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
- # type: (Iterable[Reg | int], RegShape | None, int) -> None
- for reg_or_start in regs_or_starts:
- if isinstance(reg_or_start, Reg):
- if shape is None:
- shape = reg_or_start.shape
- elif shape != reg_or_start.shape:
- raise ValueError(f"conflicting RegShapes: {shape} and "
- f"{reg_or_start.shape}")
- start = reg_or_start.start
- else:
- start = reg_or_start
- if start < 0:
- raise ValueError("a Reg's start is out of range")
- starts_bitset |= 1 << start
- if starts_bitset == 0:
- shape = None
- self.__shape = shape
- self.__starts_bitset = starts_bitset
- if shape is None:
- if starts_bitset != 0:
- raise ValueError("non-empty RegClass must have non-None shape")
+class LocSet(AbstractSet[Loc]):
+ __slots__ = "starts", "ty"
+
+ def __init__(self, __locs=()):
+ # type: (Iterable[Loc]) -> None
+ if isinstance(__locs, LocSet):
+ self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
+ self.ty = __locs.ty # type: Ty | None
return
- if self.stops_bitset >= 1 << shape.kind.reg_count:
- raise ValueError("a Reg's start is out of range")
-
- @property
- def shape(self):
- # type: () -> RegShape | None
- return self.__shape
-
- @property
- def starts_bitset(self):
- # type: () -> int
- return self.__starts_bitset
-
- @property
- def stops_bitset(self):
- # type: () -> int
- if self.__shape is None:
- return 0
- return self.__starts_bitset << self.__shape.length
-
- @cached_property
- def starts(self):
- # type: () -> OFSet[int]
- if self.length is None:
- return OFSet()
- # TODO: fixme
- # return OFSet(for i in range(self.length))
+ starts = {i: BitSet() for i in LocKind}
+ ty = None
+ for loc in __locs:
+ if ty is None:
+ ty = loc.ty
+ if ty != loc.ty:
+ raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+ starts[loc.kind].add(loc.start)
+ self.starts = FMap(
+ (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
+ self.ty = ty
@cached_property
def stops(self):
- # type: () -> OFSet[int]
- if self.__shape is None:
- return OFSet()
- return OFSet(i + self.__shape.length for i in self.__starts)
+ # type: () -> FMap[LocKind, FBitSet]
+ if self.ty is None:
+ return FMap()
+ sh = self.ty.reg_len
+ return FMap(
+ (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
@property
- def kind(self):
- if self.__shape is None:
+ def kinds(self):
+ # type: () -> AbstractSet[LocKind]
+ return self.starts.keys()
+
+ @property
+ def reg_len(self):
+ # type: () -> int | None
+ if self.ty is None:
return None
- return self.__shape.kind
+ return self.ty.reg_len
@property
- def length(self):
- """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
- if self.__shape is None:
+ def base_ty(self):
+ # type: () -> BaseTy | None
+ if self.ty is None:
return None
- return self.__shape.length
+ return self.ty.base_ty
def concat(self, *others):
- # type: (*RegClass) -> RegClass
- shape = self.__shape
- if shape is None:
- return RegClass()
- shape = shape.try_concat(*others)
- if shape is None:
- return RegClass()
- starts = OSet(self.starts)
- offset = shape.length
+ # type: (*LocSet) -> LocSet
+ if self.ty is None:
+ return LocSet()
+ base_ty = self.ty.base_ty
+ reg_len = self.ty.reg_len
+ starts = {k: BitSet(v) for k, v in self.starts.items()}
for other in others:
- assert other.__shape is not None, \
- "already caught by RegShape.try_concat"
- starts &= OSet(i - offset for i in other.starts)
- offset += other.__shape.length
- return RegClass(starts, shape=shape)
-
- def __contains__(self, reg):
- # type: (Reg) -> bool
- return reg.shape == self.shape and reg.start in self.starts
+ if other.ty is None:
+ return LocSet()
+ if other.ty.base_ty != base_ty:
+ return LocSet()
+ for kind, other_starts in other.starts.items():
+ if kind not in starts:
+ continue
+ starts[kind].bits &= other_starts.bits >> reg_len
+ if starts[kind] == 0:
+ del starts[kind]
+ if len(starts) == 0:
+ return LocSet()
+ reg_len += other.ty.reg_len
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ for kind, v in starts.items():
+ for start in v:
+ loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
+ if loc is not None:
+ yield loc
+ return LocSet(locs())
+
+ def __contains__(self, loc):
+ # type: (Loc | Any) -> bool
+ if not isinstance(loc, Loc) or loc.ty == self.ty:
+ return False
+ if loc.kind not in self.starts:
+ return False
+ return loc.start in self.starts[loc.kind]
def __iter__(self):
- # type: () -> Iterator[Reg]
- if self.shape is None:
+ # type: () -> Iterator[Loc]
+ if self.ty is None:
return
- for start in self.starts:
- yield Reg(shape=self.shape, start=start)
+ for kind, starts in self.starts.items():
+ for start in starts:
+ yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
+
+ @cached_property
+ def __len(self):
+ return sum((len(v) for v in self.starts.values()), 0)
def __len__(self):
- return len(self.starts)
+ return self.__len
- def __hash__(self):
+ @cached_property
+ def __hash(self):
return super()._hash()
+ def __hash__(self):
+ return self.__hash
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class Operand:
- __slots__ = "ty", "regs"
-
- def __init__(self, ty, regs=None):
- # type: (OperandType, OFSet[int] | None) -> None
- pass
-
-
-OT_VGPR = OperandType(RegKind.GPR, vec=True)
-OT_SGPR = OperandType(RegKind.GPR, vec=False)
-OT_CA = OperandType(RegKind.CA, vec=False)
-OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+class GenericOperandDesc:
+ """generic Op operand descriptor"""
+ __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index"
+
+ def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None):
+ # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None
+ self.ty = ty
+ self.sub_kinds = OFSet(sub_kinds)
+ if len(self.sub_kinds) == 0:
+ raise ValueError("sub_kinds can't be empty")
+ self.fixed_loc = fixed_loc
+ if fixed_loc is not None:
+ if tied_input_index is not None:
+ raise ValueError("operand can't be both tied and fixed")
+ if not ty.can_instantiate_to(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc has incompatible type for given generic "
+ f"type: fixed_loc={fixed_loc} generic ty={ty}")
+ if len(self.sub_kinds) != 1:
+ raise ValueError(
+ "multiple sub_kinds not allowed for fixed operand")
+ for sub_kind in self.sub_kinds:
+ if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc not in given sub_kind: "
+ f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
+ for sub_kind in self.sub_kinds:
+ if sub_kind.base_ty != ty.base_ty:
+ raise ValueError(f"sub_kind is incompatible with type: "
+ f"sub_kind={sub_kind} ty={ty}")
+ if tied_input_index is not None and tied_input_index < 0:
+ raise ValueError("invalid tied_input_index")
+ self.tied_input_index = tied_input_index
+
+ def tied_to_input(self, tied_input_index):
+ # type: (int) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds,
+ tied_input_index=tied_input_index)
+
+ def with_fixed_loc(self, fixed_loc):
+ # type: (Loc) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+
+ def instantiate(self, maxvl):
+ # type: (int) -> OperandDesc
+ ty = self.ty.instantiate(maxvl=maxvl)
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ if self.fixed_loc is not None:
+ if ty != self.fixed_loc.ty:
+ raise ValueError(
+ f"instantiation failed: type mismatch with fixed_loc: "
+ f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+ yield self.fixed_loc
+ return
+ for sub_kind in self.sub_kinds:
+ yield from sub_kind.allocatable_locs(ty)
+ return OperandDesc(loc_set=LocSet(locs()),
+ tied_input_index=self.tied_input_index)
@plain_data(frozen=True, unsafe_hash=True)
-class TiedOutput:
- __slots__ = "input_index", "output_index"
-
- def __init__(self, input_index, output_index):
- # type: (int, int) -> None
- self.input_index = input_index
- self.output_index = output_index
-
-
-Constraint = Union[TiedOutput, NoReturn]
+@final
+class OperandDesc:
+ """Op operand descriptor"""
+ __slots__ = "loc_set", "tied_input_index"
+
+ def __init__(self, loc_set, tied_input_index):
+ # type: (LocSet, int | None) -> None
+ if len(loc_set) == 0:
+ raise ValueError("loc_set must not be empty")
+ self.loc_set = loc_set
+ self.tied_input_index = tied_input_index
+
+
+OD_BASE_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.BASE_GPR])
+OD_EXTRA3_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
+OD_EXTRA3_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
+OD_EXTRA2_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
+OD_EXTRA2_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
+OD_CA = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
+ sub_kinds=[LocSubKind.CA])
+OD_VL = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
+ sub_kinds=[LocSubKind.VL_MAXVL])
@plain_data(frozen=True, unsafe_hash=True)
@final
-class OpProperties:
- __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+class GenericOpProperties:
+ __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
"is_copy", "is_load_immediate", "has_side_effects")
def __init__(self, demo_asm, # type: str
- inputs, # type: Iterable[OperandType]
- outputs, # type: Iterable[OperandType]
+ inputs, # type: Iterable[GenericOperandDesc]
+ outputs, # type: Iterable[GenericOperandDesc]
immediates, # type: Iterable[range]
- constraints, # type: Iterable[Constraint]
is_copy=False, # type: bool
is_load_immediate=False, # type: bool
has_side_effects=False, # type: bool
self.inputs = tuple(inputs)
self.outputs = tuple(outputs)
self.immediates = tuple(immediates)
- self.constraints = tuple(constraints)
self.is_copy = is_copy
self.is_load_immediate = is_load_immediate
self.has_side_effects = has_side_effects
+ def instantiate(self, maxvl):
+ # type: (int) -> OpProperties
+ raise NotImplementedError # FIXME: finish
+
+
+# FIXME: add OpProperties
@unique
@final
class OpKind(Enum):
def __init__(self, properties):
- # type: (OpProperties) -> None
+ # type: (GenericOpProperties) -> None
super().__init__()
self.properties = properties
self.sliced_op_outputs = tuple(processed)
def __add__(self, other):
- # type: (SSAVal) -> SSAVal
+ # type: (SSAVal | Any) -> SSAVal
if not isinstance(other, SSAVal):
return NotImplemented
return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
def __radd__(self, other):
- # type: (SSAVal) -> SSAVal
+ # type: (SSAVal | Any) -> SSAVal
if isinstance(other, SSAVal):
return other.__add__(self)
return NotImplemented
@cached_property
def expanded_sliced_op_outputs(self):
# type: () -> tuple[tuple[Op, int, int], ...]
- retval = []
+ retval = [] # type: list[tuple[Op, int, int]]
for op, output_index, range_ in self.sliced_op_outputs:
for i in range_:
retval.append((op, output_index, i))
# type: () -> str
if len(self.sliced_op_outputs) == 0:
return "SSAVal([])"
- parts = []
+ parts = [] # type: list[str]
for op, output_index, range_ in self.sliced_op_outputs:
out_len = op.properties.outputs[output_index].get_length(op.maxvl)
parts.append(f"<{op.name}#{output_index}>")
self.maxvl = maxvl
outputs_len = len(self.properties.outputs)
self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
- self.name = fn._add_op_with_unused_name(self, name)
+ self.name = fn._add_op_with_unused_name(self, name) # type: ignore
@property
def properties(self):
return self.kind.properties
def __eq__(self, other):
+ # type: (Op | Any) -> bool
if isinstance(other, Op):
return self is other
return NotImplemented