add treereduce function
[nmutil.git] / src / nmutil / util.py
1 from collections.abc import Iterable
2
3 # XXX this already exists in nmigen._utils
4 # see https://bugs.libre-soc.org/show_bug.cgi?id=297
5 def flatten(v):
6 if isinstance(v, Iterable):
7 for i in v:
8 yield from flatten(i)
9 else:
10 yield v
11
12 # tree reduction function. operates recursively.
13 def treereduce(tree, op, attr="data_o"):
14 #print ("treereduce", tree)
15 if not isinstance(tree, list):
16 return tree
17 if len(tree) == 1:
18 return getattr(tree[0], attr)
19 if len(tree) == 2:
20 return op(getattr(tree[0], attr), getattr(tree[1], attr))
21 s = len(tree) // 2 # splitpoint
22 return treereduce(op(tree[:s], op, attr), treereduce(tree[s:], op, attr))