improved goldschmidt division algorithm parameter optimization algorithm
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
index 62156ee5ec418d9490e31edb89a6f401d587a06f..3af5320d83373e106db74bb7c5394e356036b10d 100644 (file)
@@ -4,11 +4,13 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, fields, replace
+import logging
 import math
 import enum
 from fractions import Fraction
 from types import FunctionType
+from functools import lru_cache
 
 try:
     from functools import cached_property
@@ -16,7 +18,7 @@ except ImportError:
     from cached_property import cached_property
 
 # fix broken IDE type detection for cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 if TYPE_CHECKING:
     from functools import cached_property
 
@@ -284,9 +286,9 @@ def _assert_accuracy(condition, msg="not accurate enough"):
 
 
 @dataclass(frozen=True, unsafe_hash=True)
-class GoldschmidtDivParams:
-    """parameters for a Goldschmidt division algorithm.
-    Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+class GoldschmidtDivParamsBase:
+    """parameters for a Goldschmidt division algorithm, excluding derived
+    parameters.
     """
 
     io_width: int
@@ -306,13 +308,19 @@ class GoldschmidtDivParams:
     iter_count: int
     """the total number of iterations of the division algorithm's loop"""
 
-    # tuple to be immutable, default so repr() works for debugging even when
+
+@dataclass(frozen=True, unsafe_hash=True)
+class GoldschmidtDivParams(GoldschmidtDivParamsBase):
+    """parameters for a Goldschmidt division algorithm.
+    Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
+    """
+
+    # tuple to be immutable, repr=False so repr() works for debugging even when
     # __post_init__ hasn't finished running yet
-    table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
+    table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
     """the lookup-table"""
 
-    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
-                                                default=NotImplemented)
+    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
     """the operations needed to perform the goldschmidt division algorithm."""
 
     def _shrink_bound(self, bound, round_dir):
@@ -392,11 +400,14 @@ class GoldschmidtDivParams:
 
     def __post_init__(self):
         # called by the autogenerated __init__
-        assert self.io_width >= 1
-        assert self.extra_precision >= 0
-        assert self.table_addr_bits >= 1
-        assert self.table_data_bits >= 1
-        assert self.iter_count >= 1
+        _assert_accuracy(self.io_width >= 1, "io_width out of range")
+        _assert_accuracy(self.extra_precision >= 0,
+                         "extra_precision out of range")
+        _assert_accuracy(self.table_addr_bits >= 1,
+                         "table_addr_bits out of range")
+        _assert_accuracy(self.table_data_bits >= 1,
+                         "table_data_bits out of range")
+        _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
         table = []
         for addr in range(1 << self.table_addr_bits):
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
@@ -404,38 +415,7 @@ class GoldschmidtDivParams:
                                                   RoundDir.DOWN))
         # we have to use object.__setattr__ since frozen=True
         object.__setattr__(self, "table", tuple(table))
-        object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
-
-    @staticmethod
-    def get(io_width):
-        """ find efficient parameters for a goldschmidt division algorithm
-        with `params.io_width == io_width`.
-        """
-        assert isinstance(io_width, int) and io_width >= 1
-        last_params = None
-        last_error = None
-        for extra_precision in range(io_width * 2 + 4):
-            for table_addr_bits in range(1, 7 + 1):
-                table_data_bits = io_width + extra_precision
-                for iter_count in range(1, 2 * io_width.bit_length()):
-                    try:
-                        return GoldschmidtDivParams(
-                            io_width=io_width,
-                            extra_precision=extra_precision,
-                            table_addr_bits=table_addr_bits,
-                            table_data_bits=table_data_bits,
-                            iter_count=iter_count)
-                    except ParamsNotAccurateEnough as e:
-                        last_params = (f"GoldschmidtDivParams("
-                                       f"io_width={io_width!r}, "
-                                       f"extra_precision={extra_precision!r}, "
-                                       f"table_addr_bits={table_addr_bits!r}, "
-                                       f"table_data_bits={table_data_bits!r}, "
-                                       f"iter_count={iter_count!r})")
-                        last_error = e
-        raise ValueError(f"can't find working parameters for a goldschmidt "
-                         f"division algorithm: last params: {last_params}"
-                         ) from last_error
+        object.__setattr__(self, "ops", tuple(self.__make_ops()))
 
     @property
     def expanded_width(self):
@@ -662,6 +642,242 @@ class GoldschmidtDivParams:
             max_n_shift += 1
         return max_n_shift
 
