# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
-from dataclasses import dataclass, field
+from collections import defaultdict
+from dataclasses import dataclass, field, fields, replace
+import logging
import math
import enum
from fractions import Fraction
from types import FunctionType
+from functools import lru_cache
+from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.dsl import Module, Elaboratable
+from nmigen.hdl.mem import Memory
+from nmutil.clz import CLZ
try:
from functools import cached_property
from cached_property import cached_property
# fix broken IDE type detection for cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from functools import cached_property
frac_wid: int
def __post_init__(self):
+ # called by the autogenerated __init__
assert isinstance(self.bits, int)
assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
denominator = value.denominator
else:
value = FixedPoint.cast(value)
- # compute number of bits that should be removed from value
- del_bits = value.frac_wid - frac_wid
- if del_bits == 0:
- return value
- if del_bits < 0: # add bits
- return FixedPoint(value.bits << -del_bits,
- frac_wid)
numerator = value.bits
denominator = 1 << value.frac_wid
if denominator < 0:
def __floor__(self):
return self.bits >> self.frac_wid
+ def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(frac_wid, int) and frac_wid >= 0
+ assert isinstance(round_dir, RoundDir)
+ rhs = FixedPoint.cast(rhs)
+ return FixedPoint.with_frac_wid(self.as_fraction()
+ / rhs.as_fraction(),
+ frac_wid, round_dir)
-@dataclass
-class GoldschmidtDivState:
- orig_n: int
- """original numerator"""
-
- orig_d: int
- """original denominator"""
-
- n: FixedPoint
- """numerator -- N_prime[i] in the paper's algorithm 2"""
-
- d: FixedPoint
- """denominator -- D_prime[i] in the paper's algorithm 2"""
-
- f: "FixedPoint | None" = None
- """current factor -- F_prime[i] in the paper's algorithm 2"""
-
- quotient: "int | None" = None
- """final quotient"""
-
- remainder: "int | None" = None
- """final remainder"""
+ def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute sqrt of negative number")
+ if self == 0:
+ return self
+ retval = FixedPoint(0, self.frac_wid)
+ int_part_wid = self.bits.bit_length() - self.frac_wid
+ first_bit_index = -(-int_part_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial <= self:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval < self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way <= self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval != self:
+ raise ValueError("inexact sqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
- n_shift: "int | None" = None
- """amount the numerator needs to be left-shifted at the end of the
- algorithm.
- """
+ def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ """compute the reciprocal-sqrt of `self`"""
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute rsqrt of negative number")
+ if self == 0:
+ raise ZeroDivisionError("can't compute rsqrt of zero")
+ retval = FixedPoint(0, self.frac_wid)
+ first_bit_index = -(-self.frac_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial * self <= 1:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval * self < 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way * self <= 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval * self != 1:
+ raise ValueError("inexact rsqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
class ParamsNotAccurateEnough(Exception):
@dataclass(frozen=True, unsafe_hash=True)
-class GoldschmidtDivParams:
- """parameters for a Goldschmidt division algorithm.
- Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+class GoldschmidtDivParamsBase:
+ """parameters for a Goldschmidt division algorithm, excluding derived
+ parameters.
"""
io_width: int
iter_count: int
"""the total number of iterations of the division algorithm's loop"""
- # tuple to be immutable
- table: "tuple[FixedPoint, ...]" = field(init=False)
+
+@dataclass(frozen=True, unsafe_hash=True)
+class GoldschmidtDivParams(GoldschmidtDivParamsBase):
+ """parameters for a Goldschmidt division algorithm.
+ Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+ """
+
+ # tuple to be immutable, repr=False so repr() works for debugging even when
+ # __post_init__ hasn't finished running yet
+ table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
"""the lookup-table"""
- ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
+ ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
"""the operations needed to perform the goldschmidt division algorithm."""
+ def _shrink_bound(self, bound, round_dir):
+ """prevent fractions from having huge numerators/denominators by
+ rounding to a `FixedPoint` and converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ assert isinstance(bound, (Fraction, int))
+ assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
+ "you shouldn't use that round_dir on bounds"
+ frac_wid = self.io_width * 4 + 100 # should be enough precision
+ fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
+ return fixed.as_fraction()
+
+ def _shrink_min(self, min_bound):
+ """prevent fractions used as minimum bounds from having huge
+ numerators/denominators by rounding down to a `FixedPoint` and
+ converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ return self._shrink_bound(min_bound, RoundDir.DOWN)
+
+ def _shrink_max(self, max_bound):
+ """prevent fractions used as maximum bounds from having huge
+ numerators/denominators by rounding up to a `FixedPoint` and
+ converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ return self._shrink_bound(max_bound, RoundDir.UP)
+
@property
def table_addr_count(self):
"""number of distinct addresses in the lookup-table."""
assert isinstance(addr, int)
assert 0 <= addr < self.table_addr_count
_assert_accuracy(self.io_width >= self.table_addr_bits)
- min_numerator = (1 << self.table_addr_bits) + addr
- denominator = 1 << self.table_addr_bits
- values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
- max_numerator = min_numerator + values_per_table_entry
+ addr_shift = self.io_width - self.table_addr_bits
+ min_numerator = (1 << self.io_width) + (addr << addr_shift)
+ denominator = 1 << self.io_width
+ values_per_table_entry = 1 << addr_shift
+ max_numerator = min_numerator + values_per_table_entry - 1
min_input = Fraction(min_numerator, denominator)
max_input = Fraction(max_numerator, denominator)
+ min_input = self._shrink_min(min_input)
+ max_input = self._shrink_max(max_input)
+ assert 1 <= min_input <= max_input < 2
return min_input, max_input
def table_value_exact_range(self, addr):
"""return the range of values as `Fraction`s used for the table entry
with address `addr`."""
- min_value, max_value = self.table_input_exact_range(addr)
+ min_input, max_input = self.table_input_exact_range(addr)
# division swaps min/max
- return 1 / max_value, 1 / min_value
+ min_value = 1 / max_input
+ max_value = 1 / min_input
+ min_value = self._shrink_min(min_value)
+ max_value = self._shrink_max(max_value)
+ assert 0.5 < min_value <= max_value <= 1
+ return min_value, max_value
def table_exact_value(self, index):
min_value, max_value = self.table_value_exact_range(index)
def __post_init__(self):
# called by the autogenerated __init__
- assert self.io_width >= 1
- assert self.extra_precision >= 0
- assert self.table_addr_bits >= 1
- assert self.table_data_bits >= 1
- assert self.iter_count >= 1
+ _assert_accuracy(self.io_width >= 1, "io_width out of range")
+ _assert_accuracy(self.extra_precision >= 0,
+ "extra_precision out of range")
+ _assert_accuracy(self.table_addr_bits >= 1,
+ "table_addr_bits out of range")
+ _assert_accuracy(self.table_data_bits >= 1,
+ "table_data_bits out of range")
+ _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
table = []
for addr in range(1 << self.table_addr_bits):
table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
RoundDir.DOWN))
# we have to use object.__setattr__ since frozen=True
object.__setattr__(self, "table", tuple(table))
- object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
-
- @staticmethod
- def get(io_width):
- """ find efficient parameters for a goldschmidt division algorithm
- with `params.io_width == io_width`.
- """
- assert isinstance(io_width, int) and io_width >= 1
- for extra_precision in range(io_width * 2 + 4):
- for table_addr_bits in range(1, 7 + 1):
- table_data_bits = io_width + extra_precision
- for iter_count in range(1, 2 * io_width.bit_length()):
- try:
- return GoldschmidtDivParams(
- io_width=io_width,
- extra_precision=extra_precision,
- table_addr_bits=table_addr_bits,
- table_data_bits=table_data_bits,
- iter_count=iter_count)
- except ParamsNotAccurateEnough:
- pass
- raise ValueError(f"can't find working parameters for a goldschmidt "
- f"division algorithm with io_width={io_width}")
+ object.__setattr__(self, "ops", tuple(self.__make_ops()))
@property
def expanded_width(self):
"""the total number of bits of precision used inside the algorithm."""
return self.io_width + self.extra_precision
+ @property
+ def n_d_f_int_wid(self):
+ """the number of bits in the integer part of `state.n`, `state.d`, and
+ `state.f` during the main iteration loop.
+ """
+ return 2
+
+ @property
+ def n_d_f_total_wid(self):
+ """the total number of bits (both integer and fraction bits) in
+ `state.n`, `state.d`, and `state.f` during the main iteration loop.
+ """
+ return self.n_d_f_int_wid + self.expanded_width
+
@cache_on_self
def max_neps(self, i):
"""maximum value of `neps[i]`.
cur_max_e0 = 1 - min_product
min_e0 = min(min_e0, cur_min_e0)
max_e0 = max(max_e0, cur_max_e0)
+ min_e0 = self._shrink_min(min_e0)
+ max_e0 = self._shrink_max(max_e0)
return min_e0, max_e0
@cached_property
"only one quadrant of interval division implemented"
retval = self.max_neps(i) / min_mpd
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_d(self, i):
"only one quadrant of interval division implemented"
retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_f(self, i):
# `f[i] <= max_feps[i]`
retval = self.max_feps(i)
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_delta(self, i):
assert prev_max_delta >= 0
retval = prev_max_delta ** 2 + self.max_f(i - 1)
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ # `delta[i]` has to be smaller than one otherwise errors would go off
+ # to infinity
+ _assert_accuracy(retval < 1)
+
+ return self._shrink_max(retval)
@cache_on_self
def max_pi(self, i):
# `pi[i] = 1 - (1 - n[i]) * prod`
# where `prod` is the product of,
# for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
- min_prod = Fraction(0)
+ min_prod = Fraction(1)
for j in range(i):
max_n_j = self.max_n(j)
max_d_j = self.max_d(j)
max_n_i = self.max_n(i)
assert max_n_i <= 1 and min_prod >= 0, \
"only one quadrant of interval multiplication implemented"
- return 1 - (1 - max_n_i) * min_prod
+ retval = 1 - (1 - max_n_i) * min_prod
+ return self._shrink_max(retval)
@cached_property
def max_n_shift(self):
max_n_shift += 1
return max_n_shift
+ @cached_property
+ def n_hat(self):
+ """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
+ """
+ n_hat = Fraction(0)
+ for i in range(self.iter_count):
+ n_hat = max(n_hat, self.max_n(i), self.max_d(i))
+ return self._shrink_max(n_hat)
+
+ def __make_ops(self):
+ """ Goldschmidt division algorithm.
+
+ based on:
+ Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+ A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+ yields: GoldschmidtDivOp
+ the operations needed to perform the division.
+ """
+ # establish assumptions of the paper's error analysis (section 3.1):
+
+ # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
+ yield GoldschmidtDivOp.Normalize
+
+ # 2. ensure all relative errors from directed rounding are <= 1 / 4.
+ # the assumption is met by multipliers with > 4-bits precision
+ _assert_accuracy(self.expanded_width > 4)
+
+ # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
+ _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
+ + self.max_f(0) < Fraction(1, 2))
+
+ # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
+ # (B is the denominator)
+
+ for addr in range(self.table_addr_count):
+ f_prime_m1 = self.table[addr]
+ _assert_accuracy(0.5 <= f_prime_m1 <= 1)
+
+ yield GoldschmidtDivOp.FEqTableLookup
+
+ # we use Setting I (section 4.1 of the paper):
+ # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
+ # the conditions on n_hat are satisfied by construction.
+ for i in range(self.iter_count):
+ _assert_accuracy(self.max_f(i) == 0)
+ yield GoldschmidtDivOp.MulNByF
+ if i != self.iter_count - 1:
+ yield GoldschmidtDivOp.MulDByF
+ yield GoldschmidtDivOp.FEq2MinusD
+
+ # relative approximation error `p(N_prime[i])`:
+ # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
+ # `0 <= p(N_prime[i])`
+ # `p(N_prime[i]) <= (2 * i) * n_hat \`
+ # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
+ i = self.iter_count - 1 # last used `i`
+ # compute power manually to prevent huge intermediate values
+ power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
+ for _ in range(i):
+ power = self._shrink_max(power * power)
+
+ max_rel_error = (2 * i) * self.n_hat + power
+
+ min_a_over_b = Fraction(1, 2)
+ max_a_over_b = Fraction(2)
+ max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
+ max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+
+ _assert_accuracy(max_rel_error < max_allowed_rel_error,
+ f"not accurate enough: max_rel_error={max_rel_error}"
+ f" max_allowed_rel_error={max_allowed_rel_error}")
+
+ yield GoldschmidtDivOp.CalcResult
+
+ @cache_on_self
+ def default_cost_fn(self):
+ """ calculate the estimated cost on an arbitrary scale of implementing
+ goldschmidt division with the specified parameters. larger cost
+ values mean worse parameters.
+
+ This is the default cost function for `GoldschmidtDivParams.get`.
+
+ returns: float
+ """
+ rom_cells = self.table_data_bits << self.table_addr_bits
+ cost = float(rom_cells)
+ for op in self.ops:
+ if op == GoldschmidtDivOp.MulNByF \
+ or op == GoldschmidtDivOp.MulDByF:
+ mul_cost = self.expanded_width ** 2
+ mul_cost *= self.expanded_width.bit_length()
+ cost += mul_cost
+ cost += 5e7 * self.iter_count
+ return cost
+
+ @staticmethod
+ @lru_cache(maxsize=1 << 16)
+ def __cached_new(base_params):
+ assert isinstance(base_params, GoldschmidtDivParamsBase)
+ # can't use dataclasses.asdict, since it's recursive and will also give
+ # child class fields too, which we don't want.
+ kwargs = {}
+ for field in fields(GoldschmidtDivParamsBase):
+ kwargs[field.name] = getattr(base_params, field.name)
+ try:
+ return GoldschmidtDivParams(**kwargs), None
+ except ParamsNotAccurateEnough as e:
+ return None, e
+
+ @staticmethod
+ def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
+ raise e
+
+ @staticmethod
+ def cached_new(base_params, handle_error=__raise):
+ assert isinstance(base_params, GoldschmidtDivParamsBase)
+ params, error = GoldschmidtDivParams.__cached_new(base_params)
+ if error is None:
+ return params
+ else:
+ return handle_error(error)
+
+ @staticmethod
+ def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
+ """ find efficient parameters for a goldschmidt division algorithm
+ with `params.io_width == io_width`.
+
+ arguments:
+ io_width: int
+ bit-width of the input divisor and the result.
+ the input numerator is `2 * io_width`-bits wide.
+ cost_fn: Callable[[GoldschmidtDivParams], float]
+ return the estimated cost on an arbitrary scale of implementing
+ goldschmidt division with the specified parameters. larger cost
+ values mean worse parameters.
+ max_table_addr_bits: int
+ maximum allowable value of `table_addr_bits`
+ """
+ assert isinstance(io_width, int) and io_width >= 1
+ assert callable(cost_fn)
+
+ last_error = None
+ last_error_params = None
+
+ def cached_new(base_params):
+ def handle_error(e):
+ nonlocal last_error, last_error_params
+ last_error = e
+ last_error_params = base_params
+ return None
+
+ retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
+ if retval is None:
+ logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
+ else:
+ logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
+ return retval
+
+ @lru_cache(maxsize=None)
+ def get_cost(base_params):
+ params = cached_new(base_params)
+ if params is None:
+ return math.inf
+ retval = cost_fn(params)
+ logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
+ return retval
+
+ # start with parameters big enough to always work.
+ initial_extra_precision = io_width * 2 + 4
+ initial_params = GoldschmidtDivParamsBase(
+ io_width=io_width,
+ extra_precision=initial_extra_precision,
+ table_addr_bits=min(max_table_addr_bits, io_width),
+ table_data_bits=io_width + initial_extra_precision,
+ iter_count=1 + io_width.bit_length())
+
+ if cached_new(initial_params) is None:
+ raise ValueError(f"initial goldschmidt division algorithm "
+ f"parameters are invalid: {initial_params}"
+ ) from last_error
+
+ # find good initial `iter_count`
+ params = initial_params
+ for iter_count in range(1, initial_params.iter_count):
+ trial_params = replace(params, iter_count=iter_count)
+ if cached_new(trial_params) is not None:
+ params = trial_params
+ break
+
+ # now find `table_addr_bits`
+ cost = get_cost(params)
+ for table_addr_bits in range(1, max_table_addr_bits):
+ trial_params = replace(params, table_addr_bits=table_addr_bits)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # check one higher `iter_count` to see if it has lower cost
+ for table_addr_bits in range(1, max_table_addr_bits + 1):
+ trial_params = replace(params,
+ table_addr_bits=table_addr_bits,
+ iter_count=params.iter_count + 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # now shrink `table_data_bits`
+ while True:
+ trial_params = replace(params,
+ table_data_bits=params.table_data_bits - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ # and shrink `extra_precision`
+ while True:
+ trial_params = replace(params,
+ extra_precision=params.extra_precision - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ return cached_new(params)
+
+
+def clz(v, wid):
+ assert isinstance(wid, int)
+ assert isinstance(v, int) and 0 <= v < (1 << wid)
+ return (1 << wid).bit_length() - v.bit_length()
+
@enum.unique
class GoldschmidtDivOp(enum.Enum):
else:
assert False, f"unimplemented GoldschmidtDivOp: {self}"
-
-def _goldschmidt_div_ops(params):
- """ Goldschmidt division algorithm.
-
- based on:
- Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
- A Parametric Error Analysis of Goldschmidt's Division Algorithm.
- https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+ def gen_hdl(self, params, state, sync_rom):
+ # FIXME: finish getting hdl/simulation to work
+ """generate the hdl for this operation.
arguments:
params: GoldschmidtDivParams
- the parameters for the algorithm
-
- yields: GoldschmidtDivOp
- the operations needed to perform the division.
- """
- assert isinstance(params, GoldschmidtDivParams)
-
- # establish assumptions of the paper's error analysis (section 3.1):
+ the goldschmidt division parameters.
+ state: GoldschmidtDivHDLState
+ the input/output state
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ """
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(state, GoldschmidtDivHDLState)
+ m = state.m
+ expanded_width = params.expanded_width
+ table_addr_bits = params.table_addr_bits
+ if self == GoldschmidtDivOp.Normalize:
+ # normalize so 1 <= d < 2
+ assert state.d.width == params.io_width
+ assert state.n.width == 2 * params.io_width
+ d_leading_zeros = CLZ(params.io_width)
+ m.submodules.d_leading_zeros = d_leading_zeros
+ m.d.comb += d_leading_zeros.sig_in.eq(state.d)
+ d_shift_out = Signal.like(state.d)
+ m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
+ state.d = Signal(params.n_d_f_total_wid)
+ m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
+ + params.n_d_f_int_wid))
- # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
- yield GoldschmidtDivOp.Normalize
+ # normalize so 1 <= n < 2
+ n_leading_zeros = CLZ(2 * params.io_width)
+ m.submodules.n_leading_zeros = n_leading_zeros
+ m.d.comb += n_leading_zeros.sig_in.eq(state.n)
+ n_shift_s_v = (params.io_width + d_leading_zeros.lz
+ - n_leading_zeros.lz)
+ n_shift_s = Signal.like(n_shift_s_v)
+ state.n_shift = Signal(d_leading_zeros.lz.width)
+ m.d.comb += [
+ n_shift_s.eq(n_shift_s_v),
+ state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+ ]
+ n = Signal(params.n_d_f_total_wid)
+ shifted_n = state.n << state.n_shift
+ fixed_shift = params.expanded_width - state.n.width
+ m.d.comb += n.eq(shifted_n << fixed_shift)
+ state.n = n
+ elif self == GoldschmidtDivOp.FEqTableLookup:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ # compute initial f by table lookup
- # 2. ensure all relative errors from directed rounding are <= 1 / 4.
- # the assumption is met by multipliers with > 4-bits precision
- _assert_accuracy(params.expanded_width > 4)
+ # extra bit for table entries == 1.0
+ table_width = 1 + params.table_data_bits
+ table = Memory(width=table_width, depth=len(params.table),
+ init=[i.bits for i in params.table])
+ addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+ if sync_rom:
+ table_read = table.read_port()
+ m.d.comb += table_read.addr.eq(addr)
+ state.insert_pipeline_register()
+ else:
+ table_read = table.read_port(domain="comb")
+ m.d.comb += table_read.addr.eq(addr)
+ m.submodules.table_read = table_read
+ state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
+ data_shift = (table_width - params.table_data_bits
+ + params.expanded_width)
+ m.d.comb += state.f.eq(table_read.data << data_shift)
+ elif self == GoldschmidtDivOp.MulNByF:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ n = Signal.like(state.n)
+ m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
+ state.n = n
+ elif self == GoldschmidtDivOp.MulDByF:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ assert state.f is not None
+ assert state.f.width == params.n_d_f_total_wid, "invalid f width"
+ d = Signal.like(state.d)
+ m.d.comb += d.eq((state.d * state.f) >> params.expanded_width)
+ state.d = d
+ elif self == GoldschmidtDivOp.FEq2MinusD:
+ assert state.d.width == params.n_d_f_total_wid, "invalid d width"
+ f = Signal.like(state.d)
+ m.d.comb += f.eq((2 << params.expanded_width) - state.d)
+ state.f = f
+ elif self == GoldschmidtDivOp.CalcResult:
+ assert state.n.width == params.n_d_f_total_wid, "invalid n width"
+ assert state.n_shift is not None
+ # scale to correct value
+ n = state.n * (1 << state.n_shift)
+ q_approx = Signal(params.io_width)
+ # extra bit for if it's bigger than orig_d
+ r_approx = Signal(params.io_width + 1)
+ adjusted_r = Signal(signed(1 + params.io_width))
+ m.d.comb += [
+ q_approx.eq((state.n << state.n_shift)
+ >> params.expanded_width),
+ r_approx.eq(state.orig_n - q_approx * state.orig_d),
+ adjusted_r.eq(r_approx - state.orig_d),
+ ]
+ state.quotient = Signal(params.io_width)
+ state.remainder = Signal(params.io_width)
+
+ with m.If(adjusted_r >= 0):
+ m.d.comb += [
+ state.quotient.eq(q_approx + 1),
+ state.remainder.eq(adjusted_r),
+ ]
+ with m.Else():
+ m.d.comb += [
+ state.quotient.eq(q_approx),
+ state.remainder.eq(r_approx),
+ ]
+ else:
+ assert False, f"unimplemented GoldschmidtDivOp: {self}"
- # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
- _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2
- + params.max_f(0) < Fraction(1, 2))
- # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
- # (B is the denominator)
+@dataclass
+class GoldschmidtDivState:
+ orig_n: int
+ """original numerator"""
- for addr in range(params.table_addr_count):
- f_prime_m1 = params.table[addr]
- _assert_accuracy(0.5 <= f_prime_m1 <= 1)
+ orig_d: int
+ """original denominator"""
- yield GoldschmidtDivOp.FEqTableLookup
+ n: FixedPoint
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
- # we use Setting I (section 4.1 of the paper):
- # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
- n_hat = Fraction(0)
- for i in range(params.iter_count):
- _assert_accuracy(params.max_f(i) == 0)
- n_hat = max(n_hat, params.max_n(i), params.max_d(i))
- yield GoldschmidtDivOp.MulNByF
- if i != params.iter_count - 1:
- yield GoldschmidtDivOp.MulDByF
- yield GoldschmidtDivOp.FEq2MinusD
+ d: FixedPoint
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
- # relative approximation error `p(N_prime[i])`:
- # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
- # `0 <= p(N_prime[i])`
- # `p(N_prime[i]) <= (2 * i) * n_hat \`
- # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
- i = params.iter_count - 1 # last used `i`
- max_rel_error = (2 * i) * n_hat + \
- (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i)
+ f: "FixedPoint | None" = None
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
- min_a_over_b = Fraction(1, 2)
- max_a_over_b = Fraction(2)
- max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
- max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+ quotient: "int | None" = None
+ """final quotient"""
- _assert_accuracy(max_rel_error < max_allowed_rel_error)
+ remainder: "int | None" = None
+ """final remainder"""
- yield GoldschmidtDivOp.CalcResult
+ n_shift: "int | None" = None
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
def goldschmidt_div(n, d, params):
assert state.remainder is not None
return state.quotient, state.remainder
+
+
+@dataclass(eq=False)
+class GoldschmidtDivHDLState:
+ m: Module
+ """The HDL Module"""
+
+ orig_n: Signal
+ """original numerator"""
+
+ orig_d: Signal
+ """original denominator"""
+
+ n: Signal
+ """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+ d: Signal
+ """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+ f: "Signal | None" = None
+ """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+ quotient: "Signal | None" = None
+ """final quotient"""
+
+ remainder: "Signal | None" = None
+ """final remainder"""
+
+ n_shift: "Signal | None" = None
+ """amount the numerator needs to be left-shifted at the end of the
+ algorithm.
+ """
+
+ old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
+ init=False)
+
+ __signal_name_prefix: "str" = field(default="state_", repr=False,
+ init=False)
+
+ def __post_init__(self):
+ # called by the autogenerated __init__
+ self.old_signals = defaultdict(list)
+
+ def __setattr__(self, name, value):
+ assert isinstance(name, str)
+ if name.startswith("_"):
+ return super().__setattr__(name, value)
+ try:
+ old_signals = self.old_signals[name]
+ except AttributeError:
+ # haven't yet finished __post_init__
+ return super().__setattr__(name, value)
+ assert name != "m" and name != "old_signals", f"can't write to {name}"
+ assert isinstance(value, Signal)
+ value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
+ old_signal = getattr(self, name, None)
+ if old_signal is not None:
+ assert isinstance(old_signal, Signal)
+ old_signals.append(old_signal)
+ return super().__setattr__(name, value)
+
+ def insert_pipeline_register(self):
+ old_prefix = self.__signal_name_prefix
+ try:
+ for field in fields(GoldschmidtDivHDLState):
+ if field.name.startswith("_") or field.name == "m":
+ continue
+ old_sig = getattr(self, field.name, None)
+ if old_sig is None:
+ continue
+ assert isinstance(old_sig, Signal)
+ new_sig = Signal.like(old_sig)
+ setattr(self, field.name, new_sig)
+ self.m.d.sync += new_sig.eq(old_sig)
+ finally:
+ self.__signal_name_prefix = old_prefix
+
+
+class GoldschmidtDivHDL(Elaboratable):
+ # FIXME: finish getting hdl/simulation to work
+ """ Goldschmidt division algorithm.
+
+ based on:
+ Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+ A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+ https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+ attributes:
+ params: GoldschmidtDivParams
+ the goldschmidt division algorithm parameters.
+ pipe_reg_indexes: list[int]
+ the operation indexes where pipeline registers should be inserted.
+ duplicate values mean multiple registers should be inserted for
+ that operation index -- this is useful to allow yosys to spread a
+ multiplication across those multiple pipeline stages.
+ sync_rom: bool
+ true if the rom should be read synchronously rather than
+ combinatorially, incurring an extra clock cycle of latency.
+ n: Signal(unsigned(2 * params.io_width))
+ input numerator. a `2 * params.io_width`-bit unsigned integer.
+ must be less than `d << params.io_width`, otherwise the quotient
+ wouldn't fit in `params.io_width` bits.
+ d: Signal(unsigned(params.io_width))
+ input denominator. a `params.io_width`-bit unsigned integer.
+ must not be zero.
+ q: Signal(unsigned(params.io_width))
+ output quotient. only valid when `n < (d << params.io_width)`.
+ r: Signal(unsigned(params.io_width))
+ output remainder. only valid when `n < (d << params.io_width)`.
+ """
+
+ @property
+ def total_pipeline_registers(self):
+ """the total number of pipeline registers"""
+ return len(self.pipe_reg_indexes) + self.sync_rom
+
+ def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
+ assert isinstance(params, GoldschmidtDivParams)
+ assert isinstance(sync_rom, bool)
+ self.params = params
+ self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
+ self.sync_rom = sync_rom
+ self.n = Signal(unsigned(2 * params.io_width))
+ self.d = Signal(unsigned(params.io_width))
+ self.q = Signal(unsigned(params.io_width))
+ self.r = Signal(unsigned(params.io_width))
+
+ def elaborate(self, platform):
+ m = Module()
+ state = GoldschmidtDivHDLState(
+ m=m,
+ orig_n=self.n,
+ orig_d=self.d,
+ n=self.n,
+ d=self.d)
+
+ # copy and reverse
+ pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
+
+ for op_index, op in enumerate(self.params.ops):
+ while len(pipe_reg_indexes) > 0 \
+ and pipe_reg_indexes[-1] <= op_index:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+ op.gen_hdl(self.params, state, self.sync_rom)
+
+ while len(pipe_reg_indexes) > 0:
+ pipe_reg_indexes.pop()
+ state.insert_pipeline_register()
+
+ m.d.comb += self.q.eq(state.quotient)
+ m.d.comb += self.r.eq(state.remainder)
+ return m
+
+
+GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
+
+
+@lru_cache()
+def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
+ """Generate the look-up table needed for Goldschmidt's square-root and
+ reciprocal-square-root algorithm.
+
+ arguments:
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+ """
+ assert isinstance(table_addr_bits, int) and \
+ table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ table = []
+ table_len = 1 << table_addr_bits
+ for addr in range(table_len):
+ if addr == 0:
+ value = FixedPoint(0, table_data_bits)
+ elif (addr << 2) < table_len:
+ value = None # table entries should be unused
+ else:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
+ max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
+ value = max_input_value.to_frac_wid(max_frac_wid)
+ value = value.rsqrt(RoundDir.DOWN)
+ value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
+ table.append(value)
+
+ # tuple for immutability
+ return tuple(table)
+
+# FIXME: add code to calculate error bounds and check that the algorithm will
+# actually work (like in the goldschmidt division algorithm).
+# FIXME: add code to calculate a good set of parameters based on the error
+# bounds checking.
+
+
+def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
+ table_addr_bits, table_data_bits, iter_count):
+ """Goldschmidt's square-root and reciprocal-square-root algorithm.
+
+ uses algorithm based on second method at:
+ https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+
+ arguments:
+ radicand: FixedPoint(frac_wid=frac_wid)
+ the input value to take the square-root and reciprocal-square-root of.
+ io_width: int
+ the number of bits in the input (`radicand`) and output values.
+ frac_wid: int
+ the number of fraction bits in the input (`radicand`) and output
+ values.
+ extra_precision: int
+ the number of bits of internal extra precision.
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+
+ returns: tuple[FixedPoint, FixedPoint]
+ the square-root and reciprocal-square-root, rounded down to the
+ nearest representable value. If `radicand == 0`, then the
+ reciprocal-square-root value returned is zero.
+ """
+ assert (isinstance(radicand, FixedPoint)
+ and radicand.frac_wid == frac_wid
+ and 0 <= radicand.bits < (1 << io_width))
+ assert isinstance(io_width, int) and io_width >= 1
+ assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
+ assert isinstance(extra_precision, int) and extra_precision >= io_width
+ assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ assert isinstance(iter_count, int) and iter_count >= 0
+ expanded_frac_wid = frac_wid + extra_precision
+ s = radicand.to_frac_wid(expanded_frac_wid)
+ sqrt_rshift = extra_precision
+ rsqrt_rshift = extra_precision
+ while s != 0 and s < 1:
+ s = (s * 4).to_frac_wid(expanded_frac_wid)
+ sqrt_rshift += 1
+ rsqrt_rshift -= 1
+ while s >= 4:
+ s = s.div(4, expanded_frac_wid)
+ sqrt_rshift -= 1
+ rsqrt_rshift += 1
+ table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits)
+ # core goldschmidt sqrt/rsqrt algorithm:
+ # initial setup:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
+ assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
+ f = table[addr.bits]
+ assert f is not None, "accessed invalid table entry"
+ # use with_frac_wid to fix IDE type deduction
+ f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
+ x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ for _ in range(iter_count):
+ # iteration step:
+ f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ r = 2 * h
+ # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
+
+ sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
+ rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
+
+ next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
+ if next_sqrt * next_sqrt <= radicand:
+ sqrt = next_sqrt
+
+ next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
+ if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
+ rsqrt = next_rsqrt
+ return sqrt, rsqrt