add BitSet classes
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 27 Oct 2022 08:01:19 +0000 (01:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 27 Oct 2022 08:01:19 +0000 (01:01 -0700)
src/bigint_presentation_code/util.py
src/bigint_presentation_code/util.pyi

index ef8c36c3533aaf66cca867a36b0913ae92281bb1..aeea240c545aaf458772a515c78252ac1b928fc9 100644 (file)
@@ -1,3 +1,4 @@
+from abc import abstractmethod
 from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator,
                     Mapping, MutableSet, NoReturn, TypeVar, Union)
 
@@ -20,7 +21,21 @@ else:
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
 
-__all__ = ["final", "Literal", "Self", "assert_never", "OFSet", "OSet", "FMap"]
+__all__ = [
+    "assert_never",
+    "BaseBitSet",
+    "bit_count",
+    "BitSet",
+    "FBitSet",
+    "final",
+    "FMap",
+    "Literal",
+    "OFSet",
+    "OSet",
+    "Self",
+    "top_set_bit_index",
+    "trailing_zero_count",
+]
 
 
 # pyright currently doesn't like typing_extensions' definition
@@ -120,3 +135,230 @@ class FMap(Mapping[_T, _T_co]):
 
     def __repr__(self):
         return f"FMap({self.__items})"
+
+
+def trailing_zero_count(v, default=-1):
+    # type: (int, int) -> int
+    without_bit = v & (v - 1)  # clear lowest set bit
+    bit = v & ~without_bit  # extract lowest set bit
+    return top_set_bit_index(bit, default)
+
+
+def top_set_bit_index(v, default=-1):
+    # type: (int, int) -> int
+    if v <= 0:
+        return default
+    return v.bit_length() - 1
+
+
+try:
+    # added in cpython 3.10
+    bit_count = int.bit_count  # type: ignore[attr]
+except AttributeError:
+    def bit_count(v):
+        # type: (int) -> int
+        """returns the number of 1 bits in the absolute value of the input"""
+        return bin(abs(v)).count('1')
+
+
+class BaseBitSet(AbstractSet[int]):
+    __slots__ = "__bits",
+
+    @classmethod
+    @abstractmethod
+    def _frozen(cls):
+        # type: () -> bool
+        return False
+
+    @classmethod
+    def _from_bits(cls, bits):
+        # type: (int) -> Self
+        return cls(bits=bits)
+
+    def __init__(self, items=(), bits=0):
+        # type: (Iterable[int], int) -> None
+        for item in items:
+            if item < 0:
+                raise ValueError("can't store negative integers")
+            bits |= 1 << item
+        if bits < 0:
+            raise ValueError("can't store an infinite set")
+        self.__bits = bits
+
+    @property
+    def bits(self):
+        return self.__bits
+
+    @bits.setter
+    def bits(self, bits):
+        # type: (int) -> None
+        if self._frozen():
+            raise AttributeError("can't write to frozen bitset's bits")
+        if bits < 0:
+            raise ValueError("can't store an infinite set")
+        self.__bits = bits
+
+    def __contains__(self, x):
+        if isinstance(x, int) and x >= 0:
+            return (1 << x) & self.bits != 0
+        return False
+
+    def __iter__(self):
+        # type: () -> Iterator[int]
+        bits = self.bits
+        while bits != 0:
+            index = trailing_zero_count(bits)
+            yield index
+            bits -= 1 << index
+
+    def __reversed__(self):
+        # type: () -> Iterator[int]
+        bits = self.bits
+        while bits != 0:
+            index = top_set_bit_index(bits)
+            yield index
+            bits -= 1 << index
+
+    def __len__(self):
+        return bit_count(self.bits)
+
+    def __repr__(self):
+        if self.bits == 0:
+            return f"{self.__class__.__name__}()"
+        if self.bits > 0xFFFFFFFF and len(self) < 10:
+            v = list(self)
+            return f"{self.__class__.__name__}({v})"
+        return f"{self.__class__.__name__}(bits={hex(self.bits)})"
+
+    def __eq__(self, other):
+        # type: (object) -> bool
+        if not isinstance(other, BaseBitSet):
+            return super().__eq__(other)
+        return self.bits == other.bits
+
+    def __and__(self, other):
+        # type: (Iterable[Any]) -> Self
+        if isinstance(other, BaseBitSet):
+            return self._from_bits(self.bits & other.bits)
+        bits = 0
+        for item in other:
+            if isinstance(item, int) and item >= 0:
+                bits |= 1 << item
+        return self._from_bits(self.bits & bits)
+
+    __rand__ = __and__
+
+    def __or__(self, other):
+        # type: (Iterable[Any]) -> Self
+        if isinstance(other, BaseBitSet):
+            return self._from_bits(self.bits | other.bits)
+        bits = self.bits
+        for item in other:
+            if isinstance(item, int) and item >= 0:
+                bits |= 1 << item
+        return self._from_bits(bits)
+
+    __ror__ = __or__
+
+    def __xor__(self, other):
+        # type: (Iterable[Any]) -> Self
+        if isinstance(other, BaseBitSet):
+            return self._from_bits(self.bits ^ other.bits)
+        bits = self.bits
+        for item in other:
+            if isinstance(item, int) and item >= 0:
+                bits ^= 1 << item
+        return self._from_bits(bits)
+
+    __rxor__ = __xor__
+
+    def __sub__(self, other):
+        # type: (Iterable[Any]) -> Self
+        if isinstance(other, BaseBitSet):
+            return self._from_bits(self.bits & ~other.bits)
+        bits = self.bits
+        for item in other:
+            if isinstance(item, int) and item >= 0:
+                bits &= ~(1 << item)
+        return self._from_bits(bits)
+
+    def __rsub__(self, other):
+        # type: (Iterable[Any]) -> Self
+        if isinstance(other, BaseBitSet):
+            return self._from_bits(~self.bits & other.bits)
+        bits = 0
+        for item in other:
+            if isinstance(item, int) and item >= 0:
+                bits |= 1 << item
+        return self._from_bits(~self.bits & bits)
+
+    def isdisjoint(self, other):
+        # type: (Iterable[Any]) -> bool
+        if isinstance(other, BaseBitSet):
+            return self.bits & other.bits == 0
+        return super().isdisjoint(other)
+
+
+class BitSet(BaseBitSet, MutableSet[int]):
+    """Mutable Bit Set"""
+
+    @final
+    @classmethod
+    def _frozen(cls):
+        # type: () -> bool
+        return False
+
+    def add(self, value):
+        # type: (int) -> None
+        if value < 0:
+            raise ValueError("can't store negative integers")
+        self.bits |= 1 << value
+
+    def discard(self, value):
+        # type: (int) -> None
+        if value >= 0:
+            self.bits &= ~(1 << value)
+
+    def clear(self):
+        self.bits = 0
+
+    def __ior__(self, it):
+        # type: (AbstractSet[Any]) -> Self
+        if isinstance(it, BaseBitSet):
+            self.bits |= it.bits
+            return self
+        return super().__ior__(it)
+
+    def __iand__(self, it):
+        # type: (AbstractSet[Any]) -> Self
+        if isinstance(it, BaseBitSet):
+            self.bits &= it.bits
+            return self
+        return super().__iand__(it)
+
+    def __ixor__(self, it):
+        # type: (AbstractSet[Any]) -> Self
+        if isinstance(it, BaseBitSet):
+            self.bits ^= it.bits
+            return self
+        return super().__ixor__(it)
+
+    def __isub__(self, it):
+        # type: (AbstractSet[Any]) -> Self
+        if isinstance(it, BaseBitSet):
+            self.bits &= ~it.bits
+            return self
+        return super().__isub__(it)
+
+
+class FBitSet(BaseBitSet):
+    """Frozen Bit Set"""
+
+    @final
+    @classmethod
+    def _frozen(cls):
+        # type: () -> bool
+        return True
+
+    def __hash__(self):
+        return super()._hash()
index 0e0dbc72a21e734e6afa9282996b4dcc76d21a58..6315823fa3b23bf075dd24678e6601e122b95a19 100644 (file)
@@ -1,11 +1,27 @@
-from typing import (AbstractSet, Iterable, Iterator, Mapping,
-                    MutableSet, NoReturn, TypeVar, overload)
-from typing_extensions import final, Literal, Self
+from abc import abstractmethod
+from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
+                    NoReturn, TypeVar, overload)
+
+from typing_extensions import Literal, Self, final
 
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
 
