add tests for non-power-of-2 shifts for MultiShift* classes
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Jul 2022 06:18:20 +0000 (23:18 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Jul 2022 06:18:20 +0000 (23:18 -0700)
src/ieee754/add/test_multishift.py

index 86483453e4f9eda81d668823d2ed601c7627a511..1859b8a76cea620e848e5b9762d2cf12258350b2 100644 (file)
@@ -1,10 +1,11 @@
-from random import randint
-from nmigen import Module, Signal
-from nmigen.compat.sim import run_simulation
-
+from nmigen import Module, Signal, Elaboratable
+from nmutil.formaltest import FHDLTestCase
+from nmutil.sim_util import do_sim, hash_256
+import unittest
 from ieee754.fpcommon.fpbase import MultiShift, MultiShiftR, MultiShiftRMerge
 
-class MultiShiftModL:
+
+class MultiShiftModL(Elaboratable):
     def __init__(self, width):
         self.ms = MultiShift(width)
         self.a = Signal(width)
@@ -18,7 +19,8 @@ class MultiShiftModL:
 
         return m
 
-class MultiShiftModR:
+
+class MultiShiftModR(Elaboratable):
     def __init__(self, width):
         self.ms = MultiShift(width)
         self.a = Signal(width)
@@ -32,7 +34,8 @@ class MultiShiftModR:
 
         return m
 
-class MultiShiftModRMod:
+
+class MultiShiftModRMod(Elaboratable):
     def __init__(self, width):
         self.ms = MultiShiftR(width)
         self.a = Signal(width)
@@ -49,7 +52,8 @@ class MultiShiftModRMod:
 
         return m
 
-class MultiShiftRMergeMod:
+
+class MultiShiftRMergeMod(Elaboratable):
     def __init__(self, width):
         self.ms = MultiShiftRMerge(width)
         self.a = Signal(width)
@@ -67,68 +71,123 @@ class MultiShiftRMergeMod:
         return m
 
 
-def check_case(dut, width, a, b):
-    yield dut.a.eq(a)
-    yield dut.b.eq(b)
-    yield
+class TestMultiShift(FHDLTestCase):
+    def check_case(self, dut, width, a, b):
+        yield dut.a.eq(a)
+        yield dut.b.eq(b)
+        yield
 
-    x = (a << b) & ((1<<width)-1)
+        x = (a << b) & ((1 << width) - 1)
 
-    out_x = yield dut.x
-    assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
+        out_x = yield dut.x
+        self.assertEqual(
+            out_x, x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x))
 
-def check_caser(dut, width, a, b):
-    yield dut.a.eq(a)
-    yield dut.b.eq(b)
-    yield
+    def check_caser(self, dut, width, a, b):
+        yield dut.a.eq(a)
+        yield dut.b.eq(b)
+        yield
 
-    x = (a >> b) & ((1<<width)-1)
+        x = (a >> b) & ((1 << width) - 1)
 
-    out_x = yield dut.x
-    assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
+        out_x = yield dut.x
+        self.assertEqual(
+            out_x, x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x))
 
+    def check_case_merge(self, dut, width, a, b):
+        yield dut.a.eq(a)
+        yield dut.b.eq(b)
+        yield
 
-def check_case_merge(dut, width, a, b):
-    yield dut.a.eq(a)
-    yield dut.b.eq(b)
-    yield
+        x = (a >> b) & ((1 << width) - 1)  # actual shift
+        if (a & ((2 << b) - 1)) != 0:  # mask for sticky bit
+            x |= 1  # set LSB
 
-    x = (a >> b) & ((1<<width)-1) # actual shift
-    if (a & ((2<<b)-1)) != 0: # mask for sticky bit
-        x |= 1 # set LSB
+        out_x = yield dut.x
+        self.assertEqual(
+            out_x, x, "\nshift %d\nInput\n%+32s\nOutput x\n%+32s != \n%+32s" %
+            (b, bin(a), bin(out_x), bin(x)))
 
