speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3 import keyword
4
5
6 class FrozenPlainDataError(AttributeError):
7 pass
8
9
10 class __NotSet:
11 """ helper for __repr__ for when fields aren't set """
12
13 def __repr__(self):
14 return "<not set>"
15
16
17 __NOT_SET = __NotSet()
18
19
20 def __ignored_classes():
21 classes = [object] # type: list[type]
22
23 from abc import ABC
24
25 classes += [ABC]
26
27 from typing import (
28 Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
29 SupportsInt, SupportsRound)
30
31 classes += [
32 Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
33 SupportsInt, SupportsRound]
34
35 from collections.abc import (
36 Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
37 Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
38 Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
39 MappingView, KeysView, ItemsView, ValuesView, Sequence,
40 MutableSequence)
41
42 classes += [
43 Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
44 Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
45 Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
46 MappingView, KeysView, ItemsView, ValuesView, Sequence,
47 MutableSequence]
48
49 # rest aren't supported by python 3.7, so try to import them and skip if
50 # that errors
51
52 try:
53 # typing_extensions uses typing.Protocol if available
54 from typing_extensions import Protocol
55 classes.append(Protocol)
56 except ImportError:
57 pass
58
59 for cls in classes:
60 yield from cls.__mro__
61
62
63 __IGNORED_CLASSES = frozenset(__ignored_classes())
64
65
66 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
67 if not isinstance(cls, type):
68 raise TypeError(
69 "plain_data() can only be used as a class decorator")
70 # slots is an ordered set by using dict keys.
71 # always add __dict__ and __weakref__
72 slots = {"__dict__": None, "__weakref__": None}
73 if frozen:
74 slots["__plain_data_init_done"] = None
75 fields = []
76 any_parents_have_dict = False
77 any_parents_have_weakref = False
78 for cur_cls in reversed(cls.__mro__):
79 d = getattr(cur_cls, "__dict__", {})
80 if cur_cls is not cls:
81 if "__dict__" in d:
82 any_parents_have_dict = True
83 if "__weakref__" in d:
84 any_parents_have_weakref = True
85 if cur_cls in __IGNORED_CLASSES:
86 continue
87 try:
88 cur_slots = cur_cls.__slots__
89 except AttributeError as e:
90 raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
91 " must have __slots__ so plain_data() can "
92 "determine what fields exist in "
93 f"{cls.__module__}.{cls.__qualname__}") from e
94 if not isinstance(cur_slots, tuple):
95 raise TypeError("plain_data() requires __slots__ to be a "
96 "tuple of str")
97 for field in cur_slots:
98 if not isinstance(field, str):
99 raise TypeError("plain_data() requires __slots__ to be a "
100 "tuple of str")
101 if not field.isidentifier() or keyword.iskeyword(field):
102 raise TypeError(
103 "plain_data() requires __slots__ entries to be valid "
104 "Python identifiers and not keywords")
105 if field not in slots:
106 fields.append(field)
107 slots[field] = None
108
109 fields = tuple(fields) # fields needs to be immutable
110
111 if any_parents_have_dict:
112 # work around a CPython bug that unnecessarily checks if parent
113 # classes already have the __dict__ slot.
114 del slots["__dict__"]
115
116 if any_parents_have_weakref:
117 # work around a CPython bug that unnecessarily checks if parent
118 # classes already have the __weakref__ slot.
119 del slots["__weakref__"]
120
121 # now create a new class having everything we need
122 retval_dict = dict(cls.__dict__)
123 # remove all old descriptors:
124 for name in slots.keys():
125 retval_dict.pop(name, None)
126
127 retval_dict["__plain_data_fields"] = fields
128
129 def add_method_or_error(value, replace=False):
130 name = value.__name__
131 if name in retval_dict and not replace:
132 raise TypeError(
133 f"can't generate {name} method: attribute already exists")
134 value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
135 retval_dict[name] = value
136
137 if frozen:
138 def __setattr__(self, name: str, value):
139 if getattr(self, "__plain_data_init_done", False):
140 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
141 elif name not in slots and not name.startswith("_"):
142 raise AttributeError(
143 f"cannot assign to unknown field {name!r}")
144 object.__setattr__(self, name, value)
145
146 add_method_or_error(__setattr__)
147
148 def __delattr__(self, name):
149 if getattr(self, "__plain_data_init_done", False):
150 raise FrozenPlainDataError(f"cannot delete field {name!r}")
151 object.__delattr__(self, name)
152
153 add_method_or_error(__delattr__)
154
155 old_init = cls.__init__
156
157 def __init__(self, *args, **kwargs):
158 if hasattr(self, "__plain_data_init_done"):
159 # we're already in an __init__ call (probably a
160 # superclass's __init__), don't set
161 # __plain_data_init_done too early
162 return old_init(self, *args, **kwargs)
163 object.__setattr__(self, "__plain_data_init_done", False)
164 try:
165 return old_init(self, *args, **kwargs)
166 finally:
167 object.__setattr__(self, "__plain_data_init_done", True)
168
169 add_method_or_error(__init__, replace=True)
170 else:
171 old_init = None
172
173 # set __slots__ to have everything we need in the preferred order
174 retval_dict["__slots__"] = tuple(slots.keys())
175
176 def __getstate__(self):
177 # pickling support
178 return [getattr(self, name) for name in fields]
179
180 add_method_or_error(__getstate__)
181
182 def __setstate__(self, state):
183 # pickling support
184 for name, value in zip(fields, state):
185 # bypass frozen setattr
186 object.__setattr__(self, name, value)
187
188 add_method_or_error(__setstate__)
189
190 # get source code that gets a tuple of all fields
191 def fields_tuple(var):
192 # type: (str) -> str
193 l = []
194 for name in fields:
195 l.append(f"{var}.{name}, ")
196 return "(" + "".join(l) + ")"
197
198 if eq:
199 exec(f"""
200 def __eq__(self, other):
201 if other.__class__ is not self.__class__:
202 return NotImplemented
203 return {fields_tuple('self')} == {fields_tuple('other')}
204
205 add_method_or_error(__eq__)
206 """)
207
208 if unsafe_hash:
209 exec(f"""
210 def __hash__(self):
211 return hash({fields_tuple('self')})
212
213 add_method_or_error(__hash__)
214 """)
215
216 if order:
217 exec(f"""
218 def __lt__(self, other):
219 if other.__class__ is not self.__class__:
220 return NotImplemented
221 return {fields_tuple('self')} < {fields_tuple('other')}
222
223 add_method_or_error(__lt__)
224
225 def __le__(self, other):
226 if other.__class__ is not self.__class__:
227 return NotImplemented
228 return {fields_tuple('self')} <= {fields_tuple('other')}
229
230 add_method_or_error(__le__)
231
232 def __gt__(self, other):
233 if other.__class__ is not self.__class__:
234 return NotImplemented
235 return {fields_tuple('self')} > {fields_tuple('other')}
236
237 add_method_or_error(__gt__)
238
239 def __ge__(self, other):
240 if other.__class__ is not self.__class__:
241 return NotImplemented
242 return {fields_tuple('self')} >= {fields_tuple('other')}
243
244 add_method_or_error(__ge__)
245 """)
246
247 if repr_:
248 def __repr__(self):
249 parts = []
250 for name in fields:
251 parts.append(f"{name}={getattr(self, name, __NOT_SET)!r}")
252 return f"{self.__class__.__qualname__}({', '.join(parts)})"
253
254 add_method_or_error(__repr__)
255
256 # construct class
257 retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
258
259 # add __qualname__
260 retval.__qualname__ = cls.__qualname__
261
262 def fix_super_and_class(value):
263 # fixup super() and __class__
264 # derived from: https://stackoverflow.com/a/71666065/2597900
265 try:
266 closure = value.__closure__
267 if isinstance(closure, tuple):
268 if closure[0].cell_contents is cls:
269 closure[0].cell_contents = retval
270 except (AttributeError, IndexError):
271 pass
272
273 for value in retval.__dict__.values():
274 fix_super_and_class(value)
275
276 if old_init is not None:
277 fix_super_and_class(old_init)
278
279 return retval
280
281
282 def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
283 frozen=False):
284 # defaults match dataclass, with the exception of `init`
285 """ Decorator for adding equality comparison, ordered comparison,
286 `repr` support, `hash` support, and frozen type (read-only fields)
287 support to classes that are just plain data.
288
289 This is kinda like dataclasses, but uses `__slots__` instead of type
290 annotations, as well as requiring you to write your own `__init__`
291 """
292 def decorator(cls):
293 return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
294 repr_=repr, frozen=frozen)
295 return decorator
296
297
298 def fields(pd):
299 """ get the tuple of field names of the passed-in
300 `@plain_data()`-decorated class.
301
302 This is similar to `dataclasses.fields`, except this returns a
303 different type.
304
305 Returns: tuple[str, ...]
306
307 e.g.:
308 ```
309 @plain_data()
310 class MyBaseClass:
311 __slots__ = "a_field", "field2"
312 def __init__(self, a_field, field2):
313 self.a_field = a_field
314 self.field2 = field2
315
316 assert fields(MyBaseClass) == ("a_field", "field2")
317 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
318
319 @plain_data()
320 class MyClass(MyBaseClass):
321 __slots__ = "child_field",
322 def __init__(self, a_field, field2, child_field):
323 super().__init__(a_field=a_field, field2=field2)
324 self.child_field = child_field
325
326 assert fields(MyClass) == ("a_field", "field2", "child_field")
327 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
328 ```
329 """
330 retval = getattr(pd, "__plain_data_fields", None)
331 if not isinstance(retval, tuple):
332 raise TypeError("the passed-in object must be a class or an instance"
333 " of a class decorated with @plain_data()")
334 return retval
335
336
337 __NOT_SPECIFIED = object()
338
339
340 def replace(pd, **changes):
341 """ Return a new instance of the passed-in `@plain_data()`-decorated
342 object, but with the specified fields replaced with new values.
343 This is quite useful with frozen `@plain_data()` classes.
344
345 e.g.:
346 ```
347 @plain_data(frozen=True)
348 class MyClass:
349 __slots__ = "a", "b", "c"
350 def __init__(self, a, b, *, c):
351 self.a = a
352 self.b = b
353 self.c = c
354
355 v1 = MyClass(1, 2, c=3)
356 v2 = replace(v1, b=4)
357 assert v2 == MyClass(a=1, b=4, c=3)
358 assert v2 is not v1
359 ```
360 """
361 kwargs = {}
362 ty = type(pd)
363 # call fields on ty rather than pd to ensure we're not called with a
364 # class rather than an instance.
365 for name in fields(ty):
366 value = changes.pop(name, __NOT_SPECIFIED)
367 if value is __NOT_SPECIFIED:
368 kwargs[name] = getattr(pd, name)
369 else:
370 kwargs[name] = value
371 if len(changes) != 0:
372 raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
373 return ty(**kwargs)