add DisjointSets, a disjoint-set data structure
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 6 Jan 2023 21:20:11 +0000 (13:20 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 6 Jan 2023 21:20:11 +0000 (13:20 -0800)
src/bigint_presentation_code/util.py

index 9f1e5ab8c92165476d99594ad1267af53b51a1d4..c697e3001cce0c0e6467ba436ae990b59ba81a6c 100644 (file)
@@ -1,9 +1,10 @@
 from abc import ABCMeta, abstractmethod
 from collections import defaultdict
-from typing import (AbstractSet, Any, Callable, Iterable, Iterator, Mapping,
-                    MutableSet, TypeVar, overload)
+from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
+                    Mapping, MutableSet, NewType, TypeVar, overload)
 
 from bigint_presentation_code.type_util import Self, final
+from nmutil.plain_data import plain_data
 
 _T_co = TypeVar("_T_co", covariant=True)
 _T = TypeVar("_T")
@@ -13,13 +14,15 @@ __all__ = [
     "BaseBitSet",
     "bit_count",
     "BitSet",
+    "DisjointSets",
+    "DisjointSetsItem",
     "FBitSet",
     "FMap",
+    "Interned",
     "OFSet",
     "OSet",
     "top_set_bit_index",
     "trailing_zero_count",
-    "Interned",
 ]
 
 
@@ -495,3 +498,100 @@ class FBitSet(BaseBitSet, Interned):
     def __hash__(self):
         # type: () -> int
         return super()._hash()
+
+
+DisjointSetsItem = NewType("DisjointSetsItem", int)
+
+
+@plain_data()
+@final
+class _DisjointSetsEntry(Generic[_T_co]):
+    __slots__ = "value", "parent", "rank"
+
+    def __init__(self, value, parent, rank):
+        # type: (_T_co, DisjointSetsItem, int) -> None
+        self.value = value
+        self.parent = parent
+        self.rank = rank
+
+
+@final
+class DisjointSets(Generic[_T_co]):
+    """ Disjoint-set data structure, aka. union-find or merge-find
+    https://en.wikipedia.org/wiki/Disjoint-set_data_structure
+    """
+
+    def __init__(self):
+        self.__values = []  # type: list[_DisjointSetsEntry[_T_co]]
+
+    def __entry(self, __key):
+        # type: (DisjointSetsItem) -> _DisjointSetsEntry[_T_co]
+        if __key < 0 or __key >= len(self.__values):
+            raise KeyError(__key)
+        return self.__values[__key]
+
+    def __getitem__(self, __key):
+        # type: (DisjointSetsItem) -> _T_co
+        return self.__entry(__key).value
+
+    def __setitem__(self, __key, __value):
+        # type: (DisjointSetsItem, _T_co) -> None
+        self.__entry(__key).value = __value
+
+    def __len__(self):
+        # type: () -> int
+        return len(self.__values)
+
+    @property
+    def representatives(self):
+        # type: () -> Iterator[DisjointSetsItem]
+        for i, entry in enumerate(self.__values):
+            item = DisjointSetsItem(i)
+            if entry.parent == item:
+                yield item
+
+    def __iter__(self):
+        # type: () -> Iterator[DisjointSetsItem]
+        return map(DisjointSetsItem, range(len(self.__values)))
+
+    def add_new_set(self, value):
+        # type: (_T_co) -> DisjointSetsItem
+        item = DisjointSetsItem(len(self.__values))
+        self.__values.append(_DisjointSetsEntry(
+            value=value, parent=item, rank=0))
+        return item
+
+    def find_representative(self, item):
+        # type: (DisjointSetsItem) -> DisjointSetsItem
+        entry = self.__entry(item)
+        while entry.parent != item:
+            parent_entry = self.__values[entry.parent]
+            item = entry.parent = parent_entry.parent
+            entry = self.__values[item]
+        return item
+
+    def merge(self, __x, __y):
+        # type: (DisjointSetsItem, DisjointSetsItem) -> DisjointSetsItem
+        __x = self.find_representative(__x)
+        __y = self.find_representative(__y)
+        if __x == __y:
+            return __x
+        x_entry = self.__values[__x]
+        y_entry = self.__values[__y]
+        if x_entry.rank < y_entry.rank:
+            __x, __y = __y, __x
+            x_entry, y_entry = y_entry, x_entry
+        y_entry.parent = __x
+        if x_entry.rank == y_entry.rank:
+            x_entry.rank += 1
+        return __x
+
+    def __repr__(self):
+        # type: () -> str
+        sets = defaultdict(
+            list)  # type: dict[DisjointSetsItem, list[DisjointSetsItem]]
+        values = {}
+        for item in self:
+            sets[self.find_representative(item)].append(item)
+            values[item] = self[item]
+        return f"DisjointSets(sets={sets}, values={values})"