speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 import operator
9 from nmigen.hdl.ast import Value, Const
10 from nmutil.plain_data import plain_data
11
12
13 @plain_data(order=True, unsafe_hash=True, frozen=True)
14 class Op:
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
18 """
19 __slots__ = "out", "lhs", "rhs", "row"
20
21 def __init__(self, out, lhs, rhs, row):
22 self.out = out
23 "index of the item to output to"
24
25 self.lhs = lhs
26 "index of the item the left-hand-side input comes from"
27
28 self.rhs = rhs
29 "index of the item the right-hand-side input comes from"
30
31 self.row = row
32 "row in the prefix-sum diagram"
33
34
35 def prefix_sum_ops(item_count, *, work_efficient=False):
36 """Get the associative operations needed to compute a parallel prefix-sum
37 of `item_count` items.
38
39 The operations aren't assumed to be commutative.
40
41 This has a depth of `O(log(N))` and an operation count of `O(N)` if
42 `work_efficient` is true, otherwise `O(N*log(N))`.
43
44 The algorithms used are derived from:
45 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
46 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
47
48 Parameters:
49 item_count: int
50 number of input items.
51 work_efficient: bool
52 True if the algorithm used should be work-efficient -- has a larger
53 depth (about twice as large) but does only `O(N)` operations total
54 instead of `O(N*log(N))`.
55 Returns: Iterable[Op]
56 output associative operations.
57 """
58 assert isinstance(item_count, int)
59 # compute the partial sums using a set of binary trees
60 # this is the first half of the work-efficient algorithm and the whole of
61 # the non-work-efficient algorithm.
62 dist = 1
63 row = 0
64 while dist < item_count:
65 start = dist * 2 - 1 if work_efficient else dist
66 step = dist * 2 if work_efficient else 1
67 for i in reversed(range(start, item_count, step)):
68 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
69 dist <<= 1
70 row += 1
71 if work_efficient:
72 # express all output items in terms of the computed partial sums.
73 dist >>= 1
74 while dist >= 1:
75 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
76 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
77 row += 1
78 dist >>= 1
79
80
81 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
82 """Compute the parallel prefix-sum of `items`, using associative operator
83 `fn` instead of addition.
84
85 This has a depth of `O(log(N))` and an operation count of `O(N)` if
86 `work_efficient` is true, otherwise `O(N*log(N))`.
87
88 The algorithms used are derived from:
89 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
90 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
91
92 Parameters:
93 items: Iterable[_T]
94 input items.
95 fn: Callable[[_T, _T], _T]
96 Operation to use for the prefix-sum algorithm instead of addition.
97 Assumed to be associative not necessarily commutative.
98 work_efficient: bool
99 True if the algorithm used should be work-efficient -- has a larger
100 depth (about twice as large) but does only `O(N)` operations total
101 instead of `O(N*log(N))`.
102 Returns: list[_T]
103 output items.
104 """
105 items = list(items)
106 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
107 items[op.out] = fn(items[op.lhs], items[op.rhs])
108 return items
109
110
111 @plain_data()
112 class _Cell:
113 __slots__ = "slant", "plus", "tee"
114
115 def __init__(self, slant, plus, tee):
116 self.slant = slant
117 self.plus = plus
118 self.tee = tee
119
120
121 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
122 sp=" ", vbar="|", plus="⊕",
123 slant="\\", connect="●", no_connect="X",
124 padding=1,
125 ):
126 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
127
128 Parameters:
129 item_count: int
130 number of input items.
131 work_efficient: bool
132 True if the algorithm used should be work-efficient -- has a larger
133 depth (about twice as large) but does only `O(N)` operations total
134 instead of `O(N*log(N))`.
135 sp: str
136 character used for blank space
137 vbar: str
138 character used for a vertical bar
139 plus: str
140 character used for the addition operation
141 slant: str
142 character used to draw a line from the top left to the bottom right
143 connect: str
144 character used to draw a connection between a vertical line and a line
145 going from the center of this character to the bottom right
146 no_connect: str
147 character used to draw two lines crossing but not connecting, the lines
148 are vertical and diagonal from top left to the bottom right
149 padding: int
150 amount of padding characters in the output cells.
151 Returns: str
152 rendered diagram
153 """
154 ops_by_row = defaultdict(set)
155 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
156 assert op.out == op.rhs, f"can't draw op: {op}"
157 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
158 ops_by_row[op.row].add(op)
159
160 def blank_row():
161 return [_Cell(slant=False, plus=False, tee=False)
162 for _ in range(item_count)]
163
164 cells = [blank_row()]
165
166 for row in sorted(ops_by_row.keys()):
167 ops = ops_by_row[row]
168 max_distance = max(op.rhs - op.lhs for op in ops)
169 cells.extend(blank_row() for _ in range(max_distance))
170 for op in ops:
171 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
172 y = len(cells) - 1
173 x = op.out
174 cells[y][x].plus = True
175 x -= 1
176 y -= 1
177 while op.lhs < x:
178 cells[y][x].slant = True
179 x -= 1
180 y -= 1
181 cells[y][x].tee = True
182
183 lines = []
184 for cells_row in cells:
185 row_text = [[] for y in range(2 * padding + 1)]
186 for cell in cells_row:
187 # top padding
188 for y in range(padding):
189 # top left padding
190 for x in range(padding):
191 is_slant = x == y and (cell.plus or cell.slant)
192 row_text[y].append(slant if is_slant else sp)
193 # top vertical bar
194 row_text[y].append(vbar)
195 # top right padding
196 for x in range(padding):
197 row_text[y].append(sp)
198 # center left padding
199 for x in range(padding):
200 row_text[padding].append(sp)
201 # center
202 center = vbar
203 if cell.plus:
204 center = plus
205 elif cell.tee:
206 center = connect
207 elif cell.slant:
208 center = no_connect
209 row_text[padding].append(center)
210 # center right padding
211 for x in range(padding):
212 row_text[padding].append(sp)
213 # bottom padding
214 for y in range(padding + 1, 2 * padding + 1):
215 # bottom left padding
216 for x in range(padding):
217 row_text[y].append(sp)
218 # bottom vertical bar
219 row_text[y].append(vbar)
220 # bottom right padding
221 for x in range(padding + 1, 2 * padding + 1):
222 is_slant = x == y and (cell.tee or cell.slant)
223 row_text[y].append(slant if is_slant else sp)
224 for line in row_text:
225 lines.append("".join(line))
226
227 return "\n".join(map(str.rstrip, lines))
228
229
230 def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
231 """ Get the associative operations needed to compute a parallel prefix-sum
232 of `len(needed_outputs)` items.
233
234 The operations aren't assumed to be commutative.
235
236 This has a depth of `O(log(N))` and an operation count of `O(N)` if
237 `work_efficient` is true, otherwise `O(N*log(N))`.
238
239 The algorithms used are derived from:
240 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
241 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
242
243 Parameters:
244 needed_outputs: Iterable[bool]
245 The length is the number of input/output items.
246 Each item is True if that corresponding output is needed.
247 Unneeded outputs have unspecified value.
248 work_efficient: bool
249 True if the algorithm used should be work-efficient -- has a larger
250 depth (about twice as large) but does only `O(N)` operations total
251 instead of `O(N*log(N))`.
252 Returns: Iterable[Op]
253 output associative operations.
254 """
255
256 # needed_outputs is an iterable, we need to construct a new list so we
257 # don't modify the passed-in value
258 items_live_flags = [bool(i) for i in needed_outputs]
259 ops = list(prefix_sum_ops(item_count=len(items_live_flags),
260 work_efficient=work_efficient))
261 ops_live_flags = [False] * len(ops)
262 for i in reversed(range(len(ops))):
263 op = ops[i]
264 out_live = items_live_flags[op.out]
265 items_live_flags[op.out] = False
266 items_live_flags[op.lhs] |= out_live
267 items_live_flags[op.rhs] |= out_live
268 ops_live_flags[i] = out_live
269 for op, live_flag in zip(ops, ops_live_flags):
270 if live_flag:
271 yield op
272
273
274 def tree_reduction_ops(item_count):
275 assert item_count >= 1
276 needed_outputs = (i == item_count - 1 for i in range(item_count))
277 return partial_prefix_sum_ops(needed_outputs)
278
279
280 def tree_reduction(items, fn=operator.add):
281 items = list(items)
282 for op in tree_reduction_ops(len(items)):
283 items[op.out] = fn(items[op.lhs], items[op.rhs])
284 return items[-1]
285
286
287 if __name__ == "__main__":
288 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
289 "\n"
290 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
291 "\n\n")
292 print(render_prefix_sum_diagram(16, work_efficient=False))
293 print()
294 print()
295 print("the work-efficient algorithm, matches the diagram in wikipedia:")
296 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
297 print()
298 print(render_prefix_sum_diagram(16, work_efficient=True))
299 print()
300 print()