change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
index 055ff7c137ba4aaefe626fe597108843777f991d..3f7c2480742d6913859461da120099385f99d18a 100644 (file)
@@ -4,13 +4,18 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
-from dataclasses import dataclass, field, fields, replace
+from collections import defaultdict
 import logging
 import math
 import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
+from nmigen.hdl.ast import Signal, unsigned, signed, Const
+from nmigen.hdl.dsl import Module, Elaboratable
+from nmigen.hdl.mem import Memory
+from nmutil.clz import CLZ
+from nmutil.plain_data import plain_data, fields, replace
 
 try:
     from functools import cached_property
@@ -60,12 +65,13 @@ class RoundDir(enum.Enum):
     ERROR_IF_INEXACT = enum.auto()
 
 
-@dataclass(frozen=True)
+@plain_data(frozen=True, eq=False, repr=False)
 class FixedPoint:
-    bits: int
-    frac_wid: int
+    __slots__ = "bits", "frac_wid"
 
-    def __post_init__(self):
+    def __init__(self, bits, frac_wid):
+        self.bits = bits
+        self.frac_wid = frac_wid
         assert isinstance(self.bits, int)
         assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
 
@@ -315,35 +321,6 @@ class FixedPoint:
         return retval
 
 
-@dataclass
-class GoldschmidtDivState:
-    orig_n: int
-    """original numerator"""
-
-    orig_d: int
-    """original denominator"""
-
-    n: FixedPoint
-    """numerator -- N_prime[i] in the paper's algorithm 2"""
-
-    d: FixedPoint
-    """denominator -- D_prime[i] in the paper's algorithm 2"""
-
-    f: "FixedPoint | None" = None
-    """current factor -- F_prime[i] in the paper's algorithm 2"""
-
-    quotient: "int | None" = None
-    """final quotient"""
-
-    remainder: "int | None" = None
-    """final remainder"""
-
-    n_shift: "int | None" = None
-    """amount the numerator needs to be left-shifted at the end of the
-    algorithm.
-    """
-
-
 class ParamsNotAccurateEnough(Exception):
     """raised when the parameters aren't accurate enough to have goldschmidt
     division work."""
@@ -355,43 +332,47 @@ def _assert_accuracy(condition, msg="not accurate enough"):
     raise ParamsNotAccurateEnough(msg)
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParamsBase:
     """parameters for a Goldschmidt division algorithm, excluding derived
     parameters.
     """
 
-    io_width: int
-    """bit-width of the input divisor and the result.
-    the input numerator is `2 * io_width`-bits wide.
-    """
+    __slots__ = ("io_width", "extra_precision", "table_addr_bits",
+                 "table_data_bits", "iter_count")
+
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        assert isinstance(io_width, int)
+        assert isinstance(extra_precision, int)
+        assert isinstance(table_addr_bits, int)
+        assert isinstance(table_data_bits, int)
+        assert isinstance(iter_count, int)
+        self.io_width = io_width
+        """bit-width of the input divisor and the result.
+        the input numerator is `2 * io_width`-bits wide.
+        """
 
-    extra_precision: int
-    """number of bits of additional precision used inside the algorithm."""
+        self.extra_precision = extra_precision
+        """number of bits of additional precision used inside the algorithm."""
 
-    table_addr_bits: int
-    """the number of address bits used in the lookup-table."""
+        self.table_addr_bits = table_addr_bits
+        """the number of address bits used in the lookup-table."""
 
-    table_data_bits: int
-    """the number of data bits used in the lookup-table."""
+        self.table_data_bits = table_data_bits
+        """the number of data bits used in the lookup-table."""
 
-    iter_count: int
-    """the total number of iterations of the division algorithm's loop"""
+        self.iter_count = iter_count
+        """the total number of iterations of the division algorithm's loop"""
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     """parameters for a Goldschmidt division algorithm.
     Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
     """
 
