all: simplify walking
[mdis.git] / src / mdis / dispatcher.py
1 __all__ = [
2 "Dispatcher",
3 "DispatcherMeta",
4 "Hook",
5 ]
6
7 import collections
8 import functools
9 import inspect
10 import types
11
12
13 class Hook(object):
14 def __init__(self, *typeids):
15 for typeid in typeids:
16 if not callable(typeid):
17 raise ValueError(typeid)
18 self.__typeids = typeids
19 return super().__init__()
20
21 def __iter__(self):
22 yield from self.__typeids
23
24 def __repr__(self):
25 names = []
26 for typeid in self.__typeids:
27 name = typeid.__qualname__
28 module = typeid.__module__
29 if module not in ("builtins",):
30 name = f"{module}.{name}"
31 names.append(name)
32 return f"<{', '.join(names)}>"
33
34 def __call__(self, call):
35 class ConcreteHook(Hook):
36 @functools.wraps(call)
37 def __call__(self, dispatcher, node, *args, **kwargs):
38 # We do not force specific arguments other than node.
39 # API users can introduce additional *args and **kwargs.
40 # However, in case they choose not to, this is fine too.
41 parameters = tuple(inspect.signature(call).parameters.values())
42 if len(parameters) < 2:
43 raise TypeError(f"{call.__name__}: missing required arguments")
44 if parameters[0].kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
45 raise TypeError(f"{call.__name__}: incorrect self argument")
46 if parameters[1].kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
47 raise TypeError(f"{call.__name__}: incorrect node argument")
48 args_present = False
49 kwargs_present = False
50 for parameter in parameters:
51 positionals = (
52 inspect.Parameter.POSITIONAL_OR_KEYWORD,
53 inspect.Parameter.VAR_POSITIONAL,
54 )
55 keywords = (
56 inspect.Parameter.POSITIONAL_OR_KEYWORD,
57 inspect.Parameter.VAR_KEYWORD,
58 inspect.Parameter.KEYWORD_ONLY,
59 )
60 if parameter.kind in positionals:
61 args_present = True
62 elif parameter.kind in keywords:
63 kwargs_present = True
64 if args_present and kwargs_present:
65 return call(dispatcher, node, *args, **kwargs)
66 elif args_present:
67 return call(dispatcher, node, *args)
68 elif kwargs_present:
69 return call(dispatcher, node, **kwargs)
70 else:
71 return call(dispatcher, node)
72
73 return ConcreteHook(*tuple(self))
74
75
76 class DispatcherMeta(type):
77 __hooks__ = {}
78
79 def __new__(metacls, name, bases, ns):
80 hooks = {}
81 ishook = lambda member: isinstance(member, Hook)
82
83 for basecls in reversed(bases):
84 members = inspect.getmembers(basecls, predicate=ishook)
85 for (_, hook) in members:
86 hooks.update(dict.fromkeys(hook, hook))
87
88 conflicts = collections.defaultdict(list)
89 for (key, value) in tuple(ns.items()):
90 if not ishook(value):
91 continue
92 hook = value
93 for typeid in hook:
94 hooks[typeid] = hook
95 conflicts[typeid].append(key)
96 ns[key] = hook
97
98 for (typeid, keys) in conflicts.items():
99 if len(keys) > 1:
100 raise ValueError(f"dispatch conflict: {keys!r}")
101
102 ns["__hooks__"] = types.MappingProxyType(hooks)
103
104 return super().__new__(metacls, name, bases, ns)
105
106 @functools.lru_cache(maxsize=None)
107 def dispatch(cls, typeid=object):
108 hook = cls.__hooks__.get(typeid)
109 if hook is not None:
110 return hook
111 for (checker, hook) in cls.__hooks__.items():
112 if not isinstance(checker, type) and checker(typeid):
113 return hook
114 return None
115
116
117 class Dispatcher(metaclass=DispatcherMeta):
118 def __call__(self, node, *args, **kwargs):
119 for typeid in node.__class__.__mro__:
120 hook = self.__class__.dispatch(typeid=typeid)
121 if hook is not None:
122 break
123 if hook is None:
124 hook = self.__class__.dispatch()
125 return hook(self, node)
126
127 @Hook(object)
128 def dispatch_object(self, node):
129 raise NotImplementedError()