9d2b7b1af8fdc4cb099311d696965959305e3b2b
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import ABCMeta, abstractmethod
2 from collections import defaultdict
3 from typing import (AbstractSet, Any, Callable, Iterable, Iterator, Mapping,
4 MutableSet, TypeVar, overload)
5
6 from bigint_presentation_code.type_util import Self, final
7
8 _T_co = TypeVar("_T_co", covariant=True)
9 _T = TypeVar("_T")
10
11 __all__ = [
12 "BaseBitSet",
13 "bit_count",
14 "BitSet",
15 "FBitSet",
16 "FMap",
17 "OFSet",
18 "OSet",
19 "top_set_bit_index",
20 "trailing_zero_count",
21 "Interned",
22 ]
23
24
25 class _InternedMeta(ABCMeta):
26 def __call__(self, *args: Any, **kwds: Any) -> Any:
27 return super().__call__(*args, **kwds)._Interned__intern()
28
29
30 class Interned(metaclass=_InternedMeta):
31 def __init_intern(self):
32 # type: (Self) -> Self
33 cls = type(self)
34 old_hash = cls.__hash__
35 old_hash = getattr(old_hash, "_Interned__old_hash", old_hash)
36 old_eq = cls.__eq__
37 old_eq = getattr(old_eq, "_Interned__old_eq", old_eq)
38
39 def __hash__(self):
40 # type: (Self) -> int
41 return self._Interned__hash # type: ignore
42 __hash__._Interned__old_hash = old_hash # type: ignore
43 cls.__hash__ = __hash__
44
45 def __eq__(self, # type: Self
46 __other, # type: Any
47 *, __eq=old_eq, # type: Callable[[Self, Any], bool]
48 ):
49 # type: (...) -> bool
50 if self.__class__ is __other.__class__:
51 return self is __other
52 return __eq(self, __other)
53 __eq__._Interned__old_eq = old_eq # type: ignore
54 cls.__eq__ = __eq__
55
56 table = defaultdict(list) # type: dict[int, list[Self]]
57
58 def __intern(self, # type: Self
59 *, __hash=old_hash, # type: Callable[[Self], int]
60 __eq=old_eq, # type: Callable[[Self, Any], bool]
61 __table=table, # type: dict[int, list[Self]]
62 __NotImplemented=NotImplemented, # type: Any
63 ):
64 # type: (...) -> Self
65 h = __hash(self)
66 bucket = __table[h]
67 for i in bucket:
68 v = __eq(self, i)
69 if v is not __NotImplemented and v:
70 return i
71 self.__dict__["_Interned__hash"] = h
72 bucket.append(self)
73 return self
74 cls._Interned__intern = __intern
75 return __intern(self)
76
77 _Interned__intern = __init_intern
78
79
80 class OFSet(AbstractSet[_T_co], Interned):
81 """ ordered frozen set """
82 __slots__ = "__items", "__dict__", "__weakref__"
83
84 def __init__(self, items=()):
85 # type: (Iterable[_T_co]) -> None
86 super().__init__()
87 if isinstance(items, OFSet):
88 self.__items = items.__items
89 else:
90 self.__items = {v: None for v in items}
91
92 def __contains__(self, x):
93 # type: (Any) -> bool
94 return x in self.__items
95
96 def __iter__(self):
97 # type: () -> Iterator[_T_co]
98 return iter(self.__items)
99
100 def __len__(self):
101 # type: () -> int
102 return len(self.__items)
103
104 def __hash__(self):
105 # type: () -> int
106 return self._hash()
107
108 def __repr__(self):
109 # type: () -> str
110 if len(self) == 0:
111 return "OFSet()"
112 return f"OFSet({list(self)})"
113
114
115 class OSet(MutableSet[_T]):
116 """ ordered mutable set """
117 __slots__ = "__items", "__dict__"
118
119 def __init__(self, items=()):
120 # type: (Iterable[_T]) -> None
121 super().__init__()
122 self.__items = {v: None for v in items}
123
124 def __contains__(self, x):
125 # type: (Any) -> bool
126 return x in self.__items
127
128 def __iter__(self):
129 # type: () -> Iterator[_T]
130 return iter(self.__items)
131
132 def __len__(self):
133 # type: () -> int
134 return len(self.__items)
135
136 def add(self, value):
137 # type: (_T) -> None
138 self.__items[value] = None
139
140 def discard(self, value):
141 # type: (_T) -> None
142 self.__items.pop(value, None)
143
144 def __repr__(self):
145 # type: () -> str
146 if len(self) == 0:
147 return "OSet()"
148 return f"OSet({list(self)})"
149
150
151 class FMap(Mapping[_T, _T_co], Interned):
152 """ordered frozen hashable mapping"""
153 __slots__ = "__items", "__hash", "__dict__", "__weakref__"
154
155 @overload
156 def __init__(self, items):
157 # type: (Mapping[_T, _T_co]) -> None
158 ...
159
160 @overload
161 def __init__(self, items):
162 # type: (Iterable[tuple[_T, _T_co]]) -> None
163 ...
164
165 @overload
166 def __init__(self):
167 # type: () -> None
168 ...
169
170 def __init__(self, items=()):
171 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
172 super().__init__()
173 self.__items = dict(items) # type: dict[_T, _T_co]
174 self.__hash = None # type: None | int
175
176 def __getitem__(self, item):
177 # type: (_T) -> _T_co
178 return self.__items[item]
179
180 def __iter__(self):
181 # type: () -> Iterator[_T]
182 return iter(self.__items)
183
184 def __len__(self):
185 # type: () -> int
186 return len(self.__items)
187
188 def __eq__(self, other):
189 # type: (FMap[Any, Any] | Any) -> bool
190 if isinstance(other, FMap):
191 return self.__items == other.__items
192 return super().__eq__(other)
193
194 def __hash__(self):
195 # type: () -> int
196 if self.__hash is None:
197 self.__hash = hash(frozenset(self.items()))
198 return self.__hash
199
200 def __repr__(self):
201 # type: () -> str
202 return f"FMap({self.__items})"
203
204
205 def trailing_zero_count(v, default=-1):
206 # type: (int, int) -> int
207 without_bit = v & (v - 1) # clear lowest set bit
208 bit = v & ~without_bit # extract lowest set bit
209 return top_set_bit_index(bit, default)
210
211
212 def top_set_bit_index(v, default=-1):
213 # type: (int, int) -> int
214 if v <= 0:
215 return default
216 return v.bit_length() - 1
217
218
219 try:
220 # added in cpython 3.10
221 bit_count = int.bit_count # type: ignore
222 except AttributeError:
223 def bit_count(v):
224 # type: (int) -> int
225 """returns the number of 1 bits in the absolute value of the input"""
226 return bin(abs(v)).count('1')
227
228
229 class BaseBitSet(AbstractSet[int]):
230 __slots__ = "__bits", "__dict__", "__weakref__"
231
232 @classmethod
233 @abstractmethod
234 def _frozen(cls):
235 # type: () -> bool
236 return False
237
238 @classmethod
239 def _from_bits(cls, bits):
240 # type: (int) -> Self
241 return cls(bits=bits)
242
243 def __init__(self, items=(), bits=0):
244 # type: (Iterable[int], int) -> None
245 super().__init__()
246 if isinstance(items, BaseBitSet):
247 bits |= items.bits
248 else:
249 for item in items:
250 if item < 0:
251 raise ValueError("can't store negative integers")
252 bits |= 1 << item
253 if bits < 0:
254 raise ValueError("can't store an infinite set")
255 self.__bits = bits
256
257 @property
258 def bits(self):
259 # type: () -> int
260 return self.__bits
261
262 @bits.setter
263 def bits(self, bits):
264 # type: (int) -> None
265 if self._frozen():
266 raise AttributeError("can't write to frozen bitset's bits")
267 if bits < 0:
268 raise ValueError("can't store an infinite set")
269 self.__bits = bits
270
271 def __contains__(self, x):
272 # type: (Any) -> bool
273 if isinstance(x, int) and x >= 0:
274 return (1 << x) & self.bits != 0
275 return False
276
277 def __iter__(self):
278 # type: () -> Iterator[int]
279 bits = self.bits
280 while bits != 0:
281 index = trailing_zero_count(bits)
282 yield index
283 bits -= 1 << index
284
285 def __reversed__(self):
286 # type: () -> Iterator[int]
287 bits = self.bits
288 while bits != 0:
289 index = top_set_bit_index(bits)
290 yield index
291 bits -= 1 << index
292
293 def __len__(self):
294 # type: () -> int
295 return bit_count(self.bits)
296
297 def __repr__(self):
298 # type: () -> str
299 if self.bits == 0:
300 return f"{self.__class__.__name__}()"
301 len_self = len(self)
302 if len_self <= 3:
303 v = list(self)
304 return f"{self.__class__.__name__}({v})"
305 ranges = [] # type: list[range]
306 MAX_RANGES = 5
307 for i in self:
308 if len(ranges) != 0 and ranges[-1].stop == i:
309 ranges[-1] = range(
310 ranges[-1].start, i + ranges[-1].step, ranges[-1].step)
311 elif len(ranges) != 0 and len(ranges[-1]) == 1:
312 start = ranges[-1][0]
313 step = i - start
314 stop = i + step
315 ranges[-1] = range(start, stop, step)
316 elif len(ranges) != 0 and len(ranges[-1]) == 2:
317 single = ranges[-1][0]
318 start = ranges[-1][1]
319 ranges[-1] = range(single, single + 1)
320 step = i - start
321 stop = i + step
322 ranges.append(range(start, stop, step))
323 else:
324 ranges.append(range(i, i + 1))
325 if len(ranges) > MAX_RANGES:
326 break
327 if len(ranges) == 1:
328 return f"{self.__class__.__name__}({ranges[0]})"
329 if len(ranges) <= MAX_RANGES:
330 range_strs = [] # type: list[str]
331 for r in ranges:
332 if len(r) == 1:
333 range_strs.append(str(r[0]))
334 else:
335 range_strs.append(f"*{r}")
336 ranges_str = ", ".join(range_strs)
337 return f"{self.__class__.__name__}([{ranges_str}])"
338 if self.bits > 0xFFFFFFFF and len_self < 10:
339 v = list(self)
340 return f"{self.__class__.__name__}({v})"
341 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
342
343 def __eq__(self, other):
344 # type: (Any) -> bool
345 if not isinstance(other, BaseBitSet):
346 return super().__eq__(other)
347 return self.bits == other.bits
348
349 def __and__(self, other):
350 # type: (Iterable[Any]) -> Self
351 if isinstance(other, BaseBitSet):
352 return self._from_bits(self.bits & other.bits)
353 bits = 0
354 for item in other:
355 if isinstance(item, int) and item >= 0:
356 bits |= 1 << item
357 return self._from_bits(self.bits & bits)
358
359 __rand__ = __and__
360
361 def __or__(self, other):
362 # type: (Iterable[Any]) -> Self
363 if isinstance(other, BaseBitSet):
364 return self._from_bits(self.bits | other.bits)
365 bits = self.bits
366 for item in other:
367 if isinstance(item, int) and item >= 0:
368 bits |= 1 << item
369 return self._from_bits(bits)
370
371 __ror__ = __or__
372
373 def __xor__(self, other):
374 # type: (Iterable[Any]) -> Self
375 if isinstance(other, BaseBitSet):
376 return self._from_bits(self.bits ^ other.bits)
377 bits = self.bits
378 for item in other:
379 if isinstance(item, int) and item >= 0:
380 bits ^= 1 << item
381 return self._from_bits(bits)
382
383 __rxor__ = __xor__
384
385 def __sub__(self, other):
386 # type: (Iterable[Any]) -> Self
387 if isinstance(other, BaseBitSet):
388 return self._from_bits(self.bits & ~other.bits)
389 bits = self.bits
390 for item in other:
391 if isinstance(item, int) and item >= 0:
392 bits &= ~(1 << item)
393 return self._from_bits(bits)
394
395 def __rsub__(self, other):
396 # type: (Iterable[Any]) -> Self
397 if isinstance(other, BaseBitSet):
398 return self._from_bits(~self.bits & other.bits)
399 bits = 0
400 for item in other:
401 if isinstance(item, int) and item >= 0:
402 bits |= 1 << item
403 return self._from_bits(~self.bits & bits)
404
405 def isdisjoint(self, other):
406 # type: (Iterable[Any]) -> bool
407 if isinstance(other, BaseBitSet):
408 return self.bits & other.bits == 0
409 return super().isdisjoint(other)
410
411
412 class BitSet(BaseBitSet, MutableSet[int]):
413 """Mutable Bit Set"""
414
415 @final
416 @classmethod
417 def _frozen(cls):
418 # type: () -> bool
419 return False
420
421 def add(self, value):
422 # type: (int) -> None
423 if value < 0:
424 raise ValueError("can't store negative integers")
425 self.bits |= 1 << value
426
427 def discard(self, value):
428 # type: (int) -> None
429 if value >= 0:
430 self.bits &= ~(1 << value)
431
432 def clear(self):
433 # type: () -> None
434 self.bits = 0
435
436 def __ior__(self, it):
437 # type: (AbstractSet[Any]) -> Self
438 if isinstance(it, BaseBitSet):
439 self.bits |= it.bits
440 return self
441 return super().__ior__(it)
442
443 def __iand__(self, it):
444 # type: (AbstractSet[Any]) -> Self
445 if isinstance(it, BaseBitSet):
446 self.bits &= it.bits
447 return self
448 return super().__iand__(it)
449
450 def __ixor__(self, it):
451 # type: (AbstractSet[Any]) -> Self
452 if isinstance(it, BaseBitSet):
453 self.bits ^= it.bits
454 return self
455 return super().__ixor__(it)
456
457 def __isub__(self, it):
458 # type: (AbstractSet[Any]) -> Self
459 if isinstance(it, BaseBitSet):
460 self.bits &= ~it.bits
461 return self
462 return super().__isub__(it)
463
464
465 class FBitSet(BaseBitSet, Interned):
466 """Frozen Bit Set"""
467
468 @final
469 @classmethod
470 def _frozen(cls):
471 # type: () -> bool
472 return True
473
474 def __hash__(self):
475 # type: () -> int
476 return super()._hash()