speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / prefix_sum.py
index 549d772d78dc25aaa13aa8cf003e8cc90416971b..23eca36e2bb748c296c5a7ca88b9fa578258c653 100644 (file)
@@ -7,22 +7,33 @@
 from collections import defaultdict
 import operator
 from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
 
 
+@plain_data(order=True, unsafe_hash=True, frozen=True)
 class Op:
     """An associative operation in a prefix-sum.
     The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
     The operation is not assumed to be commutative.
     """
-    def __init__(self,*, out, lhs, rhs, row):
-        self.out = out; "index of the item to output to"
-        self.lhs = lhs; "index of item the left-hand-side input comes from"
-        self.rhs = rhs; "index of item the right-hand-side input comes from"
-        self.row = row; "row in the prefix-sum diagram"
+    __slots__ = "out", "lhs", "rhs", "row"
+
+    def __init__(self, out, lhs, rhs, row):
+        self.out = out
+        "index of the item to output to"
+
+        self.lhs = lhs
+        "index of the item the left-hand-side input comes from"
+
+        self.rhs = rhs
+        "index of the item the right-hand-side input comes from"
+
+        self.row = row
+        "row in the prefix-sum diagram"
 
 
 def prefix_sum_ops(item_count, *, work_efficient=False):
-    """ Get the associative operations needed to compute a parallel prefix-sum
+    """Get the associative operations needed to compute a parallel prefix-sum
     of `item_count` items.
 
     The operations aren't assumed to be commutative.
@@ -44,8 +55,9 @@ def prefix_sum_ops(item_count, *, work_efficient=False):
     Returns: Iterable[Op]
         output associative operations.
     """
+    assert isinstance(item_count, int)
     # compute the partial sums using a set of binary trees
-    # first half of the work-efficient algorithm and the whole of
+    # this is the first half of the work-efficient algorithm and the whole of
     # the non-work-efficient algorithm.
     dist = 1
     row = 0
@@ -67,7 +79,7 @@ def prefix_sum_ops(item_count, *, work_efficient=False):
 
 
 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
-    """ Compute the parallel prefix-sum of `items`, using associative operator
+    """Compute the parallel prefix-sum of `items`, using associative operator
     `fn` instead of addition.
 
     This has a depth of `O(log(N))` and an operation count of `O(N)` if
@@ -96,8 +108,11 @@ def prefix_sum(items, fn=operator.add, *, work_efficient=False):
     return items
 
 
+@plain_data()
 class _Cell:
-    def __init__(self, *, slant, plus, tee):
+    __slots__ = "slant", "plus", "tee"
+
+    def __init__(self, slant, plus, tee):
         self.slant = slant
         self.plus = plus
         self.tee = tee
@@ -138,6 +153,8 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     """
     ops_by_row = defaultdict(set)
     for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
+        assert op.out == op.rhs, f"can't draw op: {op}"
+        assert op not in ops_by_row[op.row], f"duplicate op: {op}"
         ops_by_row[op.row].add(op)
 
     def blank_row():
@@ -151,6 +168,7 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
         max_distance = max(op.rhs - op.lhs for op in ops)
         cells.extend(blank_row() for _ in range(max_distance))
         for op in ops:
+            assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
             y = len(cells) - 1
             x = op.out
             cells[y][x].plus = True
@@ -234,7 +252,10 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
     Returns: Iterable[Op]
         output associative operations.
     """
-    items_live_flags = needed_outputs
+
+    # needed_outputs is an iterable, we need to construct a new list so we
+    # don't modify the passed-in value
+    items_live_flags = [bool(i) for i in needed_outputs]
     ops = list(prefix_sum_ops(item_count=len(items_live_flags),
                               work_efficient=work_efficient))
     ops_live_flags = [False] * len(ops)
@@ -251,6 +272,7 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
 
 
 def tree_reduction_ops(item_count):
+    assert item_count >= 1
     needed_outputs = (i == item_count - 1 for i in range(item_count))
     return partial_prefix_sum_ops(needed_outputs)
 
@@ -262,20 +284,6 @@ def tree_reduction(items, fn=operator.add):
     return items[-1]
 
 
-def pop_count(v, *, width=None, process_temporary=lambda v: v):
-    if isinstance(v, Value):
-        if width is None:
-            width = len(v)
-        bits = [v[i] for i in range(width)]
-        if len(bits) == 0:
-            return Const(0)
-    else:
-        bits = [(v & (1 << i)) != 0 for i in range(width)]
-        if len(bits) == 0:
-            return 0
-    return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
-
-
 if __name__ == "__main__":
     print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
           "\n"