working on rewriting compiler ir to fix reg alloc issues
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
index aeea240c545aaf458772a515c78252ac1b928fc9..4b3978741eb5290cfd4257846ecbf5bd214aff07 100644 (file)
@@ -1,50 +1,25 @@
 from abc import abstractmethod
-from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator,
-                    Mapping, MutableSet, NoReturn, TypeVar, Union)
+from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
+                    TypeVar, overload)
 
-if TYPE_CHECKING:
-    from typing_extensions import Literal, Self, final
-else:
-    def final(v):
-        return v
-
-    class _Literal:
-        def __getitem__(self, v):
-            if isinstance(v, tuple):
-                return Union[tuple(type(i) for i in v)]
-            return type(v)
-
-    Literal = _Literal()
-
-    Self = Any
+from bigint_presentation_code.type_util import Self, final
 
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
 
 __all__ = [
-    "assert_never",
     "BaseBitSet",
     "bit_count",
     "BitSet",
     "FBitSet",
-    "final",
     "FMap",
-    "Literal",
     "OFSet",
     "OSet",
-    "Self",
     "top_set_bit_index",
     "trailing_zero_count",
 ]
 
 
-# pyright currently doesn't like typing_extensions' definition
-# -- added to typing in python 3.11
-def assert_never(arg):
-    # type: (NoReturn) -> NoReturn
-    raise AssertionError("got to code that's supposed to be unreachable")
-
-
 class OFSet(AbstractSet[_T_co]):
     """ ordered frozen set """
     __slots__ = "__items",
@@ -54,18 +29,23 @@ class OFSet(AbstractSet[_T_co]):
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         return x in self.__items
 
     def __iter__(self):
+        # type: () -> Iterator[_T_co]
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def __hash__(self):
+        # type: () -> int
         return self._hash()
 
     def __repr__(self):
+        # type: () -> str
         if len(self) == 0:
             return "OFSet()"
         return f"OFSet({list(self)})"
@@ -80,12 +60,15 @@ class OSet(MutableSet[_T]):
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         return x in self.__items
 
     def __iter__(self):
+        # type: () -> Iterator[_T]
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def add(self, value):
@@ -97,6 +80,7 @@ class OSet(MutableSet[_T]):
         self.__items.pop(value, None)
 
     def __repr__(self):
+        # type: () -> str
         if len(self) == 0:
             return "OSet()"
         return f"OSet({list(self)})"
@@ -106,6 +90,21 @@ class FMap(Mapping[_T, _T_co]):
     """ordered frozen hashable mapping"""
     __slots__ = "__items", "__hash"
 
+    @overload
+    def __init__(self, items):
+        # type: (Mapping[_T, _T_co]) -> None
+        ...
+
+    @overload
+    def __init__(self, items):
+        # type: (Iterable[tuple[_T, _T_co]]) -> None
+        ...
+
+    @overload
+    def __init__(self):
+        # type: () -> None
+        ...
+
     def __init__(self, items=()):
         # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
         self.__items = dict(items)  # type: dict[_T, _T_co]
@@ -120,20 +119,23 @@ class FMap(Mapping[_T, _T_co]):
         return iter(self.__items)
 
     def __len__(self):
+        # type: () -> int
         return len(self.__items)
 
     def __eq__(self, other):
-        # type: (object) -> bool
+        # type: (FMap[Any, Any] | Any) -> bool
         if isinstance(other, FMap):
             return self.__items == other.__items
         return super().__eq__(other)
 
     def __hash__(self):
+        # type: () -> int
         if self.__hash is None:
             self.__hash = hash(frozenset(self.items()))
         return self.__hash
 
     def __repr__(self):
+        # type: () -> str
         return f"FMap({self.__items})"
 
 
@@ -153,7 +155,7 @@ def top_set_bit_index(v, default=-1):
 
 try:
     # added in cpython 3.10
-    bit_count = int.bit_count  # type: ignore[attr]
+    bit_count = int.bit_count  # type: ignore
 except AttributeError:
     def bit_count(v):
         # type: (int) -> int
@@ -177,16 +179,20 @@ class BaseBitSet(AbstractSet[int]):
 
     def __init__(self, items=(), bits=0):
         # type: (Iterable[int], int) -> None
-        for item in items:
-            if item < 0:
-                raise ValueError("can't store negative integers")
-            bits |= 1 << item
+        if isinstance(items, BaseBitSet):
+            bits |= items.bits
+        else:
+            for item in items:
+                if item < 0:
+                    raise ValueError("can't store negative integers")
+                bits |= 1 << item
         if bits < 0:
             raise ValueError("can't store an infinite set")
         self.__bits = bits
 
     @property
     def bits(self):
+        # type: () -> int
         return self.__bits
 
     @bits.setter
@@ -199,6 +205,7 @@ class BaseBitSet(AbstractSet[int]):
         self.__bits = bits
 
     def __contains__(self, x):
+        # type: (Any) -> bool
         if isinstance(x, int) and x >= 0:
             return (1 << x) & self.bits != 0
         return False
@@ -220,9 +227,11 @@ class BaseBitSet(AbstractSet[int]):
             bits -= 1 << index
 
     def __len__(self):
+        # type: () -> int
         return bit_count(self.bits)
 
     def __repr__(self):
+        # type: () -> str
         if self.bits == 0:
             return f"{self.__class__.__name__}()"
         if self.bits > 0xFFFFFFFF and len(self) < 10:
@@ -231,7 +240,7 @@ class BaseBitSet(AbstractSet[int]):
         return f"{self.__class__.__name__}(bits={hex(self.bits)})"
 
     def __eq__(self, other):
-        # type: (object) -> bool
+        # type: (Any) -> bool
         if not isinstance(other, BaseBitSet):
             return super().__eq__(other)
         return self.bits == other.bits
@@ -320,6 +329,7 @@ class BitSet(BaseBitSet, MutableSet[int]):
             self.bits &= ~(1 << value)
 
     def clear(self):
+        # type: () -> None
         self.bits = 0
 
     def __ior__(self, it):
@@ -361,4 +371,5 @@ class FBitSet(BaseBitSet):
         return True
 
     def __hash__(self):
+        # type: () -> int
         return super()._hash()