1 """ floating-point fused-multiply-add
3 computes `z = (a * c) + b` but only rounds once at the end
6 from nmutil
.pipemodbase
import PipeModBase
7 from ieee754
.fpcommon
.fpbase
import FPRoundingMode
8 from ieee754
.fpfma
.special_cases
import FPFMASpecialCasesDeNormOutData
9 from nmigen
.hdl
.dsl
import Module
10 from nmigen
.hdl
.ast
import Signal
, signed
, unsigned
, Mux
11 from ieee754
.fpfma
.util
import expanded_exponent_shape
, \
12 expanded_mantissa_shape
, get_fpformat
13 from ieee754
.fpcommon
.getop
import FPPipeContext
16 class FPFMAPostCalcData
:
17 def __init__(self
, pspec
):
18 fpf
= get_fpformat(pspec
)
23 self
.exponent
= Signal(expanded_exponent_shape(fpf
))
24 """exponent -- unbiased"""
26 self
.mantissa
= Signal(expanded_mantissa_shape(fpf
))
27 """unnormalized mantissa"""
29 self
.bypassed_z
= Signal(fpf
.width
)
30 """final output value of the fma when `do_bypass` is set"""
32 self
.do_bypass
= Signal()
33 """set if `bypassed_z` is the final output value of the fma"""
35 self
.ctx
= FPPipeContext(pspec
)
38 self
.rm
= Signal(FPRoundingMode
, reset
=FPRoundingMode
.DEFAULT
)
42 class FPFMAMainStage(PipeModBase
):
43 def __init__(self
, pspec
):
44 super().__init
__(pspec
, "main")
47 return FPFMASpecialCasesDeNormOutData(self
.pspec
)
50 return FPFMAPostCalcData(self
.pspec
)
52 def elaborate(self
, platform
):
54 fpf
= get_fpformat(self
.pspec
)
59 product_v
= inp
.a_mantissa
* inp
.c_mantissa
60 product
= Signal(product_v
.shape())
61 m
.d
.comb
+= product
.eq(product_v
)
62 negate_b_s
= Signal(signed(1))
63 negate_b_u
= Signal(unsigned(1))
65 negate_b_s
.eq(inp
.do_sub
),
66 negate_b_u
.eq(inp
.do_sub
),
68 sum_v
= product_v
+ (inp
.b_mantissa ^ negate_b_s
) + negate_b_u
69 sum = Signal(sum_v
.shape())
70 m
.d
.comb
+= sum.eq(sum_v
)
75 sum_neg
.eq(sum < 0), # just sign bit
76 sum_zero
.eq(sum == 0),
79 zero_sign_array
= FPRoundingMode
.make_array(FPRoundingMode
.zero_sign
)
81 with m
.If(sum_zero
& ~inp
.do_bypass
):
83 out
.bypassed_z
.eq(fpf
.zero(zero_sign_array
[inp
.rm
])),
84 out
.do_bypass
.eq(True),
88 out
.bypassed_z
.eq(inp
.bypassed_z
),
89 out
.do_bypass
.eq(inp
.do_bypass
),
93 out
.sign
.eq(sum_neg ^ inp
.sign
),
94 out
.exponent
.eq(inp
.exponent
),
95 out
.mantissa
.eq(Mux(sum_neg
, -sum, sum)),