From 1b02a4882e2eaacd1a2fdb31f9ee302e346b33d1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 4 Aug 2022 23:15:30 -0700 Subject: [PATCH] add partial_prefix_sum_ops --- src/nmutil/prefix_sum.py | 42 +++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/src/nmutil/prefix_sum.py b/src/nmutil/prefix_sum.py index 3908d1b..fc2f3d1 100644 --- a/src/nmutil/prefix_sum.py +++ b/src/nmutil/prefix_sum.py @@ -220,11 +220,37 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False, return "\n".join(map(str.rstrip, lines)) -def tree_reduction_ops(item_count): - assert isinstance(item_count, int) and item_count >= 1 - ops = list(prefix_sum_ops(item_count=item_count)) - items_live_flags = [False] * item_count - items_live_flags[-1] = True +def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False): + """ Get the associative operations needed to compute a parallel prefix-sum + of `len(needed_outputs)` items. + + The operations aren't assumed to be commutative. + + This has a depth of `O(log(N))` and an operation count of `O(N)` if + `work_efficient` is true, otherwise `O(N*log(N))`. + + The algorithms used are derived from: + https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel + https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient + + Parameters: + needed_outputs: Iterable[bool] + The length is the number of input/output items. + Each item is True if that corresponding output is needed. + Unneeded outputs have unspecified value. + work_efficient: bool + True if the algorithm used should be work-efficient -- has a larger + depth (about twice as large) but does only `O(N)` operations total + instead of `O(N*log(N))`. + Returns: Iterable[Op] + output associative operations. + """ + def assert_bool(v): + assert isinstance(v, bool) + return v + items_live_flags = [assert_bool(i) for i in needed_outputs] + ops = list(prefix_sum_ops(item_count=len(items_live_flags), + work_efficient=work_efficient)) ops_live_flags = [False] * len(ops) for i in reversed(range(len(ops))): op = ops[i] @@ -238,6 +264,12 @@ def tree_reduction_ops(item_count): yield op +def tree_reduction_ops(item_count): + assert isinstance(item_count, int) and item_count >= 1 + needed_outputs = (i == item_count - 1 for i in range(item_count)) + return partial_prefix_sum_ops(needed_outputs) + + def tree_reduction(items, fn=operator.add): items = list(items) for op in tree_reduction_ops(len(items)): -- 2.30.2