speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / deduped.py
1 import functools
2 import weakref
3
4
5 class _KeyBuilder:
6 def __init__(self, do_delete):
7 self.__keys = []
8 self.__refs = {}
9 self.__do_delete = do_delete
10
11 def add_ref(self, v):
12 v_id = id(v)
13 if v_id in self.__refs:
14 return
15 try:
16 v = weakref.ref(v, callback=self.__do_delete)
17 except TypeError:
18 pass
19 self.__refs[v_id] = v
20
21 def add(self, k, v):
22 self.__keys.append(id(k))
23 self.__keys.append(id(v))
24 self.add_ref(k)
25 self.add_ref(v)
26
27 def finish(self):
28 return tuple(self.__keys), tuple(self.__refs.values())
29
30
31 def deduped(*, global_keys=()):
32 """decorator that causes functions to deduplicate their results based on
33 their input args and the requested globals. For each set of arguments, it
34 will always return the exact same object, by storing it internally.
35 Arguments are compared by their identity, so they don't need to be
36 hashable.
37
38 Usage:
39 ```
40 # for functions that don't depend on global variables
41 @deduped()
42 def my_fn1(a, b, *, c=1):
43 return a + b * c
44
45 my_global = 23
46
47 # for functions that depend on global variables
48 @deduped(global_keys=[lambda: my_global])
49 def my_fn2(a, b, *, c=2):
50 return a + b * c + my_global
51 ```
52 """
53 global_keys = tuple(global_keys)
54 assert all(map(callable, global_keys))
55
56 def decorator(f):
57 if isinstance(f, (staticmethod, classmethod)):
58 raise TypeError("@staticmethod or @classmethod should be applied "
59 "to the result of @deduped, not the other way"
60 " around")
61 assert callable(f)
62
63 map = {}
64
65 @functools.wraps(f)
66 def wrapper(*args, **kwargs):
67 key_builder = _KeyBuilder(lambda _: map.pop(key, None))
68 for arg in args:
69 key_builder.add(None, arg)
70 for k, v in kwargs.items():
71 key_builder.add(k, v)
72 for global_key in global_keys:
73 key_builder.add(None, global_key())
74 key, refs = key_builder.finish()
75 if key in map:
76 return map[key][0]
77 retval = f(*args, **kwargs)
78 # keep reference to stuff used for key to avoid ids
79 # getting reused for something else.
80 map[key] = retval, refs
81 return retval
82 return wrapper
83 return decorator