from collections import defaultdict
import operator
from nmigen.hdl.ast import Value, Const
+from nmutil.plain_data import plain_data
+@plain_data(order=True, unsafe_hash=True, frozen=True)
class Op:
"""An associative operation in a prefix-sum.
The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
The operation is not assumed to be commutative.
"""
- def __init__(self,*, out, lhs, rhs, row):
- self.out = out; "index of the item to output to"
- self.lhs = lhs; "index of item the left-hand-side input comes from"
- self.rhs = rhs; "index of item the right-hand-side input comes from"
- self.row = row; "row in the prefix-sum diagram"
+ __slots__ = "out", "lhs", "rhs", "row"
+
+ def __init__(self, out, lhs, rhs, row):
+ self.out = out
+ "index of the item to output to"
+
+ self.lhs = lhs
+ "index of the item the left-hand-side input comes from"
+
+ self.rhs = rhs
+ "index of the item the right-hand-side input comes from"
+
+ self.row = row
+ "row in the prefix-sum diagram"
def prefix_sum_ops(item_count, *, work_efficient=False):
- """ Get the associative operations needed to compute a parallel prefix-sum
+ """Get the associative operations needed to compute a parallel prefix-sum
of `item_count` items.
The operations aren't assumed to be commutative.
Returns: Iterable[Op]
output associative operations.
"""
+ assert isinstance(item_count, int)
# compute the partial sums using a set of binary trees
- # first half of the work-efficient algorithm and the whole of
+ # this is the first half of the work-efficient algorithm and the whole of
# the non-work-efficient algorithm.
dist = 1
row = 0
def prefix_sum(items, fn=operator.add, *, work_efficient=False):
- """ Compute the parallel prefix-sum of `items`, using associative operator
+ """Compute the parallel prefix-sum of `items`, using associative operator
`fn` instead of addition.
This has a depth of `O(log(N))` and an operation count of `O(N)` if
return items
+@plain_data()
class _Cell:
- def __init__(self, *, slant, plus, tee):
+ __slots__ = "slant", "plus", "tee"
+
+ def __init__(self, slant, plus, tee):
self.slant = slant
self.plus = plus
self.tee = tee
"""
ops_by_row = defaultdict(set)
for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
+ assert op.out == op.rhs, f"can't draw op: {op}"
+ assert op not in ops_by_row[op.row], f"duplicate op: {op}"
ops_by_row[op.row].add(op)
def blank_row():
max_distance = max(op.rhs - op.lhs for op in ops)
cells.extend(blank_row() for _ in range(max_distance))
for op in ops:
+ assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
y = len(cells) - 1
x = op.out
cells[y][x].plus = True
Returns: Iterable[Op]
output associative operations.
"""
- items_live_flags = needed_outputs
+
+ # needed_outputs is an iterable, we need to construct a new list so we
+ # don't modify the passed-in value
+ items_live_flags = [bool(i) for i in needed_outputs]
ops = list(prefix_sum_ops(item_count=len(items_live_flags),
work_efficient=work_efficient))
ops_live_flags = [False] * len(ops)
def tree_reduction_ops(item_count):
+ assert item_count >= 1
needed_outputs = (i == item_count - 1 for i in range(item_count))
return partial_prefix_sum_ops(needed_outputs)
return items[-1]
-def pop_count(v, *, width=None, process_temporary=lambda v: v):
- if isinstance(v, Value):
- if width is None:
- width = len(v)
- bits = [v[i] for i in range(width)]
- if len(bits) == 0:
- return Const(0)
- else:
- bits = [(v & (1 << i)) != 0 for i in range(width)]
- if len(bits) == 0:
- return 0
- return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
-
-
if __name__ == "__main__":
print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
"\n"