# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, fields, replace
+import logging
import math
import enum
from fractions import Fraction
from types import FunctionType
+from functools import lru_cache
try:
from functools import cached_property
from cached_property import cached_property
# fix broken IDE type detection for cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from functools import cached_property
def __floor__(self):
return self.bits >> self.frac_wid
+ def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(frac_wid, int) and frac_wid >= 0
+ assert isinstance(round_dir, RoundDir)
+ rhs = FixedPoint.cast(rhs)
+ return FixedPoint.with_frac_wid(self.as_fraction()
+ / rhs.as_fraction(),
+ frac_wid, round_dir)
+
+ def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute sqrt of negative number")
+ if self == 0:
+ return self
+ retval = FixedPoint(0, self.frac_wid)
+ int_part_wid = self.bits.bit_length() - self.frac_wid
+ first_bit_index = -(-int_part_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial <= self:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval < self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way <= self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval != self:
+ raise ValueError("inexact sqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
+
+ def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ """compute the reciprocal-sqrt of `self`"""
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute rsqrt of negative number")
+ if self == 0:
+ raise ZeroDivisionError("can't compute rsqrt of zero")
+ retval = FixedPoint(0, self.frac_wid)
+ first_bit_index = -(-self.frac_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial * self <= 1:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval * self < 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way * self <= 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval * self != 1:
+ raise ValueError("inexact rsqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
+
@dataclass
class GoldschmidtDivState:
@dataclass(frozen=True, unsafe_hash=True)
-class GoldschmidtDivParams:
- """parameters for a Goldschmidt division algorithm.
- Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+class GoldschmidtDivParamsBase:
+ """parameters for a Goldschmidt division algorithm, excluding derived
+ parameters.
"""
io_width: int
iter_count: int
"""the total number of iterations of the division algorithm's loop"""
- # tuple to be immutable, default so repr() works for debugging even when
+
+@dataclass(frozen=True, unsafe_hash=True)
+class GoldschmidtDivParams(GoldschmidtDivParamsBase):
+ """parameters for a Goldschmidt division algorithm.
+ Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+ """
+
+ # tuple to be immutable, repr=False so repr() works for debugging even when
# __post_init__ hasn't finished running yet
- table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
+ table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
"""the lookup-table"""
- ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
- default=NotImplemented)
+ ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
"""the operations needed to perform the goldschmidt division algorithm."""
def _shrink_bound(self, bound, round_dir):
def __post_init__(self):
# called by the autogenerated __init__
- assert self.io_width >= 1
- assert self.extra_precision >= 0
- assert self.table_addr_bits >= 1
- assert self.table_data_bits >= 1
- assert self.iter_count >= 1
+ _assert_accuracy(self.io_width >= 1, "io_width out of range")
+ _assert_accuracy(self.extra_precision >= 0,
+ "extra_precision out of range")
+ _assert_accuracy(self.table_addr_bits >= 1,
+ "table_addr_bits out of range")
+ _assert_accuracy(self.table_data_bits >= 1,
+ "table_data_bits out of range")
+ _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
table = []
for addr in range(1 << self.table_addr_bits):
table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
max_n_shift += 1
return max_n_shift
+ @cached_property
+ def n_hat(self):
+ """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
+ """
+ n_hat = Fraction(0)
+ for i in range(self.iter_count):
+ n_hat = max(n_hat, self.max_n(i), self.max_d(i))
+ return self._shrink_max(n_hat)
+
def __make_ops(self):
""" Goldschmidt division algorithm.
yield GoldschmidtDivOp.FEqTableLookup
# we use Setting I (section 4.1 of the paper):
- # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
- n_hat = Fraction(0)
+ # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
+ # the conditions on n_hat are satisfied by construction.
for i in range(self.iter_count):
_assert_accuracy(self.max_f(i) == 0)
- n_hat = max(n_hat, self.max_n(i), self.max_d(i))
yield GoldschmidtDivOp.MulNByF
if i != self.iter_count - 1:
yield GoldschmidtDivOp.MulDByF
# ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
i = self.iter_count - 1 # last used `i`
# compute power manually to prevent huge intermediate values
- power = self._shrink_max(self.max_abs_e0 + 3 * n_hat / 2)
+ power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
for _ in range(i):
power = self._shrink_max(power * power)
- max_rel_error = (2 * i) * n_hat + power
+ max_rel_error = (2 * i) * self.n_hat + power
min_a_over_b = Fraction(1, 2)
max_a_over_b = Fraction(2)
yield GoldschmidtDivOp.CalcResult
+ @cache_on_self
def default_cost_fn(self):
""" calculate the estimated cost on an arbitrary scale of implementing
goldschmidt division with the specified parameters. larger cost
mul_cost = self.expanded_width ** 2
mul_cost *= self.expanded_width.bit_length()
cost += mul_cost
- cost += 1e6 * self.iter_count
+ cost += 5e7 * self.iter_count
return cost
@staticmethod
- def get(io_width):
+ @lru_cache(maxsize=1 << 16)
+ def __cached_new(base_params):
+ assert isinstance(base_params, GoldschmidtDivParamsBase)
+ # can't use dataclasses.asdict, since it's recursive and will also give
+ # child class fields too, which we don't want.
+ kwargs = {}
+ for field in fields(GoldschmidtDivParamsBase):
+ kwargs[field.name] = getattr(base_params, field.name)
+ try:
+ return GoldschmidtDivParams(**kwargs), None
+ except ParamsNotAccurateEnough as e:
+ return None, e
+
+ @staticmethod
+ def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
+ raise e
+
+ @staticmethod
+ def cached_new(base_params, handle_error=__raise):
+ assert isinstance(base_params, GoldschmidtDivParamsBase)
+ params, error = GoldschmidtDivParams.__cached_new(base_params)
+ if error is None:
+ return params
+ else:
+ return handle_error(error)
+
+ @staticmethod
+ def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
""" find efficient parameters for a goldschmidt division algorithm
with `params.io_width == io_width`.
+
+ arguments:
+ io_width: int
+ bit-width of the input divisor and the result.
+ the input numerator is `2 * io_width`-bits wide.
+ cost_fn: Callable[[GoldschmidtDivParams], float]
+ return the estimated cost on an arbitrary scale of implementing
+ goldschmidt division with the specified parameters. larger cost
+ values mean worse parameters.
+ max_table_addr_bits: int
+ maximum allowable value of `table_addr_bits`
"""
assert isinstance(io_width, int) and io_width >= 1
- last_params = None
+ assert callable(cost_fn)
+
last_error = None
- for extra_precision in range(io_width * 2 + 4):
- for table_addr_bits in range(1, 7 + 1):
- table_data_bits = io_width + extra_precision
- for iter_count in range(1, 2 * io_width.bit_length()):
- try:
- return GoldschmidtDivParams(
- io_width=io_width,
- extra_precision=extra_precision,
- table_addr_bits=table_addr_bits,
- table_data_bits=table_data_bits,
- iter_count=iter_count)
- except ParamsNotAccurateEnough as e:
- last_params = (f"GoldschmidtDivParams("
- f"io_width={io_width!r}, "
- f"extra_precision={extra_precision!r}, "
- f"table_addr_bits={table_addr_bits!r}, "
- f"table_data_bits={table_data_bits!r}, "
- f"iter_count={iter_count!r})")
- last_error = e
- raise ValueError(f"can't find working parameters for a goldschmidt "
- f"division algorithm: last params: {last_params}"
- ) from last_error
+ last_error_params = None
+
+ def cached_new(base_params):
+ def handle_error(e):
+ nonlocal last_error, last_error_params
+ last_error = e
+ last_error_params = base_params
+ return None
+
+ retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
+ if retval is None:
+ logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
+ else:
+ logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
+ return retval
+
+ @lru_cache(maxsize=None)
+ def get_cost(base_params):
+ params = cached_new(base_params)
+ if params is None:
+ return math.inf
+ retval = cost_fn(params)
+ logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
+ return retval
+
+ # start with parameters big enough to always work.
+ initial_extra_precision = io_width * 2 + 4
+ initial_params = GoldschmidtDivParamsBase(
+ io_width=io_width,
+ extra_precision=initial_extra_precision,
+ table_addr_bits=min(max_table_addr_bits, io_width),
+ table_data_bits=io_width + initial_extra_precision,
+ iter_count=1 + io_width.bit_length())
+
+ if cached_new(initial_params) is None:
+ raise ValueError(f"initial goldschmidt division algorithm "
+ f"parameters are invalid: {initial_params}"
+ ) from last_error
+
+ # find good initial `iter_count`
+ params = initial_params
+ for iter_count in range(1, initial_params.iter_count):
+ trial_params = replace(params, iter_count=iter_count)
+ if cached_new(trial_params) is not None:
+ params = trial_params
+ break
+
+ # now find `table_addr_bits`
+ cost = get_cost(params)
+ for table_addr_bits in range(1, max_table_addr_bits):
+ trial_params = replace(params, table_addr_bits=table_addr_bits)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # check one higher `iter_count` to see if it has lower cost
+ for table_addr_bits in range(1, max_table_addr_bits + 1):
+ trial_params = replace(params,
+ table_addr_bits=table_addr_bits,
+ iter_count=params.iter_count + 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # now shrink `table_data_bits`
+ while True:
+ trial_params = replace(params,
+ table_data_bits=params.table_data_bits - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ # and shrink `extra_precision`
+ while True:
+ trial_params = replace(params,
+ extra_precision=params.extra_precision - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ return cached_new(params)
@enum.unique
assert state.remainder is not None
return state.quotient, state.remainder
+
+
+GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
+
+
+@lru_cache()
+def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
+ """Generate the look-up table needed for Goldschmidt's square-root and
+ reciprocal-square-root algorithm.
+
+ arguments:
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+ """
+ assert isinstance(table_addr_bits, int) and \
+ table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ table = []
+ table_len = 1 << table_addr_bits
+ for addr in range(table_len):
+ if addr == 0:
+ value = FixedPoint(0, table_data_bits)
+ elif (addr << 2) < table_len:
+ value = None # table entries should be unused
+ else:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
+ max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
+ value = max_input_value.to_frac_wid(max_frac_wid)
+ value = value.rsqrt(RoundDir.DOWN)
+ value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
+ table.append(value)
+
+ # tuple for immutability
+ return tuple(table)
+
+
+def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
+ table_addr_bits, table_data_bits, iter_count):
+ """Goldschmidt's square-root and reciprocal-square-root algorithm.
+
+ uses algorithm based on second method at:
+ https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+
+ arguments:
+ radicand: FixedPoint(frac_wid=frac_wid)
+ the input value to take the square-root and reciprocal-square-root of.
+ io_width: int
+ the number of bits in the input (`radicand`) and output values.
+ frac_wid: int
+ the number of fraction bits in the input (`radicand`) and output
+ values.
+ extra_precision: int
+ the number of bits of internal extra precision.
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+
+ returns: tuple[FixedPoint, FixedPoint]
+ the square-root and reciprocal-square-root, rounded down to the
+ nearest representable value. If `radicand == 0`, then the
+ reciprocal-square-root value returned is zero.
+ """
+ assert (isinstance(radicand, FixedPoint)
+ and radicand.frac_wid == frac_wid
+ and 0 <= radicand.bits < (1 << io_width))
+ assert isinstance(io_width, int) and io_width >= 1
+ assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
+ assert isinstance(extra_precision, int) and extra_precision >= io_width
+ assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ assert isinstance(iter_count, int) and iter_count >= 0
+ expanded_frac_wid = frac_wid + extra_precision
+ s = radicand.to_frac_wid(expanded_frac_wid)
+ sqrt_rshift = extra_precision
+ rsqrt_rshift = extra_precision
+ while s != 0 and s < 1:
+ s = (s * 4).to_frac_wid(expanded_frac_wid)
+ sqrt_rshift += 1
+ rsqrt_rshift -= 1
+ while s >= 4:
+ s = s.div(4, expanded_frac_wid)
+ sqrt_rshift -= 1
+ rsqrt_rshift += 1
+ table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits)
+ # core goldschmidt sqrt/rsqrt algorithm:
+ # initial setup:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
+ assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
+ f = table[addr.bits]
+ assert f is not None, "accessed invalid table entry"
+ # use with_frac_wid to fix IDE type deduction
+ f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
+ x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ for _ in range(iter_count):
+ # iteration step:
+ f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ r = 2 * h
+ # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
+
+ sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
+ rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
+
+ next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
+ if next_sqrt * next_sqrt <= radicand:
+ sqrt = next_sqrt
+
+ next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
+ if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
+ rsqrt = next_rsqrt
+ return sqrt, rsqrt