implement fixed_sqrt
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
index a72f9243feb53bb65d6201a03fe288dbe72d3cd9..a93562639c2a9378cc693ccea2dce67517ca9fa4 100644 (file)
@@ -3,7 +3,8 @@
 
 from nmigen.hdl.ast import Const
 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
-                        Fixed, fixed_sqrt, FixedSqrt, fixed_rsqrt, FixedRSqrt)
+                        Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
+                        fixed_rsqrt, FixedRSqrt)
 import unittest
 import math
 
@@ -678,4 +679,53 @@ class TestFixed(unittest.TestCase):
                          "fixed:0x1.23450")
 
 
-# FIXME: add tests for fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt
+class TestFixedSqrtFn(unittest.TestCase):
+    def test_on_ints(self):
+        for radicand in range(-1, 32):
+            if radicand < 0:
+                expected = None
+            else:
+                root = math.floor(math.sqrt(radicand))
+                remainder = radicand - root * root
+                expected = RootRemainder(root, remainder)
+            with self.subTest(radicand=radicand, expected=expected):
+                self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+        radicand = 2 << 64
+        root = 0x16A09E667
+        remainder = radicand - root * root
+        expected = RootRemainder(root, remainder)
+        with self.subTest(radicand=radicand, expected=expected):
+            self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+
+    def test_on_fixed(self):
+        for signed in False, True:
+            for bit_width in range(1, 10):
+                for fract_width in range(bit_width):
+                    for bits in range(1 << bit_width):
+                        radicand = Fixed.from_bits(bits,
+                                                   fract_width,
+                                                   bit_width,
+                                                   signed)
+                        if radicand < 0:
+                            continue
+                        root = radicand.with_value(math.sqrt(float(radicand)))
+                        remainder = radicand - root * root
+                        expected = RootRemainder(root, remainder)
+                        with self.subTest(radicand=repr(radicand),
+                                          expected=repr(expected)):
+                            self.assertEqual(repr(fixed_sqrt(radicand)),
+                                             repr(expected))
+
+    def test_misc_cases(self):
+        test_cases = [
+            # radicand, expected
+            (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
+            (Fixed(2, 30, 32, False),
+             "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
+        ]
+        for radicand, expected in test_cases:
+            with self.subTest(radicand=str(radicand), expected=expected):
+                self.assertEqual(str(fixed_sqrt(radicand)), expected)
+
+
+# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt