1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 class FrozenPlainDataError(AttributeError):
9 """ helper for __repr__ for when fields aren't set """
15 __NOT_SET
= __NotSet()
18 def __ignored_classes():
19 classes
= [object] # type: list[type]
26 Generic
, SupportsAbs
, SupportsBytes
, SupportsComplex
, SupportsFloat
,
27 SupportsInt
, SupportsRound
)
30 Generic
, SupportsAbs
, SupportsBytes
, SupportsComplex
, SupportsFloat
,
31 SupportsInt
, SupportsRound
]
33 from collections
.abc
import (
34 Awaitable
, Coroutine
, AsyncIterable
, AsyncIterator
, AsyncGenerator
,
35 Hashable
, Iterable
, Iterator
, Generator
, Reversible
, Sized
, Container
,
36 Callable
, Collection
, Set
, MutableSet
, Mapping
, MutableMapping
,
37 MappingView
, KeysView
, ItemsView
, ValuesView
, Sequence
,
41 Awaitable
, Coroutine
, AsyncIterable
, AsyncIterator
, AsyncGenerator
,
42 Hashable
, Iterable
, Iterator
, Generator
, Reversible
, Sized
, Container
,
43 Callable
, Collection
, Set
, MutableSet
, Mapping
, MutableMapping
,
44 MappingView
, KeysView
, ItemsView
, ValuesView
, Sequence
,
47 # rest aren't supported by python 3.7, so try to import them and skip if
51 # typing_extensions uses typing.Protocol if available
52 from typing_extensions
import Protocol
53 classes
.append(Protocol
)
58 yield from cls
.__mro
__
61 __IGNORED_CLASSES
= frozenset(__ignored_classes())
64 def _decorator(cls
, *, eq
, unsafe_hash
, order
, repr_
, frozen
):
65 if not isinstance(cls
, type):
67 "plain_data() can only be used as a class decorator")
68 # slots is an ordered set by using dict keys.
69 # always add __dict__ and __weakref__
70 slots
= {"__dict__": None, "__weakref__": None}
72 slots
["__plain_data_init_done"] = None
74 any_parents_have_dict
= False
75 any_parents_have_weakref
= False
76 for cur_cls
in reversed(cls
.__mro
__):
77 d
= getattr(cur_cls
, "__dict__", {})
78 if cur_cls
is not cls
:
80 any_parents_have_dict
= True
81 if "__weakref__" in d
:
82 any_parents_have_weakref
= True
83 if cur_cls
in __IGNORED_CLASSES
:
86 cur_slots
= cur_cls
.__slots
__
87 except AttributeError as e
:
88 raise TypeError(f
"{cur_cls.__module__}.{cur_cls.__qualname__}"
89 " must have __slots__ so plain_data() can "
90 "determine what fields exist in "
91 f
"{cls.__module__}.{cls.__qualname__}") from e
92 if not isinstance(cur_slots
, tuple):
93 raise TypeError("plain_data() requires __slots__ to be a "
95 for field
in cur_slots
:
96 if not isinstance(field
, str):
97 raise TypeError("plain_data() requires __slots__ to be a "
99 if field
not in slots
:
103 fields
= tuple(fields
) # fields needs to be immutable
105 if any_parents_have_dict
:
106 # work around a CPython bug that unnecessarily checks if parent
107 # classes already have the __dict__ slot.
108 del slots
["__dict__"]
110 if any_parents_have_weakref
:
111 # work around a CPython bug that unnecessarily checks if parent
112 # classes already have the __weakref__ slot.
113 del slots
["__weakref__"]
115 # now create a new class having everything we need
116 retval_dict
= dict(cls
.__dict
__)
117 # remove all old descriptors:
118 for name
in slots
.keys():
119 retval_dict
.pop(name
, None)
121 retval_dict
["__plain_data_fields"] = fields
123 def add_method_or_error(value
, replace
=False):
124 name
= value
.__name
__
125 if name
in retval_dict
and not replace
:
127 f
"can't generate {name} method: attribute already exists")
128 value
.__qualname
__ = f
"{cls.__qualname__}.{value.__name__}"
129 retval_dict
[name
] = value
132 def __setattr__(self
, name
: str, value
):
133 if getattr(self
, "__plain_data_init_done", False):
134 raise FrozenPlainDataError(f
"cannot assign to field {name!r}")
135 elif name
not in slots
and not name
.startswith("_"):
136 raise AttributeError(
137 f
"cannot assign to unknown field {name!r}")
138 object.__setattr
__(self
, name
, value
)
140 add_method_or_error(__setattr__
)
142 def __delattr__(self
, name
):
143 if getattr(self
, "__plain_data_init_done", False):
144 raise FrozenPlainDataError(f
"cannot delete field {name!r}")
145 object.__delattr
__(self
, name
)
147 add_method_or_error(__delattr__
)
149 old_init
= cls
.__init
__
151 def __init__(self
, *args
, **kwargs
):
152 if hasattr(self
, "__plain_data_init_done"):
153 # we're already in an __init__ call (probably a
154 # superclass's __init__), don't set
155 # __plain_data_init_done too early
156 return old_init(self
, *args
, **kwargs
)
157 object.__setattr
__(self
, "__plain_data_init_done", False)
159 return old_init(self
, *args
, **kwargs
)
161 object.__setattr
__(self
, "__plain_data_init_done", True)
163 add_method_or_error(__init__
, replace
=True)
167 # set __slots__ to have everything we need in the preferred order
168 retval_dict
["__slots__"] = tuple(slots
.keys())
170 def __getstate__(self
):
172 return [getattr(self
, name
) for name
in fields
]
174 add_method_or_error(__getstate__
)
176 def __setstate__(self
, state
):
178 for name
, value
in zip(fields
, state
):
179 # bypass frozen setattr
180 object.__setattr
__(self
, name
, value
)
182 add_method_or_error(__setstate__
)
184 # get a tuple of all fields
185 def fields_tuple(self
):
186 return tuple(getattr(self
, name
) for name
in fields
)
189 def __eq__(self
, other
):
190 if other
.__class
__ is not self
.__class
__:
191 return NotImplemented
192 return fields_tuple(self
) == fields_tuple(other
)
194 add_method_or_error(__eq__
)
198 return hash(fields_tuple(self
))
200 add_method_or_error(__hash__
)
203 def __lt__(self
, other
):
204 if other
.__class
__ is not self
.__class
__:
205 return NotImplemented
206 return fields_tuple(self
) < fields_tuple(other
)
208 add_method_or_error(__lt__
)
210 def __le__(self
, other
):
211 if other
.__class
__ is not self
.__class
__:
212 return NotImplemented
213 return fields_tuple(self
) <= fields_tuple(other
)
215 add_method_or_error(__le__
)
217 def __gt__(self
, other
):
218 if other
.__class
__ is not self
.__class
__:
219 return NotImplemented
220 return fields_tuple(self
) > fields_tuple(other
)
222 add_method_or_error(__gt__
)
224 def __ge__(self
, other
):
225 if other
.__class
__ is not self
.__class
__:
226 return NotImplemented
227 return fields_tuple(self
) >= fields_tuple(other
)
229 add_method_or_error(__ge__
)
235 parts
.append(f
"{name}={getattr(self, name, __NOT_SET)!r}")
236 return f
"{self.__class__.__qualname__}({', '.join(parts)})"
238 add_method_or_error(__repr__
)
241 retval
= type(cls
)(cls
.__name
__, cls
.__bases
__, retval_dict
)
244 retval
.__qualname
__ = cls
.__qualname
__
246 def fix_super_and_class(value
):
247 # fixup super() and __class__
248 # derived from: https://stackoverflow.com/a/71666065/2597900
250 closure
= value
.__closure
__
251 if isinstance(closure
, tuple):
252 if closure
[0].cell_contents
is cls
:
253 closure
[0].cell_contents
= retval
254 except (AttributeError, IndexError):
257 for value
in retval
.__dict
__.values():
258 fix_super_and_class(value
)
260 if old_init
is not None:
261 fix_super_and_class(old_init
)
266 def plain_data(*, eq
=True, unsafe_hash
=False, order
=False, repr=True,
268 # defaults match dataclass, with the exception of `init`
269 """ Decorator for adding equality comparison, ordered comparison,
270 `repr` support, `hash` support, and frozen type (read-only fields)
271 support to classes that are just plain data.
273 This is kinda like dataclasses, but uses `__slots__` instead of type
274 annotations, as well as requiring you to write your own `__init__`
277 return _decorator(cls
, eq
=eq
, unsafe_hash
=unsafe_hash
, order
=order
,
278 repr_
=repr, frozen
=frozen
)
283 """ get the tuple of field names of the passed-in
284 `@plain_data()`-decorated class.
286 This is similar to `dataclasses.fields`, except this returns a
289 Returns: tuple[str, ...]
295 __slots__ = "a_field", "field2"
296 def __init__(self, a_field, field2):
297 self.a_field = a_field
300 assert fields(MyBaseClass) == ("a_field", "field2")
301 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
304 class MyClass(MyBaseClass):
305 __slots__ = "child_field",
306 def __init__(self, a_field, field2, child_field):
307 super().__init__(a_field=a_field, field2=field2)
308 self.child_field = child_field
310 assert fields(MyClass) == ("a_field", "field2", "child_field")
311 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
314 retval
= getattr(pd
, "__plain_data_fields", None)
315 if not isinstance(retval
, tuple):
316 raise TypeError("the passed-in object must be a class or an instance"
317 " of a class decorated with @plain_data()")
321 __NOT_SPECIFIED
= object()
324 def replace(pd
, **changes
):
325 """ Return a new instance of the passed-in `@plain_data()`-decorated
326 object, but with the specified fields replaced with new values.
327 This is quite useful with frozen `@plain_data()` classes.
331 @plain_data(frozen=True)
333 __slots__ = "a", "b", "c"
334 def __init__(self, a, b, *, c):
339 v1 = MyClass(1, 2, c=3)
340 v2 = replace(v1, b=4)
341 assert v2 == MyClass(a=1, b=4, c=3)
347 # call fields on ty rather than pd to ensure we're not called with a
348 # class rather than an instance.
349 for name
in fields(ty
):
350 value
= changes
.pop(name
, __NOT_SPECIFIED
)
351 if value
is __NOT_SPECIFIED
:
352 kwargs
[name
] = getattr(pd
, name
)
355 if len(changes
) != 0:
356 raise TypeError(f
"can't set unknown field {changes.popitem()[0]!r}")