1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
9 from nmigen
.hdl
.ast
import Value
, Const
13 """An associative operation in a prefix-sum.
14 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
15 The operation is not assumed to be commutative.
17 def __init__(self
,*, out
, lhs
, rhs
, row
):
18 self
.out
= out
; "index of the item to output to"
19 self
.lhs
= lhs
; "index of item the left-hand-side input comes from"
20 self
.rhs
= rhs
; "index of item the right-hand-side input comes from"
21 self
.row
= row
; "row in the prefix-sum diagram"
24 def prefix_sum_ops(item_count
, *, work_efficient
=False):
25 """ Get the associative operations needed to compute a parallel prefix-sum
26 of `item_count` items.
28 The operations aren't assumed to be commutative.
30 This has a depth of `O(log(N))` and an operation count of `O(N)` if
31 `work_efficient` is true, otherwise `O(N*log(N))`.
33 The algorithms used are derived from:
34 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
35 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
39 number of input items.
41 True if the algorithm used should be work-efficient -- has a larger
42 depth (about twice as large) but does only `O(N)` operations total
43 instead of `O(N*log(N))`.
45 output associative operations.
47 assert isinstance(item_count
, int)
48 # compute the partial sums using a set of binary trees
49 # first half of the work-efficient algorithm and the whole of
50 # the non-work-efficient algorithm.
53 while dist
< item_count
:
54 start
= dist
* 2 - 1 if work_efficient
else dist
55 step
= dist
* 2 if work_efficient
else 1
56 for i
in reversed(range(start
, item_count
, step
)):
57 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
61 # express all output items in terms of the computed partial sums.
64 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
65 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
70 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
71 """ Compute the parallel prefix-sum of `items`, using associative operator
72 `fn` instead of addition.
74 This has a depth of `O(log(N))` and an operation count of `O(N)` if
75 `work_efficient` is true, otherwise `O(N*log(N))`.
77 The algorithms used are derived from:
78 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
79 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
84 fn: Callable[[_T, _T], _T]
85 Operation to use for the prefix-sum algorithm instead of addition.
86 Assumed to be associative not necessarily commutative.
88 True if the algorithm used should be work-efficient -- has a larger
89 depth (about twice as large) but does only `O(N)` operations total
90 instead of `O(N*log(N))`.
95 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
96 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
101 def __init__(self
, *, slant
, plus
, tee
):
107 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
108 sp
=" ", vbar
="|", plus
="⊕",
109 slant
="\\", connect
="●", no_connect
="X",
112 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
116 number of input items.
118 True if the algorithm used should be work-efficient -- has a larger
119 depth (about twice as large) but does only `O(N)` operations total
120 instead of `O(N*log(N))`.
122 character used for blank space
124 character used for a vertical bar
126 character used for the addition operation
128 character used to draw a line from the top left to the bottom right
130 character used to draw a connection between a vertical line and a line
131 going from the center of this character to the bottom right
133 character used to draw two lines crossing but not connecting, the lines
134 are vertical and diagonal from top left to the bottom right
136 amount of padding characters in the output cells.
140 assert isinstance(item_count
, int)
141 assert isinstance(padding
, int)
142 ops_by_row
= defaultdict(set)
143 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
144 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
145 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
146 ops_by_row
[op
.row
].add(op
)
149 return [_Cell(slant
=False, plus
=False, tee
=False)
150 for _
in range(item_count
)]
152 cells
= [blank_row()]
154 for row
in sorted(ops_by_row
.keys()):
155 ops
= ops_by_row
[row
]
156 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
157 cells
.extend(blank_row() for _
in range(max_distance
))
159 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
162 cells
[y
][x
].plus
= True
166 cells
[y
][x
].slant
= True
169 cells
[y
][x
].tee
= True
172 for cells_row
in cells
:
173 row_text
= [[] for y
in range(2 * padding
+ 1)]
174 for cell
in cells_row
:
176 for y
in range(padding
):
178 for x
in range(padding
):
179 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
180 row_text
[y
].append(slant
if is_slant
else sp
)
182 row_text
[y
].append(vbar
)
184 for x
in range(padding
):
185 row_text
[y
].append(sp
)
186 # center left padding
187 for x
in range(padding
):
188 row_text
[padding
].append(sp
)
197 row_text
[padding
].append(center
)
198 # center right padding
199 for x
in range(padding
):
200 row_text
[padding
].append(sp
)
202 for y
in range(padding
+ 1, 2 * padding
+ 1):
203 # bottom left padding
204 for x
in range(padding
):
205 row_text
[y
].append(sp
)
206 # bottom vertical bar
207 row_text
[y
].append(vbar
)
208 # bottom right padding
209 for x
in range(padding
+ 1, 2 * padding
+ 1):
210 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
211 row_text
[y
].append(slant
if is_slant
else sp
)
212 for line
in row_text
:
213 lines
.append("".join(line
))
215 return "\n".join(map(str.rstrip
, lines
))
218 def partial_prefix_sum_ops(needed_outputs
, *, work_efficient
=False):
219 """ Get the associative operations needed to compute a parallel prefix-sum
220 of `len(needed_outputs)` items.
222 The operations aren't assumed to be commutative.
224 This has a depth of `O(log(N))` and an operation count of `O(N)` if
225 `work_efficient` is true, otherwise `O(N*log(N))`.
227 The algorithms used are derived from:
228 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
229 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
232 needed_outputs: Iterable[bool]
233 The length is the number of input/output items.
234 Each item is True if that corresponding output is needed.
235 Unneeded outputs have unspecified value.
237 True if the algorithm used should be work-efficient -- has a larger
238 depth (about twice as large) but does only `O(N)` operations total
239 instead of `O(N*log(N))`.
240 Returns: Iterable[Op]
241 output associative operations.
244 assert isinstance(v
, bool)
246 items_live_flags
= [assert_bool(i
) for i
in needed_outputs
]
247 ops
= list(prefix_sum_ops(item_count
=len(items_live_flags
),
248 work_efficient
=work_efficient
))
249 ops_live_flags
= [False] * len(ops
)
250 for i
in reversed(range(len(ops
))):
252 out_live
= items_live_flags
[op
.out
]
253 items_live_flags
[op
.out
] = False
254 items_live_flags
[op
.lhs
] |
= out_live
255 items_live_flags
[op
.rhs
] |
= out_live
256 ops_live_flags
[i
] = out_live
257 for op
, live_flag
in zip(ops
, ops_live_flags
):
262 def tree_reduction_ops(item_count
):
263 assert isinstance(item_count
, int) and item_count
>= 1
264 needed_outputs
= (i
== item_count
- 1 for i
in range(item_count
))
265 return partial_prefix_sum_ops(needed_outputs
)
268 def tree_reduction(items
, fn
=operator
.add
):
270 for op
in tree_reduction_ops(len(items
)):
271 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
275 def pop_count(v
, *, width
=None, process_temporary
=lambda v
: v
):
276 if isinstance(v
, Value
):
279 assert width
== len(v
)
280 bits
= [v
[i
] for i
in range(width
)]
284 assert isinstance(width
, int) and width
>= 0
285 assert isinstance(v
, int)
286 bits
= [(v
& (1 << i
)) != 0 for i
in range(width
)]
289 return tree_reduction(bits
, fn
=lambda a
, b
: process_temporary(a
+ b
))
292 if __name__
== "__main__":
293 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
295 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
297 print(render_prefix_sum_diagram(16, work_efficient
=False))
300 print("the work-efficient algorithm, matches the diagram in wikipedia:")
301 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
303 print(render_prefix_sum_diagram(16, work_efficient
=True))