add tree_reduction and pop_count based off of dead-code-elimination of prefix_sum_ops
[nmutil.git] / src / nmutil / prefix_sum.py
index 4aa89b2a3239fd3edf55792f236b30b3e486f60a..3908d1b5495c402059dca54040827e9fae941ee6 100644 (file)
@@ -7,6 +7,7 @@
 from collections import defaultdict
 from dataclasses import dataclass
 import operator
+from nmigen.hdl.ast import Value, Const
 
 
 @dataclass(order=True, unsafe_hash=True, frozen=True)
@@ -219,6 +220,48 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     return "\n".join(map(str.rstrip, lines))
 
 
+def tree_reduction_ops(item_count):
+    assert isinstance(item_count, int) and item_count >= 1
+    ops = list(prefix_sum_ops(item_count=item_count))
+    items_live_flags = [False] * item_count
+    items_live_flags[-1] = True
+    ops_live_flags = [False] * len(ops)
+    for i in reversed(range(len(ops))):
+        op = ops[i]
+        out_live = items_live_flags[op.out]
+        items_live_flags[op.out] = False
+        items_live_flags[op.lhs] |= out_live
+        items_live_flags[op.rhs] |= out_live
+        ops_live_flags[i] = out_live
+    for op, live_flag in zip(ops, ops_live_flags):
+        if live_flag:
+            yield op
+
+
+def tree_reduction(items, fn=operator.add):
+    items = list(items)
+    for op in tree_reduction_ops(len(items)):
+        items[op.out] = fn(items[op.lhs], items[op.rhs])
+    return items[-1]
+
+
+def pop_count(v, *, width=None, process_temporary=lambda v: v):
+    if isinstance(v, Value):
+        if width is None:
+            width = len(v)
+        assert width == len(v)
+        bits = [v[i] for i in range(width)]
+        if len(bits) == 0:
+            return Const(0)
+    else:
+        assert isinstance(width, int) and width >= 0
+        assert isinstance(v, int)
+        bits = [(v & (1 << i)) != 0 for i in range(width)]
+        if len(bits) == 0:
+            return 0
+    return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
+
+
 if __name__ == "__main__":
     print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
           "\n"