finish implementing @plain_data()
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 12 Aug 2022 06:08:44 +0000 (23:08 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 12 Aug 2022 06:22:23 +0000 (23:22 -0700)
src/nmutil/plain_data.py
src/nmutil/test/test_plain_data.py

index b1912f1532439a58316785bf8e9607789b775a20..92c303daf323c73d6f4a9ee3c27be583f2214861 100644 (file)
@@ -12,6 +12,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     # slots is an ordered set by using dict keys.
     # always add __dict__ and __weakref__
     slots = {"__dict__": None, "__weakref__": None}
+    if frozen:
+        slots["__plain_data_init_done"] = None
     fields = []
     any_parents_have_dict = False
     any_parents_have_weakref = False
@@ -41,6 +43,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
                 elif field == "__weakref__":
                     any_parents_have_weakref = True
 
+    fields = tuple(fields)  # fields needs to be immutable
+
     if any_parents_have_dict:
         # work around a CPython bug that unnecessarily checks if parent
         # classes already have the __dict__ slot.
@@ -57,6 +61,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     for name in slots.keys():
         retval_dict.pop(name, None)
 
+    retval_dict["_fields"] = fields
+
     def add_method_or_error(value, replace=False):
         name = value.__name__
         if name in retval_dict and not replace:
@@ -66,8 +72,6 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
         retval_dict[name] = value
 
     if frozen:
-        slots["__plain_data_init_done"] = None
-
         def __setattr__(self, name: str, value):
             if getattr(self, "__plain_data_init_done", False):
                 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
@@ -100,16 +104,12 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
                 object.__setattr__(self, "__plain_data_init_done", True)
 
         add_method_or_error(__init__, replace=True)
+    else:
+        old_init = None
 
     # set __slots__ to have everything we need in the preferred order
     retval_dict["__slots__"] = tuple(slots.keys())
 
-    def __dir__(self):
-        # don't return fields un-copied so users can't mess with it
-        return fields.copy()
-
-    add_method_or_error(__dir__)
-
     def __getstate__(self):
         # pickling support
         return [getattr(self, name) for name in fields]
@@ -186,9 +186,9 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     # add __qualname__
     retval.__qualname__ = cls.__qualname__
 
-    # fixup super() and __class__
-    # derived from: https://stackoverflow.com/a/71666065/2597900
-    for value in retval.__dict__.values():
+    def fix_super_and_class(value):
+        # fixup super() and __class__
+        # derived from: https://stackoverflow.com/a/71666065/2597900
         try:
             closure = value.__closure__
             if isinstance(closure, tuple):
@@ -197,11 +197,18 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
         except (AttributeError, IndexError):
             pass
 
+    for value in retval.__dict__.values():
+        fix_super_and_class(value)
+
+    if old_init is not None:
+        fix_super_and_class(old_init)
+
     return retval
 
 
-def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True,
+def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
                frozen=False):
