34bfefa0d77c2c967eb36a21fd144596d41c2e56
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
12 @dataclass(order
=True, unsafe_hash
=True, frozen
=True)
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
19 """the index of the item to output to"""
21 """the index of the item the left-hand-side input comes from"""
23 """the index of the item the right-hand-side input comes from"""
25 """the row in the prefix-sum diagram"""
28 def prefix_sum_ops(item_count
, *, work_efficient
=False):
29 """ Get the associative operations needed to compute a parallel prefix-sum of
32 The operations aren't assumed to be commutative.
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
43 number of input items.
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
49 output associative operations.
51 assert isinstance(item_count
, int)
52 # compute the partial sums using a set of binary trees
53 # this is the first half of the work-efficient algorithm and the whole of
54 # the non-work-efficient algorithm.
57 while dist
< item_count
:
58 start
= dist
* 2 - 1 if work_efficient
else dist
59 step
= dist
* 2 if work_efficient
else 1
60 for i
in reversed(range(start
, item_count
, step
)):
61 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
65 # express all output items in terms of the computed partial sums.
68 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
69 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
74 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
75 """ Compute the parallel prefix-sum of `items`, using associative operator
76 `fn` instead of addition.
78 This has a depth of `O(log(N))` and an operation count of `O(N)` if
79 `work_efficient` is true, otherwise `O(N*log(N))`.
81 The algorithms used are derived from:
82 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
88 fn: Callable[[_T, _T], _T]
89 Operation to use for the prefix-sum algorithm instead of addition.
90 Assumed to be associative not necessarily commutative.
92 True if the algorithm used should be work-efficient -- has a larger
93 depth (about twice as large) but does only `O(N)` operations total
94 instead of `O(N*log(N))`.
99 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
100 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
111 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
112 sp
=" ", vbar
="|", plus
="⊕",
113 slant
="\\", connect
="●", no_connect
="X",
116 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
120 number of input items.
122 True if the algorithm used should be work-efficient -- has a larger
123 depth (about twice as large) but does only `O(N)` operations total
124 instead of `O(N*log(N))`.
126 character used for blank space
128 character used for a vertical bar
130 character used for the addition operation
132 character used to draw a line from the top left to the bottom right
134 character used to draw a connection between a vertical line and a line
135 going from the center of this character to the bottom right
137 character used to draw two lines crossing but not connecting, the lines
138 are vertical and diagonal from top left to the bottom right
140 amount of padding characters in the output cells.
144 assert isinstance(item_count
, int)
145 assert isinstance(padding
, int)
146 ops_by_row
= defaultdict(set)
147 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
148 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
149 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
150 ops_by_row
[op
.row
].add(op
)
153 return [_Cell(slant
=False, plus
=False, tee
=False) for _
in range(item_count
)]
155 cells
= [blank_row()]
157 for row
in sorted(ops_by_row
.keys()):
158 ops
= ops_by_row
[row
]
159 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
160 cells
.extend(blank_row() for _
in range(max_distance
))
162 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
165 cells
[y
][x
].plus
= True
169 cells
[y
][x
].slant
= True
172 cells
[y
][x
].tee
= True
175 for cells_row
in cells
:
176 row_text
= [[] for y
in range(2 * padding
+ 1)]
177 for cell
in cells_row
:
179 for y
in range(padding
):
181 for x
in range(padding
):
182 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
183 row_text
[y
].append(slant
if is_slant
else sp
)
185 row_text
[y
].append(vbar
)
187 for x
in range(padding
):
188 row_text
[y
].append(sp
)
189 # center left padding
190 for x
in range(padding
):
191 row_text
[padding
].append(sp
)
200 row_text
[padding
].append(center
)
201 # center right padding
202 for x
in range(padding
):
203 row_text
[padding
].append(sp
)
205 for y
in range(padding
+ 1, 2 * padding
+ 1):
206 # bottom left padding
207 for x
in range(padding
):
208 row_text
[y
].append(sp
)
209 # bottom vertical bar
210 row_text
[y
].append(vbar
)
211 # bottom right padding
212 for x
in range(padding
+ 1, 2 * padding
+ 1):
213 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
214 row_text
[y
].append(slant
if is_slant
else sp
)
215 for line
in row_text
:
216 lines
.append("".join(line
))
218 return "\n".join(map(str.rstrip
, lines
))
221 if __name__
== "__main__":
222 print("the non-work-efficient algorithm, matches the diagram in wikipedia:")
223 print("https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg")
225 print(render_prefix_sum_diagram(16, work_efficient
=False))
228 print("the work-efficient algorithm, matches the diagram in wikipedia:")
229 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
231 print(render_prefix_sum_diagram(16, work_efficient
=True))