remove dataclass and use of types
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 import operator
9 from nmigen.hdl.ast import Value, Const
10
11
12 class Op:
13 """An associative operation in a prefix-sum.
14 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
15 The operation is not assumed to be commutative.
16 """
17 def __init__(self,*, out, lhs, rhs, row):
18 self.out = out; "index of the item to output to"
19 self.lhs = lhs; "index of item the left-hand-side input comes from"
20 self.rhs = rhs; "index of item the right-hand-side input comes from"
21 self.row = row; "row in the prefix-sum diagram"
22
23
24 def prefix_sum_ops(item_count, *, work_efficient=False):
25 """ Get the associative operations needed to compute a parallel prefix-sum
26 of `item_count` items.
27
28 The operations aren't assumed to be commutative.
29
30 This has a depth of `O(log(N))` and an operation count of `O(N)` if
31 `work_efficient` is true, otherwise `O(N*log(N))`.
32
33 The algorithms used are derived from:
34 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
35 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
36
37 Parameters:
38 item_count: int
39 number of input items.
40 work_efficient: bool
41 True if the algorithm used should be work-efficient -- has a larger
42 depth (about twice as large) but does only `O(N)` operations total
43 instead of `O(N*log(N))`.
44 Returns: Iterable[Op]
45 output associative operations.
46 """
47 assert isinstance(item_count, int)
48 # compute the partial sums using a set of binary trees
49 # first half of the work-efficient algorithm and the whole of
50 # the non-work-efficient algorithm.
51 dist = 1
52 row = 0
53 while dist < item_count:
54 start = dist * 2 - 1 if work_efficient else dist
55 step = dist * 2 if work_efficient else 1
56 for i in reversed(range(start, item_count, step)):
57 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
58 dist <<= 1
59 row += 1
60 if work_efficient:
61 # express all output items in terms of the computed partial sums.
62 dist >>= 1
63 while dist >= 1:
64 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
65 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
66 row += 1
67 dist >>= 1
68
69
70 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
71 """ Compute the parallel prefix-sum of `items`, using associative operator
72 `fn` instead of addition.
73
74 This has a depth of `O(log(N))` and an operation count of `O(N)` if
75 `work_efficient` is true, otherwise `O(N*log(N))`.
76
77 The algorithms used are derived from:
78 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
79 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
80
81 Parameters:
82 items: Iterable[_T]
83 input items.
84 fn: Callable[[_T, _T], _T]
85 Operation to use for the prefix-sum algorithm instead of addition.
86 Assumed to be associative not necessarily commutative.
87 work_efficient: bool
88 True if the algorithm used should be work-efficient -- has a larger
89 depth (about twice as large) but does only `O(N)` operations total
90 instead of `O(N*log(N))`.
91 Returns: list[_T]
92 output items.
93 """
94 items = list(items)
95 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
96 items[op.out] = fn(items[op.lhs], items[op.rhs])
97 return items
98
99
100 class _Cell:
101 def __init__(self, *, slant, plus, tee):
102 self.slant = slant
103 self.plus = plus
104 self.tee = tee
105
106
107 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
108 sp=" ", vbar="|", plus="⊕",
109 slant="\\", connect="●", no_connect="X",
110 padding=1,
111 ):
112 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
113
114 Parameters:
115 item_count: int
116 number of input items.
117 work_efficient: bool
118 True if the algorithm used should be work-efficient -- has a larger
119 depth (about twice as large) but does only `O(N)` operations total
120 instead of `O(N*log(N))`.
121 sp: str
122 character used for blank space
123 vbar: str
124 character used for a vertical bar
125 plus: str
126 character used for the addition operation
127 slant: str
128 character used to draw a line from the top left to the bottom right
129 connect: str
130 character used to draw a connection between a vertical line and a line
131 going from the center of this character to the bottom right
132 no_connect: str
133 character used to draw two lines crossing but not connecting, the lines
134 are vertical and diagonal from top left to the bottom right
135 padding: int
136 amount of padding characters in the output cells.
137 Returns: str
138 rendered diagram
139 """
140 assert isinstance(item_count, int)
141 assert isinstance(padding, int)
142 ops_by_row = defaultdict(set)
143 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
144 assert op.out == op.rhs, f"can't draw op: {op}"
145 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
146 ops_by_row[op.row].add(op)
147
148 def blank_row():
149 return [_Cell(slant=False, plus=False, tee=False)
150 for _ in range(item_count)]
151
152 cells = [blank_row()]
153
154 for row in sorted(ops_by_row.keys()):
155 ops = ops_by_row[row]
156 max_distance = max(op.rhs - op.lhs for op in ops)
157 cells.extend(blank_row() for _ in range(max_distance))
158 for op in ops:
159 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
160 y = len(cells) - 1
161 x = op.out
162 cells[y][x].plus = True
163 x -= 1
164 y -= 1
165 while op.lhs < x:
166 cells[y][x].slant = True
167 x -= 1
168 y -= 1
169 cells[y][x].tee = True
170
171 lines = []
172 for cells_row in cells:
173 row_text = [[] for y in range(2 * padding + 1)]
174 for cell in cells_row:
175 # top padding
176 for y in range(padding):
177 # top left padding
178 for x in range(padding):
179 is_slant = x == y and (cell.plus or cell.slant)
180 row_text[y].append(slant if is_slant else sp)
181 # top vertical bar
182 row_text[y].append(vbar)
183 # top right padding
184 for x in range(padding):
185 row_text[y].append(sp)
186 # center left padding
187 for x in range(padding):
188 row_text[padding].append(sp)
189 # center
190 center = vbar
191 if cell.plus:
192 center = plus
193 elif cell.tee:
194 center = connect
195 elif cell.slant:
196 center = no_connect
197 row_text[padding].append(center)
198 # center right padding
199 for x in range(padding):
200 row_text[padding].append(sp)
201 # bottom padding
202 for y in range(padding + 1, 2 * padding + 1):
203 # bottom left padding
204 for x in range(padding):
205 row_text[y].append(sp)
206 # bottom vertical bar
207 row_text[y].append(vbar)
208 # bottom right padding
209 for x in range(padding + 1, 2 * padding + 1):
210 is_slant = x == y and (cell.tee or cell.slant)
211 row_text[y].append(slant if is_slant else sp)
212 for line in row_text:
213 lines.append("".join(line))
214
215 return "\n".join(map(str.rstrip, lines))
216
217
218 def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
219 """ Get the associative operations needed to compute a parallel prefix-sum
220 of `len(needed_outputs)` items.
221
222 The operations aren't assumed to be commutative.
223
224 This has a depth of `O(log(N))` and an operation count of `O(N)` if
225 `work_efficient` is true, otherwise `O(N*log(N))`.
226
227 The algorithms used are derived from:
228 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
229 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
230
231 Parameters:
232 needed_outputs: Iterable[bool]
233 The length is the number of input/output items.
234 Each item is True if that corresponding output is needed.
235 Unneeded outputs have unspecified value.
236 work_efficient: bool
237 True if the algorithm used should be work-efficient -- has a larger
238 depth (about twice as large) but does only `O(N)` operations total
239 instead of `O(N*log(N))`.
240 Returns: Iterable[Op]
241 output associative operations.
242 """
243 def assert_bool(v):
244 assert isinstance(v, bool)
245 return v
246 items_live_flags = [assert_bool(i) for i in needed_outputs]
247 ops = list(prefix_sum_ops(item_count=len(items_live_flags),
248 work_efficient=work_efficient))
249 ops_live_flags = [False] * len(ops)
250 for i in reversed(range(len(ops))):
251 op = ops[i]
252 out_live = items_live_flags[op.out]
253 items_live_flags[op.out] = False
254 items_live_flags[op.lhs] |= out_live
255 items_live_flags[op.rhs] |= out_live
256 ops_live_flags[i] = out_live
257 for op, live_flag in zip(ops, ops_live_flags):
258 if live_flag:
259 yield op
260
261
262 def tree_reduction_ops(item_count):
263 assert isinstance(item_count, int) and item_count >= 1
264 needed_outputs = (i == item_count - 1 for i in range(item_count))
265 return partial_prefix_sum_ops(needed_outputs)
266
267
268 def tree_reduction(items, fn=operator.add):
269 items = list(items)
270 for op in tree_reduction_ops(len(items)):
271 items[op.out] = fn(items[op.lhs], items[op.rhs])
272 return items[-1]
273
274
275 def pop_count(v, *, width=None, process_temporary=lambda v: v):
276 if isinstance(v, Value):
277 if width is None:
278 width = len(v)
279 assert width == len(v)
280 bits = [v[i] for i in range(width)]
281 if len(bits) == 0:
282 return Const(0)
283 else:
284 assert isinstance(width, int) and width >= 0
285 assert isinstance(v, int)
286 bits = [(v & (1 << i)) != 0 for i in range(width)]
287 if len(bits) == 0:
288 return 0
289 return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b))
290
291
292 if __name__ == "__main__":
293 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
294 "\n"
295 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
296 "\n\n")
297 print(render_prefix_sum_diagram(16, work_efficient=False))
298 print()
299 print()
300 print("the work-efficient algorithm, matches the diagram in wikipedia:")
301 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
302 print()
303 print(render_prefix_sum_diagram(16, work_efficient=True))
304 print()
305 print()