+    # defaults match dataclass, with the exception of `init`
     """ Decorator for adding equality comparison, ordered comparison,
     `repr` support, `hash` support, and frozen type (read-only fields)
     support to classes that are just plain data.
index b07d64bd7cbe8e63702e3ae474763d6873859388..93facbb4eaa96b13d64c975e719c62d31640dc58 100644 (file)
@@ -1,16 +1,18 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
 
+import operator
+import pickle
 import unittest
 from nmutil.plain_data import FrozenPlainDataError, plain_data
 
 
-@plain_data()
+@plain_data(order=True)
 class PlainData0:
     __slots__ = ()
 
 
-@plain_data()
+@plain_data(order=True)
 class PlainData1:
     __slots__ = "a", "b", "x", "y"
 
@@ -21,7 +23,7 @@ class PlainData1:
         self.y = y
 
 
-@plain_data()
+@plain_data(order=True)
 class PlainData2(PlainData1):
     __slots__ = "a", "z"
 
@@ -30,12 +32,12 @@ class PlainData2(PlainData1):
         self.z = z
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(order=True, frozen=True, unsafe_hash=True)
 class PlainDataF0:
     __slots__ = ()
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(order=True, frozen=True, unsafe_hash=True)
 class PlainDataF1:
     __slots__ = "a", "b", "x", "y"
 
@@ -46,17 +48,28 @@ class PlainDataF1:
         self.y = y
 
 
+@plain_data(order=True, frozen=True, unsafe_hash=True)
+class PlainDataF2(PlainDataF1):
+    __slots__ = "a", "z"
+
+    def __init__(self, a, b, *, x, y, z):
+        super().__init__(a, b, x=x, y=y)
+        self.z = z
+
+
 class TestPlainData(unittest.TestCase):
-    def test_repr(self):
-        self.assertEqual(repr(PlainData0()), "PlainData0()")
-        self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
-                         "PlainData1(a=1, b=2, x='x', y='y')")
-        self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
-                         "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+    def test_fields(self):
+        self.assertEqual(PlainData0._fields, ())
+        self.assertEqual(PlainData1._fields, ("a", "b", "x", "y"))
+        self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z"))
+        self.assertEqual(PlainDataF0._fields, ())
+        self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y"))
+        self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z"))
 
     def test_eq(self):
         self.assertTrue(PlainData0() == PlainData0())
         self.assertFalse('a' == PlainData0())
+        self.assertFalse(PlainDataF0() == PlainData0())
         self.assertTrue(PlainData1(1, 2, x="x", y="y")
                         == PlainData1(1, 2, x="x", y="y"))
         self.assertFalse(PlainData1(1, 2, x="x", y="y")
@@ -64,6 +77,78 @@ class TestPlainData(unittest.TestCase):
         self.assertFalse(PlainData1(1, 2, x="x", y="y")
                          == PlainData2(1, 2, x="x", y="y", z=3))
 
+    def test_hash(self):
+        def check_op(v, tuple_v):
+            with self.subTest(v=v, tuple_v=tuple_v):
+                self.assertEqual(hash(v), hash(tuple_v))
+
+        def check(a, b, x, y, z):
+            tuple_v = a, b, x, y, z
+            v = PlainDataF2(a=a, b=b, x=x, y=y, z=z)
+            check_op(v, tuple_v)
+
+        check(1, 2, "x", "y", "z")
+
+        check(1, 2, "x", "y", "a")
+        check(1, 2, "x", "y", "zz")
+
+        check(1, 2, "x", "a", "z")
+        check(1, 2, "x", "zz", "z")
+
+        check(1, 2, "a", "y", "z")
+        check(1, 2, "zz", "y", "z")
+
+        check(1, -10, "x", "y", "z")
+        check(1, 10, "x", "y", "z")
+
+        check(-10, 2, "x", "y", "z")
+        check(10, 2, "x", "y", "z")
+
+    def test_order(self):
+        def check_op(l, r, tuple_l, tuple_r, op):
+            with self.subTest(l=l, r=r,
+                              tuple_l=tuple_l, tuple_r=tuple_r, op=op):
+                self.assertEqual(op(l, r), op(tuple_l, tuple_r))
+                self.assertEqual(op(r, l), op(tuple_r, tuple_l))
+
+        def check(a, b, x, y, z):
+            tuple_l = 1, 2, "x", "y", "z"
+            l = PlainData2(a=1, b=2, x="x", y="y", z="z")
+            tuple_r = a, b, x, y, z
+            r = PlainData2(a=a, b=b, x=x, y=y, z=z)
+            check_op(l, r, tuple_l, tuple_r, operator.eq)
+            check_op(l, r, tuple_l, tuple_r, operator.ne)
+            check_op(l, r, tuple_l, tuple_r, operator.lt)
+            check_op(l, r, tuple_l, tuple_r, operator.le)
+            check_op(l, r, tuple_l, tuple_r, operator.gt)
+            check_op(l, r, tuple_l, tuple_r, operator.ge)
+
+        check(1, 2, "x", "y", "z")
+
+        check(1, 2, "x", "y", "a")
+        check(1, 2, "x", "y", "zz")
+
+        check(1, 2, "x", "a", "z")
+        check(1, 2, "x", "zz", "z")
+
+        check(1, 2, "a", "y", "z")
+        check(1, 2, "zz", "y", "z")
+
+        check(1, -10, "x", "y", "z")
+        check(1, 10, "x", "y", "z")
+
+        check(-10, 2, "x", "y", "z")
+        check(10, 2, "x", "y", "z")
+
+    def test_repr(self):
+        self.assertEqual(repr(PlainData0()), "PlainData0()")
+        self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
+                         "PlainData1(a=1, b=2, x='x', y='y')")
+        self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
+                         "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+        self.assertEqual(repr(PlainDataF2(1, 2, x="x", y="y", z=3)),
+                         "PlainDataF2(a=1, b=2, x='x', y='y', z=3)")
+
     def test_frozen(self):
         not_frozen = PlainData0()
         not_frozen.a = 1
@@ -74,7 +159,17 @@ class TestPlainData(unittest.TestCase):
         with self.assertRaises(FrozenPlainDataError):
             frozen1.a = 1
 
-    # FIXME: add more tests
+    def test_pickle(self):
+        def check(v):
+            with self.subTest(v=v):
+                self.assertEqual(v, pickle.loads(pickle.dumps(v)))
+
+        check(PlainData0())
+        check(PlainData1(a=1, b=2, x="x", y="y"))
+        check(PlainData2(a=1, b=2, x="x", y="y", z="z"))
+        check(PlainDataF0())
+        check(PlainDataF1(a=1, b=2, x="x", y="y"))
+        check(PlainDataF2(a=1, b=2, x="x", y="y", z="z"))
 
 
 if __name__ == "__main__":