implement fixed_rsqrt
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 11:01:45 +0000 (04:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 11:01:45 +0000 (04:01 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index b5cde64d5b42167eacb650975d2fa6bd1f3cdcd8..580895bfdffdc4c0df803648873ae1507c05efbc 100644 (file)
@@ -490,8 +490,31 @@ class FixedSqrt:
 
 
 def fixed_rsqrt(radicand):
-    # FIXME: finish
-    raise NotImplementedError()
+    """ Compute the Reciprocal Square Root and Remainder.
+
+    Solves the polynomial ``1 - x * x * radicand == 0``
+
+    :param radicand: the ``Fixed`` to take the reciprocal square root of.
+    :returns RootRemainder:
+    """
+    # Written for correctness, not speed
+    if radicand <= 0:
+        return None
+    if not isinstance(radicand, Fixed):
+        raise TypeError()
+
+    def is_remainder_non_negative(root):
+        return 1 >= root * root * radicand
+
+    root = radicand.with_bits(0)
+    for i in reversed(range(root.bit_width)):
+        new_root = root.with_bits(root.bits | (1 << i))
+        if new_root < 0:  # skip sign bit
+            continue
+        if is_remainder_non_negative(new_root):
+            root = new_root
+    remainder = 1 - root * root * radicand
+    return RootRemainder(root, remainder)
 
 
 class FixedRSqrt:
index a93562639c2a9378cc693ccea2dce67517ca9fa4..91b7b5e70cac5359a3316f22d0d16308a5112d0b 100644 (file)
@@ -728,4 +728,57 @@ class TestFixedSqrtFn(unittest.TestCase):
                 self.assertEqual(str(fixed_sqrt(radicand)), expected)
 
 
-# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt
+# FIXME: add tests for FixedSqrt
+
+
+class TestFixedRSqrtFn(unittest.TestCase):
+    def test2(self):
+        for bits in range(1, 1 << 5):
+            radicand = Fixed.from_bits(bits, 5, 12, False)
+            float_root = 1 / math.sqrt(float(radicand))
+            root = radicand.with_value(float_root)
+            remainder = 1 - root * root * radicand
+            expected = RootRemainder(root, remainder)
+            with self.subTest(radicand=repr(radicand),
+                              expected=repr(expected)):
+                self.assertEqual(repr(fixed_rsqrt(radicand)),
+                                 repr(expected))
+
+    def test(self):
+        for signed in False, True:
+            for bit_width in range(1, 10):
+                for fract_width in range(bit_width):
+                    for bits in range(1 << bit_width):
+                        radicand = Fixed.from_bits(bits,
+                                                   fract_width,
+                                                   bit_width,
+                                                   signed)
+                        if radicand <= 0:
+                            continue
+                        float_root = 1 / math.sqrt(float(radicand))
+                        max_value = radicand.with_bits(
+                            (1 << (bit_width - signed)) - 1)
+                        if float_root > float(max_value):
+                            root = max_value
+                        else:
+                            root = radicand.with_value(float_root)
+                        remainder = 1 - root * root * radicand
+                        expected = RootRemainder(root, remainder)
+                        with self.subTest(radicand=repr(radicand),
+                                          expected=repr(expected)):
+                            self.assertEqual(repr(fixed_rsqrt(radicand)),
+                                             repr(expected))
+
+    def test_misc_cases(self):
+        test_cases = [
+            # radicand, expected
+            (Fixed(0.5, 30, 32, False),
+             "RootRemainder(fixed:0x1.6a09e664, "
+             "fixed:0x0.0000000596d014780000000)")
+        ]
+        for radicand, expected in test_cases:
+            with self.subTest(radicand=str(radicand), expected=expected):
+                self.assertEqual(str(fixed_rsqrt(radicand)), expected)
+
+
+# FIXME: add tests for FixedRSqrt