whitespace in docstrings
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 import operator
9 from nmigen.hdl.ast import Value, Const
10 from nmutil.plain_data import plain_data
11
12
13 @plain_data(order=True, unsafe_hash=True, frozen=True)
14 class Op:
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
18 """
19 __slots__ = "out", "lhs", "rhs", "row"
20
21 def __init__(self, out, lhs, rhs, row):
22 self.out = out; "index of the item to output to"
23 self.lhs = lhs; "index of the item the left-hand-side input comes from"
24 self.rhs = rhs; "index of the item the right-hand-side input comes from"
25 self.row = row; "row in the prefix-sum diagram"
26
27
28 def prefix_sum_ops(item_count, *, work_efficient=False):
29 """Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
31
32 The operations aren't assumed to be commutative.
33
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
36
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
40
41 Parameters:
42 item_count: int
43 number of input items.
44 work_efficient: bool
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
48 Returns: Iterable[Op]
49 output associative operations.
50 """
51 assert isinstance(item_count, int)
52 # compute the partial sums using a set of binary trees
53 # this is the first half of the work-efficient algorithm and the whole of
54 # the non-work-efficient algorithm.
55 dist = 1
56 row = 0
57 while dist < item_count:
58 start = dist * 2 - 1 if work_efficient else dist
59 step = dist * 2 if work_efficient else 1
60 for i in reversed(range(start, item_count, step)):
61 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
62 dist <<= 1
63 row += 1
64 if work_efficient:
65 # express all output items in terms of the computed partial sums.
66 dist >>= 1
67 while dist >= 1:
68 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
69 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
70 row += 1
71 dist >>= 1
72
73
74 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
75 """Compute the parallel prefix-sum of `items`, using associative operator
76 `fn` instead of addition.
77
78 This has a depth of `O(log(N))` and an operation count of `O(N)` if
79 `work_efficient` is true, otherwise `O(N*log(N))`.
80
81 The algorithms used are derived from:
82 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
84
85 Parameters:
86 items: Iterable[_T]
87 input items.
88 fn: Callable[[_T, _T], _T]
89 Operation to use for the prefix-sum algorithm instead of addition.
90 Assumed to be associative not necessarily commutative.
91 work_efficient: bool
92 True if the algorithm used should be work-efficient -- has a larger
93 depth (about twice as large) but does only `O(N)` operations total
94 instead of `O(N*log(N))`.
95 Returns: list[_T]
96 output items.
97 """
98 items = list(items)
99 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
100 items[op.out] = fn(items[op.lhs], items[op.rhs])
101 return items
102
103
104 @plain_data()
105 class _Cell:
106 __slots__ = "slant", "plus", "tee"
107
108 def __init__(self, slant, plus, tee):
109 self.slant = slant
110 self.plus = plus
111 self.tee = tee
112
113
114 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
115 sp=" ", vbar="|", plus="⊕",
116 slant="\\", connect="●", no_connect="X",
117 padding=1,
118 ):
119 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
120
121 Parameters:
122 item_count: int
123 number of input items.
124 work_efficient: bool
125 True if the algorithm used should be work-efficient -- has a larger
126 depth (about twice as large) but does only `O(N)` operations total
127 instead of `O(N*log(N))`.
128 sp: str
129 character used for blank space
130 vbar: str
131 character used for a vertical bar
132 plus: str
133 character used for the addition operation
134 slant: str
135 character used to draw a line from the top left to the bottom right
136 connect: str
137 character used to draw a connection between a vertical line and a line
138 going from the center of this character to the bottom right
139 no_connect: str
140 character used to draw two lines crossing but not connecting, the lines
141 are vertical and diagonal from top left to the bottom right
142 padding: int
143 amount of padding characters in the output cells.
144 Returns: str
145 rendered diagram
146 """
147 ops_by_row = defaultdict(set)
148 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
149 assert op.out == op.rhs, f"can't draw op: {op}"
150 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
151 ops_by_row[op.row].add(op)
152
153 def blank_row():
154 return [_Cell(slant=False, plus=False, tee=False)
155 for _ in range(item_count)]
156
157 cells = [blank_row()]
158
159 for row in sorted(ops_by_row.keys()):
160 ops = ops_by_row[row]
161 max_distance = max(op.rhs - op.lhs for op in ops)
162 cells.extend(blank_row() for _ in range(max_distance))
163 for op in ops:
164 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
165 y = len(cells) - 1
166 x = op.out
167 cells[y][x].plus = True
168 x -= 1
169 y -= 1
170 while op.lhs < x:
171 cells[y][x].slant = True
172 x -= 1
173 y -= 1
174 cells[y][x].tee = True
175
176 lines = []
177 for cells_row in cells:
178 row_text = [[] for y in range(2 * padding + 1)]
179 for cell in cells_row:
180 # top padding
181 for y in range(padding):
182 # top left padding
183 for x in range(padding):
184 is_slant = x == y and (cell.plus or cell.slant)
185 row_text[y].append(slant if is_slant else sp)
186 # top vertical bar
187 row_text[y].append(vbar)
188 # top right padding
189 for x in range(padding):
190 row_text[y].append(sp)
191 # center left padding
192 for x in range(padding):
193 row_text[padding].append(sp)
194 # center
195 center = vbar
196 if cell.plus:
197 center = plus
198 elif cell.tee:
199 center = connect
200 elif cell.slant:
201 center = no_connect
202 row_text[padding].append(center)
203 # center right padding
204 for x in range(padding):
205 row_text[padding].append(sp)
206 # bottom padding
207 for y in range(padding + 1, 2 * padding + 1):
208 # bottom left padding
209 for x in range(padding):
210 row_text[y].append(sp)
211 # bottom vertical bar
212 row_text[y].append(vbar)
213 # bottom right padding
214 for x in range(padding + 1, 2 * padding + 1):
215 is_slant = x == y and (cell.tee or cell.slant)
216 row_text[y].append(slant if is_slant else sp)
217 for line in row_text:
218 lines.append("".join(line))
219
220 return "\n".join(map(str.rstrip, lines))
221
222
223 def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
224 """ Get the associative operations needed to compute a parallel prefix-sum
225 of `len(needed_outputs)` items.
226
227 The operations aren't assumed to be commutative.
228
229 This has a depth of `O(log(N))` and an operation count of `O(N)` if
230 `work_efficient` is true, otherwise `O(N*log(N))`.
231
232 The algorithms used are derived from:
233 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
234 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
235
236 Parameters:
237 needed_outputs: Iterable[bool]
238 The length is the number of input/output items.
239 Each item is True if that corresponding output is needed.
240 Unneeded outputs have unspecified value.
241 work_efficient: bool
242 True if the algorithm used should be work-efficient -- has a larger
243 depth (about twice as large) but does only `O(N)` operations total
244 instead of `O(N*log(N))`.
245 Returns: Iterable[Op]
246 output associative operations.
247 """
248
249 # needed_outputs is an iterable, we need to construct a new list so we
250 # don't modify the passed-in value
251 items_live_flags = [bool(i) for i in needed_outputs]
252 ops = list(prefix_sum_ops(item_count=len(items_live_flags),
253 work_efficient=work_efficient))
254 ops_live_flags = [False] * len(ops)
255 for i in reversed(range(len(ops))):
256 op = ops[i]
257 out_live = items_live_flags[op.out]
258 items_live_flags[op.out] = False
259 items_live_flags[op.lhs] |= out_live
260 items_live_flags[op.rhs] |= out_live
261 ops_live_flags[i] = out_live
262 for op, live_flag in zip(ops, ops_live_flags):
263 if live_flag:
264 yield op
265
266
267 def tree_reduction_ops(item_count):
268 assert item_count >= 1
269 needed_outputs = (i == item_count - 1 for i in range(item_count))
270 return partial_prefix_sum_ops(needed_outputs)
271
272
273 def tree_reduction(items, fn=operator.add):
274 items = list(items)
275 for op in tree_reduction_ops(len(items)):
276 items[op.out] = fn(items[op.lhs], items[op.rhs])
277 return items[-1]
278
279
280 def pop_count(v, *, width=None, process_temporary=lambda v: v):
281 """return the population count (number of 1 bits) of `v`.
282 Arguments:
283 v: nmigen.Value | int
284 the value to calculate the pop-count of.
285 width: int | None
286 the bit-width of `v`.
287 If `width` is None, then `v` must be a nmigen Value or
288 match `v`'s width.
289 process_temporary: function of (type(v)) -> type(v)
290 called after every addition operation, can be used to introduce
291 `Signal`s for the intermediate values in the pop-count computation
292 like so:
293
294 ```
295 def process_temporary(v):
296 sig = Signal.like(v)
297 m.d.comb += sig.eq(v)
298 return sig
299 ```
300 """
301 if isinstance(v, Value):
302 if width is None:
303 width = len(v)
304 assert width == len(v)
305 bits = [v[i] for i in range(width)]
306 if len(bits) == 0:
307 return Const(0)
308 else:
309 assert width is not None, "width must be given"
310 # v and width are ints
311 bits = [(v & (1 << i)) != 0 for i in range(width)]
312 if len(bits) == 0:
313 return 0
314 return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
315
316
317 if __name__ == "__main__":
318 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
319 "\n"
320 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
321 "\n\n")
322 print(render_prefix_sum_diagram(16, work_efficient=False))
323 print()
324 print()
325 print("the work-efficient algorithm, matches the diagram in wikipedia:")
326 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
327 print()
328 print(render_prefix_sum_diagram(16, work_efficient=True))
329 print()
330 print()