1 """ floating-point fused-multiply-add
3 computes `z = (a * c) + b` but only rounds once at the end
6 from nmutil
.pipemodbase
import PipeModBase
, PipeModBaseChain
7 from ieee754
.fpcommon
.basedata
import FPBaseData
8 from nmigen
.hdl
.ast
import Signal
9 from nmigen
.hdl
.dsl
import Module
10 from ieee754
.fpcommon
.getop
import FPPipeContext
11 from ieee754
.fpcommon
.fpbase
import FPRoundingMode
, MultiShiftRMerge
12 from ieee754
.fpfma
.util
import expanded_exponent_shape
, \
13 expanded_mantissa_shape
, get_fpformat
, multiplicand_mantissa_shape
, \
14 EXPANDED_MANTISSA_EXTRA_MSBS
, EXPANDED_MANTISSA_EXTRA_LSBS
, \
15 product_mantissa_shape
18 class FPFMAInputData(FPBaseData
):
19 def __init__(self
, pspec
):
20 assert pspec
.n_ops
== 3
21 super().__init
__(pspec
)
23 self
.negate_addend
= Signal()
24 """if the addend should be negated"""
26 self
.negate_product
= Signal()
27 """if the product should be negated"""
31 ret
.append(self
.negate_addend
.eq(i
.negate_addend
))
32 ret
.append(self
.negate_product
.eq(i
.negate_product
))
36 yield from super().__iter
__()
37 yield self
.negate_addend
38 yield self
.negate_product
44 class FPFMASpecialCasesDeNormOutData
:
45 def __init__(self
, pspec
):
46 fpf
= get_fpformat(pspec
)
51 self
.exponent
= Signal(expanded_exponent_shape(fpf
))
52 """exponent of intermediate -- unbiased"""
54 self
.a_mantissa
= Signal(multiplicand_mantissa_shape(fpf
))
55 """mantissa of a input -- un-normalized and with implicit bit added"""
57 self
.b_mantissa
= Signal(expanded_mantissa_shape(fpf
))
58 """mantissa of b input
60 shifted to appropriate location for add and with implicit bit added
63 self
.c_mantissa
= Signal(multiplicand_mantissa_shape(fpf
))
64 """mantissa of c input -- un-normalized and with implicit bit added"""
66 self
.do_sub
= Signal()
67 """true if `b_mantissa` should be subtracted from
68 `a_mantissa * c_mantissa` rather than added
71 self
.bypassed_z
= Signal(fpf
.width
)
72 """final output value of the fma when `do_bypass` is set"""
74 self
.do_bypass
= Signal()
75 """set if `bypassed_z` is the final output value of the fma"""
77 self
.ctx
= FPPipeContext(pspec
)
80 self
.rm
= Signal(FPRoundingMode
, reset
=FPRoundingMode
.DEFAULT
)
98 self
.exponent
.eq(i
.exponent
),
99 self
.a_mantissa
.eq(i
.a_mantissa
),
100 self
.b_mantissa
.eq(i
.b_mantissa
),
101 self
.c_mantissa
.eq(i
.c_mantissa
),
102 self
.do_sub
.eq(i
.do_sub
),
103 self
.bypassed_z
.eq(i
.bypassed_z
),
104 self
.do_bypass
.eq(i
.do_bypass
),
110 class FPFMASpecialCasesDeNorm(PipeModBase
):
111 def __init__(self
, pspec
):
112 super().__init
__(pspec
, "sc_denorm")
115 return FPFMAInputData(self
.pspec
)
118 return FPFMASpecialCasesDeNormOutData(self
.pspec
)
120 def elaborate(self
, platform
):
122 fpf
= get_fpformat(self
.pspec
)
127 a_exponent
= Signal(expanded_exponent_shape(fpf
))
128 m
.d
.comb
+= a_exponent
.eq(fpf
.get_exponent_value(inp
.a
))
129 b_exponent_in
= Signal(expanded_exponent_shape(fpf
))
130 m
.d
.comb
+= b_exponent_in
.eq(fpf
.get_exponent_value(inp
.b
))
131 c_exponent
= Signal(expanded_exponent_shape(fpf
))
132 m
.d
.comb
+= c_exponent
.eq(fpf
.get_exponent_value(inp
.c
))
133 b_exponent
= Signal(expanded_exponent_shape(fpf
))
134 m
.d
.comb
+= b_exponent
.eq(b_exponent_in
+ EXPANDED_MANTISSA_EXTRA_MSBS
)
135 prod_exponent
= Signal(expanded_exponent_shape(fpf
))
137 # number of bits that the product of two normalized signals needs to
138 # be shifted left to be normalized, e.g. the product of 2 8-bit
139 # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted
140 # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again:
141 # `0x4000 << 1 == 0x8000`
142 PROD_STAY_NORM_SHIFT
= 1
144 extra_prod_exponent
= (expanded_mantissa_shape(fpf
).width
145 - product_mantissa_shape(fpf
).width
146 + PROD_STAY_NORM_SHIFT
147 - EXPANDED_MANTISSA_EXTRA_LSBS
)
148 m
.d
.comb
+= prod_exponent
.eq(a_exponent
+ c_exponent
149 + extra_prod_exponent
)
150 prod_exp_minus_b_exp
= Signal(expanded_exponent_shape(fpf
))
151 m
.d
.comb
+= prod_exp_minus_b_exp
.eq(prod_exponent
- b_exponent
)
152 b_mantissa_in
= Signal(fpf
.fraction_width
+ 1)
153 m
.d
.comb
+= b_mantissa_in
.eq(fpf
.get_mantissa_value(inp
.b
))
155 m
.d
.comb
+= p_sign
.eq(fpf
.get_sign_field(inp
.a
) ^
156 fpf
.get_sign_field(inp
.c
) ^ inp
.negate_product
)
158 m
.d
.comb
+= b_sign
.eq(fpf
.get_sign_field(inp
.b
) ^ inp
.negate_addend
)
160 exponent
= Signal(expanded_exponent_shape(fpf
))
161 b_shift
= Signal(expanded_exponent_shape(fpf
))
162 # use >= since that's just checking the sign bit
163 with m
.If(prod_exp_minus_b_exp
>= 0):
165 exponent
.eq(prod_exponent
),
166 b_shift
.eq(prod_exp_minus_b_exp
),
170 exponent
.eq(b_exponent
),
174 m
.submodules
.rshiftm
= rshiftm
= MultiShiftRMerge(
175 out
.b_mantissa
.width
- EXPANDED_MANTISSA_EXTRA_MSBS
,
176 s_max
=expanded_exponent_shape(fpf
).width
- 1)
179 rshiftm
.inp
[-b_mantissa_in
.width
:].eq(b_mantissa_in
),
180 rshiftm
.diff
.eq(b_shift
),
183 keep
= {"keep": True}
185 # handle special cases
186 with m
.If(fpf
.is_nan(inp
.a
)):
188 Signal(name
="case_nan_a", attrs
=keep
).eq(True),
189 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.a
)),
190 out
.do_bypass
.eq(True),
192 with m
.Elif(fpf
.is_nan(inp
.b
)):
194 Signal(name
="case_nan_b", attrs
=keep
).eq(True),
195 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.b
)),
196 out
.do_bypass
.eq(True),
198 with m
.Elif(fpf
.is_nan(inp
.c
)):
200 Signal(name
="case_nan_c", attrs
=keep
).eq(True),
201 out
.bypassed_z
.eq(fpf
.to_quiet_nan(inp
.c
)),
202 out
.do_bypass
.eq(True),
204 with m
.Elif((fpf
.is_zero(inp
.a
) & fpf
.is_inf(inp
.c
))
205 |
(fpf
.is_inf(inp
.a
) & fpf
.is_zero(inp
.c
))):
208 Signal(name
="case_inf_times_zero", attrs
=keep
).eq(True),
209 out
.bypassed_z
.eq(fpf
.quiet_nan()),
210 out
.do_bypass
.eq(True),
212 with m
.Elif((fpf
.is_inf(inp
.a
) | fpf
.is_inf(inp
.c
))
213 & fpf
.is_inf(inp
.b
) & (p_sign
!= b_sign
)):
216 Signal(name
="case_inf_minus_inf", attrs
=keep
).eq(True),
217 out
.bypassed_z
.eq(fpf
.quiet_nan()),
218 out
.do_bypass
.eq(True),
220 with m
.Elif(fpf
.is_inf(inp
.a
) | fpf
.is_inf(inp
.c
)):
223 Signal(name
="case_inf_plus_x", attrs
=keep
).eq(True),
224 out
.bypassed_z
.eq(fpf
.inf(p_sign
)),
225 out
.do_bypass
.eq(True),
227 with m
.Elif(fpf
.is_inf(inp
.b
)):
230 Signal(name
="case_x_plus_inf", attrs
=keep
).eq(True),
231 out
.bypassed_z
.eq(fpf
.inf(b_sign
)),
232 out
.do_bypass
.eq(True),
234 with m
.Elif((fpf
.is_zero(inp
.a
) | fpf
.is_zero(inp
.c
))
235 & fpf
.is_zero(inp
.b
) & (p_sign
== b_sign
)):
238 Signal(name
="case_zero_plus_zero", attrs
=keep
).eq(True),
239 out
.bypassed_z
.eq(fpf
.zero(p_sign
)),
240 out
.do_bypass
.eq(True),
242 with m
.Elif((fpf
.is_zero(inp
.a
) | fpf
.is_zero(inp
.c
))
243 & ~fpf
.is_zero(inp
.b
)):
246 Signal(name
="case_zero_plus_x", attrs
=keep
).eq(True),
247 out
.bypassed_z
.eq(inp
.b
),
248 out
.do_bypass
.eq(True),
251 # zero - zero handled by FPFMAMainStage
253 out
.bypassed_z
.eq(0),
254 out
.do_bypass
.eq(False),
259 out
.exponent
.eq(exponent
),
260 out
.a_mantissa
.eq(fpf
.get_mantissa_value(inp
.a
)),
261 out
.b_mantissa
.eq(rshiftm
.m
),
262 out
.c_mantissa
.eq(fpf
.get_mantissa_value(inp
.c
)),
263 out
.do_sub
.eq(p_sign
!= b_sign
),
271 class FPFMASpecialCasesDeNormStage(PipeModBaseChain
):
272 def __init__(self
, pspec
):
273 super().__init
__(pspec
)
276 """ gets chain of modules
278 return [FPFMASpecialCasesDeNorm(self
.pspec
)]