add fields and replace functions, like dataclasses.fields/replace
[nmutil.git] / src / nmutil / test / test_plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 import operator
5 import pickle
6 import unittest
7 from nmutil.plain_data import (FrozenPlainDataError, plain_data,
8 fields, replace)
9
10
11 @plain_data(order=True)
12 class PlainData0:
13 __slots__ = ()
14
15
16 @plain_data(order=True)
17 class PlainData1:
18 __slots__ = "a", "b", "x", "y"
19
20 def __init__(self, a, b, *, x, y):
21 self.a = a
22 self.b = b
23 self.x = x
24 self.y = y
25
26
27 @plain_data(order=True)
28 class PlainData2(PlainData1):
29 __slots__ = "a", "z"
30
31 def __init__(self, a, b, *, x, y, z):
32 super().__init__(a, b, x=x, y=y)
33 self.z = z
34
35
36 @plain_data(order=True, frozen=True, unsafe_hash=True)
37 class PlainDataF0:
38 __slots__ = ()
39
40
41 @plain_data(order=True, frozen=True, unsafe_hash=True)
42 class PlainDataF1:
43 __slots__ = "a", "b", "x", "y"
44
45 def __init__(self, a, b, *, x, y):
46 self.a = a
47 self.b = b
48 self.x = x
49 self.y = y
50
51
52 @plain_data(order=True, frozen=True, unsafe_hash=True)
53 class PlainDataF2(PlainDataF1):
54 __slots__ = "a", "z"
55
56 def __init__(self, a, b, *, x, y, z):
57 super().__init__(a, b, x=x, y=y)
58 self.z = z
59
60
61 class TestPlainData(unittest.TestCase):
62 def test_fields(self):
63 self.assertEqual(fields(PlainData0), ())
64 self.assertEqual(fields(PlainData0()), ())
65 self.assertEqual(fields(PlainData1), ("a", "b", "x", "y"))
66 self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")),
67 ("a", "b", "x", "y"))
68 self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z"))
69 self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)),
70 ("a", "b", "x", "y", "z"))
71 self.assertEqual(fields(PlainDataF0), ())
72 self.assertEqual(fields(PlainDataF0()), ())
73 self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y"))
74 self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")),
75 ("a", "b", "x", "y"))
76 self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
77 self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
78 ("a", "b", "x", "y", "z"))
79 with self.assertRaisesRegex(
80 TypeError,
81 r"the passed-in object must be a class or an instance of a "
82 r"class decorated with @plain_data\(\)"):
83 fields(type)
84
85 def test_replace(self):
86 with self.assertRaisesRegex(
87 TypeError,
88 r"the passed-in object must be a class or an instance of a "
89 r"class decorated with @plain_data\(\)"):
90 replace(PlainData0)
91 with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"):
92 replace(PlainData0(), a=1)
93 with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"):
94 replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1)
95 self.assertEqual(replace(PlainData0()), PlainData0())
96 self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")),
97 PlainDataF1(1, 2, x="x", y="y"))
98 self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3),
99 PlainDataF1(3, 2, x="x", y="y"))
100 self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3),
101 PlainDataF1(3, 2, x=5, y="y"))
102
103 def test_eq(self):
104 self.assertTrue(PlainData0() == PlainData0())
105 self.assertFalse('a' == PlainData0())
106 self.assertFalse(PlainDataF0() == PlainData0())
107 self.assertTrue(PlainData1(1, 2, x="x", y="y")
108 == PlainData1(1, 2, x="x", y="y"))
109 self.assertFalse(PlainData1(1, 2, x="x", y="y")
110 == PlainData1(1, 2, x="x", y="z"))
111 self.assertFalse(PlainData1(1, 2, x="x", y="y")
112 == PlainData2(1, 2, x="x", y="y", z=3))
113
114 def test_hash(self):
115 def check_op(v, tuple_v):
116 with self.subTest(v=v, tuple_v=tuple_v):
117 self.assertEqual(hash(v), hash(tuple_v))
118
119 def check(a, b, x, y, z):
120 tuple_v = a, b, x, y, z
121 v = PlainDataF2(a=a, b=b, x=x, y=y, z=z)
122 check_op(v, tuple_v)
123
124 check(1, 2, "x", "y", "z")
125
126 check(1, 2, "x", "y", "a")
127 check(1, 2, "x", "y", "zz")
128
129 check(1, 2, "x", "a", "z")
130 check(1, 2, "x", "zz", "z")
131
132 check(1, 2, "a", "y", "z")
133 check(1, 2, "zz", "y", "z")
134
135 check(1, -10, "x", "y", "z")
136 check(1, 10, "x", "y", "z")
137
138 check(-10, 2, "x", "y", "z")
139 check(10, 2, "x", "y", "z")
140
141 def test_order(self):
142 def check_op(l, r, tuple_l, tuple_r, op):
143 with self.subTest(l=l, r=r,
144 tuple_l=tuple_l, tuple_r=tuple_r, op=op):
145 self.assertEqual(op(l, r), op(tuple_l, tuple_r))
146 self.assertEqual(op(r, l), op(tuple_r, tuple_l))
147
148 def check(a, b, x, y, z):
149 tuple_l = 1, 2, "x", "y", "z"
150 l = PlainData2(a=1, b=2, x="x", y="y", z="z")
151 tuple_r = a, b, x, y, z
152 r = PlainData2(a=a, b=b, x=x, y=y, z=z)
153 check_op(l, r, tuple_l, tuple_r, operator.eq)
154 check_op(l, r, tuple_l, tuple_r, operator.ne)
155 check_op(l, r, tuple_l, tuple_r, operator.lt)
156 check_op(l, r, tuple_l, tuple_r, operator.le)
157 check_op(l, r, tuple_l, tuple_r, operator.gt)
158 check_op(l, r, tuple_l, tuple_r, operator.ge)
159
160 check(1, 2, "x", "y", "z")
161
162 check(1, 2, "x", "y", "a")
163 check(1, 2, "x", "y", "zz")
164
165 check(1, 2, "x", "a", "z")
166 check(1, 2, "x", "zz", "z")
167
168 check(1, 2, "a", "y", "z")
169 check(1, 2, "zz", "y", "z")
170
171 check(1, -10, "x", "y", "z")
172 check(1, 10, "x", "y", "z")
173
174 check(-10, 2, "x", "y", "z")
175 check(10, 2, "x", "y", "z")
176
177 def test_repr(self):
178 self.assertEqual(repr(PlainData0()), "PlainData0()")
179 self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
180 "PlainData1(a=1, b=2, x='x', y='y')")
181 self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
182 "PlainData2(a=1, b=2, x='x', y='y', z=3)")
183 self.assertEqual(repr(PlainDataF2(1, 2, x="x", y="y", z=3)),
184 "PlainDataF2(a=1, b=2, x='x', y='y', z=3)")
185
186 def test_frozen(self):
187 not_frozen = PlainData0()
188 not_frozen.a = 1
189 frozen0 = PlainDataF0()
190 with self.assertRaises(AttributeError):
191 frozen0.a = 1
192 frozen1 = PlainDataF1(1, 2, x="x", y="y")
193 with self.assertRaises(FrozenPlainDataError):
194 frozen1.a = 1
195
196 def test_pickle(self):
197 def check(v):
198 with self.subTest(v=v):
199 self.assertEqual(v, pickle.loads(pickle.dumps(v)))
200
201 check(PlainData0())
202 check(PlainData1(a=1, b=2, x="x", y="y"))
203 check(PlainData2(a=1, b=2, x="x", y="y", z="z"))
204 check(PlainDataF0())
205 check(PlainDataF1(a=1, b=2, x="x", y="y"))
206 check(PlainDataF2(a=1, b=2, x="x", y="y", z="z"))
207
208
209 if __name__ == "__main__":
210 unittest.main()