-    # tuple to be immutable, repr=False so repr() works for debugging even when
-    # __post_init__ hasn't finished running yet
-    table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
-    """the lookup-table"""
-
-    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
-    """the operations needed to perform the goldschmidt division algorithm."""
+    __slots__ = "table", "ops"
 
     def _shrink_bound(self, bound, round_dir):
         """prevent fractions from having huge numerators/denominators by
@@ -468,8 +449,13 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         # we round down
         return min_value
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        super().__init__(io_width=io_width,
+                         extra_precision=extra_precision,
+                         table_addr_bits=table_addr_bits,
+                         table_data_bits=table_data_bits,
+                         iter_count=iter_count)
         _assert_accuracy(self.io_width >= 1, "io_width out of range")
         _assert_accuracy(self.extra_precision >= 0,
                          "extra_precision out of range")
@@ -483,15 +469,34 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
                                                   self.table_data_bits,
                                                   RoundDir.DOWN))
-        # we have to use object.__setattr__ since frozen=True
-        object.__setattr__(self, "table", tuple(table))
-        object.__setattr__(self, "ops", tuple(self.__make_ops()))
+
+        self.table = tuple(table)
+        """ the lookup-table.
+        type: tuple[FixedPoint, ...]
+        """
+
+        self.ops = tuple(self.__make_ops())
+        "the operations needed to perform the goldschmidt division algorithm."
 
     @property
     def expanded_width(self):
         """the total number of bits of precision used inside the algorithm."""
         return self.io_width + self.extra_precision
 
+    @property
+    def n_d_f_int_wid(self):
+        """the number of bits in the integer part of `state.n`, `state.d`, and
+        `state.f` during the main iteration loop.
+        """
+        return 2
+
+    @property
+    def n_d_f_total_wid(self):
+        """the total number of bits (both integer and fraction bits) in
+        `state.n`, `state.d`, and `state.f` during the main iteration loop.
+        """
+        return self.n_d_f_int_wid + self.expanded_width
+
     @cache_on_self
     def max_neps(self, i):
         """maximum value of `neps[i]`.
@@ -703,14 +708,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     def max_n_shift(self):
         """ maximum value of `state.n_shift`.
         """
-        # input numerator is `2*io_width`-bits
-        max_n = (1 << (self.io_width * 2)) - 1
-        max_n_shift = 0
-        # normalize so 1 <= n < 2
-        while max_n >= 2:
-            max_n >>= 1
-            max_n_shift += 1
-        return max_n_shift
+        # numerator must be less than `denominator << self.io_width`, so
+        # `n_shift` is at most `self.io_width`
+        return self.io_width
 
     @cached_property
     def n_hat(self):
@@ -778,13 +778,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         max_rel_error = (2 * i) * self.n_hat + power
 
         min_a_over_b = Fraction(1, 2)
-        max_a_over_b = Fraction(2)
-        max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
-        max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+        min_abs_error_for_correctness = min_a_over_b / (1 << self.max_n_shift)
+        min_rel_error_for_correctness = (min_abs_error_for_correctness
+                                         / min_a_over_b)
 
-        _assert_accuracy(max_rel_error < max_allowed_rel_error,
-                         f"not accurate enough: max_rel_error={max_rel_error}"
-                         f" max_allowed_rel_error={max_allowed_rel_error}")
+        _assert_accuracy(
+            max_rel_error < min_rel_error_for_correctness,
+            f"not accurate enough: max_rel_error={max_rel_error}"
+            f" min_rel_error_for_correctness={min_rel_error_for_correctness}")
 
         yield GoldschmidtDivOp.CalcResult
 
@@ -813,11 +814,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     @lru_cache(maxsize=1 << 16)
     def __cached_new(base_params):
         assert isinstance(base_params, GoldschmidtDivParamsBase)
-        # can't use dataclasses.asdict, since it's recursive and will also give
-        # child class fields too, which we don't want.
         kwargs = {}
         for field in fields(GoldschmidtDivParamsBase):
-            kwargs[field.name] = getattr(base_params, field.name)
+            kwargs[field] = getattr(base_params, field)
         try:
             return GoldschmidtDivParams(**kwargs), None
         except ParamsNotAccurateEnough as e:
@@ -946,7 +945,16 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             else:
                 break
 
-        return cached_new(params)
+        retval = cached_new(params)
+        assert isinstance(retval, GoldschmidtDivParams)
+        return retval
+
+
+def clz(v, wid):
+    """count leading zeros -- handy for debugging."""
+    assert isinstance(wid, int)
+    assert isinstance(v, int) and 0 <= v < (1 << wid)
+    return (1 << wid).bit_length() - v.bit_length()
 
 
 @enum.unique
@@ -973,7 +981,8 @@ class GoldschmidtDivOp(enum.Enum):
             state.n_shift = 0
             # normalize so 1 <= n < 2
             while state.n >= 2:
