implement CLDivRemFSMStage
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
index b6e4cfb7b19d023fccc02e14a1e03dca28655e08..fa93812df04b301cc21e9d7dd0fdb5e988f41555 100644 (file)
@@ -8,10 +8,13 @@ import unittest
 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned
 from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
-from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference,
+from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
+                                    CLDivRemOutputData, CLDivRemShape, CLDivRemState,
+                                    equal_leading_zero_count_reference,
                                     EqualLeadingZeroCount)
-from nmigen.sim import Delay
+from nmigen.sim import Delay, Tick
 from nmutil.sim_util import do_sim, hash_256
+from nmigen_gf.reference.cldivrem import cldivrem
 
 
 class TestEqualLeadingZeroCount(FHDLTestCase):
@@ -100,7 +103,175 @@ class TestEqualLeadingZeroCount(FHDLTestCase):
     def test_formal_3(self):
         self.tst_formal(3)
 
-# TODO: add TestCLDivRem
+
+class TestCLDivRemComb(FHDLTestCase):
+    def tst(self, shape, full):
+        assert isinstance(shape, CLDivRemShape)
+        m = Module()
+        n_in = Signal(shape.n_width)
+        d_in = Signal(shape.width)
+        states: "list[CLDivRemState]" = []
+        for i in shape.step_range:
+            states.append(CLDivRemState(shape, name=f"state_{i}"))
+            if i == 0:
+                states[i].set_to_initial(m, n=n_in, d=d_in)
+            else:
+                states[i].set_to_next(m, states[i - 1])
+
+        def case(n, d):
+            assert isinstance(n, int)
+            assert isinstance(d, int)
+            max_width = max(shape.width, shape.n_width)
+            if d != 0:
+                expected_q, expected_r = cldivrem(n, d, width=max_width)
+            else:
+                expected_q = expected_r = 0
+            with self.subTest(n=hex(n), d=hex(d),
+                              expected_q=hex(expected_q),
+                              expected_r=hex(expected_r)):
+                yield n_in.eq(n)
+                yield d_in.eq(d)
+                yield Delay(1e-6)
+                for i in shape.step_range:
+                    with self.subTest(i=i):
+                        done = yield states[i].done
+                        step = yield states[i].step
+                        self.assertEqual(done, i >= shape.done_step)
+                        self.assertEqual(step, i)
+                q = yield states[-1].q
+                r = yield states[-1].r
+                with self.subTest(q=hex(q), r=hex(r)):
+                    # only check results when inputs are valid
+                    if d != 0 and (expected_q >> shape.width) == 0:
+                        self.assertEqual(q, expected_q)
+                        self.assertEqual(r, expected_r)
+
+        def process():
+            if full:
+                for n in range(1 << shape.n_width):
+                    for d in range(1 << shape.width):
+                        yield from case(n, d)
+            else:
+                for i in range(100):
+                    n = hash_256(f"cldivrem comb n {i}")
+                    n = Const.normalize(n, unsigned(shape.n_width))
+                    d = hash_256(f"cldivrem comb d {i}")
+                    d = Const.normalize(d, unsigned(shape.width))
+                    yield from case(n, d)
+        with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test_4(self):
+        self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+
+    def test_8_by_4(self):
+        self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+
+
+class TestCLDivRemFSM(FHDLTestCase):
+    def tst(self, shape, full, steps_per_clock):
+        assert isinstance(shape, CLDivRemShape)
+        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+        pspec = {}
+        dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
+        i_data: CLDivRemInputData = dut.p.i_data
+        o_data: CLDivRemOutputData = dut.n.o_data
+        self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+        self.assertEqual(i_data.d.shape(), unsigned(shape.width))
+        self.assertEqual(o_data.q.shape(), unsigned(shape.width))
+        self.assertEqual(o_data.r.shape(), unsigned(shape.width))
+
+        def case(n, d):
+            assert isinstance(n, int)
+            assert isinstance(d, int)
+            max_width = max(shape.width, shape.n_width)
+            if d != 0:
+                expected_q, expected_r = cldivrem(n, d, width=max_width)
+            else:
+                expected_q = expected_r = 0
+            with self.subTest(n=hex(n), d=hex(d),
+                              expected_q=hex(expected_q),
+                              expected_r=hex(expected_r)):
+                yield dut.p.i_valid.eq(0)
+                yield Tick()
+                yield i_data.n.eq(n)
+                yield i_data.d.eq(d)
+                yield dut.p.i_valid.eq(1)
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield Tick()
+                yield i_data.n.eq(-1)
+                yield i_data.d.eq(-1)
+                yield dut.p.i_valid.eq(0)
+                for i in range(steps_per_clock * 2, shape.done_step,
+                               steps_per_clock):
+                    yield Delay(0.1e-6)
+                    valid = yield dut.n.o_valid
+                    ready = yield dut.p.o_ready
+                    with self.subTest():
+                        self.assertFalse(valid)
+                        self.assertFalse(ready)
+                    yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertTrue(valid)
+                    self.assertFalse(ready)
+                q = yield o_data.q
+                r = yield o_data.r
+                with self.subTest(q=hex(q), r=hex(r)):
+                    # only check results when inputs are valid
+                    if d != 0 and (expected_q >> shape.width) == 0:
+                        self.assertEqual(q, expected_q)
+                        self.assertEqual(r, expected_r)
+                yield dut.n.i_ready.eq(1)
+                yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield dut.n.i_ready.eq(0)
+
+        def process():
+            if full:
+                for n in range(1 << shape.n_width):
+                    for d in range(1 << shape.width):
+                        yield from case(n, d)
+            else:
+                for i in range(100):
+                    n = hash_256(f"cldivrem fsm n {i}")
+                    n = Const.normalize(n, unsigned(shape.n_width))
+                    d = hash_256(f"cldivrem fsm d {i}")
+                    d = Const.normalize(d, unsigned(shape.width))
+                    yield from case(n, d)
+
+        with do_sim(self, dut, list(dut.ports())) as sim:
+            sim.add_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_4_step_1(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=1)
+
+    def test_4_step_2(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=2)
+
+    def test_4_step_3(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=3)
 
 
 if __name__ == "__main__":