From: Jacob Lifshay Date: Thu, 27 Oct 2022 08:01:19 +0000 (-0700) Subject: add BitSet classes X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=def8a912ee1fe40aba16219443824969d67ee13b;p=bigint-presentation-code.git add BitSet classes --- diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index ef8c36c..aeea240 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, NoReturn, TypeVar, Union) @@ -20,7 +21,21 @@ else: _T_co = TypeVar("_T_co", covariant=True) _T = TypeVar("_T") -__all__ = ["final", "Literal", "Self", "assert_never", "OFSet", "OSet", "FMap"] +__all__ = [ + "assert_never", + "BaseBitSet", + "bit_count", + "BitSet", + "FBitSet", + "final", + "FMap", + "Literal", + "OFSet", + "OSet", + "Self", + "top_set_bit_index", + "trailing_zero_count", +] # pyright currently doesn't like typing_extensions' definition @@ -120,3 +135,230 @@ class FMap(Mapping[_T, _T_co]): def __repr__(self): return f"FMap({self.__items})" + + +def trailing_zero_count(v, default=-1): + # type: (int, int) -> int + without_bit = v & (v - 1) # clear lowest set bit + bit = v & ~without_bit # extract lowest set bit + return top_set_bit_index(bit, default) + + +def top_set_bit_index(v, default=-1): + # type: (int, int) -> int + if v <= 0: + return default + return v.bit_length() - 1 + + +try: + # added in cpython 3.10 + bit_count = int.bit_count # type: ignore[attr] +except AttributeError: + def bit_count(v): + # type: (int) -> int + """returns the number of 1 bits in the absolute value of the input""" + return bin(abs(v)).count('1') + + +class BaseBitSet(AbstractSet[int]): + __slots__ = "__bits", + + @classmethod + @abstractmethod + def _frozen(cls): + # type: () -> bool + return False + + @classmethod + def _from_bits(cls, bits): + # type: (int) -> Self + return cls(bits=bits) + + def __init__(self, items=(), bits=0): + # type: (Iterable[int], int) -> None + for item in items: + if item < 0: + raise ValueError("can't store negative integers") + bits |= 1 << item + if bits < 0: + raise ValueError("can't store an infinite set") + self.__bits = bits + + @property + def bits(self): + return self.__bits + + @bits.setter + def bits(self, bits): + # type: (int) -> None + if self._frozen(): + raise AttributeError("can't write to frozen bitset's bits") + if bits < 0: + raise ValueError("can't store an infinite set") + self.__bits = bits + + def __contains__(self, x): + if isinstance(x, int) and x >= 0: + return (1 << x) & self.bits != 0 + return False + + def __iter__(self): + # type: () -> Iterator[int] + bits = self.bits + while bits != 0: + index = trailing_zero_count(bits) + yield index + bits -= 1 << index + + def __reversed__(self): + # type: () -> Iterator[int] + bits = self.bits + while bits != 0: + index = top_set_bit_index(bits) + yield index + bits -= 1 << index + + def __len__(self): + return bit_count(self.bits) + + def __repr__(self): + if self.bits == 0: + return f"{self.__class__.__name__}()" + if self.bits > 0xFFFFFFFF and len(self) < 10: + v = list(self) + return f"{self.__class__.__name__}({v})" + return f"{self.__class__.__name__}(bits={hex(self.bits)})" + + def __eq__(self, other): + # type: (object) -> bool + if not isinstance(other, BaseBitSet): + return super().__eq__(other) + return self.bits == other.bits + + def __and__(self, other): + # type: (Iterable[Any]) -> Self + if isinstance(other, BaseBitSet): + return self._from_bits(self.bits & other.bits) + bits = 0 + for item in other: + if isinstance(item, int) and item >= 0: + bits |= 1 << item + return self._from_bits(self.bits & bits) + + __rand__ = __and__ + + def __or__(self, other): + # type: (Iterable[Any]) -> Self + if isinstance(other, BaseBitSet): + return self._from_bits(self.bits | other.bits) + bits = self.bits + for item in other: + if isinstance(item, int) and item >= 0: + bits |= 1 << item + return self._from_bits(bits) + + __ror__ = __or__ + + def __xor__(self, other): + # type: (Iterable[Any]) -> Self + if isinstance(other, BaseBitSet): + return self._from_bits(self.bits ^ other.bits) + bits = self.bits + for item in other: + if isinstance(item, int) and item >= 0: + bits ^= 1 << item + return self._from_bits(bits) + + __rxor__ = __xor__ + + def __sub__(self, other): + # type: (Iterable[Any]) -> Self + if isinstance(other, BaseBitSet): + return self._from_bits(self.bits & ~other.bits) + bits = self.bits + for item in other: + if isinstance(item, int) and item >= 0: + bits &= ~(1 << item) + return self._from_bits(bits) + + def __rsub__(self, other): + # type: (Iterable[Any]) -> Self + if isinstance(other, BaseBitSet): + return self._from_bits(~self.bits & other.bits) + bits = 0 + for item in other: + if isinstance(item, int) and item >= 0: + bits |= 1 << item + return self._from_bits(~self.bits & bits) + + def isdisjoint(self, other): + # type: (Iterable[Any]) -> bool + if isinstance(other, BaseBitSet): + return self.bits & other.bits == 0 + return super().isdisjoint(other) + + +class BitSet(BaseBitSet, MutableSet[int]): + """Mutable Bit Set""" + + @final + @classmethod + def _frozen(cls): + # type: () -> bool + return False + + def add(self, value): + # type: (int) -> None + if value < 0: + raise ValueError("can't store negative integers") + self.bits |= 1 << value + + def discard(self, value): + # type: (int) -> None + if value >= 0: + self.bits &= ~(1 << value) + + def clear(self): + self.bits = 0 + + def __ior__(self, it): + # type: (AbstractSet[Any]) -> Self + if isinstance(it, BaseBitSet): + self.bits |= it.bits + return self + return super().__ior__(it) + + def __iand__(self, it): + # type: (AbstractSet[Any]) -> Self + if isinstance(it, BaseBitSet): + self.bits &= it.bits + return self + return super().__iand__(it) + + def __ixor__(self, it): + # type: (AbstractSet[Any]) -> Self + if isinstance(it, BaseBitSet): + self.bits ^= it.bits + return self + return super().__ixor__(it) + + def __isub__(self, it): + # type: (AbstractSet[Any]) -> Self + if isinstance(it, BaseBitSet): + self.bits &= ~it.bits + return self + return super().__isub__(it) + + +class FBitSet(BaseBitSet): + """Frozen Bit Set""" + + @final + @classmethod + def _frozen(cls): + # type: () -> bool + return True + + def __hash__(self): + return super()._hash() diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi index 0e0dbc7..6315823 100644 --- a/src/bigint_presentation_code/util.pyi +++ b/src/bigint_presentation_code/util.pyi @@ -1,11 +1,27 @@ -from typing import (AbstractSet, Iterable, Iterator, Mapping, - MutableSet, NoReturn, TypeVar, overload) -from typing_extensions import final, Literal, Self +from abc import abstractmethod +from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, + NoReturn, TypeVar, overload) + +from typing_extensions import Literal, Self, final _T_co = TypeVar("_T_co", covariant=True) _T = TypeVar("_T") -__all__ = ["final", "Literal", "Self", "assert_never", "OFSet", "OSet", "FMap"] +__all__ = [ + "assert_never", + "BaseBitSet", + "bit_count", + "BitSet", + "FBitSet", + "final", + "FMap", + "Literal", + "OFSet", + "OSet", + "Self", + "top_set_bit_index", + "trailing_zero_count", +] # pyright currently doesn't like typing_extensions' definition @@ -88,3 +104,87 @@ class FMap(Mapping[_T, _T_co]): def __repr__(self) -> str: ... + + +def trailing_zero_count(v: int, default: int = -1) -> int: ... +def top_set_bit_index(v: int, default: int = -1) -> int: ... +def bit_count(v: int) -> int: ... + + +class BaseBitSet(AbstractSet[int]): + @classmethod + @abstractmethod + def _frozen(cls) -> bool: ... + + @classmethod + def _from_bits(cls, bits: int) -> Self: ... + + def __init__(self, items: Iterable[int] = (), bits: int = 0): ... + + @property + def bits(self) -> int: + ... + + @bits.setter + def bits(self, bits: int) -> None: ... + + def __contains__(self, x: object) -> bool: ... + + def __iter__(self) -> Iterator[int]: ... + + def __reversed__(self) -> Iterator[int]: ... + + def __len__(self) -> int: ... + + def __repr__(self) -> str: ... + + def __eq__(self, other: object) -> bool: ... + + def __and__(self, other: Iterable[Any]) -> Self: ... + + __rand__ = __and__ + + def __or__(self, other: Iterable[Any]) -> Self: ... + + __ror__ = __or__ + + def __xor__(self, other: Iterable[Any]) -> Self: ... + + __rxor__ = __xor__ + + def __sub__(self, other: Iterable[Any]) -> Self: ... + + def __rsub__(self, other: Iterable[Any]) -> Self: ... + + def isdisjoint(self, other: Iterable[Any]) -> bool: ... + + +class BitSet(BaseBitSet, MutableSet[int]): + @final + @classmethod + def _frozen(cls) -> Literal[False]: ... + + def add(self, value: int) -> None: ... + + def discard(self, value: int) -> None: ... + + def clear(self) -> None: ... + + def __ior__(self, it: AbstractSet[Any]) -> Self: ... + + def __iand__(self, it: AbstractSet[Any]) -> Self: ... + + def __ixor__(self, it: AbstractSet[Any]) -> Self: ... + + def __isub__(self, it: AbstractSet[Any]) -> Self: ... + + +class FBitSet(BaseBitSet): + @property + def bits(self) -> int: ... + + @final + @classmethod + def _frozen(cls) -> Literal[True]: ... + + def __hash__(self) -> int: ...