split out n_hat as separate property
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
index 03378048810b08eb2e87e10d77d4576625338279..f57e044889fac9d4650d888b6b1f928dbf7b5e27 100644 (file)
@@ -631,6 +631,15 @@ class GoldschmidtDivParams:
             max_n_shift += 1
         return max_n_shift
 
+    @cached_property
+    def n_hat(self):
+        """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
+        """
+        n_hat = Fraction(0)
+        for i in range(self.iter_count):
+            n_hat = max(n_hat, self.max_n(i), self.max_d(i))
+        return self._shrink_max(n_hat)
+
     def __make_ops(self):
         """ Goldschmidt division algorithm.
 
@@ -665,11 +674,10 @@ class GoldschmidtDivParams:
         yield GoldschmidtDivOp.FEqTableLookup
 
         # we use Setting I (section 4.1 of the paper):
-        # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
-        n_hat = Fraction(0)
+        # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
+        # the conditions on n_hat are satisfied by construction.
         for i in range(self.iter_count):
             _assert_accuracy(self.max_f(i) == 0)
-            n_hat = max(n_hat, self.max_n(i), self.max_d(i))
             yield GoldschmidtDivOp.MulNByF
             if i != self.iter_count - 1:
                 yield GoldschmidtDivOp.MulDByF
@@ -682,11 +690,11 @@ class GoldschmidtDivParams:
         # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
         i = self.iter_count - 1  # last used `i`
         # compute power manually to prevent huge intermediate values
-        power = self._shrink_max(self.max_abs_e0 + 3 * n_hat / 2)
+        power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
         for _ in range(i):
             power = self._shrink_max(power * power)
 
-        max_rel_error = (2 * i) * n_hat + power
+        max_rel_error = (2 * i) * self.n_hat + power
 
         min_a_over_b = Fraction(1, 2)
         max_a_over_b = Fraction(2)
@@ -699,6 +707,26 @@ class GoldschmidtDivParams:
 
         yield GoldschmidtDivOp.CalcResult
 
+    def default_cost_fn(self):
+        """ calculate the estimated cost on an arbitrary scale of implementing
+        goldschmidt division with the specified parameters. larger cost
+        values mean worse parameters.
+
+        This is the default cost function for `GoldschmidtDivParams.get`.
+
+        returns: float
+        """
+        rom_cells = self.table_data_bits << self.table_addr_bits
+        cost = float(rom_cells)
+        for op in self.ops:
+            if op == GoldschmidtDivOp.MulNByF \
+                    or op == GoldschmidtDivOp.MulDByF:
+                mul_cost = self.expanded_width ** 2
+                mul_cost *= self.expanded_width.bit_length()
+                cost += mul_cost
+        cost += 1e6 * self.iter_count
+        return cost
+
     @staticmethod
     def get(io_width):
         """ find efficient parameters for a goldschmidt division algorithm