X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fsoc%2Ffu%2Fdiv%2Fexperiment%2Fgoldschmidt_div_sqrt.py;h=62156ee5ec418d9490e31edb89a6f401d587a06f;hb=6af1a415a6885a2645fc1bc3adc8a3a3cc3aaaf7;hp=dc363c5a4bebe56df804193c36bf71fa11db76d7;hpb=97643943bcf8cf079c67f4e13e317dc1c748c0dd;p=soc.git diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index dc363c5a..62156ee5 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -125,13 +125,6 @@ class FixedPoint: denominator = value.denominator else: value = FixedPoint.cast(value) - # compute number of bits that should be removed from value - del_bits = value.frac_wid - frac_wid - if del_bits == 0: - return value - if del_bits < 0: # add bits - return FixedPoint(value.bits << -del_bits, - frac_wid) numerator = value.bits denominator = 1 << value.frac_wid if denominator < 0: @@ -313,13 +306,49 @@ class GoldschmidtDivParams: iter_count: int """the total number of iterations of the division algorithm's loop""" - # tuple to be immutable - table: "tuple[FixedPoint, ...]" = field(init=False) + # tuple to be immutable, default so repr() works for debugging even when + # __post_init__ hasn't finished running yet + table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented) """the lookup-table""" - ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False) + ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, + default=NotImplemented) """the operations needed to perform the goldschmidt division algorithm.""" + def _shrink_bound(self, bound, round_dir): + """prevent fractions from having huge numerators/denominators by + rounding to a `FixedPoint` and converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + assert isinstance(bound, (Fraction, int)) + assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \ + "you shouldn't use that round_dir on bounds" + frac_wid = self.io_width * 4 + 100 # should be enough precision + fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir) + return fixed.as_fraction() + + def _shrink_min(self, min_bound): + """prevent fractions used as minimum bounds from having huge + numerators/denominators by rounding down to a `FixedPoint` and + converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + return self._shrink_bound(min_bound, RoundDir.DOWN) + + def _shrink_max(self, max_bound): + """prevent fractions used as maximum bounds from having huge + numerators/denominators by rounding up to a `FixedPoint` and + converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + return self._shrink_bound(max_bound, RoundDir.UP) + @property def table_addr_count(self): """number of distinct addresses in the lookup-table.""" @@ -332,20 +361,29 @@ class GoldschmidtDivParams: assert isinstance(addr, int) assert 0 <= addr < self.table_addr_count _assert_accuracy(self.io_width >= self.table_addr_bits) - min_numerator = (1 << self.table_addr_bits) + addr - denominator = 1 << self.table_addr_bits - values_per_table_entry = 1 << (self.io_width - self.table_addr_bits) - max_numerator = min_numerator + values_per_table_entry + addr_shift = self.io_width - self.table_addr_bits + min_numerator = (1 << self.io_width) + (addr << addr_shift) + denominator = 1 << self.io_width + values_per_table_entry = 1 << addr_shift + max_numerator = min_numerator + values_per_table_entry - 1 min_input = Fraction(min_numerator, denominator) max_input = Fraction(max_numerator, denominator) + min_input = self._shrink_min(min_input) + max_input = self._shrink_max(max_input) + assert 1 <= min_input <= max_input < 2 return min_input, max_input def table_value_exact_range(self, addr): """return the range of values as `Fraction`s used for the table entry with address `addr`.""" - min_value, max_value = self.table_input_exact_range(addr) + min_input, max_input = self.table_input_exact_range(addr) # division swaps min/max - return 1 / max_value, 1 / min_value + min_value = 1 / max_input + max_value = 1 / min_input + min_value = self._shrink_min(min_value) + max_value = self._shrink_max(max_value) + assert 0.5 < min_value <= max_value <= 1 + return min_value, max_value def table_exact_value(self, index): min_value, max_value = self.table_value_exact_range(index) @@ -374,6 +412,8 @@ class GoldschmidtDivParams: with `params.io_width == io_width`. """ assert isinstance(io_width, int) and io_width >= 1 + last_params = None + last_error = None for extra_precision in range(io_width * 2 + 4): for table_addr_bits in range(1, 7 + 1): table_data_bits = io_width + extra_precision @@ -385,10 +425,17 @@ class GoldschmidtDivParams: table_addr_bits=table_addr_bits, table_data_bits=table_data_bits, iter_count=iter_count) - except ParamsNotAccurateEnough: - pass + except ParamsNotAccurateEnough as e: + last_params = (f"GoldschmidtDivParams(" + f"io_width={io_width!r}, " + f"extra_precision={extra_precision!r}, " + f"table_addr_bits={table_addr_bits!r}, " + f"table_data_bits={table_data_bits!r}, " + f"iter_count={iter_count!r})") + last_error = e raise ValueError(f"can't find working parameters for a goldschmidt " - f"division algorithm with io_width={io_width}") + f"division algorithm: last params: {last_params}" + ) from last_error @property def expanded_width(self): @@ -442,6 +489,8 @@ class GoldschmidtDivParams: cur_max_e0 = 1 - min_product min_e0 = min(min_e0, cur_min_e0) max_e0 = max(max_e0, cur_max_e0) + min_e0 = self._shrink_min(min_e0) + max_e0 = self._shrink_max(max_e0) return min_e0, max_e0 @cached_property @@ -500,9 +549,7 @@ class GoldschmidtDivParams: "only one quadrant of interval division implemented" retval = self.max_neps(i) / min_mpd - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_d(self, i): @@ -533,9 +580,7 @@ class GoldschmidtDivParams: "only one quadrant of interval division implemented" retval = self.max_deps(i) / (1 - self.max_delta(i - 1)) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_f(self, i): @@ -559,9 +604,7 @@ class GoldschmidtDivParams: # `f[i] <= max_feps[i]` retval = self.max_feps(i) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_delta(self, i): @@ -578,9 +621,11 @@ class GoldschmidtDivParams: assert prev_max_delta >= 0 retval = prev_max_delta ** 2 + self.max_f(i - 1) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + # `delta[i]` has to be smaller than one otherwise errors would go off + # to infinity + _assert_accuracy(retval < 1) + + return self._shrink_max(retval) @cache_on_self def max_pi(self, i): @@ -591,7 +636,7 @@ class GoldschmidtDivParams: # `pi[i] = 1 - (1 - n[i]) * prod` # where `prod` is the product of, # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])` - min_prod = Fraction(0) + min_prod = Fraction(1) for j in range(i): max_n_j = self.max_n(j) max_d_j = self.max_d(j) @@ -601,7 +646,8 @@ class GoldschmidtDivParams: max_n_i = self.max_n(i) assert max_n_i <= 1 and min_prod >= 0, \ "only one quadrant of interval multiplication implemented" - return 1 - (1 - max_n_i) * min_prod + retval = 1 - (1 - max_n_i) * min_prod + return self._shrink_max(retval) @cached_property def max_n_shift(self): @@ -729,15 +775,21 @@ def _goldschmidt_div_ops(params): # `p(N_prime[i]) <= (2 * i) * n_hat \` # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)` i = params.iter_count - 1 # last used `i` - max_rel_error = (2 * i) * n_hat + \ - (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i) + # compute power manually to prevent huge intermediate values + power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2) + for _ in range(i): + power = params._shrink_max(power * power) + + max_rel_error = (2 * i) * n_hat + power min_a_over_b = Fraction(1, 2) max_a_over_b = Fraction(2) max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift) max_allowed_rel_error = max_allowed_abs_error / min_a_over_b - _assert_accuracy(max_rel_error < max_allowed_rel_error) + _assert_accuracy(max_rel_error < max_allowed_rel_error, + f"not accurate enough: max_rel_error={max_rel_error} " + f"max_allowed_rel_error={max_allowed_rel_error}") yield GoldschmidtDivOp.CalcResult