1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
12 @dataclass(order
=True, unsafe_hash
=True, frozen
=True)
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
19 """index of the item to output to"""
21 """index of the item the left-hand-side input comes from"""
23 """index of the item the right-hand-side input comes from"""
25 """row in the prefix-sum diagram"""
28 def prefix_sum_ops(item_count
, *, work_efficient
=False):
29 """ Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
32 The operations aren't assumed to be commutative.
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
43 number of input items.
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
49 output associative operations.
51 assert isinstance(item_count
, int)
52 # compute the partial sums using a set of binary trees
53 # this is the first half of the work-efficient algorithm and the whole of
54 # the non-work-efficient algorithm.
57 while dist
< item_count
:
58 start
= dist
* 2 - 1 if work_efficient
else dist
59 step
= dist
* 2 if work_efficient
else 1
60 for i
in reversed(range(start
, item_count
, step
)):
61 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
65 # express all output items in terms of the computed partial sums.
68 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
69 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
74 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
75 """ Compute the parallel prefix-sum of `items`, using associative operator
76 `fn` instead of addition.
78 This has a depth of `O(log(N))` and an operation count of `O(N)` if
79 `work_efficient` is true, otherwise `O(N*log(N))`.
81 The algorithms used are derived from:
82 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
88 fn: Callable[[_T, _T], _T]
89 Operation to use for the prefix-sum algorithm instead of addition.
90 Assumed to be associative not necessarily commutative.
92 True if the algorithm used should be work-efficient -- has a larger
93 depth (about twice as large) but does only `O(N)` operations total
94 instead of `O(N*log(N))`.
99 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
100 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
111 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
112 sp
=" ", vbar
="|", plus
="⊕",
113 slant
="\\", connect
="●", no_connect
="X",
116 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
120 number of input items.
122 True if the algorithm used should be work-efficient -- has a larger
123 depth (about twice as large) but does only `O(N)` operations total
124 instead of `O(N*log(N))`.
126 character used for blank space
128 character used for a vertical bar
130 character used for the addition operation
132 character used to draw a line from the top left to the bottom right
134 character used to draw a connection between a vertical line and a line
135 going from the center of this character to the bottom right
137 character used to draw two lines crossing but not connecting, the lines
138 are vertical and diagonal from top left to the bottom right
140 amount of padding characters in the output cells.
144 assert isinstance(item_count
, int)
145 assert isinstance(padding
, int)
146 ops_by_row
= defaultdict(set)
147 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
148 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
149 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
150 ops_by_row
[op
.row
].add(op
)
153 return [_Cell(slant
=False, plus
=False, tee
=False)
154 for _
in range(item_count
)]
156 cells
= [blank_row()]
158 for row
in sorted(ops_by_row
.keys()):
159 ops
= ops_by_row
[row
]
160 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
161 cells
.extend(blank_row() for _
in range(max_distance
))
163 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
166 cells
[y
][x
].plus
= True
170 cells
[y
][x
].slant
= True
173 cells
[y
][x
].tee
= True
176 for cells_row
in cells
:
177 row_text
= [[] for y
in range(2 * padding
+ 1)]
178 for cell
in cells_row
:
180 for y
in range(padding
):
182 for x
in range(padding
):
183 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
184 row_text
[y
].append(slant
if is_slant
else sp
)
186 row_text
[y
].append(vbar
)
188 for x
in range(padding
):
189 row_text
[y
].append(sp
)
190 # center left padding
191 for x
in range(padding
):
192 row_text
[padding
].append(sp
)
201 row_text
[padding
].append(center
)
202 # center right padding
203 for x
in range(padding
):
204 row_text
[padding
].append(sp
)
206 for y
in range(padding
+ 1, 2 * padding
+ 1):
207 # bottom left padding
208 for x
in range(padding
):
209 row_text
[y
].append(sp
)
210 # bottom vertical bar
211 row_text
[y
].append(vbar
)
212 # bottom right padding
213 for x
in range(padding
+ 1, 2 * padding
+ 1):
214 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
215 row_text
[y
].append(slant
if is_slant
else sp
)
216 for line
in row_text
:
217 lines
.append("".join(line
))
219 return "\n".join(map(str.rstrip
, lines
))
222 if __name__
== "__main__":
223 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
225 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
227 print(render_prefix_sum_diagram(16, work_efficient
=False))
230 print("the work-efficient algorithm, matches the diagram in wikipedia:")
231 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
233 print(render_prefix_sum_diagram(16, work_efficient
=True))