7f1c45a6d0503f5fd012248ce9ecde72f1f3be07
[nmutil.git] / src / nmutil / test / test_prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from functools import reduce
8 from nmutil.formaltest import FHDLTestCase
9 from nmutil.sim_util import write_il
10 from itertools import accumulate
11 import operator
12 from nmutil.prefix_sum import (Op, pop_count, prefix_sum,
13 render_prefix_sum_diagram,
14 tree_reduction, tree_reduction_ops)
15 import unittest
16 from nmigen.hdl.ast import Signal, AnyConst, Assert
17 from nmigen.hdl.dsl import Module
18
19
20 def reference_prefix_sum(items, fn):
21 return list(accumulate(items, fn))
22
23
24 class TestPrefixSum(FHDLTestCase):
25 maxDiff = None
26
27 def test_prefix_sum_str(self):
28 input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
29 expected = reference_prefix_sum(input_items, operator.add)
30 with self.subTest(expected=repr(expected)):
31 non_work_efficient = prefix_sum(input_items, work_efficient=False)
32 self.assertEqual(expected, non_work_efficient)
33 with self.subTest(expected=repr(expected)):
34 work_efficient = prefix_sum(input_items, work_efficient=True)
35 self.assertEqual(expected, work_efficient)
36
37 def test_tree_reduction_str(self):
38 input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i")
39 expected = reduce(operator.add, input_items)
40 with self.subTest(expected=repr(expected)):
41 work_efficient = tree_reduction(input_items)
42 self.assertEqual(expected, work_efficient)
43
44 def test_tree_reduction_ops_9(self):
45 ops = list(tree_reduction_ops(9))
46 self.assertEqual(ops, [
47 Op(out=8, lhs=7, rhs=8, row=0),
48 Op(out=6, lhs=5, rhs=6, row=0),
49 Op(out=4, lhs=3, rhs=4, row=0),
50 Op(out=2, lhs=1, rhs=2, row=0),
51 Op(out=8, lhs=6, rhs=8, row=1),
52 Op(out=4, lhs=2, rhs=4, row=1),
53 Op(out=8, lhs=4, rhs=8, row=2),
54 Op(out=8, lhs=0, rhs=8, row=3),
55 ])
56
57 def test_tree_reduction_ops_8(self):
58 ops = list(tree_reduction_ops(8))
59 self.assertEqual(ops, [
60 Op(out=7, lhs=6, rhs=7, row=0),
61 Op(out=5, lhs=4, rhs=5, row=0),
62 Op(out=3, lhs=2, rhs=3, row=0),
63 Op(out=1, lhs=0, rhs=1, row=0),
64 Op(out=7, lhs=5, rhs=7, row=1),
65 Op(out=3, lhs=1, rhs=3, row=1),
66 Op(out=7, lhs=3, rhs=7, row=2),
67 ])
68
69 def tst_pop_count_int(self, width):
70 assert isinstance(width, int)
71 for v in range(1 << width):
72 expected = f"{v:b}".count("1")
73 with self.subTest(v=v, expected=expected):
74 self.assertEqual(expected, pop_count(v, width=width))
75
76 def test_pop_count_int_0(self):
77 self.tst_pop_count_int(0)
78
79 def test_pop_count_int_1(self):
80 self.tst_pop_count_int(1)
81
82 def test_pop_count_int_2(self):
83 self.tst_pop_count_int(2)
84
85 def test_pop_count_int_3(self):
86 self.tst_pop_count_int(3)
87
88 def test_pop_count_int_4(self):
89 self.tst_pop_count_int(4)
90
91 def test_pop_count_int_5(self):
92 self.tst_pop_count_int(5)
93
94 def test_pop_count_int_6(self):
95 self.tst_pop_count_int(6)
96
97 def test_pop_count_int_7(self):
98 self.tst_pop_count_int(7)
99
100 def test_pop_count_int_8(self):
101 self.tst_pop_count_int(8)
102
103 def test_pop_count_int_9(self):
104 self.tst_pop_count_int(9)
105
106 def test_pop_count_int_10(self):
107 self.tst_pop_count_int(10)
108
109 def tst_pop_count_formal(self, width):
110 assert isinstance(width, int)
111 m = Module()
112 v = Signal(width)
113 out = Signal(16)
114
115 def process_temporary(v):
116 sig = Signal.like(v)
117 m.d.comb += sig.eq(v)
118 return sig
119
120 m.d.comb += out.eq(pop_count(v, process_temporary=process_temporary))
121 write_il(self, m, [v, out])
122 m.d.comb += v.eq(AnyConst(width))
123 expected = Signal(16)
124 m.d.comb += expected.eq(reduce(operator.add,
125 (v[i] for i in range(width)),
126 0))
127 m.d.comb += Assert(out == expected)
128 self.assertFormal(m)
129
130 def test_pop_count_formal_0(self):
131 self.tst_pop_count_formal(0)
132
133 def test_pop_count_formal_1(self):
134 self.tst_pop_count_formal(1)
135
136 def test_pop_count_formal_2(self):
137 self.tst_pop_count_formal(2)
138
139 def test_pop_count_formal_3(self):
140 self.tst_pop_count_formal(3)
141
142 def test_pop_count_formal_4(self):
143 self.tst_pop_count_formal(4)
144
145 def test_pop_count_formal_5(self):
146 self.tst_pop_count_formal(5)
147
148 def test_pop_count_formal_6(self):
149 self.tst_pop_count_formal(6)
150
151 def test_pop_count_formal_7(self):
152 self.tst_pop_count_formal(7)
153
154 def test_pop_count_formal_8(self):
155 self.tst_pop_count_formal(8)
156
157 def test_pop_count_formal_9(self):
158 self.tst_pop_count_formal(9)
159
160 def test_pop_count_formal_10(self):
161 self.tst_pop_count_formal(10)
162
163 def test_render_work_efficient(self):
164 text = render_prefix_sum_diagram(16, work_efficient=True, plus="@")
165 expected = r"""
166 | | | | | | | | | | | | | | | |
167 ● | ● | ● | ● | ● | ● | ● | ● |
168 |\ | |\ | |\ | |\ | |\ | |\ | |\ | |\ |
169 | \| | \| | \| | \| | \| | \| | \| | \|
170 | @ | @ | @ | @ | @ | @ | @ | @
171 | |\ | | | |\ | | | |\ | | | |\ | |
172 | | \| | | | \| | | | \| | | | \| |
173 | | X | | | X | | | X | | | X |
174 | | |\ | | | |\ | | | |\ | | | |\ |
175 | | | \| | | | \| | | | \| | | | \|
176 | | | @ | | | @ | | | @ | | | @
177 | | | |\ | | | | | | | |\ | | | |
178 | | | | \| | | | | | | | \| | | |
179 | | | | X | | | | | | | X | | |
180 | | | | |\ | | | | | | | |\ | | |
181 | | | | | \| | | | | | | | \| | |
182 | | | | | X | | | | | | | X | |
183 | | | | | |\ | | | | | | | |\ | |
184 | | | | | | \| | | | | | | | \| |
185 | | | | | | X | | | | | | | X |
186 | | | | | | |\ | | | | | | | |\ |
187 | | | | | | | \| | | | | | | | \|
188 | | | | | | | @ | | | | | | | @
189 | | | | | | | |\ | | | | | | | |
190 | | | | | | | | \| | | | | | | |
191 | | | | | | | | X | | | | | | |
192 | | | | | | | | |\ | | | | | | |
193 | | | | | | | | | \| | | | | | |
194 | | | | | | | | | X | | | | | |
195 | | | | | | | | | |\ | | | | | |
196 | | | | | | | | | | \| | | | | |
197 | | | | | | | | | | X | | | | |
198 | | | | | | | | | | |\ | | | | |
199 | | | | | | | | | | | \| | | | |
200 | | | | | | | | | | | X | | | |
201 | | | | | | | | | | | |\ | | | |
202 | | | | | | | | | | | | \| | | |
203 | | | | | | | | | | | | X | | |
204 | | | | | | | | | | | | |\ | | |
205 | | | | | | | | | | | | | \| | |
206 | | | | | | | | | | | | | X | |
207 | | | | | | | | | | | | | |\ | |
208 | | | | | | | | | | | | | | \| |
209 | | | | | | | | | | | | | | X |
210 | | | | | | | | | | | | | | |\ |
211 | | | | | | | | | | | | | | | \|
212 | | | | | | | ● | | | | | | | @
213 | | | | | | | |\ | | | | | | | |
214 | | | | | | | | \| | | | | | | |
215 | | | | | | | | X | | | | | | |
216 | | | | | | | | |\ | | | | | | |
217 | | | | | | | | | \| | | | | | |
218 | | | | | | | | | X | | | | | |
219 | | | | | | | | | |\ | | | | | |
220 | | | | | | | | | | \| | | | | |
221 | | | | | | | | | | X | | | | |
222 | | | | | | | | | | |\ | | | | |
223 | | | | | | | | | | | \| | | | |
224 | | | ● | | | ● | | | @ | | | |
225 | | | |\ | | | |\ | | | |\ | | | |
226 | | | | \| | | | \| | | | \| | | |
227 | | | | X | | | X | | | X | | |
228 | | | | |\ | | | |\ | | | |\ | | |
229 | | | | | \| | | | \| | | | \| | |
230 | ● | ● | @ | ● | @ | ● | @ | |
231 | |\ | |\ | |\ | |\ | |\ | |\ | |\ | |
232 | | \| | \| | \| | \| | \| | \| | \| |
233 | | @ | @ | @ | @ | @ | @ | @ |
234 | | | | | | | | | | | | | | | |
235 """
236 expected = expected[1:-1] # trim newline at start and end
237 if text != expected:
238 print("text:")
239 print(text)
240 print()
241 self.assertEqual(expected, text)
242
243 def test_render_not_work_efficient(self):
244 text = render_prefix_sum_diagram(16, work_efficient=False, plus="@")
245 expected = r"""
246 | | | | | | | | | | | | | | | |
247 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ● |
248 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
249 | \| \| \| \| \| \| \| \| \| \| \| \| \| \| \|
250 ● @ @ @ @ @ @ @ @ @ @ @ @ @ @ @
251 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | |
252 | \| \| \| \| \| \| \| \| \| \| \| \| \| \| |
253 | X X X X X X X X X X X X X X |
254 | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
255 | | \| \| \| \| \| \| \| \| \| \| \| \| \| \|
256 ● ● @ @ @ @ @ @ @ @ @ @ @ @ @ @
257 |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | | | |
258 | \| \| \| \| \| \| \| \| \| \| \| \| | | |
259 | X X X X X X X X X X X X | | |
260 | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | | |
261 | | \| \| \| \| \| \| \| \| \| \| \| \| | |
262 | | X X X X X X X X X X X X | |
263 | | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ | |
264 | | | \| \| \| \| \| \| \| \| \| \| \| \| |
265 | | | X X X X X X X X X X X X |
266 | | | |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |\ |
267 | | | | \| \| \| \| \| \| \| \| \| \| \| \|
268 ● ● ● ● @ @ @ @ @ @ @ @ @ @ @ @
269 |\ |\ |\ |\ |\ |\ |\ |\ | | | | | | | |
270 | \| \| \| \| \| \| \| \| | | | | | | |
271 | X X X X X X X X | | | | | | |
272 | |\ |\ |\ |\ |\ |\ |\ |\ | | | | | | |
273 | | \| \| \| \| \| \| \| \| | | | | | |
274 | | X X X X X X X X | | | | | |
275 | | |\ |\ |\ |\ |\ |\ |\ |\ | | | | | |
276 | | | \| \| \| \| \| \| \| \| | | | | |
277 | | | X X X X X X X X | | | | |
278 | | | |\ |\ |\ |\ |\ |\ |\ |\ | | | | |
279 | | | | \| \| \| \| \| \| \| \| | | | |
280 | | | | X X X X X X X X | | | |
281 | | | | |\ |\ |\ |\ |\ |\ |\ |\ | | | |
282 | | | | | \| \| \| \| \| \| \| \| | | |
283 | | | | | X X X X X X X X | | |
284 | | | | | |\ |\ |\ |\ |\ |\ |\ |\ | | |
285 | | | | | | \| \| \| \| \| \| \| \| | |
286 | | | | | | X X X X X X X X | |
287 | | | | | | |\ |\ |\ |\ |\ |\ |\ |\ | |
288 | | | | | | | \| \| \| \| \| \| \| \| |
289 | | | | | | | X X X X X X X X |
290 | | | | | | | |\ |\ |\ |\ |\ |\ |\ |\ |
291 | | | | | | | | \| \| \| \| \| \| \| \|
292 | | | | | | | | @ @ @ @ @ @ @ @
293 | | | | | | | | | | | | | | | |
294 """
295 expected = expected[1:-1] # trim newline at start and end
296 if text != expected:
297 print("text:")
298 print(text)
299 print()
300 self.assertEqual(expected, text)
301
302 # TODO: add more tests
303
304
305 if __name__ == "__main__":
306 unittest.main()