HDL works for io_width=5
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 05:40:32 +0000 (22:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 05:40:32 +0000 (22:40 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index ea0ddda0763c6847336589ec38cc8b5426013ba5..a86fa78d111e3957b31e45a645b0ba8ac5e85174 100644 (file)
@@ -12,7 +12,7 @@ import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
@@ -937,7 +937,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             else:
                 break
 
             else:
                 break
 
-        return cached_new(params)
+        retval = cached_new(params)
+        assert isinstance(retval, GoldschmidtDivParams)
+        return retval
 
 
 def clz(v, wid):
 
 
 def clz(v, wid):
@@ -979,6 +981,8 @@ class GoldschmidtDivOp(enum.Enum):
             d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
             assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
             state.f = params.table[d_m_1.bits]
             d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
             assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
             state.f = params.table[d_m_1.bits]
+            state.f = state.f.to_frac_wid(expanded_width,
+                                          round_dir=RoundDir.DOWN)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.f is not None
             n = state.n * state.f
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.f is not None
             n = state.n * state.f
@@ -1003,7 +1007,6 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
     def gen_hdl(self, params, state, sync_rom):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
     def gen_hdl(self, params, state, sync_rom):
-        # FIXME: finish getting hdl/simulation to work
         """generate the hdl for this operation.
 
         arguments:
         """generate the hdl for this operation.
 
         arguments:
@@ -1018,8 +1021,6 @@ class GoldschmidtDivOp(enum.Enum):
         assert isinstance(params, GoldschmidtDivParams)
         assert isinstance(state, GoldschmidtDivHDLState)
         m = state.m
         assert isinstance(params, GoldschmidtDivParams)
         assert isinstance(state, GoldschmidtDivHDLState)
         m = state.m
-        expanded_width = params.expanded_width
-        table_addr_bits = params.table_addr_bits
         if self == GoldschmidtDivOp.Normalize:
             # normalize so 1 <= d < 2
             assert state.d.width == params.io_width
         if self == GoldschmidtDivOp.Normalize:
             # normalize so 1 <= d < 2
             assert state.d.width == params.io_width
@@ -1029,27 +1030,41 @@ class GoldschmidtDivOp(enum.Enum):
             m.d.comb += d_leading_zeros.sig_in.eq(state.d)
             d_shift_out = Signal.like(state.d)
             m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
             m.d.comb += d_leading_zeros.sig_in.eq(state.d)
             d_shift_out = Signal.like(state.d)
             m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
-            state.d = Signal(params.n_d_f_total_wid)
-            m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
-                                                   + params.n_d_f_int_wid))
+            d = Signal(params.n_d_f_total_wid)
+            m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+                             >> state.d.width)
 
             # normalize so 1 <= n < 2
             n_leading_zeros = CLZ(2 * params.io_width)
             m.submodules.n_leading_zeros = n_leading_zeros
             m.d.comb += n_leading_zeros.sig_in.eq(state.n)
 
             # normalize so 1 <= n < 2
             n_leading_zeros = CLZ(2 * params.io_width)
             m.submodules.n_leading_zeros = n_leading_zeros
             m.d.comb += n_leading_zeros.sig_in.eq(state.n)
-            n_shift_s_v = (params.io_width + d_leading_zeros.lz
+            signed_zero = Const(0, signed(1))  # force subtraction to be signed
+            n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
                            - n_leading_zeros.lz)
             n_shift_s = Signal.like(n_shift_s_v)
                            - n_leading_zeros.lz)
             n_shift_s = Signal.like(n_shift_s_v)
-            state.n_shift = Signal(d_leading_zeros.lz.width)
+            n_shift_n_lz_out = Signal.like(state.n)
+            n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
             m.d.comb += [
                 n_shift_s.eq(n_shift_s_v),
             m.d.comb += [
                 n_shift_s.eq(n_shift_s_v),
-                state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+                n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+                n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
             ]
             ]
+            state.n_shift = Signal(d_leading_zeros.lz.width)
             n = Signal(params.n_d_f_total_wid)
             n = Signal(params.n_d_f_total_wid)
-            shifted_n = state.n << state.n_shift
-            fixed_shift = params.expanded_width - state.n.width
-            m.d.comb += n.eq(shifted_n << fixed_shift)
+            with m.If(n_shift_s < 0):
+                m.d.comb += [
+                    state.n_shift.eq(0),
+                    n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+                         >> state.d.width),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.n_shift.eq(n_shift_s),
+                    n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+                         >> state.n.width),
+                ]
             state.n = n
             state.n = n
+            state.d = d
         elif self == GoldschmidtDivOp.FEqTableLookup:
             assert state.d.width == params.n_d_f_total_wid, "invalid d width"
             # compute initial f by table lookup
         elif self == GoldschmidtDivOp.FEqTableLookup:
             assert state.d.width == params.n_d_f_total_wid, "invalid d width"
             # compute initial f by table lookup
