549d772d78dc25aaa13aa8cf003e8cc90416971b
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
9 from nmigen
.hdl
.ast
import Value
, Const
13 """An associative operation in a prefix-sum.
14 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
15 The operation is not assumed to be commutative.
17 def __init__(self
,*, out
, lhs
, rhs
, row
):
18 self
.out
= out
; "index of the item to output to"
19 self
.lhs
= lhs
; "index of item the left-hand-side input comes from"
20 self
.rhs
= rhs
; "index of item the right-hand-side input comes from"
21 self
.row
= row
; "row in the prefix-sum diagram"
24 def prefix_sum_ops(item_count
, *, work_efficient
=False):
25 """ Get the associative operations needed to compute a parallel prefix-sum
26 of `item_count` items.
28 The operations aren't assumed to be commutative.
30 This has a depth of `O(log(N))` and an operation count of `O(N)` if
31 `work_efficient` is true, otherwise `O(N*log(N))`.
33 The algorithms used are derived from:
34 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
35 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
39 number of input items.
41 True if the algorithm used should be work-efficient -- has a larger
42 depth (about twice as large) but does only `O(N)` operations total
43 instead of `O(N*log(N))`.
45 output associative operations.
47 # compute the partial sums using a set of binary trees
48 # first half of the work-efficient algorithm and the whole of
49 # the non-work-efficient algorithm.
52 while dist
< item_count
:
53 start
= dist
* 2 - 1 if work_efficient
else dist
54 step
= dist
* 2 if work_efficient
else 1
55 for i
in reversed(range(start
, item_count
, step
)):
56 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
60 # express all output items in terms of the computed partial sums.
63 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
64 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
69 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
70 """ Compute the parallel prefix-sum of `items`, using associative operator
71 `fn` instead of addition.
73 This has a depth of `O(log(N))` and an operation count of `O(N)` if
74 `work_efficient` is true, otherwise `O(N*log(N))`.
76 The algorithms used are derived from:
77 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
78 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
83 fn: Callable[[_T, _T], _T]
84 Operation to use for the prefix-sum algorithm instead of addition.
85 Assumed to be associative not necessarily commutative.
87 True if the algorithm used should be work-efficient -- has a larger
88 depth (about twice as large) but does only `O(N)` operations total
89 instead of `O(N*log(N))`.
94 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
95 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
100 def __init__(self
, *, slant
, plus
, tee
):
106 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
107 sp
=" ", vbar
="|", plus
="⊕",
108 slant
="\\", connect
="●", no_connect
="X",
111 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
115 number of input items.
117 True if the algorithm used should be work-efficient -- has a larger
118 depth (about twice as large) but does only `O(N)` operations total
119 instead of `O(N*log(N))`.
121 character used for blank space
123 character used for a vertical bar
125 character used for the addition operation
127 character used to draw a line from the top left to the bottom right
129 character used to draw a connection between a vertical line and a line
130 going from the center of this character to the bottom right
132 character used to draw two lines crossing but not connecting, the lines
133 are vertical and diagonal from top left to the bottom right
135 amount of padding characters in the output cells.
139 ops_by_row
= defaultdict(set)
140 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
141 ops_by_row
[op
.row
].add(op
)
144 return [_Cell(slant
=False, plus
=False, tee
=False)
145 for _
in range(item_count
)]
147 cells
= [blank_row()]
149 for row
in sorted(ops_by_row
.keys()):
150 ops
= ops_by_row
[row
]
151 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
152 cells
.extend(blank_row() for _
in range(max_distance
))
156 cells
[y
][x
].plus
= True
160 cells
[y
][x
].slant
= True
163 cells
[y
][x
].tee
= True
166 for cells_row
in cells
:
167 row_text
= [[] for y
in range(2 * padding
+ 1)]
168 for cell
in cells_row
:
170 for y
in range(padding
):
172 for x
in range(padding
):
173 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
174 row_text
[y
].append(slant
if is_slant
else sp
)
176 row_text
[y
].append(vbar
)
178 for x
in range(padding
):
179 row_text
[y
].append(sp
)
180 # center left padding
181 for x
in range(padding
):
182 row_text
[padding
].append(sp
)
191 row_text
[padding
].append(center
)
192 # center right padding
193 for x
in range(padding
):
194 row_text
[padding
].append(sp
)
196 for y
in range(padding
+ 1, 2 * padding
+ 1):
197 # bottom left padding
198 for x
in range(padding
):
199 row_text
[y
].append(sp
)
200 # bottom vertical bar
201 row_text
[y
].append(vbar
)
202 # bottom right padding
203 for x
in range(padding
+ 1, 2 * padding
+ 1):
204 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
205 row_text
[y
].append(slant
if is_slant
else sp
)
206 for line
in row_text
:
207 lines
.append("".join(line
))
209 return "\n".join(map(str.rstrip
, lines
))
212 def partial_prefix_sum_ops(needed_outputs
, *, work_efficient
=False):
213 """ Get the associative operations needed to compute a parallel prefix-sum
214 of `len(needed_outputs)` items.
216 The operations aren't assumed to be commutative.
218 This has a depth of `O(log(N))` and an operation count of `O(N)` if
219 `work_efficient` is true, otherwise `O(N*log(N))`.
221 The algorithms used are derived from:
222 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
223 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
226 needed_outputs: Iterable[bool]
227 The length is the number of input/output items.
228 Each item is True if that corresponding output is needed.
229 Unneeded outputs have unspecified value.
231 True if the algorithm used should be work-efficient -- has a larger
232 depth (about twice as large) but does only `O(N)` operations total
233 instead of `O(N*log(N))`.
234 Returns: Iterable[Op]
235 output associative operations.
237 items_live_flags
= needed_outputs
238 ops
= list(prefix_sum_ops(item_count
=len(items_live_flags
),
239 work_efficient
=work_efficient
))
240 ops_live_flags
= [False] * len(ops
)
241 for i
in reversed(range(len(ops
))):
243 out_live
= items_live_flags
[op
.out
]
244 items_live_flags
[op
.out
] = False
245 items_live_flags
[op
.lhs
] |
= out_live
246 items_live_flags
[op
.rhs
] |
= out_live
247 ops_live_flags
[i
] = out_live
248 for op
, live_flag
in zip(ops
, ops_live_flags
):
253 def tree_reduction_ops(item_count
):
254 needed_outputs
= (i
== item_count
- 1 for i
in range(item_count
))
255 return partial_prefix_sum_ops(needed_outputs
)
258 def tree_reduction(items
, fn
=operator
.add
):
260 for op
in tree_reduction_ops(len(items
)):
261 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
265 def pop_count(v
, *, width
=None, process_temporary
=lambda v
: v
):
266 if isinstance(v
, Value
):
269 bits
= [v
[i
] for i
in range(width
)]
273 bits
= [(v
& (1 << i
)) != 0 for i
in range(width
)]
276 return tree_reduction(bits
, fn
=lambda a
, b
: process_temporary(a
+ b
))
279 if __name__
== "__main__":
280 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
282 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
284 print(render_prefix_sum_diagram(16, work_efficient
=False))
287 print("the work-efficient algorithm, matches the diagram in wikipedia:")
288 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
290 print(render_prefix_sum_diagram(16, work_efficient
=True))