diff on div and mul shows corrections stage missed out
[ieee754fpu.git] / src / add / fmul.py
1 from nmigen import Module, Signal
2 from nmigen.cli import main, verilog
3
4 from fpbase import FPNum, FPOp, Overflow, FPBase
5
6
7 class FPMUL(FPBase):
8
9 def __init__(self, width):
10 FPBase.__init__(self)
11 self.width = width
12
13 self.in_a = FPOp(width)
14 self.in_b = FPOp(width)
15 self.out_z = FPOp(width)
16
17 def get_fragment(self, platform=None):
18 """ creates the HDL code-fragment for FPMUL
19 """
20 m = Module()
21
22 # Latches
23 a = FPNum(self.width, False)
24 b = FPNum(self.width, False)
25 z = FPNum(self.width, False)
26
27 mw = (self.m_width)*2 - 1 + 3 # sticky/round/guard bits + (2*mant) - 1
28 product = Signal(mw)
29
30 of = Overflow()
31
32 with m.FSM() as fsm:
33
34 # ******
35 # gets operand a
36
37 with m.State("get_a"):
38 self.get_op(m, self.in_a, a, "get_b")
39
40 # ******
41 # gets operand b
42
43 with m.State("get_b"):
44 self.get_op(m, self.in_b, b, "special_cases")
45
46 # ******
47 # special cases
48
49 with m.State("special_cases"):
50 m.next = "normalise_a"
51 #if a or b is NaN return NaN
52 with m.If(a.is_nan() | b.is_nan()):
53 m.next += "put_z"
54 m.d.sync += z.nan(1)
55 #if a is inf return inf
56 with m.Elif(a.is_inf()):
57 m.next += "put_z"
58 m.d.sync += z.inf(0)
59 #if b is zero return NaN
60 with m.If(b.is_zero()):
61 m.d.sync += z.nan(1)
62 #if b is inf return inf
63 with m.Elif(b.is_inf()):
64 m.next += "put_z"
65 m.d.sync += z.inf(0)
66 #if a is zero return NaN
67 with m.If(a.is_zero()):
68 m.next += "put_z"
69 m.d.sync += z.nan(1)
70 #if a is zero return zero
71 with m.Elif(a.is_zero()):
72 m.next += "put_z"
73 m.d.sync += z.zero(0)
74 #if b is zero return zero
75 with m.Elif(b.is_zero()):
76 m.next += "put_z"
77 m.d.sync += z.zero(0)
78 # Denormalised Number checks
79 with m.Else():
80 m.next += "normalise_a"
81 self.denormalise(m, a)
82 self.denormalise(m, b)
83
84 # ******
85 # normalise_a
86
87 with m.State("normalise_a"):
88 self.op_normalise(m, a, "normalise_b")
89
90 # ******
91 # normalise_b
92
93 with m.State("normalise_b"):
94 self.op_normalise(m, b, "multiply_0")
95
96 #multiply_0
97 with m.State("multiply_0"):
98 m.next += "multiply_1"
99 m.d.sync += [
100 z.s.eq(a.s ^ b.s),
101 z.e.eq(a.e + b.e + 1),
102 product.eq(a.m * b.m * 4)
103 ]
104
105 #multiply_1
106 with m.State("multiply_1"):
107 m.next += "normalise_1"
108 m.d.sync += [
109 z.m.eq(product[26:50]),
110 guard.eq(product[25]),
111 round_bit.eq(product[24]),
112 sticky.eq(product[0:23] != 0)
113 ]
114
115 # ******
116 # First stage of normalisation.
117 with m.State("normalise_1"):
118 self.normalise_1(m, z, of, "normalise_2")
119
120 # ******
121 # Second stage of normalisation.
122
123 with m.State("normalise_2"):
124 self.normalise_2(m, z, of, "round")
125
126 # ******
127 # rounding stage
128
129 with m.State("round"):
130 self.roundz(m, z, of, "corrections")
131
132 # ******
133 # correction stage
134
135 with m.State("corrections"):
136 self.corrections(m, z, "pack")
137
138 # ******
139 # pack stage
140 with m.State("pack"):
141 self.pack(m, z, "put_z")
142
143 # ******
144 # put_z stage
145
146 with m.State("put_z"):
147 self.put_z(m, z, self.out_z, "get_a")
148
149 return m
150
151 """
152 special_cases:
153 begin
154 //if a is NaN or b is NaN return NaN
155 if ((a_e == 128 && a_m != 0) || (b_e == 128 && b_m != 0)) begin
156 z[31] <= 1;
157 z[30:23] <= 255;
158 z[22] <= 1;
159 z[21:0] <= 0;
160 state <= put_z;
161 //if a is inf return inf
162 end else if (a_e == 128) begin
163 z[31] <= a_s ^ b_s;
164 z[30:23] <= 255;
165 z[22:0] <= 0;
166 //if b is zero return NaN
167 if (($signed(b_e) == -127) && (b_m == 0)) begin
168 z[31] <= 1;
169 z[30:23] <= 255;
170 z[22] <= 1;
171 z[21:0] <= 0;
172 end
173 state <= put_z;
174 //if b is inf return inf
175 end else if (b_e == 128) begin
176 z[31] <= a_s ^ b_s;
177 z[30:23] <= 255;
178 z[22:0] <= 0;
179 //if a is zero return NaN
180 if (($signed(a_e) == -127) && (a_m == 0)) begin
181 z[31] <= 1;
182 z[30:23] <= 255;
183 z[22] <= 1;
184 z[21:0] <= 0;
185 end
186 state <= put_z;
187 //if a is zero return zero
188 end else if (($signed(a_e) == -127) && (a_m == 0)) begin
189 z[31] <= a_s ^ b_s;
190 z[30:23] <= 0;
191 z[22:0] <= 0;
192 state <= put_z;
193 //if b is zero return zero
194 end else if (($signed(b_e) == -127) && (b_m == 0)) begin
195 z[31] <= a_s ^ b_s;
196 z[30:23] <= 0;
197 z[22:0] <= 0;
198 state <= put_z;
199 //^ done up to here
200 end else begin
201 //Denormalised Number
202 if ($signed(a_e) == -127) begin
203 a_e <= -126;
204 end else begin
205 a_m[23] <= 1;
206 end
207 //Denormalised Number
208 if ($signed(b_e) == -127) begin
209 b_e <= -126;
210 end else begin
211 b_m[23] <= 1;
212 end
213 state <= normalise_a;
214 end
215 end
216
217 normalise_a:
218 begin
219 if (a_m[23]) begin
220 state <= normalise_b;
221 end else begin
222 a_m <= a_m << 1;
223 a_e <= a_e - 1;
224 end
225 end
226
227 normalise_b:
228 begin
229 if (b_m[23]) begin
230 state <= multiply_0;
231 end else begin
232 b_m <= b_m << 1;
233 b_e <= b_e - 1;
234 end
235 end
236
237 multiply_0:
238 begin
239 z_s <= a_s ^ b_s;
240 z_e <= a_e + b_e + 1;
241 product <= a_m * b_m * 4;
242 state <= multiply_1;
243 end
244
245 multiply_1:
246 begin
247 z_m <= product[49:26];
248 guard <= product[25];
249 round_bit <= product[24];
250 sticky <= (product[23:0] != 0);
251 state <= normalise_1;
252 end
253
254 normalise_1:
255 begin
256 if (z_m[23] == 0) begin
257 z_e <= z_e - 1;
258 z_m <= z_m << 1;
259 z_m[0] <= guard;
260 guard <= round_bit;
261 round_bit <= 0;
262 end else begin
263 state <= normalise_2;
264 end
265 end
266
267 normalise_2:
268 begin
269 if ($signed(z_e) < -126) begin
270 z_e <= z_e + 1;
271 z_m <= z_m >> 1;
272 guard <= z_m[0];
273 round_bit <= guard;
274 sticky <= sticky | round_bit;
275 end else begin
276 state <= round;
277 end
278 end
279
280 round:
281 begin
282 if (guard && (round_bit | sticky | z_m[0])) begin
283 z_m <= z_m + 1;
284 if (z_m == 24'hffffff) begin
285 z_e <=z_e + 1;
286 end
287 end
288 state <= pack;
289 end
290
291 pack:
292 begin
293 z[22 : 0] <= z_m[22:0];
294 z[30 : 23] <= z_e[7:0] + 127;
295 z[31] <= z_s;
296 if ($signed(z_e) == -126 && z_m[23] == 0) begin
297 z[30 : 23] <= 0;
298 end
299 //if overflow occur
300 s, return inf
301 if ($signed(z_e) > 127) begin
302 z[22 : 0] <= 0;
303 z[30 : 23] <= 255;
304 z[31] <= z_s;
305 end
306 state <= put_z;
307 end
308
309 put_z:
310 begin
311 s_output_z_stb <= 1;
312 s_output_z <= z;
313 if (s_output_z_stb && output_z_ack) begin
314 s_output_z_stb <= 0;
315 state <= get_a;
316 end
317 end
318
319 """
320
321 if __name__ == "__main__":
322 alu = FPMUL(width=32)
323 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())