-    out_x = yield dut.x
-    assert out_x == x, \
-                "\nshift %d\nInput\n%+32s\nOutput x\n%+32s != \n%+32s" % \
-                        (b, bin(a), bin(out_x), bin(x))
+    def tst_multi_shift_r_merge(self, width):
+        dut = MultiShiftRMergeMod(width=width)
 
-def testmerge(dut):
-    for i in range(32):
-        for j in range(1000):
-            a = randint(0, (1<<32)-1)
-            yield from check_case_merge(dut, 32, a, i)
+        def process():
+            for i in range(width):
+                for j in range(1000):
+                    a = hash_256(f"MultiShiftRMerge {i} {j}") % (1 << width)
+                    yield from self.check_case_merge(dut, width, a, i)
 
-def testbench(dut):
-    for i in range(32):
-        for j in range(1000):
-            a = randint(0, (1<<32)-1)
-            yield from check_case(dut, 32, a, i)
+        with do_sim(self, dut, [dut.a, dut.b, dut.x]) as sim:
+            sim.add_sync_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
 
-def testbenchr(dut):
-    for i in range(32):
-        for j in range(1000):
-            a = randint(0, (1<<32)-1)
-            yield from check_caser(dut, 32, a, i)
+    def test_multi_shift_r_merge_32(self):
+        self.tst_multi_shift_r_merge(32)
 
-if __name__ == '__main__':
-    dut = MultiShiftRMergeMod(width=32)
-    run_simulation(dut, testmerge(dut), vcd_name="test_multishiftmerge.vcd")
-    dut = MultiShiftModRMod(width=32)
-    run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")
+    def test_multi_shift_r_merge_24(self):
+        self.tst_multi_shift_r_merge(24)
+
+    def tst_multi_shift_r_mod(self, width):
+        dut = MultiShiftModRMod(width=width)
+
+        def process():
+            for i in range(width):
+                for j in range(1000):
+                    a = hash_256(f"MultiShiftRMod {i} {j}") % (1 << width)
+                    yield from self.check_caser(dut, width, a, i)
+
+        with do_sim(self, dut, [dut.a, dut.b, dut.x]) as sim:
+            sim.add_sync_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_multi_shift_r_mod_32(self):
+        self.tst_multi_shift_r_mod(32)
+
+    def test_multi_shift_r_mod_24(self):
+        self.tst_multi_shift_r_mod(24)
 
-    dut = MultiShiftModR(width=32)
-    run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")
+    def tst_multi_shift_r(self, width):
+        dut = MultiShiftModR(width=width)
 
-    dut = MultiShiftModL(width=32)
-    run_simulation(dut, testbench(dut), vcd_name="test_multishift.vcd")
+        def process():
+            for i in range(width):
+                for j in range(1000):
+                    a = hash_256(f"MultiShiftModR {i} {j}") % (1 << width)
+                    yield from self.check_caser(dut, width, a, i)
 
+        with do_sim(self, dut, [dut.a, dut.b, dut.x]) as sim:
+            sim.add_sync_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_multi_shift_r_32(self):
+        self.tst_multi_shift_r(32)
+
+    def test_multi_shift_r_24(self):
+        self.tst_multi_shift_r(24)
+
+    def tst_multi_shift_l(self, width):
+        dut = MultiShiftModL(width=width)
+
+        def process():
+            for i in range(width):
+                for j in range(1000):
+                    a = hash_256(f"MultiShiftModL {i} {j}") % (1 << width)
+                    yield from self.check_case(dut, width, a, i)
+
+        with do_sim(self, dut, [dut.a, dut.b, dut.x]) as sim:
+            sim.add_sync_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_multi_shift_l_32(self):
+        self.tst_multi_shift_l(32)
+
+    def test_multi_shift_l_24(self):
+        self.tst_multi_shift_l(24)
+
+
+if __name__ == '__main__':
+    unittest.main()