1 """ floating-point fused-multiply-add
3 computes `z = (a * c) + b` but only rounds once at the end
6 from nmutil
.pipemodbase
import PipeModBase
, PipeModBaseChain
7 from ieee754
.fpcommon
.fpbase
import FPRoundingMode
, FPFormat
8 from ieee754
.fpfma
.special_cases
import FPFMASpecialCasesDeNormOutData
9 from nmigen
.hdl
.dsl
import Module
10 from nmigen
.hdl
.ast
import Signal
, signed
, unsigned
, Mux
, Cat
11 from ieee754
.fpfma
.util
import expanded_exponent_shape
, \
12 expanded_mantissa_shape
, EXPANDED_MANTISSA_EXTRA_LSBS
13 from ieee754
.fpcommon
.getop
import FPPipeContext
16 class FPFMAPostCalcData
:
17 def __init__(self
, pspec
):
18 fpf
= FPFormat
.from_pspec(pspec
)
23 self
.exponent
= Signal(expanded_exponent_shape(fpf
))
24 """exponent -- unbiased"""
26 self
.mantissa
= Signal(expanded_mantissa_shape(fpf
))
27 """unnormalized mantissa"""
29 self
.bypassed_z
= Signal(fpf
.width
)
30 """final output value of the fma when `do_bypass` is set"""
32 self
.do_bypass
= Signal()
33 """set if `bypassed_z` is the final output value of the fma"""
35 self
.ctx
= FPPipeContext(pspec
)
38 self
.rm
= Signal(FPRoundingMode
, reset
=FPRoundingMode
.DEFAULT
)
44 self
.exponent
.eq(i
.exponent
),
45 self
.mantissa
.eq(i
.mantissa
),
46 self
.bypassed_z
.eq(i
.bypassed_z
),
47 self
.do_bypass
.eq(i
.do_bypass
),
65 class FPFMAMain(PipeModBase
):
66 def __init__(self
, pspec
):
67 super().__init
__(pspec
, "main")
70 return FPFMASpecialCasesDeNormOutData(self
.pspec
)
73 return FPFMAPostCalcData(self
.pspec
)
75 def elaborate(self
, platform
):
77 fpf
= FPFormat
.from_pspec(self
.pspec
)
82 product_v
= inp
.a_mantissa
* inp
.c_mantissa
83 product
= Signal(product_v
.shape())
84 m
.d
.comb
+= product
.eq(product_v
)
85 negate_b_s
= Signal(signed(1))
86 negate_b_u
= Signal(unsigned(1))
88 negate_b_s
.eq(inp
.do_sub
),
89 negate_b_u
.eq(inp
.do_sub
),
91 sum_v
= (product_v
<< EXPANDED_MANTISSA_EXTRA_LSBS
) + \
92 (inp
.b_mantissa ^ negate_b_s
) + negate_b_u
93 sum = Signal(expanded_mantissa_shape(fpf
))
94 m
.d
.comb
+= sum.eq(sum_v
)
99 sum_neg
.eq(sum < 0), # just sign bit
100 sum_zero
.eq(sum == 0),
103 zero_sign_array
= FPRoundingMode
.make_array(FPRoundingMode
.zero_sign
)
105 with m
.If(sum_zero
& ~inp
.do_bypass
):
107 out
.bypassed_z
.eq(fpf
.zero(zero_sign_array
[inp
.rm
])),
108 out
.do_bypass
.eq(True),
112 out
.bypassed_z
.eq(inp
.bypassed_z
),
113 out
.do_bypass
.eq(inp
.do_bypass
),
117 out
.sign
.eq(sum_neg ^ inp
.sign
),
118 out
.exponent
.eq(inp
.exponent
),
119 out
.mantissa
.eq(Mux(sum_neg
, -sum, sum)),
126 class FPFMAMainStage(PipeModBaseChain
):
127 def __init__(self
, pspec
):
128 super().__init
__(pspec
)
131 """ gets chain of modules
133 return [FPFMAMain(self
.pspec
)]