added tests for rest of Fixed
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 07:01:32 +0000 (00:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 1 Jul 2019 07:01:32 +0000 (00:01 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 7ec21510b1eebb15864d93fa06cb41bdf1a9db11..199450ed2163bad857c5a7070bb063a5e8a75787 100644 (file)
@@ -333,6 +333,10 @@ class Fixed:
         bits = self.bits * rhs_bits
         return self.from_bits(bits, fract_width, bit_width, self.signed)
 
+    def __rmul__(self, rhs):
+        """ Reverse Multiplication. """
+        return self.__mul__(rhs)
+
     @staticmethod
     def _cmp_impl(lhs, rhs, fract_width, bit_width, signed):
         if lhs < rhs:
@@ -374,7 +378,7 @@ class Fixed:
         """ Greater Than or Equal."""
         return self.cmp(rhs) >= 0
 
-    def __bool__(self, rhs):
+    def __bool__(self):
         """ Convert to bool."""
         return bool(self.bits)
 
index b1a944d918ed99aa7801ed71301d5d1ecc95dd06..a72f9243feb53bb65d6201a03fe288dbe72d3cd9 100644 (file)
@@ -482,27 +482,147 @@ class TestFixed(unittest.TestCase):
             with self.subTest(value=repr(value)):
                 self.assertEqual(float(~value), (~i) / 4)
 
-    # TODO: add test for _binary_op
-    # TODO: add test for __add__
-    # TODO: add test for __radd__
-    # TODO: add test for __sub__
-    # TODO: add test for __rsub__
-    # TODO: add test for __and__
-    # TODO: add test for __rand__
-    # TODO: add test for __or__
-    # TODO: add test for __ror__
-    # TODO: add test for __xor__
-    # TODO: add test for __rxor__
-    # TODO: add test for __mul__
-    # TODO: add test for _cmp_impl
-    # TODO: add test for cmp
-    # TODO: add test for __lt__
-    # TODO: add test for __le__
-    # TODO: add test for __eq__
-    # TODO: add test for __ne__
-    # TODO: add test for __gt__
-    # TODO: add test for __ge__
-    # TODO: add test for __bool__
+    @staticmethod
+    def get_test_values(max_bit_width, include_int):
+        for signed in False, True:
+            if include_int:
+                for bits in range(1 << max_bit_width):
+                    int_value = Const.normalize(bits, (max_bit_width, signed))
+                    yield int_value
+            for bit_width in range(1, max_bit_width):
+                for fract_width in range(bit_width + 1):
+                    for bits in range(1 << bit_width):
+                        yield Fixed.from_bits(bits,
+                                              fract_width,
+                                              bit_width,
+                                              signed)
+
+    def binary_op_test_helper(self,
+                              operation,
+                              is_fixed=True,
+                              width_combine_op=max,
+                              adjust_bits_op=None):
+        def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
+            return bits << (out_fract_width - in_fract_width)
+        if adjust_bits_op is None:
+            adjust_bits_op = default_adjust_bits_op
+        max_bit_width = 5
+        for lhs in self.get_test_values(max_bit_width, True):
+            lhs_is_int = isinstance(lhs, int)
+            for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
+                rhs_is_int = isinstance(rhs, int)
+                if lhs_is_int:
+                    assert not rhs_is_int
+                    lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
+                    int_result = operation(lhs_int, rhs.bits)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   rhs.fract_width,
+                                                   rhs.bit_width,
+                                                   rhs.signed)
+                    else:
+                        expected = int_result
+                elif rhs_is_int:
+                    rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
+                    int_result = operation(lhs.bits, rhs_int)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   lhs.fract_width,
+                                                   lhs.bit_width,
+                                                   lhs.signed)
+                    else:
+                        expected = int_result
+                elif lhs.signed != rhs.signed:
+                    continue
+                else:
+                    fract_width = width_combine_op(lhs.fract_width,
+                                                   rhs.fract_width)
+                    int_width = width_combine_op(lhs.bit_width
+                                                 - lhs.fract_width,
+                                                 rhs.bit_width
+                                                 - rhs.fract_width)
+                    bit_width = fract_width + int_width
+                    lhs_int = adjust_bits_op(lhs.bits,
+                                             fract_width,
+                                             lhs.fract_width)
+                    rhs_int = adjust_bits_op(rhs.bits,
+                                             fract_width,
+                                             rhs.fract_width)
+                    int_result = operation(lhs_int, rhs_int)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   fract_width,
+                                                   bit_width,
+                                                   lhs.signed)
+                    else:
+                        expected = int_result
+                with self.subTest(lhs=repr(lhs),
+                                  rhs=repr(rhs),
+                                  expected=repr(expected)):
+                    result = operation(lhs, rhs)
+                    if is_fixed:
+                        self.assertEqual(result.bit_width, expected.bit_width)
+                        self.assertEqual(result.signed, expected.signed)
+                        self.assertEqual(result.fract_width,
+                                         expected.fract_width)
+                        self.assertEqual(result.bits, expected.bits)
+                    else:
+                        self.assertEqual(result, expected)
+
+    def test_add(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
+
+    def test_sub(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
+
+    def test_and(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
+
+    def test_or(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
+
+    def test_xor(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
+
+    def test_mul(self):
+        def adjust_bits_op(bits, out_fract_width, in_fract_width):
+            return bits
+        self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
+                                   True,
+                                   lambda l_width, r_width: l_width + r_width,
+                                   adjust_bits_op)
+
+    def test_cmp(self):
+        def cmp(lhs, rhs):
+            if lhs < rhs:
+                return -1
+            elif lhs > rhs:
+                return 1
+            return 0
+        self.binary_op_test_helper(cmp, False)
+
+    def test_lt(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
+
+    def test_le(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
+
+    def test_eq(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
+
+    def test_ne(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
+
+    def test_gt(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
+
+    def test_ge(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
+
+    def test_bool(self):
+        for v in self.get_test_values(6, False):
+            with self.subTest(v=repr(v)):
+                self.assertEqual(bool(v), bool(v.bits))
 
     def test_str(self):
         self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),