add fields and replace functions, like dataclasses.fields/replace
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 05:25:58 +0000 (22:25 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 05:25:58 +0000 (22:25 -0700)
src/nmutil/plain_data.py
src/nmutil/test/test_plain_data.py

index 92c303daf323c73d6f4a9ee3c27be583f2214861..1c170a4d28d0384e93321ad4c4a1ff2ad2770549 100644 (file)
@@ -61,7 +61,7 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     for name in slots.keys():
         retval_dict.pop(name, None)
 
-    retval_dict["_fields"] = fields
+    retval_dict["__plain_data_fields"] = fields
 
     def add_method_or_error(value, replace=False):
         name = value.__name__
@@ -220,3 +220,81 @@ def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
         return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
                           repr_=repr, frozen=frozen)
     return decorator
+
+
+def fields(pd):
+    """ get the tuple of field names of the passed-in
+    `@plain_data()`-decorated class.
+
+    This is similar to `dataclasses.fields`, except this returns a
+    different type.
+
+    Returns: tuple[str, ...]
+
+    e.g.:
+    ```
+    @plain_data()
+    class MyBaseClass:
+        __slots__ = "a_field", "field2"
+        def __init__(self, a_field, field2):
+            self.a_field = a_field
+            self.field2 = field2
+
+    assert fields(MyBaseClass) == ("a_field", "field2")
+    assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
+
+    @plain_data()
+    class MyClass(MyBaseClass):
+        __slots__ = "child_field",
+        def __init__(self, a_field, field2, child_field):
+            super().__init__(a_field=a_field, field2=field2)
+            self.child_field = child_field
+
+    assert fields(MyClass) == ("a_field", "field2", "child_field")
+    assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
+    ```
+    """
+    retval = getattr(pd, "__plain_data_fields", None)
+    if not isinstance(retval, tuple):
+        raise TypeError("the passed-in object must be a class or an instance"
+                        " of a class decorated with @plain_data()")
+    return retval
+
+
+__NOT_SPECIFIED = object()
+
+
+def replace(pd, **changes):
+    """ Return a new instance of the passed-in `@plain_data()`-decorated
+    object, but with the specified fields replaced with new values.
+    This is quite useful with frozen `@plain_data()` classes.
+
+    e.g.:
+    ```
+    @plain_data(frozen=True)
+    class MyClass:
+        __slots__ = "a", "b", "c"
+        def __init__(self, a, b, *, c):
+            self.a = a
+            self.b = b
+            self.c = c
+
+    v1 = MyClass(1, 2, c=3)
+    v2 = replace(v1, b=4)
+    assert v2 == MyClass(a=1, b=4, c=3)
+    assert v2 is not v1
+    ```
+    """
+    kwargs = {}
+    ty = type(pd)
+    # call fields on ty rather than pd to ensure we're not called with a
+    # class rather than an instance.
+    for name in fields(ty):
+        value = changes.pop(name, __NOT_SPECIFIED)
+        if value is __NOT_SPECIFIED:
+            kwargs[name] = getattr(pd, name)
+        else:
+            kwargs[name] = value
+    if len(changes) != 0:
+        raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
+    return ty(**kwargs)
index 93facbb4eaa96b13d64c975e719c62d31640dc58..6952fe1ff1ca9a52eb14770db4fa8ee4b7406db5 100644 (file)
@@ -4,7 +4,8 @@
 import operator
 import pickle
 import unittest
-from nmutil.plain_data import FrozenPlainDataError, plain_data
+from nmutil.plain_data import (FrozenPlainDataError, plain_data,
+                               fields, replace)
 
 
 @plain_data(order=True)
@@ -59,12 +60,45 @@ class PlainDataF2(PlainDataF1):
 
 class TestPlainData(unittest.TestCase):
     def test_fields(self):
-        self.assertEqual(PlainData0._fields, ())
-        self.assertEqual(PlainData1._fields, ("a", "b", "x", "y"))
-        self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z"))
-        self.assertEqual(PlainDataF0._fields, ())
-        self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y"))
-        self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainData0), ())
+        self.assertEqual(fields(PlainData0()), ())
+        self.assertEqual(fields(PlainData1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF0), ())
+        self.assertEqual(fields(PlainDataF0()), ())
+        self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            fields(type)
+
+    def test_replace(self):
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            replace(PlainData0)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"):
+            replace(PlainData0(), a=1)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"):
+            replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1)
+        self.assertEqual(replace(PlainData0()), PlainData0())
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")),
+                         PlainDataF1(1, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3),
+                         PlainDataF1(3, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3),
+                         PlainDataF1(3, 2, x=5, y="y"))
 
     def test_eq(self):
         self.assertTrue(PlainData0() == PlainData0())