Revert "add reduce_only option to prefix_sum_ops"
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import operator
10
11
12 @dataclass(order=True, unsafe_hash=True, frozen=True)
13 class Op:
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
17 """
18 out: int
19 """index of the item to output to"""
20 lhs: int
21 """index of the item the left-hand-side input comes from"""
22 rhs: int
23 """index of the item the right-hand-side input comes from"""
24 row: int
25 """row in the prefix-sum diagram"""
26
27
28 def prefix_sum_ops(item_count, *, work_efficient=False):
29 """ Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
31
32 The operations aren't assumed to be commutative.
33
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
36
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
40
41 Parameters:
42 item_count: int
43 number of input items.
44 work_efficient: bool
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
48 Returns: Iterable[Op]
49 output associative operations.
50 """
51 assert isinstance(item_count, int)
52 # compute the partial sums using a set of binary trees
53 # this is the first half of the work-efficient algorithm and the whole of
54 # the non-work-efficient algorithm.
55 dist = 1
56 row = 0
57 while dist < item_count:
58 start = dist * 2 - 1 if work_efficient else dist
59 step = dist * 2 if work_efficient else 1
60 for i in reversed(range(start, item_count, step)):
61 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
62 dist <<= 1
63 row += 1
64 if work_efficient:
65 # express all output items in terms of the computed partial sums.
66 dist >>= 1
67 while dist >= 1:
68 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
69 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
70 row += 1
71 dist >>= 1
72
73
74 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
75 """ Compute the parallel prefix-sum of `items`, using associative operator
76 `fn` instead of addition.
77
78 This has a depth of `O(log(N))` and an operation count of `O(N)` if
79 `work_efficient` is true, otherwise `O(N*log(N))`.
80
81 The algorithms used are derived from:
82 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
84
85 Parameters:
86 items: Iterable[_T]
87 input items.
88 fn: Callable[[_T, _T], _T]
89 Operation to use for the prefix-sum algorithm instead of addition.
90 Assumed to be associative not necessarily commutative.
91 work_efficient: bool
92 True if the algorithm used should be work-efficient -- has a larger
93 depth (about twice as large) but does only `O(N)` operations total
94 instead of `O(N*log(N))`.
95 Returns: list[_T]
96 output items.
97 """
98 items = list(items)
99 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
100 items[op.out] = fn(items[op.lhs], items[op.rhs])
101 return items
102
103
104 @dataclass
105 class _Cell:
106 slant: bool
107 plus: bool
108 tee: bool
109
110
111 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
112 sp=" ", vbar="|", plus="⊕",
113 slant="\\", connect="●", no_connect="X",
114 padding=1,
115 ):
116 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
117
118 Parameters:
119 item_count: int
120 number of input items.
121 work_efficient: bool
122 True if the algorithm used should be work-efficient -- has a larger
123 depth (about twice as large) but does only `O(N)` operations total
124 instead of `O(N*log(N))`.
125 sp: str
126 character used for blank space
127 vbar: str
128 character used for a vertical bar
129 plus: str
130 character used for the addition operation
131 slant: str
132 character used to draw a line from the top left to the bottom right
133 connect: str
134 character used to draw a connection between a vertical line and a line
135 going from the center of this character to the bottom right
136 no_connect: str
137 character used to draw two lines crossing but not connecting, the lines
138 are vertical and diagonal from top left to the bottom right
139 padding: int
140 amount of padding characters in the output cells.
141 Returns: str
142 rendered diagram
143 """
144 assert isinstance(item_count, int)
145 assert isinstance(padding, int)
146 ops_by_row = defaultdict(set)
147 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
148 assert op.out == op.rhs, f"can't draw op: {op}"
149 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
150 ops_by_row[op.row].add(op)
151
152 def blank_row():
153 return [_Cell(slant=False, plus=False, tee=False)
154 for _ in range(item_count)]
155
156 cells = [blank_row()]
157
158 for row in sorted(ops_by_row.keys()):
159 ops = ops_by_row[row]
160 max_distance = max(op.rhs - op.lhs for op in ops)
161 cells.extend(blank_row() for _ in range(max_distance))
162 for op in ops:
163 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
164 y = len(cells) - 1
165 x = op.out
166 cells[y][x].plus = True
167 x -= 1
168 y -= 1
169 while op.lhs < x:
170 cells[y][x].slant = True
171 x -= 1
172 y -= 1
173 cells[y][x].tee = True
174
175 lines = []
176 for cells_row in cells:
177 row_text = [[] for y in range(2 * padding + 1)]
178 for cell in cells_row:
179 # top padding
180 for y in range(padding):
181 # top left padding
182 for x in range(padding):
183 is_slant = x == y and (cell.plus or cell.slant)
184 row_text[y].append(slant if is_slant else sp)
185 # top vertical bar
186 row_text[y].append(vbar)
187 # top right padding
188 for x in range(padding):
189 row_text[y].append(sp)
190 # center left padding
191 for x in range(padding):
192 row_text[padding].append(sp)
193 # center
194 center = vbar
195 if cell.plus:
196 center = plus
197 elif cell.tee:
198 center = connect
199 elif cell.slant:
200 center = no_connect
201 row_text[padding].append(center)
202 # center right padding
203 for x in range(padding):
204 row_text[padding].append(sp)
205 # bottom padding
206 for y in range(padding + 1, 2 * padding + 1):
207 # bottom left padding
208 for x in range(padding):
209 row_text[y].append(sp)
210 # bottom vertical bar
211 row_text[y].append(vbar)
212 # bottom right padding
213 for x in range(padding + 1, 2 * padding + 1):
214 is_slant = x == y and (cell.tee or cell.slant)
215 row_text[y].append(slant if is_slant else sp)
216 for line in row_text:
217 lines.append("".join(line))
218
219 return "\n".join(map(str.rstrip, lines))
220
221
222 if __name__ == "__main__":
223 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
224 "\n"
225 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
226 "\n\n")
227 print(render_prefix_sum_diagram(16, work_efficient=False))
228 print()
229 print()
230 print("the work-efficient algorithm, matches the diagram in wikipedia:")
231 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
232 print()
233 print(render_prefix_sum_diagram(16, work_efficient=True))
234 print()
235 print()