1ab2b2b8a048a14330c6c0f6230e6b0539cfac65
[ieee754fpu.git] / src / ieee754 / fpfma / main_stage.py
1 """ floating-point fused-multiply-add
2
3 computes `z = (a * c) + b` but only rounds once at the end
4 """
5
6 from nmutil.pipemodbase import PipeModBase
7 from ieee754.fpcommon.fpbase import FPRoundingMode
8 from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
9 from nmigen.hdl.dsl import Module
10 from nmigen.hdl.ast import Signal, signed, unsigned, Mux
11 from ieee754.fpfma.util import expanded_exponent_shape, \
12 expanded_mantissa_shape, get_fpformat
13 from ieee754.fpcommon.getop import FPPipeContext
14
15
16 class FPFMAPostCalcData:
17 def __init__(self, pspec):
18 fpf = get_fpformat(pspec)
19
20 self.sign = Signal()
21 """sign"""
22
23 self.exponent = Signal(expanded_exponent_shape(fpf))
24 """exponent -- unbiased"""
25
26 self.mantissa = Signal(expanded_mantissa_shape(fpf))
27 """unnormalized mantissa"""
28
29 self.bypassed_z = Signal(fpf.width)
30 """final output value of the fma when `do_bypass` is set"""
31
32 self.do_bypass = Signal()
33 """set if `bypassed_z` is the final output value of the fma"""
34
35 self.ctx = FPPipeContext(pspec)
36 """pipe context"""
37
38 self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
39 """rounding mode"""
40
41
42 class FPFMAMainStage(PipeModBase):
43 def __init__(self, pspec):
44 super().__init__(pspec, "main")
45
46 def ispec(self):
47 return FPFMASpecialCasesDeNormOutData(self.pspec)
48
49 def ospec(self):
50 return FPFMAPostCalcData(self.pspec)
51
52 def elaborate(self, platform):
53 m = Module()
54 fpf = get_fpformat(self.pspec)
55 assert fpf.has_sign
56 inp = self.i
57 out = self.o
58
59 product_v = inp.a_mantissa * inp.c_mantissa
60 product = Signal(product_v.shape())
61 m.d.comb += product.eq(product_v)
62 negate_b_s = Signal(signed(1))
63 negate_b_u = Signal(unsigned(1))
64 m.d.comb += [
65 negate_b_s.eq(inp.do_sub),
66 negate_b_u.eq(inp.do_sub),
67 ]
68 sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u
69 sum = Signal(sum_v.shape())
70 m.d.comb += sum.eq(sum_v)
71
72 sum_neg = Signal()
73 sum_zero = Signal()
74 m.d.comb += [
75 sum_neg.eq(sum < 0), # just sign bit
76 sum_zero.eq(sum == 0),
77 ]
78
79 zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
80
81 with m.If(sum_zero & ~inp.do_bypass):
82 m.d.comb += [
83 out.bypassed_z.eq(fpf.zero(zero_sign_array[inp.rm])),
84 out.do_bypass.eq(True),
85 ]
86 with m.Else():
87 m.d.comb += [
88 out.bypassed_z.eq(inp.bypassed_z),
89 out.do_bypass.eq(inp.do_bypass),
90 ]
91
92 m.d.comb += [
93 out.sign.eq(sum_neg ^ inp.sign),
94 out.exponent.eq(inp.exponent),
95 out.mantissa.eq(Mux(sum_neg, -sum, sum)),
96 out.ctx.eq(inp.ctx),
97 out.rm.eq(inp.rm),
98 ]
99 return m