add fsub support to fadd pipeline
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 1 Jul 2022 05:07:27 +0000 (22:07 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 1 Jul 2022 05:07:27 +0000 (22:07 -0700)
src/ieee754/fpadd/specialcases.py
src/ieee754/fpadd/test/test_add_formal.py
src/ieee754/fpcommon/fpbase.py

index 31a490685c8ed902852186e5569110354f9b9427..8ae66e7e6970fb4a913c615b706a2899d85e1f3a 100644 (file)
@@ -5,13 +5,31 @@
 from nmigen import Module, Signal, Cat, Mux
 
 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from nmigen import Module, Signal, Cat, Mux
 
 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
-from ieee754.fpcommon.fpbase import FPNumDecode, FPRoundingMode
+from ieee754.fpcommon.fpbase import FPFormat, FPNumDecode, FPRoundingMode
 
 from ieee754.fpcommon.fpbase import FPNumBaseRecord
 from ieee754.fpcommon.basedata import FPBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 
 
 
 from ieee754.fpcommon.fpbase import FPNumBaseRecord
 from ieee754.fpcommon.basedata import FPBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 
 
+class FPAddInputData(FPBaseData):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+        self.is_sub = Signal(reset=False)
+
+    def eq(self, i):
+        ret = super().eq(i)
+        ret.append(self.is_sub.eq(i.is_sub))
+        return ret
+
+    def __iter__(self):
+        yield from super().__iter__()
+        yield self.is_sub
+
+    def ports(self):
+        return list(self)
+
+
 class FPAddSpecialCasesMod(PipeModBase):
     """ special cases: NaNs, infs, zeros, denormalised
         NOTE: some of these are unique to add.  see "Special Operations"
 class FPAddSpecialCasesMod(PipeModBase):
     """ special cases: NaNs, infs, zeros, denormalised
         NOTE: some of these are unique to add.  see "Special Operations"
@@ -22,7 +40,7 @@ class FPAddSpecialCasesMod(PipeModBase):
         super().__init__(pspec, "specialcases")
 
     def ispec(self):
         super().__init__(pspec, "specialcases")
 
     def ispec(self):
-        return FPBaseData(self.pspec)
+        return FPAddInputData(self.pspec)
 
     def ospec(self):
         return FPSCData(self.pspec, True)
 
     def ospec(self):
         return FPSCData(self.pspec, True)
@@ -37,11 +55,16 @@ class FPAddSpecialCasesMod(PipeModBase):
         b1 = FPNumBaseRecord(width)
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
         b1 = FPNumBaseRecord(width)
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
-        comb += [a1.v.eq(self.i.a),
-                     b1.v.eq(self.i.b),
-                     self.o.a.eq(a1),
-                     self.o.b.eq(b1)
-                    ]
+        flip_b_sign = Signal()
+        b_is_nan = Signal()
+        comb += [
+            b_is_nan.eq(FPFormat.standard(width).is_nan(self.i.b)),
+            flip_b_sign.eq(self.i.is_sub & ~b_is_nan),
+            a1.v.eq(self.i.a),
+            b1.v.eq(self.i.b ^ (flip_b_sign << (width - 1))),
+            self.o.a.eq(a1),
+            self.o.b.eq(b1)
+        ]
 
         zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
 
 
         zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
 
index 915c2b94312a8fafc6646cb0ec7ac369055d60cf..95d04d1705675518ab478f59ce5dce9f8d4299e6 100644 (file)
@@ -10,10 +10,11 @@ from ieee754.fpcommon.fpbase import FPRoundingMode
 from ieee754.pipeline import PipelineSpec
 
 
 from ieee754.pipeline import PipelineSpec
 
 
-class TestFAddFormal(FHDLTestCase):
-    def tst_fadd_formal(self, sort, rm):
+class TestFAddFSubFormal(FHDLTestCase):
+    def tst_fadd_fsub_formal(self, sort, rm, is_sub):
         assert isinstance(sort, SmtSortFloatingPoint)
         assert isinstance(rm, FPRoundingMode)
         assert isinstance(sort, SmtSortFloatingPoint)
         assert isinstance(rm, FPRoundingMode)
+        assert isinstance(is_sub, bool)
         width = sort.width
         dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
         m = Module()
         width = sort.width
         dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
         m = Module()
@@ -32,6 +33,9 @@ class TestFAddFormal(FHDLTestCase):
         b = Signal(width)
         m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
         m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
         b = Signal(width)
         m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
         m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
+        m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0))
+
+        smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add
         a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
         b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
         out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
         a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
         b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
         out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
