working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import ABCMeta, abstractmethod
2 from functools import lru_cache
3 from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
4 TypeVar, overload)
5
6 from bigint_presentation_code.type_util import Self, final
7
8 _T_co = TypeVar("_T_co", covariant=True)
9 _T = TypeVar("_T")
10
11 __all__ = [
12 "BaseBitSet",
13 "bit_count",
14 "BitSet",
15 "FBitSet",
16 "FMap",
17 "OFSet",
18 "OSet",
19 "top_set_bit_index",
20 "trailing_zero_count",
21 "InternedMeta",
22 ]
23
24
25 class InternedMeta(ABCMeta):
26 def __init__(self, *args, **kwargs):
27 # type: (*Any, **Any) -> None
28 super().__init__(*args, **kwargs)
29 self.__INTERN_TABLE = {} # type: dict[Any, Any]
30
31 def __intern(self, value):
32 # type: (_T) -> _T
33 value = self.__INTERN_TABLE.setdefault(value, value)
34 if value.__dict__.get("_InternedMeta__interned", False):
35 return value
36 value.__dict__["_InternedMeta__interned"] = True
37 hash_v = hash(value)
38 value.__dict__["__hash__"] = lambda: hash_v
39 old_eq = value.__eq__
40
41 def __eq__(__o):
42 # type: (_T) -> bool
43 if value.__class__ is __o.__class__:
44 return value is __o
45 return old_eq(__o)
46 value.__dict__["__eq__"] = __eq__
47 return value
48
49 def __call__(self, *args, **kwargs):
50 # type: (*Any, **Any) -> Any
51 return self.__intern(super().__call__(*args, **kwargs))
52
53
54 class OFSet(AbstractSet[_T_co], metaclass=InternedMeta):
55 """ ordered frozen set """
56 __slots__ = "__items", "__dict__", "__weakref__"
57
58 def __init__(self, items=()):
59 # type: (Iterable[_T_co]) -> None
60 super().__init__()
61 self.__items = {v: None for v in items}
62
63 def __contains__(self, x):
64 # type: (Any) -> bool
65 return x in self.__items
66
67 def __iter__(self):
68 # type: () -> Iterator[_T_co]
69 return iter(self.__items)
70
71 def __len__(self):
72 # type: () -> int
73 return len(self.__items)
74
75 def __hash__(self):
76 # type: () -> int
77 return self._hash()
78
79 def __repr__(self):
80 # type: () -> str
81 if len(self) == 0:
82 return "OFSet()"
83 return f"OFSet({list(self)})"
84
85
86 class OSet(MutableSet[_T]):
87 """ ordered mutable set """
88 __slots__ = "__items", "__dict__"
89
90 def __init__(self, items=()):
91 # type: (Iterable[_T]) -> None
92 super().__init__()
93 self.__items = {v: None for v in items}
94
95 def __contains__(self, x):
96 # type: (Any) -> bool
97 return x in self.__items
98
99 def __iter__(self):
100 # type: () -> Iterator[_T]
101 return iter(self.__items)
102
103 def __len__(self):
104 # type: () -> int
105 return len(self.__items)
106
107 def add(self, value):
108 # type: (_T) -> None
109 self.__items[value] = None
110
111 def discard(self, value):
112 # type: (_T) -> None
113 self.__items.pop(value, None)
114
115 def __repr__(self):
116 # type: () -> str
117 if len(self) == 0:
118 return "OSet()"
119 return f"OSet({list(self)})"
120
121
122 class FMap(Mapping[_T, _T_co], metaclass=InternedMeta):
123 """ordered frozen hashable mapping"""
124 __slots__ = "__items", "__hash", "__dict__", "__weakref__"
125
126 @overload
127 def __init__(self, items):
128 # type: (Mapping[_T, _T_co]) -> None
129 ...
130
131 @overload
132 def __init__(self, items):
133 # type: (Iterable[tuple[_T, _T_co]]) -> None
134 ...
135
136 @overload
137 def __init__(self):
138 # type: () -> None
139 ...
140
141 def __init__(self, items=()):
142 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
143 super().__init__()
144 self.__items = dict(items) # type: dict[_T, _T_co]
145 self.__hash = None # type: None | int
146
147 def __getitem__(self, item):
148 # type: (_T) -> _T_co
149 return self.__items[item]
150
151 def __iter__(self):
152 # type: () -> Iterator[_T]
153 return iter(self.__items)
154
155 def __len__(self):
156 # type: () -> int
157 return len(self.__items)
158
159 def __eq__(self, other):
160 # type: (FMap[Any, Any] | Any) -> bool
161 if isinstance(other, FMap):
162 return self.__items == other.__items
163 return super().__eq__(other)
164
165 def __hash__(self):
166 # type: () -> int
167 if self.__hash is None:
168 self.__hash = hash(frozenset(self.items()))
169 return self.__hash
170
171 def __repr__(self):
172 # type: () -> str
173 return f"FMap({self.__items})"
174
175
176 def trailing_zero_count(v, default=-1):
177 # type: (int, int) -> int
178 without_bit = v & (v - 1) # clear lowest set bit
179 bit = v & ~without_bit # extract lowest set bit
180 return top_set_bit_index(bit, default)
181
182
183 def top_set_bit_index(v, default=-1):
184 # type: (int, int) -> int
185 if v <= 0:
186 return default
187 return v.bit_length() - 1
188
189
190 try:
191 # added in cpython 3.10
192 bit_count = int.bit_count # type: ignore
193 except AttributeError:
194 def bit_count(v):
195 # type: (int) -> int
196 """returns the number of 1 bits in the absolute value of the input"""
197 return bin(abs(v)).count('1')
198
199
200 class BaseBitSet(AbstractSet[int]):
201 __slots__ = "__bits", "__dict__", "__weakref__"
202
203 @classmethod
204 @abstractmethod
205 def _frozen(cls):
206 # type: () -> bool
207 return False
208
209 @classmethod
210 def _from_bits(cls, bits):
211 # type: (int) -> Self
212 return cls(bits=bits)
213
214 def __init__(self, items=(), bits=0):
215 # type: (Iterable[int], int) -> None
216 super().__init__()
217 if isinstance(items, BaseBitSet):
218 bits |= items.bits
219 else:
220 for item in items:
221 if item < 0:
222 raise ValueError("can't store negative integers")
223 bits |= 1 << item
224 if bits < 0:
225 raise ValueError("can't store an infinite set")
226 self.__bits = bits
227
228 @property
229 def bits(self):
230 # type: () -> int
231 return self.__bits
232
233 @bits.setter
234 def bits(self, bits):
235 # type: (int) -> None
236 if self._frozen():
237 raise AttributeError("can't write to frozen bitset's bits")
238 if bits < 0:
239 raise ValueError("can't store an infinite set")
240 self.__bits = bits
241
242 def __contains__(self, x):
243 # type: (Any) -> bool
244 if isinstance(x, int) and x >= 0:
245 return (1 << x) & self.bits != 0
246 return False
247
248 def __iter__(self):
249 # type: () -> Iterator[int]
250 bits = self.bits
251 while bits != 0:
252 index = trailing_zero_count(bits)
253 yield index
254 bits -= 1 << index
255
256 def __reversed__(self):
257 # type: () -> Iterator[int]
258 bits = self.bits
259 while bits != 0:
260 index = top_set_bit_index(bits)
261 yield index
262 bits -= 1 << index
263
264 def __len__(self):
265 # type: () -> int
266 return bit_count(self.bits)
267
268 def __repr__(self):
269 # type: () -> str
270 if self.bits == 0:
271 return f"{self.__class__.__name__}()"
272 len_self = len(self)
273 if len_self <= 3:
274 v = list(self)
275 return f"{self.__class__.__name__}({v})"
276 ranges = [] # type: list[range]
277 MAX_RANGES = 5
278 for i in self:
279 if len(ranges) != 0 and ranges[-1].stop == i:
280 ranges[-1] = range(
281 ranges[-1].start, i + ranges[-1].step, ranges[-1].step)
282 elif len(ranges) != 0 and len(ranges[-1]) == 1:
283 start = ranges[-1][0]
284 step = i - start
285 stop = i + step
286 ranges[-1] = range(start, stop, step)
287 elif len(ranges) != 0 and len(ranges[-1]) == 2:
288 single = ranges[-1][0]
289 start = ranges[-1][1]
290 ranges[-1] = range(single, single + 1)
291 step = i - start
292 stop = i + step
293 ranges.append(range(start, stop, step))
294 else:
295 ranges.append(range(i, i + 1))
296 if len(ranges) > MAX_RANGES:
297 break
298 if len(ranges) == 1:
299 return f"{self.__class__.__name__}({ranges[0]})"
300 if len(ranges) <= MAX_RANGES:
301 range_strs = [] # type: list[str]
302 for r in ranges:
303 if len(r) == 1:
304 range_strs.append(str(r[0]))
305 else:
306 range_strs.append(f"*{r}")
307 ranges_str = ", ".join(range_strs)
308 return f"{self.__class__.__name__}([{ranges_str}])"
309 if self.bits > 0xFFFFFFFF and len_self < 10:
310 v = list(self)
311 return f"{self.__class__.__name__}({v})"
312 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
313
314 def __eq__(self, other):
315 # type: (Any) -> bool
316 if not isinstance(other, BaseBitSet):
317 return super().__eq__(other)
318 return self.bits == other.bits
319
320 def __and__(self, other):
321 # type: (Iterable[Any]) -> Self
322 if isinstance(other, BaseBitSet):
323 return self._from_bits(self.bits & other.bits)
324 bits = 0
325 for item in other:
326 if isinstance(item, int) and item >= 0:
327 bits |= 1 << item
328 return self._from_bits(self.bits & bits)
329
330 __rand__ = __and__
331
332 def __or__(self, other):
333 # type: (Iterable[Any]) -> Self
334 if isinstance(other, BaseBitSet):
335 return self._from_bits(self.bits | other.bits)
336 bits = self.bits
337 for item in other:
338 if isinstance(item, int) and item >= 0:
339 bits |= 1 << item
340 return self._from_bits(bits)
341
342 __ror__ = __or__
343
344 def __xor__(self, other):
345 # type: (Iterable[Any]) -> Self
346 if isinstance(other, BaseBitSet):
347 return self._from_bits(self.bits ^ other.bits)
348 bits = self.bits
349 for item in other:
350 if isinstance(item, int) and item >= 0:
351 bits ^= 1 << item
352 return self._from_bits(bits)
353
354 __rxor__ = __xor__
355
356 def __sub__(self, other):
357 # type: (Iterable[Any]) -> Self
358 if isinstance(other, BaseBitSet):
359 return self._from_bits(self.bits & ~other.bits)
360 bits = self.bits
361 for item in other:
362 if isinstance(item, int) and item >= 0:
363 bits &= ~(1 << item)
364 return self._from_bits(bits)
365
366 def __rsub__(self, other):
367 # type: (Iterable[Any]) -> Self
368 if isinstance(other, BaseBitSet):
369 return self._from_bits(~self.bits & other.bits)
370 bits = 0
371 for item in other:
372 if isinstance(item, int) and item >= 0:
373 bits |= 1 << item
374 return self._from_bits(~self.bits & bits)
375
376 def isdisjoint(self, other):
377 # type: (Iterable[Any]) -> bool
378 if isinstance(other, BaseBitSet):
379 return self.bits & other.bits == 0
380 return super().isdisjoint(other)
381
382
383 class BitSet(BaseBitSet, MutableSet[int]):
384 """Mutable Bit Set"""
385
386 @final
387 @classmethod
388 def _frozen(cls):
389 # type: () -> bool
390 return False
391
392 def add(self, value):
393 # type: (int) -> None
394 if value < 0:
395 raise ValueError("can't store negative integers")
396 self.bits |= 1 << value
397
398 def discard(self, value):
399 # type: (int) -> None
400 if value >= 0:
401 self.bits &= ~(1 << value)
402
403 def clear(self):
404 # type: () -> None
405 self.bits = 0
406
407 def __ior__(self, it):
408 # type: (AbstractSet[Any]) -> Self
409 if isinstance(it, BaseBitSet):
410 self.bits |= it.bits
411 return self
412 return super().__ior__(it)
413
414 def __iand__(self, it):
415 # type: (AbstractSet[Any]) -> Self
416 if isinstance(it, BaseBitSet):
417 self.bits &= it.bits
418 return self
419 return super().__iand__(it)
420
421 def __ixor__(self, it):
422 # type: (AbstractSet[Any]) -> Self
423 if isinstance(it, BaseBitSet):
424 self.bits ^= it.bits
425 return self
426 return super().__ixor__(it)
427
428 def __isub__(self, it):
429 # type: (AbstractSet[Any]) -> Self
430 if isinstance(it, BaseBitSet):
431 self.bits &= ~it.bits
432 return self
433 return super().__isub__(it)
434
435
436 class FBitSet(BaseBitSet, metaclass=InternedMeta):
437 """Frozen Bit Set"""
438
439 @final
440 @classmethod
441 def _frozen(cls):
442 # type: () -> bool
443 return True
444
445 def __hash__(self):
446 # type: () -> int
447 return super()._hash()