start implementing fp fused-mul-add pipeline
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 2 Jul 2022 02:59:39 +0000 (19:59 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 2 Jul 2022 02:59:39 +0000 (19:59 -0700)
https://bugs.libre-soc.org/show_bug.cgi?id=877

src/ieee754/fpcommon/fpbase.py
src/ieee754/fpfma/__init__.py [new file with mode: 0644]
src/ieee754/fpfma/main_stage.py [new file with mode: 0644]
src/ieee754/fpfma/norm.py [new file with mode: 0644]
src/ieee754/fpfma/pipeline.py [new file with mode: 0644]
src/ieee754/fpfma/special_cases.py [new file with mode: 0644]
src/ieee754/fpfma/util.py [new file with mode: 0644]

index 417a66341c30bc3f81f13cefaa09266204fe801e..f3e68fe6add8882507a92c63aa3604886cf3726e 100644 (file)
@@ -6,7 +6,8 @@ Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
 """
 
 
-from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array, Value
+from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
+                    Value, Shape)
 from math import log
 from operator import or_
 from functools import reduce
@@ -319,6 +320,17 @@ class FPFormat:
         """
         return x & self.mantissa_mask
 
+    def get_mantissa_value(self, x):
+        """ returns the mantissa of its input number, x, but with the
+        implicit bit, if any, made explicit.
+        """
+        if self.has_int_bit:
+            return self.get_mantissa_field(x)
+        exponent_field = self.get_exponent_field(x)
+        mantissa_field = self.get_mantissa_field(x)
+        implicit_bit = exponent_field == self.exponent_denormal_zero
+        return (implicit_bit << self.fraction_width) | mantissa_field
+
     def is_zero(self, x):
         """ returns true if x is +/- zero
         """
@@ -351,6 +363,23 @@ class FPFormat:
             (self.get_mantissa_field(x) != 0) & \
             (self.get_mantissa_field(x) & highbit != 0)
 
+    def to_quiet_nan(self, x):
+        """ converts `x` to a quiet NaN """
+        highbit = 1 << (self.m_width - 1)
+        return x | highbit | self.exponent_mask
+
+    def quiet_nan(self, sign=0):
+        """ return the default quiet NaN with sign `sign` """
+        return self.to_quiet_nan(self.zero(sign))
+
+    def zero(self, sign=0):
+        """ return zero with sign `sign` """
+        return (sign != 0) << (self.e_width + self.m_width)
+
+    def inf(self, sign=0):
+        """ return infinity with sign `sign` """
+        return self.zero(sign) | self.exponent_mask
+
     def is_nan_signaling(self, x):
         """ returns true if x is a signalling nan
         """
@@ -369,6 +398,11 @@ class FPFormat:
         """ Get a mantissa mask based on the mantissa width """
         return (1 << self.m_width) - 1
 
+    @property
+    def exponent_mask(self):
+        """ Get an exponent mask """
+        return self.exponent_inf_nan << self.m_width
+
     @property
     def exponent_inf_nan(self):
         """ Get the value of the exponent field designating infinity/NaN. """
@@ -775,7 +809,7 @@ class MultiShiftRMerge(Elaboratable):
     def __init__(self, width, s_max=None):
         if s_max is None:
             s_max = int(log(width) / log(2))
-        self.smax = s_max
+        self.smax = Shape.cast(s_max)
         self.m = Signal(width, reset_less=True)
         self.inp = Signal(width, reset_less=True)
         self.diff = Signal(s_max, reset_less=True)
@@ -789,8 +823,8 @@ class MultiShiftRMerge(Elaboratable):
         smask = Signal(self.width, reset_less=True)
         stickybit = Signal(reset_less=True)
         # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
-        maxslen = Signal(self.smax[0], reset_less=True)
-        maxsleni = Signal(self.smax[0], reset_less=True)
+        maxslen = Signal(self.smax.width, reset_less=True)
+        maxsleni = Signal(self.smax.width, reset_less=True)
 
         sm = MultiShift(self.width-1)
         m0s = Const(0, self.width-1)
diff --git a/src/ieee754/fpfma/__init__.py b/src/ieee754/fpfma/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/fpfma/main_stage.py b/src/ieee754/fpfma/main_stage.py
new file mode 100644 (file)
index 0000000..1ab2b2b
--- /dev/null
@@ -0,0 +1,99 @@
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.pipemodbase import PipeModBase
+from ieee754.fpcommon.fpbase import FPRoundingMode
+from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ast import Signal, signed, unsigned, Mux
+from ieee754.fpfma.util import expanded_exponent_shape, \
+    expanded_mantissa_shape, get_fpformat
+from ieee754.fpcommon.getop import FPPipeContext
+
+
+class FPFMAPostCalcData:
+    def __init__(self, pspec):
+        fpf = get_fpformat(pspec)
+
+        self.sign = Signal()
+        """sign"""
+
+        self.exponent = Signal(expanded_exponent_shape(fpf))
+        """exponent -- unbiased"""
+
+        self.mantissa = Signal(expanded_mantissa_shape(fpf))
+        """unnormalized mantissa"""
+
+        self.bypassed_z = Signal(fpf.width)
+        """final output value of the fma when `do_bypass` is set"""
+
+        self.do_bypass = Signal()
+        """set if `bypassed_z` is the final output value of the fma"""
+
+        self.ctx = FPPipeContext(pspec)
+        """pipe context"""
+
+        self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
+        """rounding mode"""
+
+
+class FPFMAMainStage(PipeModBase):
+    def __init__(self, pspec):
+        super().__init__(pspec, "main")
+
+    def ispec(self):
+        return FPFMASpecialCasesDeNormOutData(self.pspec)
+
+    def ospec(self):
+        return FPFMAPostCalcData(self.pspec)
+
+    def elaborate(self, platform):
+        m = Module()
+        fpf = get_fpformat(self.pspec)
+        assert fpf.has_sign
+        inp = self.i
+        out = self.o
+
+        product_v = inp.a_mantissa * inp.c_mantissa
+        product = Signal(product_v.shape())
+        m.d.comb += product.eq(product_v)
+        negate_b_s = Signal(signed(1))
+        negate_b_u = Signal(unsigned(1))
+        m.d.comb += [
+            negate_b_s.eq(inp.do_sub),
+            negate_b_u.eq(inp.do_sub),
+        ]
+        sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u
+        sum = Signal(sum_v.shape())
+        m.d.comb += sum.eq(sum_v)
+
+        sum_neg = Signal()
+        sum_zero = Signal()
+        m.d.comb += [
+            sum_neg.eq(sum < 0),  # just sign bit
+            sum_zero.eq(sum == 0),
+        ]
+
+        zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
+
+        with m.If(sum_zero & ~inp.do_bypass):
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.zero(zero_sign_array[inp.rm])),
+                out.do_bypass.eq(True),
+            ]
+        with m.Else():
+            m.d.comb += [
+                out.bypassed_z.eq(inp.bypassed_z),
+                out.do_bypass.eq(inp.do_bypass),
+            ]
+
+        m.d.comb += [
+            out.sign.eq(sum_neg ^ inp.sign),
+            out.exponent.eq(inp.exponent),
+            out.mantissa.eq(Mux(sum_neg, -sum, sum)),
+            out.ctx.eq(inp.ctx),
+            out.rm.eq(inp.rm),
+        ]
+        return m
diff --git a/src/ieee754/fpfma/norm.py b/src/ieee754/fpfma/norm.py
new file mode 100644 (file)
index 0000000..0c16f45
--- /dev/null
@@ -0,0 +1,51 @@
+from nmutil.pipemodbase import PipeModBaseChain, PipeModBase
+from ieee754.fpcommon.postnormalise import FPNorm1Data
+from ieee754.fpcommon.roundz import FPRoundMod
+from ieee754.fpcommon.corrections import FPCorrectionsMod
+from ieee754.fpcommon.pack import FPPackMod
+from ieee754.fpfma.main_stage import FPFMAPostCalcData
+from nmigen.hdl.dsl import Module
+
+from ieee754.fpfma.util import get_fpformat
+
+
+class FPFMANorm(PipeModBase):
+    def __init__(self, pspec):
+        super().__init__(pspec, "norm")
+
+    def ispec(self):
+        return FPFMAPostCalcData(self.pspec)
+
+    def ospec(self):
+        return FPNorm1Data(self.pspec)
+
+    def elaborate(self, platform):
+        m = Module()
+        fpf = get_fpformat(self.pspec)
+        assert fpf.has_sign
+        inp = self.i
+        out = self.o
+        raise NotImplementedError  # FIXME: finish
+        m.d.comb += [
+            out.roundz.eq(),
+            out.z.eq(),
+            out.out_do_z.eq(),
+            out.oz.eq(),
+            out.ctx.eq(),
+            out.rm.eq(),
+        ]
+        return m
+
+
+class FPFMANormToPack(PipeModBaseChain):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+
+    def get_chain(self):
+        """ gets chain of modules
+        """
+        nmod = FPFMANorm(self.pspec)
+        rmod = FPRoundMod(self.pspec)
+        cmod = FPCorrectionsMod(self.pspec)
+        pmod = FPPackMod(self.pspec)
+        return [nmod, rmod, cmod, pmod]
diff --git a/src/ieee754/fpfma/pipeline.py b/src/ieee754/fpfma/pipeline.py
new file mode 100644 (file)
index 0000000..f0b928d
--- /dev/null
@@ -0,0 +1,26 @@
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.singlepipe import ControlBase
+from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNorm
+from ieee754.fpfma.main_stage import FPFMAMainStage
+from ieee754.fpfma.norm import FPFMANormToPack
+
+
+class FPFMABasePipe(ControlBase):
+    def __init__(self, pspec):
+        super().__init__()
+        self.sc_denorm = FPFMASpecialCasesDeNorm(pspec)
+        self.main = FPFMAMainStage(pspec)
+        self.normpack = FPFMANormToPack(pspec)
+        self._eqs = self.connect([self.sc_denorm, self.main, self.normpack])
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        m.submodules.sc_denorm = self.sc_denorm
+        m.submodules.main = self.main
+        m.submodules.normpack = self.normpack
+        m.d.comb += self._eqs
+        return m
diff --git a/src/ieee754/fpfma/special_cases.py b/src/ieee754/fpfma/special_cases.py
new file mode 100644 (file)
index 0000000..95d3026
--- /dev/null
@@ -0,0 +1,231 @@
+""" floating-point fused-multiply-add
+
+computes `z = (a * c) + b` but only rounds once at the end
+"""
+
+from nmutil.pipemodbase import PipeModBase
+from ieee754.fpcommon.basedata import FPBaseData
+from nmigen.hdl.ast import Signal
+from nmigen.hdl.dsl import Module
+from ieee754.fpcommon.getop import FPPipeContext
+from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
+from ieee754.fpfma.util import expanded_exponent_shape, \
+    expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
+
+
+class FPFMAInputData(FPBaseData):
+    def __init__(self, pspec):
+        assert pspec.n_ops == 3
+        super().__init__(pspec)
+
+        self.negate_addend = Signal()
+        """if the addend should be negated"""
+
+        self.negate_product = Signal()
+        """if the product should be negated"""
+
+    def eq(self, i):
+        ret = super().eq(i)
+        ret.append(self.negate_addend.eq(i.negate_addend))
+        ret.append(self.negate_product.eq(i.negate_product))
+        return ret
+
+    def __iter__(self):
+        yield from super().__iter__()
+        yield self.negate_addend
+        yield self.negate_product
+
+    def ports(self):
+        return list(self)
+
+
+class FPFMASpecialCasesDeNormOutData:
+    def __init__(self, pspec):
+        fpf = get_fpformat(pspec)
+
+        self.sign = Signal()
+        """sign"""
+
+        self.exponent = Signal(expanded_exponent_shape(fpf))
+        """exponent of intermediate -- unbiased"""
+
+        self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+        """mantissa of a input -- un-normalized and with implicit bit added"""
+
+        self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
+        """mantissa of b input
+
+        shifted to appropriate location for add and with implicit bit added
+        """
+
+        self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
+        """mantissa of c input -- un-normalized and with implicit bit added"""
+
+        self.do_sub = Signal()
+        """true if `b_mantissa` should be subtracted from
+        `a_mantissa * c_mantissa` rather than added
+        """
+
+        self.bypassed_z = Signal(fpf.width)
+        """final output value of the fma when `do_bypass` is set"""
+
+        self.do_bypass = Signal()
+        """set if `bypassed_z` is the final output value of the fma"""
+
+        self.ctx = FPPipeContext(pspec)
+        """pipe context"""
+
+        self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
+        """rounding mode"""
+
+    def __iter__(self):
+        yield self.sign
+        yield self.exponent
+        yield self.a_mantissa
+        yield self.b_mantissa
+        yield self.c_mantissa
+        yield self.do_sub
+        yield self.bypassed_z
+        yield self.do_bypass
+        yield from self.ctx
+        yield self.rm
+
+    def eq(self, i):
+        return [
+            self.sign.eq(i.sign),
+            self.exponent.eq(i.exponent),
+            self.a_mantissa.eq(i.a_mantissa),
+            self.b_mantissa.eq(i.b_mantissa),
+            self.c_mantissa.eq(i.c_mantissa),
+            self.do_sub.eq(i.do_sub),
+            self.bypassed_z.eq(i.bypassed_z),
+            self.do_bypass.eq(i.do_bypass),
+            self.ctx.eq(i.ctx),
+            self.rm.eq(i.rm),
+        ]
+
+
+class FPFMASpecialCasesDeNorm(PipeModBase):
+    def __init__(self, pspec):
+        super().__init__(pspec, "sc_denorm")
+
+    def ispec(self):
+        return FPFMAInputData(self.pspec)
+
+    def ospec(self):
+        return FPFMASpecialCasesDeNormOutData(self.pspec)
+
+    def elaborate(self, platform):
+        m = Module()
+        fpf = get_fpformat(self.pspec)
+        assert fpf.has_sign
+        inp = self.i
+        out = self.o
+
+        a_exponent = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
+        b_exponent_in = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
+        c_exponent = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
+        prod_exponent = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
+        prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
+        m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
+        b_mantissa_in = Signal(fpf.fraction_width + 1)
+        m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
+        p_sign = Signal()
+        m.d.comb += p_sign.eq(fpf.get_sign_field(inp.a) ^
+                              fpf.get_sign_field(inp.c) ^ inp.negate_product)
+        b_sign = Signal()
+        m.d.comb += b_sign.eq(fpf.get_sign_field(inp.b) ^ inp.negate_addend)
+
+        exponent = Signal(expanded_exponent_shape(fpf))
+        b_shift = Signal(expanded_exponent_shape(fpf))
+        # use >= since that's just checking the sign bit
+        with m.If(prod_exp_minus_b_exp >= 0):
+            m.d.comb += [
+                exponent.eq(prod_exponent),
+                b_shift.eq(prod_exp_minus_b_exp),
+            ]
+        with m.Else():
+            m.d.comb += [
+                exponent.eq(b_exponent_in),
+                b_shift.eq(0),
+            ]
+
+        m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
+        m.d.comb += [
+            rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
+                                             - b_mantissa_in.width)),
+            rshiftm.diff.eq(b_shift),
+        ]
+
+        # handle special cases
+        with m.If(fpf.is_nan(inp.a)):
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif(fpf.is_nan(inp.b)):
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif(fpf.is_nan(inp.c)):
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif((fpf.is_zero(inp.a) & fpf.is_inf(inp.c))
+                    | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
+            # infinity * 0
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.quiet_nan()),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
+                    & fpf.is_inf(inp.b) & p_sign != b_sign):
+            # inf - inf
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.quiet_nan()),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
+            # inf + x
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.inf(p_sign)),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif(fpf.is_inf(inp.b)):
+            # x + inf
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.inf(b_sign)),
+                out.do_bypass.eq(True),
+            ]
+        with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
+                    & fpf.is_zero(inp.b) & p_sign == b_sign):
+            # zero + zero
+            m.d.comb += [
+                out.bypassed_z.eq(fpf.zero(p_sign)),
+                out.do_bypass.eq(True),
+            ]
+            # zero - zero handled by FPFMAMainStage
+        with m.Else():
+            m.d.comb += [
+                out.bypassed_z.eq(0),
+                out.do_bypass.eq(False),
+            ]
+
+        m.d.comb += [
+            out.sign.eq(p_sign),
+            out.exponent.eq(exponent),
+            out.a_mantissa.eq(fpf.get_mantissa_value(inp.a)),
+            out.b_mantissa.eq(rshiftm.m),
+            out.c_mantissa.eq(fpf.get_mantissa_value(inp.c)),
+            out.do_sub.eq(p_sign != b_sign),
+            out.ctx.eq(inp.ctx),
+            out.rm.eq(inp.rm),
+        ]
+
+        return m
diff --git a/src/ieee754/fpfma/util.py b/src/ieee754/fpfma/util.py
new file mode 100644 (file)
index 0000000..5372ab8
--- /dev/null
@@ -0,0 +1,38 @@
+from ieee754.fpcommon.fpbase import FPFormat
+from nmigen.hdl.ast import signed, unsigned
+
+
+def expanded_exponent_shape(fpformat):
+    assert isinstance(fpformat, FPFormat)
+    return signed(fpformat.e_width + 3)
+
+
+EXPANDED_MANTISSA_EXTRA_LSBS = 3
+
+
+def expanded_mantissa_shape(fpformat):
+    assert isinstance(fpformat, FPFormat)
+    return signed(fpformat.fraction_width * 3 +
+                  2 + EXPANDED_MANTISSA_EXTRA_LSBS)
+
+
+def multiplicand_mantissa_shape(fpformat):
+    assert isinstance(fpformat, FPFormat)
+    return unsigned(fpformat.fraction_width + 1)
+
+
+def product_mantissa_shape(fpformat):
+    assert isinstance(fpformat, FPFormat)
+    return unsigned(multiplicand_mantissa_shape(fpformat).width * 2)
+
+
+def get_fpformat(pspec):
+    width = pspec.width
+    assert isinstance(width, int)
+    fpformat = getattr(pspec, "fpformat", None)
+    if fpformat is None:
+        fpformat = FPFormat.standard(width)
+    else:
+        assert isinstance(fpformat, FPFormat)
+    assert width == fpformat.width
+    return fpformat