fix waay-too-precise error requirements
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
index 62a82706964dc839939388af8ca22d02f9bfb7f9..6f739c33db8d6a97f55015b795709c5b4ee34b2e 100644 (file)
@@ -694,14 +694,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     def max_n_shift(self):
         """ maximum value of `state.n_shift`.
         """
-        # input numerator is `2*io_width`-bits
-        max_n = (1 << (self.io_width * 2)) - 1
-        max_n_shift = 0
-        # normalize so 1 <= n < 2
-        while max_n >= 2:
-            max_n >>= 1
-            max_n_shift += 1
-        return max_n_shift
+        # numerator must be less than `denominator << self.io_width`, so
+        # `n_shift` is at most `self.io_width`
+        return self.io_width
 
     @cached_property
     def n_hat(self):
@@ -769,13 +764,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         max_rel_error = (2 * i) * self.n_hat + power
 
         min_a_over_b = Fraction(1, 2)
-        max_a_over_b = Fraction(2)
-        max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
-        max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+        min_abs_error_for_correctness = min_a_over_b / (1 << self.max_n_shift)
+        min_rel_error_for_correctness = (min_abs_error_for_correctness
+                                         / min_a_over_b)
 
-        _assert_accuracy(max_rel_error < max_allowed_rel_error,
-                         f"not accurate enough: max_rel_error={max_rel_error}"
-                         f" max_allowed_rel_error={max_allowed_rel_error}")
+        _assert_accuracy(
+            max_rel_error < min_rel_error_for_correctness,
+            f"not accurate enough: max_rel_error={max_rel_error}"
+            f" min_rel_error_for_correctness={min_rel_error_for_correctness}")
 
         yield GoldschmidtDivOp.CalcResult
 
@@ -973,7 +969,8 @@ class GoldschmidtDivOp(enum.Enum):
             state.n_shift = 0
             # normalize so 1 <= n < 2
             while state.n >= 2:
-                state.n = (state.n * 0.5).to_frac_wid(expanded_width)
+                state.n = (state.n * 0.5).to_frac_wid(expanded_width,
+                                                      round_dir=RoundDir.DOWN)
                 state.n_shift += 1
         elif self == GoldschmidtDivOp.FEqTableLookup:
             # compute initial f by table lookup