From: Jacob Lifshay Date: Mon, 25 Apr 2022 08:44:31 +0000 (-0700) Subject: working on goldschmidt_div_sqrt.py X-Git-Url: https://git.libre-soc.org/?p=soc.git;a=commitdiff_plain;h=2c405908df52cad2c15767fed13873fe19b554df working on goldschmidt_div_sqrt.py --- diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index f4aee9da..dc363c5a 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -8,6 +8,46 @@ from dataclasses import dataclass, field import math import enum from fractions import Fraction +from types import FunctionType + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +# fix broken IDE type detection for cached_property +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from functools import cached_property + + +_NOT_FOUND = object() + + +def cache_on_self(func): + """like `functools.cached_property`, except for methods. unlike + `lru_cache` the cache is per-class instance rather than a global cache + per-method.""" + + assert isinstance(func, FunctionType), \ + "non-plain methods are not supported" + + cache_name = func.__name__ + "__cache" + + def wrapper(self, *args, **kwargs): + # specifically access through `__dict__` to bypass frozen=True + cache = self.__dict__.get(cache_name, _NOT_FOUND) + if cache is _NOT_FOUND: + self.__dict__[cache_name] = cache = {} + key = (args, *kwargs.items()) + retval = cache.get(key, _NOT_FOUND) + if retval is _NOT_FOUND: + retval = func(self, *args, **kwargs) + cache[key] = retval + return retval + + wrapper.__doc__ = func.__doc__ + return wrapper @enum.unique @@ -212,14 +252,27 @@ class FixedPoint: @dataclass class GoldschmidtDivState: + orig_n: int + """original numerator""" + + orig_d: int + """original denominator""" + n: FixedPoint """numerator -- N_prime[i] in the paper's algorithm 2""" + d: FixedPoint """denominator -- D_prime[i] in the paper's algorithm 2""" + f: "FixedPoint | None" = None """current factor -- F_prime[i] in the paper's algorithm 2""" - result: "int | None" = None - """final result""" + + quotient: "int | None" = None + """final quotient""" + + remainder: "int | None" = None + """final remainder""" + n_shift: "int | None" = None """amount the numerator needs to be left-shifted at the end of the algorithm. @@ -242,19 +295,28 @@ class GoldschmidtDivParams: """parameters for a Goldschmidt division algorithm. Use `GoldschmidtDivParams.get` to find a efficient set of parameters. """ + io_width: int """bit-width of the input divisor and the result. the input numerator is `2 * io_width`-bits wide. """ + extra_precision: int """number of bits of additional precision used inside the algorithm.""" + table_addr_bits: int """the number of address bits used in the lookup-table.""" + table_data_bits: int """the number of data bits used in the lookup-table.""" + + iter_count: int + """the total number of iterations of the division algorithm's loop""" + # tuple to be immutable table: "tuple[FixedPoint, ...]" = field(init=False) """the lookup-table""" + ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False) """the operations needed to perform the goldschmidt division algorithm.""" @@ -269,7 +331,7 @@ class GoldschmidtDivParams: with address `addr`.""" assert isinstance(addr, int) assert 0 <= addr < self.table_addr_count - assert self.io_width >= self.table_addr_bits + _assert_accuracy(self.io_width >= self.table_addr_bits) min_numerator = (1 << self.table_addr_bits) + addr denominator = 1 << self.table_addr_bits values_per_table_entry = 1 << (self.io_width - self.table_addr_bits) @@ -296,6 +358,7 @@ class GoldschmidtDivParams: assert self.extra_precision >= 0 assert self.table_addr_bits >= 1 assert self.table_data_bits >= 1 + assert self.iter_count >= 1 table = [] for addr in range(1 << self.table_addr_bits): table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr), @@ -311,17 +374,19 @@ class GoldschmidtDivParams: with `params.io_width == io_width`. """ assert isinstance(io_width, int) and io_width >= 1 - for extra_precision in range(io_width * 2): - for table_addr_bits in range(3, 7 + 1): + for extra_precision in range(io_width * 2 + 4): + for table_addr_bits in range(1, 7 + 1): table_data_bits = io_width + extra_precision - try: - return GoldschmidtDivParams( - io_width=io_width, - extra_precision=extra_precision, - table_addr_bits=table_addr_bits, - table_data_bits=table_data_bits) - except ParamsNotAccurateEnough: - pass + for iter_count in range(1, 2 * io_width.bit_length()): + try: + return GoldschmidtDivParams( + io_width=io_width, + extra_precision=extra_precision, + table_addr_bits=table_addr_bits, + table_data_bits=table_data_bits, + iter_count=iter_count) + except ParamsNotAccurateEnough: + pass raise ValueError(f"can't find working parameters for a goldschmidt " f"division algorithm with io_width={io_width}") @@ -330,6 +395,227 @@ class GoldschmidtDivParams: """the total number of bits of precision used inside the algorithm.""" return self.io_width + self.extra_precision + @cache_on_self + def max_neps(self, i): + """maximum value of `neps[i]`. + `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`. + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + return Fraction(1, 1 << self.expanded_width) + + @cache_on_self + def max_deps(self, i): + """maximum value of `deps[i]`. + `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`. + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + return Fraction(1, 1 << self.expanded_width) + + @cache_on_self + def max_feps(self, i): + """maximum value of `feps[i]`. + `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`. + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + # zero, because the computation of `F_prime[i]` in + # `GoldschmidtDivOp.MulDByF.run(...)` is exact. + return Fraction(0) + + @cached_property + def e0_range(self): + """minimum and maximum values of `e[0]` + (the relative error in `F_prime[-1]`) + """ + min_e0 = Fraction(0) + max_e0 = Fraction(0) + for addr in range(self.table_addr_count): + # `F_prime[-1] = (1 - e[0]) / B` + # => `e[0] = 1 - B * F_prime[-1]` + min_b, max_b = self.table_input_exact_range(addr) + f_prime_m1 = self.table[addr].as_fraction() + assert min_b >= 0 and f_prime_m1 >= 0, \ + "only positive quadrant of interval multiplication implemented" + min_product = min_b * f_prime_m1 + max_product = max_b * f_prime_m1 + # negation swaps min/max + cur_min_e0 = 1 - max_product + cur_max_e0 = 1 - min_product + min_e0 = min(min_e0, cur_min_e0) + max_e0 = max(max_e0, cur_max_e0) + return min_e0, max_e0 + + @cached_property + def min_e0(self): + """minimum value of `e[0]` (the relative error in `F_prime[-1]`) + """ + min_e0, max_e0 = self.e0_range + return min_e0 + + @cached_property + def max_e0(self): + """maximum value of `e[0]` (the relative error in `F_prime[-1]`) + """ + min_e0, max_e0 = self.e0_range + return max_e0 + + @cached_property + def max_abs_e0(self): + """maximum value of `abs(e[0])`.""" + return max(abs(self.min_e0), abs(self.max_e0)) + + @cached_property + def min_abs_e0(self): + """minimum value of `abs(e[0])`.""" + return Fraction(0) + + @cache_on_self + def max_n(self, i): + """maximum value of `n[i]` (the relative error in `N_prime[i]` + relative to the previous iteration) + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + if i == 0: + # from Claim 10 + # `n[0] = neps[0] / ((1 - e[0]) * (A / B))` + # `n[0] <= 2 * neps[0] / (1 - e[0])` + + assert self.max_e0 < 1 and self.max_neps(0) >= 0, \ + "only one quadrant of interval division implemented" + retval = 2 * self.max_neps(0) / (1 - self.max_e0) + elif i == 1: + # from Claim 10 + # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))` + min_mpd = 1 - self.max_pi(0) - self.max_delta(0) + assert self.max_f(0) <= 1 and min_mpd >= 0, \ + "only one quadrant of interval multiplication implemented" + prod = (1 - self.max_f(0)) * min_mpd + assert self.max_neps(1) >= 0 and prod > 0, \ + "only one quadrant of interval division implemented" + retval = self.max_neps(1) / prod + else: + # from Claim 6 + # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])` + min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1) + assert self.max_neps(i) >= 0 and min_mpd > 0, \ + "only one quadrant of interval division implemented" + retval = self.max_neps(i) / min_mpd + + # we need Fraction to avoid using float by accident + # -- it also hints to the IDE to give the correct type + return Fraction(retval) + + @cache_on_self + def max_d(self, i): + """maximum value of `d[i]` (the relative error in `D_prime[i]` + relative to the previous iteration) + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + if i == 0: + # from Claim 10 + # `d[0] = deps[0] / (1 - e[0])` + + assert self.max_e0 < 1 and self.max_deps(0) >= 0, \ + "only one quadrant of interval division implemented" + retval = self.max_deps(0) / (1 - self.max_e0) + elif i == 1: + # from Claim 10 + # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))` + assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \ + "only one quadrant of interval multiplication implemented" + divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2) + assert self.max_deps(1) >= 0 and divisor > 0, \ + "only one quadrant of interval division implemented" + retval = self.max_deps(1) / divisor + else: + # from Claim 6 + # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])` + assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \ + "only one quadrant of interval division implemented" + retval = self.max_deps(i) / (1 - self.max_delta(i - 1)) + + # we need Fraction to avoid using float by accident + # -- it also hints to the IDE to give the correct type + return Fraction(retval) + + @cache_on_self + def max_f(self, i): + """maximum value of `f[i]` (the relative error in `F_prime[i]` + relative to the previous iteration) + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + if i == 0: + # from Claim 10 + # `f[0] = feps[0] / (1 - delta[0])` + + assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \ + "only one quadrant of interval division implemented" + retval = self.max_feps(0) / (1 - self.max_delta(0)) + elif i == 1: + # from Claim 10 + # `f[1] = feps[1]` + retval = self.max_feps(1) + else: + # from Claim 6 + # `f[i] <= max_feps[i]` + retval = self.max_feps(i) + + # we need Fraction to avoid using float by accident + # -- it also hints to the IDE to give the correct type + return Fraction(retval) + + @cache_on_self + def max_delta(self, i): + """ maximum value of `delta[i]`. + `delta[i]` is defined in Definition 4 of paper. + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + if i == 0: + # `delta[0] = abs(e[0]) + 3 * d[0] / 2` + retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0) + else: + # `delta[i] = delta[i - 1] ** 2 + f[i - 1]` + prev_max_delta = self.max_delta(i - 1) + assert prev_max_delta >= 0 + retval = prev_max_delta ** 2 + self.max_f(i - 1) + + # we need Fraction to avoid using float by accident + # -- it also hints to the IDE to give the correct type + return Fraction(retval) + + @cache_on_self + def max_pi(self, i): + """ maximum value of `pi[i]`. + `pi[i]` is defined right below Theorem 5 of paper. + """ + assert isinstance(i, int) and 0 <= i < self.iter_count + # `pi[i] = 1 - (1 - n[i]) * prod` + # where `prod` is the product of, + # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])` + min_prod = Fraction(0) + for j in range(i): + max_n_j = self.max_n(j) + max_d_j = self.max_d(j) + assert max_n_j <= 1 and max_d_j > -1, \ + "only one quadrant of interval division implemented" + min_prod *= (1 - max_n_j) / (1 + max_d_j) + max_n_i = self.max_n(i) + assert max_n_i <= 1 and min_prod >= 0, \ + "only one quadrant of interval multiplication implemented" + return 1 - (1 - max_n_i) * min_prod + + @cached_property + def max_n_shift(self): + """ maximum value of `state.n_shift`. + """ + # input numerator is `2*io_width`-bits + max_n = (1 << (self.io_width * 2)) - 1 + max_n_shift = 0 + # normalize so 1 <= n < 2 + while max_n >= 2: + max_n >>= 1 + max_n_shift += 1 + return max_n_shift + @enum.unique class GoldschmidtDivOp(enum.Enum): @@ -378,9 +664,11 @@ class GoldschmidtDivOp(enum.Enum): # scale to correct value n = state.n * (1 << state.n_shift) - # avoid incorrectly rounding down - n = n.to_frac_wid(params.io_width, round_dir=RoundDir.UP) - state.result = math.floor(n) + state.quotient = math.floor(n) + state.remainder = state.orig_n - state.quotient * state.orig_d + if state.remainder >= state.orig_d: + state.quotient += 1 + state.remainder -= state.orig_d else: assert False, f"unimplemented GoldschmidtDivOp: {self}" @@ -412,37 +700,8 @@ def _goldschmidt_div_ops(params): _assert_accuracy(params.expanded_width > 4) # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`. - - # maximum `abs(e[0])` - max_abs_e0 = 0 - # maximum `d[0]` - max_d0 = 0 - # `f[i] = 0` for all `i` - fi = 0 - for addr in range(params.table_addr_count): - # `F_prime[-1] = (1 - e[0]) / B` - # => `e[0] = 1 - B * F_prime[-1]` - min_b, max_b = params.table_input_exact_range(addr) - f_prime_m1 = params.table[addr].as_fraction() - assert min_b >= 0 and f_prime_m1 >= 0, \ - "only positive quadrant of interval multiplication implemented" - min_product = min_b * f_prime_m1 - max_product = max_b * f_prime_m1 - # negation swaps min/max - min_e0 = 1 - max_product - max_e0 = 1 - min_product - max_abs_e0 = max(max_abs_e0, abs(min_e0), abs(max_e0)) - - # `D_prime[0] = (1 + d[0]) * B * F_prime[-1]` - # `D_prime[0] = abs_round_err + B * F_prime[-1]` - # => `d[0] = abs_round_err / (B * F_prime[-1])` - max_abs_round_err = Fraction(1, 1 << params.expanded_width) - assert min_product > 0 and max_abs_round_err >= 0, \ - "only positive quadrant of interval division implemented" - # division swaps divisor's min/max - max_d0 = max(max_d0, max_abs_round_err / min_product) - - _assert_accuracy(max_abs_e0 + 3 * max_d0 / 2 + fi < Fraction(1, 2)) + _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2 + + params.max_f(0) < Fraction(1, 2)) # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1]. # (B is the denominator) @@ -453,16 +712,32 @@ def _goldschmidt_div_ops(params): yield GoldschmidtDivOp.FEqTableLookup - # we use Setting I (section 4.1 of the paper) - - min_bits_of_precision = 1 - # FIXME: calculate error and check if it's small enough - while min_bits_of_precision < params.io_width * 2: + # we use Setting I (section 4.1 of the paper): + # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0` + n_hat = Fraction(0) + for i in range(params.iter_count): + _assert_accuracy(params.max_f(i) == 0) + n_hat = max(n_hat, params.max_n(i), params.max_d(i)) yield GoldschmidtDivOp.MulNByF - yield GoldschmidtDivOp.MulDByF - yield GoldschmidtDivOp.FEq2MinusD - - min_bits_of_precision *= 2 + if i != params.iter_count - 1: + yield GoldschmidtDivOp.MulDByF + yield GoldschmidtDivOp.FEq2MinusD + + # relative approximation error `p(N_prime[i])`: + # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)` + # `0 <= p(N_prime[i])` + # `p(N_prime[i]) <= (2 * i) * n_hat \` + # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)` + i = params.iter_count - 1 # last used `i` + max_rel_error = (2 * i) * n_hat + \ + (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i) + + min_a_over_b = Fraction(1, 2) + max_a_over_b = Fraction(2) + max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift) + max_allowed_rel_error = max_allowed_abs_error / min_a_over_b + + _assert_accuracy(max_rel_error < max_allowed_rel_error) yield GoldschmidtDivOp.CalcResult @@ -485,8 +760,9 @@ def goldschmidt_div(n, d, params): width: int the bit-width of the inputs/outputs. must be a positive integer. - returns: int - the quotient. a `width`-bit unsigned integer. + returns: tuple[int, int] + the quotient and remainder. a tuple of two `width`-bit unsigned + integers. """ assert isinstance(params, GoldschmidtDivParams) assert isinstance(d, int) and 0 < d < (1 << params.io_width) @@ -496,6 +772,8 @@ def goldschmidt_div(n, d, params): # have `width` fractional bits state = GoldschmidtDivState( + orig_n=n, + orig_d=d, n=FixedPoint(n, params.io_width), d=FixedPoint(d, params.io_width), ) @@ -503,6 +781,7 @@ def goldschmidt_div(n, d, params): for op in params.ops: op.run(params, state) - assert state.result is not None + assert state.quotient is not None + assert state.remainder is not None - return state.result + return state.quotient, state.remainder diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index b4c9da7f..fd07d615 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -25,18 +25,21 @@ class TestGoldschmidtDiv(FHDLTestCase): def tst(self, io_width): assert isinstance(io_width, int) params = GoldschmidtDivParams.get(io_width) - print(params) - for d in range(1, 1 << io_width): - for n in range(d << io_width): - expected = n // d - with self.subTest(io_width=io_width, n=hex(n), d=hex(d), - expected=hex(expected)): - result = goldschmidt_div(n, d, params) - self.assertEqual(result, expected, f"result={hex(result)}") + with self.subTest(params=str(params)): + for d in range(1, 1 << io_width): + for n in range(d << io_width): + expected_q, expected_r = divmod(n, d) + with self.subTest(n=hex(n), d=hex(d), + expected_q=hex(expected_q), + expected_r=hex(expected_r)): + q, r = goldschmidt_div(n, d, params) + with self.subTest(q=hex(q), r=hex(r)): + self.assertEqual((q, r), (expected_q, expected_r)) def test_1_through_5(self): - for width in range(1, 5 + 1): - self.tst(width) + for io_width in range(1, 5 + 1): + with self.subTest(io_width=io_width): + self.tst(io_width) def test_6(self): self.tst(6)