# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
-from dataclasses import dataclass, field, fields, replace
+from collections import defaultdict
import logging
import math
import enum
from fractions import Fraction
from types import FunctionType
from functools import lru_cache
+from nmigen.hdl.ast import Signal, unsigned, signed, Const
+from nmigen.hdl.dsl import Module, Elaboratable
+from nmigen.hdl.mem import Memory
+from nmutil.clz import CLZ
+from nmutil.plain_data import plain_data, fields, replace
try:
from functools import cached_property
ERROR_IF_INEXACT = enum.auto()
-@dataclass(frozen=True)
+@plain_data(frozen=True, eq=False, repr=False)
class FixedPoint:
- bits: int
- frac_wid: int
+ __slots__ = "bits", "frac_wid"
- def __post_init__(self):
+ def __init__(self, bits, frac_wid):
+ self.bits = bits
+ self.frac_wid = frac_wid
assert isinstance(self.bits, int)
assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
return retval
-@dataclass
-class GoldschmidtDivState:
- orig_n: int
- """original numerator"""
-
- orig_d: int
- """original denominator"""
-
- n: FixedPoint
- """numerator -- N_prime[i] in the paper's algorithm 2"""
-
- d: FixedPoint
- """denominator -- D_prime[i] in the paper's algorithm 2"""
-
- f: "FixedPoint | None" = None
- """current factor -- F_prime[i] in the paper's algorithm 2"""
-
- quotient: "int | None" = None
- """final quotient"""
-
- remainder: "int | None" = None
- """final remainder"""
-
- n_shift: "int | None" = None
- """amount the numerator needs to be left-shifted at the end of the
- algorithm.
- """
-
-
class ParamsNotAccurateEnough(Exception):
"""raised when the parameters aren't accurate enough to have goldschmidt
division work."""
raise ParamsNotAccurateEnough(msg)
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
class GoldschmidtDivParamsBase:
"""parameters for a Goldschmidt division algorithm, excluding derived
parameters.
"""
- io_width: int
- """bit-width of the input divisor and the result.
- the input numerator is `2 * io_width`-bits wide.
- """
+ __slots__ = ("io_width", "extra_precision", "table_addr_bits",
+ "table_data_bits", "iter_count")
+
+ def __init__(self, io_width, extra_precision, table_addr_bits,
+ table_data_bits, iter_count):
+ assert isinstance(io_width, int)
+ assert isinstance(extra_precision, int)
+ assert isinstance(table_addr_bits, int)
+ assert isinstance(table_data_bits, int)
+ assert isinstance(iter_count, int)
+ self.io_width = io_width
+ """bit-width of the input divisor and the result.
+ the input numerator is `2 * io_width`-bits wide.
+ """
- extra_precision: int
- """number of bits of additional precision used inside the algorithm."""
+ self.extra_precision = extra_precision
+ """number of bits of additional precision used inside the algorithm."""
- table_addr_bits: int
- """the number of address bits used in the lookup-table."""
+ self.table_addr_bits = table_addr_bits
+ """the number of address bits used in the lookup-table."""
- table_data_bits: int
- """the number of data bits used in the lookup-table."""
+ self.table_data_bits = table_data_bits
+ """the number of data bits used in the lookup-table."""
- iter_count: int
- """the total number of iterations of the division algorithm's loop"""
+ self.iter_count = iter_count
+ """the total number of iterations of the division algorithm's loop"""
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
class GoldschmidtDivParams(GoldschmidtDivParamsBase):
"""parameters for a Goldschmidt division algorithm.
Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
"""
- # tuple to be immutable, repr=False so repr() works for debugging even when
- # __post_init__ hasn't finished running yet
- table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
- """the lookup-table"""
-
- ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
- """the operations needed to perform the goldschmidt division algorithm."""
+ __slots__ = "table", "ops"
def _shrink_bound(self, bound, round_dir):
"""prevent fractions from having huge numerators/denominators by
# we round down
return min_value
- def __post_init__(self):
- # called by the autogenerated __init__
+ def __init__(self, io_width, extra_precision, table_addr_bits,
+ table_data_bits, iter_count):
+ super().__init__(io_width=io_width,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits,
+ iter_count=iter_count)
_assert_accuracy(self.io_width >= 1, "io_width out of range")
_assert_accuracy(self.extra_precision >= 0,
"extra_precision out of range")
table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
self.table_data_bits,
RoundDir.DOWN))
- # we have to use object.__setattr__ since frozen=True
- object.__setattr__(self, "table", tuple(table))
- object.__setattr__(self, "ops", tuple(self.__make_ops()))
+
+ self.table = tuple(table)
+ """ the lookup-table.
+ type: tuple[FixedPoint, ...]
+ """
+
+ self.ops = tuple(self.__make_ops())
+ "the operations needed to perform the goldschmidt division algorithm."
@property
def expanded_width(self):
"""the total number of bits of precision used inside the algorithm."""
return self.io_width + self.extra_precision
+ @property
+ def n_d_f_int_wid(self):
+ """the number of bits in the integer part of `state.n`, `state.d`, and
+ `state.f` during the main iteration loop.
+ """
+ return 2
+
+ @property
+ def n_d_f_total_wid(self):
+ """the total number of bits (both integer and fraction bits) in
+ `state.n`, `state.d`, and `state.f` during the main iteration loop.
+ """
+ return self.n_d_f_int_wid + self.expanded_width
+
@cache_on_self
def max_neps(self, i):
"""maximum value of `neps[i]`.
def max_n_shift(self):
""" maximum value of `state.n_shift`.
"""
- # input numerator is `2*io_width`-bits
- max_n = (1 << (self.io_width * 2)) - 1
- max_n_shift = 0
- # normalize so 1 <= n < 2
- while max_n >= 2:
- max_n >>= 1
- max_n_shift += 1
- return max_n_shift
+ # numerator must be less than `denominator << self.io_width`, so
+ # `n_shift` is at most `self.io_width`
+ return self.io_width
@cached_property
def n_hat(self):
max_rel_error = (2 * i) * self.n_hat + power
min_a_over_b = Fraction(1, 2)
- max_a_over_b = Fraction(2)
- max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
- max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+ min_abs_error_for_correctness = min_a_over_b / (1 << self.max_n_shift)
+ min_rel_error_for_correctness = (min_abs_error_for_correctness
+ / min_a_over_b)
- _assert_accuracy(max_rel_error < max_allowed_rel_error,
- f"not accurate enough: max_rel_error={max_rel_error}"
- f" max_allowed_rel_error={max_allowed_rel_error}")
+ _assert_accuracy(
+ max_rel_error < min_rel_error_for_correctness,
+ f"not accurate enough: max_rel_error={max_rel_error}"
+ f" min_rel_error_for_correctness={min_rel_error_for_correctness}")
yield GoldschmidtDivOp.CalcResult
@lru_cache(maxsize=1 << 16)
def __cached_new(base_params):
assert isinstance(base_params, GoldschmidtDivParamsBase)
- # can't use dataclasses.asdict, since it's recursive and will also give
- # child class fields too, which we don't want.
kwargs = {}
for field in fields(GoldschmidtDivParamsBase):
- kwargs[field.name] = getattr(base_params, field.name)
+ kwargs[field] = getattr(base_params, field)
try:
return GoldschmidtDivParams(**kwargs), None
except ParamsNotAccurateEnough as e:
else:
break
- return cached_new(params)
+ retval = cached_new(params)
+ assert isinstance(retval, GoldschmidtDivParams)
+ return retval
+
+
+def clz(v, wid):
+ """count leading zeros -- handy for debugging."""
+ assert isinstance(wid, int)
+ assert isinstance(v, int) and 0 <= v < (1 << wid)
+ return (1 << wid).bit_length() - v.bit_length()
@enum.unique
state.n_shift = 0
# normalize so 1 <= n < 2
while state.n >= 2:
- state.n = (state.n * 0.5).to_frac_wid(expanded_width)
+ state.n = (state.n * 0.5).to_frac_wid(expanded_width,
+ round_dir=RoundDir.DOWN)
state.n_shift += 1
elif self == GoldschmidtDivOp.FEqTableLookup:
# compute initial f by table lookup
d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
state.f = params.table[d_m_1.bits]
+ state.f = state.f.to_frac_wid(expanded_width,
+ round_dir=RoundDir.DOWN)
elif self == GoldschmidtDivOp.MulNByF:
assert state.f is not None
n = state.n * state.f
else:
assert False, f"unimplemented GoldschmidtDivOp: {self}"
+ def gen_hdl(self, params, state, sync_rom):
+ """generate the hdl for this operation.
-def goldschmidt_div(n, d, params):
+ arguments:
+ params: GoldschmidtDivParams
+ the goldschmidt division parameters.
+ state: GoldschmidtDivHDLState
+ the input/output state
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ """
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(state, GoldschmidtDivHDLState)
+ m = state.m
+ if self == GoldschmidtDivOp.Normalize:
+ # normalize so 1 <= d < 2
+ assert state.d.width == params.io_width
+ assert state.n.width == 2 * params.io_width
+ d_leading_zeros = CLZ(params.io_width)
+ m.submodules.d_leading_zeros = d_leading_zeros
+ m.d.comb += d_leading_zeros.sig_in.eq(state.d)
+ d_shift_out = Signal.like(state.d)
+ m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
+ d = Signal(params.n_d_f_total_wid)
+ m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+ >> state.d.width)
+
+ # normalize so 1 <= n < 2
+ n_leading_zeros = CLZ(2 * params.io_width)
+ m.submodules.n_leading_zeros = n_leading_zeros
+ m.d.comb += n_leading_zeros.sig_in.eq(state.n)
+ signed_zero = Const(0, signed(1)) # force subtraction to be signed
+ n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
+ - n_leading_zeros.lz)
+ n_shift_s = Signal.like(n_shift_s_v)
+ n_shift_n_lz_out = Signal.like(state.n)
+ n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
+ m.d.comb += [
+ n_shift_s.eq(n_shift_s_v),
+ n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+ n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
+ ]
+ state.n_shift = Signal(d_leading_zeros.lz.width)
+ n = Signal(params.n_d_f_total_wid)
+ with m.If(n_shift_s < 0):
+ m.d.comb += [
+ state.n_shift.eq(0),
+ n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+ >> state.d.width),
+ ]
+ with m.Else():
+ m.d.comb += [
+ state.n_shift.eq(n_shift_s),
+ n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+ >> state.n.width),
+ ]
+ state.n = n
+ state.d = d
+ elif self == GoldschmidtDivOp.FEqTableLookup:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ # compute initial f by table lookup
+
+ # extra bit for table entries == 1.0
+ table_width = 1 + params.table_data_bits
+ table = Memory(width=table_width, depth=len(params.table),
+ init=[i.bits for i in params.table])
+ addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
+ if sync_rom:
+ table_read = table.read_port()
+ m.d.comb += table_read.addr.eq(addr)
+ state.insert_pipeline_register()
+ else:
+ table_read = table.read_port(domain="comb")
+ m.d.comb += table_read.addr.eq(addr)
+ m.submodules.table_read = table_read
+ state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
+ data_shift = params.expanded_width - params.table_data_bits
+ m.d.comb += state.f.eq(table_read.data << data_shift)
+ elif self == GoldschmidtDivOp.MulNByF:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ n = Signal.like(state.n)
+ m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
+ state.n = n
+ elif self == GoldschmidtDivOp.MulDByF:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ d = Signal.like(state.d)
+ d_times_f = Signal.like(state.d * state.f)
+ m.d.comb += [
+ d_times_f.eq(state.d * state.f),
+ # round the multiplication up
+ d.eq((d_times_f >> params.expanded_width)
+ + (d_times_f[:params.expanded_width] != 0)),
+ ]
+ state.d = d
+ elif self == GoldschmidtDivOp.FEq2MinusD:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ f = Signal.like(state.d)
+ m.d.comb += f.eq((2 << params.expanded_width) - state.d)
+ state.f = f
+ elif self == GoldschmidtDivOp.CalcResult:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.n_shift is not None
+ # scale to correct value
+ n = state.n * (1 << state.n_shift)
+ q_approx = Signal(params.io_width)
+ # extra bit for if it's bigger than orig_d
+ r_approx = Signal(params.io_width + 1)
+ adjusted_r = Signal(signed(1 + params.io_width))
+ m.d.comb += [
+ q_approx.eq((state.n << state.n_shift)
+ >> params.expanded_width),
+ r_approx.eq(state.orig_n - q_approx * state.orig_d),
+ adjusted_r.eq(r_approx - state.orig_d),
+ ]
+ state.quotient = Signal(params.io_width)
+ state.remainder = Signal(params.io_width)
+
+ with m.If(adjusted_r >= 0):
+ m.d.comb += [
+ state.quotient.eq(q_approx + 1),
+ state.remainder.eq(adjusted_r),
+ ]
+ with m.Else():
+ m.d.comb += [
+ state.quotient.eq(q_approx),
+ state.remainder.eq(r_approx),
+ ]
+ else:
+ assert False, f"unimplemented GoldschmidtDivOp: {self}"
+
+
+@plain_data(repr=False)
+class GoldschmidtDivState:
+ __slots__ = ("orig_n", "orig_d", "n", "d",
+ "f", "quotient", "remainder", "n_shift")
+
+ def __init__(self, orig_n, orig_d, n, d,
+ f=None, quotient=None, remainder=None, n_shift=None):
+ assert isinstance(orig_n, int)
+ assert isinstance(orig_d, int)
+ assert isinstance(n, FixedPoint)
+ assert isinstance(d, FixedPoint)
+ assert f is None or isinstance(f, FixedPoint)
+ assert quotient is None or isinstance(quotient, int)
+ assert remainder is None or isinstance(remainder, int)
+ assert n_shift is None or isinstance(n_shift, int)
+ self.orig_n = orig_n
+ """original numerator"""
+
+ self.orig_d = orig_d
+ """original denominator"""
+
+ self.n = n
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+ self.d = d
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+ self.f = f
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+ self.quotient = quotient
+ """final quotient"""
+
+ self.remainder = remainder
+ """final remainder"""
+
+ self.n_shift = n_shift
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
+
+ def __repr__(self):
+ fields_str = []
+ for field in fields(GoldschmidtDivState):
+ value = getattr(self, field)
+ if value is None:
+ continue
+ if isinstance(value, int) and field != "n_shift":
+ fields_str.append(f"{field}={hex(value)}")
+ else:
+ fields_str.append(f"{field}={value!r}")
+ return f"GoldschmidtDivState({', '.join(fields_str)})"
+
+
+def goldschmidt_div(n, d, params, trace=lambda state: None):
""" Goldschmidt division algorithm.
based on:
denominator. a `width`-bit unsigned integer. must not be zero.
width: int
the bit-width of the inputs/outputs. must be a positive integer.
+ trace: Function[[GoldschmidtDivState], None]
+ called with the initial state and the state after executing each
+ operation in `params.ops`.
returns: tuple[int, int]
the quotient and remainder. a tuple of two `width`-bit unsigned
d=FixedPoint(d, params.io_width),
)
+ trace(state)
for op in params.ops:
op.run(params, state)
+ trace(state)
assert state.quotient is not None
assert state.remainder is not None
return state.quotient, state.remainder
+@plain_data(eq=False)
+class GoldschmidtDivHDLState:
+ __slots__ = ("m", "orig_n", "orig_d", "n", "d",
+ "f", "quotient", "remainder", "n_shift")
+
+ __signal_name_prefix = "state_"
+
+ def __init__(self, m, orig_n, orig_d, n, d,
+ f=None, quotient=None, remainder=None, n_shift=None):
+ assert isinstance(m, Module)
+ assert isinstance(orig_n, Signal)
+ assert isinstance(orig_d, Signal)
+ assert isinstance(n, Signal)
+ assert isinstance(d, Signal)
+ assert f is None or isinstance(f, Signal)
+ assert quotient is None or isinstance(quotient, Signal)
+ assert remainder is None or isinstance(remainder, Signal)
+ assert n_shift is None or isinstance(n_shift, Signal)
+
+ self.m = m
+ """The HDL Module"""
+
+ self.orig_n = orig_n
+ """original numerator"""
+
+ self.orig_d = orig_d
+ """original denominator"""
+
+ self.n = n
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+ self.d = d
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+ self.f = f
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+ self.quotient = quotient
+ """final quotient"""
+
+ self.remainder = remainder
+ """final remainder"""
+
+ self.n_shift = n_shift
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
+
+ # old_signals must be set last
+ self.old_signals = defaultdict(list)
+
+ def __setattr__(self, name, value):
+ assert isinstance(name, str)
+ if name.startswith("_"):
+ return super().__setattr__(name, value)
+ try:
+ old_signals = self.old_signals[name]
+ except AttributeError:
+ # haven't yet finished __post_init__
+ return super().__setattr__(name, value)
+ assert name != "m" and name != "old_signals", f"can't write to {name}"
+ assert isinstance(value, Signal)
+ value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
+ old_signal = getattr(self, name, None)
+ if old_signal is not None:
+ assert isinstance(old_signal, Signal)
+ old_signals.append(old_signal)
+ return super().__setattr__(name, value)
+
+ def insert_pipeline_register(self):
+ old_prefix = self.__signal_name_prefix
+ try:
+ for field in fields(GoldschmidtDivHDLState):
+ if field.startswith("_") or field == "m":
+ continue
+ old_sig = getattr(self, field, None)
+ if old_sig is None:
+ continue
+ assert isinstance(old_sig, Signal)
+ new_sig = Signal.like(old_sig)
+ setattr(self, field, new_sig)
+ self.m.d.sync += new_sig.eq(old_sig)
+ finally:
+ self.__signal_name_prefix = old_prefix
+
+
+class GoldschmidtDivHDL(Elaboratable):
+ """ Goldschmidt division algorithm.
+
+ based on:
+ Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+ A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+ attributes:
+ params: GoldschmidtDivParams
+ the goldschmidt division algorithm parameters.
+ pipe_reg_indexes: list[int]
+ the operation indexes where pipeline registers should be inserted.
+ duplicate values mean multiple registers should be inserted for
+ that operation index -- this is useful to allow yosys to spread a
+ multiplication across those multiple pipeline stages.
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ n: Signal(unsigned(2 * params.io_width))
+ input numerator. a `2 * params.io_width`-bit unsigned integer.
+ must be less than `d << params.io_width`, otherwise the quotient
+ wouldn't fit in `params.io_width` bits.
+ d: Signal(unsigned(params.io_width))
+ input denominator. a `params.io_width`-bit unsigned integer.
+ must not be zero.
+ q: Signal(unsigned(params.io_width))
+ output quotient. only valid when `n < (d << params.io_width)`.
+ r: Signal(unsigned(params.io_width))
+ output remainder. only valid when `n < (d << params.io_width)`.
+ trace: list[GoldschmidtDivHDLState]
+ list of the initial state and the state after executing each
+ operation in `params.ops`.
+ """
+
+ @property
+ def total_pipeline_registers(self):
+ """the total number of pipeline registers"""
+ return len(self.pipe_reg_indexes) + self.sync_rom
+
+ def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(sync_rom, bool)
+ self.params = params
+ self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
+ self.sync_rom = sync_rom
+ self.n = Signal(unsigned(2 * params.io_width))
+ self.d = Signal(unsigned(params.io_width))
+ self.q = Signal(unsigned(params.io_width))
+ self.r = Signal(unsigned(params.io_width))
+
+ # in constructor so we get trace without needing to call elaborate
+ state = GoldschmidtDivHDLState(
+ m=Module(),
+ orig_n=self.n,
+ orig_d=self.d,
+ n=self.n,
+ d=self.d)
+
+ self.trace = [replace(state)]
+
+ # copy and reverse
+ pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
+
+ for op_index, op in enumerate(self.params.ops):
+ while len(pipe_reg_indexes) > 0 \
+ and pipe_reg_indexes[-1] <= op_index:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+ op.gen_hdl(self.params, state, self.sync_rom)
+ self.trace.append(replace(state))
+
+ while len(pipe_reg_indexes) > 0:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+
+ state.m.d.comb += [
+ self.q.eq(state.quotient),
+ self.r.eq(state.remainder),
+ ]
+
+ def elaborate(self, platform):
+ return self.trace[0].m
+
+
GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
# tuple for immutability
return tuple(table)
+# FIXME: add code to calculate error bounds and check that the algorithm will
+# actually work (like in the goldschmidt division algorithm).
+# FIXME: add code to calculate a good set of parameters based on the error
+# bounds checking.
+
def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
table_addr_bits, table_data_bits, iter_count):