implement fixed_sqrt
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 10:21:45 +0000 (03:21 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 10:21:45 +0000 (03:21 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 199450ed2163bad857c5a7070bb063a5e8a75787..b5cde64d5b42167eacb650975d2fa6bd1f3cdcd8 100644 (file)
@@ -200,6 +200,28 @@ class Fixed:
         self.bit_width = bit_width
         self.signed = signed
 
+    def with_bits(self, bits):
+        """ Create a new Fixed with the specified bits.
+
+        :param bits: the new bits.
+        :returns Fixed: the new Fixed.
+        """
+        return self.from_bits(bits,
+                              self.fract_width,
+                              self.bit_width,
+                              self.signed)
+
+    def with_value(self, value):
+        """ Create a new Fixed with the specified value.
+
+        :param value: the new value.
+        :returns Fixed: the new Fixed.
+        """
+        return Fixed(value,
+                     self.fract_width,
+                     self.bit_width,
+                     self.signed)
+
     def __repr__(self):
         """ Get representation."""
         retval = f"Fixed.from_bits({self.bits}, {self.fract_width}, "
@@ -217,7 +239,7 @@ class Fixed:
 
     def __float__(self):
         """ Convert to float."""
-        return self.bits * 2 ** -self.fract_width
+        return self.bits * 2.0 ** -self.fract_width
 
     def __floor__(self):
         """ Floor to integer."""
@@ -403,9 +425,63 @@ class Fixed:
         return retval
 
 
-def fixed_sqrt():
-    # FIXME: finish
-    raise NotImplementedError()
+class RootRemainder:
+    """ A polynomial root and remainder.
+
+    :attribute root: the polynomial root.
+    :attribute remainder: the remainder.
+    """
+
+    def __init__(self, root, remainder):
+        """ Create a new RootRemainder.
+
+        :param root: the polynomial root.
+        :param remainder: the remainder.
+        """
+        self.root = root
+        self.remainder = remainder
+
+    def __repr__(self):
+        """ Get the representation as a string. """
+        return f"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
+
+    def __str__(self):
+        """ Convert to a string. """
+        return f"RootRemainder({str(self.root)}, {str(self.remainder)})"
+
+
+def fixed_sqrt(radicand):
+    """ Compute the Square Root and Remainder.
+
+    Solves the polynomial ``radicand - x * x == 0``
+
+    :param radicand: the ``Fixed`` to take the square root of.
+    :returns RootRemainder:
+    """
+    # Written for correctness, not speed
+    if radicand < 0:
+        return None
+    is_int = isinstance(radicand, int)
+    if is_int:
+        radicand = Fixed(radicand, 0, radicand.bit_length() + 1, True)
+    elif not isinstance(radicand, Fixed):
+        raise TypeError()
+
+    def is_remainder_non_negative(root):
+        return radicand >= root * root
+
+    root = radicand.with_bits(0)
+    for i in reversed(range(root.bit_width)):
+        new_root = root.with_bits(root.bits | (1 << i))
+        if new_root < 0:  # skip sign bit
+            continue
+        if is_remainder_non_negative(new_root):
+            root = new_root
+    remainder = radicand - root * root
+    if is_int:
+        root = int(root)
+        remainder = int(remainder)
+    return RootRemainder(root, remainder)
 
 
 class FixedSqrt:
@@ -413,7 +489,7 @@ class FixedSqrt:
     pass
 
 
-def fixed_rsqrt():
+def fixed_rsqrt(radicand):
     # FIXME: finish
     raise NotImplementedError()
 
index a72f9243feb53bb65d6201a03fe288dbe72d3cd9..a93562639c2a9378cc693ccea2dce67517ca9fa4 100644 (file)
@@ -3,7 +3,8 @@
 
 from nmigen.hdl.ast import Const
 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
-                        Fixed, fixed_sqrt, FixedSqrt, fixed_rsqrt, FixedRSqrt)
+                        Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
+                        fixed_rsqrt, FixedRSqrt)
 import unittest
 import math
 
@@ -678,4 +679,53 @@ class TestFixed(unittest.TestCase):
                          "fixed:0x1.23450")
 
 
-# FIXME: add tests for fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt
+class TestFixedSqrtFn(unittest.TestCase):
+    def test_on_ints(self):
+        for radicand in range(-1, 32):
+            if radicand < 0:
+                expected = None
+            else:
+                root = math.floor(math.sqrt(radicand))
+                remainder = radicand - root * root
+                expected = RootRemainder(root, remainder)
+            with self.subTest(radicand=radicand, expected=expected):
+                self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+        radicand = 2 << 64
+        root = 0x16A09E667
+        remainder = radicand - root * root
+        expected = RootRemainder(root, remainder)
+        with self.subTest(radicand=radicand, expected=expected):
+            self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+
+    def test_on_fixed(self):
+        for signed in False, True:
+            for bit_width in range(1, 10):
+                for fract_width in range(bit_width):
+                    for bits in range(1 << bit_width):
+                        radicand = Fixed.from_bits(bits,
+                                                   fract_width,
+                                                   bit_width,
+                                                   signed)
+                        if radicand < 0:
+                            continue
+                        root = radicand.with_value(math.sqrt(float(radicand)))
+                        remainder = radicand - root * root
+                        expected = RootRemainder(root, remainder)
+                        with self.subTest(radicand=repr(radicand),
+                                          expected=repr(expected)):
+                            self.assertEqual(repr(fixed_sqrt(radicand)),
+                                             repr(expected))
+
+    def test_misc_cases(self):
+        test_cases = [
+            # radicand, expected
+            (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
+            (Fixed(2, 30, 32, False),
+             "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
+        ]
+        for radicand, expected in test_cases:
+            with self.subTest(radicand=str(radicand), expected=expected):
+                self.assertEqual(str(fixed_sqrt(radicand)), expected)
+
+
+# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt