start implementing fp fused-mul-add pipeline
[ieee754fpu.git] / src / ieee754 / fpfma / special_cases.py
1 """ floating-point fused-multiply-add
2
3 computes `z = (a * c) + b` but only rounds once at the end
4 """
5
6 from nmutil.pipemodbase import PipeModBase
7 from ieee754.fpcommon.basedata import FPBaseData
8 from nmigen.hdl.ast import Signal
9 from nmigen.hdl.dsl import Module
10 from ieee754.fpcommon.getop import FPPipeContext
11 from ieee754.fpcommon.fpbase import FPRoundingMode, MultiShiftRMerge
12 from ieee754.fpfma.util import expanded_exponent_shape, \
13 expanded_mantissa_shape, get_fpformat, multiplicand_mantissa_shape
14
15
16 class FPFMAInputData(FPBaseData):
17 def __init__(self, pspec):
18 assert pspec.n_ops == 3
19 super().__init__(pspec)
20
21 self.negate_addend = Signal()
22 """if the addend should be negated"""
23
24 self.negate_product = Signal()
25 """if the product should be negated"""
26
27 def eq(self, i):
28 ret = super().eq(i)
29 ret.append(self.negate_addend.eq(i.negate_addend))
30 ret.append(self.negate_product.eq(i.negate_product))
31 return ret
32
33 def __iter__(self):
34 yield from super().__iter__()
35 yield self.negate_addend
36 yield self.negate_product
37
38 def ports(self):
39 return list(self)
40
41
42 class FPFMASpecialCasesDeNormOutData:
43 def __init__(self, pspec):
44 fpf = get_fpformat(pspec)
45
46 self.sign = Signal()
47 """sign"""
48
49 self.exponent = Signal(expanded_exponent_shape(fpf))
50 """exponent of intermediate -- unbiased"""
51
52 self.a_mantissa = Signal(multiplicand_mantissa_shape(fpf))
53 """mantissa of a input -- un-normalized and with implicit bit added"""
54
55 self.b_mantissa = Signal(multiplicand_mantissa_shape(fpf))
56 """mantissa of b input
57
58 shifted to appropriate location for add and with implicit bit added
59 """
60
61 self.c_mantissa = Signal(expanded_mantissa_shape(fpf))
62 """mantissa of c input -- un-normalized and with implicit bit added"""
63
64 self.do_sub = Signal()
65 """true if `b_mantissa` should be subtracted from
66 `a_mantissa * c_mantissa` rather than added
67 """
68
69 self.bypassed_z = Signal(fpf.width)
70 """final output value of the fma when `do_bypass` is set"""
71
72 self.do_bypass = Signal()
73 """set if `bypassed_z` is the final output value of the fma"""
74
75 self.ctx = FPPipeContext(pspec)
76 """pipe context"""
77
78 self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
79 """rounding mode"""
80
81 def __iter__(self):
82 yield self.sign
83 yield self.exponent
84 yield self.a_mantissa
85 yield self.b_mantissa
86 yield self.c_mantissa
87 yield self.do_sub
88 yield self.bypassed_z
89 yield self.do_bypass
90 yield from self.ctx
91 yield self.rm
92
93 def eq(self, i):
94 return [
95 self.sign.eq(i.sign),
96 self.exponent.eq(i.exponent),
97 self.a_mantissa.eq(i.a_mantissa),
98 self.b_mantissa.eq(i.b_mantissa),
99 self.c_mantissa.eq(i.c_mantissa),
100 self.do_sub.eq(i.do_sub),
101 self.bypassed_z.eq(i.bypassed_z),
102 self.do_bypass.eq(i.do_bypass),
103 self.ctx.eq(i.ctx),
104 self.rm.eq(i.rm),
105 ]
106
107
108 class FPFMASpecialCasesDeNorm(PipeModBase):
109 def __init__(self, pspec):
110 super().__init__(pspec, "sc_denorm")
111
112 def ispec(self):
113 return FPFMAInputData(self.pspec)
114
115 def ospec(self):
116 return FPFMASpecialCasesDeNormOutData(self.pspec)
117
118 def elaborate(self, platform):
119 m = Module()
120 fpf = get_fpformat(self.pspec)
121 assert fpf.has_sign
122 inp = self.i
123 out = self.o
124
125 a_exponent = Signal(expanded_exponent_shape(fpf))
126 m.d.comb += a_exponent.eq(fpf.get_exponent(inp.a))
127 b_exponent_in = Signal(expanded_exponent_shape(fpf))
128 m.d.comb += b_exponent_in.eq(fpf.get_exponent(inp.b))
129 c_exponent = Signal(expanded_exponent_shape(fpf))
130 m.d.comb += c_exponent.eq(fpf.get_exponent(inp.c))
131 prod_exponent = Signal(expanded_exponent_shape(fpf))
132 m.d.comb += prod_exponent.eq(a_exponent + c_exponent)
133 prod_exp_minus_b_exp = Signal(expanded_exponent_shape(fpf))
134 m.d.comb += prod_exp_minus_b_exp.eq(prod_exponent - b_exponent_in)
135 b_mantissa_in = Signal(fpf.fraction_width + 1)
136 m.d.comb += b_mantissa_in.eq(fpf.get_mantissa_value(inp.b))
137 p_sign = Signal()
138 m.d.comb += p_sign.eq(fpf.get_sign_field(inp.a) ^
139 fpf.get_sign_field(inp.c) ^ inp.negate_product)
140 b_sign = Signal()
141 m.d.comb += b_sign.eq(fpf.get_sign_field(inp.b) ^ inp.negate_addend)
142
143 exponent = Signal(expanded_exponent_shape(fpf))
144 b_shift = Signal(expanded_exponent_shape(fpf))
145 # use >= since that's just checking the sign bit
146 with m.If(prod_exp_minus_b_exp >= 0):
147 m.d.comb += [
148 exponent.eq(prod_exponent),
149 b_shift.eq(prod_exp_minus_b_exp),
150 ]
151 with m.Else():
152 m.d.comb += [
153 exponent.eq(b_exponent_in),
154 b_shift.eq(0),
155 ]
156
157 m.submodules.rshiftm = rshiftm = MultiShiftRMerge(out.b_mantissa.width)
158 m.d.comb += [
159 rshiftm.inp.eq(b_mantissa_in << (out.b_mantissa.width
160 - b_mantissa_in.width)),
161 rshiftm.diff.eq(b_shift),
162 ]
163
164 # handle special cases
165 with m.If(fpf.is_nan(inp.a)):
166 m.d.comb += [
167 out.bypassed_z.eq(fpf.to_quiet_nan(inp.a)),
168 out.do_bypass.eq(True),
169 ]
170 with m.Elif(fpf.is_nan(inp.b)):
171 m.d.comb += [
172 out.bypassed_z.eq(fpf.to_quiet_nan(inp.b)),
173 out.do_bypass.eq(True),
174 ]
175 with m.Elif(fpf.is_nan(inp.c)):
176 m.d.comb += [
177 out.bypassed_z.eq(fpf.to_quiet_nan(inp.c)),
178 out.do_bypass.eq(True),
179 ]
180 with m.Elif((fpf.is_zero(inp.a) & fpf.is_inf(inp.c))
181 | (fpf.is_inf(inp.a) & fpf.is_zero(inp.c))):
182 # infinity * 0
183 m.d.comb += [
184 out.bypassed_z.eq(fpf.quiet_nan()),
185 out.do_bypass.eq(True),
186 ]
187 with m.Elif((fpf.is_inf(inp.a) | fpf.is_inf(inp.c))
188 & fpf.is_inf(inp.b) & p_sign != b_sign):
189 # inf - inf
190 m.d.comb += [
191 out.bypassed_z.eq(fpf.quiet_nan()),
192 out.do_bypass.eq(True),
193 ]
194 with m.Elif(fpf.is_inf(inp.a) | fpf.is_inf(inp.c)):
195 # inf + x
196 m.d.comb += [
197 out.bypassed_z.eq(fpf.inf(p_sign)),
198 out.do_bypass.eq(True),
199 ]
200 with m.Elif(fpf.is_inf(inp.b)):
201 # x + inf
202 m.d.comb += [
203 out.bypassed_z.eq(fpf.inf(b_sign)),
204 out.do_bypass.eq(True),
205 ]
206 with m.Elif((fpf.is_zero(inp.a) | fpf.is_zero(inp.c))
207 & fpf.is_zero(inp.b) & p_sign == b_sign):
208 # zero + zero
209 m.d.comb += [
210 out.bypassed_z.eq(fpf.zero(p_sign)),
211 out.do_bypass.eq(True),
212 ]
213 # zero - zero handled by FPFMAMainStage
214 with m.Else():
215 m.d.comb += [
216 out.bypassed_z.eq(0),
217 out.do_bypass.eq(False),
218 ]
219
220 m.d.comb += [
221 out.sign.eq(p_sign),
222 out.exponent.eq(exponent),
223 out.a_mantissa.eq(fpf.get_mantissa_value(inp.a)),
224 out.b_mantissa.eq(rshiftm.m),
225 out.c_mantissa.eq(fpf.get_mantissa_value(inp.c)),
226 out.do_sub.eq(p_sign != b_sign),
227 out.ctx.eq(inp.ctx),
228 out.rm.eq(inp.rm),
229 ]
230
231 return m