add copyright notices
[nmutil.git] / src / nmutil / test / test_grev.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import AnyConst, Assert
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmutil.grev import GRev, grev
12 from nmigen.sim import Delay
13 from nmutil.sim_util import do_sim, hash_256
14
15
16 class TestGrev(FHDLTestCase):
17 def test(self):
18 log2_width = 6
19 width = 2 ** log2_width
20 dut = GRev(log2_width)
21 self.assertEqual(width, dut.width)
22
23 def case(inval, chunk_sizes):
24 expected = grev(inval, chunk_sizes, log2_width)
25 with self.subTest(inval=hex(inval), chunk_sizes=bin(chunk_sizes),
26 expected=hex(expected)):
27 yield dut.input.eq(inval)
28 yield dut.chunk_sizes.eq(chunk_sizes)
29 yield Delay(1e-6)
30 output = yield dut.output
31 with self.subTest(output=hex(output)):
32 self.assertEqual(expected, output)
33 for i, step in enumerate(dut._steps):
34 cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
35 step_expected = grev(inval, cur_chunk_sizes, log2_width)
36 step = yield step
37 with self.subTest(i=i, step=hex(step),
38 cur_chunk_sizes=bin(cur_chunk_sizes),
39 step_expected=hex(step_expected)):
40 self.assertEqual(step, step_expected)
41
42 def process():
43 self.assertEqual(len(dut._steps), log2_width + 1)
44 for count in range(width + 1):
45 inval = (1 << count) - 1
46 for chunk_sizes in range(2 ** log2_width):
47 yield from case(inval, chunk_sizes)
48 for i in range(100):
49 inval = hash_256(f"grev input {i}")
50 inval &= 2 ** width - 1
51 chunk_sizes = hash_256(f"grev 2 {i}")
52 chunk_sizes &= 2 ** log2_width - 1
53 yield from case(inval, chunk_sizes)
54 with do_sim(self, dut, [dut.input, dut.chunk_sizes,
55 dut.output]) as sim:
56 sim.add_process(process)
57 sim.run()
58
59 def test_formal(self):
60 log2_width = 4
61 dut = GRev(log2_width)
62 m = Module()
63 m.submodules.dut = dut
64 m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
65 m.d.comb += dut.chunk_sizes.eq(AnyConst(log2_width))
66 # actual formal correctness proof is inside the module itself, now
67 self.assertFormal(m)
68
69
70 if __name__ == "__main__":
71 unittest.main()