From: Jacob Lifshay Date: Wed, 27 Apr 2022 06:50:37 +0000 (-0700) Subject: improved goldschmidt division algorithm parameter optimization algorithm X-Git-Url: https://git.libre-soc.org/?p=soc.git;a=commitdiff_plain;h=756b6be9c2f867e3bb34724d2e16d1f93b43cd56 improved goldschmidt division algorithm parameter optimization algorithm --- diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index bdce6fd9..3af5320d 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -4,11 +4,13 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields, replace +import logging import math import enum from fractions import Fraction from types import FunctionType +from functools import lru_cache try: from functools import cached_property @@ -16,7 +18,7 @@ except ImportError: from cached_property import cached_property # fix broken IDE type detection for cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from functools import cached_property @@ -313,13 +315,12 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): Use `GoldschmidtDivParams.get` to find a efficient set of parameters. """ - # tuple to be immutable, default so repr() works for debugging even when + # tuple to be immutable, repr=False so repr() works for debugging even when # __post_init__ hasn't finished running yet - table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented) + table: "tuple[FixedPoint, ...]" = field(init=False, repr=False) """the lookup-table""" - ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, - default=NotImplemented) + ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False) """the operations needed to perform the goldschmidt division algorithm.""" def _shrink_bound(self, bound, round_dir): @@ -399,11 +400,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): def __post_init__(self): # called by the autogenerated __init__ - assert self.io_width >= 1 - assert self.extra_precision >= 0 - assert self.table_addr_bits >= 1 - assert self.table_data_bits >= 1 - assert self.iter_count >= 1 + _assert_accuracy(self.io_width >= 1, "io_width out of range") + _assert_accuracy(self.extra_precision >= 0, + "extra_precision out of range") + _assert_accuracy(self.table_addr_bits >= 1, + "table_addr_bits out of range") + _assert_accuracy(self.table_data_bits >= 1, + "table_data_bits out of range") + _assert_accuracy(self.iter_count >= 1, "iter_count out of range") table = [] for addr in range(1 << self.table_addr_bits): table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr), @@ -714,6 +718,7 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): yield GoldschmidtDivOp.CalcResult + @cache_on_self def default_cost_fn(self): """ calculate the estimated cost on an arbitrary scale of implementing goldschmidt division with the specified parameters. larger cost @@ -731,39 +736,147 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): mul_cost = self.expanded_width ** 2 mul_cost *= self.expanded_width.bit_length() cost += mul_cost - cost += 1e6 * self.iter_count + cost += 5e7 * self.iter_count return cost @staticmethod - def get(io_width): + @lru_cache(maxsize=1 << 16) + def __cached_new(base_params): + assert isinstance(base_params, GoldschmidtDivParamsBase) + # can't use dataclasses.asdict, since it's recursive and will also give + # child class fields too, which we don't want. + kwargs = {} + for field in fields(GoldschmidtDivParamsBase): + kwargs[field.name] = getattr(base_params, field.name) + try: + return GoldschmidtDivParams(**kwargs), None + except ParamsNotAccurateEnough as e: + return None, e + + @staticmethod + def __raise(e): # type: (ParamsNotAccurateEnough) -> Any + raise e + + @staticmethod + def cached_new(base_params, handle_error=__raise): + assert isinstance(base_params, GoldschmidtDivParamsBase) + params, error = GoldschmidtDivParams.__cached_new(base_params) + if error is None: + return params + else: + return handle_error(error) + + @staticmethod + def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12): """ find efficient parameters for a goldschmidt division algorithm with `params.io_width == io_width`. + + arguments: + io_width: int + bit-width of the input divisor and the result. + the input numerator is `2 * io_width`-bits wide. + cost_fn: Callable[[GoldschmidtDivParams], float] + return the estimated cost on an arbitrary scale of implementing + goldschmidt division with the specified parameters. larger cost + values mean worse parameters. + max_table_addr_bits: int + maximum allowable value of `table_addr_bits` """ assert isinstance(io_width, int) and io_width >= 1 - last_params = None + assert callable(cost_fn) + last_error = None - for extra_precision in range(io_width * 2 + 4): - for table_addr_bits in range(1, 7 + 1): - table_data_bits = io_width + extra_precision - for iter_count in range(1, 2 * io_width.bit_length()): - try: - return GoldschmidtDivParams( - io_width=io_width, - extra_precision=extra_precision, - table_addr_bits=table_addr_bits, - table_data_bits=table_data_bits, - iter_count=iter_count) - except ParamsNotAccurateEnough as e: - last_params = (f"GoldschmidtDivParams(" - f"io_width={io_width!r}, " - f"extra_precision={extra_precision!r}, " - f"table_addr_bits={table_addr_bits!r}, " - f"table_data_bits={table_data_bits!r}, " - f"iter_count={iter_count!r})") - last_error = e - raise ValueError(f"can't find working parameters for a goldschmidt " - f"division algorithm: last params: {last_params}" - ) from last_error + last_error_params = None + + def cached_new(base_params): + def handle_error(e): + nonlocal last_error, last_error_params + last_error = e + last_error_params = base_params + return None + + retval = GoldschmidtDivParams.cached_new(base_params, handle_error) + if retval is None: + logging.debug(f"GoldschmidtDivParams.get: err: {base_params}") + else: + logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}") + return retval + + @lru_cache(maxsize=None) + def get_cost(base_params): + params = cached_new(base_params) + if params is None: + return math.inf + retval = cost_fn(params) + logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}") + return retval + + # start with parameters big enough to always work. + initial_extra_precision = io_width * 2 + 4 + initial_params = GoldschmidtDivParamsBase( + io_width=io_width, + extra_precision=initial_extra_precision, + table_addr_bits=min(max_table_addr_bits, io_width), + table_data_bits=io_width + initial_extra_precision, + iter_count=1 + io_width.bit_length()) + + if cached_new(initial_params) is None: + raise ValueError(f"initial goldschmidt division algorithm " + f"parameters are invalid: {initial_params}" + ) from last_error + + # find good initial `iter_count` + params = initial_params + for iter_count in range(1, initial_params.iter_count): + trial_params = replace(params, iter_count=iter_count) + if cached_new(trial_params) is not None: + params = trial_params + break + + # now find `table_addr_bits` + cost = get_cost(params) + for table_addr_bits in range(1, max_table_addr_bits): + trial_params = replace(params, table_addr_bits=table_addr_bits) + trial_cost = get_cost(trial_params) + if trial_cost < cost: + params = trial_params + cost = trial_cost + break + + # check one higher `iter_count` to see if it has lower cost + for table_addr_bits in range(1, max_table_addr_bits + 1): + trial_params = replace(params, + table_addr_bits=table_addr_bits, + iter_count=params.iter_count + 1) + trial_cost = get_cost(trial_params) + if trial_cost < cost: + params = trial_params + cost = trial_cost + break + + # now shrink `table_data_bits` + while True: + trial_params = replace(params, + table_data_bits=params.table_data_bits - 1) + trial_cost = get_cost(trial_params) + if trial_cost < cost: + params = trial_params + cost = trial_cost + else: + break + + # and shrink `extra_precision` + while True: + trial_params = replace(params, + extra_precision=params.extra_precision - 1) + trial_cost = get_cost(trial_params) + if trial_cost < cost: + params = trial_params + cost = trial_cost + else: + break + + return cached_new(params) @enum.unique