+    @cached_property
+    def n_hat(self):
+        """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
+        """
+        n_hat = Fraction(0)
+        for i in range(self.iter_count):
+            n_hat = max(n_hat, self.max_n(i), self.max_d(i))
+        return self._shrink_max(n_hat)
+
+    def __make_ops(self):
+        """ Goldschmidt division algorithm.
+
+            based on:
+            Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+            A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+            https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+            yields: GoldschmidtDivOp
+                the operations needed to perform the division.
+        """
+        # establish assumptions of the paper's error analysis (section 3.1):
+
+        # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
+        yield GoldschmidtDivOp.Normalize
+
+        # 2. ensure all relative errors from directed rounding are <= 1 / 4.
+        # the assumption is met by multipliers with > 4-bits precision
+        _assert_accuracy(self.expanded_width > 4)
+
+        # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
+        _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
+                         + self.max_f(0) < Fraction(1, 2))
+
+        # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
+        # (B is the denominator)
+
+        for addr in range(self.table_addr_count):
+            f_prime_m1 = self.table[addr]
+            _assert_accuracy(0.5 <= f_prime_m1 <= 1)
+
+        yield GoldschmidtDivOp.FEqTableLookup
+
+        # we use Setting I (section 4.1 of the paper):
+        # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
+        # the conditions on n_hat are satisfied by construction.
+        for i in range(self.iter_count):
+            _assert_accuracy(self.max_f(i) == 0)
+            yield GoldschmidtDivOp.MulNByF
+            if i != self.iter_count - 1:
+                yield GoldschmidtDivOp.MulDByF
+                yield GoldschmidtDivOp.FEq2MinusD
+
+        # relative approximation error `p(N_prime[i])`:
+        # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
+        # `0 <= p(N_prime[i])`
+        # `p(N_prime[i]) <= (2 * i) * n_hat \`
+        # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
+        i = self.iter_count - 1  # last used `i`
+        # compute power manually to prevent huge intermediate values
+        power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
+        for _ in range(i):
+            power = self._shrink_max(power * power)
+
+        max_rel_error = (2 * i) * self.n_hat + power
+
+        min_a_over_b = Fraction(1, 2)
+        max_a_over_b = Fraction(2)
+        max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
+        max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+
+        _assert_accuracy(max_rel_error < max_allowed_rel_error,
+                         f"not accurate enough: max_rel_error={max_rel_error}"
+                         f" max_allowed_rel_error={max_allowed_rel_error}")
+
+        yield GoldschmidtDivOp.CalcResult
+
+    @cache_on_self
+    def default_cost_fn(self):
+        """ calculate the estimated cost on an arbitrary scale of implementing
+        goldschmidt division with the specified parameters. larger cost
+        values mean worse parameters.
+
+        This is the default cost function for `GoldschmidtDivParams.get`.
+
+        returns: float
+        """
+        rom_cells = self.table_data_bits << self.table_addr_bits
+        cost = float(rom_cells)
+        for op in self.ops:
+            if op == GoldschmidtDivOp.MulNByF \
+                    or op == GoldschmidtDivOp.MulDByF:
+                mul_cost = self.expanded_width ** 2
+                mul_cost *= self.expanded_width.bit_length()
+                cost += mul_cost
+        cost += 5e7 * self.iter_count
+        return cost
+
+    @staticmethod
+    @lru_cache(maxsize=1 << 16)
+    def __cached_new(base_params):
+        assert isinstance(base_params, GoldschmidtDivParamsBase)
+        # can't use dataclasses.asdict, since it's recursive and will also give
+        # child class fields too, which we don't want.
+        kwargs = {}
+        for field in fields(GoldschmidtDivParamsBase):
+            kwargs[field.name] = getattr(base_params, field.name)
+        try:
+            return GoldschmidtDivParams(**kwargs), None
+        except ParamsNotAccurateEnough as e:
+            return None, e
+
+    @staticmethod
+    def __raise(e):  # type: (ParamsNotAccurateEnough) -> Any
+        raise e
+
+    @staticmethod
+    def cached_new(base_params, handle_error=__raise):
+        assert isinstance(base_params, GoldschmidtDivParamsBase)
+        params, error = GoldschmidtDivParams.__cached_new(base_params)
+        if error is None:
+            return params
+        else:
+            return handle_error(error)
+
+    @staticmethod
+    def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
+        """ find efficient parameters for a goldschmidt division algorithm
+        with `params.io_width == io_width`.
+
+        arguments:
+        io_width: int
+            bit-width of the input divisor and the result.
+            the input numerator is `2 * io_width`-bits wide.
+        cost_fn: Callable[[GoldschmidtDivParams], float]
+            return the estimated cost on an arbitrary scale of implementing
+            goldschmidt division with the specified parameters. larger cost
+            values mean worse parameters.
+        max_table_addr_bits: int
+            maximum allowable value of `table_addr_bits`
+        """
+        assert isinstance(io_width, int) and io_width >= 1
+        assert callable(cost_fn)
+
+        last_error = None
+        last_error_params = None
+
+        def cached_new(base_params):
+            def handle_error(e):
+                nonlocal last_error, last_error_params
+                last_error = e
+                last_error_params = base_params
+                return None
+
+            retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
+            if retval is None:
+                logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
+            else:
+                logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
+            return retval
+
+        @lru_cache(maxsize=None)
+        def get_cost(base_params):
+            params = cached_new(base_params)
+            if params is None:
+                return math.inf
+            retval = cost_fn(params)
+            logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
+            return retval
+
+        # start with parameters big enough to always work.
+        initial_extra_precision = io_width * 2 + 4
+        initial_params = GoldschmidtDivParamsBase(
+            io_width=io_width,
+            extra_precision=initial_extra_precision,
+            table_addr_bits=min(max_table_addr_bits, io_width),
+            table_data_bits=io_width + initial_extra_precision,
+            iter_count=1 + io_width.bit_length())
+
+        if cached_new(initial_params) is None:
+            raise ValueError(f"initial goldschmidt division algorithm "
+                             f"parameters are invalid: {initial_params}"
+                             ) from last_error
+
+        # find good initial `iter_count`
+        params = initial_params
+        for iter_count in range(1, initial_params.iter_count):
+            trial_params = replace(params, iter_count=iter_count)
+            if cached_new(trial_params) is not None:
+                params = trial_params
+                break
+
+        # now find `table_addr_bits`
+        cost = get_cost(params)
+        for table_addr_bits in range(1, max_table_addr_bits):
+            trial_params = replace(params, table_addr_bits=table_addr_bits)
+            trial_cost = get_cost(trial_params)
+            if trial_cost < cost:
+                params = trial_params
+                cost = trial_cost
+                break
+
+        # check one higher `iter_count` to see if it has lower cost
+        for table_addr_bits in range(1, max_table_addr_bits + 1):
+            trial_params = replace(params,
+                                   table_addr_bits=table_addr_bits,
+                                   iter_count=params.iter_count + 1)
+            trial_cost = get_cost(trial_params)
+            if trial_cost < cost:
+                params = trial_params
+                cost = trial_cost
+                break
+
+        # now shrink `table_data_bits`
+        while True:
+            trial_params = replace(params,
+                                   table_data_bits=params.table_data_bits - 1)
+            trial_cost = get_cost(trial_params)
+            if trial_cost < cost:
+                params = trial_params
+                cost = trial_cost
+            else:
+                break
+
+        # and shrink `extra_precision`
+        while True:
+            trial_params = replace(params,
+                                   extra_precision=params.extra_precision - 1)
+            trial_cost = get_cost(trial_params)
+            if trial_cost < cost:
+                params = trial_params
+                cost = trial_cost
+            else:
+                break
+
+        return cached_new(params)
+
 
 @enum.unique
 class GoldschmidtDivOp(enum.Enum):
