b1912f1532439a58316785bf8e9607789b775a20
[nmutil.git] / src / nmutil / plain_data.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 class FrozenPlainDataError(AttributeError):
5 pass
6
7
8 def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
9 if not isinstance(cls, type):
10 raise TypeError(
11 "plain_data() can only be used as a class decorator")
12 # slots is an ordered set by using dict keys.
13 # always add __dict__ and __weakref__
14 slots = {"__dict__": None, "__weakref__": None}
15 fields = []
16 any_parents_have_dict = False
17 any_parents_have_weakref = False
18 for cur_cls in reversed(cls.__mro__):
19 if cur_cls is object:
20 continue
21 try:
22 cur_slots = cur_cls.__slots__
23 except AttributeError as e:
24 raise TypeError(f"{cur_cls.__module__}.{cur_cls.__qualname__}"
25 " must have __slots__ so plain_data() can "
26 "determine what fields exist in "
27 f"{cls.__module__}.{cls.__qualname__}") from e
28 if not isinstance(cur_slots, tuple):
29 raise TypeError("plain_data() requires __slots__ to be a "
30 "tuple of str")
31 for field in cur_slots:
32 if not isinstance(field, str):
33 raise TypeError("plain_data() requires __slots__ to be a "
34 "tuple of str")
35 if field not in slots:
36 fields.append(field)
37 slots[field] = None
38 if cur_cls is not cls:
39 if field == "__dict__":
40 any_parents_have_dict = True
41 elif field == "__weakref__":
42 any_parents_have_weakref = True
43
44 if any_parents_have_dict:
45 # work around a CPython bug that unnecessarily checks if parent
46 # classes already have the __dict__ slot.
47 del slots["__dict__"]
48
49 if any_parents_have_weakref:
50 # work around a CPython bug that unnecessarily checks if parent
51 # classes already have the __weakref__ slot.
52 del slots["__weakref__"]
53
54 # now create a new class having everything we need
55 retval_dict = dict(cls.__dict__)
56 # remove all old descriptors:
57 for name in slots.keys():
58 retval_dict.pop(name, None)
59
60 def add_method_or_error(value, replace=False):
61 name = value.__name__
62 if name in retval_dict and not replace:
63 raise TypeError(
64 f"can't generate {name} method: attribute already exists")
65 value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
66 retval_dict[name] = value
67
68 if frozen:
69 slots["__plain_data_init_done"] = None
70
71 def __setattr__(self, name: str, value):
72 if getattr(self, "__plain_data_init_done", False):
73 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
74 elif name not in slots and not name.startswith("_"):
75 raise AttributeError(
76 f"cannot assign to unknown field {name!r}")
77 object.__setattr__(self, name, value)
78
79 add_method_or_error(__setattr__)
80
81 def __delattr__(self, name):
82 if getattr(self, "__plain_data_init_done", False):
83 raise FrozenPlainDataError(f"cannot delete field {name!r}")
84 object.__delattr__(self, name)
85
86 add_method_or_error(__delattr__)
87
88 old_init = cls.__init__
89
90 def __init__(self, *args, **kwargs):
91 if hasattr(self, "__plain_data_init_done"):
92 # we're already in an __init__ call (probably a
93 # superclass's __init__), don't set
94 # __plain_data_init_done too early
95 return old_init(self, *args, **kwargs)
96 object.__setattr__(self, "__plain_data_init_done", False)
97 try:
98 return old_init(self, *args, **kwargs)
99 finally:
100 object.__setattr__(self, "__plain_data_init_done", True)
101
102 add_method_or_error(__init__, replace=True)
103
104 # set __slots__ to have everything we need in the preferred order
105 retval_dict["__slots__"] = tuple(slots.keys())
106
107 def __dir__(self):
108 # don't return fields un-copied so users can't mess with it
109 return fields.copy()
110
111 add_method_or_error(__dir__)
112
113 def __getstate__(self):
114 # pickling support
115 return [getattr(self, name) for name in fields]
116
117 add_method_or_error(__getstate__)
118
119 def __setstate__(self, state):
120 # pickling support
121 for name, value in zip(fields, state):
122 # bypass frozen setattr
123 object.__setattr__(self, name, value)
124
125 add_method_or_error(__setstate__)
126
127 # get a tuple of all fields
128 def fields_tuple(self):
129 return tuple(getattr(self, name) for name in fields)
130
131 if eq:
132 def __eq__(self, other):
133 if other.__class__ is not self.__class__:
134 return NotImplemented
135 return fields_tuple(self) == fields_tuple(other)
136
137 add_method_or_error(__eq__)
138
139 if unsafe_hash:
140 def __hash__(self):
141 return hash(fields_tuple(self))
142
143 add_method_or_error(__hash__)
144
145 if order:
146 def __lt__(self, other):
147 if other.__class__ is not self.__class__:
148 return NotImplemented
149 return fields_tuple(self) < fields_tuple(other)
150
151 add_method_or_error(__lt__)
152
153 def __le__(self, other):
154 if other.__class__ is not self.__class__:
155 return NotImplemented
156 return fields_tuple(self) <= fields_tuple(other)
157
158 add_method_or_error(__le__)
159
160 def __gt__(self, other):
161 if other.__class__ is not self.__class__:
162 return NotImplemented
163 return fields_tuple(self) > fields_tuple(other)
164
165 add_method_or_error(__gt__)
166
167 def __ge__(self, other):
168 if other.__class__ is not self.__class__:
169 return NotImplemented
170 return fields_tuple(self) >= fields_tuple(other)
171
172 add_method_or_error(__ge__)
173
174 if repr_:
175 def __repr__(self):
176 parts = []
177 for name in fields:
178 parts.append(f"{name}={getattr(self, name)!r}")
179 return f"{self.__class__.__qualname__}({', '.join(parts)})"
180
181 add_method_or_error(__repr__)
182
183 # construct class
184 retval = type(cls)(cls.__name__, cls.__bases__, retval_dict)
185
186 # add __qualname__
187 retval.__qualname__ = cls.__qualname__
188
189 # fixup super() and __class__
190 # derived from: https://stackoverflow.com/a/71666065/2597900
191 for value in retval.__dict__.values():
192 try:
193 closure = value.__closure__
194 if isinstance(closure, tuple):
195 if closure[0].cell_contents is cls:
196 closure[0].cell_contents = retval
197 except (AttributeError, IndexError):
198 pass
199
200 return retval
201
202
203 def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True,
204 frozen=False):
205 """ Decorator for adding equality comparison, ordered comparison,
206 `repr` support, `hash` support, and frozen type (read-only fields)
207 support to classes that are just plain data.
208
209 This is kinda like dataclasses, but uses `__slots__` instead of type
210 annotations, as well as requiring you to write your own `__init__`
211 """
212 def decorator(cls):
213 return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
214 repr_=repr, frozen=frozen)
215 return decorator