1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
12 @dataclass(order
=True, unsafe_hash
=True, frozen
=True)
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
18 out
: int; """index of the item to output to"""
19 lhs
: int; """index of the item the left-hand-side input comes from"""
20 rhs
: int; """index of the item the right-hand-side input comes from"""
21 row
: int; """row in the prefix-sum diagram"""
24 def prefix_sum_ops(item_count
, *, work_efficient
=False):
25 """ Get the associative operations needed to compute a parallel prefix-sum
26 of `item_count` items.
28 The operations aren't assumed to be commutative.
30 This has a depth of `O(log(N))` and an operation count of `O(N)` if
31 `work_efficient` is true, otherwise `O(N*log(N))`.
33 The algorithms used are derived from:
34 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
35 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
39 number of input items.
41 True if the algorithm used should be work-efficient -- has a larger
42 depth (about twice as large) but does only `O(N)` operations total
43 instead of `O(N*log(N))`.
45 output associative operations.
47 assert isinstance(item_count
, int)
48 # compute the partial sums using a set of binary trees
49 # this is the first half of the work-efficient algorithm and the whole of
50 # the non-work-efficient algorithm.
53 while dist
< item_count
:
54 start
= dist
* 2 - 1 if work_efficient
else dist
55 step
= dist
* 2 if work_efficient
else 1
56 for i
in reversed(range(start
, item_count
, step
)):
57 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
61 # express all output items in terms of the computed partial sums.
64 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
65 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
70 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
71 """ Compute the parallel prefix-sum of `items`, using associative operator
72 `fn` instead of addition.
74 This has a depth of `O(log(N))` and an operation count of `O(N)` if
75 `work_efficient` is true, otherwise `O(N*log(N))`.
77 The algorithms used are derived from:
78 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
79 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
84 fn: Callable[[_T, _T], _T]
85 Operation to use for the prefix-sum algorithm instead of addition.
86 Assumed to be associative not necessarily commutative.
88 True if the algorithm used should be work-efficient -- has a larger
89 depth (about twice as large) but does only `O(N)` operations total
90 instead of `O(N*log(N))`.
95 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
96 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
107 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
108 sp
=" ", vbar
="|", plus
="⊕",
109 slant
="\\", connect
="●", no_connect
="X",
112 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
116 number of input items.
118 True if the algorithm used should be work-efficient -- has a larger
119 depth (about twice as large) but does only `O(N)` operations total
120 instead of `O(N*log(N))`.
122 character used for blank space
124 character used for a vertical bar
126 character used for the addition operation
128 character used to draw a line from the top left to the bottom right
130 character used to draw a connection between a vertical line and a line
131 going from the center of this character to the bottom right
133 character used to draw two lines crossing but not connecting, the lines
134 are vertical and diagonal from top left to the bottom right
136 amount of padding characters in the output cells.
140 assert isinstance(item_count
, int)
141 assert isinstance(padding
, int)
142 ops_by_row
= defaultdict(set)
143 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
144 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
145 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
146 ops_by_row
[op
.row
].add(op
)
149 return [_Cell(slant
=False, plus
=False, tee
=False) \
150 for _
in range(item_count
)]
152 cells
= [blank_row()]
154 for row
in sorted(ops_by_row
.keys()):
155 ops
= ops_by_row
[row
]
156 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
157 cells
.extend(blank_row() for _
in range(max_distance
))
159 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
162 cells
[y
][x
].plus
= True
166 cells
[y
][x
].slant
= True
169 cells
[y
][x
].tee
= True
172 for cells_row
in cells
:
173 row_text
= [[] for y
in range(2 * padding
+ 1)]
174 for cell
in cells_row
:
176 for y
in range(padding
):
178 for x
in range(padding
):
179 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
180 row_text
[y
].append(slant
if is_slant
else sp
)
182 row_text
[y
].append(vbar
)
184 for x
in range(padding
):
185 row_text
[y
].append(sp
)
186 # center left padding
187 for x
in range(padding
):
188 row_text
[padding
].append(sp
)
197 row_text
[padding
].append(center
)
198 # center right padding
199 for x
in range(padding
):
200 row_text
[padding
].append(sp
)
202 for y
in range(padding
+ 1, 2 * padding
+ 1):
203 # bottom left padding
204 for x
in range(padding
):
205 row_text
[y
].append(sp
)
206 # bottom vertical bar
207 row_text
[y
].append(vbar
)
208 # bottom right padding
209 for x
in range(padding
+ 1, 2 * padding
+ 1):
210 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
211 row_text
[y
].append(slant
if is_slant
else sp
)
212 for line
in row_text
:
213 lines
.append("".join(line
))
215 return "\n".join(map(str.rstrip
, lines
))
218 if __name__
== "__main__":
219 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
221 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
223 print(render_prefix_sum_diagram(16, work_efficient
=False))
226 print("the work-efficient algorithm, matches the diagram in wikipedia:")
227 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
229 print(render_prefix_sum_diagram(16, work_efficient
=True))