working on implementing fma, f16 rtz formal proof seems likely to work
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Jul 2022 10:49:45 +0000 (03:49 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Jul 2022 10:49:45 +0000 (03:49 -0700)
bitwuzla has been running the formal proof for the last 23min, seems
like it'll probably succeed after a bunch of time.

src/ieee754/fpfma/main_stage.py
src/ieee754/fpfma/norm.py
src/ieee754/fpfma/pipeline.py
src/ieee754/fpfma/special_cases.py
src/ieee754/fpfma/test/__init__.py [new file with mode: 0644]
src/ieee754/fpfma/test/test_fma_formal.py [new file with mode: 0644]
src/ieee754/fpfma/util.py

index 1ab2b2b8a048a14330c6c0f6230e6b0539cfac65..7a028107e40d1c52ae363b42019b2c550f91e91f 100644 (file)
@@ -3,13 +3,13 @@
 computes `z = (a * c) + b` but only rounds once at the end
 """
 
-from nmutil.pipemodbase import PipeModBase
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.fpbase import FPRoundingMode
 from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
 from nmigen.hdl.dsl import Module
-from nmigen.hdl.ast import Signal, signed, unsigned, Mux
+from nmigen.hdl.ast import Signal, signed, unsigned, Mux, Cat
 from ieee754.fpfma.util import expanded_exponent_shape, \
-    expanded_mantissa_shape, get_fpformat
+    expanded_mantissa_shape, get_fpformat, EXPANDED_MANTISSA_EXTRA_LSBS
 from ieee754.fpcommon.getop import FPPipeContext
 
 
@@ -38,8 +38,31 @@ class FPFMAPostCalcData:
         self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
         """rounding mode"""
 
+    def eq(self, i):
+        return [
+            self.sign.eq(i.sign),
+            self.exponent.eq(i.exponent),
+            self.mantissa.eq(i.mantissa),
+            self.bypassed_z.eq(i.bypassed_z),
+            self.do_bypass.eq(i.do_bypass),
+            self.ctx.eq(i.ctx),
+            self.rm.eq(i.rm),
+        ]
+
+    def __iter__(self):
+        yield self.sign
+        yield self.exponent
+        yield self.mantissa
+        yield self.bypassed_z
+        yield self.do_bypass
+        yield self.ctx
+        yield self.rm
+
+    def ports(self):
+        return list(self)
+
 
-class FPFMAMainStage(PipeModBase):
+class FPFMAMain(PipeModBase):
     def __init__(self, pspec):
         super().__init__(pspec, "main")
 
@@ -65,8 +88,9 @@ class FPFMAMainStage(PipeModBase):
             negate_b_s.eq(inp.do_sub),
             negate_b_u.eq(inp.do_sub),
         ]
-        sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u
-        sum = Signal(sum_v.shape())
+        sum_v = (product_v << EXPANDED_MANTISSA_EXTRA_LSBS) + \
+            (inp.b_mantissa ^ negate_b_s) + negate_b_u
+        sum = Signal(expanded_mantissa_shape(fpf))
         m.d.comb += sum.eq(sum_v)
 
         sum_neg = Signal()
@@ -97,3 +121,13 @@ class FPFMAMainStage(PipeModBase):
             out.rm.eq(inp.rm),
         ]
         return m
+
+
+class FPFMAMainStage(PipeModBaseChain):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+
+    def get_chain(self):
+        """ gets chain of modules
+        """
+        return [FPFMAMain(self.pspec)]
index 0c16f452e3975915bfc1e4f9c65799a1efaf510c..21022c81a5866e4a677e60a1d639c4dec46eed49 100644 (file)
@@ -1,12 +1,14 @@
 from nmutil.pipemodbase import PipeModBaseChain, PipeModBase
+from ieee754.fpcommon.fpbase import OverflowMod
 from ieee754.fpcommon.postnormalise import FPNorm1Data
 from ieee754.fpcommon.roundz import FPRoundMod
 from ieee754.fpcommon.corrections import FPCorrectionsMod
 from ieee754.fpcommon.pack import FPPackMod
 from ieee754.fpfma.main_stage import FPFMAPostCalcData
 from nmigen.hdl.dsl import Module
-
+from nmigen.hdl.ast import Signal
 from ieee754.fpfma.util import get_fpformat
+from nmigen.lib.coding import PriorityEncoder
 
 
 class FPFMANorm(PipeModBase):
@@ -23,16 +25,38 @@ class FPFMANorm(PipeModBase):
         m = Module()
         fpf = get_fpformat(self.pspec)
         assert fpf.has_sign
-        inp = self.i
-        out = self.o
-        raise NotImplementedError  # FIXME: finish
+        inp: FPFMAPostCalcData = self.i
+        out: FPNorm1Data = self.o
+        m.submodules.pri_enc = pri_enc = PriorityEncoder(inp.mantissa.width)
+        m.d.comb += pri_enc.i.eq(inp.mantissa[::-1])
+        unrestricted_shift_amount = Signal(range(inp.mantissa.width))
+        shift_amount = Signal(range(inp.mantissa.width))
+        m.d.comb += unrestricted_shift_amount.eq(pri_enc.o)
+        with m.If(inp.exponent - (1 + fpf.e_sub) < unrestricted_shift_amount):
+            m.d.comb += shift_amount.eq(inp.exponent - (1 + fpf.e_sub))
+        with m.Else():
+            m.d.comb += shift_amount.eq(unrestricted_shift_amount)
+        n_mantissa = Signal(inp.mantissa.width)
+        m.d.comb += n_mantissa.eq(inp.mantissa << shift_amount)
+
+        m.submodules.of = of = OverflowMod()
         m.d.comb += [
-            out.roundz.eq(),
-            out.z.eq(),
-            out.out_do_z.eq(),
-            out.oz.eq(),
-            out.ctx.eq(),
-            out.rm.eq(),
+            pri_enc.i.eq(inp.mantissa[::-1]),
+            of.guard.eq(n_mantissa[-(out.z.m.width + 1)]),
+            of.round_bit.eq(n_mantissa[-(out.z.m.width + 2)]),
+            of.sticky.eq(n_mantissa[:-(out.z.m.width + 2)].bool()),
+            of.m0.eq(out.z.m[0]),
+            of.fpflags.eq(0),
+            of.sign.eq(inp.sign),
+            of.rm.eq(inp.rm),
+            out.roundz.eq(of.roundz_out),
+            out.z.s.eq(inp.sign),
+            out.z.e.eq(inp.exponent - shift_amount),
+            out.z.m.eq(n_mantissa[-out.z.m.width:]),
+            out.out_do_z.eq(inp.do_bypass),
+            out.oz.eq(inp.bypassed_z),
+            out.ctx.eq(inp.ctx),
+            out.rm.eq(inp.rm),
         ]
         return m
 
index f0b928d0235c2e1602f2fd5e17b3df4f33e9243f..3661d3c40f7ac2fb1799614dff44f74affde4e8a 100644 (file)
@@ -4,7 +4,7 @@ computes `z = (a * c) + b` but only rounds once at the end
 """
 
 from nmutil.singlepipe import ControlBase
-from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNorm
+from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormStage
 from ieee754.fpfma.main_stage import FPFMAMainStage
 from ieee754.fpfma.norm import FPFMANormToPack
 
@@ -12,7 +12,7 @@ from ieee754.fpfma.norm import FPFMANormToPack
 class FPFMABasePipe(ControlBase):
     def __init__(self, pspec):
         super().__init__()
-        self.sc_denorm = FPFMASpecialCasesDeNorm(pspec)
+        self.sc_denorm = FPFMASpecialCasesDeNormStage(pspec)
         self.main = FPFMAMainStage(pspec)
         self.normpack = FPFMANormToPack(pspec)
         self._eqs = self.connect([self.sc_denorm, self.main, self.normpack])
index 95d3026692465afeba1bf81955b87c3123b9e779..826c32a8e80f2a121a585d7c83bb65e436e97d50 100644 (file)
@@ -3,14 +3,16 @@
 computes `z = (a * c) + b` but only rounds once at the end
 """
 
-from nmutil.pipemodbase import PipeModBase
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.basedata import FPBaseData
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from ieee754.fpcommon.getop import FPPipeContext
 from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
 from ieee754.fpfma.util import expanded_exponent_shape, \
-    expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
+    expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape, \
+    EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \
+    product_mantissa_shape
 
 
 class FPFMAInputData(FPBaseData):
@@ -52,13 +54,13 @@ class FPFMASpecialCasesDeNormOutData:
         self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
         """mantissa of a input -- un-normalized and with implicit bit added"""
 
-        self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+        self.b_mantissa = Signal(expanded_mantissa_shape(fpf))
         """mantissa of b input
 
         shifted to appropriate location for add and with implicit bit added
         """
 
-        self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
+        self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf))
         """mantissa of c input -- un-normalized and with implicit bit added"""
 
         self.do_sub = Signal()
@@ -123,15 +125,30 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
         out = self.o
 
         a_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
+        m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a))
         b_exponent_in = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
+        m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b))
         c_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
+        m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c))
+        b_exponent = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS)
         prod_exponent = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
+
+        # number of bits that the product of two normalized signals needs to
+        # be shifted left to be normalized, e.g. the product of 2 8-bit
+        # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted
+        # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again:
+        # `0x4000 << 1 == 0x8000`
+        PROD_STAY_NORM_SHIFT = 1
+
+        extra_prod_exponent = (expanded_mantissa_shape(fpf).width
+                               - product_mantissa_shape(fpf).width
+                               + PROD_STAY_NORM_SHIFT
+                               - EXPANDED_MANTISSA_EXTRA_LSBS)
+        m.d.comb += prod_exponent.eq(a_exponent + c_exponent
+                                     + extra_prod_exponent)
         prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
-        m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
+        m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent)
         b_mantissa_in = Signal(fpf.fraction_width + 1)
         m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
         p_sign = Signal()
@@ -150,30 +167,37 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
             ]
         with m.Else():
             m.d.comb += [
-                exponent.eq(b_exponent_in),
+                exponent.eq(b_exponent),
                 b_shift.eq(0),
             ]
 
-        m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
+        m.submodules.rshiftm = rshiftm = MultiShiftRMerge(
+            out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS,
+            s_max=expanded_exponent_shape(fpf).width - 1)
         m.d.comb += [
-            rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
-                                             - b_mantissa_in.width)),
+            rshiftm.inp.eq(0),
+            rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in),
             rshiftm.diff.eq(b_shift),
         ]
 
+        keep = {"keep": True}
+
         # handle special cases
         with m.If(fpf.is_nan(inp.a)):
             m.d.comb += [
+                Signal(name="case_nan_a", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_nan(inp.b)):
             m.d.comb += [
+                Signal(name="case_nan_b", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_nan(inp.c)):
             m.d.comb += [
+                Signal(name="case_nan_c", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
                 out.do_bypass.eq(True),
             ]
@@ -181,37 +205,50 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
                     | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
             # infinity * 0
             m.d.comb += [
+                Signal(name="case_inf_times_zero", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.quiet_nan()),
                 out.do_bypass.eq(True),
             ]
         with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
-                    & fpf.is_inf(inp.b) & p_sign != b_sign):
+                    & fpf.is_inf(inp.b) & (p_sign != b_sign)):
             # inf - inf
             m.d.comb += [
+                Signal(name="case_inf_minus_inf", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.quiet_nan()),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
             # inf + x
             m.d.comb += [
+                Signal(name="case_inf_plus_x", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.inf(p_sign)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif(fpf.is_inf(inp.b)):
             # x + inf
             m.d.comb += [
+                Signal(name="case_x_plus_inf", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.inf(b_sign)),
                 out.do_bypass.eq(True),
             ]
         with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
-                    & fpf.is_zero(inp.b) & p_sign == b_sign):
+                    & fpf.is_zero(inp.b) & (p_sign == b_sign)):
             # zero + zero
             m.d.comb += [
+                Signal(name="case_zero_plus_zero", attrs=keep).eq(True),
                 out.bypassed_z.eq(fpf.zero(p_sign)),
                 out.do_bypass.eq(True),
             ]
-            # zero - zero handled by FPFMAMainStage
+        with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
+                    & ~fpf.is_zero(inp.b)):
+            # zero + x
+            m.d.comb += [
+                Signal(name="case_zero_plus_x", attrs=keep).eq(True),
+                out.bypassed_z.eq(inp.b),
+                out.do_bypass.eq(True),
+            ]
         with m.Else():
+            # zero - zero handled by FPFMAMainStage
             m.d.comb += [
                 out.bypassed_z.eq(0),
                 out.do_bypass.eq(False),
@@ -229,3 +266,13 @@ class FPFMASpecialCasesDeNorm(PipeModBase):
         ]
 
         return m
+
+
+class FPFMASpecialCasesDeNormStage(PipeModBaseChain):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+
+    def get_chain(self):
+        """ gets chain of modules
+        """
+        return [FPFMASpecialCasesDeNorm(self.pspec)]
diff --git a/src/ieee754/fpfma/test/__init__.py b/src/ieee754/fpfma/test/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/fpfma/test/test_fma_formal.py b/src/ieee754/fpfma/test/test_fma_formal.py
new file mode 100644 (file)
index 0000000..7cea7b8
--- /dev/null
@@ -0,0 +1,559 @@
+import unittest
+from nmutil.formaltest import FHDLTestCase
+from ieee754.fpfma.pipeline import FPFMABasePipe
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume, Mux
+from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
+    SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, SmtBool, \
+    SmtRoundingMode, ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE, SmtBitVec
+from ieee754.fpcommon.fpbase import FPFormat, FPRoundingMode
+from ieee754.pipeline import PipelineSpec
+import os
+
+ENABLE_FMA_F32_FORMAL = os.getenv("ENABLE_FMA_F32_FORMAL") is not None
+
+
+class TestFMAFormal(FHDLTestCase):
+    @unittest.skip("not finished implementing")  # FIXME: remove skip
+    def tst_fma_formal(self, sort, rm, negate_addend, negate_product):
+        assert isinstance(sort, SmtSortFloatingPoint)
+        assert isinstance(rm, FPRoundingMode)
+        assert isinstance(negate_addend, bool)
+        assert isinstance(negate_product, bool)
+        width = sort.width
+        pspec = PipelineSpec(width, id_width=4, n_ops=3)
+        pspec.fpformat = FPFormat(e_width=sort.eb,
+                                  m_width=sort.mantissa_field_width)
+        dut = FPFMABasePipe(pspec)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.n.i_ready.eq(True)
+        m.d.comb += dut.p.i_valid.eq(Initial())
+        m.d.comb += dut.p.i_data.rm.eq(Mux(Initial(), rm, 0))
+        out = Signal(width)
+        out_full = Signal(reset=False)
+        with m.If(dut.n.trigger):
+            # check we only got output for one cycle
+            m.d.comb += Assert(~out_full)
+            m.d.sync += out.eq(dut.n.o_data.z)
+            m.d.sync += out_full.eq(True)
+        a = Signal(width)
+        b = Signal(width)
+        c = Signal(width)
+        with m.If(Initial() | True):  # FIXME: remove | True
+            m.d.comb += [
+                dut.p.i_data.a.eq(a),
+                dut.p.i_data.b.eq(b),
+                dut.p.i_data.c.eq(c),
+                dut.p.i_data.negate_addend.eq(negate_addend),
+                dut.p.i_data.negate_product.eq(negate_product),
+            ]
+
+        def smt_op(a_fp, b_fp, c_fp, rm):
+            assert isinstance(a_fp, SmtFloatingPoint)
+            assert isinstance(b_fp, SmtFloatingPoint)
+            assert isinstance(c_fp, SmtFloatingPoint)
+            assert isinstance(rm, SmtRoundingMode)
+            if negate_addend:
+                b_fp = -b_fp
+            if negate_product:
+                a_fp = -a_fp
+            return a_fp.fma(c_fp, b_fp, rm=rm)
+        a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
+        b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
+        c_fp = SmtFloatingPoint.from_bits(c, sort=sort)
+        out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
+        if rm in (FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE,
+                  FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
+            rounded_up = Signal(width)
+            m.d.comb += rounded_up.eq(AnyConst(width))
+            rounded_up_fp = smt_op(a_fp, b_fp, c_fp, rm=ROUND_TOWARD_POSITIVE)
+            rounded_down_fp = smt_op(a_fp, b_fp, c_fp,
+                                     rm=ROUND_TOWARD_NEGATIVE)
+            m.d.comb += Assume(SmtFloatingPoint.from_bits(
+                rounded_up, sort=sort).same(rounded_up_fp).as_value())
+            use_rounded_up = SmtBool.make(rounded_up[0])
+            if rm is FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE:
+                is_zero = rounded_up_fp.is_zero() & rounded_down_fp.is_zero()
+                use_rounded_up |= is_zero
+            expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
+        else:
+            smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
+            expected_fp = smt_op(a_fp, b_fp, c_fp, rm=smt_rm)
+        expected = Signal(width)
+        m.d.comb += expected.eq(AnyConst(width))
+        quiet_bit = 1 << (sort.mantissa_field_width - 1)
+        nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
+        with m.If(expected_fp.is_nan().as_value()):
+            with m.If(a_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (a | quiet_bit))
+            with m.Elif(b_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (b | quiet_bit))
+            with m.Elif(c_fp.is_nan().as_value()):
+                m.d.comb += Assume(expected == (c | quiet_bit))
+            with m.Else():
+                m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
+        with m.Else():
+            m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
+                               .same(expected_fp).as_value())
+        m.d.comb += a.eq(AnyConst(width))
+        m.d.comb += b.eq(AnyConst(width))
+        m.d.comb += c.eq(AnyConst(width))
+        with m.If(out_full):
+            m.d.comb += Assert(out_fp.same(expected_fp).as_value())
+            m.d.comb += Assert(out == expected)
+
+        def fp_from_int(v):
+            return SmtFloatingPoint.from_signed_bv(
+                SmtBitVec.make(v, width=128),
+                rm=ROUND_TOWARD_POSITIVE, sort=sort)
+
+        # FIXME: remove:
+        if False:
+            m.d.comb += Assume(a == 0x05C1)
+            m.d.comb += Assume(b == 0x877F)
+            m.d.comb += Assume(c == 0x7437)
+            with m.If(out_full):
+                m.d.comb += Assert(out == 0x0000)
+                m.d.comb += Assert(out == 0x0001)
+
+        self.assertFormal(m, depth=5, solver="bitwuzla")
+
+    # FIXME: check exception flags
+
+    def test_fmadd_f16_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=True)
+
+    def test_fmadd_f16_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=False)
+
+    def test_fmsub_f16_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=False)
+
+    def test_fnmadd_f16_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=True)
+
+    def test_fnmsub_f16_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat16(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmadd_f32_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fmsub_f32_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmadd_f32_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skipUnless(ENABLE_FMA_F32_FORMAL,
+                         "ENABLE_FMA_F32_FORMAL not in environ")
+    def test_fnmsub_f32_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat32(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rne_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNE,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rtz_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTZ,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rtp_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTP,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rtn_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTN,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rna_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RNA,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rtop_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTOP,
+                            negate_addend=False, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fmadd_f64_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fmsub_f64_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=False)
+
+    @unittest.skip("too slow")
+    def test_fnmadd_f64_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON,
+                            negate_addend=True, negate_product=True)
+
+    @unittest.skip("too slow")
+    def test_fnmsub_f64_rton_formal(self):
+        self.tst_fma_formal(sort=SmtSortFloat64(), rm=FPRoundingMode.RTON,
+                            negate_addend=False, negate_product=True)
+
+    def test_all_rounding_modes_covered(self):
+        for width in 16, 32, 64:
+            for rm in FPRoundingMode:
+                rm_s = rm.name.lower()
+                name = f"test_fmadd_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
+                name = f"test_fmsub_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
+                name = f"test_fnmadd_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
+                name = f"test_fnmsub_f{width}_{rm_s}_formal"
+                assert callable(getattr(self, name))
+
+
+if __name__ == '__main__':
+    unittest.main()
index 5372ab8fb412a31459afed5e3bde2b2e56309868..518a5d9e6e4dba4b1c9f75f8558b0d2788fefe8d 100644 (file)
@@ -7,13 +7,48 @@ def expanded_exponent_shape(fpformat):
     return signed(fpformat.e_width + 3)
 
 
-EXPANDED_MANTISSA_EXTRA_LSBS = 3
+EXPANDED_MANTISSA_SPACE_BETWEEN_SUM_PROD = 16  # FIXME: change back to 3
+r""" the number of bits of space between the lsb of a large addend and the msb
+of the product of two small factors to guarantee that the product ends up
+entirely in the sticky bit.
+
+e.g. let's assume the floating point format has
+5 mantissa bits (4 bits in the field + 1 implicit bit):
+
+if `a` and `b` are `0b11111` and `c` is `0b11111 * 2**-50`, and we are
+computing `a * c + b`:
+
+the computed mantissa would be:
+
+```text
+      sticky bit
+         |
+         v
+0b111110001111000001
+  \-b-/   \-product/
+```
+
+(note this isn't the mathematically correct
+answer, but it rounds to the correct floating-point answer and takes
+less hardware)
+"""
+
+# the number of extra LSBs needed by the expanded mantissa to avoid
+# having a tiny addend conflict with the lsb of the product.
+EXPANDED_MANTISSA_EXTRA_LSBS = 16  # FIXME: change back to 2
+
+
+# the number of extra MSBs needed by the expanded mantissa to avoid
+# overflowing. 2 bits -- 1 bit for carry out of addition, 1 bit for sign.
+EXPANDED_MANTISSA_EXTRA_MSBS = 16  # FIXME: change back to 2
 
 
 def expanded_mantissa_shape(fpformat):
     assert isinstance(fpformat, FPFormat)
-    return signed(fpformat.fraction_width * 3 +
-                  2 + EXPANDED_MANTISSA_EXTRA_LSBS)
+    return signed((fpformat.fraction_width + 1) * 3
+                  + EXPANDED_MANTISSA_EXTRA_MSBS
+                  + EXPANDED_MANTISSA_SPACE_BETWEEN_SUM_PROD
+                  + EXPANDED_MANTISSA_EXTRA_LSBS)
 
 
 def multiplicand_mantissa_shape(fpformat):