Fix carry output of adder/subtracter
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
index 53647082033ed646fad84d4067ef07d60fdf233c..b209aec18a5ab582f802ffb138fbbd057a045d41 100644 (file)
@@ -12,6 +12,7 @@ from ieee754.part_mux.part_mux import PMux
 from random import randint
 import unittest
 import itertools
+import math
 
 
 def perms(k):
@@ -71,7 +72,7 @@ class TestAddMod(Elaboratable):
         sub_out, sub_carry = self.a.sub_op(self.a, self.b,
                                            self.carry_in)
         comb += self.sub_output.eq(sub_out)
-        comb += self.sub_carry_out.eq(add_carry)
+        comb += self.sub_carry_out.eq(sub_carry)
         comb += self.neg_output.eq(-self.a)
         ppts = self.partpoints
         comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
@@ -97,14 +98,21 @@ class TestPartitionPoints(unittest.TestCase):
 
             def test_add_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
-                return mask & ((a & mask) + (b & mask) + lsb)
+                sum = (a & mask) + (b & mask) + lsb
+                result = mask & sum
+                carry = (sum & mask) != sum
+                print(a, b, sum, mask)
+                return result, carry
 
             def test_sub_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
-                return mask & ((a & mask) + (~b & mask) + lsb)
+                sum = (a & mask) + (~b & mask) + lsb
+                result = mask & sum
+                carry = (sum & mask) != sum
+                return result, carry
 
             def test_neg_fn(carry_in, a, b, mask):
-                return mask & ((a & mask) + (~0 & mask))
+                return test_add_fn(0, a, ~0, mask)
 
             def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
                 rand_data = []
@@ -124,14 +132,25 @@ class TestPartitionPoints(unittest.TestCase):
                     yield module.carry_in.eq(carry_sig)
                     yield Delay(0.1e-6)
                     y = 0
+                    carry_result = 0
                     for i, mask in enumerate(mask_list):
-                        y |= test_fn(carry, a, b, mask)
+                        res, c = test_fn(carry, a, b, mask)
+                        y |= res
+                        lsb = mask & ~(mask - 1)
+                        bit_set = int(math.log2(lsb))
+                        carry_result |= c << int(bit_set/4)
                     outval = (yield getattr(module, "%s_output" % mod_attr))
                     # TODO: get (and test) carry output as well
                     print(a, b, outval, carry)
                     msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
-                        f" => 0x{y:X} != 0x{outval:X}"
+                        f" => 0x{y:X} != 0x{outval:X} ({mod_attr})"
                     self.assertEqual(y, outval, msg)
+                    if hasattr(module, "%s_carry_out" % mod_attr):
+                        c_outval = (yield getattr(module,
+                                                  "%s_carry_out" % mod_attr))
+                        msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
+                            f" => 0x{carry_result:X} != 0x{c_outval:X} ({mod_attr})"
+                        self.assertEqual(carry_result, c_outval, msg)
 
             for (test_fn, mod_attr) in ((test_add_fn, "add"),
                                         (test_sub_fn, "sub"),