add tree_reduction and pop_count based off of dead-code-elimination of prefix_sum_ops
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Aug 2022 05:53:08 +0000 (22:53 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Aug 2022 05:53:08 +0000 (22:53 -0700)
src/nmutil/prefix_sum.py
src/nmutil/test/test_prefix_sum.py

index 4aa89b2a3239fd3edf55792f236b30b3e486f60a..3908d1b5495c402059dca54040827e9fae941ee6 100644 (file)
@@ -7,6 +7,7 @@
 from collections import defaultdict
 from dataclasses import dataclass
 import operator
+from nmigen.hdl.ast import Value, Const
 
 
 @dataclass(order=True, unsafe_hash=True, frozen=True)
@@ -219,6 +220,48 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     return "\n".join(map(str.rstrip, lines))
 
 
+def tree_reduction_ops(item_count):
+    assert isinstance(item_count, int) and item_count >= 1
+    ops = list(prefix_sum_ops(item_count=item_count))
+    items_live_flags = [False] * item_count
+    items_live_flags[-1] = True
+    ops_live_flags = [False] * len(ops)
+    for i in reversed(range(len(ops))):
+        op = ops[i]
+        out_live = items_live_flags[op.out]
+        items_live_flags[op.out] = False
+        items_live_flags[op.lhs] |= out_live
+        items_live_flags[op.rhs] |= out_live
+        ops_live_flags[i] = out_live
+    for op, live_flag in zip(ops, ops_live_flags):
+        if live_flag:
+            yield op
+
+
+def tree_reduction(items, fn=operator.add):
+    items = list(items)
+    for op in tree_reduction_ops(len(items)):
+        items[op.out] = fn(items[op.lhs], items[op.rhs])
+    return items[-1]
+
+
+def pop_count(v, *, width=None, process_temporary=lambda v: v):
+    if isinstance(v, Value):
+        if width is None:
+            width = len(v)
+        assert width == len(v)
+        bits = [v[i] for i in range(width)]
+        if len(bits) == 0:
+            return Const(0)
+    else:
+        assert isinstance(width, int) and width >= 0
+        assert isinstance(v, int)
+        bits = [(v & (1 << i)) != 0 for i in range(width)]
+        if len(bits) == 0:
+            return 0
+    return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
+
+
 if __name__ == "__main__":
     print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
           "\n"
index 63aa68ee94c998102f40d3c16ab161c40f08dcb3..7f1c45a6d0503f5fd012248ce9ecde72f1f3be07 100644 (file)
@@ -4,11 +4,17 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+from functools import reduce
 from nmutil.formaltest import FHDLTestCase
+from nmutil.sim_util import write_il
 from itertools import accumulate
 import operator
-from nmutil.prefix_sum import prefix_sum, render_prefix_sum_diagram
+from nmutil.prefix_sum import (Op, pop_count, prefix_sum,
+                               render_prefix_sum_diagram,
+                               tree_reduction, tree_reduction_ops)
 import unittest
+from nmigen.hdl.ast import Signal, AnyConst, Assert
+from nmigen.hdl.dsl import Module
 
 
 def reference_prefix_sum(items, fn):
@@ -28,6 +34,132 @@ class TestPrefixSum(FHDLTestCase):
             work_efficient = prefix_sum(input_items, work_efficient=True)
             self.assertEqual(expected, work_efficient)
 
+    def test_tree_reduction_str(self):
+        input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
+        expected = reduce(operator.add, input_items)
+        with self.subTest(expected=repr(expected)):
+            work_efficient = tree_reduction(input_items)
+            self.assertEqual(expected, work_efficient)
+
+    def test_tree_reduction_ops_9(self):
+        ops = list(tree_reduction_ops(9))
+        self.assertEqual(ops, [
+            Op(out=8, lhs=7, rhs=8, row=0),
+            Op(out=6, lhs=5, rhs=6, row=0),
+            Op(out=4, lhs=3, rhs=4, row=0),
+            Op(out=2, lhs=1, rhs=2, row=0),
+            Op(out=8, lhs=6, rhs=8, row=1),
+            Op(out=4, lhs=2, rhs=4, row=1),
+            Op(out=8, lhs=4, rhs=8, row=2),
+            Op(out=8, lhs=0, rhs=8, row=3),
+        ])
+
+    def test_tree_reduction_ops_8(self):
+        ops = list(tree_reduction_ops(8))
+        self.assertEqual(ops, [
+            Op(out=7, lhs=6, rhs=7, row=0),
+            Op(out=5, lhs=4, rhs=5, row=0),
+            Op(out=3, lhs=2, rhs=3, row=0),
+            Op(out=1, lhs=0, rhs=1, row=0),
+            Op(out=7, lhs=5, rhs=7, row=1),
+            Op(out=3, lhs=1, rhs=3, row=1),
+            Op(out=7, lhs=3, rhs=7, row=2),
+        ])
+
+    def tst_pop_count_int(self, width):
+        assert isinstance(width, int)
+        for v in range(1 << width):
+            expected = f"{v:b}".count("1")
+            with self.subTest(v=v, expected=expected):
+                self.assertEqual(expected, pop_count(v, width=width))
+
+    def test_pop_count_int_0(self):
+        self.tst_pop_count_int(0)
+
+    def test_pop_count_int_1(self):
+        self.tst_pop_count_int(1)
+
+    def test_pop_count_int_2(self):
+        self.tst_pop_count_int(2)
+
+    def test_pop_count_int_3(self):
+        self.tst_pop_count_int(3)
+
+    def test_pop_count_int_4(self):
+        self.tst_pop_count_int(4)
+
+    def test_pop_count_int_5(self):
+        self.tst_pop_count_int(5)
+
+    def test_pop_count_int_6(self):
+        self.tst_pop_count_int(6)
+
+    def test_pop_count_int_7(self):
+        self.tst_pop_count_int(7)
+
+    def test_pop_count_int_8(self):
+        self.tst_pop_count_int(8)
+
+    def test_pop_count_int_9(self):
+        self.tst_pop_count_int(9)
+
+    def test_pop_count_int_10(self):
+        self.tst_pop_count_int(10)
+
+    def tst_pop_count_formal(self, width):
+        assert isinstance(width, int)
+        m = Module()
+        v = Signal(width)
+        out = Signal(16)
+
+        def process_temporary(v):
+            sig = Signal.like(v)
+            m.d.comb += sig.eq(v)
+            return sig
+
+        m.d.comb += out.eq(pop_count(v, process_temporary=process_temporary))
+        write_il(self, m, [v, out])
+        m.d.comb += v.eq(AnyConst(width))
+        expected = Signal(16)
+        m.d.comb += expected.eq(reduce(operator.add,
+                                       (v[i] for i in range(width)),
+                                       0))
+        m.d.comb += Assert(out == expected)
+        self.assertFormal(m)
+
+    def test_pop_count_formal_0(self):
+        self.tst_pop_count_formal(0)
+
+    def test_pop_count_formal_1(self):
+        self.tst_pop_count_formal(1)
+
+    def test_pop_count_formal_2(self):
+        self.tst_pop_count_formal(2)
+
+    def test_pop_count_formal_3(self):
+        self.tst_pop_count_formal(3)
+
+    def test_pop_count_formal_4(self):
+        self.tst_pop_count_formal(4)
+
+    def test_pop_count_formal_5(self):
+        self.tst_pop_count_formal(5)
+
+    def test_pop_count_formal_6(self):
+        self.tst_pop_count_formal(6)
+
+    def test_pop_count_formal_7(self):
+        self.tst_pop_count_formal(7)
+
+    def test_pop_count_formal_8(self):
+        self.tst_pop_count_formal(8)
+
+    def test_pop_count_formal_9(self):
+        self.tst_pop_count_formal(9)
+
+    def test_pop_count_formal_10(self):
+        self.tst_pop_count_formal(10)
+
     def test_render_work_efficient(self):
         text = render_prefix_sum_diagram(16, work_efficient=True, plus="@")
         expected = r"""