add grev test and formal proof
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 17 Dec 2021 03:12:54 +0000 (19:12 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 17 Dec 2021 03:12:54 +0000 (19:12 -0800)
src/nmutil/test/test_grev.py [new file with mode: 0644]

diff --git a/src/nmutil/test/test_grev.py b/src/nmutil/test/test_grev.py
new file mode 100644 (file)
index 0000000..18e1917
--- /dev/null
@@ -0,0 +1,73 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# See Notices.txt for copyright information
+
+import unittest
+from nmigen.hdl.ast import AnyConst, Assert
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.grev import GRev, grev
+from nmigen.sim import Delay
+from nmutil.sim_util import do_sim, hash_256
+
+
+class TestGrev(FHDLTestCase):
+    def test(self):
+        log2_width = 6
+        width = 2 ** log2_width
+        dut = GRev(log2_width)
+        self.assertEqual(width, dut.width)
+        self.assertEqual(len(dut._steps), log2_width + 1)
+
+        def case(input, chunk_sizes):
+            expected = grev(input, chunk_sizes, log2_width)
+            with self.subTest(input=hex(input), chunk_sizes=bin(chunk_sizes),
+                              expected=hex(expected)):
+                yield dut.input.eq(input)
+                yield dut.chunk_sizes.eq(chunk_sizes)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(expected, output)
+                for i, step in enumerate(dut._steps):
+                    cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
+                    step_expected = grev(input, cur_chunk_sizes, log2_width)
+                    step = yield step
+                    with self.subTest(i=i, step=hex(step),
+                                      cur_chunk_sizes=bin(cur_chunk_sizes),
+                                      step_expected=hex(step_expected)):
+                        self.assertEqual(step, step_expected)
+
+        def process():
+            for count in range(width + 1):
+                input = (1 << count) - 1
+                for chunk_sizes in range(2 ** log2_width):
+                    yield from case(input, chunk_sizes)
+            for i in range(100):
+                input = hash_256(f"grev input {i}")
+                input &= 2 ** width - 1
+                chunk_sizes = hash_256(f"grev 2 {i}")
+                chunk_sizes &= 2 ** log2_width - 1
+                yield from case(input, chunk_sizes)
+        with do_sim(self, dut, [dut.input, dut.chunk_sizes,
+                                *dut._steps, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test_formal(self):
+        log2_width = 4
+        dut = GRev(log2_width)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
+        m.d.comb += dut.chunk_sizes.eq(AnyConst(log2_width))
+        m.d.comb += Assert(dut.output == grev(dut.input,
+                                              dut.chunk_sizes, log2_width))
+        for i, step in enumerate(dut._steps):
+            cur_chunk_sizes = dut.chunk_sizes & (2 ** i - 1)
+            step_expected = grev(dut.input, cur_chunk_sizes, log2_width)
+            m.d.comb += Assert(step == step_expected)
+        self.assertFormal(m)
+
+
+if __name__ == "__main__":
+    unittest.main()