add WIP HDL version of goldschmidt division -- it's currently broken
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 28 Apr 2022 09:19:01 +0000 (02:19 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 28 Apr 2022 09:24:42 +0000 (02:24 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index c3837e9afc2f661ffc864d088d3ead1a3b0cd2ab..3801b200abe5f04e5968e6565c5f307465650967 100644 (file)
@@ -4,6 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+from collections import defaultdict
 from dataclasses import dataclass, field, fields, replace
 import logging
 import math
@@ -11,6 +12,10 @@ import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
+from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.dsl import Module, Elaboratable
+from nmigen.hdl.mem import Memory
+from nmutil.clz import CLZ
 
 try:
     from functools import cached_property
@@ -66,6 +71,7 @@ class FixedPoint:
     frac_wid: int
 
     def __post_init__(self):
+        # called by the autogenerated __init__
         assert isinstance(self.bits, int)
         assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
 
@@ -463,6 +469,20 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         """the total number of bits of precision used inside the algorithm."""
         return self.io_width + self.extra_precision
 
+    @property
+    def n_d_f_int_wid(self):
+        """the number of bits in the integer part of `state.n`, `state.d`, and
+        `state.f` during the main iteration loop.
+        """
+        return 2
+
+    @property
+    def n_d_f_total_wid(self):
+        """the total number of bits (both integer and fraction bits) in
+        `state.n`, `state.d`, and `state.f` during the main iteration loop.
+        """
+        return self.n_d_f_int_wid + self.expanded_width
+
     @cache_on_self
     def max_neps(self, i):
         """maximum value of `neps[i]`.
@@ -920,6 +940,12 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         return cached_new(params)
 
 
+def clz(v, wid):
+    assert isinstance(wid, int)
+    assert isinstance(v, int) and 0 <= v < (1 << wid)
+    return (1 << wid).bit_length() - v.bit_length()
+
+
 @enum.unique
 class GoldschmidtDivOp(enum.Enum):
     Normalize = "n, d, n_shift = normalize(n, d)"
@@ -975,6 +1001,125 @@ class GoldschmidtDivOp(enum.Enum):
         else:
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
+    def gen_hdl(self, params, state, sync_rom):
+        # FIXME: finish getting hdl/simulation to work
+        """generate the hdl for this operation.
+
+        arguments:
+        params: GoldschmidtDivParams
+            the goldschmidt division parameters.
+        state: GoldschmidtDivHDLState
+            the input/output state
+        sync_rom: bool
+            true if the rom should be read synchronously rather than
+            combinatorially, incurring an extra clock cycle of latency.
+        """
+        assert isinstance(params, GoldschmidtDivParams)
+        assert isinstance(state, GoldschmidtDivHDLState)
+        m = state.m
+        expanded_width = params.expanded_width
+        table_addr_bits = params.table_addr_bits
+        if self == GoldschmidtDivOp.Normalize:
+            # normalize so 1 <= d < 2
+            assert state.d.width == params.io_width
+            assert state.n.width == 2 * params.io_width
+            d_leading_zeros = CLZ(params.io_width)
+            m.submodules.d_leading_zeros = d_leading_zeros
+            m.d.comb += d_leading_zeros.sig_in.eq(state.d)
+            d_shift_out = Signal.like(state.d)
+            m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
+            state.d = Signal(params.n_d_f_total_wid)
+            m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
+                                                   + params.n_d_f_int_wid))
+
+            # normalize so 1 <= n < 2
+            n_leading_zeros = CLZ(2 * params.io_width)
+            m.submodules.n_leading_zeros = n_leading_zeros
+            m.d.comb += n_leading_zeros.sig_in.eq(state.n)
+            n_shift_s_v = (params.io_width + d_leading_zeros.lz
+                           - n_leading_zeros.lz)
+            n_shift_s = Signal.like(n_shift_s_v)
+            state.n_shift = Signal(d_leading_zeros.lz.width)
+            m.d.comb += [
+                n_shift_s.eq(n_shift_s_v),
+                state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+            ]
+            n = Signal(params.n_d_f_total_wid)
+            shifted_n = state.n << state.n_shift
+            fixed_shift = params.expanded_width - state.n.width
+            m.d.comb += n.eq(shifted_n << fixed_shift)
+            state.n = n
+        elif self == GoldschmidtDivOp.FEqTableLookup:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            # compute initial f by table lookup
+
+            # extra bit for table entries == 1.0
+            table_width = 1 + params.table_data_bits
+            table = Memory(width=table_width, depth=len(params.table),
+                           init=[i.bits for i in params.table])
+            addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+            if sync_rom:
+                table_read = table.read_port()
+                m.d.comb += table_read.addr.eq(addr)
+                state.insert_pipeline_register()
+            else:
+                table_read = table.read_port(domain="comb")
+                m.d.comb += table_read.addr.eq(addr)
+            m.submodules.table_read = table_read
+            state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
+            data_shift = (table_width - params.table_data_bits
+                          + params.expanded_width)
+            m.d.comb += state.f.eq(table_read.data << data_shift)
+        elif self == GoldschmidtDivOp.MulNByF:
+            assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+            assert state.f is not None
+            assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+            n = Signal.like(state.n)
+            m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
+            state.n = n
+        elif self == GoldschmidtDivOp.MulDByF:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            assert state.f is not None
+            assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+            d = Signal.like(state.d)
+            m.d.comb += d.eq((state.d * state.f) >> params.expanded_width)
+            state.d = d
+        elif self == GoldschmidtDivOp.FEq2MinusD:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            f = Signal.like(state.d)
+            m.d.comb += f.eq((2 << params.expanded_width) - state.d)
+            state.f = f
+        elif self == GoldschmidtDivOp.CalcResult:
+            assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+            assert state.n_shift is not None
+            # scale to correct value
+            n = state.n * (1 << state.n_shift)
+            q_approx = Signal(params.io_width)
+            # extra bit for if it's bigger than orig_d
+            r_approx = Signal(params.io_width + 1)
+            adjusted_r = Signal(signed(1 + params.io_width))
+            m.d.comb += [
+                q_approx.eq((state.n << state.n_shift)
+                            >> params.expanded_width),
+                r_approx.eq(state.orig_n - q_approx * state.orig_d),
+                adjusted_r.eq(r_approx - state.orig_d),
+            ]
+            state.quotient = Signal(params.io_width)
+            state.remainder = Signal(params.io_width)
+
+            with m.If(adjusted_r >= 0):
+                m.d.comb += [
+                    state.quotient.eq(q_approx + 1),
+                    state.remainder.eq(adjusted_r),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.quotient.eq(q_approx),
+                    state.remainder.eq(r_approx),
+                ]
+        else:
+            assert False, f"unimplemented GoldschmidtDivOp: {self}"
+
 
 @dataclass
 class GoldschmidtDivState:
@@ -1050,6 +1195,159 @@ def goldschmidt_div(n, d, params):
     return state.quotient, state.remainder
 
 
+@dataclass(eq=False)
+class GoldschmidtDivHDLState:
+    m: Module
+    """The HDL Module"""
+
+    orig_n: Signal
+    """original numerator"""
+
+    orig_d: Signal
+    """original denominator"""
+
+    n: Signal
+    """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+    d: Signal
+    """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+    f: "Signal | None" = None
+    """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+    quotient: "Signal | None" = None
+    """final quotient"""
+
+    remainder: "Signal | None" = None
+    """final remainder"""
+
+    n_shift: "Signal | None" = None
+    """amount the numerator needs to be left-shifted at the end of the
+    algorithm.
+    """
+
+    old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
+                                                          init=False)
+
+    __signal_name_prefix: "str" = field(default="state_", repr=False,
+                                        init=False)
+
+    def __post_init__(self):
+        # called by the autogenerated __init__
+        self.old_signals = defaultdict(list)
+
+    def __setattr__(self, name, value):
+        assert isinstance(name, str)
+        if name.startswith("_"):
+            return super().__setattr__(name, value)
+        try:
+            old_signals = self.old_signals[name]
+        except AttributeError:
+            # haven't yet finished __post_init__
+            return super().__setattr__(name, value)
+        assert name != "m" and name != "old_signals", f"can't write to {name}"
+        assert isinstance(value, Signal)
+        value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
+        old_signal = getattr(self, name, None)
+        if old_signal is not None:
+            assert isinstance(old_signal, Signal)
+            old_signals.append(old_signal)
+        return super().__setattr__(name, value)
+
+    def insert_pipeline_register(self):
+        old_prefix = self.__signal_name_prefix
+        try:
+            for field in fields(GoldschmidtDivHDLState):
+                if field.name.startswith("_") or field.name == "m":
+                    continue
+                old_sig = getattr(self, field.name, None)
+                if old_sig is None:
+                    continue
+                assert isinstance(old_sig, Signal)
+                new_sig = Signal.like(old_sig)
+                setattr(self, field.name, new_sig)
+                self.m.d.sync += new_sig.eq(old_sig)
+        finally:
+            self.__signal_name_prefix = old_prefix
+
+
+class GoldschmidtDivHDL(Elaboratable):
+    # FIXME: finish getting hdl/simulation to work
+    """ Goldschmidt division algorithm.
+
+        based on:
+        Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+        A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+        https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+        attributes:
+        params: GoldschmidtDivParams
+            the goldschmidt division algorithm parameters.
+        pipe_reg_indexes: list[int]
+            the operation indexes where pipeline registers should be inserted.
+            duplicate values mean multiple registers should be inserted for
+            that operation index -- this is useful to allow yosys to spread a
+            multiplication across those multiple pipeline stages.
+        sync_rom: bool
+            true if the rom should be read synchronously rather than
+            combinatorially, incurring an extra clock cycle of latency.
+        n: Signal(unsigned(2 * params.io_width))
+            input numerator. a `2 * params.io_width`-bit unsigned integer.
+            must be less than `d << params.io_width`, otherwise the quotient
+            wouldn't fit in `params.io_width` bits.
+        d: Signal(unsigned(params.io_width))
+            input denominator. a `params.io_width`-bit unsigned integer.
+            must not be zero.
+        q: Signal(unsigned(params.io_width))
+            output quotient. only valid when `n < (d << params.io_width)`.
+        r: Signal(unsigned(params.io_width))
+            output remainder. only valid when `n < (d << params.io_width)`.
+    """
+
+    @property
+    def total_pipeline_registers(self):
+        """the total number of pipeline registers"""
+        return len(self.pipe_reg_indexes) + self.sync_rom
+
+    def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
+        assert isinstance(params, GoldschmidtDivParams)
+        assert isinstance(sync_rom, bool)
+        self.params = params
+        self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
+        self.sync_rom = sync_rom
+        self.n = Signal(unsigned(2 * params.io_width))
+        self.d = Signal(unsigned(params.io_width))
+        self.q = Signal(unsigned(params.io_width))
+        self.r = Signal(unsigned(params.io_width))
+
+    def elaborate(self, platform):
+        m = Module()
+        state = GoldschmidtDivHDLState(
+            m=m,
+            orig_n=self.n,
+            orig_d=self.d,
+            n=self.n,
+            d=self.d)
+
+        # copy and reverse
+        pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
+
+        for op_index, op in enumerate(self.params.ops):
+            while len(pipe_reg_indexes) > 0 \
+                    and pipe_reg_indexes[-1] <= op_index:
+                pipe_reg_indexes.pop()
+                state.insert_pipeline_register()
+            op.gen_hdl(self.params, state, self.sync_rom)
+
+        while len(pipe_reg_indexes) > 0:
+            pipe_reg_indexes.pop()
+            state.insert_pipeline_register()
+
+        m.d.comb += self.q.eq(state.quotient)
+        m.d.comb += self.r.eq(state.remainder)
+        return m
+
+
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
 
 
index e2984dc16db684cc31ea20973d235ba72e9aa01d..5b4c89ad037e0963006dd96b6ea6f87fc7e253c5 100644 (file)
@@ -7,9 +7,13 @@
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
+from nmutil.sim_util import do_sim
+from nmigen.sim import Tick, Delay
+from nmigen.hdl.ast import Signal
+from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
-    FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+    GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
+    goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -84,6 +88,57 @@ class TestGoldschmidtDiv(FHDLTestCase):
                         with self.subTest(q=hex(q), r=hex(r)):
                             self.assertEqual((q, r), (expected_q, expected_r))
 
+    @unittest.skip("hdl/simulation currently broken")
+    def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
+                sync_rom=False):
+        # FIXME: finish getting hdl/simulation to work
+        assert isinstance(io_width, int)
+        params = GoldschmidtDivParams.get(io_width)
+        m = Module()
+        dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes,
+                                sync_rom=sync_rom)
+        m.submodules.dut = dut
+        # make sync domain get added
+        m.d.sync += Signal().eq(0)
+
+        def iter_cases():
+            if cases is not None:
+                yield from cases
+                return
+            for d in range(1, 1 << io_width):
+                for n in range(d << io_width):
+                    yield (n, d)
+
+        def inputs_proc():
+            yield Tick()
+            for n, d in iter_cases():
+                yield dut.n.eq(n)
+                yield dut.d.eq(d)
+                yield Tick()
+
+        def check_outputs():
+            yield Tick()
+            for _ in range(dut.total_pipeline_registers):
+                yield Tick()
+            for n, d in iter_cases():
+                yield Delay(0.1e-6)
+                expected_q, expected_r = divmod(n, d)
+                with self.subTest(n=hex(n), d=hex(d),
+                                  expected_q=hex(expected_q),
+                                  expected_r=hex(expected_r)):
+                    q = yield dut.q
+                    r = yield dut.r
+                    with self.subTest(q=hex(q), r=hex(r)):
+                        self.assertEqual((q, r), (expected_q, expected_r))
+                yield Tick()
+
+        with self.subTest(params=str(params)):
+            with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim:
+                sim.add_clock(1e-6)
+                sim.add_process(inputs_proc)
+                sim.add_process(check_outputs)
+                sim.run()
+
     def test_1_through_4(self):
         for io_width in range(1, 4 + 1):
             with self.subTest(io_width=io_width):
@@ -95,6 +150,9 @@ class TestGoldschmidtDiv(FHDLTestCase):
     def test_6(self):
         self.tst(6)
 
+    def test_sim_5(self):
+        self.tst_sim(5)
+
     def tst_params(self, io_width):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)