HDL works for io_width=5
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 05:40:32 +0000 (22:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 05:40:32 +0000 (22:40 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index ea0ddda0763c6847336589ec38cc8b5426013ba5..a86fa78d111e3957b31e45a645b0ba8ac5e85174 100644 (file)
@@ -12,7 +12,7 @@ import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
@@ -937,7 +937,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             else:
                 break
 
-        return cached_new(params)
+        retval = cached_new(params)
+        assert isinstance(retval, GoldschmidtDivParams)
+        return retval
 
 
 def clz(v, wid):
@@ -979,6 +981,8 @@ class GoldschmidtDivOp(enum.Enum):
             d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
             assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
             state.f = params.table[d_m_1.bits]
+            state.f = state.f.to_frac_wid(expanded_width,
+                                          round_dir=RoundDir.DOWN)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.f is not None
             n = state.n * state.f
@@ -1003,7 +1007,6 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
     def gen_hdl(self, params, state, sync_rom):
-        # FIXME: finish getting hdl/simulation to work
         """generate the hdl for this operation.
 
         arguments:
@@ -1018,8 +1021,6 @@ class GoldschmidtDivOp(enum.Enum):
         assert isinstance(params, GoldschmidtDivParams)
         assert isinstance(state, GoldschmidtDivHDLState)
         m = state.m
-        expanded_width = params.expanded_width
-        table_addr_bits = params.table_addr_bits
         if self == GoldschmidtDivOp.Normalize:
             # normalize so 1 <= d < 2
             assert state.d.width == params.io_width
@@ -1029,27 +1030,41 @@ class GoldschmidtDivOp(enum.Enum):
             m.d.comb += d_leading_zeros.sig_in.eq(state.d)
             d_shift_out = Signal.like(state.d)
             m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
-            state.d = Signal(params.n_d_f_total_wid)
-            m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
-                                                   + params.n_d_f_int_wid))
+            d = Signal(params.n_d_f_total_wid)
+            m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+                             >> state.d.width)
 
             # normalize so 1 <= n < 2
             n_leading_zeros = CLZ(2 * params.io_width)
             m.submodules.n_leading_zeros = n_leading_zeros
             m.d.comb += n_leading_zeros.sig_in.eq(state.n)
-            n_shift_s_v = (params.io_width + d_leading_zeros.lz
+            signed_zero = Const(0, signed(1))  # force subtraction to be signed
+            n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
                            - n_leading_zeros.lz)
             n_shift_s = Signal.like(n_shift_s_v)
-            state.n_shift = Signal(d_leading_zeros.lz.width)
+            n_shift_n_lz_out = Signal.like(state.n)
+            n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
             m.d.comb += [
                 n_shift_s.eq(n_shift_s_v),
-                state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+                n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+                n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
             ]
+            state.n_shift = Signal(d_leading_zeros.lz.width)
             n = Signal(params.n_d_f_total_wid)
-            shifted_n = state.n << state.n_shift
-            fixed_shift = params.expanded_width - state.n.width
-            m.d.comb += n.eq(shifted_n << fixed_shift)
+            with m.If(n_shift_s < 0):
+                m.d.comb += [
+                    state.n_shift.eq(0),
+                    n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+                         >> state.d.width),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.n_shift.eq(n_shift_s),
+                    n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+                         >> state.n.width),
+                ]
             state.n = n
+            state.d = d
         elif self == GoldschmidtDivOp.FEqTableLookup:
             assert state.d.width == params.n_d_f_total_wid, "invalid d width"
             # compute initial f by table lookup
@@ -1058,7 +1073,7 @@ class GoldschmidtDivOp(enum.Enum):
             table_width = 1 + params.table_data_bits
             table = Memory(width=table_width, depth=len(params.table),
                            init=[i.bits for i in params.table])
-            addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+            addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
             if sync_rom:
                 table_read = table.read_port()
                 m.d.comb += table_read.addr.eq(addr)
@@ -1068,8 +1083,7 @@ class GoldschmidtDivOp(enum.Enum):
                 m.d.comb += table_read.addr.eq(addr)
             m.submodules.table_read = table_read
             state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
-            data_shift = (table_width - params.table_data_bits
-                          + params.expanded_width)
+            data_shift = params.expanded_width - params.table_data_bits
             m.d.comb += state.f.eq(table_read.data << data_shift)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.n.width == params.n_d_f_total_wid, "invalid n width"
@@ -1150,8 +1164,20 @@ class GoldschmidtDivState:
     algorithm.
     """
 
+    def __repr__(self):
+        fields_str = []
+        for field in fields(GoldschmidtDivState):
+            value = getattr(self, field.name)
+            if value is None:
+                continue
+            if isinstance(value, int) and field.name != "n_shift":
+                fields_str.append(f"{field.name}={hex(value)}")
+            else:
+                fields_str.append(f"{field.name}={value!r}")
+        return f"GoldschmidtDivState({', '.join(fields_str)})"
+
 
-def goldschmidt_div(n, d, params):
+def goldschmidt_div(n, d, params, trace=lambda state: None):
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1168,6 +1194,9 @@ def goldschmidt_div(n, d, params):
             denominator. a `width`-bit unsigned integer. must not be zero.
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
+        trace: Function[[GoldschmidtDivState], None]
+            called with the initial state and the state after executing each
+            operation in `params.ops`.
 
         returns: tuple[int, int]
             the quotient and remainder. a tuple of two `width`-bit unsigned
@@ -1187,8 +1216,10 @@ def goldschmidt_div(n, d, params):
         d=FixedPoint(d, params.io_width),
     )
 
+    trace(state)
     for op in params.ops:
         op.run(params, state)
+        trace(state)
 
     assert state.quotient is not None
     assert state.remainder is not None
@@ -1273,7 +1304,6 @@ class GoldschmidtDivHDLState:
 
 
 class GoldschmidtDivHDL(Elaboratable):
-    # FIXME: finish getting hdl/simulation to work
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1303,6 +1333,9 @@ class GoldschmidtDivHDL(Elaboratable):
             output quotient. only valid when `n < (d << params.io_width)`.
         r: Signal(unsigned(params.io_width))
             output remainder. only valid when `n < (d << params.io_width)`.
+        trace: list[GoldschmidtDivHDLState]
+            list of the initial state and the state after executing each
+            operation in `params.ops`.
     """
 
     @property
@@ -1321,15 +1354,16 @@ class GoldschmidtDivHDL(Elaboratable):
         self.q = Signal(unsigned(params.io_width))
         self.r = Signal(unsigned(params.io_width))
 
-    def elaborate(self, platform):
-        m = Module()
+        # in constructor so we get trace without needing to call elaborate
         state = GoldschmidtDivHDLState(
-            m=m,
+            m=Module(),
             orig_n=self.n,
             orig_d=self.d,
             n=self.n,
             d=self.d)
 
+        self.trace = [replace(state)]
+
         # copy and reverse
         pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
 
@@ -1339,14 +1373,19 @@ class GoldschmidtDivHDL(Elaboratable):
                 pipe_reg_indexes.pop()
                 state.insert_pipeline_register()
             op.gen_hdl(self.params, state, self.sync_rom)
+            self.trace.append(replace(state))
 
         while len(pipe_reg_indexes) > 0:
             pipe_reg_indexes.pop()
             state.insert_pipeline_register()
 
-        m.d.comb += self.q.eq(state.quotient)
-        m.d.comb += self.r.eq(state.remainder)
-        return m
+        state.m.d.comb += [
+            self.q.eq(state.quotient),
+            self.r.eq(state.remainder),
+        ]
+
+    def elaborate(self, platform):
+        return self.trace[0].m
 
 
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
index 5b4c89ad037e0963006dd96b6ea6f87fc7e253c5..b4d4fb85cd232b12b8eb28f40226e0679b23855a 100644 (file)
@@ -4,6 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+from dataclasses import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
@@ -12,8 +13,9 @@ from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
-    goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
+    FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -88,10 +90,8 @@ class TestGoldschmidtDiv(FHDLTestCase):
                         with self.subTest(q=hex(q), r=hex(r)):
                             self.assertEqual((q, r), (expected_q, expected_r))
 
-    @unittest.skip("hdl/simulation currently broken")
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
                 sync_rom=False):
-        # FIXME: finish getting hdl/simulation to work
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         m = Module()
@@ -103,7 +103,12 @@ class TestGoldschmidtDiv(FHDLTestCase):
 
         def iter_cases():
             if cases is not None:
-                yield from cases
+                for n, d in cases:
+                    assert isinstance(d, int) \
+                        and 0 < d < (1 << params.io_width), "invalid case"
+                    assert isinstance(n, int) \
+                        and 0 <= n < (d << params.io_width), "invalid case"
+                    yield (n, d)
                 return
             for d in range(1, 1 << io_width):
                 for n in range(d << io_width):
@@ -116,6 +121,45 @@ class TestGoldschmidtDiv(FHDLTestCase):
                 yield dut.d.eq(d)
                 yield Tick()
 
+        def check_interals(n, d):
+            # check internals only if dut is completely combinatorial
+            # so we don't have to figure out how to read values in
+            # previous clock cycles
+            if dut.total_pipeline_registers != 0:
+                return
+            ref_trace = []
+
+            def ref_trace_fn(state):
+                assert isinstance(state, GoldschmidtDivState)
+                ref_trace.append((replace(state)))
+            goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
+            self.assertEqual(len(dut.trace), len(ref_trace))
+            for index, state in enumerate(dut.trace):
+                ref_state = ref_trace[index]
+                last_op = None if index == 0 else params.ops[index - 1]
+                with self.subTest(index=index, state=repr(state),
+                                  ref_state=repr(ref_state),
+                                  last_op=str(last_op)):
+                    for field in fields(GoldschmidtDivHDLState):
+                        sig = getattr(state, field.name)
+                        if not isinstance(sig, Signal):
+                            continue
+                        ref_value = getattr(ref_state, field.name)
+                        ref_value_str = repr(ref_value)
+                        if isinstance(ref_value, int):
+                            ref_value_str = hex(ref_value)
+                        value = yield sig
+                        with self.subTest(field_name=field.name,
+                                          sig=repr(sig),
+                                          sig_shape=repr(sig.shape()),
+                                          value=hex(value),
+                                          ref_value=ref_value_str):
+                            if isinstance(ref_value, int):
+                                self.assertEqual(value, ref_value)
+                            else:
+                                assert isinstance(ref_value, FixedPoint)
+                                self.assertEqual(value, ref_value.bits)
+
         def check_outputs():
             yield Tick()
             for _ in range(dut.total_pipeline_registers):
@@ -130,6 +174,8 @@ class TestGoldschmidtDiv(FHDLTestCase):
                     r = yield dut.r
                     with self.subTest(q=hex(q), r=hex(r)):
                         self.assertEqual((q, r), (expected_q, expected_r))
+                    yield from check_interals(n, d)
+
                 yield Tick()
 
         with self.subTest(params=str(params)):