from abc import ABCMeta, abstractmethod
from collections import defaultdict
-from typing import (AbstractSet, Any, Callable, Iterable, Iterator, Mapping,
- MutableSet, TypeVar, overload)
+from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
+ Mapping, MutableSet, NewType, TypeVar, overload)
from bigint_presentation_code.type_util import Self, final
+from nmutil.plain_data import plain_data
_T_co = TypeVar("_T_co", covariant=True)
_T = TypeVar("_T")
"BaseBitSet",
"bit_count",
"BitSet",
+ "DisjointSets",
+ "DisjointSetsItem",
"FBitSet",
"FMap",
+ "Interned",
"OFSet",
"OSet",
"top_set_bit_index",
"trailing_zero_count",
- "Interned",
]
def __hash__(self):
# type: () -> int
return super()._hash()
+
+
+DisjointSetsItem = NewType("DisjointSetsItem", int)
+
+
+@plain_data()
+@final
+class _DisjointSetsEntry(Generic[_T_co]):
+ __slots__ = "value", "parent", "rank"
+
+ def __init__(self, value, parent, rank):
+ # type: (_T_co, DisjointSetsItem, int) -> None
+ self.value = value
+ self.parent = parent
+ self.rank = rank
+
+
+@final
+class DisjointSets(Generic[_T_co]):
+ """ Disjoint-set data structure, aka. union-find or merge-find
+ https://en.wikipedia.org/wiki/Disjoint-set_data_structure
+ """
+
+ def __init__(self):
+ self.__values = [] # type: list[_DisjointSetsEntry[_T_co]]
+
+ def __entry(self, __key):
+ # type: (DisjointSetsItem) -> _DisjointSetsEntry[_T_co]
+ if __key < 0 or __key >= len(self.__values):
+ raise KeyError(__key)
+ return self.__values[__key]
+
+ def __getitem__(self, __key):
+ # type: (DisjointSetsItem) -> _T_co
+ return self.__entry(__key).value
+
+ def __setitem__(self, __key, __value):
+ # type: (DisjointSetsItem, _T_co) -> None
+ self.__entry(__key).value = __value
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.__values)
+
+ @property
+ def representatives(self):
+ # type: () -> Iterator[DisjointSetsItem]
+ for i, entry in enumerate(self.__values):
+ item = DisjointSetsItem(i)
+ if entry.parent == item:
+ yield item
+
+ def __iter__(self):
+ # type: () -> Iterator[DisjointSetsItem]
+ return map(DisjointSetsItem, range(len(self.__values)))
+
+ def add_new_set(self, value):
+ # type: (_T_co) -> DisjointSetsItem
+ item = DisjointSetsItem(len(self.__values))
+ self.__values.append(_DisjointSetsEntry(
+ value=value, parent=item, rank=0))
+ return item
+
+ def find_representative(self, item):
+ # type: (DisjointSetsItem) -> DisjointSetsItem
+ entry = self.__entry(item)
+ while entry.parent != item:
+ parent_entry = self.__values[entry.parent]
+ item = entry.parent = parent_entry.parent
+ entry = self.__values[item]
+ return item
+
+ def merge(self, __x, __y):
+ # type: (DisjointSetsItem, DisjointSetsItem) -> DisjointSetsItem
+ __x = self.find_representative(__x)
+ __y = self.find_representative(__y)
+ if __x == __y:
+ return __x
+ x_entry = self.__values[__x]
+ y_entry = self.__values[__y]
+ if x_entry.rank < y_entry.rank:
+ __x, __y = __y, __x
+ x_entry, y_entry = y_entry, x_entry
+ y_entry.parent = __x
+ if x_entry.rank == y_entry.rank:
+ x_entry.rank += 1
+ return __x
+
+ def __repr__(self):
+ # type: () -> str
+ sets = defaultdict(
+ list) # type: dict[DisjointSetsItem, list[DisjointSetsItem]]
+ values = {}
+ for item in self:
+ sets[self.find_representative(item)].append(item)
+ values[item] = self[item]
+ return f"DisjointSets(sets={sets}, values={values})"