implement FixedRSqrt
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 04:44:56 +0000 (21:44 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 04:44:56 +0000 (21:44 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index c073dd6fc4c4addd87a8ea8f820beabae22c3860..7fa051f2fdcfed335e04580d713a61199cc63ead 100644 (file)
@@ -592,5 +592,87 @@ def fixed_rsqrt(radicand):
 
 
 class FixedRSqrt:
-    # FIXME: finish
-    pass
+    """ Fixed-point Reciprocal-Square-Root/Remainder.
+
+    :attribute radicand: the radicand
+    :attribute root: the reciprocal square root
+    :attribute radicand_root: ``radicand * root``
+    :attribute radicand_root_squared: ``radicand * root * root``
+    :attribute remainder: the remainder
+    :attribute log2_radix: the base-2 log of the operation radix. The number of
+        bits of root that are calculated per pipeline stage.
+    :attribute current_shift: the current bit index
+    """
+
+    def __init__(self, radicand, log2_radix=3):
+        """ Create an FixedRSqrt.
+
+        :param radicand: the radicand.
+        :param log2_radix: the base-2 log of the operation radix. The number of
+            bits of root that are calculated per pipeline stage.
+        """
+        assert isinstance(radicand, Fixed)
+        assert radicand.signed is False
+        self.radicand = radicand
+        self.root = radicand.with_bits(0)
+        self.radicand_root = radicand.with_bits(0) * self.root
+        self.radicand_root_squared = self.radicand_root * self.root
+        self.remainder = radicand.with_bits(0) - self.radicand_root_squared
+        self.log2_radix = log2_radix
+        self.current_shift = self.root.bit_width
+
+    def calculate_stage(self):
+        """ Calculate the next pipeline stage of the operation.
+
+        :returns bool: True if this is the last pipeline stage.
+        """
+        if self.current_shift == 0:
+            return True
+        log2_radix = min(self.log2_radix, self.current_shift)
+        assert log2_radix > 0
+        self.current_shift -= log2_radix
+        radix = 1 << log2_radix
+        trial_values = []
+        for i in range(radix):
+            v = self.radicand_root_squared
+            factor1 = Fixed.from_bits(i << (self.current_shift + 1),
+                                      self.root.fract_width,
+                                      self.root.bit_width + 1 + log2_radix,
+                                      False)
+            v += self.radicand_root * factor1
+            factor2 = Fixed.from_bits(i << self.current_shift,
+                                      self.root.fract_width,
+                                      self.root.bit_width + log2_radix,
+                                      False)
+            v += self.radicand * factor2 * factor2
+            trial_values.append(self.radicand_root_squared.with_value(v))
+        root_bits = 0
+        new_radicand_root_squared = self.radicand_root_squared
+        for i in range(radix):
+            if 1 >= trial_values[i]:
+                root_bits = i
+                new_radicand_root_squared = trial_values[i]
+        v = self.radicand_root
+        v += self.radicand * Fixed.from_bits(root_bits << self.current_shift,
+                                             self.root.fract_width,
+                                             self.root.bit_width + log2_radix,
+                                             False)
+        self.radicand_root = self.radicand_root.with_value(v)
+        self.root |= Fixed.from_bits(root_bits << self.current_shift,
+                                     self.root.fract_width,
+                                     self.root.bit_width + log2_radix,
+                                     False)
+        self.radicand_root_squared = new_radicand_root_squared
+        if self.current_shift == 0:
+            self.remainder = 1 - self.radicand_root_squared
+            return True
+        return False
+
+    def calculate(self):
+        """ Calculate the results of the reciprocal square root.
+
+        :returns: self
+        """
+        while not self.calculate_stage():
+            pass
+        return self
index f633f1a4428dbf956289dc9466bd636774a5c3ba..c82641248b7e371510a7f4356bc675e30c63d31c 100644 (file)
@@ -824,4 +824,51 @@ class TestFixedRSqrtFn(unittest.TestCase):
                 self.assertEqual(str(fixed_rsqrt(radicand)), expected)
 
 
-# FIXME: add tests for FixedRSqrt
+class TestFixedRSqrt(unittest.TestCase):
+    def helper(self, log2_radix):
+        for bit_width in range(1, 8):
+            for fract_width in range(bit_width):
+                for radicand_bits in range(1, 1 << bit_width):
+                    radicand = Fixed.from_bits(radicand_bits,
+                                               fract_width,
+                                               bit_width,
+                                               False)
+                    root_remainder = fixed_rsqrt(radicand)
+                    with self.subTest(radicand=repr(radicand),
+                                      root_remainder=repr(root_remainder),
+                                      log2_radix=log2_radix):
+                        obj = FixedRSqrt(radicand, log2_radix)
+                        for _ in range(250 * bit_width):
+                            self.assertEqual(obj.radicand * obj.root,
+                                             obj.radicand_root)
+                            self.assertEqual(obj.radicand_root * obj.root,
+                                             obj.radicand_root_squared)
+                            self.assertGreaterEqual(1,
+                                                    obj.radicand_root_squared)
+                            if obj.calculate_stage():
+                                break
+                        else:
+                            self.fail("infinite loop")
+                        self.assertEqual(obj.radicand * obj.root,
+                                         obj.radicand_root)
+                        self.assertEqual(obj.radicand_root * obj.root,
+                                         obj.radicand_root_squared)
+                        self.assertGreaterEqual(1,
+                                                obj.radicand_root_squared)
+                        self.assertEqual(obj.remainder,
+                                         1 - obj.radicand_root_squared)
+                        self.assertEqual(obj.root, root_remainder.root)
+                        self.assertEqual(obj.remainder,
+                                         root_remainder.remainder)
+
+    def test_radix_2(self):
+        self.helper(1)
+
+    def test_radix_4(self):
+        self.helper(2)
+
+    def test_radix_8(self):
+        self.helper(3)
+
+    def test_radix_16(self):
+        self.helper(4)