From aa614b6468513fff3c47d9afba9a1e61107df738 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 1 Jul 2022 19:59:39 -0700 Subject: [PATCH] start implementing fp fused-mul-add pipeline https://bugs.libre-soc.org/show_bug.cgi?id=877 --- src/ieee754/fpcommon/fpbase.py | 42 +++++- src/ieee754/fpfma/__init__.py | 0 src/ieee754/fpfma/main_stage.py | 99 +++++++++++++ src/ieee754/fpfma/norm.py | 51 +++++++ src/ieee754/fpfma/pipeline.py | 26 ++++ src/ieee754/fpfma/special_cases.py | 231 +++++++++++++++++++++++++++++ src/ieee754/fpfma/util.py | 38 +++++ 7 files changed, 483 insertions(+), 4 deletions(-) create mode 100644 src/ieee754/fpfma/__init__.py create mode 100644 src/ieee754/fpfma/main_stage.py create mode 100644 src/ieee754/fpfma/norm.py create mode 100644 src/ieee754/fpfma/pipeline.py create mode 100644 src/ieee754/fpfma/special_cases.py create mode 100644 src/ieee754/fpfma/util.py diff --git a/src/ieee754/fpcommon/fpbase.py b/src/ieee754/fpcommon/fpbase.py index 417a6634..f3e68fe6 100644 --- a/src/ieee754/fpcommon/fpbase.py +++ b/src/ieee754/fpcommon/fpbase.py @@ -6,7 +6,8 @@ Copyright (C) 2019,2022 Jacob Lifshay """ -from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array, Value +from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array, + Value, Shape) from math import log from operator import or_ from functools import reduce @@ -319,6 +320,17 @@ class FPFormat: """ return x & self.mantissa_mask + def get_mantissa_value(self, x): + """ returns the mantissa of its input number, x, but with the + implicit bit, if any, made explicit. + """ + if self.has_int_bit: + return self.get_mantissa_field(x) + exponent_field = self.get_exponent_field(x) + mantissa_field = self.get_mantissa_field(x) + implicit_bit = exponent_field == self.exponent_denormal_zero + return (implicit_bit << self.fraction_width) | mantissa_field + def is_zero(self, x): """ returns true if x is +/- zero """ @@ -351,6 +363,23 @@ class FPFormat: (self.get_mantissa_field(x) != 0) & \ (self.get_mantissa_field(x) & highbit != 0) + def to_quiet_nan(self, x): + """ converts `x` to a quiet NaN """ + highbit = 1 << (self.m_width - 1) + return x | highbit | self.exponent_mask + + def quiet_nan(self, sign=0): + """ return the default quiet NaN with sign `sign` """ + return self.to_quiet_nan(self.zero(sign)) + + def zero(self, sign=0): + """ return zero with sign `sign` """ + return (sign != 0) << (self.e_width + self.m_width) + + def inf(self, sign=0): + """ return infinity with sign `sign` """ + return self.zero(sign) | self.exponent_mask + def is_nan_signaling(self, x): """ returns true if x is a signalling nan """ @@ -369,6 +398,11 @@ class FPFormat: """ Get a mantissa mask based on the mantissa width """ return (1 << self.m_width) - 1 + @property + def exponent_mask(self): + """ Get an exponent mask """ + return self.exponent_inf_nan << self.m_width + @property def exponent_inf_nan(self): """ Get the value of the exponent field designating infinity/NaN. """ @@ -775,7 +809,7 @@ class MultiShiftRMerge(Elaboratable): def __init__(self, width, s_max=None): if s_max is None: s_max = int(log(width) / log(2)) - self.smax = s_max + self.smax = Shape.cast(s_max) self.m = Signal(width, reset_less=True) self.inp = Signal(width, reset_less=True) self.diff = Signal(s_max, reset_less=True) @@ -789,8 +823,8 @@ class MultiShiftRMerge(Elaboratable): smask = Signal(self.width, reset_less=True) stickybit = Signal(reset_less=True) # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302 - maxslen = Signal(self.smax[0], reset_less=True) - maxsleni = Signal(self.smax[0], reset_less=True) + maxslen = Signal(self.smax.width, reset_less=True) + maxsleni = Signal(self.smax.width, reset_less=True) sm = MultiShift(self.width-1) m0s = Const(0, self.width-1) diff --git a/src/ieee754/fpfma/__init__.py b/src/ieee754/fpfma/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ieee754/fpfma/main_stage.py b/src/ieee754/fpfma/main_stage.py new file mode 100644 index 00000000..1ab2b2b8 --- /dev/null +++ b/src/ieee754/fpfma/main_stage.py @@ -0,0 +1,99 @@ +""" floating-point fused-multiply-add + +computes `z = (a * c) + b` but only rounds once at the end +""" + +from nmutil.pipemodbase import PipeModBase +from ieee754.fpcommon.fpbase import FPRoundingMode +from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData +from nmigen.hdl.dsl import Module +from nmigen.hdl.ast import Signal, signed, unsigned, Mux +from ieee754.fpfma.util import expanded_exponent_shape, \ + expanded_mantissa_shape, get_fpformat +from ieee754.fpcommon.getop import FPPipeContext + + +class FPFMAPostCalcData: + def __init__(self, pspec): + fpf = get_fpformat(pspec) + + self.sign = Signal() + """sign""" + + self.exponent = Signal(expanded_exponent_shape(fpf)) + """exponent -- unbiased""" + + self.mantissa = Signal(expanded_mantissa_shape(fpf)) + """unnormalized mantissa""" + + self.bypassed_z = Signal(fpf.width) + """final output value of the fma when `do_bypass` is set""" + + self.do_bypass = Signal() + """set if `bypassed_z` is the final output value of the fma""" + + self.ctx = FPPipeContext(pspec) + """pipe context""" + + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + + +class FPFMAMainStage(PipeModBase): + def __init__(self, pspec): + super().__init__(pspec, "main") + + def ispec(self): + return FPFMASpecialCasesDeNormOutData(self.pspec) + + def ospec(self): + return FPFMAPostCalcData(self.pspec) + + def elaborate(self, platform): + m = Module() + fpf = get_fpformat(self.pspec) + assert fpf.has_sign + inp = self.i + out = self.o + + product_v = inp.a_mantissa * inp.c_mantissa + product = Signal(product_v.shape()) + m.d.comb += product.eq(product_v) + negate_b_s = Signal(signed(1)) + negate_b_u = Signal(unsigned(1)) + m.d.comb += [ + negate_b_s.eq(inp.do_sub), + negate_b_u.eq(inp.do_sub), + ] + sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u + sum = Signal(sum_v.shape()) + m.d.comb += sum.eq(sum_v) + + sum_neg = Signal() + sum_zero = Signal() + m.d.comb += [ + sum_neg.eq(sum < 0), # just sign bit + sum_zero.eq(sum == 0), + ] + + zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign) + + with m.If(sum_zero & ~inp.do_bypass): + m.d.comb += [ + out.bypassed_z.eq(fpf.zero(zero_sign_array[inp.rm])), + out.do_bypass.eq(True), + ] + with m.Else(): + m.d.comb += [ + out.bypassed_z.eq(inp.bypassed_z), + out.do_bypass.eq(inp.do_bypass), + ] + + m.d.comb += [ + out.sign.eq(sum_neg ^ inp.sign), + out.exponent.eq(inp.exponent), + out.mantissa.eq(Mux(sum_neg, -sum, sum)), + out.ctx.eq(inp.ctx), + out.rm.eq(inp.rm), + ] + return m diff --git a/src/ieee754/fpfma/norm.py b/src/ieee754/fpfma/norm.py new file mode 100644 index 00000000..0c16f452 --- /dev/null +++ b/src/ieee754/fpfma/norm.py @@ -0,0 +1,51 @@ +from nmutil.pipemodbase import PipeModBaseChain, PipeModBase +from ieee754.fpcommon.postnormalise import FPNorm1Data +from ieee754.fpcommon.roundz import FPRoundMod +from ieee754.fpcommon.corrections import FPCorrectionsMod +from ieee754.fpcommon.pack import FPPackMod +from ieee754.fpfma.main_stage import FPFMAPostCalcData +from nmigen.hdl.dsl import Module + +from ieee754.fpfma.util import get_fpformat + + +class FPFMANorm(PipeModBase): + def __init__(self, pspec): + super().__init__(pspec, "norm") + + def ispec(self): + return FPFMAPostCalcData(self.pspec) + + def ospec(self): + return FPNorm1Data(self.pspec) + + def elaborate(self, platform): + m = Module() + fpf = get_fpformat(self.pspec) + assert fpf.has_sign + inp = self.i + out = self.o + raise NotImplementedError # FIXME: finish + m.d.comb += [ + out.roundz.eq(), + out.z.eq(), + out.out_do_z.eq(), + out.oz.eq(), + out.ctx.eq(), + out.rm.eq(), + ] + return m + + +class FPFMANormToPack(PipeModBaseChain): + def __init__(self, pspec): + super().__init__(pspec) + + def get_chain(self): + """ gets chain of modules + """ + nmod = FPFMANorm(self.pspec) + rmod = FPRoundMod(self.pspec) + cmod = FPCorrectionsMod(self.pspec) + pmod = FPPackMod(self.pspec) + return [nmod, rmod, cmod, pmod] diff --git a/src/ieee754/fpfma/pipeline.py b/src/ieee754/fpfma/pipeline.py new file mode 100644 index 00000000..f0b928d0 --- /dev/null +++ b/src/ieee754/fpfma/pipeline.py @@ -0,0 +1,26 @@ +""" floating-point fused-multiply-add + +computes `z = (a * c) + b` but only rounds once at the end +""" + +from nmutil.singlepipe import ControlBase +from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNorm +from ieee754.fpfma.main_stage import FPFMAMainStage +from ieee754.fpfma.norm import FPFMANormToPack + + +class FPFMABasePipe(ControlBase): + def __init__(self, pspec): + super().__init__() + self.sc_denorm = FPFMASpecialCasesDeNorm(pspec) + self.main = FPFMAMainStage(pspec) + self.normpack = FPFMANormToPack(pspec) + self._eqs = self.connect([self.sc_denorm, self.main, self.normpack]) + + def elaborate(self, platform): + m = super().elaborate(platform) + m.submodules.sc_denorm = self.sc_denorm + m.submodules.main = self.main + m.submodules.normpack = self.normpack + m.d.comb += self._eqs + return m diff --git a/src/ieee754/fpfma/special_cases.py b/src/ieee754/fpfma/special_cases.py new file mode 100644 index 00000000..95d30266 --- /dev/null +++ b/src/ieee754/fpfma/special_cases.py @@ -0,0 +1,231 @@ +""" floating-point fused-multiply-add + +computes `z = (a * c) + b` but only rounds once at the end +""" + +from nmutil.pipemodbase import PipeModBase +from ieee754.fpcommon.basedata import FPBaseData +from nmigen.hdl.ast import Signal +from nmigen.hdl.dsl import Module +from ieee754.fpcommon.getop import FPPipeContext +from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge +from ieee754.fpfma.util import expanded_exponent_shape, \ + expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape + + +class FPFMAInputData(FPBaseData): + def __init__(self, pspec): + assert pspec.n_ops == 3 + super().__init__(pspec) + + self.negate_addend = Signal() + """if the addend should be negated""" + + self.negate_product = Signal() + """if the product should be negated""" + + def eq(self, i): + ret = super().eq(i) + ret.append(self.negate_addend.eq(i.negate_addend)) + ret.append(self.negate_product.eq(i.negate_product)) + return ret + + def __iter__(self): + yield from super().__iter__() + yield self.negate_addend + yield self.negate_product + + def ports(self): + return list(self) + + +class FPFMASpecialCasesDeNormOutData: + def __init__(self, pspec): + fpf = get_fpformat(pspec) + + self.sign = Signal() + """sign""" + + self.exponent = Signal(expanded_exponent_shape(fpf)) + """exponent of intermediate -- unbiased""" + + self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf)) + """mantissa of a input -- un-normalized and with implicit bit added""" + + self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf)) + """mantissa of b input + + shifted to appropriate location for add and with implicit bit added + """ + + self.c_mantissa = Signal(expanded_mantissa_shape(fpf)) + """mantissa of c input -- un-normalized and with implicit bit added""" + + self.do_sub = Signal() + """true if `b_mantissa` should be subtracted from + `a_mantissa * c_mantissa` rather than added + """ + + self.bypassed_z = Signal(fpf.width) + """final output value of the fma when `do_bypass` is set""" + + self.do_bypass = Signal() + """set if `bypassed_z` is the final output value of the fma""" + + self.ctx = FPPipeContext(pspec) + """pipe context""" + + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + + def __iter__(self): + yield self.sign + yield self.exponent + yield self.a_mantissa + yield self.b_mantissa + yield self.c_mantissa + yield self.do_sub + yield self.bypassed_z + yield self.do_bypass + yield from self.ctx + yield self.rm + + def eq(self, i): + return [ + self.sign.eq(i.sign), + self.exponent.eq(i.exponent), + self.a_mantissa.eq(i.a_mantissa), + self.b_mantissa.eq(i.b_mantissa), + self.c_mantissa.eq(i.c_mantissa), + self.do_sub.eq(i.do_sub), + self.bypassed_z.eq(i.bypassed_z), + self.do_bypass.eq(i.do_bypass), + self.ctx.eq(i.ctx), + self.rm.eq(i.rm), + ] + + +class FPFMASpecialCasesDeNorm(PipeModBase): + def __init__(self, pspec): + super().__init__(pspec, "sc_denorm") + + def ispec(self): + return FPFMAInputData(self.pspec) + + def ospec(self): + return FPFMASpecialCasesDeNormOutData(self.pspec) + + def elaborate(self, platform): + m = Module() + fpf = get_fpformat(self.pspec) + assert fpf.has_sign + inp = self.i + out = self.o + + a_exponent = Signal(expanded_exponent_shape(fpf)) + m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a)) + b_exponent_in = Signal(expanded_exponent_shape(fpf)) + m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b)) + c_exponent = Signal(expanded_exponent_shape(fpf)) + m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c)) + prod_exponent = Signal(expanded_exponent_shape(fpf)) + m.d.comb += prod_exponent.eq(a_exponent + c_exponent) + prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf)) + m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in) + b_mantissa_in = Signal(fpf.fraction_width + 1) + m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b)) + p_sign = Signal() + m.d.comb += p_sign.eq(fpf.get_sign_field(inp.a) ^ + fpf.get_sign_field(inp.c) ^ inp.negate_product) + b_sign = Signal() + m.d.comb += b_sign.eq(fpf.get_sign_field(inp.b) ^ inp.negate_addend) + + exponent = Signal(expanded_exponent_shape(fpf)) + b_shift = Signal(expanded_exponent_shape(fpf)) + # use >= since that's just checking the sign bit + with m.If(prod_exp_minus_b_exp >= 0): + m.d.comb += [ + exponent.eq(prod_exponent), + b_shift.eq(prod_exp_minus_b_exp), + ] + with m.Else(): + m.d.comb += [ + exponent.eq(b_exponent_in), + b_shift.eq(0), + ] + + m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width) + m.d.comb += [ + rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width + - b_mantissa_in.width)), + rshiftm.diff.eq(b_shift), + ] + + # handle special cases + with m.If(fpf.is_nan(inp.a)): + m.d.comb += [ + out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)), + out.do_bypass.eq(True), + ] + with m.Elif(fpf.is_nan(inp.b)): + m.d.comb += [ + out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)), + out.do_bypass.eq(True), + ] + with m.Elif(fpf.is_nan(inp.c)): + m.d.comb += [ + out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)), + out.do_bypass.eq(True), + ] + with m.Elif((fpf.is_zero(inp.a) & fpf.is_inf(inp.c)) + | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))): + # infinity * 0 + m.d.comb += [ + out.bypassed_z.eq(fpf.quiet_nan()), + out.do_bypass.eq(True), + ] + with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c)) + & fpf.is_inf(inp.b) & p_sign != b_sign): + # inf - inf + m.d.comb += [ + out.bypassed_z.eq(fpf.quiet_nan()), + out.do_bypass.eq(True), + ] + with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)): + # inf + x + m.d.comb += [ + out.bypassed_z.eq(fpf.inf(p_sign)), + out.do_bypass.eq(True), + ] + with m.Elif(fpf.is_inf(inp.b)): + # x + inf + m.d.comb += [ + out.bypassed_z.eq(fpf.inf(b_sign)), + out.do_bypass.eq(True), + ] + with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c)) + & fpf.is_zero(inp.b) & p_sign == b_sign): + # zero + zero + m.d.comb += [ + out.bypassed_z.eq(fpf.zero(p_sign)), + out.do_bypass.eq(True), + ] + # zero - zero handled by FPFMAMainStage + with m.Else(): + m.d.comb += [ + out.bypassed_z.eq(0), + out.do_bypass.eq(False), + ] + + m.d.comb += [ + out.sign.eq(p_sign), + out.exponent.eq(exponent), + out.a_mantissa.eq(fpf.get_mantissa_value(inp.a)), + out.b_mantissa.eq(rshiftm.m), + out.c_mantissa.eq(fpf.get_mantissa_value(inp.c)), + out.do_sub.eq(p_sign != b_sign), + out.ctx.eq(inp.ctx), + out.rm.eq(inp.rm), + ] + + return m diff --git a/src/ieee754/fpfma/util.py b/src/ieee754/fpfma/util.py new file mode 100644 index 00000000..5372ab8f --- /dev/null +++ b/src/ieee754/fpfma/util.py @@ -0,0 +1,38 @@ +from ieee754.fpcommon.fpbase import FPFormat +from nmigen.hdl.ast import signed, unsigned + + +def expanded_exponent_shape(fpformat): + assert isinstance(fpformat, FPFormat) + return signed(fpformat.e_width + 3) + + +EXPANDED_MANTISSA_EXTRA_LSBS = 3 + + +def expanded_mantissa_shape(fpformat): + assert isinstance(fpformat, FPFormat) + return signed(fpformat.fraction_width * 3 + + 2 + EXPANDED_MANTISSA_EXTRA_LSBS) + + +def multiplicand_mantissa_shape(fpformat): + assert isinstance(fpformat, FPFormat) + return unsigned(fpformat.fraction_width + 1) + + +def product_mantissa_shape(fpformat): + assert isinstance(fpformat, FPFormat) + return unsigned(multiplicand_mantissa_shape(fpformat).width * 2) + + +def get_fpformat(pspec): + width = pspec.width + assert isinstance(width, int) + fpformat = getattr(pspec, "fpformat", None) + if fpformat is None: + fpformat = FPFormat.standard(width) + else: + assert isinstance(fpformat, FPFormat) + assert width == fpformat.width + return fpformat -- 2.30.2