split step counter into clock and substep
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
index 438b547fca5cc822f91c6add9f71cadd2ddfe8cb..efcd23e08f8511e51e0e5e34a37ef10201ac5f9e 100644 (file)
@@ -17,39 +17,68 @@ from nmigen_gf.reference.cldivrem import cldivrem
 
 
 class TestCLDivRemShifting(FHDLTestCase):
-    def tst(self, width, full):
+    def tst(self, shape, full):
+        assert isinstance(shape, CLDivRemShape)
+
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=width)
-                q, r = cldivrem_shifting(n, d, width=width)
+                expected_q, expected_r = cldivrem(n, d, width=shape.width)
+                q, r = cldivrem_shifting(n, d, shape)
             else:
                 expected_q = expected_r = 0
                 q = r = 0
-            with self.subTest(n=hex(n), d=hex(d),
-                              expected_q=hex(expected_q),
+            with self.subTest(expected_q=hex(expected_q),
                               expected_r=hex(expected_r),
                               q=hex(q), r=hex(r)):
                 self.assertEqual(expected_q, q)
                 self.assertEqual(expected_r, r)
         if full:
-            for n in range(1 << width):
-                for d in range(1 << width):
-                    case(n, d)
+            for n in range(1 << shape.width):
+                for d in range(1 << shape.width):
+                    with self.subTest(n=hex(n), d=hex(d)):
+                        case(n, d)
         else:
             for i in range(100):
                 n = hash_256(f"cldivrem comb n {i}")
-                n = Const.normalize(n, unsigned(width))
+                n = Const.normalize(n, unsigned(shape.width))
                 d = hash_256(f"cldivrem comb d {i}")
-                d = Const.normalize(d, unsigned(width))
+                d = Const.normalize(d, unsigned(shape.width))
                 case(n, d)
 
-    def test_6(self):
-        self.tst(6, full=True)
+    def test_6_step_1(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=True)
+
+    def test_6_step_2(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=True)
+
+    def test_6_step_3(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=3), full=True)
 
-    def test_64(self):
-        self.tst(64, full=False)
+    def test_6_step_4(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=4), full=True)
+
+    def test_6_step_6(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=True)
+
+    def test_6_step_10(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=10), full=True)
+
+    def test_64_step_1(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=1), full=False)
+
+    def test_64_step_2(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=2), full=False)
+
+    def test_64_step_3(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=3), full=False)
+
+    def test_64_step_4(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
+
+    def test_64_step_8(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
 
 
 class TestCLDivRemComb(FHDLTestCase):
@@ -62,7 +91,7 @@ class TestCLDivRemComb(FHDLTestCase):
         q_out = Signal(width)
         r_out = Signal(width)
         states: "list[CLDivRemState]" = []
-        for i in shape.step_range:
+        for i in range(shape.width + 1):
             states.append(CLDivRemState(shape, name=f"state_{i}"))
             if i == 0:
                 states[i].set_to_initial(m, n=n_in, d=d_in)
@@ -75,7 +104,7 @@ class TestCLDivRemComb(FHDLTestCase):
             assert isinstance(n, int)
             assert isinstance(d, int)
             if d != 0:
-                expected_q, expected_r = cldivrem_shifting(n, d, width)
+                expected_q, expected_r = cldivrem_shifting(n, d, shape)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -84,12 +113,17 @@ class TestCLDivRemComb(FHDLTestCase):
                 yield n_in.eq(n)
                 yield d_in.eq(d)
                 yield Delay(1e-6)
-                for i in shape.step_range:
+                for i, state in enumerate(states):
                     with self.subTest(i=i):
-                        done = yield states[i].done
-                        step = yield states[i].step
-                        self.assertEqual(done, i >= shape.done_step)
-                        self.assertEqual(step, i)
+                        done = yield state.done
+                        substep = yield state.substep
+                        clock = yield state.clock
+                        self.assertEqual(done, i >= shape.width)
+                        if i % shape.steps_per_clock == 0:
+                            self.assertEqual(substep, 0)
+                        if i < shape.width:
+                            self.assertEqual(substep
+                                             + clock * shape.steps_per_clock, i)
                 q = yield q_out
                 r = yield r_out
                 with self.subTest(q=hex(q), r=hex(r)):
@@ -114,22 +148,45 @@ class TestCLDivRemComb(FHDLTestCase):
             sim.add_process(process)
             sim.run()
 
-    def test_4(self):
-        self.tst(CLDivRemShape(width=4), full=True)
+    def test_4_step_1(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
+
+    def test_4_step_2(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
+
+    def test_4_step_3(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
+
+    def test_4_step_4(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
 
-    def test_6(self):
-        self.tst(CLDivRemShape(width=6), full=True)
+    def test_6_step_1(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=False)
 
-    def test_8(self):
-        self.tst(CLDivRemShape(width=8), full=False)
+    def test_6_step_2(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=False)
+
+    def test_6_step_6(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=False)
+
+    def test_6_step_8(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=8), full=False)
+
+    def test_8_step_1(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=1), full=False)
+
+    def test_8_step_4(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
+
+    def test_8_step_8(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=8), full=False)
 
 
 class TestCLDivRemFSM(FHDLTestCase):
-    def tst(self, shape, full, steps_per_clock):
+    def tst(self, shape, full):
         assert isinstance(shape, CLDivRemShape)
-        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
         pspec = {}
-        dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
+        dut = CLDivRemFSMStage(pspec, shape)
         i_data: CLDivRemInputData = dut.p.i_data
         o_data: CLDivRemOutputData = dut.n.o_data
         self.assertEqual(i_data.n.shape(), unsigned(shape.width))
@@ -162,7 +219,7 @@ class TestCLDivRemFSM(FHDLTestCase):
                 yield i_data.n.eq(-1)
                 yield i_data.d.eq(-1)
                 yield dut.p.i_valid.eq(0)
-                for step in range(0, shape.done_step, steps_per_clock):
+                for step in range(shape.done_clock):
                     yield Delay(0.1e-6)
                     valid = yield dut.n.o_valid
                     ready = yield dut.p.o_ready
@@ -212,39 +269,25 @@ class TestCLDivRemFSM(FHDLTestCase):
             sim.run()
 
     def test_4_step_1(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=1)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
 
     def test_4_step_2(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=2)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
 
     def test_4_step_3(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=3)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
 
     def test_4_step_4(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
 
     def test_8_step_4(self):
-        self.tst(CLDivRemShape(width=8),
-                 full=False,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
 
     def test_64_step_4(self):
-        self.tst(CLDivRemShape(width=64),
-                 full=False,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
 
     def test_64_step_8(self):
-        self.tst(CLDivRemShape(width=64),
-                 full=False,
-                 steps_per_clock=8)
+        self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
 
 
 if __name__ == "__main__":