87c6975425dbfd4fdec64df750a0110924c648ee
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import ABCMeta, abstractmethod
2 from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
3 TypeVar, overload)
4
5 from bigint_presentation_code.type_util import Self, final
6
7 _T_co = TypeVar("_T_co", covariant=True)
8 _T = TypeVar("_T")
9
10 __all__ = [
11 "BaseBitSet",
12 "bit_count",
13 "BitSet",
14 "FBitSet",
15 "FMap",
16 "OFSet",
17 "OSet",
18 "top_set_bit_index",
19 "trailing_zero_count",
20 "InternedMeta",
21 ]
22
23
24 class InternedMeta(ABCMeta):
25 def __init__(self, *args, **kwargs):
26 # type: (*Any, **Any) -> None
27 super().__init__(*args, **kwargs)
28 self.__INTERN_TABLE = {} # type: dict[Any, Any]
29
30 def __intern(self, value):
31 # type: (_T) -> _T
32 value = self.__INTERN_TABLE.setdefault(value, value)
33 if value.__dict__.get("_InternedMeta__interned", False):
34 return value
35 value.__dict__["_InternedMeta__interned"] = True
36 hash_v = hash(value)
37 value.__dict__["__hash__"] = lambda: hash_v
38 old_eq = value.__eq__
39
40 def __eq__(__o):
41 # type: (_T) -> bool
42 if value.__class__ is __o.__class__:
43 return value is __o
44 return old_eq(__o)
45 value.__dict__["__eq__"] = __eq__
46 return value
47
48 def __call__(self, *args, **kwargs):
49 # type: (*Any, **Any) -> Any
50 return self.__intern(super().__call__(*args, **kwargs))
51
52
53 class OFSet(AbstractSet[_T_co], metaclass=InternedMeta):
54 """ ordered frozen set """
55 __slots__ = "__items", "__dict__", "__weakref__"
56
57 def __init__(self, items=()):
58 # type: (Iterable[_T_co]) -> None
59 super().__init__()
60 self.__items = {v: None for v in items}
61
62 def __contains__(self, x):
63 # type: (Any) -> bool
64 return x in self.__items
65
66 def __iter__(self):
67 # type: () -> Iterator[_T_co]
68 return iter(self.__items)
69
70 def __len__(self):
71 # type: () -> int
72 return len(self.__items)
73
74 def __hash__(self):
75 # type: () -> int
76 return self._hash()
77
78 def __repr__(self):
79 # type: () -> str
80 if len(self) == 0:
81 return "OFSet()"
82 return f"OFSet({list(self)})"
83
84
85 class OSet(MutableSet[_T]):
86 """ ordered mutable set """
87 __slots__ = "__items", "__dict__"
88
89 def __init__(self, items=()):
90 # type: (Iterable[_T]) -> None
91 super().__init__()
92 self.__items = {v: None for v in items}
93
94 def __contains__(self, x):
95 # type: (Any) -> bool
96 return x in self.__items
97
98 def __iter__(self):
99 # type: () -> Iterator[_T]
100 return iter(self.__items)
101
102 def __len__(self):
103 # type: () -> int
104 return len(self.__items)
105
106 def add(self, value):
107 # type: (_T) -> None
108 self.__items[value] = None
109
110 def discard(self, value):
111 # type: (_T) -> None
112 self.__items.pop(value, None)
113
114 def __repr__(self):
115 # type: () -> str
116 if len(self) == 0:
117 return "OSet()"
118 return f"OSet({list(self)})"
119
120
121 class FMap(Mapping[_T, _T_co], metaclass=InternedMeta):
122 """ordered frozen hashable mapping"""
123 __slots__ = "__items", "__hash", "__dict__", "__weakref__"
124
125 @overload
126 def __init__(self, items):
127 # type: (Mapping[_T, _T_co]) -> None
128 ...
129
130 @overload
131 def __init__(self, items):
132 # type: (Iterable[tuple[_T, _T_co]]) -> None
133 ...
134
135 @overload
136 def __init__(self):
137 # type: () -> None
138 ...
139
140 def __init__(self, items=()):
141 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
142 super().__init__()
143 self.__items = dict(items) # type: dict[_T, _T_co]
144 self.__hash = None # type: None | int
145
146 def __getitem__(self, item):
147 # type: (_T) -> _T_co
148 return self.__items[item]
149
150 def __iter__(self):
151 # type: () -> Iterator[_T]
152 return iter(self.__items)
153
154 def __len__(self):
155 # type: () -> int
156 return len(self.__items)
157
158 def __eq__(self, other):
159 # type: (FMap[Any, Any] | Any) -> bool
160 if isinstance(other, FMap):
161 return self.__items == other.__items
162 return super().__eq__(other)
163
164 def __hash__(self):
165 # type: () -> int
166 if self.__hash is None:
167 self.__hash = hash(frozenset(self.items()))
168 return self.__hash
169
170 def __repr__(self):
171 # type: () -> str
172 return f"FMap({self.__items})"
173
174
175 def trailing_zero_count(v, default=-1):
176 # type: (int, int) -> int
177 without_bit = v & (v - 1) # clear lowest set bit
178 bit = v & ~without_bit # extract lowest set bit
179 return top_set_bit_index(bit, default)
180
181
182 def top_set_bit_index(v, default=-1):
183 # type: (int, int) -> int
184 if v <= 0:
185 return default
186 return v.bit_length() - 1
187
188
189 try:
190 # added in cpython 3.10
191 bit_count = int.bit_count # type: ignore
192 except AttributeError:
193 def bit_count(v):
194 # type: (int) -> int
195 """returns the number of 1 bits in the absolute value of the input"""
196 return bin(abs(v)).count('1')
197
198
199 class BaseBitSet(AbstractSet[int]):
200 __slots__ = "__bits", "__dict__", "__weakref__"
201
202 @classmethod
203 @abstractmethod
204 def _frozen(cls):
205 # type: () -> bool
206 return False
207
208 @classmethod
209 def _from_bits(cls, bits):
210 # type: (int) -> Self
211 return cls(bits=bits)
212
213 def __init__(self, items=(), bits=0):
214 # type: (Iterable[int], int) -> None
215 super().__init__()
216 if isinstance(items, BaseBitSet):
217 bits |= items.bits
218 else:
219 for item in items:
220 if item < 0:
221 raise ValueError("can't store negative integers")
222 bits |= 1 << item
223 if bits < 0:
224 raise ValueError("can't store an infinite set")
225 self.__bits = bits
226
227 @property
228 def bits(self):
229 # type: () -> int
230 return self.__bits
231
232 @bits.setter
233 def bits(self, bits):
234 # type: (int) -> None
235 if self._frozen():
236 raise AttributeError("can't write to frozen bitset's bits")
237 if bits < 0:
238 raise ValueError("can't store an infinite set")
239 self.__bits = bits
240
241 def __contains__(self, x):
242 # type: (Any) -> bool
243 if isinstance(x, int) and x >= 0:
244 return (1 << x) & self.bits != 0
245 return False
246
247 def __iter__(self):
248 # type: () -> Iterator[int]
249 bits = self.bits
250 while bits != 0:
251 index = trailing_zero_count(bits)
252 yield index
253 bits -= 1 << index
254
255 def __reversed__(self):
256 # type: () -> Iterator[int]
257 bits = self.bits
258 while bits != 0:
259 index = top_set_bit_index(bits)
260 yield index
261 bits -= 1 << index
262
263 def __len__(self):
264 # type: () -> int
265 return bit_count(self.bits)
266
267 def __repr__(self):
268 # type: () -> str
269 if self.bits == 0:
270 return f"{self.__class__.__name__}()"
271 len_self = len(self)
272 if len_self <= 3:
273 v = list(self)
274 return f"{self.__class__.__name__}({v})"
275 ranges = [] # type: list[range]
276 MAX_RANGES = 5
277 for i in self:
278 if len(ranges) != 0 and ranges[-1].stop == i:
279 ranges[-1] = range(
280 ranges[-1].start, i + ranges[-1].step, ranges[-1].step)
281 elif len(ranges) != 0 and len(ranges[-1]) == 1:
282 start = ranges[-1][0]
283 step = i - start
284 stop = i + step
285 ranges[-1] = range(start, stop, step)
286 elif len(ranges) != 0 and len(ranges[-1]) == 2:
287 single = ranges[-1][0]
288 start = ranges[-1][1]
289 ranges[-1] = range(single, single + 1)
290 step = i - start
291 stop = i + step
292 ranges.append(range(start, stop, step))
293 else:
294 ranges.append(range(i, i + 1))
295 if len(ranges) > MAX_RANGES:
296 break
297 if len(ranges) == 1:
298 return f"{self.__class__.__name__}({ranges[0]})"
299 if len(ranges) <= MAX_RANGES:
300 range_strs = [] # type: list[str]
301 for r in ranges:
302 if len(r) == 1:
303 range_strs.append(str(r[0]))
304 else:
305 range_strs.append(f"*{r}")
306 ranges_str = ", ".join(range_strs)
307 return f"{self.__class__.__name__}([{ranges_str}])"
308 if self.bits > 0xFFFFFFFF and len_self < 10:
309 v = list(self)
310 return f"{self.__class__.__name__}({v})"
311 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
312
313 def __eq__(self, other):
314 # type: (Any) -> bool
315 if not isinstance(other, BaseBitSet):
316 return super().__eq__(other)
317 return self.bits == other.bits
318
319 def __and__(self, other):
320 # type: (Iterable[Any]) -> Self
321 if isinstance(other, BaseBitSet):
322 return self._from_bits(self.bits & other.bits)
323 bits = 0
324 for item in other:
325 if isinstance(item, int) and item >= 0:
326 bits |= 1 << item
327 return self._from_bits(self.bits & bits)
328
329 __rand__ = __and__
330
331 def __or__(self, other):
332 # type: (Iterable[Any]) -> Self
333 if isinstance(other, BaseBitSet):
334 return self._from_bits(self.bits | other.bits)
335 bits = self.bits
336 for item in other:
337 if isinstance(item, int) and item >= 0:
338 bits |= 1 << item
339 return self._from_bits(bits)
340
341 __ror__ = __or__
342
343 def __xor__(self, other):
344 # type: (Iterable[Any]) -> Self
345 if isinstance(other, BaseBitSet):
346 return self._from_bits(self.bits ^ other.bits)
347 bits = self.bits
348 for item in other:
349 if isinstance(item, int) and item >= 0:
350 bits ^= 1 << item
351 return self._from_bits(bits)
352
353 __rxor__ = __xor__
354
355 def __sub__(self, other):
356 # type: (Iterable[Any]) -> Self
357 if isinstance(other, BaseBitSet):
358 return self._from_bits(self.bits & ~other.bits)
359 bits = self.bits
360 for item in other:
361 if isinstance(item, int) and item >= 0:
362 bits &= ~(1 << item)
363 return self._from_bits(bits)
364
365 def __rsub__(self, other):
366 # type: (Iterable[Any]) -> Self
367 if isinstance(other, BaseBitSet):
368 return self._from_bits(~self.bits & other.bits)
369 bits = 0
370 for item in other:
371 if isinstance(item, int) and item >= 0:
372 bits |= 1 << item
373 return self._from_bits(~self.bits & bits)
374
375 def isdisjoint(self, other):
376 # type: (Iterable[Any]) -> bool
377 if isinstance(other, BaseBitSet):
378 return self.bits & other.bits == 0
379 return super().isdisjoint(other)
380
381
382 class BitSet(BaseBitSet, MutableSet[int]):
383 """Mutable Bit Set"""
384
385 @final
386 @classmethod
387 def _frozen(cls):
388 # type: () -> bool
389 return False
390
391 def add(self, value):
392 # type: (int) -> None
393 if value < 0:
394 raise ValueError("can't store negative integers")
395 self.bits |= 1 << value
396
397 def discard(self, value):
398 # type: (int) -> None
399 if value >= 0:
400 self.bits &= ~(1 << value)
401
402 def clear(self):
403 # type: () -> None
404 self.bits = 0
405
406 def __ior__(self, it):
407 # type: (AbstractSet[Any]) -> Self
408 if isinstance(it, BaseBitSet):
409 self.bits |= it.bits
410 return self
411 return super().__ior__(it)
412
413 def __iand__(self, it):
414 # type: (AbstractSet[Any]) -> Self
415 if isinstance(it, BaseBitSet):
416 self.bits &= it.bits
417 return self
418 return super().__iand__(it)
419
420 def __ixor__(self, it):
421 # type: (AbstractSet[Any]) -> Self
422 if isinstance(it, BaseBitSet):
423 self.bits ^= it.bits
424 return self
425 return super().__ixor__(it)
426
427 def __isub__(self, it):
428 # type: (AbstractSet[Any]) -> Self
429 if isinstance(it, BaseBitSet):
430 self.bits &= ~it.bits
431 return self
432 return super().__isub__(it)
433
434
435 class FBitSet(BaseBitSet, metaclass=InternedMeta):
436 """Frozen Bit Set"""
437
438 @final
439 @classmethod
440 def _frozen(cls):
441 # type: () -> bool
442 return True
443
444 def __hash__(self):
445 # type: () -> int
446 return super()._hash()