change plain_data to ignore more base classes, so it'll work with ABCs and stuff
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 13 Oct 2022 01:33:17 +0000 (18:33 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 13 Oct 2022 01:33:17 +0000 (18:33 -0700)
src/nmutil/plain_data.py
src/nmutil/test/test_plain_data.py

index ace4d61a890f1b80cfbd2bc11288715ca650fd35..0ebcb5d2b04500de537ccb1b36cb107bd188fd98 100644 (file)
@@ -15,6 +15,52 @@ class __NotSet:
 __NOT_SET = __NotSet()
 
 
+def __ignored_classes():
+    classes = [object]  # type: list[type]
+
+    from abc import ABC
+
+    classes += [ABC]
+
+    from typing import (
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound)
+
+    classes += [
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound]
+
+    from collections.abc import (
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence)
+
+    classes += [
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence]
+
+    # rest aren't supported by python 3.7, so try to import them and skip if
+    # that errors
+
+    try:
+        # typing_extensions uses typing.Protocol if available
+        from typing_extensions import Protocol
+        classes.append(Protocol)
+    except ImportError:
+        pass
+
+    for cls in classes:
+        yield from cls.__mro__
+
+
+__IGNORED_CLASSES = frozenset(__ignored_classes())
+
+
 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     if not isinstance(cls, type):
         raise TypeError(
@@ -28,7 +74,13 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     any_parents_have_dict = False
     any_parents_have_weakref = False
     for cur_cls in reversed(cls.__mro__):
-        if cur_cls is object:
+        d = getattr(cur_cls, "__dict__", {})
+        if cur_cls is not cls:
+            if "__dict__" in d:
+                any_parents_have_dict = True
+            if "__weakref__" in d:
+                any_parents_have_weakref = True
+        if cur_cls in __IGNORED_CLASSES:
             continue
         try:
             cur_slots = cur_cls.__slots__
@@ -47,11 +99,6 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
             if field not in slots:
                 fields.append(field)
             slots[field] = None
-            if cur_cls is not cls:
-                if field == "__dict__":
-                    any_parents_have_dict = True
-                elif field == "__weakref__":
-                    any_parents_have_weakref = True
 
     fields = tuple(fields)  # fields needs to be immutable
 
index a087e6e5b04f7f05e6f74c911a9dfb9b79b7eea0..f16faba342a4c38d18d74c547591b56f9f8ffc7a 100644 (file)
@@ -4,9 +4,18 @@
 import operator
 import pickle
 import unittest
+import typing
 from nmutil.plain_data import (FrozenPlainDataError, plain_data,
                                fields, replace)
 
+try:
+    from typing import Protocol
+except ImportError:
+    try:
+        from typing_extensions import Protocol
+    except ImportError:
+        Protocol = None
+
 
 @plain_data(order=True)
 class PlainData0:
@@ -67,6 +76,51 @@ class UnsetField:
             setattr(self, name, value)
 
 
+T = typing.TypeVar("T")
+
+
+@plain_data()
+class GenericClass(typing.Generic[T]):
+    __slots__ = "a",
+
+    def __init__(self, a):
+        self.a = a
+
+
+@plain_data()
+class MySet(typing.AbstractSet[int]):
+    __slots__ = ()
+
+    def __contains__(self, x):
+        raise NotImplementedError
+
+    def __iter__(self):
+        raise NotImplementedError
+
+    def __len__(self):
+        raise NotImplementedError
+
+
+@plain_data()
+class MyIntLike(typing.SupportsInt):
+    __slots__ = ()
+
+    def __int__(self):
+        return 1
+
+
+if Protocol is not None:
+    class MyProtocol(Protocol):
+        def my_method(self): ...
+
+    @plain_data()
+    class MyProtocolImpl(MyProtocol):
+        __slots__ = ()
+
+        def my_method(self):
+            pass
+
+
 class TestPlainData(unittest.TestCase):
     def test_fields(self):
         self.assertEqual(fields(PlainData0), ())
@@ -85,6 +139,11 @@ class TestPlainData(unittest.TestCase):
         self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
         self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
                          ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(GenericClass(1)), ("a",))
+        self.assertEqual(fields(MySet()), ())
+        self.assertEqual(fields(MyIntLike()), ())
+        if Protocol is not None:
+            self.assertEqual(fields(MyProtocolImpl()), ())
         with self.assertRaisesRegex(
                 TypeError,
                 r"the passed-in object must be a class or an instance of a "