@@ -1058,7 +1073,7 @@ class GoldschmidtDivOp(enum.Enum):
             table_width = 1 + params.table_data_bits
             table = Memory(width=table_width, depth=len(params.table),
                            init=[i.bits for i in params.table])
             table_width = 1 + params.table_data_bits
             table = Memory(width=table_width, depth=len(params.table),
                            init=[i.bits for i in params.table])
-            addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+            addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
             if sync_rom:
                 table_read = table.read_port()
                 m.d.comb += table_read.addr.eq(addr)
             if sync_rom:
                 table_read = table.read_port()
                 m.d.comb += table_read.addr.eq(addr)
@@ -1068,8 +1083,7 @@ class GoldschmidtDivOp(enum.Enum):
                 m.d.comb += table_read.addr.eq(addr)
             m.submodules.table_read = table_read
             state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
                 m.d.comb += table_read.addr.eq(addr)
             m.submodules.table_read = table_read
             state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
-            data_shift = (table_width - params.table_data_bits
-                          + params.expanded_width)
+            data_shift = params.expanded_width - params.table_data_bits
             m.d.comb += state.f.eq(table_read.data << data_shift)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.n.width == params.n_d_f_total_wid, "invalid n width"
             m.d.comb += state.f.eq(table_read.data << data_shift)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.n.width == params.n_d_f_total_wid, "invalid n width"
@@ -1150,8 +1164,20 @@ class GoldschmidtDivState:
     algorithm.
     """
 
     algorithm.
     """
 
+    def __repr__(self):
+        fields_str = []
+        for field in fields(GoldschmidtDivState):
+            value = getattr(self, field.name)
+            if value is None:
+                continue
+            if isinstance(value, int) and field.name != "n_shift":
+                fields_str.append(f"{field.name}={hex(value)}")
+            else:
+                fields_str.append(f"{field.name}={value!r}")
+        return f"GoldschmidtDivState({', '.join(fields_str)})"
+
 
 
-def goldschmidt_div(n, d, params):
+def goldschmidt_div(n, d, params, trace=lambda state: None):
     """ Goldschmidt division algorithm.
 
         based on:
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1168,6 +1194,9 @@ def goldschmidt_div(n, d, params):
             denominator. a `width`-bit unsigned integer. must not be zero.
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
             denominator. a `width`-bit unsigned integer. must not be zero.
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
+        trace: Function[[GoldschmidtDivState], None]
+            called with the initial state and the state after executing each
+            operation in `params.ops`.
 
         returns: tuple[int, int]
             the quotient and remainder. a tuple of two `width`-bit unsigned
 
         returns: tuple[int, int]
             the quotient and remainder. a tuple of two `width`-bit unsigned
@@ -1187,8 +1216,10 @@ def goldschmidt_div(n, d, params):
         d=FixedPoint(d, params.io_width),
     )
 
         d=FixedPoint(d, params.io_width),
     )
 
+    trace(state)
     for op in params.ops:
         op.run(params, state)
     for op in params.ops:
         op.run(params, state)
+        trace(state)
 
     assert state.quotient is not None
     assert state.remainder is not None
 
     assert state.quotient is not None
     assert state.remainder is not None
@@ -1273,7 +1304,6 @@ class GoldschmidtDivHDLState:
 
 
 class GoldschmidtDivHDL(Elaboratable):
 
 
 class GoldschmidtDivHDL(Elaboratable):
-    # FIXME: finish getting hdl/simulation to work
     """ Goldschmidt division algorithm.
 
         based on:
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1303,6 +1333,9 @@ class GoldschmidtDivHDL(Elaboratable):
             output quotient. only valid when `n < (d << params.io_width)`.
         r: Signal(unsigned(params.io_width))
             output remainder. only valid when `n < (d << params.io_width)`.
             output quotient. only valid when `n < (d << params.io_width)`.
         r: Signal(unsigned(params.io_width))
             output remainder. only valid when `n < (d << params.io_width)`.
+        trace: list[GoldschmidtDivHDLState]
+            list of the initial state and the state after executing each
+            operation in `params.ops`.
     """
 
     @property
     """
 
     @property
@@ -1321,15 +1354,16 @@ class GoldschmidtDivHDL(Elaboratable):
         self.q = Signal(unsigned(params.io_width))
         self.r = Signal(unsigned(params.io_width))
 
         self.q = Signal(unsigned(params.io_width))
         self.r = Signal(unsigned(params.io_width))
 
-    def elaborate(self, platform):
-        m = Module()
+        # in constructor so we get trace without needing to call elaborate
         state = GoldschmidtDivHDLState(
         state = GoldschmidtDivHDLState(
-            m=m,
+            m=Module(),
             orig_n=self.n,
             orig_d=self.d,
             n=self.n,
             d=self.d)
 
             orig_n=self.n,
             orig_d=self.d,
             n=self.n,
             d=self.d)
 
+        self.trace = [replace(state)]
+
         # copy and reverse
         pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
 
         # copy and reverse
         pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
 
