757f267ecdfabcc923980e3d863666b5c55d0305
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import abstractmethod
2 from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
3 TypeVar, overload)
4
5 from bigint_presentation_code.type_util import Self, final
6
7 _T_co = TypeVar("_T_co", covariant=True)
8 _T = TypeVar("_T")
9
10 __all__ = [
11 "BaseBitSet",
12 "bit_count",
13 "BitSet",
14 "FBitSet",
15 "FMap",
16 "OFSet",
17 "OSet",
18 "top_set_bit_index",
19 "trailing_zero_count",
20 ]
21
22
23 class OFSet(AbstractSet[_T_co]):
24 """ ordered frozen set """
25 __slots__ = "__items",
26
27 def __init__(self, items=()):
28 # type: (Iterable[_T_co]) -> None
29 super().__init__()
30 self.__items = {v: None for v in items}
31
32 def __contains__(self, x):
33 # type: (Any) -> bool
34 return x in self.__items
35
36 def __iter__(self):
37 # type: () -> Iterator[_T_co]
38 return iter(self.__items)
39
40 def __len__(self):
41 # type: () -> int
42 return len(self.__items)
43
44 def __hash__(self):
45 # type: () -> int
46 return self._hash()
47
48 def __repr__(self):
49 # type: () -> str
50 if len(self) == 0:
51 return "OFSet()"
52 return f"OFSet({list(self)})"
53
54
55 class OSet(MutableSet[_T]):
56 """ ordered mutable set """
57 __slots__ = "__items",
58
59 def __init__(self, items=()):
60 # type: (Iterable[_T]) -> None
61 super().__init__()
62 self.__items = {v: None for v in items}
63
64 def __contains__(self, x):
65 # type: (Any) -> bool
66 return x in self.__items
67
68 def __iter__(self):
69 # type: () -> Iterator[_T]
70 return iter(self.__items)
71
72 def __len__(self):
73 # type: () -> int
74 return len(self.__items)
75
76 def add(self, value):
77 # type: (_T) -> None
78 self.__items[value] = None
79
80 def discard(self, value):
81 # type: (_T) -> None
82 self.__items.pop(value, None)
83
84 def __repr__(self):
85 # type: () -> str
86 if len(self) == 0:
87 return "OSet()"
88 return f"OSet({list(self)})"
89
90
91 class FMap(Mapping[_T, _T_co]):
92 """ordered frozen hashable mapping"""
93 __slots__ = "__items", "__hash"
94
95 @overload
96 def __init__(self, items):
97 # type: (Mapping[_T, _T_co]) -> None
98 ...
99
100 @overload
101 def __init__(self, items):
102 # type: (Iterable[tuple[_T, _T_co]]) -> None
103 ...
104
105 @overload
106 def __init__(self):
107 # type: () -> None
108 ...
109
110 def __init__(self, items=()):
111 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
112 super().__init__()
113 self.__items = dict(items) # type: dict[_T, _T_co]
114 self.__hash = None # type: None | int
115
116 def __getitem__(self, item):
117 # type: (_T) -> _T_co
118 return self.__items[item]
119
120 def __iter__(self):
121 # type: () -> Iterator[_T]
122 return iter(self.__items)
123
124 def __len__(self):
125 # type: () -> int
126 return len(self.__items)
127
128 def __eq__(self, other):
129 # type: (FMap[Any, Any] | Any) -> bool
130 if isinstance(other, FMap):
131 return self.__items == other.__items
132 return super().__eq__(other)
133
134 def __hash__(self):
135 # type: () -> int
136 if self.__hash is None:
137 self.__hash = hash(frozenset(self.items()))
138 return self.__hash
139
140 def __repr__(self):
141 # type: () -> str
142 return f"FMap({self.__items})"
143
144
145 def trailing_zero_count(v, default=-1):
146 # type: (int, int) -> int
147 without_bit = v & (v - 1) # clear lowest set bit
148 bit = v & ~without_bit # extract lowest set bit
149 return top_set_bit_index(bit, default)
150
151
152 def top_set_bit_index(v, default=-1):
153 # type: (int, int) -> int
154 if v <= 0:
155 return default
156 return v.bit_length() - 1
157
158
159 try:
160 # added in cpython 3.10
161 bit_count = int.bit_count # type: ignore
162 except AttributeError:
163 def bit_count(v):
164 # type: (int) -> int
165 """returns the number of 1 bits in the absolute value of the input"""
166 return bin(abs(v)).count('1')
167
168
169 class BaseBitSet(AbstractSet[int]):
170 __slots__ = "__bits",
171
172 @classmethod
173 @abstractmethod
174 def _frozen(cls):
175 # type: () -> bool
176 return False
177
178 @classmethod
179 def _from_bits(cls, bits):
180 # type: (int) -> Self
181 return cls(bits=bits)
182
183 def __init__(self, items=(), bits=0):
184 # type: (Iterable[int], int) -> None
185 super().__init__()
186 if isinstance(items, BaseBitSet):
187 bits |= items.bits
188 else:
189 for item in items:
190 if item < 0:
191 raise ValueError("can't store negative integers")
192 bits |= 1 << item
193 if bits < 0:
194 raise ValueError("can't store an infinite set")
195 self.__bits = bits
196
197 @property
198 def bits(self):
199 # type: () -> int
200 return self.__bits
201
202 @bits.setter
203 def bits(self, bits):
204 # type: (int) -> None
205 if self._frozen():
206 raise AttributeError("can't write to frozen bitset's bits")
207 if bits < 0:
208 raise ValueError("can't store an infinite set")
209 self.__bits = bits
210
211 def __contains__(self, x):
212 # type: (Any) -> bool
213 if isinstance(x, int) and x >= 0:
214 return (1 << x) & self.bits != 0
215 return False
216
217 def __iter__(self):
218 # type: () -> Iterator[int]
219 bits = self.bits
220 while bits != 0:
221 index = trailing_zero_count(bits)
222 yield index
223 bits -= 1 << index
224
225 def __reversed__(self):
226 # type: () -> Iterator[int]
227 bits = self.bits
228 while bits != 0:
229 index = top_set_bit_index(bits)
230 yield index
231 bits -= 1 << index
232
233 def __len__(self):
234 # type: () -> int
235 return bit_count(self.bits)
236
237 def __repr__(self):
238 # type: () -> str
239 if self.bits == 0:
240 return f"{self.__class__.__name__}()"
241 len_self = len(self)
242 if len_self <= 3:
243 v = list(self)
244 return f"{self.__class__.__name__}({v})"
245 ranges = [] # type: list[range]
246 MAX_RANGES = 5
247 for i in self:
248 if len(ranges) != 0 and ranges[-1].stop == i:
249 ranges[-1] = range(
250 ranges[-1].start, i + ranges[-1].step, ranges[-1].step)
251 elif len(ranges) != 0 and len(ranges[-1]) == 1:
252 start = ranges[-1][0]
253 step = i - start
254 stop = i + step
255 ranges[-1] = range(start, stop, step)
256 elif len(ranges) != 0 and len(ranges[-1]) == 2:
257 single = ranges[-1][0]
258 start = ranges[-1][1]
259 ranges[-1] = range(single, single + 1)
260 step = i - start
261 stop = i + step
262 ranges.append(range(start, stop, step))
263 else:
264 ranges.append(range(i, i + 1))
265 if len(ranges) > MAX_RANGES:
266 break
267 if len(ranges) == 1:
268 return f"{self.__class__.__name__}({ranges[0]})"
269 if len(ranges) <= MAX_RANGES:
270 range_strs = [] # type: list[str]
271 for r in ranges:
272 if len(r) == 1:
273 range_strs.append(str(r[0]))
274 else:
275 range_strs.append(f"*{r}")
276 ranges_str = ", ".join(range_strs)
277 return f"{self.__class__.__name__}([{ranges_str}])"
278 if self.bits > 0xFFFFFFFF and len_self < 10:
279 v = list(self)
280 return f"{self.__class__.__name__}({v})"
281 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
282
283 def __eq__(self, other):
284 # type: (Any) -> bool
285 if not isinstance(other, BaseBitSet):
286 return super().__eq__(other)
287 return self.bits == other.bits
288
289 def __and__(self, other):
290 # type: (Iterable[Any]) -> Self
291 if isinstance(other, BaseBitSet):
292 return self._from_bits(self.bits & other.bits)
293 bits = 0
294 for item in other:
295 if isinstance(item, int) and item >= 0:
296 bits |= 1 << item
297 return self._from_bits(self.bits & bits)
298
299 __rand__ = __and__
300
301 def __or__(self, other):
302 # type: (Iterable[Any]) -> Self
303 if isinstance(other, BaseBitSet):
304 return self._from_bits(self.bits | other.bits)
305 bits = self.bits
306 for item in other:
307 if isinstance(item, int) and item >= 0:
308 bits |= 1 << item
309 return self._from_bits(bits)
310
311 __ror__ = __or__
312
313 def __xor__(self, other):
314 # type: (Iterable[Any]) -> Self
315 if isinstance(other, BaseBitSet):
316 return self._from_bits(self.bits ^ other.bits)
317 bits = self.bits
318 for item in other:
319 if isinstance(item, int) and item >= 0:
320 bits ^= 1 << item
321 return self._from_bits(bits)
322
323 __rxor__ = __xor__
324
325 def __sub__(self, other):
326 # type: (Iterable[Any]) -> Self
327 if isinstance(other, BaseBitSet):
328 return self._from_bits(self.bits & ~other.bits)
329 bits = self.bits
330 for item in other:
331 if isinstance(item, int) and item >= 0:
332 bits &= ~(1 << item)
333 return self._from_bits(bits)
334
335 def __rsub__(self, other):
336 # type: (Iterable[Any]) -> Self
337 if isinstance(other, BaseBitSet):
338 return self._from_bits(~self.bits & other.bits)
339 bits = 0
340 for item in other:
341 if isinstance(item, int) and item >= 0:
342 bits |= 1 << item
343 return self._from_bits(~self.bits & bits)
344
345 def isdisjoint(self, other):
346 # type: (Iterable[Any]) -> bool
347 if isinstance(other, BaseBitSet):
348 return self.bits & other.bits == 0
349 return super().isdisjoint(other)
350
351
352 class BitSet(BaseBitSet, MutableSet[int]):
353 """Mutable Bit Set"""
354
355 @final
356 @classmethod
357 def _frozen(cls):
358 # type: () -> bool
359 return False
360
361 def add(self, value):
362 # type: (int) -> None
363 if value < 0:
364 raise ValueError("can't store negative integers")
365 self.bits |= 1 << value
366
367 def discard(self, value):
368 # type: (int) -> None
369 if value >= 0:
370 self.bits &= ~(1 << value)
371
372 def clear(self):
373 # type: () -> None
374 self.bits = 0
375
376 def __ior__(self, it):
377 # type: (AbstractSet[Any]) -> Self
378 if isinstance(it, BaseBitSet):
379 self.bits |= it.bits
380 return self
381 return super().__ior__(it)
382
383 def __iand__(self, it):
384 # type: (AbstractSet[Any]) -> Self
385 if isinstance(it, BaseBitSet):
386 self.bits &= it.bits
387 return self
388 return super().__iand__(it)
389
390 def __ixor__(self, it):
391 # type: (AbstractSet[Any]) -> Self
392 if isinstance(it, BaseBitSet):
393 self.bits ^= it.bits
394 return self
395 return super().__ixor__(it)
396
397 def __isub__(self, it):
398 # type: (AbstractSet[Any]) -> Self
399 if isinstance(it, BaseBitSet):
400 self.bits &= ~it.bits
401 return self
402 return super().__isub__(it)
403
404
405 class FBitSet(BaseBitSet):
406 """Frozen Bit Set"""
407
408 @final
409 @classmethod
410 def _frozen(cls):
411 # type: () -> bool
412 return True
413
414 def __hash__(self):
415 # type: () -> int
416 return super()._hash()