add fields and replace functions, like dataclasses.fields/replace
[nmutil.git] / src / nmutil / plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 class FrozenPlainDataError(AttributeError):
5 pass
6
7
8 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
9 if not isinstance(cls, type):
10 raise TypeError(
11 "plain_data() can only be used as a class decorator")
12 # slots is an ordered set by using dict keys.
13 # always add __dict__ and __weakref__
14 slots = {"__dict__": None, "__weakref__": None}
15 if frozen:
16 slots["__plain_data_init_done"] = None
17 fields = []
18 any_parents_have_dict = False
19 any_parents_have_weakref = False
20 for cur_cls in reversed(cls.__mro__):
21 if cur_cls is object:
22 continue
23 try:
24 cur_slots = cur_cls.__slots__
25 except AttributeError as e:
26 raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
27 " must have __slots__ so plain_data() can "
28 "determine what fields exist in "
29 f"{cls.__module__}.{cls.__qualname__}") from e
30 if not isinstance(cur_slots, tuple):
31 raise TypeError("plain_data() requires __slots__ to be a "
32 "tuple of str")
33 for field in cur_slots:
34 if not isinstance(field, str):
35 raise TypeError("plain_data() requires __slots__ to be a "
36 "tuple of str")
37 if field not in slots:
38 fields.append(field)
39 slots[field] = None
40 if cur_cls is not cls:
41 if field == "__dict__":
42 any_parents_have_dict = True
43 elif field == "__weakref__":
44 any_parents_have_weakref = True
45
46 fields = tuple(fields) # fields needs to be immutable
47
48 if any_parents_have_dict:
49 # work around a CPython bug that unnecessarily checks if parent
50 # classes already have the __dict__ slot.
51 del slots["__dict__"]
52
53 if any_parents_have_weakref:
54 # work around a CPython bug that unnecessarily checks if parent
55 # classes already have the __weakref__ slot.
56 del slots["__weakref__"]
57
58 # now create a new class having everything we need
59 retval_dict = dict(cls.__dict__)
60 # remove all old descriptors:
61 for name in slots.keys():
62 retval_dict.pop(name, None)
63
64 retval_dict["__plain_data_fields"] = fields
65
66 def add_method_or_error(value, replace=False):
67 name = value.__name__
68 if name in retval_dict and not replace:
69 raise TypeError(
70 f"can't generate {name} method: attribute already exists")
71 value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
72 retval_dict[name] = value
73
74 if frozen:
75 def __setattr__(self, name: str, value):
76 if getattr(self, "__plain_data_init_done", False):
77 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
78 elif name not in slots and not name.startswith("_"):
79 raise AttributeError(
80 f"cannot assign to unknown field {name!r}")
81 object.__setattr__(self, name, value)
82
83 add_method_or_error(__setattr__)
84
85 def __delattr__(self, name):
86 if getattr(self, "__plain_data_init_done", False):
87 raise FrozenPlainDataError(f"cannot delete field {name!r}")
88 object.__delattr__(self, name)
89
90 add_method_or_error(__delattr__)
91
92 old_init = cls.__init__
93
94 def __init__(self, *args, **kwargs):
95 if hasattr(self, "__plain_data_init_done"):
96 # we're already in an __init__ call (probably a
97 # superclass's __init__), don't set
98 # __plain_data_init_done too early
99 return old_init(self, *args, **kwargs)
100 object.__setattr__(self, "__plain_data_init_done", False)
101 try:
102 return old_init(self, *args, **kwargs)
103 finally:
104 object.__setattr__(self, "__plain_data_init_done", True)
105
106 add_method_or_error(__init__, replace=True)
107 else:
108 old_init = None
109
110 # set __slots__ to have everything we need in the preferred order
111 retval_dict["__slots__"] = tuple(slots.keys())
112
113 def __getstate__(self):
114 # pickling support
115 return [getattr(self, name) for name in fields]
116
117 add_method_or_error(__getstate__)
118
119 def __setstate__(self, state):
120 # pickling support
121 for name, value in zip(fields, state):
122 # bypass frozen setattr
123 object.__setattr__(self, name, value)
124
125 add_method_or_error(__setstate__)
126
127 # get a tuple of all fields
128 def fields_tuple(self):
129 return tuple(getattr(self, name) for name in fields)
130
131 if eq:
132 def __eq__(self, other):
133 if other.__class__ is not self.__class__:
134 return NotImplemented
135 return fields_tuple(self) == fields_tuple(other)
136
137 add_method_or_error(__eq__)
138
139 if unsafe_hash:
140 def __hash__(self):
141 return hash(fields_tuple(self))
142
143 add_method_or_error(__hash__)
144
145 if order:
146 def __lt__(self, other):
147 if other.__class__ is not self.__class__:
148 return NotImplemented
149 return fields_tuple(self) < fields_tuple(other)
150
151 add_method_or_error(__lt__)
152
153 def __le__(self, other):
154 if other.__class__ is not self.__class__:
155 return NotImplemented
156 return fields_tuple(self) <= fields_tuple(other)
157
158 add_method_or_error(__le__)
159
160 def __gt__(self, other):
161 if other.__class__ is not self.__class__:
162 return NotImplemented
163 return fields_tuple(self) > fields_tuple(other)
164
165 add_method_or_error(__gt__)
166
167 def __ge__(self, other):
168 if other.__class__ is not self.__class__:
169 return NotImplemented
170 return fields_tuple(self) >= fields_tuple(other)
171
172 add_method_or_error(__ge__)
173
174 if repr_:
175 def __repr__(self):
176 parts = []
177 for name in fields:
178 parts.append(f"{name}={getattr(self, name)!r}")
179 return f"{self.__class__.__qualname__}({', '.join(parts)})"
180
181 add_method_or_error(__repr__)
182
183 # construct class
184 retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
185
186 # add __qualname__
187 retval.__qualname__ = cls.__qualname__
188
189 def fix_super_and_class(value):
190 # fixup super() and __class__
191 # derived from: https://stackoverflow.com/a/71666065/2597900
192 try:
193 closure = value.__closure__
194 if isinstance(closure, tuple):
195 if closure[0].cell_contents is cls:
196 closure[0].cell_contents = retval
197 except (AttributeError, IndexError):
198 pass
199
200 for value in retval.__dict__.values():
201 fix_super_and_class(value)
202
203 if old_init is not None:
204 fix_super_and_class(old_init)
205
206 return retval
207
208
209 def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
210 frozen=False):
211 # defaults match dataclass, with the exception of `init`
212 """ Decorator for adding equality comparison, ordered comparison,
213 `repr` support, `hash` support, and frozen type (read-only fields)
214 support to classes that are just plain data.
215
216 This is kinda like dataclasses, but uses `__slots__` instead of type
217 annotations, as well as requiring you to write your own `__init__`
218 """
219 def decorator(cls):
220 return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
221 repr_=repr, frozen=frozen)
222 return decorator
223
224
225 def fields(pd):
226 """ get the tuple of field names of the passed-in
227 `@plain_data()`-decorated class.
228
229 This is similar to `dataclasses.fields`, except this returns a
230 different type.
231
232 Returns: tuple[str, ...]
233
234 e.g.:
235 ```
236 @plain_data()
237 class MyBaseClass:
238 __slots__ = "a_field", "field2"
239 def __init__(self, a_field, field2):
240 self.a_field = a_field
241 self.field2 = field2
242
243 assert fields(MyBaseClass) == ("a_field", "field2")
244 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
245
246 @plain_data()
247 class MyClass(MyBaseClass):
248 __slots__ = "child_field",
249 def __init__(self, a_field, field2, child_field):
250 super().__init__(a_field=a_field, field2=field2)
251 self.child_field = child_field
252
253 assert fields(MyClass) == ("a_field", "field2", "child_field")
254 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
255 ```
256 """
257 retval = getattr(pd, "__plain_data_fields", None)
258 if not isinstance(retval, tuple):
259 raise TypeError("the passed-in object must be a class or an instance"
260 " of a class decorated with @plain_data()")
261 return retval
262
263
264 __NOT_SPECIFIED = object()
265
266
267 def replace(pd, **changes):
268 """ Return a new instance of the passed-in `@plain_data()`-decorated
269 object, but with the specified fields replaced with new values.
270 This is quite useful with frozen `@plain_data()` classes.
271
272 e.g.:
273 ```
274 @plain_data(frozen=True)
275 class MyClass:
276 __slots__ = "a", "b", "c"
277 def __init__(self, a, b, *, c):
278 self.a = a
279 self.b = b
280 self.c = c
281
282 v1 = MyClass(1, 2, c=3)
283 v2 = replace(v1, b=4)
284 assert v2 == MyClass(a=1, b=4, c=3)
285 assert v2 is not v1
286 ```
287 """
288 kwargs = {}
289 ty = type(pd)
290 # call fields on ty rather than pd to ensure we're not called with a
291 # class rather than an instance.
292 for name in fields(ty):
293 value = changes.pop(name, __NOT_SPECIFIED)
294 if value is __NOT_SPECIFIED:
295 kwargs[name] = getattr(pd, name)
296 else:
297 kwargs[name] = value
298 if len(changes) != 0:
299 raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
300 return ty(**kwargs)