-                state.n = (state.n * 0.5).to_frac_wid(expanded_width)
+                state.n = (state.n * 0.5).to_frac_wid(expanded_width,
+                                                      round_dir=RoundDir.DOWN)
                 state.n_shift += 1
         elif self == GoldschmidtDivOp.FEqTableLookup:
             # compute initial f by table lookup
@@ -981,6 +990,8 @@ class GoldschmidtDivOp(enum.Enum):
             d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
             assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
             state.f = params.table[d_m_1.bits]
+            state.f = state.f.to_frac_wid(expanded_width,
+                                          round_dir=RoundDir.DOWN)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.f is not None
             n = state.n * state.f
@@ -1004,8 +1015,197 @@ class GoldschmidtDivOp(enum.Enum):
         else:
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
+    def gen_hdl(self, params, state, sync_rom):
+        """generate the hdl for this operation.
 
-def goldschmidt_div(n, d, params):
+        arguments:
+        params: GoldschmidtDivParams
+            the goldschmidt division parameters.
+        state: GoldschmidtDivHDLState
+            the input/output state
+        sync_rom: bool
+            true if the rom should be read synchronously rather than
+            combinatorially, incurring an extra clock cycle of latency.
+        """
+        assert isinstance(params, GoldschmidtDivParams)
+        assert isinstance(state, GoldschmidtDivHDLState)
+        m = state.m
+        if self == GoldschmidtDivOp.Normalize:
+            # normalize so 1 <= d < 2
+            assert state.d.width == params.io_width
+            assert state.n.width == 2 * params.io_width
+            d_leading_zeros = CLZ(params.io_width)
+            m.submodules.d_leading_zeros = d_leading_zeros
+            m.d.comb += d_leading_zeros.sig_in.eq(state.d)
+            d_shift_out = Signal.like(state.d)
+            m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
+            d = Signal(params.n_d_f_total_wid)
+            m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+                             >> state.d.width)
+
+            # normalize so 1 <= n < 2
+            n_leading_zeros = CLZ(2 * params.io_width)
+            m.submodules.n_leading_zeros = n_leading_zeros
+            m.d.comb += n_leading_zeros.sig_in.eq(state.n)
+            signed_zero = Const(0, signed(1))  # force subtraction to be signed
+            n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
+                           - n_leading_zeros.lz)
+            n_shift_s = Signal.like(n_shift_s_v)
+            n_shift_n_lz_out = Signal.like(state.n)
+            n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
+            m.d.comb += [
+                n_shift_s.eq(n_shift_s_v),
+                n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+                n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
+            ]
+            state.n_shift = Signal(d_leading_zeros.lz.width)
+            n = Signal(params.n_d_f_total_wid)
+            with m.If(n_shift_s < 0):
+                m.d.comb += [
+                    state.n_shift.eq(0),
+                    n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+                         >> state.d.width),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.n_shift.eq(n_shift_s),
+                    n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+                         >> state.n.width),
+                ]
+            state.n = n
+            state.d = d
+        elif self == GoldschmidtDivOp.FEqTableLookup:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            # compute initial f by table lookup
+
+            # extra bit for table entries == 1.0
+            table_width = 1 + params.table_data_bits
+            table = Memory(width=table_width, depth=len(params.table),
+                           init=[i.bits for i in params.table])
+            addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
+            if sync_rom:
+                table_read = table.read_port()
+                m.d.comb += table_read.addr.eq(addr)
+                state.insert_pipeline_register()
+            else:
+                table_read = table.read_port(domain="comb")
+                m.d.comb += table_read.addr.eq(addr)
+            m.submodules.table_read = table_read
+            state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
+            data_shift = params.expanded_width - params.table_data_bits
+            m.d.comb += state.f.eq(table_read.data << data_shift)
+        elif self == GoldschmidtDivOp.MulNByF:
+            assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+            assert state.f is not None
+            assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+            n = Signal.like(state.n)
+            m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
+            state.n = n
+        elif self == GoldschmidtDivOp.MulDByF:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            assert state.f is not None
+            assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+            d = Signal.like(state.d)
+            d_times_f = Signal.like(state.d * state.f)
+            m.d.comb += [
+                d_times_f.eq(state.d * state.f),
+                # round the multiplication up
+                d.eq((d_times_f >> params.expanded_width)
+                     + (d_times_f[:params.expanded_width] != 0)),
+            ]
+            state.d = d
+        elif self == GoldschmidtDivOp.FEq2MinusD:
+            assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+            f = Signal.like(state.d)
+            m.d.comb += f.eq((2 << params.expanded_width) - state.d)
+            state.f = f
+        elif self == GoldschmidtDivOp.CalcResult:
+            assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+            assert state.n_shift is not None
+            # scale to correct value
+            n = state.n * (1 << state.n_shift)
+            q_approx = Signal(params.io_width)
+            # extra bit for if it's bigger than orig_d
+            r_approx = Signal(params.io_width + 1)
+            adjusted_r = Signal(signed(1 + params.io_width))
+            m.d.comb += [
+                q_approx.eq((state.n << state.n_shift)
+                            >> params.expanded_width),
+                r_approx.eq(state.orig_n - q_approx * state.orig_d),
+                adjusted_r.eq(r_approx - state.orig_d),
+            ]
+            state.quotient = Signal(params.io_width)
+            state.remainder = Signal(params.io_width)
+
+            with m.If(adjusted_r >= 0):
+                m.d.comb += [
+                    state.quotient.eq(q_approx + 1),
+                    state.remainder.eq(adjusted_r),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.quotient.eq(q_approx),
+                    state.remainder.eq(r_approx),
+                ]
+        else:
+            assert False, f"unimplemented GoldschmidtDivOp: {self}"
+
+
+@plain_data(repr=False)
+class GoldschmidtDivState:
+    __slots__ = ("orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
+
+    def __init__(self, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(orig_n, int)
+        assert isinstance(orig_d, int)
+        assert isinstance(n, FixedPoint)
+        assert isinstance(d, FixedPoint)
+        assert f is None or isinstance(f, FixedPoint)
+        assert quotient is None or isinstance(quotient, int)
+        assert remainder is None or isinstance(remainder, int)
+        assert n_shift is None or isinstance(n_shift, int)
+        self.orig_n = orig_n
+        """original numerator"""
+
+        self.orig_d = orig_d
+        """original denominator"""
+
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+        self.quotient = quotient
+        """final quotient"""
+
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
+
+    def __repr__(self):
+        fields_str = []
+        for field in fields(GoldschmidtDivState):
+            value = getattr(self, field)
+            if value is None:
+                continue
+            if isinstance(value, int) and field != "n_shift":
+                fields_str.append(f"{field}={hex(value)}")
+            else:
+                fields_str.append(f"{field}={value!r}")
+        return f"GoldschmidtDivState({', '.join(fields_str)})"
+
+
+def goldschmidt_div(n, d, params, trace=lambda state: None):
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1022,6 +1222,9 @@ def goldschmidt_div(n, d, params):
             denominator. a `width`-bit unsigned integer. must not be zero.
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
+        trace: Function[[GoldschmidtDivState], None]
+            called with the initial state and the state after executing each
+            operation in `params.ops`.
 
         returns: tuple[int, int]
             the quotient and remainder. a tuple of two `width`-bit unsigned
@@ -1041,8 +1244,10 @@ def goldschmidt_div(n, d, params):
         d=FixedPoint(d, params.io_width),
     )
 
+    trace(state)
     for op in params.ops:
         op.run(params, state)
+        trace(state)
 
     assert state.quotient is not None
     assert state.remainder is not None
@@ -1050,6 +1255,177 @@ def goldschmidt_div(n, d, params):
     return state.quotient, state.remainder
 
 
+@plain_data(eq=False)
+class GoldschmidtDivHDLState:
+    __slots__ = ("m", "orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
+
+    __signal_name_prefix = "state_"
+
+    def __init__(self, m, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(m, Module)
+        assert isinstance(orig_n, Signal)
+        assert isinstance(orig_d, Signal)
+        assert isinstance(n, Signal)
+        assert isinstance(d, Signal)
+        assert f is None or isinstance(f, Signal)
+        assert quotient is None or isinstance(quotient, Signal)
+        assert remainder is None or isinstance(remainder, Signal)
+        assert n_shift is None or isinstance(n_shift, Signal)
+
+        self.m = m
+        """The HDL Module"""
+
+        self.orig_n = orig_n
+        """original numerator"""
+
+        self.orig_d = orig_d
+        """original denominator"""
+
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+        self.quotient = quotient
+        """final quotient"""
+
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
+
+        # old_signals must be set last
+        self.old_signals = defaultdict(list)
+
+    def __setattr__(self, name, value):
+        assert isinstance(name, str)
+        if name.startswith("_"):
+            return super().__setattr__(name, value)
+        try:
+            old_signals = self.old_signals[name]
+        except AttributeError:
+            # haven't yet finished __post_init__
+            return super().__setattr__(name, value)
+        assert name != "m" and name != "old_signals", f"can't write to {name}"
+        assert isinstance(value, Signal)
+        value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
+        old_signal = getattr(self, name, None)
+        if old_signal is not None:
+            assert isinstance(old_signal, Signal)
+            old_signals.append(old_signal)
+        return super().__setattr__(name, value)
+
+    def insert_pipeline_register(self):
+        old_prefix = self.__signal_name_prefix
+        try:
+            for field in fields(GoldschmidtDivHDLState):
+                if field.startswith("_") or field == "m":
+                    continue
+                old_sig = getattr(self, field, None)
+                if old_sig is None:
+                    continue
+                assert isinstance(old_sig, Signal)
+                new_sig = Signal.like(old_sig)
+                setattr(self, field, new_sig)
+                self.m.d.sync += new_sig.eq(old_sig)
+        finally:
+            self.__signal_name_prefix = old_prefix
+
+
+class GoldschmidtDivHDL(Elaboratable):
+    """ Goldschmidt division algorithm.
+
+        based on:
+        Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+        A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+        https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+        attributes:
+        params: GoldschmidtDivParams
+            the goldschmidt division algorithm parameters.
+        pipe_reg_indexes: list[int]
+            the operation indexes where pipeline registers should be inserted.
+            duplicate values mean multiple registers should be inserted for
+            that operation index -- this is useful to allow yosys to spread a
+            multiplication across those multiple pipeline stages.
+        sync_rom: bool
+            true if the rom should be read synchronously rather than
+            combinatorially, incurring an extra clock cycle of latency.
+        n: Signal(unsigned(2 * params.io_width))
+            input numerator. a `2 * params.io_width`-bit unsigned integer.
+            must be less than `d << params.io_width`, otherwise the quotient
+            wouldn't fit in `params.io_width` bits.
+        d: Signal(unsigned(params.io_width))
+            input denominator. a `params.io_width`-bit unsigned integer.
+            must not be zero.
+        q: Signal(unsigned(params.io_width))
+            output quotient. only valid when `n < (d << params.io_width)`.
+        r: Signal(unsigned(params.io_width))
+            output remainder. only valid when `n < (d << params.io_width)`.
+        trace: list[GoldschmidtDivHDLState]
+            list of the initial state and the state after executing each
+            operation in `params.ops`.
+    """
+
+    @property
+    def total_pipeline_registers(self):
+        """the total number of pipeline registers"""
+        return len(self.pipe_reg_indexes) + self.sync_rom
+
+    def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
+        assert isinstance(params, GoldschmidtDivParams)
+        assert isinstance(sync_rom, bool)
+        self.params = params
+        self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
+        self.sync_rom = sync_rom
+        self.n = Signal(unsigned(2 * params.io_width))
+        self.d = Signal(unsigned(params.io_width))
+        self.q = Signal(unsigned(params.io_width))
+        self.r = Signal(unsigned(params.io_width))
+
+        # in constructor so we get trace without needing to call elaborate
+        state = GoldschmidtDivHDLState(
+            m=Module(),
+            orig_n=self.n,
+            orig_d=self.d,
+            n=self.n,
+            d=self.d)
+
+        self.trace = [replace(state)]
+
+        # copy and reverse
+        pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
+
+        for op_index, op in enumerate(self.params.ops):
+            while len(pipe_reg_indexes) > 0 \
+                    and pipe_reg_indexes[-1] <= op_index:
+                pipe_reg_indexes.pop()
+                state.insert_pipeline_register()
+            op.gen_hdl(self.params, state, self.sync_rom)
+            self.trace.append(replace(state))
+
+        while len(pipe_reg_indexes) > 0:
+            pipe_reg_indexes.pop()
+            state.insert_pipeline_register()
+
+        state.m.d.comb += [
+            self.q.eq(state.quotient),
+            self.r.eq(state.remainder),
+        ]
+
+    def elaborate(self, platform):
+        return self.trace[0].m
+
+
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
 
 
@@ -1087,6 +1463,11 @@ def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
     # tuple for immutability
     return tuple(table)
 
+# FIXME: add code to calculate error bounds and check that the algorithm will
+# actually work (like in the goldschmidt division algorithm).
+# FIXME: add code to calculate a good set of parameters based on the error
+# bounds checking.
+
 
 def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
                            table_addr_bits, table_data_bits, iter_count):