c197d735ed531d73a1671a9d2ecedd8cc1daa4cb
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main
7
8
9 class FPADD:
10 def __init__(self, width):
11 self.width = width
12
13 self.in_a = Signal(width)
14 self.in_a_stb = Signal()
15 self.in_a_ack = Signal()
16
17 self.in_b = Signal(width)
18 self.in_b_stb = Signal()
19 self.in_b_ack = Signal()
20
21 self.out_z = Signal(width)
22 self.out_z_stb = Signal()
23 self.out_z_ack = Signal()
24
25 s_out_z_stb = Signal()
26 s_out_z = Signal(width)
27 s_in_a_ack = Signal()
28 s_in_b_ack = Signal()
29
30 def get_fragment(self, platform):
31 m = Module()
32
33 # Latches
34 a = Signal(self.width)
35 b = Signal(self.width)
36 z = Signal(self.width)
37
38 # Mantissa
39 a_m = Signal(27) # ??? seems to be 1 bit extra??
40 b_m = Signal(27) # ??? seems to be 1 bit extra??
41 z_m = Signal(24)
42
43 # Exponent: 10 bits, signed (the exponent bias is subtracted)
44 a_e = Signal((10, True))
45 b_e = Signal((10, True))
46 z_e = Signal((10, True))
47
48 # Sign
49 a_s = Signal()
50 b_s = Signal()
51 z_s = Signal()
52
53 guard = Signal()
54 round_bit = Signal()
55 sticky = Signal()
56
57 tot = Signal(28)
58
59 with m.FSM() as fsm:
60
61 # ******
62 # gets operand a
63
64 with m.State("get_a"):
65 with m.If((self.in_a_ack) & (self.in_a_stb)):
66 m.next = "get_b"
67 m.d.sync += [
68 a.eq(self.in_a),
69 self.in_a_ack.eq(0)
70 ]
71 with m.Else():
72 m.d.sync += self.in_a_ack.eq(1)
73
74 # ******
75 # gets operand b
76
77 with m.State("get_b"):
78 with m.If((self.in_b_ack) & (self.in_b_stb)):
79 m.next = "get_a"
80 m.d.sync += [
81 b.eq(self.in_b),
82 self.in_b_ack.eq(0)
83 ]
84 with m.Else():
85 m.d.sync += self.in_b_ack.eq(1)
86
87 # ******
88 # unpacks operands into sign, mantissa and exponent
89
90 with m.State("unpack"):
91 m.next = "special_cases"
92 m.d.sync += [
93 # mantissa
94 a_m.eq(Cat(0, 0, 0, a[0:23])),
95 b_m.eq(Cat(0, 0, 0, b[0:23])),
96 # exponent (take off exponent bias, here)
97 a_e.eq(Cat(a[23:31]) - 127),
98 b_e.eq(Cat(b[23:31]) - 127),
99 # sign
100 a_s.eq(Cat(a[31])),
101 b_s.eq(Cat(b[31]))
102 ]
103
104 # ******
105 # special cases: NaNs, infs, zeros, denormalised
106
107 with m.State("special_cases"):
108
109 # if a is NaN or b is NaN return NaN
110 with m.If(((a_e == 128) & (a_m != 0)) | \
111 ((b_e == 128) & (b_m != 0))):
112 m.next = "put_z"
113 m.d.sync += [
114 z[31].eq(1), # sign: 1
115 z[23:31].eq(255), # exp: 0b11111...
116 z[22].eq(1), # mantissa top bit: 1
117 z[0:22].eq(0) # mantissa rest: 0b0000...
118 ]
119
120 # if a is inf return inf (or NaN)
121 with m.Elif(a_e == 128):
122 m.next = "put_z"
123 m.d.sync += [
124 z[31].eq(a_s), # sign: a_s
125 z[23:31].eq(255), # exp: 0b11111...
126 z[0:23].eq(0) # mantissa rest: 0b0000...
127 ]
128 # if a is inf and signs don't match return NaN
129 with m.If((b_e == 128) & (a_s != b_s)):
130 m.d.sync += [
131 z[31].eq(b_s), # sign: b_s
132 z[23:31].eq(255), # exp: 0b11111...
133 z[22].eq(1), # mantissa top bit: 1
134 z[0:22].eq(0) # mantissa rest: 0b0000...
135 ]
136 # if b is inf return inf
137 with m.Elif(b_e == 128):
138 m.next = "put_z"
139 m.d.sync += [
140 z[31].eq(b_s), # sign: b_s
141 z[23:31].eq(255), # exp: 0b11111...
142 z[0:23].eq(0) # mantissa rest: 0b0000...
143 ]
144
145 # if a is zero and b zero return signed-a/b
146 with m.Elif(((a_e == -127) & (a_m == 0)) & \
147 ((b_e == -127) & (b_m == 0))):
148 m.next = "put_z"
149 m.d.sync += [
150 z[31].eq(a_s & b_s), # sign: a/b_s
151 z[23:31].eq(b_e[0:8] + 127), # exp: b_e (plus bias)
152 z[0:23].eq(b_m[3:26]) # mantissa: b_m top bits
153 ]
154
155 # if a is zero return b
156 with m.Elif((a_e == -127) & (a_m == 0)):
157 m.next = "put_z"
158 m.d.sync += [
159 z[31].eq(b_s), # sign: a/b_s
160 z[23:31].eq(b_e[0:8] + 127), # exp: b_e (plus bias)
161 z[0:23].eq(b_m[3:26]) # mantissa: b_m top bits
162 ]
163
164 # if b is zero return a
165 with m.Elif((b_e == -127) & (b_m == 0)):
166 m.next = "put_z"
167 m.d.sync += [
168 z[31].eq(a_s), # sign: a/b_s
169 z[23:31].eq(a_e[0:8] + 127), # exp: a_e (plus bias)
170 z[0:23].eq(a_m[3:26]) # mantissa: a_m top bits
171 ]
172
173 # Denormalised Number checks
174 with m.Else():
175 m.next = "align"
176 # denormalise a check
177 with m.If(a_e == -127):
178 m.d.sync += a_e.eq(-126) # limit a exponent
179 with m.Else():
180 m.d.sync += a_m[26].eq(1) # set highest mantissa bit
181 # denormalise b check
182 with m.If(b_e == -127):
183 m.d.sync += b_e.eq(-126) # limit b exponent
184 with m.Else():
185 m.d.sync += b_m[26].eq(1) # set highest mantissa bit
186
187 # ******
188 # align. NOTE: this does *not* do single-cycle multi-shifting,
189 # it *STAYS* in the align state until the exponents match
190
191 with m.State("align"):
192 # exponent of a greater than b: increment b exp, shift b mant
193 with m.If(a_e > b_e):
194 m.d.sync += [
195 b_e.eq(b_e + 1),
196 b_m.eq(b_m >> 1),
197 b_m[0].eq(b_m[0] | b_m[1]) # moo??
198 ]
199 # exponent of b greater than a: increment a exp, shift a mant
200 with m.Elif(a_e < b_e):
201 m.d.sync += [
202 a_e.eq(a_e + 1),
203 a_m.eq(a_m >> 1),
204 a_m[0].eq(a_m[0] | a_m[1]) # moo??
205 ]
206 # exponents equal: move to next stage.
207 with m.Else():
208 m.next = "add_0"
209
210 # ******
211 # First stage of add
212
213 with m.State("add_0"):
214 m.next = "add_1"
215 m.d.sync += z_e.eq(a_e)
216 # same-sign (both negative or both positive) add mantissas
217 with m.If(a_s == b_s):
218 m.d.sync += [
219 tot.eq(a_m + b_m),
220 z_s.eq(a_s)
221 ]
222 # a mantissa greater than b, use a
223 with m.Elif(a_m >= b_m):
224 m.d.sync += [
225 tot.eq(a_m - b_m),
226 z_s.eq(a_s)
227 ]
228 # b mantissa greater than a, use b
229 with m.Else():
230 m.d.sync += [
231 tot.eq(b_m - a_m),
232 z_s.eq(b_s)
233 ]
234
235 with m.State("add_1"):
236 m.next = "normalise_1"
237
238 with m.If(tot[27]):
239 m.d.sync += [
240 z_m.eq(tot[4:27]),
241 guard.eq(tot[3]),
242 round_bit.eq(tot[2]),
243 sticky.eq(tot[1] | tot[0]),
244 z_e.eq(z_e + 1)
245 ]
246
247 with m.Else():
248 m.d.sync += [
249 z_m.eq(tot[3:26]),
250 guard.eq(tot[2]),
251 round_bit.eq(tot[1]),
252 sticky.eq(tot[0])
253 ]
254 return m
255
256 """
257 always @(posedge clk)
258 begin
259
260 case(state)
261
262 get_a:
263 begin
264 s_in_a_ack <= 1;
265 if (s_in_a_ack && in_a_stb) begin
266 a <= in_a;
267 s_in_a_ack <= 0;
268 state <= get_b;
269 end
270 end
271
272 get_b:
273 begin
274 s_in_b_ack <= 1;
275 if (s_in_b_ack && in_b_stb) begin
276 b <= in_b;
277 s_in_b_ack <= 0;
278 state <= unpack;
279 end
280 end
281
282 unpack:
283 begin
284 a_m <= {a[22 : 0], 3'd0};
285 b_m <= {b[22 : 0], 3'd0};
286 a_e <= a[30 : 23] - 127;
287 b_e <= b[30 : 23] - 127;
288 a_s <= a[31];
289 b_s <= b[31];
290 state <= special_cases;
291 end
292
293 special_cases:
294 begin
295 //if a is NaN or b is NaN return NaN
296 if ((a_e == 128 && a_m != 0) || (b_e == 128 && b_m != 0)) begin
297 z[31] <= 1;
298 z[30:23] <= 255;
299 z[22] <= 1;
300 z[21:0] <= 0;
301 state <= put_z;
302 //if a is inf return inf
303 end else if (a_e == 128) begin
304 z[31] <= a_s;
305 z[30:23] <= 255;
306 z[22:0] <= 0;
307 //if a is inf and signs don't match return nan
308 if ((b_e == 128) && (a_s != b_s)) begin
309 z[31] <= b_s;
310 z[30:23] <= 255;
311 z[22] <= 1;
312 z[21:0] <= 0;
313 end
314 state <= put_z;
315 //if b is inf return inf
316 end else if (b_e == 128) begin
317 z[31] <= b_s;
318 z[30:23] <= 255;
319 z[22:0] <= 0;
320 state <= put_z;
321 //if a is zero return b
322 end else if ((($signed(a_e) == -127) && (a_m == 0)) && (($signed(b_e) == -127) && (b_m == 0))) begin
323 z[31] <= a_s & b_s;
324 z[30:23] <= b_e[7:0] + 127;
325 z[22:0] <= b_m[26:3];
326 state <= put_z;
327 //if a is zero return b
328 end else if (($signed(a_e) == -127) && (a_m == 0)) begin
329 z[31] <= b_s;
330 z[30:23] <= b_e[7:0] + 127;
331 z[22:0] <= b_m[26:3];
332 state <= put_z;
333 //if b is zero return a
334 end else if (($signed(b_e) == -127) && (b_m == 0)) begin
335 z[31] <= a_s;
336 z[30:23] <= a_e[7:0] + 127;
337 z[22:0] <= a_m[26:3];
338 state <= put_z;
339 end else begin
340 //Denormalised Number
341 if ($signed(a_e) == -127) begin
342 a_e <= -126;
343 end else begin
344 a_m[26] <= 1;
345 end
346 //Denormalised Number
347 if ($signed(b_e) == -127) begin
348 b_e <= -126;
349 end else begin
350 b_m[26] <= 1;
351 end
352 state <= align;
353 end
354 end
355
356 align:
357 begin
358 if ($signed(a_e) > $signed(b_e)) begin
359 b_e <= b_e + 1;
360 b_m <= b_m >> 1;
361 b_m[0] <= b_m[0] | b_m[1];
362 end else if ($signed(a_e) < $signed(b_e)) begin
363 a_e <= a_e + 1;
364 a_m <= a_m >> 1;
365 a_m[0] <= a_m[0] | a_m[1];
366 end else begin
367 state <= add_0;
368 end
369 end
370
371 add_0:
372 begin
373 z_e <= a_e;
374 if (a_s == b_s) begin
375 tot <= a_m + b_m;
376 z_s <= a_s;
377 end else begin
378 if (a_m >= b_m) begin
379 tot <= a_m - b_m;
380 z_s <= a_s;
381 end else begin
382 tot <= b_m - a_m;
383 z_s <= b_s;
384 end
385 end
386 state <= add_1;
387 end
388
389 add_1:
390 begin
391 if (tot[27]) begin
392 z_m <= tot[27:4];
393 guard <= tot[3];
394 round_bit <= tot[2];
395 sticky <= tot[1] | tot[0];
396 z_e <= z_e + 1;
397 end else begin
398 z_m <= tot[26:3];
399 guard <= tot[2];
400 round_bit <= tot[1];
401 sticky <= tot[0];
402 end
403 state <= normalise_1;
404 end
405
406 normalise_1:
407 begin
408 if (z_m[23] == 0 && $signed(z_e) > -126) begin
409 z_e <= z_e - 1;
410 z_m <= z_m << 1;
411 z_m[0] <= guard;
412 guard <= round_bit;
413 round_bit <= 0;
414 end else begin
415 state <= normalise_2;
416 end
417 end
418
419 normalise_2:
420 begin
421 if ($signed(z_e) < -126) begin
422 z_e <= z_e + 1;
423 z_m <= z_m >> 1;
424 guard <= z_m[0];
425 round_bit <= guard;
426 sticky <= sticky | round_bit;
427 end else begin
428 state <= round;
429 end
430 end
431
432 round:
433 begin
434 if (guard && (round_bit | sticky | z_m[0])) begin
435 z_m <= z_m + 1;
436 if (z_m == 24'hffffff) begin
437 z_e <=z_e + 1;
438 end
439 end
440 state <= pack;
441 end
442
443 pack:
444 begin
445 z[22 : 0] <= z_m[22:0];
446 z[30 : 23] <= z_e[7:0] + 127;
447 z[31] <= z_s;
448 if ($signed(z_e) == -126 && z_m[23] == 0) begin
449 z[30 : 23] <= 0;
450 end
451 if ($signed(z_e) == -126 && z_m[23:0] == 24'h0) begin
452 z[31] <= 1'b0; // FIX SIGN BUG: -a + a = +0.
453 end
454 //if overflow occurs, return inf
455 if ($signed(z_e) > 127) begin
456 z[22 : 0] <= 0;
457 z[30 : 23] <= 255;
458 z[31] <= z_s;
459 end
460 state <= put_z;
461 end
462
463 put_z:
464 begin
465 s_out_z_stb <= 1;
466 s_out_z <= z;
467 if (s_out_z_stb && out_z_ack) begin
468 s_out_z_stb <= 0;
469 state <= get_a;
470 end
471 end
472
473 endcase
474
475 if (rst == 1) begin
476 state <= get_a;
477 s_in_a_ack <= 0;
478 s_in_b_ack <= 0;
479 s_out_z_stb <= 0;
480 end
481
482 end
483 assign in_a_ack = s_in_a_ack;
484 assign in_b_ack = s_in_b_ack;
485 assign out_z_stb = s_out_z_stb;
486 assign out_z = s_out_z;
487
488 endmodule
489 """
490
491 if __name__ == "__main__":
492 alu = FPADD(width=32)
493 main(alu, ports=[
494 alu.in_a, alu.in_a_stb, alu.in_a_ack,
495 alu.in_b, alu.in_b_stb, alu.in_b_ack,
496 alu.out_z, alu.out_z_stb, alu.out_z_ack,
497 ])