From: Jacob Lifshay Date: Fri, 12 Aug 2022 06:37:02 +0000 (-0700) Subject: fix prefix_sum.py after 63ffb1aa and d7288021 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=fcbd2f816210cd48879f05ac6d342f93409e03b4;p=nmutil.git fix prefix_sum.py after 63ffb1aa and d7288021 --- diff --git a/src/nmutil/prefix_sum.py b/src/nmutil/prefix_sum.py index 549d772..65f2b33 100644 --- a/src/nmutil/prefix_sum.py +++ b/src/nmutil/prefix_sum.py @@ -7,18 +7,29 @@ from collections import defaultdict import operator from nmigen.hdl.ast import Value, Const +from nmutil.plain_data import plain_data +@plain_data(order=True, unsafe_hash=True, frozen=True) class Op: """An associative operation in a prefix-sum. The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`. The operation is not assumed to be commutative. """ - def __init__(self,*, out, lhs, rhs, row): - self.out = out; "index of the item to output to" - self.lhs = lhs; "index of item the left-hand-side input comes from" - self.rhs = rhs; "index of item the right-hand-side input comes from" - self.row = row; "row in the prefix-sum diagram" + __slots__ = "out", "lhs", "rhs", "row" + + def __init__(self, out, lhs, rhs, row): + self.out = out + """index of the item to output to""" + + self.lhs = lhs + """index of the item the left-hand-side input comes from""" + + self.rhs = rhs + """index of the item the right-hand-side input comes from""" + + self.row = row + """row in the prefix-sum diagram""" def prefix_sum_ops(item_count, *, work_efficient=False): @@ -44,8 +55,9 @@ def prefix_sum_ops(item_count, *, work_efficient=False): Returns: Iterable[Op] output associative operations. """ + assert isinstance(item_count, int) # compute the partial sums using a set of binary trees - # first half of the work-efficient algorithm and the whole of + # this is the first half of the work-efficient algorithm and the whole of # the non-work-efficient algorithm. dist = 1 row = 0 @@ -96,8 +108,11 @@ def prefix_sum(items, fn=operator.add, *, work_efficient=False): return items +@plain_data() class _Cell: - def __init__(self, *, slant, plus, tee): + __slots__ = "slant", "plus", "tee" + + def __init__(self, slant, plus, tee): self.slant = slant self.plus = plus self.tee = tee @@ -138,6 +153,8 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False, """ ops_by_row = defaultdict(set) for op in prefix_sum_ops(item_count, work_efficient=work_efficient): + assert op.out == op.rhs, f"can't draw op: {op}" + assert op not in ops_by_row[op.row], f"duplicate op: {op}" ops_by_row[op.row].add(op) def blank_row(): @@ -151,6 +168,7 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False, max_distance = max(op.rhs - op.lhs for op in ops) cells.extend(blank_row() for _ in range(max_distance)) for op in ops: + assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}" y = len(cells) - 1 x = op.out cells[y][x].plus = True @@ -234,7 +252,10 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False): Returns: Iterable[Op] output associative operations. """ - items_live_flags = needed_outputs + + # needed_outputs is an iterable, we need to construct a new list so we + # don't modify the passed-in value + items_live_flags = [bool(i) for i in needed_outputs] ops = list(prefix_sum_ops(item_count=len(items_live_flags), work_efficient=work_efficient)) ops_live_flags = [False] * len(ops) @@ -251,6 +272,7 @@ def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False): def tree_reduction_ops(item_count): + assert item_count >= 1 needed_outputs = (i == item_count - 1 for i in range(item_count)) return partial_prefix_sum_ops(needed_outputs) @@ -263,13 +285,36 @@ def tree_reduction(items, fn=operator.add): def pop_count(v, *, width=None, process_temporary=lambda v: v): + """ return the population count (number of 1 bits) of `v`. + Arguments: + v: nmigen.Value | int + the value to calculate the pop-count of. + width: int | None + the bit-width of `v`. + If `width` is None, then `v` must be a nmigen Value or + match `v`'s width. + process_temporary: function of (type(v)) -> type(v) + called after every addition operation, can be used to introduce + `Signal`s for the intermediate values in the pop-count computation + like so: + + ``` + def process_temporary(v): + sig = Signal.like(v) + m.d.comb += sig.eq(v) + return sig + ``` + """ if isinstance(v, Value): if width is None: width = len(v) + assert width == len(v) bits = [v[i] for i in range(width)] if len(bits) == 0: return Const(0) else: + assert width is not None, "width must be given" + # v and width are ints bits = [(v & (1 << i)) != 0 for i in range(width)] if len(bits) == 0: return 0