start adding @plain_data() decorator
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 11 Aug 2022 05:04:44 +0000 (22:04 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 11 Aug 2022 05:04:44 +0000 (22:04 -0700)
src/nmutil/plain_data.py [new file with mode: 0644]
src/nmutil/test/test_plain_data.py [new file with mode: 0644]

diff --git a/src/nmutil/plain_data.py b/src/nmutil/plain_data.py
new file mode 100644 (file)
index 0000000..b1912f1
--- /dev/null
@@ -0,0 +1,215 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+class FrozenPlainDataError(AttributeError):
+    pass
+
+
+def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
+    if not isinstance(cls, type):
+        raise TypeError(
+            "plain_data() can only be used as a class decorator")
+    # slots is an ordered set by using dict keys.
+    # always add __dict__ and __weakref__
+    slots = {"__dict__": None, "__weakref__": None}
+    fields = []
+    any_parents_have_dict = False
+    any_parents_have_weakref = False
+    for cur_cls in reversed(cls.__mro__):
+        if cur_cls is object:
+            continue
+        try:
+            cur_slots = cur_cls.__slots__
+        except AttributeError as e:
+            raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
+                            " must have __slots__ so plain_data() can "
+                            "determine what fields exist in "
+                            f"{cls.__module__}.{cls.__qualname__}") from e
+        if not isinstance(cur_slots, tuple):
+            raise TypeError("plain_data() requires __slots__ to be a "
+                            "tuple of str")
+        for field in cur_slots:
+            if not isinstance(field, str):
+                raise TypeError("plain_data() requires __slots__ to be a "
+                                "tuple of str")
+            if field not in slots:
+                fields.append(field)
+            slots[field] = None
+            if cur_cls is not cls:
+                if field == "__dict__":
+                    any_parents_have_dict = True
+                elif field == "__weakref__":
+                    any_parents_have_weakref = True
+
+    if any_parents_have_dict:
+        # work around a CPython bug that unnecessarily checks if parent
+        # classes already have the __dict__ slot.
+        del slots["__dict__"]
+
+    if any_parents_have_weakref:
+        # work around a CPython bug that unnecessarily checks if parent
+        # classes already have the __weakref__ slot.
+        del slots["__weakref__"]
+
+    # now create a new class having everything we need
+    retval_dict = dict(cls.__dict__)
+    # remove all old descriptors:
+    for name in slots.keys():
+        retval_dict.pop(name, None)
+
+    def add_method_or_error(value, replace=False):
+        name = value.__name__
+        if name in retval_dict and not replace:
+            raise TypeError(
+                f"can't generate {name} method: attribute already exists")
+        value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
+        retval_dict[name] = value
+
+    if frozen:
+        slots["__plain_data_init_done"] = None
+
+        def __setattr__(self, name: str, value):
+            if getattr(self, "__plain_data_init_done", False):
+                raise FrozenPlainDataError(f"cannot assign to field {name!r}")
+            elif name not in slots and not name.startswith("_"):
+                raise AttributeError(
+                    f"cannot assign to unknown field {name!r}")
+            object.__setattr__(self, name, value)
+
+        add_method_or_error(__setattr__)
+
+        def __delattr__(self, name):
+            if getattr(self, "__plain_data_init_done", False):
+                raise FrozenPlainDataError(f"cannot delete field {name!r}")
+            object.__delattr__(self, name)
+
+        add_method_or_error(__delattr__)
+
+        old_init = cls.__init__
+
+        def __init__(self, *args, **kwargs):
+            if hasattr(self, "__plain_data_init_done"):
+                # we're already in an __init__ call (probably a
+                # superclass's __init__), don't set
+                # __plain_data_init_done too early
+                return old_init(self, *args, **kwargs)
+            object.__setattr__(self, "__plain_data_init_done", False)
+            try:
+                return old_init(self, *args, **kwargs)
+            finally:
+                object.__setattr__(self, "__plain_data_init_done", True)
+
+        add_method_or_error(__init__, replace=True)
+
+    # set __slots__ to have everything we need in the preferred order
+    retval_dict["__slots__"] = tuple(slots.keys())
+
+    def __dir__(self):
+        # don't return fields un-copied so users can't mess with it
+        return fields.copy()
+
+    add_method_or_error(__dir__)
+
+    def __getstate__(self):
+        # pickling support
+        return [getattr(self, name) for name in fields]
+
+    add_method_or_error(__getstate__)
+
+    def __setstate__(self, state):
+        # pickling support
+        for name, value in zip(fields, state):
+            # bypass frozen setattr
+            object.__setattr__(self, name, value)
+
+    add_method_or_error(__setstate__)
+
+    # get a tuple of all fields
+    def fields_tuple(self):
+        return tuple(getattr(self, name) for name in fields)
+
+    if eq:
+        def __eq__(self, other):
+            if other.__class__ is not self.__class__:
+                return NotImplemented
+            return fields_tuple(self) == fields_tuple(other)
+
+        add_method_or_error(__eq__)
+
+    if unsafe_hash:
+        def __hash__(self):
+            return hash(fields_tuple(self))
+
+        add_method_or_error(__hash__)
+
+    if order:
+        def __lt__(self, other):
+            if other.__class__ is not self.__class__:
+                return NotImplemented
+            return fields_tuple(self) < fields_tuple(other)
+
+        add_method_or_error(__lt__)
+
+        def __le__(self, other):
+            if other.__class__ is not self.__class__:
+                return NotImplemented
+            return fields_tuple(self) <= fields_tuple(other)
+
+        add_method_or_error(__le__)
+
+        def __gt__(self, other):
+            if other.__class__ is not self.__class__:
+                return NotImplemented
+            return fields_tuple(self) > fields_tuple(other)
+
+        add_method_or_error(__gt__)
+
+        def __ge__(self, other):
+            if other.__class__ is not self.__class__:
+                return NotImplemented
+            return fields_tuple(self) >= fields_tuple(other)
+
+        add_method_or_error(__ge__)
+
+    if repr_:
+        def __repr__(self):
+            parts = []
+            for name in fields:
+                parts.append(f"{name}={getattr(self, name)!r}")
+            return f"{self.__class__.__qualname__}({', '.join(parts)})"
+
+        add_method_or_error(__repr__)
+
+    # construct class
+    retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
+
+    # add __qualname__
+    retval.__qualname__ = cls.__qualname__
+
+    # fixup super() and __class__
+    # derived from: https://stackoverflow.com/a/71666065/2597900
+    for value in retval.__dict__.values():
+        try:
+            closure = value.__closure__
+            if isinstance(closure, tuple):
+                if closure[0].cell_contents is cls:
+                    closure[0].cell_contents = retval
+        except (AttributeError, IndexError):
+            pass
+
+    return retval
+
+
+def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True,
+               frozen=False):
+    """ Decorator for adding equality comparison, ordered comparison,
+    `repr` support, `hash` support, and frozen type (read-only fields)
+    support to classes that are just plain data.
+
+    This is kinda like dataclasses, but uses `__slots__` instead of type
+    annotations, as well as requiring you to write your own `__init__`
+    """
+    def decorator(cls):
+        return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
+                          repr_=repr, frozen=frozen)
+    return decorator
diff --git a/src/nmutil/test/test_plain_data.py b/src/nmutil/test/test_plain_data.py
new file mode 100644 (file)
index 0000000..b07d64b
--- /dev/null
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+import unittest
+from nmutil.plain_data import FrozenPlainDataError, plain_data
+
+
+@plain_data()
+class PlainData0:
+    __slots__ = ()
+
+
+@plain_data()
+class PlainData1:
+    __slots__ = "a", "b", "x", "y"
+
+    def __init__(self, a, b, *, x, y):
+        self.a = a
+        self.b = b
+        self.x = x
+        self.y = y
+
+
+@plain_data()
+class PlainData2(PlainData1):
+    __slots__ = "a", "z"
+
+    def __init__(self, a, b, *, x, y, z):
+        super().__init__(a, b, x=x, y=y)
+        self.z = z
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class PlainDataF0:
+    __slots__ = ()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class PlainDataF1:
+    __slots__ = "a", "b", "x", "y"
+
+    def __init__(self, a, b, *, x, y):
+        self.a = a
+        self.b = b
+        self.x = x
+        self.y = y
+
+
+class TestPlainData(unittest.TestCase):
+    def test_repr(self):
+        self.assertEqual(repr(PlainData0()), "PlainData0()")
+        self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
+                         "PlainData1(a=1, b=2, x='x', y='y')")
+        self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
+                         "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+
+    def test_eq(self):
+        self.assertTrue(PlainData0() == PlainData0())
+        self.assertFalse('a' == PlainData0())
+        self.assertTrue(PlainData1(1, 2, x="x", y="y")
+                        == PlainData1(1, 2, x="x", y="y"))
+        self.assertFalse(PlainData1(1, 2, x="x", y="y")
+                         == PlainData1(1, 2, x="x", y="z"))
+        self.assertFalse(PlainData1(1, 2, x="x", y="y")
+                         == PlainData2(1, 2, x="x", y="y", z=3))
+
+    def test_frozen(self):
+        not_frozen = PlainData0()
+        not_frozen.a = 1
+        frozen0 = PlainDataF0()
+        with self.assertRaises(AttributeError):
+            frozen0.a = 1
+        frozen1 = PlainDataF1(1, 2, x="x", y="y")
+        with self.assertRaises(FrozenPlainDataError):
+            frozen1.a = 1
+
+    # FIXME: add more tests
+
+
+if __name__ == "__main__":
+    unittest.main()