From ac24ffbbaf17fb06d9eaea0b41c390a7f17a29cb Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 15 Aug 2022 22:25:58 -0700 Subject: [PATCH] add fields and replace functions, like dataclasses.fields/replace --- src/nmutil/plain_data.py | 80 +++++++++++++++++++++++++++++- src/nmutil/test/test_plain_data.py | 48 +++++++++++++++--- 2 files changed, 120 insertions(+), 8 deletions(-) diff --git a/src/nmutil/plain_data.py b/src/nmutil/plain_data.py index 92c303d..1c170a4 100644 --- a/src/nmutil/plain_data.py +++ b/src/nmutil/plain_data.py @@ -61,7 +61,7 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): for name in slots.keys(): retval_dict.pop(name, None) - retval_dict["_fields"] = fields + retval_dict["__plain_data_fields"] = fields def add_method_or_error(value, replace=False): name = value.__name__ @@ -220,3 +220,81 @@ def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True, return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order, repr_=repr, frozen=frozen) return decorator + + +def fields(pd): + """ get the tuple of field names of the passed-in + `@plain_data()`-decorated class. + + This is similar to `dataclasses.fields`, except this returns a + different type. + + Returns: tuple[str, ...] + + e.g.: + ``` + @plain_data() + class MyBaseClass: + __slots__ = "a_field", "field2" + def __init__(self, a_field, field2): + self.a_field = a_field + self.field2 = field2 + + assert fields(MyBaseClass) == ("a_field", "field2") + assert fields(MyBaseClass(1, 2)) == ("a_field", "field2") + + @plain_data() + class MyClass(MyBaseClass): + __slots__ = "child_field", + def __init__(self, a_field, field2, child_field): + super().__init__(a_field=a_field, field2=field2) + self.child_field = child_field + + assert fields(MyClass) == ("a_field", "field2", "child_field") + assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field") + ``` + """ + retval = getattr(pd, "__plain_data_fields", None) + if not isinstance(retval, tuple): + raise TypeError("the passed-in object must be a class or an instance" + " of a class decorated with @plain_data()") + return retval + + +__NOT_SPECIFIED = object() + + +def replace(pd, **changes): + """ Return a new instance of the passed-in `@plain_data()`-decorated + object, but with the specified fields replaced with new values. + This is quite useful with frozen `@plain_data()` classes. + + e.g.: + ``` + @plain_data(frozen=True) + class MyClass: + __slots__ = "a", "b", "c" + def __init__(self, a, b, *, c): + self.a = a + self.b = b + self.c = c + + v1 = MyClass(1, 2, c=3) + v2 = replace(v1, b=4) + assert v2 == MyClass(a=1, b=4, c=3) + assert v2 is not v1 + ``` + """ + kwargs = {} + ty = type(pd) + # call fields on ty rather than pd to ensure we're not called with a + # class rather than an instance. + for name in fields(ty): + value = changes.pop(name, __NOT_SPECIFIED) + if value is __NOT_SPECIFIED: + kwargs[name] = getattr(pd, name) + else: + kwargs[name] = value + if len(changes) != 0: + raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}") + return ty(**kwargs) diff --git a/src/nmutil/test/test_plain_data.py b/src/nmutil/test/test_plain_data.py index 93facbb..6952fe1 100644 --- a/src/nmutil/test/test_plain_data.py +++ b/src/nmutil/test/test_plain_data.py @@ -4,7 +4,8 @@ import operator import pickle import unittest -from nmutil.plain_data import FrozenPlainDataError, plain_data +from nmutil.plain_data import (FrozenPlainDataError, plain_data, + fields, replace) @plain_data(order=True) @@ -59,12 +60,45 @@ class PlainDataF2(PlainDataF1): class TestPlainData(unittest.TestCase): def test_fields(self): - self.assertEqual(PlainData0._fields, ()) - self.assertEqual(PlainData1._fields, ("a", "b", "x", "y")) - self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z")) - self.assertEqual(PlainDataF0._fields, ()) - self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y")) - self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z")) + self.assertEqual(fields(PlainData0), ()) + self.assertEqual(fields(PlainData0()), ()) + self.assertEqual(fields(PlainData1), ("a", "b", "x", "y")) + self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")), + ("a", "b", "x", "y")) + self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z")) + self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)), + ("a", "b", "x", "y", "z")) + self.assertEqual(fields(PlainDataF0), ()) + self.assertEqual(fields(PlainDataF0()), ()) + self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y")) + self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")), + ("a", "b", "x", "y")) + self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z")) + self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)), + ("a", "b", "x", "y", "z")) + with self.assertRaisesRegex( + TypeError, + r"the passed-in object must be a class or an instance of a " + r"class decorated with @plain_data\(\)"): + fields(type) + + def test_replace(self): + with self.assertRaisesRegex( + TypeError, + r"the passed-in object must be a class or an instance of a " + r"class decorated with @plain_data\(\)"): + replace(PlainData0) + with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"): + replace(PlainData0(), a=1) + with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"): + replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1) + self.assertEqual(replace(PlainData0()), PlainData0()) + self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")), + PlainDataF1(1, 2, x="x", y="y")) + self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3), + PlainDataF1(3, 2, x="x", y="y")) + self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3), + PlainDataF1(3, 2, x=5, y="y")) def test_eq(self): self.assertTrue(PlainData0() == PlainData0()) -- 2.30.2