split special cases into separate module and use it
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main, verilog
7
8 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
9
10
11 class FPState(FPBase):
12 def __init__(self, state_from):
13 self.state_from = state_from
14
15 def set_inputs(self, inputs):
16 self.inputs = inputs
17 for k,v in inputs.items():
18 setattr(self, k, v)
19
20 def set_outputs(self, outputs):
21 self.outputs = outputs
22 for k,v in outputs.items():
23 setattr(self, k, v)
24
25
26 class FPGetOpA(FPState):
27 """ gets operand a
28 """
29
30 def __init__(self, in_a, width):
31 FPState.__init__(self, "get_a")
32 self.in_a = in_a
33 self.a = FPNumIn(in_a, width)
34
35 def action(self, m):
36 self.get_op(m, self.in_a, self.a, "get_b")
37
38
39 class FPGetOpB(FPState):
40 """ gets operand b
41 """
42
43 def action(self, m):
44 self.get_op(m, self.in_b, self.b, "special_cases")
45
46
47 class FPAddSpecialCasesMod:
48 """ special cases: NaNs, infs, zeros, denormalised
49 NOTE: some of these are unique to add. see "Special Operations"
50 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
51 """
52
53 def __init__(self, width):
54 self.in_a = FPNumBase(width)
55 self.in_b = FPNumBase(width)
56 self.out_z = FPNumOut(width, False)
57 self.out_do_z = Signal(reset_less=True)
58
59 def setup(self, m, in_a, in_b, out_z, out_do_z):
60 """ links module to inputs and outputs
61 """
62 m.d.comb += self.in_a.copy(in_a)
63 m.d.comb += self.in_b.copy(in_b)
64 m.d.comb += out_z.v.eq(self.out_z.v)
65 m.d.comb += out_do_z.eq(self.out_do_z)
66
67 def elaborate(self, platform):
68 m = Module()
69
70 m.submodules.sc_in_a = self.in_a
71 m.submodules.sc_in_b = self.in_b
72 m.submodules.sc_out_z = self.out_z
73
74 s_nomatch = Signal()
75 m.d.comb += s_nomatch.eq(self.in_a.s != self.in_b.s)
76
77 m_match = Signal()
78 m.d.comb += m_match.eq(self.in_a.m == self.in_b.m)
79
80 # if a is NaN or b is NaN return NaN
81 with m.If(self.in_a.is_nan | self.in_b.is_nan):
82 m.d.comb += self.out_do_z.eq(1)
83 m.d.comb += self.out_z.nan(1)
84
85 # XXX WEIRDNESS for FP16 non-canonical NaN handling
86 # under review
87
88 ## if a is zero and b is NaN return -b
89 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
90 # m.d.comb += self.out_do_z.eq(1)
91 # m.d.comb += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
92
93 ## if b is zero and a is NaN return -a
94 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
95 # m.d.comb += self.out_do_z.eq(1)
96 # m.d.comb += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
97
98 ## if a is -zero and b is NaN return -b
99 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
100 # m.d.comb += self.out_do_z.eq(1)
101 # m.d.comb += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
102
103 ## if b is -zero and a is NaN return -a
104 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
105 # m.d.comb += self.out_do_z.eq(1)
106 # m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
107
108 # if a is inf return inf (or NaN)
109 with m.Elif(self.in_a.is_inf):
110 m.d.comb += self.out_do_z.eq(1)
111 m.d.comb += self.out_z.inf(self.in_a.s)
112 # if a is inf and signs don't match return NaN
113 with m.If(self.in_b.exp_128 & s_nomatch):
114 m.d.comb += self.out_z.nan(1)
115
116 # if b is inf return inf
117 with m.Elif(self.in_b.is_inf):
118 m.d.comb += self.out_do_z.eq(1)
119 m.d.comb += self.out_z.inf(self.in_b.s)
120
121 # if a is zero and b zero return signed-a/b
122 with m.Elif(self.in_a.is_zero & self.in_b.is_zero):
123 m.d.comb += self.out_do_z.eq(1)
124 m.d.comb += self.out_z.create(self.in_a.s & self.in_b.s,
125 self.in_b.e,
126 self.in_b.m[3:-1])
127
128 # if a is zero return b
129 with m.Elif(self.in_a.is_zero):
130 m.d.comb += self.out_do_z.eq(1)
131 m.d.comb += self.out_z.create(self.in_b.s, self.in_b.e,
132 self.in_b.m[3:-1])
133
134 # if b is zero return a
135 with m.Elif(self.in_b.is_zero):
136 m.d.comb += self.out_do_z.eq(1)
137 m.d.comb += self.out_z.create(self.in_a.s, self.in_a.e,
138 self.in_a.m[3:-1])
139
140 # if a equal to -b return zero (+ve zero)
141 with m.Elif(s_nomatch & m_match & (self.in_a.e == self.in_b.e)):
142 m.d.comb += self.out_do_z.eq(1)
143 m.d.comb += self.out_z.zero(0)
144
145 # Denormalised Number checks
146 with m.Else():
147 m.d.comb += self.out_do_z.eq(0)
148
149 return m
150
151
152 class FPAddSpecialCases(FPState):
153 """ special cases: NaNs, infs, zeros, denormalised
154 NOTE: some of these are unique to add. see "Special Operations"
155 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
156 """
157
158 def __init__(self, width):
159 FPState.__init__(self, "special_cases")
160 self.mod = FPAddSpecialCasesMod(width)
161 self.out_z = FPNumOut(width, False)
162 self.out_do_z = Signal(reset_less=True)
163
164 def action(self, m):
165 with m.If(self.out_do_z):
166 m.d.sync += self.z.v.eq(self.out_z.v) # only take the output
167 m.next = "put_z"
168 with m.Else():
169 m.next = "denormalise"
170
171
172 class FPAddDeNorm(FPState):
173
174 def action(self, m):
175 # Denormalised Number checks
176 m.next = "align"
177 self.denormalise(m, self.a)
178 self.denormalise(m, self.b)
179
180
181 class FPAddAlignMulti(FPState):
182
183 def action(self, m):
184 # NOTE: this does *not* do single-cycle multi-shifting,
185 # it *STAYS* in the align state until exponents match
186
187 # exponent of a greater than b: shift b down
188 with m.If(self.a.e > self.b.e):
189 m.d.sync += self.b.shift_down()
190 # exponent of b greater than a: shift a down
191 with m.Elif(self.a.e < self.b.e):
192 m.d.sync += self.a.shift_down()
193 # exponents equal: move to next stage.
194 with m.Else():
195 m.next = "add_0"
196
197
198 class FPAddAlignSingle(FPState):
199
200 def action(self, m):
201 # This one however (single-cycle) will do the shift
202 # in one go.
203
204 # XXX TODO: the shifter used here is quite expensive
205 # having only one would be better
206
207 ediff = Signal((len(self.a.e), True), reset_less=True)
208 ediffr = Signal((len(self.a.e), True), reset_less=True)
209 m.d.comb += ediff.eq(self.a.e - self.b.e)
210 m.d.comb += ediffr.eq(self.b.e - self.a.e)
211 with m.If(ediff > 0):
212 m.d.sync += self.b.shift_down_multi(ediff)
213 # exponent of b greater than a: shift a down
214 with m.Elif(ediff < 0):
215 m.d.sync += self.a.shift_down_multi(ediffr)
216
217 m.next = "add_0"
218
219
220 class FPAddStage0(FPState):
221 """ First stage of add. covers same-sign (add) and subtract
222 special-casing when mantissas are greater or equal, to
223 give greatest accuracy.
224 """
225
226 def action(self, m):
227 m.next = "add_1"
228 m.d.sync += self.z.e.eq(self.a.e)
229 # same-sign (both negative or both positive) add mantissas
230 with m.If(self.a.s == self.b.s):
231 m.d.sync += [
232 self.tot.eq(Cat(self.a.m, 0) + Cat(self.b.m, 0)),
233 self.z.s.eq(self.a.s)
234 ]
235 # a mantissa greater than b, use a
236 with m.Elif(self.a.m >= self.b.m):
237 m.d.sync += [
238 self.tot.eq(Cat(self.a.m, 0) - Cat(self.b.m, 0)),
239 self.z.s.eq(self.a.s)
240 ]
241 # b mantissa greater than a, use b
242 with m.Else():
243 m.d.sync += [
244 self.tot.eq(Cat(self.b.m, 0) - Cat(self.a.m, 0)),
245 self.z.s.eq(self.b.s)
246 ]
247
248
249 class FPAddStage1(FPState):
250 """ Second stage of add: preparation for normalisation.
251 detects when tot sum is too big (tot[27] is kinda a carry bit)
252 """
253
254 def action(self, m):
255 m.next = "normalise_1"
256 # tot[27] gets set when the sum overflows. shift result down
257 with m.If(self.tot[-1]):
258 m.d.sync += [
259 self.z.m.eq(self.tot[4:]),
260 self.of.m0.eq(self.tot[4]),
261 self.of.guard.eq(self.tot[3]),
262 self.of.round_bit.eq(self.tot[2]),
263 self.of.sticky.eq(self.tot[1] | self.tot[0]),
264 self.z.e.eq(self.z.e + 1)
265 ]
266 # tot[27] zero case
267 with m.Else():
268 m.d.sync += [
269 self.z.m.eq(self.tot[3:]),
270 self.of.m0.eq(self.tot[3]),
271 self.of.guard.eq(self.tot[2]),
272 self.of.round_bit.eq(self.tot[1]),
273 self.of.sticky.eq(self.tot[0])
274 ]
275
276
277 class FPNorm1(FPState):
278
279 def action(self, m):
280 self.normalise_1(m, self.z, self.of, "normalise_2")
281
282
283 class FPNorm2(FPState):
284
285 def action(self, m):
286 self.normalise_2(m, self.z, self.of, "round")
287
288
289 class FPRoundMod:
290
291 def __init__(self, width):
292 self.in_roundz = Signal(reset_less=True)
293 self.in_z = FPNumBase(width, False)
294 self.out_z = FPNumBase(width, False)
295
296 def setup(self, m, in_z, out_z, in_of):
297 """ links module to inputs and outputs
298 """
299 m.d.comb += self.in_z.copy(in_z)
300 m.d.comb += out_z.copy(self.out_z)
301 m.d.comb += self.in_roundz.eq(in_of.roundz)
302
303 def elaborate(self, platform):
304 m = Module()
305 m.d.comb += self.out_z.copy(self.in_z)
306 with m.If(self.in_roundz):
307 m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
308 with m.If(self.in_z.m == self.in_z.m1s): # all 1s
309 m.d.comb += self.out_z.e.eq(self.in_z.e + 1) # exponent up
310 return m
311
312
313 class FPRound(FPState):
314
315 def __init__(self, width):
316 FPState.__init__(self, "round")
317 self.mod = FPRoundMod(width)
318 self.out_z = FPNumBase(width)
319
320 def action(self, m):
321 m.d.sync += self.z.copy(self.out_z)
322 m.next = "corrections"
323
324
325 class FPCorrectionsMod:
326
327 def __init__(self, width):
328 self.in_z = FPNumOut(width, False)
329 self.out_z = FPNumOut(width, False)
330
331 def setup(self, m, in_z, out_z):
332 """ links module to inputs and outputs
333 """
334 m.d.comb += self.in_z.copy(in_z)
335 m.d.comb += out_z.copy(self.out_z)
336
337 def elaborate(self, platform):
338 m = Module()
339 m.submodules.corr_in_z = self.in_z
340 m.submodules.corr_out_z = self.out_z
341 m.d.comb += self.out_z.copy(self.in_z)
342 with m.If(self.in_z.is_denormalised):
343 m.d.comb += self.out_z.e.eq(self.in_z.N127)
344
345 # with m.If(self.in_z.is_overflowed):
346 # m.d.comb += self.out_z.inf(self.in_z.s)
347 # with m.Else():
348 # m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
349 return m
350
351
352 class FPCorrections(FPState):
353
354 def __init__(self, width):
355 FPState.__init__(self, "corrections")
356 self.mod = FPCorrectionsMod(width)
357 self.out_z = FPNumBase(width)
358
359 def action(self, m):
360 m.d.sync += self.z.copy(self.out_z)
361 m.next = "pack"
362
363
364 class FPPackMod:
365
366 def __init__(self, width):
367 self.in_z = FPNumOut(width, False)
368 self.out_z = FPNumOut(width, False)
369
370 def setup(self, m, in_z, out_z):
371 """ links module to inputs and outputs
372 """
373 m.d.comb += self.in_z.copy(in_z)
374 m.d.comb += out_z.v.eq(self.out_z.v)
375
376 def elaborate(self, platform):
377 m = Module()
378 m.submodules.pack_in_z = self.in_z
379 with m.If(self.in_z.is_overflowed):
380 m.d.comb += self.out_z.inf(self.in_z.s)
381 with m.Else():
382 m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
383 return m
384
385
386 class FPPack(FPState):
387
388 def __init__(self, width):
389 FPState.__init__(self, "pack")
390 self.mod = FPPackMod(width)
391 self.out_z = FPNumOut(width, False)
392
393 def action(self, m):
394 m.d.sync += self.z.v.eq(self.out_z.v)
395 m.next = "put_z"
396
397
398 class FPPutZ(FPState):
399
400 def action(self, m):
401 self.put_z(m, self.z, self.out_z, "get_a")
402
403
404 class FPADD:
405
406 def __init__(self, width, single_cycle=False):
407 self.width = width
408 self.single_cycle = single_cycle
409
410 self.in_a = FPOp(width)
411 self.in_b = FPOp(width)
412 self.out_z = FPOp(width)
413
414 self.states = []
415
416 def add_state(self, state):
417 self.states.append(state)
418 return state
419
420 def get_fragment(self, platform=None):
421 """ creates the HDL code-fragment for FPAdd
422 """
423 m = Module()
424
425 # Latches
426 #a = FPNumIn(self.in_a, self.width)
427 b = FPNumIn(self.in_b, self.width)
428 z = FPNumOut(self.width, False)
429
430 m.submodules.fpnum_b = b
431 m.submodules.fpnum_z = z
432
433 w = z.m_width + 4
434 tot = Signal(w, reset_less=True) # sticky/round/guard, {mantissa} result, 1 overflow
435
436 of = Overflow()
437 m.submodules.overflow = of
438
439 geta = self.add_state(FPGetOpA(self.in_a, self.width))
440 #geta.set_inputs({"in_a": self.in_a})
441 #geta.set_outputs({"a": a})
442 a = geta.a
443 # XXX m.d.comb += a.v.eq(self.in_a.v) # links in_a to a
444 m.submodules.fpnum_a = a
445
446 getb = self.add_state(FPGetOpB("get_b"))
447 getb.set_inputs({"in_b": self.in_b})
448 getb.set_outputs({"b": b})
449 # XXX m.d.comb += b.v.eq(self.in_b.v) # links in_b to b
450
451 sc = self.add_state(FPAddSpecialCases(self.width))
452 sc.set_inputs({"a": a, "b": b})
453 sc.set_outputs({"z": z})
454 sc.mod.setup(m, a, b, sc.out_z, sc.out_do_z)
455 m.submodules.specialcases = sc.mod
456
457 dn = self.add_state(FPAddDeNorm("denormalise"))
458 dn.set_inputs({"a": a, "b": b})
459 dn.set_outputs({"a": a, "b": b}) # XXX outputs same as inputs
460
461 if self.single_cycle:
462 alm = self.add_state(FPAddAlignSingle("align"))
463 else:
464 alm = self.add_state(FPAddAlignMulti("align"))
465 alm.set_inputs({"a": a, "b": b})
466 alm.set_outputs({"a": a, "b": b}) # XXX outputs same as inputs
467
468 add0 = self.add_state(FPAddStage0("add_0"))
469 add0.set_inputs({"a": a, "b": b})
470 add0.set_outputs({"z": z, "tot": tot})
471
472 add1 = self.add_state(FPAddStage1("add_1"))
473 add1.set_inputs({"tot": tot, "z": z}) # Z input passes through
474 add1.set_outputs({"z": z, "of": of}) # XXX Z as output
475
476 n1 = self.add_state(FPNorm1("normalise_1"))
477 n1.set_inputs({"z": z, "of": of}) # XXX Z as output
478 n1.set_outputs({"z": z}) # XXX Z as output
479
480 n2 = self.add_state(FPNorm2("normalise_2"))
481 n2.set_inputs({"z": z, "of": of}) # XXX Z as output
482 n2.set_outputs({"z": z}) # XXX Z as output
483
484 rn = self.add_state(FPRound(self.width))
485 rn.set_inputs({"z": z, "of": of}) # XXX Z as output
486 rn.set_outputs({"z": z}) # XXX Z as output
487 rn.mod.setup(m, z, rn.out_z, of)
488 m.submodules.roundz = rn.mod
489
490 cor = self.add_state(FPCorrections(self.width))
491 cor.set_inputs({"z": z}) # XXX Z as output
492 cor.set_outputs({"z": z}) # XXX Z as output
493 cor.mod.setup(m, z, cor.out_z)
494 m.submodules.corrections = cor.mod
495
496 pa = self.add_state(FPPack(self.width))
497 pa.set_inputs({"z": z}) # XXX Z as output
498 pa.set_outputs({"z": z}) # XXX Z as output
499 pa.mod.setup(m, z, pa.out_z)
500 m.submodules.pack = pa.mod
501
502 pz = self.add_state(FPPutZ("put_z"))
503 pz.set_inputs({"z": z})
504 pz.set_outputs({"out_z": self.out_z})
505
506 with m.FSM() as fsm:
507
508 for state in self.states:
509 with m.State(state.state_from):
510 state.action(m)
511
512 return m
513
514
515 if __name__ == "__main__":
516 alu = FPADD(width=32)
517 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
518
519
520 # works... but don't use, just do "python fname.py convert -t v"
521 #print (verilog.convert(alu, ports=[
522 # ports=alu.in_a.ports() + \
523 # alu.in_b.ports() + \
524 # alu.out_z.ports())