speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / prefix_sum.py
index 4aa89b2a3239fd3edf55792f236b30b3e486f60a..23eca36e2bb748c296c5a7ca88b9fa578258c653 100644 (file)
@@ -5,28 +5,35 @@
 # of Horizon 2020 EU Programme 957073.
 
 from collections import defaultdict
-from dataclasses import dataclass
 import operator
+from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
 
 
-@dataclass(order=True, unsafe_hash=True, frozen=True)
+@plain_data(order=True, unsafe_hash=True, frozen=True)
 class Op:
     """An associative operation in a prefix-sum.
     The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
     The operation is not assumed to be commutative.
     """
-    out: int
-    """index of the item to output to"""
-    lhs: int
-    """index of the item the left-hand-side input comes from"""
-    rhs: int
-    """index of the item the right-hand-side input comes from"""
-    row: int
-    """row in the prefix-sum diagram"""
+    __slots__ = "out", "lhs", "rhs", "row"
+
+    def __init__(self, out, lhs, rhs, row):
+        self.out = out
+        "index of the item to output to"
+
+        self.lhs = lhs
+        "index of the item the left-hand-side input comes from"
+
+        self.rhs = rhs
+        "index of the item the right-hand-side input comes from"
+
+        self.row = row
+        "row in the prefix-sum diagram"
 
 
 def prefix_sum_ops(item_count, *, work_efficient=False):
-    """ Get the associative operations needed to compute a parallel prefix-sum
+    """Get the associative operations needed to compute a parallel prefix-sum
     of `item_count` items.
 
     The operations aren't assumed to be commutative.
@@ -72,7 +79,7 @@ def prefix_sum_ops(item_count, *, work_efficient=False):
 
 
 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
-    """ Compute the parallel prefix-sum of `items`, using associative operator
+    """Compute the parallel prefix-sum of `items`, using associative operator
     `fn` instead of addition.
 
     This has a depth of `O(log(N))` and an operation count of `O(N)` if
@@ -101,11 +108,14 @@ def prefix_sum(items, fn=operator.add, *, work_efficient=False):
     return items
 
 
-@dataclass
+@plain_data()
 class _Cell:
-    slant: bool
-    plus: bool
-    tee: bool
+    __slots__ = "slant", "plus", "tee"
+
+    def __init__(self, slant, plus, tee):
+        self.slant = slant
+        self.plus = plus
+        self.tee = tee
 
 
 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
@@ -141,8 +151,6 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     Returns: str
         rendered diagram
     """
-    assert isinstance(item_count, int)
-    assert isinstance(padding, int)
     ops_by_row = defaultdict(set)
     for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
         assert op.out == op.rhs, f"can't draw op: {op}"
@@ -219,6 +227,63 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False,
     return "\n".join(map(str.rstrip, lines))
 
 
+def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
+    """ Get the associative operations needed to compute a parallel prefix-sum
+    of `len(needed_outputs)` items.
+
+    The operations aren't assumed to be commutative.
+
+    This has a depth of `O(log(N))` and an operation count of `O(N)` if
+    `work_efficient` is true, otherwise `O(N*log(N))`.
+
+    The algorithms used are derived from:
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+    https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+    Parameters:
+    needed_outputs: Iterable[bool]
+        The length is the number of input/output items.
+        Each item is True if that corresponding output is needed.
+        Unneeded outputs have unspecified value.
+    work_efficient: bool
+        True if the algorithm used should be work-efficient -- has a larger
+        depth (about twice as large) but does only `O(N)` operations total
+        instead of `O(N*log(N))`.
+    Returns: Iterable[Op]
+        output associative operations.
+    """
+
+    # needed_outputs is an iterable, we need to construct a new list so we
+    # don't modify the passed-in value
+    items_live_flags = [bool(i) for i in needed_outputs]
+    ops = list(prefix_sum_ops(item_count=len(items_live_flags),
+                              work_efficient=work_efficient))
+    ops_live_flags = [False] * len(ops)
+    for i in reversed(range(len(ops))):
+        op = ops[i]
+        out_live = items_live_flags[op.out]
+        items_live_flags[op.out] = False
+        items_live_flags[op.lhs] |= out_live
+        items_live_flags[op.rhs] |= out_live
+        ops_live_flags[i] = out_live
+    for op, live_flag in zip(ops, ops_live_flags):
+        if live_flag:
+            yield op
+
+
+def tree_reduction_ops(item_count):
+    assert item_count >= 1
+    needed_outputs = (i == item_count - 1 for i in range(item_count))
+    return partial_prefix_sum_ops(needed_outputs)
+
+
+def tree_reduction(items, fn=operator.add):
+    items = list(items)
+    for op in tree_reduction_ops(len(items)):
+        items[op.out] = fn(items[op.lhs], items[op.rhs])
+    return items[-1]
+
+
 if __name__ == "__main__":
     print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
           "\n"