put exponent > 126 logic in FPNumBase, use it in norm module
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main, verilog
7
8 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
9
10
11 class FPState(FPBase):
12 def __init__(self, state_from):
13 self.state_from = state_from
14
15 def set_inputs(self, inputs):
16 self.inputs = inputs
17 for k,v in inputs.items():
18 setattr(self, k, v)
19
20 def set_outputs(self, outputs):
21 self.outputs = outputs
22 for k,v in outputs.items():
23 setattr(self, k, v)
24
25
26 class FPGetOpA(FPState):
27 """ gets operand a
28 """
29
30 def __init__(self, in_a, width):
31 FPState.__init__(self, "get_a")
32 self.in_a = in_a
33 self.a = FPNumIn(in_a, width)
34
35 def action(self, m):
36 self.get_op(m, self.in_a, self.a, "get_b")
37
38
39 class FPGetOpB(FPState):
40 """ gets operand b
41 """
42
43 def action(self, m):
44 self.get_op(m, self.in_b, self.b, "special_cases")
45
46
47 class FPAddSpecialCasesMod:
48 """ special cases: NaNs, infs, zeros, denormalised
49 NOTE: some of these are unique to add. see "Special Operations"
50 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
51 """
52
53 def __init__(self, width):
54 self.in_a = FPNumBase(width)
55 self.in_b = FPNumBase(width)
56 self.out_z = FPNumOut(width, False)
57 self.out_do_z = Signal(reset_less=True)
58
59 def setup(self, m, in_a, in_b, out_z, out_do_z):
60 """ links module to inputs and outputs
61 """
62 m.d.comb += self.in_a.copy(in_a)
63 m.d.comb += self.in_b.copy(in_b)
64 m.d.comb += out_z.v.eq(self.out_z.v)
65 m.d.comb += out_do_z.eq(self.out_do_z)
66
67 def elaborate(self, platform):
68 m = Module()
69
70 m.submodules.sc_in_a = self.in_a
71 m.submodules.sc_in_b = self.in_b
72 m.submodules.sc_out_z = self.out_z
73
74 s_nomatch = Signal()
75 m.d.comb += s_nomatch.eq(self.in_a.s != self.in_b.s)
76
77 m_match = Signal()
78 m.d.comb += m_match.eq(self.in_a.m == self.in_b.m)
79
80 # if a is NaN or b is NaN return NaN
81 with m.If(self.in_a.is_nan | self.in_b.is_nan):
82 m.d.comb += self.out_do_z.eq(1)
83 m.d.comb += self.out_z.nan(1)
84
85 # XXX WEIRDNESS for FP16 non-canonical NaN handling
86 # under review
87
88 ## if a is zero and b is NaN return -b
89 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
90 # m.d.comb += self.out_do_z.eq(1)
91 # m.d.comb += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
92
93 ## if b is zero and a is NaN return -a
94 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
95 # m.d.comb += self.out_do_z.eq(1)
96 # m.d.comb += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
97
98 ## if a is -zero and b is NaN return -b
99 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
100 # m.d.comb += self.out_do_z.eq(1)
101 # m.d.comb += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
102
103 ## if b is -zero and a is NaN return -a
104 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
105 # m.d.comb += self.out_do_z.eq(1)
106 # m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
107
108 # if a is inf return inf (or NaN)
109 with m.Elif(self.in_a.is_inf):
110 m.d.comb += self.out_do_z.eq(1)
111 m.d.comb += self.out_z.inf(self.in_a.s)
112 # if a is inf and signs don't match return NaN
113 with m.If(self.in_b.exp_128 & s_nomatch):
114 m.d.comb += self.out_z.nan(1)
115
116 # if b is inf return inf
117 with m.Elif(self.in_b.is_inf):
118 m.d.comb += self.out_do_z.eq(1)
119 m.d.comb += self.out_z.inf(self.in_b.s)
120
121 # if a is zero and b zero return signed-a/b
122 with m.Elif(self.in_a.is_zero & self.in_b.is_zero):
123 m.d.comb += self.out_do_z.eq(1)
124 m.d.comb += self.out_z.create(self.in_a.s & self.in_b.s,
125 self.in_b.e,
126 self.in_b.m[3:-1])
127
128 # if a is zero return b
129 with m.Elif(self.in_a.is_zero):
130 m.d.comb += self.out_do_z.eq(1)
131 m.d.comb += self.out_z.create(self.in_b.s, self.in_b.e,
132 self.in_b.m[3:-1])
133
134 # if b is zero return a
135 with m.Elif(self.in_b.is_zero):
136 m.d.comb += self.out_do_z.eq(1)
137 m.d.comb += self.out_z.create(self.in_a.s, self.in_a.e,
138 self.in_a.m[3:-1])
139
140 # if a equal to -b return zero (+ve zero)
141 with m.Elif(s_nomatch & m_match & (self.in_a.e == self.in_b.e)):
142 m.d.comb += self.out_do_z.eq(1)
143 m.d.comb += self.out_z.zero(0)
144
145 # Denormalised Number checks
146 with m.Else():
147 m.d.comb += self.out_do_z.eq(0)
148
149 return m
150
151
152 class FPAddSpecialCases(FPState):
153 """ special cases: NaNs, infs, zeros, denormalised
154 NOTE: some of these are unique to add. see "Special Operations"
155 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
156 """
157
158 def __init__(self, width):
159 FPState.__init__(self, "special_cases")
160 self.mod = FPAddSpecialCasesMod(width)
161 self.out_z = FPNumOut(width, False)
162 self.out_do_z = Signal(reset_less=True)
163
164 def action(self, m):
165 with m.If(self.out_do_z):
166 m.d.sync += self.z.v.eq(self.out_z.v) # only take the output
167 m.next = "put_z"
168 with m.Else():
169 m.next = "denormalise"
170
171
172 class FPAddDeNorm(FPState):
173
174 def action(self, m):
175 # Denormalised Number checks
176 m.next = "align"
177 self.denormalise(m, self.a)
178 self.denormalise(m, self.b)
179
180
181 class FPAddAlignMulti(FPState):
182
183 def action(self, m):
184 # NOTE: this does *not* do single-cycle multi-shifting,
185 # it *STAYS* in the align state until exponents match
186
187 # exponent of a greater than b: shift b down
188 with m.If(self.a.e > self.b.e):
189 m.d.sync += self.b.shift_down()
190 # exponent of b greater than a: shift a down
191 with m.Elif(self.a.e < self.b.e):
192 m.d.sync += self.a.shift_down()
193 # exponents equal: move to next stage.
194 with m.Else():
195 m.next = "add_0"
196
197
198 class FPAddAlignSingle(FPState):
199
200 def action(self, m):
201 # This one however (single-cycle) will do the shift
202 # in one go.
203
204 # XXX TODO: the shifter used here is quite expensive
205 # having only one would be better
206
207 ediff = Signal((len(self.a.e), True), reset_less=True)
208 ediffr = Signal((len(self.a.e), True), reset_less=True)
209 m.d.comb += ediff.eq(self.a.e - self.b.e)
210 m.d.comb += ediffr.eq(self.b.e - self.a.e)
211 with m.If(ediff > 0):
212 m.d.sync += self.b.shift_down_multi(ediff)
213 # exponent of b greater than a: shift a down
214 with m.Elif(ediff < 0):
215 m.d.sync += self.a.shift_down_multi(ediffr)
216
217 m.next = "add_0"
218
219
220 class FPAddStage0(FPState):
221 """ First stage of add. covers same-sign (add) and subtract
222 special-casing when mantissas are greater or equal, to
223 give greatest accuracy.
224 """
225
226 def action(self, m):
227 m.next = "add_1"
228 m.d.sync += self.z.e.eq(self.a.e)
229 # same-sign (both negative or both positive) add mantissas
230 with m.If(self.a.s == self.b.s):
231 m.d.sync += [
232 self.tot.eq(Cat(self.a.m, 0) + Cat(self.b.m, 0)),
233 self.z.s.eq(self.a.s)
234 ]
235 # a mantissa greater than b, use a
236 with m.Elif(self.a.m >= self.b.m):
237 m.d.sync += [
238 self.tot.eq(Cat(self.a.m, 0) - Cat(self.b.m, 0)),
239 self.z.s.eq(self.a.s)
240 ]
241 # b mantissa greater than a, use b
242 with m.Else():
243 m.d.sync += [
244 self.tot.eq(Cat(self.b.m, 0) - Cat(self.a.m, 0)),
245 self.z.s.eq(self.b.s)
246 ]
247
248
249 class FPAddStage1(FPState):
250 """ Second stage of add: preparation for normalisation.
251 detects when tot sum is too big (tot[27] is kinda a carry bit)
252 """
253
254 def action(self, m):
255 m.next = "normalise_1"
256 # tot[27] gets set when the sum overflows. shift result down
257 with m.If(self.tot[-1]):
258 m.d.sync += [
259 self.z.m.eq(self.tot[4:]),
260 self.of.m0.eq(self.tot[4]),
261 self.of.guard.eq(self.tot[3]),
262 self.of.round_bit.eq(self.tot[2]),
263 self.of.sticky.eq(self.tot[1] | self.tot[0]),
264 self.z.e.eq(self.z.e + 1)
265 ]
266 # tot[27] zero case
267 with m.Else():
268 m.d.sync += [
269 self.z.m.eq(self.tot[3:]),
270 self.of.m0.eq(self.tot[3]),
271 self.of.guard.eq(self.tot[2]),
272 self.of.round_bit.eq(self.tot[1]),
273 self.of.sticky.eq(self.tot[0])
274 ]
275
276
277 class FPNorm1Mod:
278
279 def __init__(self, width):
280 self.out_norm = Signal(reset_less=True)
281 self.in_z = FPNumBase(width, False)
282 self.out_z = FPNumBase(width, False)
283 self.in_of = Overflow()
284 self.out_of = Overflow()
285
286 def setup(self, m, in_z, out_z, in_of, out_of, out_norm):
287 """ links module to inputs and outputs
288 """
289 m.d.comb += self.in_z.copy(in_z)
290 m.d.comb += out_z.copy(self.out_z)
291 m.d.comb += self.in_of.copy(in_of)
292 m.d.comb += out_of.copy(self.out_of)
293 m.d.comb += out_norm.eq(self.out_norm)
294
295 def elaborate(self, platform):
296 m = Module()
297 m.submodules.norm1_in_overflow = self.in_of
298 m.submodules.norm1_out_overflow = self.out_of
299 m.submodules.norm1_in_z = self.in_z
300 m.submodules.norm1_out_z = self.out_z
301 m.d.comb += self.out_z.copy(self.in_z)
302 m.d.comb += self.out_of.copy(self.in_of)
303 m.d.comb += self.out_norm.eq((self.in_z.m_msbzero) & \
304 (self.in_z.exp_gt_n126))
305 with m.If(self.out_norm):
306 m.d.comb += [
307 self.out_z.e.eq(self.in_z.e - 1), # DECREASE exponent
308 self.out_z.m.eq(self.in_z.m << 1), # shift mantissa UP
309 self.out_z.m[0].eq(self.in_of.guard), # steal guard (was tot[2])
310 self.out_of.guard.eq(self.in_of.round_bit), # round (was tot[1])
311 self.out_of.round_bit.eq(0), # reset round bit
312 self.out_of.m0.eq(self.in_of.guard),
313 ]
314
315 return m
316
317
318 class FPNorm1(FPState):
319
320 def __init__(self, width):
321 FPState.__init__(self, "normalise_1")
322 self.mod = FPNorm1Mod(width)
323 self.out_norm = Signal(reset_less=True)
324 self.out_z = FPNumBase(width)
325 self.out_of = Overflow()
326
327 def action(self, m):
328 m.d.sync += self.of.copy(self.out_of)
329 m.d.sync += self.z.copy(self.out_z)
330 with m.If(~self.out_norm):
331 m.next = "normalise_2"
332
333
334 class FPNorm2(FPState):
335
336 def action(self, m):
337 self.normalise_2(m, self.z, self.of, "round")
338
339
340 class FPRoundMod:
341
342 def __init__(self, width):
343 self.in_roundz = Signal(reset_less=True)
344 self.in_z = FPNumBase(width, False)
345 self.out_z = FPNumBase(width, False)
346
347 def setup(self, m, in_z, out_z, in_of):
348 """ links module to inputs and outputs
349 """
350 m.d.comb += self.in_z.copy(in_z)
351 m.d.comb += out_z.copy(self.out_z)
352 m.d.comb += self.in_roundz.eq(in_of.roundz)
353
354 def elaborate(self, platform):
355 m = Module()
356 m.d.comb += self.out_z.copy(self.in_z)
357 with m.If(self.in_roundz):
358 m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
359 with m.If(self.in_z.m == self.in_z.m1s): # all 1s
360 m.d.comb += self.out_z.e.eq(self.in_z.e + 1) # exponent up
361 return m
362
363
364 class FPRound(FPState):
365
366 def __init__(self, width):
367 FPState.__init__(self, "round")
368 self.mod = FPRoundMod(width)
369 self.out_z = FPNumBase(width)
370
371 def action(self, m):
372 m.d.sync += self.z.copy(self.out_z)
373 m.next = "corrections"
374
375
376 class FPCorrectionsMod:
377
378 def __init__(self, width):
379 self.in_z = FPNumOut(width, False)
380 self.out_z = FPNumOut(width, False)
381
382 def setup(self, m, in_z, out_z):
383 """ links module to inputs and outputs
384 """
385 m.d.comb += self.in_z.copy(in_z)
386 m.d.comb += out_z.copy(self.out_z)
387
388 def elaborate(self, platform):
389 m = Module()
390 m.submodules.corr_in_z = self.in_z
391 m.submodules.corr_out_z = self.out_z
392 m.d.comb += self.out_z.copy(self.in_z)
393 with m.If(self.in_z.is_denormalised):
394 m.d.comb += self.out_z.e.eq(self.in_z.N127)
395
396 # with m.If(self.in_z.is_overflowed):
397 # m.d.comb += self.out_z.inf(self.in_z.s)
398 # with m.Else():
399 # m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
400 return m
401
402
403 class FPCorrections(FPState):
404
405 def __init__(self, width):
406 FPState.__init__(self, "corrections")
407 self.mod = FPCorrectionsMod(width)
408 self.out_z = FPNumBase(width)
409
410 def action(self, m):
411 m.d.sync += self.z.copy(self.out_z)
412 m.next = "pack"
413
414
415 class FPPackMod:
416
417 def __init__(self, width):
418 self.in_z = FPNumOut(width, False)
419 self.out_z = FPNumOut(width, False)
420
421 def setup(self, m, in_z, out_z):
422 """ links module to inputs and outputs
423 """
424 m.d.comb += self.in_z.copy(in_z)
425 m.d.comb += out_z.v.eq(self.out_z.v)
426
427 def elaborate(self, platform):
428 m = Module()
429 m.submodules.pack_in_z = self.in_z
430 with m.If(self.in_z.is_overflowed):
431 m.d.comb += self.out_z.inf(self.in_z.s)
432 with m.Else():
433 m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
434 return m
435
436
437 class FPPack(FPState):
438
439 def __init__(self, width):
440 FPState.__init__(self, "pack")
441 self.mod = FPPackMod(width)
442 self.out_z = FPNumOut(width, False)
443
444 def action(self, m):
445 m.d.sync += self.z.v.eq(self.out_z.v)
446 m.next = "put_z"
447
448
449 class FPPutZ(FPState):
450
451 def action(self, m):
452 self.put_z(m, self.z, self.out_z, "get_a")
453
454
455 class FPADD:
456
457 def __init__(self, width, single_cycle=False):
458 self.width = width
459 self.single_cycle = single_cycle
460
461 self.in_a = FPOp(width)
462 self.in_b = FPOp(width)
463 self.out_z = FPOp(width)
464
465 self.states = []
466
467 def add_state(self, state):
468 self.states.append(state)
469 return state
470
471 def get_fragment(self, platform=None):
472 """ creates the HDL code-fragment for FPAdd
473 """
474 m = Module()
475
476 # Latches
477 #a = FPNumIn(self.in_a, self.width)
478 b = FPNumIn(self.in_b, self.width)
479 z = FPNumOut(self.width, False)
480
481 m.submodules.fpnum_b = b
482 m.submodules.fpnum_z = z
483
484 w = z.m_width + 4
485 tot = Signal(w, reset_less=True) # sticky/round/guard, {mantissa} result, 1 overflow
486
487 of = Overflow()
488 m.submodules.overflow = of
489
490 geta = self.add_state(FPGetOpA(self.in_a, self.width))
491 #geta.set_inputs({"in_a": self.in_a})
492 #geta.set_outputs({"a": a})
493 a = geta.a
494 # XXX m.d.comb += a.v.eq(self.in_a.v) # links in_a to a
495 m.submodules.fpnum_a = a
496
497 getb = self.add_state(FPGetOpB("get_b"))
498 getb.set_inputs({"in_b": self.in_b})
499 getb.set_outputs({"b": b})
500 # XXX m.d.comb += b.v.eq(self.in_b.v) # links in_b to b
501
502 sc = self.add_state(FPAddSpecialCases(self.width))
503 sc.set_inputs({"a": a, "b": b})
504 sc.set_outputs({"z": z})
505 sc.mod.setup(m, a, b, sc.out_z, sc.out_do_z)
506 m.submodules.specialcases = sc.mod
507
508 dn = self.add_state(FPAddDeNorm("denormalise"))
509 dn.set_inputs({"a": a, "b": b})
510 dn.set_outputs({"a": a, "b": b}) # XXX outputs same as inputs
511
512 if self.single_cycle:
513 alm = self.add_state(FPAddAlignSingle("align"))
514 else:
515 alm = self.add_state(FPAddAlignMulti("align"))
516 alm.set_inputs({"a": a, "b": b})
517 alm.set_outputs({"a": a, "b": b}) # XXX outputs same as inputs
518
519 add0 = self.add_state(FPAddStage0("add_0"))
520 add0.set_inputs({"a": a, "b": b})
521 add0.set_outputs({"z": z, "tot": tot})
522
523 add1 = self.add_state(FPAddStage1("add_1"))
524 add1.set_inputs({"tot": tot, "z": z}) # Z input passes through
525 add1.set_outputs({"z": z, "of": of}) # XXX Z as output
526
527 n1 = self.add_state(FPNorm1(self.width))
528 n1.set_inputs({"z": z, "of": of}) # XXX Z as output
529 n1.set_outputs({"z": z}) # XXX Z as output
530 n1.mod.setup(m, z, n1.out_z, of, n1.out_of, n1.out_norm)
531 m.submodules.normalise_1 = n1.mod
532
533 n2 = self.add_state(FPNorm2("normalise_2"))
534 n2.set_inputs({"z": z, "of": of}) # XXX Z as output
535 n2.set_outputs({"z": z}) # XXX Z as output
536
537 rn = self.add_state(FPRound(self.width))
538 rn.set_inputs({"z": z, "of": of}) # XXX Z as output
539 rn.set_outputs({"z": z}) # XXX Z as output
540 rn.mod.setup(m, z, rn.out_z, of)
541 m.submodules.roundz = rn.mod
542
543 cor = self.add_state(FPCorrections(self.width))
544 cor.set_inputs({"z": z}) # XXX Z as output
545 cor.set_outputs({"z": z}) # XXX Z as output
546 cor.mod.setup(m, z, cor.out_z)
547 m.submodules.corrections = cor.mod
548
549 pa = self.add_state(FPPack(self.width))
550 pa.set_inputs({"z": z}) # XXX Z as output
551 pa.set_outputs({"z": z}) # XXX Z as output
552 pa.mod.setup(m, z, pa.out_z)
553 m.submodules.pack = pa.mod
554
555 pz = self.add_state(FPPutZ("put_z"))
556 pz.set_inputs({"z": z})
557 pz.set_outputs({"out_z": self.out_z})
558
559 with m.FSM() as fsm:
560
561 for state in self.states:
562 with m.State(state.state_from):
563 state.action(m)
564
565 return m
566
567
568 if __name__ == "__main__":
569 alu = FPADD(width=32)
570 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
571
572
573 # works... but don't use, just do "python fname.py convert -t v"
574 #print (verilog.convert(alu, ports=[
575 # ports=alu.in_a.ports() + \
576 # alu.in_b.ports() + \
577 # alu.out_z.ports())