fix so HDL works for 5, 8, 16, 32, and 64-bits.
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
index 5b4c89ad037e0963006dd96b6ea6f87fc7e253c5..66345fe20c77f9505b3dfade84d45bff15d40d0e 100644 (file)
@@ -4,16 +4,18 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+from dataclasses import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
-from nmutil.sim_util import do_sim
+from nmutil.sim_util import do_sim, hash_256
 from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
-    goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
+    FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -74,24 +76,45 @@ class TestGoldschmidtDiv(FHDLTestCase):
                                  table_addr_bits=1, table_data_bits=5,
                                  iter_count=1)
 
-    def tst(self, io_width):
+    @staticmethod
+    def cases(io_width, cases=None):
+        assert isinstance(io_width, int) and io_width >= 1
+        if cases is not None:
+            for n, d in cases:
+                assert isinstance(d, int) \
+                    and 0 < d < (1 << io_width), "invalid case"
+                assert isinstance(n, int) \
+                    and 0 <= n < (d << io_width), "invalid case"
+                yield (n, d)
+        elif io_width > 6:
+            assert io_width * 2 <= 256, \
+                "can't generate big enough numbers for test cases"
+            for i in range(10000):
+                d = hash_256(f'd {i}') % (1 << io_width)
+                if d == 0:
+                    d = 1
+                n = hash_256(f'n {i}') % (d << io_width)
+                yield (n, d)
+        else:
+            for d in range(1, 1 << io_width):
+                for n in range(d << io_width):
+                    yield (n, d)
+
+    def tst(self, io_width, cases=None):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         with self.subTest(params=str(params)):
-            for d in range(1, 1 << io_width):
-                for n in range(d << io_width):
-                    expected_q, expected_r = divmod(n, d)
-                    with self.subTest(n=hex(n), d=hex(d),
-                                      expected_q=hex(expected_q),
-                                      expected_r=hex(expected_r)):
-                        q, r = goldschmidt_div(n, d, params)
-                        with self.subTest(q=hex(q), r=hex(r)):
-                            self.assertEqual((q, r), (expected_q, expected_r))
-
-    @unittest.skip("hdl/simulation currently broken")
+            for n, d in self.cases(io_width, cases):
+                expected_q, expected_r = divmod(n, d)
+                with self.subTest(n=hex(n), d=hex(d),
+                                  expected_q=hex(expected_q),
+                                  expected_r=hex(expected_r)):
+                    q, r = goldschmidt_div(n, d, params)
+                    with self.subTest(q=hex(q), r=hex(r)):
+                        self.assertEqual((q, r), (expected_q, expected_r))
+
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
                 sync_rom=False):
-        # FIXME: finish getting hdl/simulation to work
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         m = Module()
@@ -101,26 +124,57 @@ class TestGoldschmidtDiv(FHDLTestCase):
         # make sync domain get added
         m.d.sync += Signal().eq(0)
 
-        def iter_cases():
-            if cases is not None:
-                yield from cases
-                return
-            for d in range(1, 1 << io_width):
-                for n in range(d << io_width):
-                    yield (n, d)
-
         def inputs_proc():
             yield Tick()
-            for n, d in iter_cases():
+            for n, d in self.cases(io_width, cases):
                 yield dut.n.eq(n)
                 yield dut.d.eq(d)
                 yield Tick()
 
+        def check_interals(n, d):
+            # check internals only if dut is completely combinatorial
+            # so we don't have to figure out how to read values in
+            # previous clock cycles
+            if dut.total_pipeline_registers != 0:
+                return
+            ref_trace = []
+
+            def ref_trace_fn(state):
+                assert isinstance(state, GoldschmidtDivState)
+                ref_trace.append((replace(state)))
+            goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
+            self.assertEqual(len(dut.trace), len(ref_trace))
+            for index, state in enumerate(dut.trace):
+                ref_state = ref_trace[index]
+                last_op = None if index == 0 else params.ops[index - 1]
+                with self.subTest(index=index, state=repr(state),
+                                  ref_state=repr(ref_state),
+                                  last_op=str(last_op)):
+                    for field in fields(GoldschmidtDivHDLState):
+                        sig = getattr(state, field.name)
+                        if not isinstance(sig, Signal):
+                            continue
+                        ref_value = getattr(ref_state, field.name)
+                        ref_value_str = repr(ref_value)
+                        if isinstance(ref_value, int):
+                            ref_value_str = hex(ref_value)
+                        value = yield sig
+                        with self.subTest(field_name=field.name,
+                                          sig=repr(sig),
+                                          sig_shape=repr(sig.shape()),
+                                          value=hex(value),
+                                          ref_value=ref_value_str):
+                            if isinstance(ref_value, int):
+                                self.assertEqual(value, ref_value)
+                            else:
+                                assert isinstance(ref_value, FixedPoint)
+                                self.assertEqual(value, ref_value.bits)
+
         def check_outputs():
             yield Tick()
             for _ in range(dut.total_pipeline_registers):
                 yield Tick()
-            for n, d in iter_cases():
+            for n, d in self.cases(io_width, cases):
                 yield Delay(0.1e-6)
                 expected_q, expected_r = divmod(n, d)
                 with self.subTest(n=hex(n), d=hex(d),
@@ -130,6 +184,8 @@ class TestGoldschmidtDiv(FHDLTestCase):
                     r = yield dut.r
                     with self.subTest(q=hex(q), r=hex(r)):
                         self.assertEqual((q, r), (expected_q, expected_r))
+                    yield from check_interals(n, d)
+
                 yield Tick()
 
         with self.subTest(params=str(params)):
@@ -150,9 +206,33 @@ class TestGoldschmidtDiv(FHDLTestCase):
     def test_6(self):
         self.tst(6)
 
+    def test_8(self):
+        self.tst(8)
+
+    def test_16(self):
+        self.tst(16)
+
+    def test_32(self):
+        self.tst(32)
+
+    def test_64(self):
+        self.tst(64)
+
     def test_sim_5(self):
         self.tst_sim(5)
 
+    def test_sim_8(self):
+        self.tst_sim(8)
+
+    def test_sim_16(self):
+        self.tst_sim(16)
+
+    def test_sim_32(self):
+        self.tst_sim(32)
+
+    def test_sim_64(self):
+        self.tst_sim(64)
+
     def tst_params(self, io_width):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)