add f8 fma tests -- f16 takes >8hr to run with bitwuzla
[ieee754fpu.git] / src / ieee754 / fpfma / special_cases.py
1 """ floating-point fused-multiply-add
2
3 computes `z = (a * c) + b` but only rounds once at the end
4 """
5
6 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
7 from ieee754.fpcommon.basedata import FPBaseData
8 from nmigen.hdl.ast import Signal
9 from nmigen.hdl.dsl import Module
10 from ieee754.fpcommon.getop import FPPipeContext
11 from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge, FPFormat
12 from ieee754.fpfma.util import expanded_exponent_shape, \
13 expanded_mantissa_shape, multiplicand_mantissa_shape, \
14 EXPANDED_MANTISSA_EXTRA_MSBS, EXPANDED_MANTISSA_EXTRA_LSBS, \
15 product_mantissa_shape
16
17
18 class FPFMAInputData(FPBaseData):
19 def __init__(self, pspec):
20 assert pspec.n_ops == 3
21 super().__init__(pspec)
22
23 self.negate_addend = Signal()
24 """if the addend should be negated"""
25
26 self.negate_product = Signal()
27 """if the product should be negated"""
28
29 def eq(self, i):
30 ret = super().eq(i)
31 ret.append(self.negate_addend.eq(i.negate_addend))
32 ret.append(self.negate_product.eq(i.negate_product))
33 return ret
34
35 def __iter__(self):
36 yield from super().__iter__()
37 yield self.negate_addend
38 yield self.negate_product
39
40 def ports(self):
41 return list(self)
42
43
44 class FPFMASpecialCasesDeNormOutData:
45 def __init__(self, pspec):
46 fpf = FPFormat.from_pspec(pspec)
47
48 self.sign = Signal()
49 """sign"""
50
51 self.exponent = Signal(expanded_exponent_shape(fpf))
52 """exponent of intermediate -- unbiased"""
53
54 self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
55 """mantissa of a input -- un-normalized and with implicit bit added"""
56
57 self.b_mantissa = Signal(expanded_mantissa_shape(fpf))
58 """mantissa of b input
59
60 shifted to appropriate location for add and with implicit bit added
61 """
62
63 self.c_mantissa = Signal(multiplicand_mantissa_shape(fpf))
64 """mantissa of c input -- un-normalized and with implicit bit added"""
65
66 self.do_sub = Signal()
67 """true if `b_mantissa` should be subtracted from
68 `a_mantissa * c_mantissa` rather than added
69 """
70
71 self.bypassed_z = Signal(fpf.width)
72 """final output value of the fma when `do_bypass` is set"""
73
74 self.do_bypass = Signal()
75 """set if `bypassed_z` is the final output value of the fma"""
76
77 self.ctx = FPPipeContext(pspec)
78 """pipe context"""
79
80 self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
81 """rounding mode"""
82
83 def __iter__(self):
84 yield self.sign
85 yield self.exponent
86 yield self.a_mantissa
87 yield self.b_mantissa
88 yield self.c_mantissa
89 yield self.do_sub
90 yield self.bypassed_z
91 yield self.do_bypass
92 yield from self.ctx
93 yield self.rm
94
95 def eq(self, i):
96 return [
97 self.sign.eq(i.sign),
98 self.exponent.eq(i.exponent),
99 self.a_mantissa.eq(i.a_mantissa),
100 self.b_mantissa.eq(i.b_mantissa),
101 self.c_mantissa.eq(i.c_mantissa),
102 self.do_sub.eq(i.do_sub),
103 self.bypassed_z.eq(i.bypassed_z),
104 self.do_bypass.eq(i.do_bypass),
105 self.ctx.eq(i.ctx),
106 self.rm.eq(i.rm),
107 ]
108
109
110 class FPFMASpecialCasesDeNorm(PipeModBase):
111 def __init__(self, pspec):
112 super().__init__(pspec, "sc_denorm")
113
114 def ispec(self):
115 return FPFMAInputData(self.pspec)
116
117 def ospec(self):
118 return FPFMASpecialCasesDeNormOutData(self.pspec)
119
120 def elaborate(self, platform):
121 m = Module()
122 fpf = FPFormat.from_pspec(self.pspec)
123 assert fpf.has_sign
124 inp = self.i
125 out = self.o
126
127 a_exponent = Signal(expanded_exponent_shape(fpf))
128 m.d.comb += a_exponent.eq(fpf.get_exponent_value(inp.a))
129 b_exponent_in = Signal(expanded_exponent_shape(fpf))
130 m.d.comb += b_exponent_in.eq(fpf.get_exponent_value(inp.b))
131 c_exponent = Signal(expanded_exponent_shape(fpf))
132 m.d.comb += c_exponent.eq(fpf.get_exponent_value(inp.c))
133 b_exponent = Signal(expanded_exponent_shape(fpf))
134 m.d.comb += b_exponent.eq(b_exponent_in + EXPANDED_MANTISSA_EXTRA_MSBS)
135 prod_exponent = Signal(expanded_exponent_shape(fpf))
136
137 # number of bits that the product of two normalized signals needs to
138 # be shifted left to be normalized, e.g. the product of 2 8-bit
139 # numbers `0x80 * 0x80 == 0x4000` and `0x4000` needs to be shifted
140 # left by `PROD_STAY_NORM_SHIFT` bits to be normalized again:
141 # `0x4000 << 1 == 0x8000`
142 PROD_STAY_NORM_SHIFT = 1
143
144 extra_prod_exponent = (expanded_mantissa_shape(fpf).width
145 - product_mantissa_shape(fpf).width
146 + PROD_STAY_NORM_SHIFT
147 - EXPANDED_MANTISSA_EXTRA_LSBS)
148 m.d.comb += prod_exponent.eq(a_exponent + c_exponent
149 + extra_prod_exponent)
150 prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
151 m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent)
152 b_mantissa_in = Signal(fpf.fraction_width + 1)
153 m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
154 p_sign = Signal()
155 m.d.comb += p_sign.eq(fpf.get_sign_field(inp.a) ^
156 fpf.get_sign_field(inp.c) ^ inp.negate_product)
157 b_sign = Signal()
158 m.d.comb += b_sign.eq(fpf.get_sign_field(inp.b) ^ inp.negate_addend)
159
160 exponent = Signal(expanded_exponent_shape(fpf))
161 b_shift = Signal(expanded_exponent_shape(fpf))
162 # use >= since that's just checking the sign bit
163 with m.If(prod_exp_minus_b_exp >= 0):
164 m.d.comb += [
165 exponent.eq(prod_exponent),
166 b_shift.eq(prod_exp_minus_b_exp),
167 ]
168 with m.Else():
169 m.d.comb += [
170 exponent.eq(b_exponent),
171 b_shift.eq(0),
172 ]
173
174 m.submodules.rshiftm = rshiftm = MultiShiftRMerge(
175 out.b_mantissa.width - EXPANDED_MANTISSA_EXTRA_MSBS,
176 s_max=expanded_exponent_shape(fpf).width - 1)
177 m.d.comb += [
178 rshiftm.inp.eq(0),
179 rshiftm.inp[-b_mantissa_in.width:].eq(b_mantissa_in),
180 rshiftm.diff.eq(b_shift),
181 ]
182
183 keep = {"keep": True}
184
185 # handle special cases
186 with m.If(fpf.is_nan(inp.a)):
187 m.d.comb += [
188 Signal(name="case_nan_a", attrs=keep).eq(True),
189 out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
190 out.do_bypass.eq(True),
191 ]
192 with m.Elif(fpf.is_nan(inp.b)):
193 m.d.comb += [
194 Signal(name="case_nan_b", attrs=keep).eq(True),
195 out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
196 out.do_bypass.eq(True),
197 ]
198 with m.Elif(fpf.is_nan(inp.c)):
199 m.d.comb += [
200 Signal(name="case_nan_c", attrs=keep).eq(True),
201 out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
202 out.do_bypass.eq(True),
203 ]
204 with m.Elif((fpf.is_zero(inp.a) & fpf.is_inf(inp.c))
205 | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
206 # infinity * 0
207 m.d.comb += [
208 Signal(name="case_inf_times_zero", attrs=keep).eq(True),
209 out.bypassed_z.eq(fpf.quiet_nan()),
210 out.do_bypass.eq(True),
211 ]
212 with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
213 & fpf.is_inf(inp.b) & (p_sign != b_sign)):
214 # inf - inf
215 m.d.comb += [
216 Signal(name="case_inf_minus_inf", attrs=keep).eq(True),
217 out.bypassed_z.eq(fpf.quiet_nan()),
218 out.do_bypass.eq(True),
219 ]
220 with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
221 # inf + x
222 m.d.comb += [
223 Signal(name="case_inf_plus_x", attrs=keep).eq(True),
224 out.bypassed_z.eq(fpf.inf(p_sign)),
225 out.do_bypass.eq(True),
226 ]
227 with m.Elif(fpf.is_inf(inp.b)):
228 # x + inf
229 m.d.comb += [
230 Signal(name="case_x_plus_inf", attrs=keep).eq(True),
231 out.bypassed_z.eq(fpf.inf(b_sign)),
232 out.do_bypass.eq(True),
233 ]
234 with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
235 & fpf.is_zero(inp.b) & (p_sign == b_sign)):
236 # zero + zero
237 m.d.comb += [
238 Signal(name="case_zero_plus_zero", attrs=keep).eq(True),
239 out.bypassed_z.eq(fpf.zero(p_sign)),
240 out.do_bypass.eq(True),
241 ]
242 with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
243 & ~fpf.is_zero(inp.b)):
244 # zero + x
245 m.d.comb += [
246 Signal(name="case_zero_plus_x", attrs=keep).eq(True),
247 out.bypassed_z.eq(inp.b ^ fpf.zero(inp.negate_addend)),
248 out.do_bypass.eq(True),
249 ]
250 with m.Else():
251 # zero - zero handled by FPFMAMainStage
252 m.d.comb += [
253 out.bypassed_z.eq(0),
254 out.do_bypass.eq(False),
255 ]
256
257 m.d.comb += [
258 out.sign.eq(p_sign),
259 out.exponent.eq(exponent),
260 out.a_mantissa.eq(fpf.get_mantissa_value(inp.a)),
261 out.b_mantissa.eq(rshiftm.m),
262 out.c_mantissa.eq(fpf.get_mantissa_value(inp.c)),
263 out.do_sub.eq(p_sign != b_sign),
264 out.ctx.eq(inp.ctx),
265 out.rm.eq(inp.rm),
266 ]
267
268 return m
269
270
271 class FPFMASpecialCasesDeNormStage(PipeModBaseChain):
272 def __init__(self, pspec):
273 super().__init__(pspec)
274
275 def get_chain(self):
276 """ gets chain of modules
277 """
278 return [FPFMASpecialCasesDeNorm(self.pspec)]