speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / plain_data.py
index 92c303daf323c73d6f4a9ee3c27be583f2214861..7bde6ba2f27ff3786beab9c036019337e36bbfe7 100644 (file)
@@ -1,10 +1,68 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+import keyword
+
 
 class FrozenPlainDataError(AttributeError):
     pass
 
 
+class __NotSet:
+    """ helper for __repr__ for when fields aren't set """
+
+    def __repr__(self):
+        return "<not set>"
+
+
+__NOT_SET = __NotSet()
+
+
+def __ignored_classes():
+    classes = [object]  # type: list[type]
+
+    from abc import ABC
+
+    classes += [ABC]
+
+    from typing import (
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound)
+
+    classes += [
+        Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+        SupportsInt, SupportsRound]
+
+    from collections.abc import (
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence)
+
+    classes += [
+        Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+        Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+        Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+        MappingView, KeysView, ItemsView, ValuesView, Sequence,
+        MutableSequence]
+
+    # rest aren't supported by python 3.7, so try to import them and skip if
+    # that errors
+
+    try:
+        # typing_extensions uses typing.Protocol if available
+        from typing_extensions import Protocol
+        classes.append(Protocol)
+    except ImportError:
+        pass
+
+    for cls in classes:
+        yield from cls.__mro__
+
+
+__IGNORED_CLASSES = frozenset(__ignored_classes())
+
+
 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     if not isinstance(cls, type):
         raise TypeError(
@@ -18,7 +76,13 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     any_parents_have_dict = False
     any_parents_have_weakref = False
     for cur_cls in reversed(cls.__mro__):
-        if cur_cls is object:
+        d = getattr(cur_cls, "__dict__", {})
+        if cur_cls is not cls:
+            if "__dict__" in d:
+                any_parents_have_dict = True
+            if "__weakref__" in d:
+                any_parents_have_weakref = True
+        if cur_cls in __IGNORED_CLASSES:
             continue
         try:
             cur_slots = cur_cls.__slots__
@@ -34,14 +98,13 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
             if not isinstance(field, str):
                 raise TypeError("plain_data() requires __slots__ to be a "
                                 "tuple of str")
+            if not field.isidentifier() or keyword.iskeyword(field):
+                raise TypeError(
+                    "plain_data() requires __slots__ entries to be valid "
+                    "Python identifiers and not keywords")
             if field not in slots:
                 fields.append(field)
             slots[field] = None
-            if cur_cls is not cls:
-                if field == "__dict__":
-                    any_parents_have_dict = True
-                elif field == "__weakref__":
-                    any_parents_have_weakref = True
 
     fields = tuple(fields)  # fields needs to be immutable
 
@@ -61,7 +124,7 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     for name in slots.keys():
         retval_dict.pop(name, None)
 
-    retval_dict["_fields"] = fields
+    retval_dict["__plain_data_fields"] = fields
 
     def add_method_or_error(value, replace=False):
         name = value.__name__
@@ -124,58 +187,68 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
 
     add_method_or_error(__setstate__)
 
-    # get a tuple of all fields
-    def fields_tuple(self):
-        return tuple(getattr(self, name) for name in fields)
+    # get source code that gets a tuple of all fields
+    def fields_tuple(var):
+        # type: (str) -> str
+        l = []
+        for name in fields:
+            l.append(f"{var}.{name}, ")
+        return "(" + "".join(l) + ")"
 
     if eq:
-        def __eq__(self, other):
-            if other.__class__ is not self.__class__:
-                return NotImplemented
-            return fields_tuple(self) == fields_tuple(other)
+        exec(f"""
+def __eq__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} == {fields_tuple('other')}
 
-        add_method_or_error(__eq__)
+add_method_or_error(__eq__)
+""")
 
     if unsafe_hash:
-        def __hash__(self):
-            return hash(fields_tuple(self))
+        exec(f"""
+def __hash__(self):
+    return hash({fields_tuple('self')})
 
-        add_method_or_error(__hash__)
+add_method_or_error(__hash__)
+""")
 
     if order:
-        def __lt__(self, other):
-            if other.__class__ is not self.__class__:
-                return NotImplemented
-            return fields_tuple(self) < fields_tuple(other)
+        exec(f"""
+def __lt__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} < {fields_tuple('other')}
 
-        add_method_or_error(__lt__)
+add_method_or_error(__lt__)
 
-        def __le__(self, other):
-            if other.__class__ is not self.__class__:
-                return NotImplemented
-            return fields_tuple(self) <= fields_tuple(other)
+def __le__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} <= {fields_tuple('other')}
 
-        add_method_or_error(__le__)
+add_method_or_error(__le__)
 
-        def __gt__(self, other):
-            if other.__class__ is not self.__class__:
-                return NotImplemented
-            return fields_tuple(self) > fields_tuple(other)
+def __gt__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} > {fields_tuple('other')}
 
-        add_method_or_error(__gt__)
+add_method_or_error(__gt__)
 
-        def __ge__(self, other):
-            if other.__class__ is not self.__class__:
-                return NotImplemented
-            return fields_tuple(self) >= fields_tuple(other)
+def __ge__(self, other):
+    if other.__class__ is not self.__class__:
+        return NotImplemented
+    return {fields_tuple('self')} >= {fields_tuple('other')}
 
-        add_method_or_error(__ge__)
+add_method_or_error(__ge__)
+""")
 
     if repr_:
         def __repr__(self):
             parts = []
             for name in fields:
-                parts.append(f"{name}={getattr(self, name)!r}")
+                parts.append(f"{name}={getattr(self, name, __NOT_SET)!r}")
             return f"{self.__class__.__qualname__}({', '.join(parts)})"
 
         add_method_or_error(__repr__)
@@ -220,3 +293,81 @@ def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
         return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
                           repr_=repr, frozen=frozen)
     return decorator
