add fsub support to fadd pipeline
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 1 Jul 2022 05:07:27 +0000 (22:07 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 1 Jul 2022 05:07:27 +0000 (22:07 -0700)
src/ieee754/fpadd/specialcases.py
src/ieee754/fpadd/test/test_add_formal.py
src/ieee754/fpcommon/fpbase.py

index 31a490685c8ed902852186e5569110354f9b9427..8ae66e7e6970fb4a913c615b706a2899d85e1f3a 100644 (file)
@@ -5,13 +5,31 @@
 from nmigen import Module, Signal, Cat, Mux
 
 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
-from ieee754.fpcommon.fpbase import FPNumDecode, FPRoundingMode
+from ieee754.fpcommon.fpbase import FPFormat, FPNumDecode, FPRoundingMode
 
 from ieee754.fpcommon.fpbase import FPNumBaseRecord
 from ieee754.fpcommon.basedata import FPBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 
 
+class FPAddInputData(FPBaseData):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+        self.is_sub = Signal(reset=False)
+
+    def eq(self, i):
+        ret = super().eq(i)
+        ret.append(self.is_sub.eq(i.is_sub))
+        return ret
+
+    def __iter__(self):
+        yield from super().__iter__()
+        yield self.is_sub
+
+    def ports(self):
+        return list(self)
+
+
 class FPAddSpecialCasesMod(PipeModBase):
     """ special cases: NaNs, infs, zeros, denormalised
         NOTE: some of these are unique to add.  see "Special Operations"
@@ -22,7 +40,7 @@ class FPAddSpecialCasesMod(PipeModBase):
         super().__init__(pspec, "specialcases")
 
     def ispec(self):
-        return FPBaseData(self.pspec)
+        return FPAddInputData(self.pspec)
 
     def ospec(self):
         return FPSCData(self.pspec, True)
@@ -37,11 +55,16 @@ class FPAddSpecialCasesMod(PipeModBase):
         b1 = FPNumBaseRecord(width)
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
-        comb += [a1.v.eq(self.i.a),
-                     b1.v.eq(self.i.b),
-                     self.o.a.eq(a1),
-                     self.o.b.eq(b1)
-                    ]
+        flip_b_sign = Signal()
+        b_is_nan = Signal()
+        comb += [
+            b_is_nan.eq(FPFormat.standard(width).is_nan(self.i.b)),
+            flip_b_sign.eq(self.i.is_sub & ~b_is_nan),
+            a1.v.eq(self.i.a),
+            b1.v.eq(self.i.b ^ (flip_b_sign << (width - 1))),
+            self.o.a.eq(a1),
+            self.o.b.eq(b1)
+        ]
 
         zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
 
index 915c2b94312a8fafc6646cb0ec7ac369055d60cf..95d04d1705675518ab478f59ce5dce9f8d4299e6 100644 (file)
@@ -10,10 +10,11 @@ from ieee754.fpcommon.fpbase import FPRoundingMode
 from ieee754.pipeline import PipelineSpec
 
 
-class TestFAddFormal(FHDLTestCase):
-    def tst_fadd_formal(self, sort, rm):
+class TestFAddFSubFormal(FHDLTestCase):
+    def tst_fadd_fsub_formal(self, sort, rm, is_sub):
         assert isinstance(sort, SmtSortFloatingPoint)
         assert isinstance(rm, FPRoundingMode)
+        assert isinstance(is_sub, bool)
         width = sort.width
         dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
         m = Module()
@@ -32,6 +33,9 @@ class TestFAddFormal(FHDLTestCase):
         b = Signal(width)
         m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
         m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
+        m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0))
+
+        smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add
         a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
         b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
         out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
@@ -39,8 +43,8 @@ class TestFAddFormal(FHDLTestCase):
                   FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
             rounded_up = Signal(width)
             m.d.comb += rounded_up.eq(AnyConst(width))
-            rounded_up_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_POSITIVE)
-            rounded_down_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_NEGATIVE)
+            rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE)
+            rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE)
             m.d.comb += Assume(SmtFloatingPoint.from_bits(
                 rounded_up, sort=sort).same(rounded_up_fp).as_value())
             use_rounded_up = SmtBool.make(rounded_up[0])
@@ -50,7 +54,7 @@ class TestFAddFormal(FHDLTestCase):
             expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
         else:
             smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
-            expected_fp = a_fp.add(b_fp, rm=smt_rm)
+            expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm)
         expected = Signal(width)
         m.d.comb += expected.eq(AnyConst(width))
         quiet_bit = 1 << (sort.mantissa_field_width - 1)
@@ -75,74 +79,144 @@ class TestFAddFormal(FHDLTestCase):
     # FIXME: check exception flags
 
     def test_fadd_f16_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False)
 
     def test_fadd_f32_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False)
 
     def test_fadd_f16_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f32_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f16_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False)
 
     def test_fadd_f32_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False)
 
     def test_fadd_f16_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False)
 
     def test_fadd_f32_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False)
 
     def test_fadd_f16_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False)
 
     def test_fadd_f32_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False)
 
     def test_fadd_f16_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f32_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f16_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False)
 
     def test_fadd_f32_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False)
+
+    def test_fsub_f16_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f32_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f16_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f32_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f16_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f32_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f16_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f32_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f16_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f32_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f16_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f32_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f16_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True)
+
+    def test_fsub_f32_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True)
 
     def test_all_rounding_modes_covered(self):
         for width in 16, 32, 64:
@@ -150,6 +224,8 @@ class TestFAddFormal(FHDLTestCase):
                 rm_s = rm.name.lower()
                 name = f"test_fadd_f{width}_{rm_s}_formal"
                 assert callable(getattr(self, name))
+                name = f"test_fsub_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
 
 
 if __name__ == '__main__':
index fee88bef91d702a74b0e255e5534de191146f459..fd707bc1e21494103b1ce6fdb25d56ef17411ddc 100644 (file)
@@ -322,42 +322,42 @@ class FPFormat:
     def is_zero(self, x):
         """ returns true if x is +/- zero
         """
-        return (self.get_exponent(x) == self.e_sub and
-                self.get_mantissa_field(x) == 0)
+        return (self.get_exponent(x) == self.e_sub) & \
+            (self.get_mantissa_field(x) == 0)
 
     def is_subnormal(self, x):
         """ returns true if x is subnormal (exp at minimum)
         """
-        return (self.get_exponent(x) == self.e_sub and
-                self.get_mantissa_field(x) != 0)
+        return (self.get_exponent(x) == self.e_sub) & \
+            (self.get_mantissa_field(x) != 0)
 
     def is_inf(self, x):
         """ returns true if x is infinite
         """
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) == 0)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) == 0)
 
     def is_nan(self, x):
         """ returns true if x is a nan (quiet or signalling)
         """
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) != 0)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0)
 
     def is_quiet_nan(self, x):
         """ returns true if x is a quiet nan
         """
-        highbit = 1<<(self.m_width-1)
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) != 0 and
-                self.get_mantissa_field(x) & highbit != 0)
+        highbit = 1 << (self.m_width - 1)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0) & \
+            (self.get_mantissa_field(x) & highbit != 0)
 
     def is_nan_signaling(self, x):
         """ returns true if x is a signalling nan
         """
-        highbit = 1<<(self.m_width-1)
-        return ((self.get_exponent(x) == self.e_max) and
-                (self.get_mantissa_field(x) != 0) and
-                (self.get_mantissa_field(x) & highbit) == 0)
+        highbit = 1 << (self.m_width - 1)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0) & \
+            (self.get_mantissa_field(x) & highbit) == 0
 
     @property
     def width(self):