dispatcher: support arbitrary callables
[mdis.git] / src / mdis / dispatcher.py
1 import collections as _collections
2 import inspect as _inspect
3 import types as _types
4
5
6 class Hook(object):
7 def __init__(self, *typeids):
8 for typeid in typeids:
9 if not callable(typeid):
10 raise ValueError(typeid)
11 self.__typeids = typeids
12 return super().__init__()
13
14 def __iter__(self):
15 yield from self.__typeids
16
17 def __repr__(self):
18 names = []
19 for typeid in self.__typeids:
20 name = typeid.__qualname__
21 module = typeid.__module__
22 if module not in ("builtins",):
23 name = f"{module}.{name}"
24 names.append(name)
25 return f"<{', '.join(names)}>"
26
27 def __call__(self, call):
28 class ConcreteHook(Hook):
29 def __call__(self, dispatcher, instance):
30 return call(self=dispatcher, instance=instance)
31
32 return ConcreteHook(*tuple(self))
33
34
35 class DispatcherMeta(type):
36 __hooks__ = {}
37
38 def __new__(metacls, name, bases, ns):
39 hooks = {}
40 ishook = lambda member: isinstance(member, Hook)
41
42 for basecls in reversed(bases):
43 members = _inspect.getmembers(basecls, predicate=ishook)
44 for (_, hook) in members:
45 hooks.update(dict.fromkeys(hook, hook))
46
47 conflicts = _collections.defaultdict(list)
48 for (key, value) in tuple(ns.items()):
49 if not ishook(value):
50 continue
51 hook = value
52 for typeid in hook:
53 hooks[typeid] = hook
54 conflicts[typeid].append(key)
55 ns[key] = hook
56
57 for (typeid, keys) in conflicts.items():
58 if len(keys) > 1:
59 raise ValueError(f"dispatch conflict: {keys!r}")
60
61 ns["__hooks__"] = _types.MappingProxyType(hooks)
62
63 return super().__new__(metacls, name, bases, ns)
64
65 def dispatch(cls, typeid=object):
66 hook = cls.__hooks__.get(typeid)
67 if hook is not None:
68 return hook
69 for (checker, hook) in cls.__hooks__.items():
70 if not isinstance(checker, type) and checker(typeid):
71 return hook
72 return None
73
74
75 class Dispatcher(metaclass=DispatcherMeta):
76 def __call__(self, instance):
77 for typeid in instance.__class__.__mro__:
78 hook = self.__class__.dispatch(typeid=typeid)
79 if hook is not None:
80 break
81 if hook is None:
82 hook = self.__class__.dispatch()
83 return hook(dispatcher=self, instance=instance)
84
85 @Hook(object)
86 def dispatch_object(self, instance):
87 raise NotImplementedError()