# of Horizon 2020 EU Programme 957073.
from collections import defaultdict
-from dataclasses import dataclass
import operator
+from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
-@dataclass(order=True, unsafe_hash=True, frozen=True)
+@plain_data(order=True, unsafe_hash=True, frozen=True)
class Op:
"""An associative operation in a prefix-sum.
The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
The operation is not assumed to be commutative.
"""
- out: int
- """index of the item to output to"""
- lhs: int
- """index of the item the left-hand-side input comes from"""
- rhs: int
- """index of the item the right-hand-side input comes from"""
- row: int
- """row in the prefix-sum diagram"""
+ __slots__ = "out", "lhs", "rhs", "row"
+
+ def __init__(self, out, lhs, rhs, row):
+ self.out = out
+ "index of the item to output to"
+
+ self.lhs = lhs
+ "index of the item the left-hand-side input comes from"
+
+ self.rhs = rhs
+ "index of the item the right-hand-side input comes from"
+
+ self.row = row
+ "row in the prefix-sum diagram"
def prefix_sum_ops(item_count, *, work_efficient=False):
- """ Get the associative operations needed to compute a parallel prefix-sum
+ """Get the associative operations needed to compute a parallel prefix-sum
of `item_count` items.
The operations aren't assumed to be commutative.
def prefix_sum(items, fn=operator.add, *, work_efficient=False):
- """ Compute the parallel prefix-sum of `items`, using associative operator
+ """Compute the parallel prefix-sum of `items`, using associative operator
`fn` instead of addition.
This has a depth of `O(log(N))` and an operation count of `O(N)` if
return items
-@dataclass
+@plain_data()
class _Cell:
- slant: bool
- plus: bool
- tee: bool
+ __slots__ = "slant", "plus", "tee"
+
+ def __init__(self, slant, plus, tee):
+ self.slant = slant
+ self.plus = plus
+ self.tee = tee
def render_prefix_sum_diagram(item_count, *, work_efficient=False,
Returns: str
rendered diagram
"""
- assert isinstance(item_count, int)
- assert isinstance(padding, int)
ops_by_row = defaultdict(set)
for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
assert op.out == op.rhs, f"can't draw op: {op}"
return "\n".join(map(str.rstrip, lines))
+def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
+ """ Get the associative operations needed to compute a parallel prefix-sum
+ of `len(needed_outputs)` items.
+
+ The operations aren't assumed to be commutative.
+
+ This has a depth of `O(log(N))` and an operation count of `O(N)` if
+ `work_efficient` is true, otherwise `O(N*log(N))`.
+
+ The algorithms used are derived from:
+ https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
+ https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
+
+ Parameters:
+ needed_outputs: Iterable[bool]
+ The length is the number of input/output items.
+ Each item is True if that corresponding output is needed.
+ Unneeded outputs have unspecified value.
+ work_efficient: bool
+ True if the algorithm used should be work-efficient -- has a larger
+ depth (about twice as large) but does only `O(N)` operations total
+ instead of `O(N*log(N))`.
+ Returns: Iterable[Op]
+ output associative operations.
+ """
+
+ # needed_outputs is an iterable, we need to construct a new list so we
+ # don't modify the passed-in value
+ items_live_flags = [bool(i) for i in needed_outputs]
+ ops = list(prefix_sum_ops(item_count=len(items_live_flags),
+ work_efficient=work_efficient))
+ ops_live_flags = [False] * len(ops)
+ for i in reversed(range(len(ops))):
+ op = ops[i]
+ out_live = items_live_flags[op.out]
+ items_live_flags[op.out] = False
+ items_live_flags[op.lhs] |= out_live
+ items_live_flags[op.rhs] |= out_live
+ ops_live_flags[i] = out_live
+ for op, live_flag in zip(ops, ops_live_flags):
+ if live_flag:
+ yield op
+
+
+def tree_reduction_ops(item_count):
+ assert item_count >= 1
+ needed_outputs = (i == item_count - 1 for i in range(item_count))
+ return partial_prefix_sum_ops(needed_outputs)
+
+
+def tree_reduction(items, fn=operator.add):
+ items = list(items)
+ for op in tree_reduction_ops(len(items)):
+ items[op.out] = fn(items[op.lhs], items[op.rhs])
+ return items[-1]
+
+
if __name__ == "__main__":
print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
"\n"