bd13219eb3f0f3d0a027aeb9dc3c83e804a14ce2
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
9 from nmigen
.hdl
.ast
import Value
, Const
10 from nmutil
.plain_data
import plain_data
13 @plain_data(order
=True, unsafe_hash
=True, frozen
=True)
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
19 __slots__
= "out", "lhs", "rhs", "row"
21 def __init__(self
, out
, lhs
, rhs
, row
):
22 self
.out
= out
; "index of the item to output to"
23 self
.lhs
= lhs
; "index of the item the left-hand-side input comes from"
24 self
.rhs
= rhs
; "index of the item the right-hand-side input comes from"
25 self
.row
= row
; "row in the prefix-sum diagram"
28 def prefix_sum_ops(item_count
, *, work_efficient
=False):
29 """Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
32 The operations aren't assumed to be commutative.
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
43 number of input items.
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
49 output associative operations.
51 assert isinstance(item_count
, int)
52 # compute the partial sums using a set of binary trees
53 # this is the first half of the work-efficient algorithm and the whole of
54 # the non-work-efficient algorithm.
57 while dist
< item_count
:
58 start
= dist
* 2 - 1 if work_efficient
else dist
59 step
= dist
* 2 if work_efficient
else 1
60 for i
in reversed(range(start
, item_count
, step
)):
61 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
65 # express all output items in terms of the computed partial sums.
68 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
69 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
74 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
75 """Compute the parallel prefix-sum of `items`, using associative operator
76 `fn` instead of addition.
78 This has a depth of `O(log(N))` and an operation count of `O(N)` if
79 `work_efficient` is true, otherwise `O(N*log(N))`.
81 The algorithms used are derived from:
82 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
88 fn: Callable[[_T, _T], _T]
89 Operation to use for the prefix-sum algorithm instead of addition.
90 Assumed to be associative not necessarily commutative.
92 True if the algorithm used should be work-efficient -- has a larger
93 depth (about twice as large) but does only `O(N)` operations total
94 instead of `O(N*log(N))`.
99 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
100 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
106 __slots__
= "slant", "plus", "tee"
108 def __init__(self
, slant
, plus
, tee
):
114 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
115 sp
=" ", vbar
="|", plus
="⊕",
116 slant
="\\", connect
="●", no_connect
="X",
119 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
123 number of input items.
125 True if the algorithm used should be work-efficient -- has a larger
126 depth (about twice as large) but does only `O(N)` operations total
127 instead of `O(N*log(N))`.
129 character used for blank space
131 character used for a vertical bar
133 character used for the addition operation
135 character used to draw a line from the top left to the bottom right
137 character used to draw a connection between a vertical line and a line
138 going from the center of this character to the bottom right
140 character used to draw two lines crossing but not connecting, the lines
141 are vertical and diagonal from top left to the bottom right
143 amount of padding characters in the output cells.
147 ops_by_row
= defaultdict(set)
148 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
149 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
150 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
151 ops_by_row
[op
.row
].add(op
)
154 return [_Cell(slant
=False, plus
=False, tee
=False)
155 for _
in range(item_count
)]
157 cells
= [blank_row()]
159 for row
in sorted(ops_by_row
.keys()):
160 ops
= ops_by_row
[row
]
161 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
162 cells
.extend(blank_row() for _
in range(max_distance
))
164 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
167 cells
[y
][x
].plus
= True
171 cells
[y
][x
].slant
= True
174 cells
[y
][x
].tee
= True
177 for cells_row
in cells
:
178 row_text
= [[] for y
in range(2 * padding
+ 1)]
179 for cell
in cells_row
:
181 for y
in range(padding
):
183 for x
in range(padding
):
184 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
185 row_text
[y
].append(slant
if is_slant
else sp
)
187 row_text
[y
].append(vbar
)
189 for x
in range(padding
):
190 row_text
[y
].append(sp
)
191 # center left padding
192 for x
in range(padding
):
193 row_text
[padding
].append(sp
)
202 row_text
[padding
].append(center
)
203 # center right padding
204 for x
in range(padding
):
205 row_text
[padding
].append(sp
)
207 for y
in range(padding
+ 1, 2 * padding
+ 1):
208 # bottom left padding
209 for x
in range(padding
):
210 row_text
[y
].append(sp
)
211 # bottom vertical bar
212 row_text
[y
].append(vbar
)
213 # bottom right padding
214 for x
in range(padding
+ 1, 2 * padding
+ 1):
215 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
216 row_text
[y
].append(slant
if is_slant
else sp
)
217 for line
in row_text
:
218 lines
.append("".join(line
))
220 return "\n".join(map(str.rstrip
, lines
))
223 def partial_prefix_sum_ops(needed_outputs
, *, work_efficient
=False):
224 """ Get the associative operations needed to compute a parallel prefix-sum
225 of `len(needed_outputs)` items.
227 The operations aren't assumed to be commutative.
229 This has a depth of `O(log(N))` and an operation count of `O(N)` if
230 `work_efficient` is true, otherwise `O(N*log(N))`.
232 The algorithms used are derived from:
233 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
234 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
237 needed_outputs: Iterable[bool]
238 The length is the number of input/output items.
239 Each item is True if that corresponding output is needed.
240 Unneeded outputs have unspecified value.
242 True if the algorithm used should be work-efficient -- has a larger
243 depth (about twice as large) but does only `O(N)` operations total
244 instead of `O(N*log(N))`.
245 Returns: Iterable[Op]
246 output associative operations.
249 # needed_outputs is an iterable, we need to construct a new list so we
250 # don't modify the passed-in value
251 items_live_flags
= [bool(i
) for i
in needed_outputs
]
252 ops
= list(prefix_sum_ops(item_count
=len(items_live_flags
),
253 work_efficient
=work_efficient
))
254 ops_live_flags
= [False] * len(ops
)
255 for i
in reversed(range(len(ops
))):
257 out_live
= items_live_flags
[op
.out
]
258 items_live_flags
[op
.out
] = False
259 items_live_flags
[op
.lhs
] |
= out_live
260 items_live_flags
[op
.rhs
] |
= out_live
261 ops_live_flags
[i
] = out_live
262 for op
, live_flag
in zip(ops
, ops_live_flags
):
267 def tree_reduction_ops(item_count
):
268 assert item_count
>= 1
269 needed_outputs
= (i
== item_count
- 1 for i
in range(item_count
))
270 return partial_prefix_sum_ops(needed_outputs
)
273 def tree_reduction(items
, fn
=operator
.add
):
275 for op
in tree_reduction_ops(len(items
)):
276 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
280 def pop_count(v
, *, width
=None, process_temporary
=lambda v
: v
):
281 """return the population count (number of 1 bits) of `v`.
283 v: nmigen.Value | int
284 the value to calculate the pop-count of.
286 the bit-width of `v`.
287 If `width` is None, then `v` must be a nmigen Value or
289 process_temporary: function of (type(v)) -> type(v)
290 called after every addition operation, can be used to introduce
291 `Signal`s for the intermediate values in the pop-count computation
295 def process_temporary(v):
297 m.d.comb += sig.eq(v)
301 if isinstance(v
, Value
):
304 assert width
== len(v
)
305 bits
= [v
[i
] for i
in range(width
)]
309 assert width
is not None, "width must be given"
310 # v and width are ints
311 bits
= [(v
& (1 << i
)) != 0 for i
in range(width
)]
314 return tree_reduction(bits
, fn
=lambda a
, b
: process_temporary(a
+ b
))
317 if __name__
== "__main__":
318 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
320 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
322 print(render_prefix_sum_diagram(16, work_efficient
=False))
325 print("the work-efficient algorithm, matches the diagram in wikipedia:")
326 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
328 print(render_prefix_sum_diagram(16, work_efficient
=True))