finish implementing @plain_data()
[nmutil.git] / src / nmutil / plain_data.py
index b1912f1532439a58316785bf8e9607789b775a20..92c303daf323c73d6f4a9ee3c27be583f2214861 100644 (file)
@@ -12,6 +12,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     # slots is an ordered set by using dict keys.
     # always add __dict__ and __weakref__
     slots = {"__dict__": None, "__weakref__": None}
+    if frozen:
+        slots["__plain_data_init_done"] = None
     fields = []
     any_parents_have_dict = False
     any_parents_have_weakref = False
@@ -41,6 +43,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
                 elif field == "__weakref__":
                     any_parents_have_weakref = True
 
+    fields = tuple(fields)  # fields needs to be immutable
+
     if any_parents_have_dict:
         # work around a CPython bug that unnecessarily checks if parent
         # classes already have the __dict__ slot.
@@ -57,6 +61,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     for name in slots.keys():
         retval_dict.pop(name, None)
 
+    retval_dict["_fields"] = fields
+
     def add_method_or_error(value, replace=False):
         name = value.__name__
         if name in retval_dict and not replace:
@@ -66,8 +72,6 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
         retval_dict[name] = value
 
     if frozen:
-        slots["__plain_data_init_done"] = None
-
         def __setattr__(self, name: str, value):
             if getattr(self, "__plain_data_init_done", False):
                 raise FrozenPlainDataError(f"cannot assign to field {name!r}")
@@ -100,16 +104,12 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
                 object.__setattr__(self, "__plain_data_init_done", True)
 
         add_method_or_error(__init__, replace=True)
+    else:
+        old_init = None
 
     # set __slots__ to have everything we need in the preferred order
     retval_dict["__slots__"] = tuple(slots.keys())
 
-    def __dir__(self):
-        # don't return fields un-copied so users can't mess with it
-        return fields.copy()
-
-    add_method_or_error(__dir__)
-
     def __getstate__(self):
         # pickling support
         return [getattr(self, name) for name in fields]
@@ -186,9 +186,9 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
     # add __qualname__
     retval.__qualname__ = cls.__qualname__
 
-    # fixup super() and __class__
-    # derived from: https://stackoverflow.com/a/71666065/2597900
-    for value in retval.__dict__.values():
+    def fix_super_and_class(value):
+        # fixup super() and __class__
+        # derived from: https://stackoverflow.com/a/71666065/2597900
         try:
             closure = value.__closure__
             if isinstance(closure, tuple):
@@ -197,11 +197,18 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
         except (AttributeError, IndexError):
             pass
 
+    for value in retval.__dict__.values():
+        fix_super_and_class(value)
+
+    if old_init is not None:
+        fix_super_and_class(old_init)
+
     return retval
 
 
-def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True,
+def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
                frozen=False):
+    # defaults match dataclass, with the exception of `init`
     """ Decorator for adding equality comparison, ordered comparison,
     `repr` support, `hash` support, and frozen type (read-only fields)
     support to classes that are just plain data.