0ebcb5d2b04500de537ccb1b36cb107bd188fd98
[nmutil.git] / src / nmutil / plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 class FrozenPlainDataError(AttributeError):
5 pass
6
7
8 class __NotSet:
9 """ helper for __repr__ for when fields aren't set """
10
11 def __repr__(self):
12 return "<not set>"
13
14
15 __NOT_SET = __NotSet()
16
17
18 def __ignored_classes():
19 classes = [object] # type: list[type]
20
21 from abc import ABC
22
23 classes += [ABC]
24
25 from typing import (
26 Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
27 SupportsInt, SupportsRound)
28
29 classes += [
30 Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
31 SupportsInt, SupportsRound]
32
33 from collections.abc import (
34 Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
35 Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
36 Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
37 MappingView, KeysView, ItemsView, ValuesView, Sequence,
38 MutableSequence)
39
40 classes += [
41 Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
42 Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
43 Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
44 MappingView, KeysView, ItemsView, ValuesView, Sequence,
45 MutableSequence]
46
47 # rest aren't supported by python 3.7, so try to import them and skip if
48 # that errors
49
50 try:
51 # typing_extensions uses typing.Protocol if available
52 from typing_extensions import Protocol
53 classes.append(Protocol)
54 except ImportError:
55 pass
56
57 for cls in classes:
58 yield from cls.__mro__
59
60
61 __IGNORED_CLASSES = frozenset(__ignored_classes())
62
63
64 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
65 if not isinstance(cls, type):
66 raise TypeError(
67 "plain_data() can only be used as a class decorator")
68 # slots is an ordered set by using dict keys.
69 # always add __dict__ and __weakref__
70 slots = {"__dict__": None, "__weakref__": None}
71 if frozen:
72 slots["__plain_data_init_done"] = None
73 fields = []
74 any_parents_have_dict = False
75 any_parents_have_weakref = False
76 for cur_cls in reversed(cls.__mro__):
77 d = getattr(cur_cls, "__dict__", {})
78 if cur_cls is not cls:
79 if "__dict__" in d:
80 any_parents_have_dict = True
81 if "__weakref__" in d:
82 any_parents_have_weakref = True
83 if cur_cls in __IGNORED_CLASSES:
84 continue
85 try:
86 cur_slots = cur_cls.__slots__
87 except AttributeError as e:
88 raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
89 " must have __slots__ so plain_data() can "
90 "determine what fields exist in "
91 f"{cls.__module__}.{cls.__qualname__}") from e
92 if not isinstance(cur_slots, tuple):
93 raise TypeError("plain_data() requires __slots__ to be a "
94 "tuple of str")
95 for field in cur_slots:
96 if not isinstance(field, str):
97 raise TypeError("plain_data() requires __slots__ to be a "
98 "tuple of str")
99 if field not in slots:
100 fields.append(field)
101 slots[field] = None
102
103 fields = tuple(fields) # fields needs to be immutable
104
105 if any_parents_have_dict:
106 # work around a CPython bug that unnecessarily checks if parent
107 # classes already have the __dict__ slot.
108 del slots["__dict__"]
109
110 if any_parents_have_weakref:
111 # work around a CPython bug that unnecessarily checks if parent
112 # classes already have the __weakref__ slot.
113 del slots["__weakref__"]
114
115 # now create a new class having everything we need
116 retval_dict = dict(cls.__dict__)
117 # remove all old descriptors:
118 for name in slots.keys():
119 retval_dict.pop(name, None)
120
121 retval_dict["__plain_data_fields"] = fields
122
123 def add_method_or_error(value, replace=False):
124 name = value.__name__
125 if name in retval_dict and not replace:
126 raise TypeError(
127 f"can't generate {name} method: attribute already exists")
128 value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
129 retval_dict[name] = value
130
131 if frozen:
132 def __setattr__(self, name: str, value):
133 if getattr(self, "__plain_data_init_done", False):
134 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
135 elif name not in slots and not name.startswith("_"):
136 raise AttributeError(
137 f"cannot assign to unknown field {name!r}")
138 object.__setattr__(self, name, value)
139
140 add_method_or_error(__setattr__)
141
142 def __delattr__(self, name):
143 if getattr(self, "__plain_data_init_done", False):
144 raise FrozenPlainDataError(f"cannot delete field {name!r}")
145 object.__delattr__(self, name)
146
147 add_method_or_error(__delattr__)
148
149 old_init = cls.__init__
150
151 def __init__(self, *args, **kwargs):
152 if hasattr(self, "__plain_data_init_done"):
153 # we're already in an __init__ call (probably a
154 # superclass's __init__), don't set
155 # __plain_data_init_done too early
156 return old_init(self, *args, **kwargs)
157 object.__setattr__(self, "__plain_data_init_done", False)
158 try:
159 return old_init(self, *args, **kwargs)
160 finally:
161 object.__setattr__(self, "__plain_data_init_done", True)
162
163 add_method_or_error(__init__, replace=True)
164 else:
165 old_init = None
166
167 # set __slots__ to have everything we need in the preferred order
168 retval_dict["__slots__"] = tuple(slots.keys())
169
170 def __getstate__(self):
171 # pickling support
172 return [getattr(self, name) for name in fields]
173
174 add_method_or_error(__getstate__)
175
176 def __setstate__(self, state):
177 # pickling support
178 for name, value in zip(fields, state):
179 # bypass frozen setattr
180 object.__setattr__(self, name, value)
181
182 add_method_or_error(__setstate__)
183
184 # get a tuple of all fields
185 def fields_tuple(self):
186 return tuple(getattr(self, name) for name in fields)
187
188 if eq:
189 def __eq__(self, other):
190 if other.__class__ is not self.__class__:
191 return NotImplemented
192 return fields_tuple(self) == fields_tuple(other)
193
194 add_method_or_error(__eq__)
195
196 if unsafe_hash:
197 def __hash__(self):
198 return hash(fields_tuple(self))
199
200 add_method_or_error(__hash__)
201
202 if order:
203 def __lt__(self, other):
204 if other.__class__ is not self.__class__:
205 return NotImplemented
206 return fields_tuple(self) < fields_tuple(other)
207
208 add_method_or_error(__lt__)
209
210 def __le__(self, other):
211 if other.__class__ is not self.__class__:
212 return NotImplemented
213 return fields_tuple(self) <= fields_tuple(other)
214
215 add_method_or_error(__le__)
216
217 def __gt__(self, other):
218 if other.__class__ is not self.__class__:
219 return NotImplemented
220 return fields_tuple(self) > fields_tuple(other)
221
222 add_method_or_error(__gt__)
223
224 def __ge__(self, other):
225 if other.__class__ is not self.__class__:
226 return NotImplemented
227 return fields_tuple(self) >= fields_tuple(other)
228
229 add_method_or_error(__ge__)
230
231 if repr_:
232 def __repr__(self):
233 parts = []
234 for name in fields:
235 parts.append(f"{name}={getattr(self, name, __NOT_SET)!r}")
236 return f"{self.__class__.__qualname__}({', '.join(parts)})"
237
238 add_method_or_error(__repr__)
239
240 # construct class
241 retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
242
243 # add __qualname__
244 retval.__qualname__ = cls.__qualname__
245
246 def fix_super_and_class(value):
247 # fixup super() and __class__
248 # derived from: https://stackoverflow.com/a/71666065/2597900
249 try:
250 closure = value.__closure__
251 if isinstance(closure, tuple):
252 if closure[0].cell_contents is cls:
253 closure[0].cell_contents = retval
254 except (AttributeError, IndexError):
255 pass
256
257 for value in retval.__dict__.values():
258 fix_super_and_class(value)
259
260 if old_init is not None:
261 fix_super_and_class(old_init)
262
263 return retval
264
265
266 def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
267 frozen=False):
268 # defaults match dataclass, with the exception of `init`
269 """ Decorator for adding equality comparison, ordered comparison,
270 `repr` support, `hash` support, and frozen type (read-only fields)
271 support to classes that are just plain data.
272
273 This is kinda like dataclasses, but uses `__slots__` instead of type
274 annotations, as well as requiring you to write your own `__init__`
275 """
276 def decorator(cls):
277 return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
278 repr_=repr, frozen=frozen)
279 return decorator
280
281
282 def fields(pd):
283 """ get the tuple of field names of the passed-in
284 `@plain_data()`-decorated class.
285
286 This is similar to `dataclasses.fields`, except this returns a
287 different type.
288
289 Returns: tuple[str, ...]
290
291 e.g.:
292 ```
293 @plain_data()
294 class MyBaseClass:
295 __slots__ = "a_field", "field2"
296 def __init__(self, a_field, field2):
297 self.a_field = a_field
298 self.field2 = field2
299
300 assert fields(MyBaseClass) == ("a_field", "field2")
301 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
302
303 @plain_data()
304 class MyClass(MyBaseClass):
305 __slots__ = "child_field",
306 def __init__(self, a_field, field2, child_field):
307 super().__init__(a_field=a_field, field2=field2)
308 self.child_field = child_field
309
310 assert fields(MyClass) == ("a_field", "field2", "child_field")
311 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
312 ```
313 """
314 retval = getattr(pd, "__plain_data_fields", None)
315 if not isinstance(retval, tuple):
316 raise TypeError("the passed-in object must be a class or an instance"
317 " of a class decorated with @plain_data()")
318 return retval
319
320
321 __NOT_SPECIFIED = object()
322
323
324 def replace(pd, **changes):
325 """ Return a new instance of the passed-in `@plain_data()`-decorated
326 object, but with the specified fields replaced with new values.
327 This is quite useful with frozen `@plain_data()` classes.
328
329 e.g.:
330 ```
331 @plain_data(frozen=True)
332 class MyClass:
333 __slots__ = "a", "b", "c"
334 def __init__(self, a, b, *, c):
335 self.a = a
336 self.b = b
337 self.c = c
338
339 v1 = MyClass(1, 2, c=3)
340 v2 = replace(v1, b=4)
341 assert v2 == MyClass(a=1, b=4, c=3)
342 assert v2 is not v1
343 ```
344 """
345 kwargs = {}
346 ty = type(pd)
347 # call fields on ty rather than pd to ensure we're not called with a
348 # class rather than an instance.
349 for name in fields(ty):
350 value = changes.pop(name, __NOT_SPECIFIED)
351 if value is __NOT_SPECIFIED:
352 kwargs[name] = getattr(pd, name)
353 else:
354 kwargs[name] = value
355 if len(changes) != 0:
356 raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
357 return ty(**kwargs)