add support for plain_data __repr__ with fields that aren't set
[nmutil.git] / src / nmutil / plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 class FrozenPlainDataError(AttributeError):
5 pass
6
7
8 class __NotSet:
9 """ helper for __repr__ for when fields aren't set """
10
11 def __repr__(self):
12 return "<not set>"
13
14
15 __NOT_SET = __NotSet()
16
17
18 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
19 if not isinstance(cls, type):
20 raise TypeError(
21 "plain_data() can only be used as a class decorator")
22 # slots is an ordered set by using dict keys.
23 # always add __dict__ and __weakref__
24 slots = {"__dict__": None, "__weakref__": None}
25 if frozen:
26 slots["__plain_data_init_done"] = None
27 fields = []
28 any_parents_have_dict = False
29 any_parents_have_weakref = False
30 for cur_cls in reversed(cls.__mro__):
31 if cur_cls is object:
32 continue
33 try:
34 cur_slots = cur_cls.__slots__
35 except AttributeError as e:
36 raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
37 " must have __slots__ so plain_data() can "
38 "determine what fields exist in "
39 f"{cls.__module__}.{cls.__qualname__}") from e
40 if not isinstance(cur_slots, tuple):
41 raise TypeError("plain_data() requires __slots__ to be a "
42 "tuple of str")
43 for field in cur_slots:
44 if not isinstance(field, str):
45 raise TypeError("plain_data() requires __slots__ to be a "
46 "tuple of str")
47 if field not in slots:
48 fields.append(field)
49 slots[field] = None
50 if cur_cls is not cls:
51 if field == "__dict__":
52 any_parents_have_dict = True
53 elif field == "__weakref__":
54 any_parents_have_weakref = True
55
56 fields = tuple(fields) # fields needs to be immutable
57
58 if any_parents_have_dict:
59 # work around a CPython bug that unnecessarily checks if parent
60 # classes already have the __dict__ slot.
61 del slots["__dict__"]
62
63 if any_parents_have_weakref:
64 # work around a CPython bug that unnecessarily checks if parent
65 # classes already have the __weakref__ slot.
66 del slots["__weakref__"]
67
68 # now create a new class having everything we need
69 retval_dict = dict(cls.__dict__)
70 # remove all old descriptors:
71 for name in slots.keys():
72 retval_dict.pop(name, None)
73
74 retval_dict["__plain_data_fields"] = fields
75
76 def add_method_or_error(value, replace=False):
77 name = value.__name__
78 if name in retval_dict and not replace:
79 raise TypeError(
80 f"can't generate {name} method: attribute already exists")
81 value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
82 retval_dict[name] = value
83
84 if frozen:
85 def __setattr__(self, name: str, value):
86 if getattr(self, "__plain_data_init_done", False):
87 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
88 elif name not in slots and not name.startswith("_"):
89 raise AttributeError(
90 f"cannot assign to unknown field {name!r}")
91 object.__setattr__(self, name, value)
92
93 add_method_or_error(__setattr__)
94
95 def __delattr__(self, name):
96 if getattr(self, "__plain_data_init_done", False):
97 raise FrozenPlainDataError(f"cannot delete field {name!r}")
98 object.__delattr__(self, name)
99
100 add_method_or_error(__delattr__)
101
102 old_init = cls.__init__
103
104 def __init__(self, *args, **kwargs):
105 if hasattr(self, "__plain_data_init_done"):
106 # we're already in an __init__ call (probably a
107 # superclass's __init__), don't set
108 # __plain_data_init_done too early
109 return old_init(self, *args, **kwargs)
110 object.__setattr__(self, "__plain_data_init_done", False)
111 try:
112 return old_init(self, *args, **kwargs)
113 finally:
114 object.__setattr__(self, "__plain_data_init_done", True)
115
116 add_method_or_error(__init__, replace=True)
117 else:
118 old_init = None
119
120 # set __slots__ to have everything we need in the preferred order
121 retval_dict["__slots__"] = tuple(slots.keys())
122
123 def __getstate__(self):
124 # pickling support
125 return [getattr(self, name) for name in fields]
126
127 add_method_or_error(__getstate__)
128
129 def __setstate__(self, state):
130 # pickling support
131 for name, value in zip(fields, state):
132 # bypass frozen setattr
133 object.__setattr__(self, name, value)
134
135 add_method_or_error(__setstate__)
136
137 # get a tuple of all fields
138 def fields_tuple(self):
139 return tuple(getattr(self, name) for name in fields)
140
141 if eq:
142 def __eq__(self, other):
143 if other.__class__ is not self.__class__:
144 return NotImplemented
145 return fields_tuple(self) == fields_tuple(other)
146
147 add_method_or_error(__eq__)
148
149 if unsafe_hash:
150 def __hash__(self):
151 return hash(fields_tuple(self))
152
153 add_method_or_error(__hash__)
154
155 if order:
156 def __lt__(self, other):
157 if other.__class__ is not self.__class__:
158 return NotImplemented
159 return fields_tuple(self) < fields_tuple(other)
160
161 add_method_or_error(__lt__)
162
163 def __le__(self, other):
164 if other.__class__ is not self.__class__:
165 return NotImplemented
166 return fields_tuple(self) <= fields_tuple(other)
167
168 add_method_or_error(__le__)
169
170 def __gt__(self, other):
171 if other.__class__ is not self.__class__:
172 return NotImplemented
173 return fields_tuple(self) > fields_tuple(other)
174
175 add_method_or_error(__gt__)
176
177 def __ge__(self, other):
178 if other.__class__ is not self.__class__:
179 return NotImplemented
180 return fields_tuple(self) >= fields_tuple(other)
181
182 add_method_or_error(__ge__)
183
184 if repr_:
185 def __repr__(self):
186 parts = []
187 for name in fields:
188 parts.append(f"{name}={getattr(self, name, __NOT_SET)!r}")
189 return f"{self.__class__.__qualname__}({', '.join(parts)})"
190
191 add_method_or_error(__repr__)
192
193 # construct class
194 retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
195
196 # add __qualname__
197 retval.__qualname__ = cls.__qualname__
198
199 def fix_super_and_class(value):
200 # fixup super() and __class__
201 # derived from: https://stackoverflow.com/a/71666065/2597900
202 try:
203 closure = value.__closure__
204 if isinstance(closure, tuple):
205 if closure[0].cell_contents is cls:
206 closure[0].cell_contents = retval
207 except (AttributeError, IndexError):
208 pass
209
210 for value in retval.__dict__.values():
211 fix_super_and_class(value)
212
213 if old_init is not None:
214 fix_super_and_class(old_init)
215
216 return retval
217
218
219 def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
220 frozen=False):
221 # defaults match dataclass, with the exception of `init`
222 """ Decorator for adding equality comparison, ordered comparison,
223 `repr` support, `hash` support, and frozen type (read-only fields)
224 support to classes that are just plain data.
225
226 This is kinda like dataclasses, but uses `__slots__` instead of type
227 annotations, as well as requiring you to write your own `__init__`
228 """
229 def decorator(cls):
230 return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
231 repr_=repr, frozen=frozen)
232 return decorator
233
234
235 def fields(pd):
236 """ get the tuple of field names of the passed-in
237 `@plain_data()`-decorated class.
238
239 This is similar to `dataclasses.fields`, except this returns a
240 different type.
241
242 Returns: tuple[str, ...]
243
244 e.g.:
245 ```
246 @plain_data()
247 class MyBaseClass:
248 __slots__ = "a_field", "field2"
249 def __init__(self, a_field, field2):
250 self.a_field = a_field
251 self.field2 = field2
252
253 assert fields(MyBaseClass) == ("a_field", "field2")
254 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
255
256 @plain_data()
257 class MyClass(MyBaseClass):
258 __slots__ = "child_field",
259 def __init__(self, a_field, field2, child_field):
260 super().__init__(a_field=a_field, field2=field2)
261 self.child_field = child_field
262
263 assert fields(MyClass) == ("a_field", "field2", "child_field")
264 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
265 ```
266 """
267 retval = getattr(pd, "__plain_data_fields", None)
268 if not isinstance(retval, tuple):
269 raise TypeError("the passed-in object must be a class or an instance"
270 " of a class decorated with @plain_data()")
271 return retval
272
273
274 __NOT_SPECIFIED = object()
275
276
277 def replace(pd, **changes):
278 """ Return a new instance of the passed-in `@plain_data()`-decorated
279 object, but with the specified fields replaced with new values.
280 This is quite useful with frozen `@plain_data()` classes.
281
282 e.g.:
283 ```
284 @plain_data(frozen=True)
285 class MyClass:
286 __slots__ = "a", "b", "c"
287 def __init__(self, a, b, *, c):
288 self.a = a
289 self.b = b
290 self.c = c
291
292 v1 = MyClass(1, 2, c=3)
293 v2 = replace(v1, b=4)
294 assert v2 == MyClass(a=1, b=4, c=3)
295 assert v2 is not v1
296 ```
297 """
298 kwargs = {}
299 ty = type(pd)
300 # call fields on ty rather than pd to ensure we're not called with a
301 # class rather than an instance.
302 for name in fields(ty):
303 value = changes.pop(name, __NOT_SPECIFIED)
304 if value is __NOT_SPECIFIED:
305 kwargs[name] = getattr(pd, name)
306 else:
307 kwargs[name] = value
308 if len(changes) != 0:
309 raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
310 return ty(**kwargs)