@@ -39,8 +43,8 @@ class TestFAddFormal(FHDLTestCase):
                   FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
             rounded_up = Signal(width)
             m.d.comb += rounded_up.eq(AnyConst(width))
                   FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
             rounded_up = Signal(width)
             m.d.comb += rounded_up.eq(AnyConst(width))
-            rounded_up_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_POSITIVE)
-            rounded_down_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_NEGATIVE)
+            rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE)
+            rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE)
             m.d.comb += Assume(SmtFloatingPoint.from_bits(
                 rounded_up, sort=sort).same(rounded_up_fp).as_value())
             use_rounded_up = SmtBool.make(rounded_up[0])
             m.d.comb += Assume(SmtFloatingPoint.from_bits(
                 rounded_up, sort=sort).same(rounded_up_fp).as_value())
             use_rounded_up = SmtBool.make(rounded_up[0])
@@ -50,7 +54,7 @@ class TestFAddFormal(FHDLTestCase):
             expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
         else:
             smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
             expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
         else:
             smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
-            expected_fp = a_fp.add(b_fp, rm=smt_rm)
+            expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm)
         expected = Signal(width)
         m.d.comb += expected.eq(AnyConst(width))
         quiet_bit = 1 << (sort.mantissa_field_width - 1)
         expected = Signal(width)
         m.d.comb += expected.eq(AnyConst(width))
         quiet_bit = 1 << (sort.mantissa_field_width - 1)
@@ -75,74 +79,144 @@ class TestFAddFormal(FHDLTestCase):
     # FIXME: check exception flags
 
     def test_fadd_f16_rne_formal(self):
     # FIXME: check exception flags
 
     def test_fadd_f16_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False)
 
     def test_fadd_f32_rne_formal(self):
 
     def test_fadd_f32_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rne_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rne_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNE)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False)
 
     def test_fadd_f16_rtz_formal(self):
 
     def test_fadd_f16_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f32_rtz_formal(self):
 
     def test_fadd_f32_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtz_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtz_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTZ)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False)
 
     def test_fadd_f16_rtp_formal(self):
 
     def test_fadd_f16_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False)
 
     def test_fadd_f32_rtp_formal(self):
 
     def test_fadd_f32_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtp_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtp_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False)
 
     def test_fadd_f16_rtn_formal(self):
 
     def test_fadd_f16_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False)
 
     def test_fadd_f32_rtn_formal(self):
 
     def test_fadd_f32_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtn_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtn_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTN)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False)
 
     def test_fadd_f16_rna_formal(self):
 
     def test_fadd_f16_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False)
 
     def test_fadd_f32_rna_formal(self):
 
     def test_fadd_f32_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rna_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rna_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNA)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False)
 
     def test_fadd_f16_rtop_formal(self):
 
     def test_fadd_f16_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f32_rtop_formal(self):
 
     def test_fadd_f32_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtop_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rtop_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTOP)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False)
 
     def test_fadd_f16_rton_formal(self):
 
     def test_fadd_f16_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False)
 
     def test_fadd_f32_rton_formal(self):
 
     def test_fadd_f32_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False)
 
     @unittest.skip("too slow")
     def test_fadd_f64_rton_formal(self):
 
     @unittest.skip("too slow")
     def test_fadd_f64_rton_formal(self):