@@ -1339,14 +1373,19 @@ class GoldschmidtDivHDL(Elaboratable):
                 pipe_reg_indexes.pop()
                 state.insert_pipeline_register()
             op.gen_hdl(self.params, state, self.sync_rom)
                 pipe_reg_indexes.pop()
                 state.insert_pipeline_register()
             op.gen_hdl(self.params, state, self.sync_rom)
+            self.trace.append(replace(state))
 
         while len(pipe_reg_indexes) > 0:
             pipe_reg_indexes.pop()
             state.insert_pipeline_register()
 
 
         while len(pipe_reg_indexes) > 0:
             pipe_reg_indexes.pop()
             state.insert_pipeline_register()
 
-        m.d.comb += self.q.eq(state.quotient)
-        m.d.comb += self.r.eq(state.remainder)
-        return m
+        state.m.d.comb += [
+            self.q.eq(state.quotient),
+            self.r.eq(state.remainder),
+        ]
+
+    def elaborate(self, platform):
+        return self.trace[0].m
 
 
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
 
 
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
index 5b4c89ad037e0963006dd96b6ea6f87fc7e253c5..b4d4fb85cd232b12b8eb28f40226e0679b23855a 100644 (file)
@@ -4,6 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+from dataclasses import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
@@ -12,8 +13,9 @@ from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
-    goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
+    FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
 
 class TestFixedPoint(FHDLTestCase):
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -88,10 +90,8 @@ class TestGoldschmidtDiv(FHDLTestCase):
                         with self.subTest(q=hex(q), r=hex(r)):
                             self.assertEqual((q, r), (expected_q, expected_r))
 
                         with self.subTest(q=hex(q), r=hex(r)):
                             self.assertEqual((q, r), (expected_q, expected_r))
 
-    @unittest.skip("hdl/simulation currently broken")
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
                 sync_rom=False):
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
                 sync_rom=False):
-        # FIXME: finish getting hdl/simulation to work
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         m = Module()
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         m = Module()
@@ -103,7 +103,12 @@ class TestGoldschmidtDiv(FHDLTestCase):
 
         def iter_cases():
             if cases is not None:
 
         def iter_cases():
             if cases is not None:
-                yield from cases
+                for n, d in cases:
+                    assert isinstance(d, int) \
+                        and 0 < d < (1 << params.io_width), "invalid case"
+                    assert isinstance(n, int) \
+                        and 0 <= n < (d << params.io_width), "invalid case"
+                    yield (n, d)
                 return
             for d in range(1, 1 << io_width):
                 for n in range(d << io_width):
                 return
             for d in range(1, 1 << io_width):
                 for n in range(d << io_width):
@@ -116,6 +121,45 @@ class TestGoldschmidtDiv(FHDLTestCase):
                 yield dut.d.eq(d)
                 yield Tick()
 
                 yield dut.d.eq(d)
                 yield Tick()
 
+        def check_interals(n, d):
+            # check internals only if dut is completely combinatorial
+            # so we don't have to figure out how to read values in
+            # previous clock cycles
+            if dut.total_pipeline_registers != 0:
+                return
+            ref_trace = []
+
+            def ref_trace_fn(state):
+                assert isinstance(state, GoldschmidtDivState)
+                ref_trace.append((replace(state)))
+            goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
+            self.assertEqual(len(dut.trace), len(ref_trace))
+            for index, state in enumerate(dut.trace):
+                ref_state = ref_trace[index]
+                last_op = None if index == 0 else params.ops[index - 1]
+                with self.subTest(index=index, state=repr(state),
+                                  ref_state=repr(ref_state),
+                                  last_op=str(last_op)):
+                    for field in fields(GoldschmidtDivHDLState):
+                        sig = getattr(state, field.name)
+                        if not isinstance(sig, Signal):
+                            continue
+                        ref_value = getattr(ref_state, field.name)
+                        ref_value_str = repr(ref_value)
+                        if isinstance(ref_value, int):
+                            ref_value_str = hex(ref_value)
+                        value = yield sig
+                        with self.subTest(field_name=field.name,
+                                          sig=repr(sig),
+                                          sig_shape=repr(sig.shape()),
+                                          value=hex(value),
+                                          ref_value=ref_value_str):
+                            if isinstance(ref_value, int):
+                                self.assertEqual(value, ref_value)
+                            else:
+                                assert isinstance(ref_value, FixedPoint)
+                                self.assertEqual(value, ref_value.bits)
+
         def check_outputs():
             yield Tick()
             for _ in range(dut.total_pipeline_registers):
         def check_outputs():
             yield Tick()
             for _ in range(dut.total_pipeline_registers):
@@ -130,6 +174,8 @@ class TestGoldschmidtDiv(FHDLTestCase):
                     r = yield dut.r
                     with self.subTest(q=hex(q), r=hex(r)):
                         self.assertEqual((q, r), (expected_q, expected_r))
                     r = yield dut.r
                     with self.subTest(q=hex(q), r=hex(r)):
                         self.assertEqual((q, r), (expected_q, expected_r))
+                    yield from check_interals(n, d)
+
                 yield Tick()
 
         with self.subTest(params=str(params)):
                 yield Tick()
 
         with self.subTest(params=str(params)):