add correct NaN propagation to the fadd pipeline and formal proof
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 28 Jun 2022 05:27:39 +0000 (22:27 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 28 Jun 2022 05:27:39 +0000 (22:27 -0700)
src/ieee754/fpadd/specialcases.py
src/ieee754/fpadd/test/test_add_formal.py
src/ieee754/fpcommon/fpbase.py

index 656ace3ceced434ee50a16740bd299db8e5bc8ec..3b6c69210b4440ac2c7403e3c5aead6b97b80696 100644 (file)
@@ -80,11 +80,15 @@ class FPAddSpecialCasesMod(PipeModBase):
 
         # prepare inf/zero/nans
         z_zero = FPNumBaseRecord(width, False, name="z_zero")
-        z_nan = FPNumBaseRecord(width, False, name="z_nan")
+        z_default_nan = FPNumBaseRecord(width, False, name="z_default_nan")
+        z_quieted_a = FPNumBaseRecord(width, False, name="z_quieted_a")
+        z_quieted_b = FPNumBaseRecord(width, False, name="z_quieted_b")
         z_infa = FPNumBaseRecord(width, False, name="z_infa")
         z_infb = FPNumBaseRecord(width, False, name="z_infb")
         comb += z_zero.zero(0)
-        comb += z_nan.nan(0)
+        comb += z_default_nan.nan(0)
+        comb += z_quieted_a.quieted_nan(a1)
+        comb += z_quieted_b.quieted_nan(b1)
         comb += z_infa.inf(a1.s)
         comb += z_infb.inf(b1.s)
 
@@ -93,6 +97,8 @@ class FPAddSpecialCasesMod(PipeModBase):
 
         # this is the logic-decision-making for special-cases:
         # if a is NaN or b is NaN return NaN
+        #   if a is NaN return quieted_nan(a)
+        #   else return quieted_nan(b)
         # elif a is inf return inf (or NaN)
         #   if a is inf and signs don't match return NaN
         #   else return inf(a)
@@ -112,8 +118,8 @@ class FPAddSpecialCasesMod(PipeModBase):
         oz = Mux(t_a1zero, b1.v, oz)
         oz = Mux(t_abz, Cat(self.i.b[:-1], absa), oz)
         oz = Mux(t_b1inf, z_infb.v, oz)
-        oz = Mux(t_a1inf, Mux(bexp128s, z_nan.v, z_infa.v), oz)
-        oz = Mux(t_abnan, z_nan.v, oz)
+        oz = Mux(t_a1inf, Mux(bexp128s, z_default_nan.v, z_infa.v), oz)
+        oz = Mux(t_abnan, Mux(a1.is_nan, z_quieted_a.v, z_quieted_b.v), oz)
 
         comb += self.o.oz.eq(oz)
 
index 013aca574e8167dc47600298edf32f8c2d061cef..95e842d09b072165ba476185ef381c0af2f088d0 100644 (file)
@@ -2,7 +2,7 @@ import unittest
 from nmutil.formaltest import FHDLTestCase
 from ieee754.fpadd.pipeline import FPADDBasePipe
 from nmigen.hdl.dsl import Module
-from nmigen.hdl.ast import AnySeq, Initial, Assert, AnyConst, Signal, Assume
+from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume
 from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
     SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, \
     ROUND_NEAREST_TIES_TO_EVEN
@@ -27,17 +27,25 @@ class TestFAddFormal(FHDLTestCase):
         z_fp = SmtFloatingPoint.from_bits(z, sort=sort)
         expected_fp = a_fp.add(b_fp, rm=rm)
         expected = Signal(width)
-        m.d.comb += expected.eq(AnySeq(width))
-        # Important Note: expected and z won't necessarily match bit-exactly
-        # if it's a NaN, all this checks for is z is also any NaN
-        m.d.comb += Assume((SmtFloatingPoint.from_bits(expected, sort=sort)
-                            == expected_fp).as_value())
-        # FIXME: check that it produces the correct NaNs
+        m.d.comb += expected.eq(AnyConst(width))
+        quiet_bit = 1 << (sort.mantissa_field_width - 1)
+        nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
+        with m.If(expected_fp.is_nan().as_value()):
+            with m.If(a_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (a | quiet_bit))
+            with m.Elif(b_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (b | quiet_bit))
+            with m.Else():
+                m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
+        with m.Else():
+            m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
+                               .same(expected_fp).as_value())
         m.d.comb += a.eq(AnyConst(width))
         m.d.comb += b.eq(AnyConst(width))
         with m.If(dut.n.trigger):
-            m.d.sync += Assert((z_fp == expected_fp).as_value())
-        self.assertFormal(m, depth=5, solver="z3")
+            m.d.sync += Assert(z_fp.same(expected_fp).as_value())
+            m.d.sync += Assert(z == expected)
+        self.assertFormal(m, depth=5, solver="bitwuzla")
 
     # FIXME: check other rounding modes
     # FIXME: check exception flags
index 30178633f96baf6771d327ae6fcd9a82dc42edd8..6e40b021945236fb914c5e4a7630a0f28b0f582c 100644 (file)
@@ -425,6 +425,12 @@ class FPNumBaseRecord:
     def nan(self, s):
         return self.create(*self._nan(s))
 
+    def quieted_nan(self, other):
+        assert isinstance(other, FPNumBaseRecord)
+        assert self.width == other.width
+        return self.create(other.s, self.fp.P128,
+                           other.v[0:self.e_start] | (1 << (self.e_start - 1)))
+
     def inf(self, s):
         return self.create(*self._inf(s))