@@ -719,81 +935,6 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
 
-def _goldschmidt_div_ops(params):
-    """ Goldschmidt division algorithm.
-
-        based on:
-        Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
-        A Parametric Error Analysis of Goldschmidt's Division Algorithm.
-        https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
-
-        arguments:
-        params: GoldschmidtDivParams
-            the parameters for the algorithm
-
-        yields: GoldschmidtDivOp
-            the operations needed to perform the division.
-    """
-    assert isinstance(params, GoldschmidtDivParams)
-
-    # establish assumptions of the paper's error analysis (section 3.1):
-
-    # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
-    yield GoldschmidtDivOp.Normalize
-
-    # 2. ensure all relative errors from directed rounding are <= 1 / 4.
-    # the assumption is met by multipliers with > 4-bits precision
-    _assert_accuracy(params.expanded_width > 4)
-
-    # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
-    _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2
-                     + params.max_f(0) < Fraction(1, 2))
-
-    # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
-    # (B is the denominator)
-
-    for addr in range(params.table_addr_count):
-        f_prime_m1 = params.table[addr]
-        _assert_accuracy(0.5 <= f_prime_m1 <= 1)
-
-    yield GoldschmidtDivOp.FEqTableLookup
-
-    # we use Setting I (section 4.1 of the paper):
-    # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
-    n_hat = Fraction(0)
-    for i in range(params.iter_count):
-        _assert_accuracy(params.max_f(i) == 0)
-        n_hat = max(n_hat, params.max_n(i), params.max_d(i))
-        yield GoldschmidtDivOp.MulNByF
-        if i != params.iter_count - 1:
-            yield GoldschmidtDivOp.MulDByF
-            yield GoldschmidtDivOp.FEq2MinusD
-
-    # relative approximation error `p(N_prime[i])`:
-    # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
-    # `0 <= p(N_prime[i])`
-    # `p(N_prime[i]) <= (2 * i) * n_hat \`
-    # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
-    i = params.iter_count - 1  # last used `i`
-    # compute power manually to prevent huge intermediate values
-    power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2)
-    for _ in range(i):
-        power = params._shrink_max(power * power)
-
-    max_rel_error = (2 * i) * n_hat + power
-
-    min_a_over_b = Fraction(1, 2)
-    max_a_over_b = Fraction(2)
-    max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
-    max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
-
-    _assert_accuracy(max_rel_error < max_allowed_rel_error,
-                     f"not accurate enough: max_rel_error={max_rel_error} "
-                     f"max_allowed_rel_error={max_allowed_rel_error}")
-
-    yield GoldschmidtDivOp.CalcResult
-
-
 def goldschmidt_div(n, d, params):
     """ Goldschmidt division algorithm.