add BitSet classes
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import abstractmethod
2 from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator,
3 Mapping, MutableSet, NoReturn, TypeVar, Union)
4
5 if TYPE_CHECKING:
6 from typing_extensions import Literal, Self, final
7 else:
8 def final(v):
9 return v
10
11 class _Literal:
12 def __getitem__(self, v):
13 if isinstance(v, tuple):
14 return Union[tuple(type(i) for i in v)]
15 return type(v)
16
17 Literal = _Literal()
18
19 Self = Any
20
21 _T_co = TypeVar("_T_co", covariant=True)
22 _T = TypeVar("_T")
23
24 __all__ = [
25 "assert_never",
26 "BaseBitSet",
27 "bit_count",
28 "BitSet",
29 "FBitSet",
30 "final",
31 "FMap",
32 "Literal",
33 "OFSet",
34 "OSet",
35 "Self",
36 "top_set_bit_index",
37 "trailing_zero_count",
38 ]
39
40
41 # pyright currently doesn't like typing_extensions' definition
42 # -- added to typing in python 3.11
43 def assert_never(arg):
44 # type: (NoReturn) -> NoReturn
45 raise AssertionError("got to code that's supposed to be unreachable")
46
47
48 class OFSet(AbstractSet[_T_co]):
49 """ ordered frozen set """
50 __slots__ = "__items",
51
52 def __init__(self, items=()):
53 # type: (Iterable[_T_co]) -> None
54 self.__items = {v: None for v in items}
55
56 def __contains__(self, x):
57 return x in self.__items
58
59 def __iter__(self):
60 return iter(self.__items)
61
62 def __len__(self):
63 return len(self.__items)
64
65 def __hash__(self):
66 return self._hash()
67
68 def __repr__(self):
69 if len(self) == 0:
70 return "OFSet()"
71 return f"OFSet({list(self)})"
72
73
74 class OSet(MutableSet[_T]):
75 """ ordered mutable set """
76 __slots__ = "__items",
77
78 def __init__(self, items=()):
79 # type: (Iterable[_T]) -> None
80 self.__items = {v: None for v in items}
81
82 def __contains__(self, x):
83 return x in self.__items
84
85 def __iter__(self):
86 return iter(self.__items)
87
88 def __len__(self):
89 return len(self.__items)
90
91 def add(self, value):
92 # type: (_T) -> None
93 self.__items[value] = None
94
95 def discard(self, value):
96 # type: (_T) -> None
97 self.__items.pop(value, None)
98
99 def __repr__(self):
100 if len(self) == 0:
101 return "OSet()"
102 return f"OSet({list(self)})"
103
104
105 class FMap(Mapping[_T, _T_co]):
106 """ordered frozen hashable mapping"""
107 __slots__ = "__items", "__hash"
108
109 def __init__(self, items=()):
110 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
111 self.__items = dict(items) # type: dict[_T, _T_co]
112 self.__hash = None # type: None | int
113
114 def __getitem__(self, item):
115 # type: (_T) -> _T_co
116 return self.__items[item]
117
118 def __iter__(self):
119 # type: () -> Iterator[_T]
120 return iter(self.__items)
121
122 def __len__(self):
123 return len(self.__items)
124
125 def __eq__(self, other):
126 # type: (object) -> bool
127 if isinstance(other, FMap):
128 return self.__items == other.__items
129 return super().__eq__(other)
130
131 def __hash__(self):
132 if self.__hash is None:
133 self.__hash = hash(frozenset(self.items()))
134 return self.__hash
135
136 def __repr__(self):
137 return f"FMap({self.__items})"
138
139
140 def trailing_zero_count(v, default=-1):
141 # type: (int, int) -> int
142 without_bit = v & (v - 1) # clear lowest set bit
143 bit = v & ~without_bit # extract lowest set bit
144 return top_set_bit_index(bit, default)
145
146
147 def top_set_bit_index(v, default=-1):
148 # type: (int, int) -> int
149 if v <= 0:
150 return default
151 return v.bit_length() - 1
152
153
154 try:
155 # added in cpython 3.10
156 bit_count = int.bit_count # type: ignore[attr]
157 except AttributeError:
158 def bit_count(v):
159 # type: (int) -> int
160 """returns the number of 1 bits in the absolute value of the input"""
161 return bin(abs(v)).count('1')
162
163
164 class BaseBitSet(AbstractSet[int]):
165 __slots__ = "__bits",
166
167 @classmethod
168 @abstractmethod
169 def _frozen(cls):
170 # type: () -> bool
171 return False
172
173 @classmethod
174 def _from_bits(cls, bits):
175 # type: (int) -> Self
176 return cls(bits=bits)
177
178 def __init__(self, items=(), bits=0):
179 # type: (Iterable[int], int) -> None
180 for item in items:
181 if item < 0:
182 raise ValueError("can't store negative integers")
183 bits |= 1 << item
184 if bits < 0:
185 raise ValueError("can't store an infinite set")
186 self.__bits = bits
187
188 @property
189 def bits(self):
190 return self.__bits
191
192 @bits.setter
193 def bits(self, bits):
194 # type: (int) -> None
195 if self._frozen():
196 raise AttributeError("can't write to frozen bitset's bits")
197 if bits < 0:
198 raise ValueError("can't store an infinite set")
199 self.__bits = bits
200
201 def __contains__(self, x):
202 if isinstance(x, int) and x >= 0:
203 return (1 << x) & self.bits != 0
204 return False
205
206 def __iter__(self):
207 # type: () -> Iterator[int]
208 bits = self.bits
209 while bits != 0:
210 index = trailing_zero_count(bits)
211 yield index
212 bits -= 1 << index
213
214 def __reversed__(self):
215 # type: () -> Iterator[int]
216 bits = self.bits
217 while bits != 0:
218 index = top_set_bit_index(bits)
219 yield index
220 bits -= 1 << index
221
222 def __len__(self):
223 return bit_count(self.bits)
224
225 def __repr__(self):
226 if self.bits == 0:
227 return f"{self.__class__.__name__}()"
228 if self.bits > 0xFFFFFFFF and len(self) < 10:
229 v = list(self)
230 return f"{self.__class__.__name__}({v})"
231 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
232
233 def __eq__(self, other):
234 # type: (object) -> bool
235 if not isinstance(other, BaseBitSet):
236 return super().__eq__(other)
237 return self.bits == other.bits
238
239 def __and__(self, other):
240 # type: (Iterable[Any]) -> Self
241 if isinstance(other, BaseBitSet):
242 return self._from_bits(self.bits & other.bits)
243 bits = 0
244 for item in other:
245 if isinstance(item, int) and item >= 0:
246 bits |= 1 << item
247 return self._from_bits(self.bits & bits)
248
249 __rand__ = __and__
250
251 def __or__(self, other):
252 # type: (Iterable[Any]) -> Self
253 if isinstance(other, BaseBitSet):
254 return self._from_bits(self.bits | other.bits)
255 bits = self.bits
256 for item in other:
257 if isinstance(item, int) and item >= 0:
258 bits |= 1 << item
259 return self._from_bits(bits)
260
261 __ror__ = __or__
262
263 def __xor__(self, other):
264 # type: (Iterable[Any]) -> Self
265 if isinstance(other, BaseBitSet):
266 return self._from_bits(self.bits ^ other.bits)
267 bits = self.bits
268 for item in other:
269 if isinstance(item, int) and item >= 0:
270 bits ^= 1 << item
271 return self._from_bits(bits)
272
273 __rxor__ = __xor__
274
275 def __sub__(self, other):
276 # type: (Iterable[Any]) -> Self
277 if isinstance(other, BaseBitSet):
278 return self._from_bits(self.bits & ~other.bits)
279 bits = self.bits
280 for item in other:
281 if isinstance(item, int) and item >= 0:
282 bits &= ~(1 << item)
283 return self._from_bits(bits)
284
285 def __rsub__(self, other):
286 # type: (Iterable[Any]) -> Self
287 if isinstance(other, BaseBitSet):
288 return self._from_bits(~self.bits & other.bits)
289 bits = 0
290 for item in other:
291 if isinstance(item, int) and item >= 0:
292 bits |= 1 << item
293 return self._from_bits(~self.bits & bits)
294
295 def isdisjoint(self, other):
296 # type: (Iterable[Any]) -> bool
297 if isinstance(other, BaseBitSet):
298 return self.bits & other.bits == 0
299 return super().isdisjoint(other)
300
301
302 class BitSet(BaseBitSet, MutableSet[int]):
303 """Mutable Bit Set"""
304
305 @final
306 @classmethod
307 def _frozen(cls):
308 # type: () -> bool
309 return False
310
311 def add(self, value):
312 # type: (int) -> None
313 if value < 0:
314 raise ValueError("can't store negative integers")
315 self.bits |= 1 << value
316
317 def discard(self, value):
318 # type: (int) -> None
319 if value >= 0:
320 self.bits &= ~(1 << value)
321
322 def clear(self):
323 self.bits = 0
324
325 def __ior__(self, it):
326 # type: (AbstractSet[Any]) -> Self
327 if isinstance(it, BaseBitSet):
328 self.bits |= it.bits
329 return self
330 return super().__ior__(it)
331
332 def __iand__(self, it):
333 # type: (AbstractSet[Any]) -> Self
334 if isinstance(it, BaseBitSet):
335 self.bits &= it.bits
336 return self
337 return super().__iand__(it)
338
339 def __ixor__(self, it):
340 # type: (AbstractSet[Any]) -> Self
341 if isinstance(it, BaseBitSet):
342 self.bits ^= it.bits
343 return self
344 return super().__ixor__(it)
345
346 def __isub__(self, it):
347 # type: (AbstractSet[Any]) -> Self
348 if isinstance(it, BaseBitSet):
349 self.bits &= ~it.bits
350 return self
351 return super().__isub__(it)
352
353
354 class FBitSet(BaseBitSet):
355 """Frozen Bit Set"""
356
357 @final
358 @classmethod
359 def _frozen(cls):
360 # type: () -> bool
361 return True
362
363 def __hash__(self):
364 return super()._hash()