From 861dc2996a6a9e6feb25ca384a8fa0d44982d80e Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 28 Oct 2022 02:24:23 -0700 Subject: [PATCH] working on rewriting compiler ir to fix reg alloc issues --- .../_tests/__init__.py | 0 .../{ => _tests}/test_compiler_ir.py | 0 .../{ => _tests}/test_matrix.py | 0 .../{ => _tests}/test_register_allocator.py | 0 .../{ => _tests}/test_toom_cook.py | 0 src/bigint_presentation_code/compiler_ir.py | 4 +- src/bigint_presentation_code/compiler_ir2.py | 583 ++++++++++++------ src/bigint_presentation_code/matrix.py | 19 +- src/bigint_presentation_code/py.typed | 0 .../register_allocator.py | 3 +- src/bigint_presentation_code/toom_cook.py | 15 +- src/bigint_presentation_code/type_util.py | 32 + src/bigint_presentation_code/type_util.pyi | 19 + src/bigint_presentation_code/util.py | 81 +-- src/bigint_presentation_code/util.pyi | 190 ------ typings/cached_property.pyi | 16 +- 16 files changed, 510 insertions(+), 452 deletions(-) create mode 100644 src/bigint_presentation_code/_tests/__init__.py rename src/bigint_presentation_code/{ => _tests}/test_compiler_ir.py (100%) rename src/bigint_presentation_code/{ => _tests}/test_matrix.py (100%) rename src/bigint_presentation_code/{ => _tests}/test_register_allocator.py (100%) rename src/bigint_presentation_code/{ => _tests}/test_toom_cook.py (100%) create mode 100644 src/bigint_presentation_code/py.typed create mode 100644 src/bigint_presentation_code/type_util.py create mode 100644 src/bigint_presentation_code/type_util.pyi delete mode 100644 src/bigint_presentation_code/util.pyi diff --git a/src/bigint_presentation_code/_tests/__init__.py b/src/bigint_presentation_code/_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py similarity index 100% rename from src/bigint_presentation_code/test_compiler_ir.py rename to src/bigint_presentation_code/_tests/test_compiler_ir.py diff --git a/src/bigint_presentation_code/test_matrix.py b/src/bigint_presentation_code/_tests/test_matrix.py similarity index 100% rename from src/bigint_presentation_code/test_matrix.py rename to src/bigint_presentation_code/_tests/test_matrix.py diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py similarity index 100% rename from src/bigint_presentation_code/test_register_allocator.py rename to src/bigint_presentation_code/_tests/test_register_allocator.py diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py similarity index 100% rename from src/bigint_presentation_code/test_toom_cook.py rename to src/bigint_presentation_code/_tests/test_toom_cook.py diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 77e44a2..c574174 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -1,3 +1,4 @@ +# type: ignore """ Compiler IR for Toom-Cook algorithm generator for SVP64 @@ -12,7 +13,8 @@ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.util import FMap, OFSet, OSet, final +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import FMap, OFSet, OSet class ABCEnumMeta(EnumMeta, ABCMeta): diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index eacceb4..666df14 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -1,20 +1,23 @@ +from collections import defaultdict import enum from enum import Enum, unique -from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload +from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload +from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final -from weakref import WeakValueDictionary +from bigint_presentation_code.type_util import Self, assert_never, final +from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet, + OSet, FMap) +from functools import lru_cache @final class Fn: def __init__(self): self.ops = [] # type: list[Op] - op_names = WeakValueDictionary() - self.__op_names = op_names # type: WeakValueDictionary[str, Op] + self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op] self.__next_name_suffix = 2 def _add_op_with_unused_name(self, op, name=""): @@ -32,278 +35,464 @@ class Fn: self.__next_name_suffix += 1 def __repr__(self): + # type: () -> str return "" @unique @final -class RegKind(Enum): - GPR = enum.auto() +class BaseTy(Enum): + I64 = enum.auto() CA = enum.auto() VL_MAXVL = enum.auto() @cached_property def only_scalar(self): - if self is RegKind.GPR: + # type: () -> bool + if self is BaseTy.I64: return False - elif self is RegKind.CA or self is RegKind.VL_MAXVL: + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: return True else: assert_never(self) @cached_property - def reg_count(self): - if self is RegKind.GPR: + def max_reg_len(self): + # type: () -> int + if self is BaseTy.I64: return 128 - elif self is RegKind.CA or self is RegKind.VL_MAXVL: + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: return 1 else: assert_never(self) def __repr__(self): - return "RegKind." + self._name_ + return "BaseTy." + self._name_ @plain_data(frozen=True, unsafe_hash=True) @final -class OperandType: - __slots__ = "kind", "vec" +class Ty: + __slots__ = "base_ty", "reg_len" - def __init__(self, kind, vec): - # type: (RegKind, bool) -> None - self.kind = kind - if kind.only_scalar and vec: - raise ValueError(f"kind={kind} must have vec=False") - self.vec = vec - - def get_length(self, maxvl): - # type: (int) -> int - # here's where subvl and elwid would be accounted for - if self.vec: - return maxvl - return 1 + @staticmethod + def validate(base_ty, reg_len): + # type: (BaseTy, int) -> str | None + """ return a string with the error if the combination is invalid, + otherwise return None + """ + if base_ty.only_scalar and reg_len != 1: + return f"can't create a vector of an only-scalar type: {base_ty}" + if reg_len < 1 or reg_len > base_ty.max_reg_len: + return "reg_len out of range" + return None + + def __init__(self, base_ty, reg_len): + # type: (BaseTy, int) -> None + msg = self.validate(base_ty=base_ty, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.base_ty = base_ty + self.reg_len = reg_len -@plain_data(frozen=True, unsafe_hash=True) +@unique @final -class RegShape: - __slots__ = "kind", "length" +class LocKind(Enum): + GPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() - def __init__(self, kind, length=1): - # type: (RegKind, int) -> None - self.kind = kind - if length < 1 or length > kind.reg_count: - raise ValueError("invalid length") - self.length = length + @cached_property + def base_ty(self): + # type: () -> BaseTy + if self is LocKind.GPR or self is LocKind.StackI64: + return BaseTy.I64 + if self is LocKind.CA: + return BaseTy.CA + if self is LocKind.VL_MAXVL: + return BaseTy.VL_MAXVL + else: + assert_never(self) - def try_concat(self, *others): - # type: (*RegShape | Reg | RegClass | None) -> RegShape | None - kind = self.kind - length = self.length - for other in others: - if isinstance(other, (Reg, RegClass)): - other = other.shape - if other is None: - return None - if other.kind != self.kind: - return None - length += other.length - if length > kind.reg_count: - return None - return RegShape(kind=kind, length=length) + @cached_property + def loc_count(self): + # type: () -> int + if self is LocKind.StackI64: + return 1024 + if self is LocKind.GPR or self is LocKind.CA \ + or self is LocKind.VL_MAXVL: + return self.base_ty.max_reg_len + else: + assert_never(self) + + def __repr__(self): + return "LocKind." + self._name_ -@plain_data(frozen=True, unsafe_hash=True) @final -class Reg: - __slots__ = "shape", "start" - - def __init__(self, shape, start): - # type: (RegShape, int) -> None - self.shape = shape - if start < 0 or start + shape.length > shape.kind.reg_count: - raise ValueError("start not in valid range") - self.start = start +@unique +class LocSubKind(Enum): + BASE_GPR = enum.auto() + SV_EXTRA2_VGPR = enum.auto() + SV_EXTRA2_SGPR = enum.auto() + SV_EXTRA3_VGPR = enum.auto() + SV_EXTRA3_SGPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() - @property + @cached_property def kind(self): - return self.shape.kind + # type: () -> LocKind + # pyright fails typechecking when using `in` here: + # reported: https://github.com/microsoft/pyright/issues/4102 + if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \ + or self is LocSubKind.SV_EXTRA2_SGPR \ + or self is LocSubKind.SV_EXTRA3_VGPR \ + or self is LocSubKind.SV_EXTRA3_SGPR: + return LocKind.GPR + if self is LocSubKind.StackI64: + return LocKind.StackI64 + if self is LocSubKind.CA: + return LocKind.CA + if self is LocSubKind.VL_MAXVL: + return LocKind.VL_MAXVL + assert_never(self) @property - def length(self): - return self.shape.length + def base_ty(self): + return self.kind.base_ty + + @lru_cache() + def allocatable_locs(self, ty): + # type: (Ty) -> LocSet + if ty.base_ty != self.base_ty: + raise ValueError("type mismatch") + raise NotImplementedError # FIXME: finish + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GenericTy: + __slots__ = "base_ty", "is_vec" + + def __init__(self, base_ty, is_vec): + # type: (BaseTy, bool) -> None + self.base_ty = base_ty + if base_ty.only_scalar and is_vec: + raise ValueError(f"base_ty={base_ty} requires is_vec=False") + self.is_vec = is_vec + + def instantiate(self, maxvl): + # type: (int) -> Ty + # here's where subvl and elwid would be accounted for + if self.is_vec: + return Ty(self.base_ty, maxvl) + return Ty(self.base_ty, 1) + + def can_instantiate_to(self, ty): + # type: (Ty) -> bool + if self.base_ty != ty.base_ty: + return False + if self.is_vec: + return True + return ty.reg_len == 1 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class Loc: + __slots__ = "kind", "start", "reg_len" + + @staticmethod + def validate(kind, start, reg_len): + # type: (LocKind, int, int) -> str | None + msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len) + if msg is not None: + return msg + if reg_len > kind.loc_count: + return "invalid reg_len" + if start < 0 or start + reg_len > kind.loc_count: + return "start not in valid range" + return None + + @staticmethod + def try_make(kind, start, reg_len): + # type: (LocKind, int, int) -> Loc | None + msg = Loc.validate(kind=kind, start=start, reg_len=reg_len) + if msg is None: + return None + return Loc(kind=kind, start=start, reg_len=reg_len) + + def __init__(self, kind, start, reg_len): + # type: (LocKind, int, int) -> None + msg = self.validate(kind=kind, start=start, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.kind = kind + self.reg_len = reg_len + self.start = start def conflicts(self, other): - # type: (Reg) -> bool - return (self.kind == other.kind + # type: (Loc) -> bool + return (self.kind != other.kind and self.start < other.stop and other.start < self.stop) + @staticmethod + def make_ty(kind, reg_len): + # type: (LocKind, int) -> Ty + return Ty(base_ty=kind.base_ty, reg_len=reg_len) + + @cached_property + def ty(self): + # type: () -> Ty + return self.make_ty(kind=self.kind, reg_len=self.reg_len) + @property def stop(self): - return self.start + self.length + # type: () -> int + return self.start + self.reg_len def try_concat(self, *others): - # type: (*Reg | None) -> Reg | None - shape = self.shape.try_concat(*others) - if shape is None: - return None + # type: (*Loc | None) -> Loc | None + reg_len = self.reg_len stop = self.stop for other in others: - assert other is not None, "already caught by RegShape.try_concat" + if other is None or other.kind != self.kind: + return None if stop != other.start: return None stop = other.stop - return Reg(shape, self.start) + reg_len += other.reg_len + return Loc(kind=self.kind, start=self.start, reg_len=reg_len) +@plain_data(frozen=True, eq=False, repr=False) @final -class RegClass(AbstractSet[Reg]): - def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0): - # type: (Iterable[Reg | int], RegShape | None, int) -> None - for reg_or_start in regs_or_starts: - if isinstance(reg_or_start, Reg): - if shape is None: - shape = reg_or_start.shape - elif shape != reg_or_start.shape: - raise ValueError(f"conflicting RegShapes: {shape} and " - f"{reg_or_start.shape}") - start = reg_or_start.start - else: - start = reg_or_start - if start < 0: - raise ValueError("a Reg's start is out of range") - starts_bitset |= 1 << start - if starts_bitset == 0: - shape = None - self.__shape = shape - self.__starts_bitset = starts_bitset - if shape is None: - if starts_bitset != 0: - raise ValueError("non-empty RegClass must have non-None shape") +class LocSet(AbstractSet[Loc]): + __slots__ = "starts", "ty" + + def __init__(self, __locs=()): + # type: (Iterable[Loc]) -> None + if isinstance(__locs, LocSet): + self.starts = __locs.starts # type: FMap[LocKind, FBitSet] + self.ty = __locs.ty # type: Ty | None return - if self.stops_bitset >= 1 << shape.kind.reg_count: - raise ValueError("a Reg's start is out of range") - - @property - def shape(self): - # type: () -> RegShape | None - return self.__shape - - @property - def starts_bitset(self): - # type: () -> int - return self.__starts_bitset - - @property - def stops_bitset(self): - # type: () -> int - if self.__shape is None: - return 0 - return self.__starts_bitset << self.__shape.length - - @cached_property - def starts(self): - # type: () -> OFSet[int] - if self.length is None: - return OFSet() - # TODO: fixme - # return OFSet(for i in range(self.length)) + starts = {i: BitSet() for i in LocKind} + ty = None + for loc in __locs: + if ty is None: + ty = loc.ty + if ty != loc.ty: + raise ValueError(f"conflicting types: {ty} != {loc.ty}") + starts[loc.kind].add(loc.start) + self.starts = FMap( + (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) + self.ty = ty @cached_property def stops(self): - # type: () -> OFSet[int] - if self.__shape is None: - return OFSet() - return OFSet(i + self.__shape.length for i in self.__starts) + # type: () -> FMap[LocKind, FBitSet] + if self.ty is None: + return FMap() + sh = self.ty.reg_len + return FMap( + (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items()) @property - def kind(self): - if self.__shape is None: + def kinds(self): + # type: () -> AbstractSet[LocKind] + return self.starts.keys() + + @property + def reg_len(self): + # type: () -> int | None + if self.ty is None: return None - return self.__shape.kind + return self.ty.reg_len @property - def length(self): - """length of registers in this RegClass, not to be confused with the number of `Reg`s in self""" - if self.__shape is None: + def base_ty(self): + # type: () -> BaseTy | None + if self.ty is None: return None - return self.__shape.length + return self.ty.base_ty def concat(self, *others): - # type: (*RegClass) -> RegClass - shape = self.__shape - if shape is None: - return RegClass() - shape = shape.try_concat(*others) - if shape is None: - return RegClass() - starts = OSet(self.starts) - offset = shape.length + # type: (*LocSet) -> LocSet + if self.ty is None: + return LocSet() + base_ty = self.ty.base_ty + reg_len = self.ty.reg_len + starts = {k: BitSet(v) for k, v in self.starts.items()} for other in others: - assert other.__shape is not None, \ - "already caught by RegShape.try_concat" - starts &= OSet(i - offset for i in other.starts) - offset += other.__shape.length - return RegClass(starts, shape=shape) - - def __contains__(self, reg): - # type: (Reg) -> bool - return reg.shape == self.shape and reg.start in self.starts + if other.ty is None: + return LocSet() + if other.ty.base_ty != base_ty: + return LocSet() + for kind, other_starts in other.starts.items(): + if kind not in starts: + continue + starts[kind].bits &= other_starts.bits >> reg_len + if starts[kind] == 0: + del starts[kind] + if len(starts) == 0: + return LocSet() + reg_len += other.ty.reg_len + + def locs(): + # type: () -> Iterable[Loc] + for kind, v in starts.items(): + for start in v: + loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len) + if loc is not None: + yield loc + return LocSet(locs()) + + def __contains__(self, loc): + # type: (Loc | Any) -> bool + if not isinstance(loc, Loc) or loc.ty == self.ty: + return False + if loc.kind not in self.starts: + return False + return loc.start in self.starts[loc.kind] def __iter__(self): - # type: () -> Iterator[Reg] - if self.shape is None: + # type: () -> Iterator[Loc] + if self.ty is None: return - for start in self.starts: - yield Reg(shape=self.shape, start=start) + for kind, starts in self.starts.items(): + for start in starts: + yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len) + + @cached_property + def __len(self): + return sum((len(v) for v in self.starts.values()), 0) def __len__(self): - return len(self.starts) + return self.__len - def __hash__(self): + @cached_property + def __hash(self): return super()._hash() + def __hash__(self): + return self.__hash + @plain_data(frozen=True, unsafe_hash=True) @final -class Operand: - __slots__ = "ty", "regs" - - def __init__(self, ty, regs=None): - # type: (OperandType, OFSet[int] | None) -> None - pass - - -OT_VGPR = OperandType(RegKind.GPR, vec=True) -OT_SGPR = OperandType(RegKind.GPR, vec=False) -OT_CA = OperandType(RegKind.CA, vec=False) -OT_VL = OperandType(RegKind.VL_MAXVL, vec=False) +class GenericOperandDesc: + """generic Op operand descriptor""" + __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index" + + def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None): + # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None + self.ty = ty + self.sub_kinds = OFSet(sub_kinds) + if len(self.sub_kinds) == 0: + raise ValueError("sub_kinds can't be empty") + self.fixed_loc = fixed_loc + if fixed_loc is not None: + if tied_input_index is not None: + raise ValueError("operand can't be both tied and fixed") + if not ty.can_instantiate_to(fixed_loc.ty): + raise ValueError( + f"fixed_loc has incompatible type for given generic " + f"type: fixed_loc={fixed_loc} generic ty={ty}") + if len(self.sub_kinds) != 1: + raise ValueError( + "multiple sub_kinds not allowed for fixed operand") + for sub_kind in self.sub_kinds: + if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty): + raise ValueError( + f"fixed_loc not in given sub_kind: " + f"fixed_loc={fixed_loc} sub_kind={sub_kind}") + for sub_kind in self.sub_kinds: + if sub_kind.base_ty != ty.base_ty: + raise ValueError(f"sub_kind is incompatible with type: " + f"sub_kind={sub_kind} ty={ty}") + if tied_input_index is not None and tied_input_index < 0: + raise ValueError("invalid tied_input_index") + self.tied_input_index = tied_input_index + + def tied_to_input(self, tied_input_index): + # type: (int) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, + tied_input_index=tied_input_index) + + def with_fixed_loc(self, fixed_loc): + # type: (Loc) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc) + + def instantiate(self, maxvl): + # type: (int) -> OperandDesc + ty = self.ty.instantiate(maxvl=maxvl) + + def locs(): + # type: () -> Iterable[Loc] + if self.fixed_loc is not None: + if ty != self.fixed_loc.ty: + raise ValueError( + f"instantiation failed: type mismatch with fixed_loc: " + f"instantiated type: {ty} fixed_loc: {self.fixed_loc}") + yield self.fixed_loc + return + for sub_kind in self.sub_kinds: + yield from sub_kind.allocatable_locs(ty) + return OperandDesc(loc_set=LocSet(locs()), + tied_input_index=self.tied_input_index) @plain_data(frozen=True, unsafe_hash=True) -class TiedOutput: - __slots__ = "input_index", "output_index" - - def __init__(self, input_index, output_index): - # type: (int, int) -> None - self.input_index = input_index - self.output_index = output_index - - -Constraint = Union[TiedOutput, NoReturn] +@final +class OperandDesc: + """Op operand descriptor""" + __slots__ = "loc_set", "tied_input_index" + + def __init__(self, loc_set, tied_input_index): + # type: (LocSet, int | None) -> None + if len(loc_set) == 0: + raise ValueError("loc_set must not be empty") + self.loc_set = loc_set + self.tied_input_index = tied_input_index + + +OD_BASE_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.BASE_GPR]) +OD_EXTRA3_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR]) +OD_EXTRA3_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR]) +OD_EXTRA2_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA2_SGPR]) +OD_EXTRA2_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA2_VGPR]) +OD_CA = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.CA, is_vec=False), + sub_kinds=[LocSubKind.CA]) +OD_VL = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False), + sub_kinds=[LocSubKind.VL_MAXVL]) @plain_data(frozen=True, unsafe_hash=True) @final -class OpProperties: - __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints", +class GenericOpProperties: + __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "is_copy", "is_load_immediate", "has_side_effects") def __init__(self, demo_asm, # type: str - inputs, # type: Iterable[OperandType] - outputs, # type: Iterable[OperandType] + inputs, # type: Iterable[GenericOperandDesc] + outputs, # type: Iterable[GenericOperandDesc] immediates, # type: Iterable[range] - constraints, # type: Iterable[Constraint] is_copy=False, # type: bool is_load_immediate=False, # type: bool has_side_effects=False, # type: bool @@ -313,17 +502,22 @@ class OpProperties: self.inputs = tuple(inputs) self.outputs = tuple(outputs) self.immediates = tuple(immediates) - self.constraints = tuple(constraints) self.is_copy = is_copy self.is_load_immediate = is_load_immediate self.has_side_effects = has_side_effects + def instantiate(self, maxvl): + # type: (int) -> OpProperties + raise NotImplementedError # FIXME: finish + + +# FIXME: add OpProperties @unique @final class OpKind(Enum): def __init__(self, properties): - # type: (OpProperties) -> None + # type: (GenericOpProperties) -> None super().__init__() self.properties = properties @@ -451,13 +645,13 @@ class SSAVal: self.sliced_op_outputs = tuple(processed) def __add__(self, other): - # type: (SSAVal) -> SSAVal + # type: (SSAVal | Any) -> SSAVal if not isinstance(other, SSAVal): return NotImplemented return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs) def __radd__(self, other): - # type: (SSAVal) -> SSAVal + # type: (SSAVal | Any) -> SSAVal if isinstance(other, SSAVal): return other.__add__(self) return NotImplemented @@ -465,7 +659,7 @@ class SSAVal: @cached_property def expanded_sliced_op_outputs(self): # type: () -> tuple[tuple[Op, int, int], ...] - retval = [] + retval = [] # type: list[tuple[Op, int, int]] for op, output_index, range_ in self.sliced_op_outputs: for i in range_: retval.append((op, output_index, i)) @@ -490,7 +684,7 @@ class SSAVal: # type: () -> str if len(self.sliced_op_outputs) == 0: return "SSAVal([])" - parts = [] + parts = [] # type: list[str] for op, output_index, range_ in self.sliced_op_outputs: out_len = op.properties.outputs[output_index].get_length(op.maxvl) parts.append(f"<{op.name}#{output_index}>") @@ -513,13 +707,14 @@ class Op: self.maxvl = maxvl outputs_len = len(self.properties.outputs) self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len)) - self.name = fn._add_op_with_unused_name(self, name) + self.name = fn._add_op_with_unused_name(self, name) # type: ignore @property def properties(self): return self.kind.properties def __eq__(self, other): + # type: (Op | Any) -> bool if isinstance(other, Op): return self is other return NotImplemented diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py index 89c3ea2..0674c02 100644 --- a/src/bigint_presentation_code/matrix.py +++ b/src/bigint_presentation_code/matrix.py @@ -1,10 +1,9 @@ -import operator from enum import Enum, unique from fractions import Fraction -from numbers import Rational +import operator from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar -from bigint_presentation_code.util import final +from bigint_presentation_code.type_util import final _T = TypeVar("_T") _T2 = TypeVar("_T2") @@ -103,7 +102,7 @@ class Matrix(Generic[_T]): return retval def __truediv__(self, rhs): - # type: (Rational | int) -> Matrix + # type: (_T | int) -> Matrix[_T] retval = self.copy() for i in self.indexes(): retval[i] /= rhs # type: ignore @@ -128,7 +127,7 @@ class Matrix(Generic[_T]): return lhs.__matmul__(self) def __elementwise_bin_op(self, rhs, op): - # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T] + # type: (Matrix[_T], Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T] if self.height != rhs.height or self.width != rhs.width: raise ValueError( "matrix dimensions must match for element-wise operations") @@ -172,8 +171,8 @@ class Matrix(Generic[_T]): # type: () -> str if self.height == 0 or self.width == 0: return f"Matrix(height={self.height}, width={self.width})" - lines = [] - line = [] + lines = [] # type: list[str] + line = [] # type: list[str] for row in range(self.height): line.clear() for col in range(self.width): @@ -183,16 +182,16 @@ class Matrix(Generic[_T]): else: line.append(repr(el)) lines.append(", ".join(line)) - lines = ",\n ".join(lines) + lines_str = ",\n ".join(lines) element_type = "" if self.element_type is not Fraction: element_type = f"element_type={self.element_type}, " return (f"Matrix(height={self.height}, width={self.width}, " f"{element_type}data=[\n" - f" {lines},\n])") + f" {lines_str},\n])") def __eq__(self, rhs): - # type: (object) -> bool + # type: (Matrix[Any] | Any) -> bool if not isinstance(rhs, Matrix): return NotImplemented return (self.height == rhs.height diff --git a/src/bigint_presentation_code/py.typed b/src/bigint_presentation_code/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index b8269e4..cc794e9 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -12,7 +12,8 @@ from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, RegLoc, RegType, SSAVal) -from bigint_presentation_code.util import OFSet, OSet, final +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import OFSet, OSet _RegType = TypeVar("_RegType", bound=RegType) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 9e3ec74..246f654 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -4,13 +4,16 @@ Toom-Cook multiplication algorithm generator for SVP64 from abc import abstractmethod from enum import Enum from fractions import Fraction -from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union +from typing import Any, Generic, Iterable, Mapping, TypeVar, Union from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange +from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub, + OpBigIntMulDiv, OpConcat, + OpLI, OpSetCA, OpSetVLImm, + OpSplit, SSAGPRRange) from bigint_presentation_code.matrix import Matrix -from bigint_presentation_code.util import Literal, OSet, final +from bigint_presentation_code.type_util import Literal, final @final @@ -158,8 +161,8 @@ class EvalOpPoly: return f"EvalOpPoly({self.coefficients})" -_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp") -_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp") +_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]") +_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]") @plain_data(frozen=True, unsafe_hash=True) @@ -238,7 +241,7 @@ class EvalOpInput(EvalOp[int, Literal[0]]): __slots__ = () def __init__(self, lhs, rhs=0): - # type: (...) -> None + # type: (int, int) -> None if lhs < 0: raise ValueError("Input part_index (lhs) must be >= 0") if rhs != 0: diff --git a/src/bigint_presentation_code/type_util.py b/src/bigint_presentation_code/type_util.py new file mode 100644 index 0000000..ed7296f --- /dev/null +++ b/src/bigint_presentation_code/type_util.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING, Any, NoReturn, Union + +if TYPE_CHECKING: + from typing_extensions import Literal, Self, final +else: + def final(v): + return v + + class _Literal: + def __getitem__(self, v): + if isinstance(v, tuple): + return Union[tuple(type(i) for i in v)] + return type(v) + + Literal = _Literal() + + Self = Any + + +# pyright currently doesn't like typing_extensions' definition +# -- added to typing in python 3.11 +def assert_never(arg): + # type: (NoReturn) -> NoReturn + raise AssertionError("got to code that's supposed to be unreachable") + + +__all__ = [ + "assert_never", + "final", + "Literal", + "Self", +] diff --git a/src/bigint_presentation_code/type_util.pyi b/src/bigint_presentation_code/type_util.pyi new file mode 100644 index 0000000..630ca20 --- /dev/null +++ b/src/bigint_presentation_code/type_util.pyi @@ -0,0 +1,19 @@ +from typing import NoReturn, TypeVar + +from typing_extensions import Literal, Self, final + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") + + +# pyright currently doesn't like typing_extensions' definition +# -- added to typing in python 3.11 +def assert_never(arg: NoReturn) -> NoReturn: ... + + +__all__ = [ + "assert_never", + "final", + "Literal", + "Self", +] diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index aeea240..4b39787 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -1,50 +1,25 @@ from abc import abstractmethod -from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator, - Mapping, MutableSet, NoReturn, TypeVar, Union) +from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, + TypeVar, overload) -if TYPE_CHECKING: - from typing_extensions import Literal, Self, final -else: - def final(v): - return v - - class _Literal: - def __getitem__(self, v): - if isinstance(v, tuple): - return Union[tuple(type(i) for i in v)] - return type(v) - - Literal = _Literal() - - Self = Any +from bigint_presentation_code.type_util import Self, final _T_co = TypeVar("_T_co", covariant=True) _T = TypeVar("_T") __all__ = [ - "assert_never", "BaseBitSet", "bit_count", "BitSet", "FBitSet", - "final", "FMap", - "Literal", "OFSet", "OSet", - "Self", "top_set_bit_index", "trailing_zero_count", ] -# pyright currently doesn't like typing_extensions' definition -# -- added to typing in python 3.11 -def assert_never(arg): - # type: (NoReturn) -> NoReturn - raise AssertionError("got to code that's supposed to be unreachable") - - class OFSet(AbstractSet[_T_co]): """ ordered frozen set """ __slots__ = "__items", @@ -54,18 +29,23 @@ class OFSet(AbstractSet[_T_co]): self.__items = {v: None for v in items} def __contains__(self, x): + # type: (Any) -> bool return x in self.__items def __iter__(self): + # type: () -> Iterator[_T_co] return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def __hash__(self): + # type: () -> int return self._hash() def __repr__(self): + # type: () -> str if len(self) == 0: return "OFSet()" return f"OFSet({list(self)})" @@ -80,12 +60,15 @@ class OSet(MutableSet[_T]): self.__items = {v: None for v in items} def __contains__(self, x): + # type: (Any) -> bool return x in self.__items def __iter__(self): + # type: () -> Iterator[_T] return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def add(self, value): @@ -97,6 +80,7 @@ class OSet(MutableSet[_T]): self.__items.pop(value, None) def __repr__(self): + # type: () -> str if len(self) == 0: return "OSet()" return f"OSet({list(self)})" @@ -106,6 +90,21 @@ class FMap(Mapping[_T, _T_co]): """ordered frozen hashable mapping""" __slots__ = "__items", "__hash" + @overload + def __init__(self, items): + # type: (Mapping[_T, _T_co]) -> None + ... + + @overload + def __init__(self, items): + # type: (Iterable[tuple[_T, _T_co]]) -> None + ... + + @overload + def __init__(self): + # type: () -> None + ... + def __init__(self, items=()): # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None self.__items = dict(items) # type: dict[_T, _T_co] @@ -120,20 +119,23 @@ class FMap(Mapping[_T, _T_co]): return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def __eq__(self, other): - # type: (object) -> bool + # type: (FMap[Any, Any] | Any) -> bool if isinstance(other, FMap): return self.__items == other.__items return super().__eq__(other) def __hash__(self): + # type: () -> int if self.__hash is None: self.__hash = hash(frozenset(self.items())) return self.__hash def __repr__(self): + # type: () -> str return f"FMap({self.__items})" @@ -153,7 +155,7 @@ def top_set_bit_index(v, default=-1): try: # added in cpython 3.10 - bit_count = int.bit_count # type: ignore[attr] + bit_count = int.bit_count # type: ignore except AttributeError: def bit_count(v): # type: (int) -> int @@ -177,16 +179,20 @@ class BaseBitSet(AbstractSet[int]): def __init__(self, items=(), bits=0): # type: (Iterable[int], int) -> None - for item in items: - if item < 0: - raise ValueError("can't store negative integers") - bits |= 1 << item + if isinstance(items, BaseBitSet): + bits |= items.bits + else: + for item in items: + if item < 0: + raise ValueError("can't store negative integers") + bits |= 1 << item if bits < 0: raise ValueError("can't store an infinite set") self.__bits = bits @property def bits(self): + # type: () -> int return self.__bits @bits.setter @@ -199,6 +205,7 @@ class BaseBitSet(AbstractSet[int]): self.__bits = bits def __contains__(self, x): + # type: (Any) -> bool if isinstance(x, int) and x >= 0: return (1 << x) & self.bits != 0 return False @@ -220,9 +227,11 @@ class BaseBitSet(AbstractSet[int]): bits -= 1 << index def __len__(self): + # type: () -> int return bit_count(self.bits) def __repr__(self): + # type: () -> str if self.bits == 0: return f"{self.__class__.__name__}()" if self.bits > 0xFFFFFFFF and len(self) < 10: @@ -231,7 +240,7 @@ class BaseBitSet(AbstractSet[int]): return f"{self.__class__.__name__}(bits={hex(self.bits)})" def __eq__(self, other): - # type: (object) -> bool + # type: (Any) -> bool if not isinstance(other, BaseBitSet): return super().__eq__(other) return self.bits == other.bits @@ -320,6 +329,7 @@ class BitSet(BaseBitSet, MutableSet[int]): self.bits &= ~(1 << value) def clear(self): + # type: () -> None self.bits = 0 def __ior__(self, it): @@ -361,4 +371,5 @@ class FBitSet(BaseBitSet): return True def __hash__(self): + # type: () -> int return super()._hash() diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi deleted file mode 100644 index 6315823..0000000 --- a/src/bigint_presentation_code/util.pyi +++ /dev/null @@ -1,190 +0,0 @@ -from abc import abstractmethod -from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, - NoReturn, TypeVar, overload) - -from typing_extensions import Literal, Self, final - -_T_co = TypeVar("_T_co", covariant=True) -_T = TypeVar("_T") - -__all__ = [ - "assert_never", - "BaseBitSet", - "bit_count", - "BitSet", - "FBitSet", - "final", - "FMap", - "Literal", - "OFSet", - "OSet", - "Self", - "top_set_bit_index", - "trailing_zero_count", -] - - -# pyright currently doesn't like typing_extensions' definition -# -- added to typing in python 3.11 -def assert_never(arg): - # type: (NoReturn) -> NoReturn - raise AssertionError("got to code that's supposed to be unreachable") - - -class OFSet(AbstractSet[_T_co]): - """ ordered frozen set """ - - def __init__(self, items: Iterable[_T_co] = ()): - ... - - def __contains__(self, x: object) -> bool: - ... - - def __iter__(self) -> Iterator[_T_co]: - ... - - def __len__(self) -> int: - ... - - def __hash__(self) -> int: - ... - - def __repr__(self) -> str: - ... - - -class OSet(MutableSet[_T]): - """ ordered mutable set """ - - def __init__(self, items: Iterable[_T] = ()): - ... - - def __contains__(self, x: object) -> bool: - ... - - def __iter__(self) -> Iterator[_T]: - ... - - def __len__(self) -> int: - ... - - def add(self, value: _T) -> None: - ... - - def discard(self, value: _T) -> None: - ... - - def __repr__(self) -> str: - ... - - -class FMap(Mapping[_T, _T_co]): - """ordered frozen hashable mapping""" - @overload - def __init__(self, items: Mapping[_T, _T_co]): ... - @overload - def __init__(self, items: Iterable[tuple[_T, _T_co]]): ... - @overload - def __init__(self): ... - - def __getitem__(self, item: _T) -> _T_co: - ... - - def __iter__(self) -> Iterator[_T]: - ... - - def __len__(self) -> int: - ... - - def __eq__(self, other: object) -> bool: - ... - - def __hash__(self) -> int: - ... - - def __repr__(self) -> str: - ... - - -def trailing_zero_count(v: int, default: int = -1) -> int: ... -def top_set_bit_index(v: int, default: int = -1) -> int: ... -def bit_count(v: int) -> int: ... - - -class BaseBitSet(AbstractSet[int]): - @classmethod - @abstractmethod - def _frozen(cls) -> bool: ... - - @classmethod - def _from_bits(cls, bits: int) -> Self: ... - - def __init__(self, items: Iterable[int] = (), bits: int = 0): ... - - @property - def bits(self) -> int: - ... - - @bits.setter - def bits(self, bits: int) -> None: ... - - def __contains__(self, x: object) -> bool: ... - - def __iter__(self) -> Iterator[int]: ... - - def __reversed__(self) -> Iterator[int]: ... - - def __len__(self) -> int: ... - - def __repr__(self) -> str: ... - - def __eq__(self, other: object) -> bool: ... - - def __and__(self, other: Iterable[Any]) -> Self: ... - - __rand__ = __and__ - - def __or__(self, other: Iterable[Any]) -> Self: ... - - __ror__ = __or__ - - def __xor__(self, other: Iterable[Any]) -> Self: ... - - __rxor__ = __xor__ - - def __sub__(self, other: Iterable[Any]) -> Self: ... - - def __rsub__(self, other: Iterable[Any]) -> Self: ... - - def isdisjoint(self, other: Iterable[Any]) -> bool: ... - - -class BitSet(BaseBitSet, MutableSet[int]): - @final - @classmethod - def _frozen(cls) -> Literal[False]: ... - - def add(self, value: int) -> None: ... - - def discard(self, value: int) -> None: ... - - def clear(self) -> None: ... - - def __ior__(self, it: AbstractSet[Any]) -> Self: ... - - def __iand__(self, it: AbstractSet[Any]) -> Self: ... - - def __ixor__(self, it: AbstractSet[Any]) -> Self: ... - - def __isub__(self, it: AbstractSet[Any]) -> Self: ... - - -class FBitSet(BaseBitSet): - @property - def bits(self) -> int: ... - - @final - @classmethod - def _frozen(cls) -> Literal[True]: ... - - def __hash__(self) -> int: ... diff --git a/typings/cached_property.pyi b/typings/cached_property.pyi index b8b1f30..5ec7085 100644 --- a/typings/cached_property.pyi +++ b/typings/cached_property.pyi @@ -1,15 +1 @@ -from typing import Any, Callable, Generic, TypeVar, overload - -_T = TypeVar("_T") - - -class cached_property(Generic[_T]): - def __init__(self, func: Callable[[Any], _T]) -> None: ... - - @overload - def __get__(self, instance: None, - owner: type[Any] | None = ...) -> cached_property[_T]: ... - - @overload - def __get__(self, instance: object, - owner: type[Any] | None = ...) -> _T: ... +cached_property = property -- 2.30.2