+
+
+def fields(pd):
+    """ get the tuple of field names of the passed-in
+    `@plain_data()`-decorated class.
+
+    This is similar to `dataclasses.fields`, except this returns a
+    different type.
+
+    Returns: tuple[str, ...]
+
+    e.g.:
+    ```
+    @plain_data()
+    class MyBaseClass:
+        __slots__ = "a_field", "field2"
+        def __init__(self, a_field, field2):
+            self.a_field = a_field
+            self.field2 = field2
+
+    assert fields(MyBaseClass) == ("a_field", "field2")
+    assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
+
+    @plain_data()
+    class MyClass(MyBaseClass):
+        __slots__ = "child_field",
+        def __init__(self, a_field, field2, child_field):
+            super().__init__(a_field=a_field, field2=field2)
+            self.child_field = child_field
+
+    assert fields(MyClass) == ("a_field", "field2", "child_field")
+    assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
+    ```
+    """
+    retval = getattr(pd, "__plain_data_fields", None)
+    if not isinstance(retval, tuple):
+        raise TypeError("the passed-in object must be a class or an instance"
+                        " of a class decorated with @plain_data()")
+    return retval
+
+
+__NOT_SPECIFIED = object()
+
+
+def replace(pd, **changes):
+    """ Return a new instance of the passed-in `@plain_data()`-decorated
+    object, but with the specified fields replaced with new values.
+    This is quite useful with frozen `@plain_data()` classes.
+
+    e.g.:
+    ```
+    @plain_data(frozen=True)
+    class MyClass:
+        __slots__ = "a", "b", "c"
+        def __init__(self, a, b, *, c):
+            self.a = a
+            self.b = b
+            self.c = c
+
+    v1 = MyClass(1, 2, c=3)
+    v2 = replace(v1, b=4)
+    assert v2 == MyClass(a=1, b=4, c=3)
+    assert v2 is not v1
+    ```
+    """
+    kwargs = {}
+    ty = type(pd)
+    # call fields on ty rather than pd to ensure we're not called with a
+    # class rather than an instance.
+    for name in fields(ty):
+        value = changes.pop(name, __NOT_SPECIFIED)
+        if value is __NOT_SPECIFIED:
+            kwargs[name] = getattr(pd, name)
+        else:
+            kwargs[name] = value
+    if len(changes) != 0:
+        raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
+    return ty(**kwargs)