autoformat all code
[nmutil.git] / src / nmutil / test / test_prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from functools import reduce
8 from nmutil.formaltest import FHDLTestCase
9 from nmutil.sim_util import write_il
10 from itertools import accumulate
11 import operator
12 from nmutil.popcount import pop_count
13 from nmutil.prefix_sum import (Op, prefix_sum,
14 render_prefix_sum_diagram,
15 tree_reduction, tree_reduction_ops)
16 import unittest
17 from nmigen.hdl.ast import Signal, AnyConst, Assert
18 from nmigen.hdl.dsl import Module
19
20
21 def reference_prefix_sum(items, fn):
22 return list(accumulate(items, fn))
23
24
25 class TestPrefixSum(FHDLTestCase):
26 maxDiff = None
27
28 def test_prefix_sum_str(self):
29 input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
30 expected = reference_prefix_sum(input_items, operator.add)
31 with self.subTest(expected=repr(expected)):
32 non_work_efficient = prefix_sum(input_items, work_efficient=False)
33 self.assertEqual(expected, non_work_efficient)
34 with self.subTest(expected=repr(expected)):
35 work_efficient = prefix_sum(input_items, work_efficient=True)
36 self.assertEqual(expected, work_efficient)
37
38 def test_tree_reduction_str(self):
39 input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
40 expected = reduce(operator.add, input_items)
41 with self.subTest(expected=repr(expected)):
42 work_efficient = tree_reduction(input_items)
43 self.assertEqual(expected, work_efficient)
44
45 def test_tree_reduction_ops_9(self):
46 ops = list(tree_reduction_ops(9))
47 self.assertEqual(ops, [
48 Op(out=8, lhs=7, rhs=8, row=0),
49 Op(out=6, lhs=5, rhs=6, row=0),
50 Op(out=4, lhs=3, rhs=4, row=0),
51 Op(out=2, lhs=1, rhs=2, row=0),
52 Op(out=8, lhs=6, rhs=8, row=1),
53 Op(out=4, lhs=2, rhs=4, row=1),
54 Op(out=8, lhs=4, rhs=8, row=2),
55 Op(out=8, lhs=0, rhs=8, row=3),
56 ])
57
58 def test_tree_reduction_ops_8(self):
59 ops = list(tree_reduction_ops(8))
60 self.assertEqual(ops, [
61 Op(out=7, lhs=6, rhs=7, row=0),
62 Op(out=5, lhs=4, rhs=5, row=0),
63 Op(out=3, lhs=2, rhs=3, row=0),
64 Op(out=1, lhs=0, rhs=1, row=0),
65 Op(out=7, lhs=5, rhs=7, row=1),
66 Op(out=3, lhs=1, rhs=3, row=1),
67 Op(out=7, lhs=3, rhs=7, row=2),
68 ])
69
70 def tst_pop_count_int(self, width):
71 assert isinstance(width, int)
72 for v in range(1 << width):
73 expected = bin(v).count("1") # converts to a string, counts 1s
74 with self.subTest(v=v, expected=expected):
75 self.assertEqual(expected, pop_count(v, width=width))
76
77 def test_pop_count_int_0(self):
78 self.tst_pop_count_int(0)
79
80 def test_pop_count_int_1(self):
81 self.tst_pop_count_int(1)
82
83 def test_pop_count_int_2(self):
84 self.tst_pop_count_int(2)
85
86 def test_pop_count_int_3(self):
87 self.tst_pop_count_int(3)
88
89 def test_pop_count_int_4(self):
90 self.tst_pop_count_int(4)
91
92 def test_pop_count_int_5(self):
93 self.tst_pop_count_int(5)
94
95 def test_pop_count_int_6(self):
96 self.tst_pop_count_int(6)
97
98 def test_pop_count_int_7(self):
99 self.tst_pop_count_int(7)
100
101 def test_pop_count_int_8(self):
102 self.tst_pop_count_int(8)
103
104 def test_pop_count_int_9(self):
105 self.tst_pop_count_int(9)
106
107 def test_pop_count_int_10(self):
108 self.tst_pop_count_int(10)
109
110 def tst_pop_count_formal(self, width):
111 assert isinstance(width, int)
112 m = Module()
113 v = Signal(width)
114 out = Signal(16)
115
116 def process_temporary(v):
117 sig = Signal.like(v)
118 m.d.comb += sig.eq(v)
119 return sig
120
121 m.d.comb += out.eq(pop_count(v, process_temporary=process_temporary))
122 write_il(self, m, [v, out])
123 m.d.comb += v.eq(AnyConst(width))
124 expected = Signal(16)
125 m.d.comb += expected.eq(reduce(operator.add,
126 (v[i] for i in range(width)),
127 0))
128 m.d.comb += Assert(out == expected)
129 self.assertFormal(m)
130
131 def test_pop_count_formal_0(self):
132 self.tst_pop_count_formal(0)
133
134 def test_pop_count_formal_1(self):
135 self.tst_pop_count_formal(1)
136
137 def test_pop_count_formal_2(self):
138 self.tst_pop_count_formal(2)
139
140 def test_pop_count_formal_3(self):
141 self.tst_pop_count_formal(3)
142
143 def test_pop_count_formal_4(self):
144 self.tst_pop_count_formal(4)
145
146 def test_pop_count_formal_5(self):
147 self.tst_pop_count_formal(5)
148
149 def test_pop_count_formal_6(self):
150 self.tst_pop_count_formal(6)
151
152 def test_pop_count_formal_7(self):
153 self.tst_pop_count_formal(7)
154
155 def test_pop_count_formal_8(self):
156 self.tst_pop_count_formal(8)
157
158 def test_pop_count_formal_9(self):
159 self.tst_pop_count_formal(9)
160
161 def test_pop_count_formal_10(self):
162 self.tst_pop_count_formal(10)
163
164 def test_render_work_efficient(self):
165 text = render_prefix_sum_diagram(16, work_efficient=True, plus="@")
166 expected = r"""
167 | | | | | | | | | | | | | | | |
168 ● | ● | ● | ● | ● | ● | ● | ● |
169 |\ | |\ | |\ | |\ | |\ | |\ | |\ | |\ |
170 | \| | \| | \| | \| | \| | \| | \| | \|
171 | @ | @ | @ | @ | @ | @ | @ | @
172 | |\ | | | |\ | | | |\ | | | |\ | |
173 | | \| | | | \| | | | \| | | | \| |
174 | | X | | | X | | | X | | | X |
175 | | |\ | | | |\ | | | |\ | | | |\ |
176 | | | \| | | | \| | | | \| | | | \|
177 | | | @ | | | @ | | | @ | | | @
178 | | | |\ | | | | | | | |\ | | | |
179 | | | | \| | | | | | | | \| | | |
180 | | | | X | | | | | | | X | | |
181 | | | | |\ | | | | | | | |\ | | |
182 | | | | | \| | | | | | | | \| | |
183 | | | | | X | | | | | | | X | |
184 | | | | | |\ | | | | | | | |\ | |
185 | | | | | | \| | | | | | | | \| |
186 | | | | | | X | | | | | | | X |
187 | | | | | | |\ | | | | | | | |\ |
188 | | | | | | | \| | | | | | | | \|
189 | | | | | | | @ | | | | | | | @
190 | | | | | | | |\ | | | | | | | |
191 | | | | | | | | \| | | | | | | |
192 | | | | | | | | X | | | | | | |
193 | | | | | | | | |\ | | | | | | |
194 | | | | | | | | | \| | | | | | |
195 | | | | | | | | | X | | | | | |
196 | | | | | | | | | |\ | | | | | |
197 | | | | | | | | | | \| | | | | |
198 | | | | | | | | | | X | | | | |
199 | | | | | | | | | | |\ | | | | |
200 | | | | | | | | | | | \| | | | |
201 | | | | | | | | | | | X | | | |
202 | | | | | | | | | | | |\ | | | |
203 | | | | | | | | | | | | \| | | |
204 | | | | | | | | | | | | X | | |
205 | | | | | | | | | | | | |\ | | |
206 | | | | | | | | | | | | | \| | |
207 | | | | | | | | | | | | | X | |
208 | | | | | | | | | | | | | |\ | |
209 | | | | | | | | | | | | | | \| |
210 | | | | | | | | | | | | | | X |
211 | | | | | | | | | | | | | | |\ |
212 | | | | | | | | | | | | | | | \|
213 | | | | | | | ● | | | | | | | @
214 | | | | | | | |\ | | | | | | | |
215 | | | | | | | | \| | | | | | | |
216 | | | | | | | | X | | | | | | |
217 | | | | | | | | |\ | | | | | | |
218 | | | | | | | | | \| | | | | | |
219 | | | | | | | | | X | | | | | |
220 | | | | | | | | | |\ | | | | | |
221 | | | | | | | | | | \| | | | | |
222 | | | | | | | | | | X | | | | |
223 | | | | | | | | | | |\ | | | | |
224 | | | | | | | | | | | \| | | | |
225 | | | ● | | | ● | | | @ | | | |
226 | | | |\ | | | |\ | | | |\ | | | |
227 | | | | \| | | | \| | | | \| | | |
228 | | | | X | | | X | | | X | | |
229 | | | | |\ | | | |\ | | | |\ | | |
230 | | | | | \| | | | \| | | | \| | |
231 | ● | ● | @ | ● | @ | ● | @ | |
232 | |\ | |\ | |\ | |\ | |\ | |\ | |\ | |
233 | | \| | \| | \| | \| | \| | \| | \| |
234 | | @ | @ | @ | @ | @ | @ | @ |
235 | | | | | | | | | | | | | | | |
236 """
237 expected = expected[1:-1] # trim newline at start and end
238 if text != expected:
239 print("text:")
240 print(text)
241 print()
242 self.assertEqual(expected, text)
243
244 def test_render_not_work_efficient(self):
245 text = render_prefix_sum_diagram(16, work_efficient=False, plus="@")
246 expected = r"""
247 | | | | | | | | | | | | | | | |
248 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ● |
249 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
250 | \| \| \| \| \| \| \| \| \| \| \| \| \| \| \|
251 ● @ @ @ @ @ @ @ @ @ @ @ @ @ @ @
252 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | |
253 | \| \| \| \| \| \| \| \| \| \| \| \| \| \| |
254 | X X X X X X X X X X X X X X |
255 | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
256 | | \| \| \| \| \| \| \| \| \| \| \| \| \| \|
257 ● ● @ @ @ @ @ @ @ @ @ @ @ @ @ @
258 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | | | |
259 | \| \| \| \| \| \| \| \| \| \| \| \| | | |
260 | X X X X X X X X X X X X | | |
261 | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | | |
262 | | \| \| \| \| \| \| \| \| \| \| \| \| | |
263 | | X X X X X X X X X X X X | |
264 | | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | |
265 | | | \| \| \| \| \| \| \| \| \| \| \| \| |
266 | | | X X X X X X X X X X X X |
267 | | | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
268 | | | | \| \| \| \| \| \| \| \| \| \| \| \|
269 ● ● ● ● @ @ @ @ @ @ @ @ @ @ @ @
270 |\ |\ |\ |\ |\ |\ |\ |\ | | | | | | | |
271 | \| \| \| \| \| \| \| \| | | | | | | |
272 | X X X X X X X X | | | | | | |
273 | |\ |\ |\ |\ |\ |\ |\ |\ | | | | | | |
274 | | \| \| \| \| \| \| \| \| | | | | | |
275 | | X X X X X X X X | | | | | |
276 | | |\ |\ |\ |\ |\ |\ |\ |\ | | | | | |
277 | | | \| \| \| \| \| \| \| \| | | | | |
278 | | | X X X X X X X X | | | | |
279 | | | |\ |\ |\ |\ |\ |\ |\ |\ | | | | |
280 | | | | \| \| \| \| \| \| \| \| | | | |
281 | | | | X X X X X X X X | | | |
282 | | | | |\ |\ |\ |\ |\ |\ |\ |\ | | | |
283 | | | | | \| \| \| \| \| \| \| \| | | |
284 | | | | | X X X X X X X X | | |
285 | | | | | |\ |\ |\ |\ |\ |\ |\ |\ | | |
286 | | | | | | \| \| \| \| \| \| \| \| | |
287 | | | | | | X X X X X X X X | |
288 | | | | | | |\ |\ |\ |\ |\ |\ |\ |\ | |
289 | | | | | | | \| \| \| \| \| \| \| \| |
290 | | | | | | | X X X X X X X X |
291 | | | | | | | |\ |\ |\ |\ |\ |\ |\ |\ |
292 | | | | | | | | \| \| \| \| \| \| \| \|
293 | | | | | | | | @ @ @ @ @ @ @ @
294 | | | | | | | | | | | | | | | |
295 """
296 expected = expected[1:-1] # trim newline at start and end
297 if text != expected:
298 print("text:")
299 print(text)
300 print()
301 self.assertEqual(expected, text)
302
303 # TODO: add more tests
304
305
306 if __name__ == "__main__":
307 unittest.main()