add NLnet Grant References
[nmutil.git] / src / nmutil / test / test_grev.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # XXX - this is insufficient See Notices.txt for copyright information - XXX
3 # XXX TODO: add individual copyright
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 import unittest
9 from nmigen.hdl.ast import AnyConst, Assert
10 from nmigen.hdl.dsl import Module
11 from nmutil.formaltest import FHDLTestCase
12 from nmutil.grev import GRev, grev
13 from nmigen.sim import Delay
14 from nmutil.sim_util import do_sim, hash_256
15
16
17 class TestGrev(FHDLTestCase):
18 def test(self):
19 log2_width = 6
20 width = 2 ** log2_width
21 dut = GRev(log2_width)
22 self.assertEqual(width, dut.width)
23
24 def case(input, chunk_sizes):
25 expected = grev(input, chunk_sizes, log2_width)
26 with self.subTest(input=hex(input), chunk_sizes=bin(chunk_sizes),
27 expected=hex(expected)):
28 yield dut.input.eq(input)
29 yield dut.chunk_sizes.eq(chunk_sizes)
30 yield Delay(1e-6)
31 output = yield dut.output
32 with self.subTest(output=hex(output)):
33 self.assertEqual(expected, output)
34 for i, step in enumerate(dut._steps):
35 cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
36 step_expected = grev(input, cur_chunk_sizes, log2_width)
37 step = yield step
38 with self.subTest(i=i, step=hex(step),
39 cur_chunk_sizes=bin(cur_chunk_sizes),
40 step_expected=hex(step_expected)):
41 self.assertEqual(step, step_expected)
42
43 def process():
44 self.assertEqual(len(dut._steps), log2_width + 1)
45 for count in range(width + 1):
46 input = (1 << count) - 1
47 for chunk_sizes in range(2 ** log2_width):
48 yield from case(input, chunk_sizes)
49 for i in range(100):
50 input = hash_256(f"grev input {i}")
51 input &= 2 ** width - 1
52 chunk_sizes = hash_256(f"grev 2 {i}")
53 chunk_sizes &= 2 ** log2_width - 1
54 yield from case(input, chunk_sizes)
55 with do_sim(self, dut, [dut.input, dut.chunk_sizes,
56 dut.output]) as sim:
57 sim.add_process(process)
58 sim.run()
59
60 def test_formal(self):
61 log2_width = 4
62 dut = GRev(log2_width)
63 m = Module()
64 m.submodules.dut = dut
65 m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
66 m.d.comb += dut.chunk_sizes.eq(AnyConst(log2_width))
67 # actual formal correctness proof is inside the module itself, now
68 self.assertFormal(m)
69
70
71 if __name__ == "__main__":
72 unittest.main()