-        self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTON)
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False)
+
+    def test_fsub_f16_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f32_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rne_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True)
+
+    def test_fsub_f16_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f32_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtz_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True)
+
+    def test_fsub_f16_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f32_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtp_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True)
+
+    def test_fsub_f16_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f32_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtn_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True)
+
+    def test_fsub_f16_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f32_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rna_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True)
+
+    def test_fsub_f16_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f32_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rtop_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True)
+
+    def test_fsub_f16_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True)
+
+    def test_fsub_f32_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True)
+
+    @unittest.skip("too slow")
+    def test_fsub_f64_rton_formal(self):
+        self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True)
 
     def test_all_rounding_modes_covered(self):
         for width in 16, 32, 64:
 
     def test_all_rounding_modes_covered(self):
         for width in 16, 32, 64:
@@ -150,6 +224,8 @@ class TestFAddFormal(FHDLTestCase):
                 rm_s = rm.name.lower()
                 name = f"test_fadd_f{width}_{rm_s}_formal"
                 assert callable(getattr(self, name))
                 rm_s = rm.name.lower()
                 name = f"test_fadd_f{width}_{rm_s}_formal"
                 assert callable(getattr(self, name))
+                name = f"test_fsub_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
 
 
 if __name__ == '__main__':
 
 
 if __name__ == '__main__':
index fee88bef91d702a74b0e255e5534de191146f459..fd707bc1e21494103b1ce6fdb25d56ef17411ddc 100644 (file)
@@ -322,42 +322,42 @@ class FPFormat:
     def is_zero(self, x):
         """ returns true if x is +/- zero
         """
     def is_zero(self, x):
         """ returns true if x is +/- zero
         """
-        return (self.get_exponent(x) == self.e_sub and
-                self.get_mantissa_field(x) == 0)
+        return (self.get_exponent(x) == self.e_sub) & \
+            (self.get_mantissa_field(x) == 0)
 
     def is_subnormal(self, x):
         """ returns true if x is subnormal (exp at minimum)
         """
 
     def is_subnormal(self, x):
         """ returns true if x is subnormal (exp at minimum)
         """
-        return (self.get_exponent(x) == self.e_sub and
-                self.get_mantissa_field(x) != 0)
+        return (self.get_exponent(x) == self.e_sub) & \
+            (self.get_mantissa_field(x) != 0)
 
     def is_inf(self, x):
         """ returns true if x is infinite
         """
 
     def is_inf(self, x):
         """ returns true if x is infinite
         """
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) == 0)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) == 0)
 
     def is_nan(self, x):
         """ returns true if x is a nan (quiet or signalling)
         """
 
     def is_nan(self, x):
         """ returns true if x is a nan (quiet or signalling)
         """
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) != 0)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0)
 
     def is_quiet_nan(self, x):
         """ returns true if x is a quiet nan
         """
 
     def is_quiet_nan(self, x):
         """ returns true if x is a quiet nan
         """
-        highbit = 1<<(self.m_width-1)
-        return (self.get_exponent(x) == self.e_max and
-                self.get_mantissa_field(x) != 0 and
-                self.get_mantissa_field(x) & highbit != 0)
+        highbit = 1 << (self.m_width - 1)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0) & \
+            (self.get_mantissa_field(x) & highbit != 0)
 
     def is_nan_signaling(self, x):
         """ returns true if x is a signalling nan
         """
 
     def is_nan_signaling(self, x):
         """ returns true if x is a signalling nan
         """
-        highbit = 1<<(self.m_width-1)
-        return ((self.get_exponent(x) == self.e_max) and
-                (self.get_mantissa_field(x) != 0) and
-                (self.get_mantissa_field(x) & highbit) == 0)
+        highbit = 1 << (self.m_width - 1)
+        return (self.get_exponent(x) == self.e_max) & \
+            (self.get_mantissa_field(x) != 0) & \
+            (self.get_mantissa_field(x) & highbit) == 0
 
     @property
     def width(self):
 
     @property
     def width(self):