implemented FixedSqrt
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 04:21:05 +0000 (21:21 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 04:21:05 +0000 (21:21 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 580895bfdffdc4c0df803648873ae1507c05efbc..6b9c2b19231338126ea6a9fe45dcbcade8c4dc8f 100644 (file)
@@ -485,8 +485,82 @@ def fixed_sqrt(radicand):
 
 
 class FixedSqrt:
-    # FIXME: finish
-    pass
+    """ Fixed-point Square-Root/Remainder.
+
+    :attribute radicand: the radicand
+    :attribute root: the square root
+    :attribute root_squared: the square of ``root``
+    :attribute remainder: the remainder
+    :attribute log2_radix: the base-2 log of the division radix. The number of
+        bits of quotient that are calculated per pipeline stage.
+    :attribute current_shift: the current bit index
+    """
+
+    def __init__(self, radicand, log2_radix=3):
+        """ Create an FixedSqrt.
+
+        :param radicand: the radicand.
+        :param log2_radix: the base-2 log of the division radix. The number of
+            bits of result that are calculated per pipeline stage.
+        """
+        assert isinstance(radicand, Fixed)
+        assert radicand.signed is False
+        self.radicand = radicand
+        self.root = radicand.with_bits(0)
+        self.root_squared = self.root * self.root
+        self.remainder = radicand.with_bits(0) - self.root_squared
+        self.log2_radix = log2_radix
+        self.current_shift = self.root.bit_width
+
+    def calculate_stage(self):
+        """ Calculate the next pipeline stage of the division.
+
+        :returns bool: True if this is the last pipeline stage.
+        """
+        if self.current_shift == 0:
+            return True
+        log2_radix = min(self.log2_radix, self.current_shift)
+        assert log2_radix > 0
+        self.current_shift -= log2_radix
+        radix = 1 << log2_radix
+        trial_squares = []
+        for i in range(radix):
+            v = self.root_squared
+            factor1 = Fixed.from_bits(i << (self.current_shift + 1),
+                                      self.root.fract_width,
+                                      self.root.bit_width + 1 + log2_radix,
+                                      False)
+            v += self.root * factor1
+            factor2 = Fixed.from_bits(i << self.current_shift,
+                                      self.root.fract_width,
+                                      self.root.bit_width + log2_radix,
+                                      False)
+            v += factor2 * factor2
+            trial_squares.append(self.root_squared.with_value(v))
+        root_bits = 0
+        new_root_squared = self.root_squared
+        for i in range(radix):
+            if self.radicand >= trial_squares[i]:
+                root_bits = i
+                new_root_squared = trial_squares[i]
+        self.root |= Fixed.from_bits(root_bits << self.current_shift,
+                                     self.root.fract_width,
+                                     self.root.bit_width + log2_radix,
+                                     False)
+        self.root_squared = new_root_squared
+        if self.current_shift == 0:
+            self.remainder = self.radicand - self.root_squared
+            return True
+        return False
+
+    def calculate(self):
+        """ Calculate the results of the square root.
+
+        :returns: self
+        """
+        while not self.calculate_stage():
+            pass
+        return self
 
 
 def fixed_rsqrt(radicand):
index 91b7b5e70cac5359a3316f22d0d16308a5112d0b..f633f1a4428dbf956289dc9466bd636774a5c3ba 100644 (file)
@@ -728,7 +728,50 @@ class TestFixedSqrtFn(unittest.TestCase):
                 self.assertEqual(str(fixed_sqrt(radicand)), expected)
 
 
-# FIXME: add tests for FixedSqrt
+class TestFixedSqrt(unittest.TestCase):
+    def helper(self, log2_radix):
+        for bit_width in range(1, 8):
+            for fract_width in range(bit_width):
+                for radicand_bits in range(1 << bit_width):
+                    radicand = Fixed.from_bits(radicand_bits,
+                                               fract_width,
+                                               bit_width,
+                                               False)
+                    root_remainder = fixed_sqrt(radicand)
+                    with self.subTest(radicand=repr(radicand),
+                                      root_remainder=repr(root_remainder),
+                                      log2_radix=log2_radix):
+                        obj = FixedSqrt(radicand, log2_radix)
+                        for _ in range(250 * bit_width):
+                            self.assertEqual(obj.root * obj.root,
+                                             obj.root_squared)
+                            self.assertGreaterEqual(obj.radicand,
+                                                    obj.root_squared)
+                            if obj.calculate_stage():
+                                break
+                        else:
+                            self.fail("infinite loop")
+                        self.assertEqual(obj.root * obj.root,
+                                         obj.root_squared)
+                        self.assertGreaterEqual(obj.radicand,
+                                                obj.root_squared)
+                        self.assertEqual(obj.remainder,
+                                         obj.radicand - obj.root_squared)
+                        self.assertEqual(obj.root, root_remainder.root)
+                        self.assertEqual(obj.remainder,
+                                         root_remainder.remainder)
+
+    def test_radix_2(self):
+        self.helper(1)
+
+    def test_radix_4(self):
+        self.helper(2)
+
+    def test_radix_8(self):
+        self.helper(3)
+
+    def test_radix_16(self):
+        self.helper(4)
 
 
 class TestFixedRSqrtFn(unittest.TestCase):