8455039ce298d31637bd6b8c584bc3840c53b5ec
[mdis.git] / src / mdis / dispatcher.py
1 __all__ = [
2 "Dispatcher",
3 "DispatcherMeta",
4 "Hook",
5 ]
6
7 import collections
8 import functools
9 import inspect
10 import types
11
12
13 class Hook(object):
14 def __init__(self, *typeids):
15 for typeid in typeids:
16 if not callable(typeid):
17 raise ValueError(typeid)
18 self.__typeids = typeids
19 return super().__init__()
20
21 def __iter__(self):
22 yield from self.__typeids
23
24 def __repr__(self):
25 names = []
26 for typeid in self.__typeids:
27 name = typeid.__qualname__
28 module = typeid.__module__
29 if module not in ("builtins",):
30 name = f"{module}.{name}"
31 names.append(name)
32 return f"<{', '.join(names)}>"
33
34 def __call__(self, call):
35 class ConcreteHook(Hook):
36 def __call__(self, dispatcher, node, *arguments):
37 if (len(inspect.signature(call).parameters) > 2):
38 return call(dispatcher, node, *arguments)
39 else:
40 return call(dispatcher, node)
41
42 return ConcreteHook(*tuple(self))
43
44
45 class DispatcherMeta(type):
46 __hooks__ = {}
47
48 def __new__(metacls, name, bases, ns):
49 hooks = {}
50 ishook = lambda member: isinstance(member, Hook)
51
52 for basecls in reversed(bases):
53 members = inspect.getmembers(basecls, predicate=ishook)
54 for (_, hook) in members:
55 hooks.update(dict.fromkeys(hook, hook))
56
57 conflicts = collections.defaultdict(list)
58 for (key, value) in tuple(ns.items()):
59 if not ishook(value):
60 continue
61 hook = value
62 for typeid in hook:
63 hooks[typeid] = hook
64 conflicts[typeid].append(key)
65 ns[key] = hook
66
67 for (typeid, keys) in conflicts.items():
68 if len(keys) > 1:
69 raise ValueError(f"dispatch conflict: {keys!r}")
70
71 ns["__hooks__"] = types.MappingProxyType(hooks)
72
73 return super().__new__(metacls, name, bases, ns)
74
75 @functools.lru_cache(maxsize=None)
76 def dispatch(cls, typeid=object):
77 hook = cls.__hooks__.get(typeid)
78 if hook is not None:
79 return hook
80 for (checker, hook) in cls.__hooks__.items():
81 if not isinstance(checker, type) and checker(typeid):
82 return hook
83 return None
84
85
86 class Dispatcher(metaclass=DispatcherMeta):
87 def __call__(self, node, *arguments):
88 for typeid in node.__class__.__mro__:
89 hook = self.__class__.dispatch(typeid=typeid)
90 if hook is not None:
91 break
92 if hook is None:
93 hook = self.__class__.dispatch()
94 return hook(self, node, *arguments)
95
96 @Hook(object)
97 def dispatch_object(self, node, *arguments):
98 raise NotImplementedError()