added tests for rest of Fixed
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
index b1a944d918ed99aa7801ed71301d5d1ecc95dd06..a72f9243feb53bb65d6201a03fe288dbe72d3cd9 100644 (file)
@@ -482,27 +482,147 @@ class TestFixed(unittest.TestCase):
             with self.subTest(value=repr(value)):
                 self.assertEqual(float(~value), (~i) / 4)
 
-    # TODO: add test for _binary_op
-    # TODO: add test for __add__
-    # TODO: add test for __radd__
-    # TODO: add test for __sub__
-    # TODO: add test for __rsub__
-    # TODO: add test for __and__
-    # TODO: add test for __rand__
-    # TODO: add test for __or__
-    # TODO: add test for __ror__
-    # TODO: add test for __xor__
-    # TODO: add test for __rxor__
-    # TODO: add test for __mul__
-    # TODO: add test for _cmp_impl
-    # TODO: add test for cmp
-    # TODO: add test for __lt__
-    # TODO: add test for __le__
-    # TODO: add test for __eq__
-    # TODO: add test for __ne__
-    # TODO: add test for __gt__
-    # TODO: add test for __ge__
-    # TODO: add test for __bool__
+    @staticmethod
+    def get_test_values(max_bit_width, include_int):
+        for signed in False, True:
+            if include_int:
+                for bits in range(1 << max_bit_width):
+                    int_value = Const.normalize(bits, (max_bit_width, signed))
+                    yield int_value
+            for bit_width in range(1, max_bit_width):
+                for fract_width in range(bit_width + 1):
+                    for bits in range(1 << bit_width):
+                        yield Fixed.from_bits(bits,
+                                              fract_width,
+                                              bit_width,
+                                              signed)
+
+    def binary_op_test_helper(self,
+                              operation,
+                              is_fixed=True,
+                              width_combine_op=max,
+                              adjust_bits_op=None):
+        def default_adjust_bits_op(bits, out_fract_width, in_fract_width):
+            return bits << (out_fract_width - in_fract_width)
+        if adjust_bits_op is None:
+            adjust_bits_op = default_adjust_bits_op
+        max_bit_width = 5
+        for lhs in self.get_test_values(max_bit_width, True):
+            lhs_is_int = isinstance(lhs, int)
+            for rhs in self.get_test_values(max_bit_width, not lhs_is_int):
+                rhs_is_int = isinstance(rhs, int)
+                if lhs_is_int:
+                    assert not rhs_is_int
+                    lhs_int = adjust_bits_op(lhs, rhs.fract_width, 0)
+                    int_result = operation(lhs_int, rhs.bits)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   rhs.fract_width,
+                                                   rhs.bit_width,
+                                                   rhs.signed)
+                    else:
+                        expected = int_result
+                elif rhs_is_int:
+                    rhs_int = adjust_bits_op(rhs, lhs.fract_width, 0)
+                    int_result = operation(lhs.bits, rhs_int)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   lhs.fract_width,
+                                                   lhs.bit_width,
+                                                   lhs.signed)
+                    else:
+                        expected = int_result
+                elif lhs.signed != rhs.signed:
+                    continue
+                else:
+                    fract_width = width_combine_op(lhs.fract_width,
+                                                   rhs.fract_width)
+                    int_width = width_combine_op(lhs.bit_width
+                                                 - lhs.fract_width,
+                                                 rhs.bit_width
+                                                 - rhs.fract_width)
+                    bit_width = fract_width + int_width
+                    lhs_int = adjust_bits_op(lhs.bits,
+                                             fract_width,
+                                             lhs.fract_width)
+                    rhs_int = adjust_bits_op(rhs.bits,
+                                             fract_width,
+                                             rhs.fract_width)
+                    int_result = operation(lhs_int, rhs_int)
+                    if is_fixed:
+                        expected = Fixed.from_bits(int_result,
+                                                   fract_width,
+                                                   bit_width,
+                                                   lhs.signed)
+                    else:
+                        expected = int_result
+                with self.subTest(lhs=repr(lhs),
+                                  rhs=repr(rhs),
+                                  expected=repr(expected)):
+                    result = operation(lhs, rhs)
+                    if is_fixed:
+                        self.assertEqual(result.bit_width, expected.bit_width)
+                        self.assertEqual(result.signed, expected.signed)
+                        self.assertEqual(result.fract_width,
+                                         expected.fract_width)
+                        self.assertEqual(result.bits, expected.bits)
+                    else:
+                        self.assertEqual(result, expected)
+
+    def test_add(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs + rhs)
+
+    def test_sub(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs - rhs)
+
+    def test_and(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs & rhs)
+
+    def test_or(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs | rhs)
+
+    def test_xor(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs ^ rhs)
+
+    def test_mul(self):
+        def adjust_bits_op(bits, out_fract_width, in_fract_width):
+            return bits
+        self.binary_op_test_helper(lambda lhs, rhs: lhs * rhs,
+                                   True,
+                                   lambda l_width, r_width: l_width + r_width,
+                                   adjust_bits_op)
+
+    def test_cmp(self):
+        def cmp(lhs, rhs):
+            if lhs < rhs:
+                return -1
+            elif lhs > rhs:
+                return 1
+            return 0
+        self.binary_op_test_helper(cmp, False)
+
+    def test_lt(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs < rhs, False)
+
+    def test_le(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs <= rhs, False)
+
+    def test_eq(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs == rhs, False)
+
+    def test_ne(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs != rhs, False)
+
+    def test_gt(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs > rhs, False)
+
+    def test_ge(self):
+        self.binary_op_test_helper(lambda lhs, rhs: lhs >= rhs, False)
+
+    def test_bool(self):
+        for v in self.get_test_values(6, False):
+            with self.subTest(v=repr(v)):
+                self.assertEqual(bool(v), bool(v.bits))
 
     def test_str(self):
         self.assertEqual(str(Fixed.from_bits(0x1234, 0, 16, False)),