549d772d78dc25aaa13aa8cf003e8cc90416971b
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 import operator
9 from nmigen.hdl.ast import Value, Const
10
11
12 class Op:
13 """An associative operation in a prefix-sum.
14 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
15 The operation is not assumed to be commutative.
16 """
17 def __init__(self,*, out, lhs, rhs, row):
18 self.out = out; "index of the item to output to"
19 self.lhs = lhs; "index of item the left-hand-side input comes from"
20 self.rhs = rhs; "index of item the right-hand-side input comes from"
21 self.row = row; "row in the prefix-sum diagram"
22
23
24 def prefix_sum_ops(item_count, *, work_efficient=False):
25 """ Get the associative operations needed to compute a parallel prefix-sum
26 of `item_count` items.
27
28 The operations aren't assumed to be commutative.
29
30 This has a depth of `O(log(N))` and an operation count of `O(N)` if
31 `work_efficient` is true, otherwise `O(N*log(N))`.
32
33 The algorithms used are derived from:
34 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
35 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
36
37 Parameters:
38 item_count: int
39 number of input items.
40 work_efficient: bool
41 True if the algorithm used should be work-efficient -- has a larger
42 depth (about twice as large) but does only `O(N)` operations total
43 instead of `O(N*log(N))`.
44 Returns: Iterable[Op]
45 output associative operations.
46 """
47 # compute the partial sums using a set of binary trees
48 # first half of the work-efficient algorithm and the whole of
49 # the non-work-efficient algorithm.
50 dist = 1
51 row = 0
52 while dist < item_count:
53 start = dist * 2 - 1 if work_efficient else dist
54 step = dist * 2 if work_efficient else 1
55 for i in reversed(range(start, item_count, step)):
56 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
57 dist <<= 1
58 row += 1
59 if work_efficient:
60 # express all output items in terms of the computed partial sums.
61 dist >>= 1
62 while dist >= 1:
63 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
64 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
65 row += 1
66 dist >>= 1
67
68
69 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
70 """ Compute the parallel prefix-sum of `items`, using associative operator
71 `fn` instead of addition.
72
73 This has a depth of `O(log(N))` and an operation count of `O(N)` if
74 `work_efficient` is true, otherwise `O(N*log(N))`.
75
76 The algorithms used are derived from:
77 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
78 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
79
80 Parameters:
81 items: Iterable[_T]
82 input items.
83 fn: Callable[[_T, _T], _T]
84 Operation to use for the prefix-sum algorithm instead of addition.
85 Assumed to be associative not necessarily commutative.
86 work_efficient: bool
87 True if the algorithm used should be work-efficient -- has a larger
88 depth (about twice as large) but does only `O(N)` operations total
89 instead of `O(N*log(N))`.
90 Returns: list[_T]
91 output items.
92 """
93 items = list(items)
94 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
95 items[op.out] = fn(items[op.lhs], items[op.rhs])
96 return items
97
98
99 class _Cell:
100 def __init__(self, *, slant, plus, tee):
101 self.slant = slant
102 self.plus = plus
103 self.tee = tee
104
105
106 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
107 sp=" ", vbar="|", plus="⊕",
108 slant="\\", connect="●", no_connect="X",
109 padding=1,
110 ):
111 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
112
113 Parameters:
114 item_count: int
115 number of input items.
116 work_efficient: bool
117 True if the algorithm used should be work-efficient -- has a larger
118 depth (about twice as large) but does only `O(N)` operations total
119 instead of `O(N*log(N))`.
120 sp: str
121 character used for blank space
122 vbar: str
123 character used for a vertical bar
124 plus: str
125 character used for the addition operation
126 slant: str
127 character used to draw a line from the top left to the bottom right
128 connect: str
129 character used to draw a connection between a vertical line and a line
130 going from the center of this character to the bottom right
131 no_connect: str
132 character used to draw two lines crossing but not connecting, the lines
133 are vertical and diagonal from top left to the bottom right
134 padding: int
135 amount of padding characters in the output cells.
136 Returns: str
137 rendered diagram
138 """
139 ops_by_row = defaultdict(set)
140 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
141 ops_by_row[op.row].add(op)
142
143 def blank_row():
144 return [_Cell(slant=False, plus=False, tee=False)
145 for _ in range(item_count)]
146
147 cells = [blank_row()]
148
149 for row in sorted(ops_by_row.keys()):
150 ops = ops_by_row[row]
151 max_distance = max(op.rhs - op.lhs for op in ops)
152 cells.extend(blank_row() for _ in range(max_distance))
153 for op in ops:
154 y = len(cells) - 1
155 x = op.out
156 cells[y][x].plus = True
157 x -= 1
158 y -= 1
159 while op.lhs < x:
160 cells[y][x].slant = True
161 x -= 1
162 y -= 1
163 cells[y][x].tee = True
164
165 lines = []
166 for cells_row in cells:
167 row_text = [[] for y in range(2 * padding + 1)]
168 for cell in cells_row:
169 # top padding
170 for y in range(padding):
171 # top left padding
172 for x in range(padding):
173 is_slant = x == y and (cell.plus or cell.slant)
174 row_text[y].append(slant if is_slant else sp)
175 # top vertical bar
176 row_text[y].append(vbar)
177 # top right padding
178 for x in range(padding):
179 row_text[y].append(sp)
180 # center left padding
181 for x in range(padding):
182 row_text[padding].append(sp)
183 # center
184 center = vbar
185 if cell.plus:
186 center = plus
187 elif cell.tee:
188 center = connect
189 elif cell.slant:
190 center = no_connect
191 row_text[padding].append(center)
192 # center right padding
193 for x in range(padding):
194 row_text[padding].append(sp)
195 # bottom padding
196 for y in range(padding + 1, 2 * padding + 1):
197 # bottom left padding
198 for x in range(padding):
199 row_text[y].append(sp)
200 # bottom vertical bar
201 row_text[y].append(vbar)
202 # bottom right padding
203 for x in range(padding + 1, 2 * padding + 1):
204 is_slant = x == y and (cell.tee or cell.slant)
205 row_text[y].append(slant if is_slant else sp)
206 for line in row_text:
207 lines.append("".join(line))
208
209 return "\n".join(map(str.rstrip, lines))
210
211
212 def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
213 """ Get the associative operations needed to compute a parallel prefix-sum
214 of `len(needed_outputs)` items.
215
216 The operations aren't assumed to be commutative.
217
218 This has a depth of `O(log(N))` and an operation count of `O(N)` if
219 `work_efficient` is true, otherwise `O(N*log(N))`.
220
221 The algorithms used are derived from:
222 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
223 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
224
225 Parameters:
226 needed_outputs: Iterable[bool]
227 The length is the number of input/output items.
228 Each item is True if that corresponding output is needed.
229 Unneeded outputs have unspecified value.
230 work_efficient: bool
231 True if the algorithm used should be work-efficient -- has a larger
232 depth (about twice as large) but does only `O(N)` operations total
233 instead of `O(N*log(N))`.
234 Returns: Iterable[Op]
235 output associative operations.
236 """
237 items_live_flags = needed_outputs
238 ops = list(prefix_sum_ops(item_count=len(items_live_flags),
239 work_efficient=work_efficient))
240 ops_live_flags = [False] * len(ops)
241 for i in reversed(range(len(ops))):
242 op = ops[i]
243 out_live = items_live_flags[op.out]
244 items_live_flags[op.out] = False
245 items_live_flags[op.lhs] |= out_live
246 items_live_flags[op.rhs] |= out_live
247 ops_live_flags[i] = out_live
248 for op, live_flag in zip(ops, ops_live_flags):
249 if live_flag:
250 yield op
251
252
253 def tree_reduction_ops(item_count):
254 needed_outputs = (i == item_count - 1 for i in range(item_count))
255 return partial_prefix_sum_ops(needed_outputs)
256
257
258 def tree_reduction(items, fn=operator.add):
259 items = list(items)
260 for op in tree_reduction_ops(len(items)):
261 items[op.out] = fn(items[op.lhs], items[op.rhs])
262 return items[-1]
263
264
265 def pop_count(v, *, width=None, process_temporary=lambda v: v):
266 if isinstance(v, Value):
267 if width is None:
268 width = len(v)
269 bits = [v[i] for i in range(width)]
270 if len(bits) == 0:
271 return Const(0)
272 else:
273 bits = [(v & (1 << i)) != 0 for i in range(width)]
274 if len(bits) == 0:
275 return 0
276 return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
277
278
279 if __name__ == "__main__":
280 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
281 "\n"
282 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
283 "\n\n")
284 print(render_prefix_sum_diagram(16, work_efficient=False))
285 print()
286 print()
287 print("the work-efficient algorithm, matches the diagram in wikipedia:")
288 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
289 print()
290 print(render_prefix_sum_diagram(16, work_efficient=True))
291 print()
292 print()