add fields and replace functions, like dataclasses.fields/replace
[nmutil.git] / src / nmutil / test / test_plain_data.py
index 93facbb4eaa96b13d64c975e719c62d31640dc58..6952fe1ff1ca9a52eb14770db4fa8ee4b7406db5 100644 (file)
@@ -4,7 +4,8 @@
 import operator
 import pickle
 import unittest
-from nmutil.plain_data import FrozenPlainDataError, plain_data
+from nmutil.plain_data import (FrozenPlainDataError, plain_data,
+                               fields, replace)
 
 
 @plain_data(order=True)
@@ -59,12 +60,45 @@ class PlainDataF2(PlainDataF1):
 
 class TestPlainData(unittest.TestCase):
     def test_fields(self):
-        self.assertEqual(PlainData0._fields, ())
-        self.assertEqual(PlainData1._fields, ("a", "b", "x", "y"))
-        self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z"))
-        self.assertEqual(PlainDataF0._fields, ())
-        self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y"))
-        self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainData0), ())
+        self.assertEqual(fields(PlainData0()), ())
+        self.assertEqual(fields(PlainData1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF0), ())
+        self.assertEqual(fields(PlainDataF0()), ())
+        self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")),
+                         ("a", "b", "x", "y"))
+        self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
+        self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
+                         ("a", "b", "x", "y", "z"))
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            fields(type)
+
+    def test_replace(self):
+        with self.assertRaisesRegex(
+                TypeError,
+                r"the passed-in object must be a class or an instance of a "
+                r"class decorated with @plain_data\(\)"):
+            replace(PlainData0)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"):
+            replace(PlainData0(), a=1)
+        with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"):
+            replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1)
+        self.assertEqual(replace(PlainData0()), PlainData0())
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")),
+                         PlainDataF1(1, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3),
+                         PlainDataF1(3, 2, x="x", y="y"))
+        self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3),
+                         PlainDataF1(3, 2, x=5, y="y"))
 
     def test_eq(self):
         self.assertTrue(PlainData0() == PlainData0())