-__all__ = ["final", "Literal", "Self", "assert_never", "OFSet", "OSet", "FMap"]
+__all__ = [
+    "assert_never",
+    "BaseBitSet",
+    "bit_count",
+    "BitSet",
+    "FBitSet",
+    "final",
+    "FMap",
+    "Literal",
+    "OFSet",
+    "OSet",
+    "Self",
+    "top_set_bit_index",
+    "trailing_zero_count",
+]
 
 
 # pyright currently doesn't like typing_extensions' definition
@@ -88,3 +104,87 @@ class FMap(Mapping[_T, _T_co]):
 
     def __repr__(self) -> str:
         ...
+
+
+def trailing_zero_count(v: int, default: int = -1) -> int: ...
+def top_set_bit_index(v: int, default: int = -1) -> int: ...
+def bit_count(v: int) -> int: ...
+
+
+class BaseBitSet(AbstractSet[int]):
+    @classmethod
+    @abstractmethod
+    def _frozen(cls) -> bool: ...
+
+    @classmethod
+    def _from_bits(cls, bits: int) -> Self: ...
+
+    def __init__(self, items: Iterable[int] = (), bits: int = 0): ...
+
+    @property
+    def bits(self) -> int:
+        ...
+
+    @bits.setter
+    def bits(self, bits: int) -> None: ...
+
+    def __contains__(self, x: object) -> bool: ...
+
+    def __iter__(self) -> Iterator[int]: ...
+
+    def __reversed__(self) -> Iterator[int]: ...
+
+    def __len__(self) -> int: ...
+
+    def __repr__(self) -> str: ...
+
+    def __eq__(self, other: object) -> bool: ...
+
+    def __and__(self, other: Iterable[Any]) -> Self: ...
+
+    __rand__ = __and__
+
+    def __or__(self, other: Iterable[Any]) -> Self: ...
+
+    __ror__ = __or__
+
+    def __xor__(self, other: Iterable[Any]) -> Self: ...
+
+    __rxor__ = __xor__
+
+    def __sub__(self, other: Iterable[Any]) -> Self: ...
+
+    def __rsub__(self, other: Iterable[Any]) -> Self: ...
+
+    def isdisjoint(self, other: Iterable[Any]) -> bool: ...
+
+
+class BitSet(BaseBitSet, MutableSet[int]):
+    @final
+    @classmethod
+    def _frozen(cls) -> Literal[False]: ...
+
+    def add(self, value: int) -> None: ...
+
+    def discard(self, value: int) -> None: ...
+
+    def clear(self) -> None: ...
+
+    def __ior__(self, it: AbstractSet[Any]) -> Self: ...
+
+    def __iand__(self, it: AbstractSet[Any]) -> Self: ...
+
+    def __ixor__(self, it: AbstractSet[Any]) -> Self: ...
+
+    def __isub__(self, it: AbstractSet[Any]) -> Self: ...
+
+
+class FBitSet(BaseBitSet):
+    @property
+    def bits(self) -> int: ...
+
+    @final
+    @classmethod
+    def _frozen(cls) -> Literal[True]: ...
+
+    def __hash__(self) -> int: ...