split out add align into separate module
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Mux, Array, Const
6 from nmigen.lib.coding import PriorityEncoder
7 from nmigen.cli import main, verilog
8 from math import log
9
10 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
11 from fpbase import MultiShiftRMerge, Trigger
12 from singlepipe import (ControlBase, StageChain, UnbufferedPipeline,
13 PassThroughStage)
14 from multipipe import CombMuxOutPipe
15 from multipipe import PriorityCombMuxInPipe
16
17 from fpbase import FPState, FPID
18 from fpcommon.getop import (FPGetOpMod, FPGetOp, FPNumBase2Ops, FPADDBaseData,
19 FPGet2OpMod, FPGet2Op)
20 from fpcommon.denorm import (FPSCData, FPAddDeNormMod, FPAddDeNorm)
21 from fpcommon.postcalc import FPAddStage1Data
22 from fpcommon.postnormalise import (FPNorm1Data, FPNorm1ModSingle,
23 FPNorm1ModMulti, FPNorm1Single, FPNorm1Multi)
24 from fpcommon.roundz import (FPRoundData, FPRoundMod, FPRound)
25 from fpcommon.corrections import (FPCorrectionsMod, FPCorrections)
26 from fpcommon.pack import (FPPackData, FPPackMod, FPPack)
27 from fpcommon.normtopack import FPNormToPack
28 from fpcommon.putz import (FPPutZ, FPPutZIdx)
29
30 from fpadd.specialcases import (FPAddSpecialCasesMod, FPAddSpecialCases,
31 FPAddSpecialCasesDeNorm)
32 from fpadd.align import (FPAddAlignMulti, FPAddAlignMultiMod, FPNumIn2Ops,
33 FPAddAlignSingleMod, FPAddAlignSingle)
34
35
36 class FPAddAlignSingleAdd(FPState, UnbufferedPipeline):
37
38 def __init__(self, width, id_wid):
39 FPState.__init__(self, "align")
40 self.width = width
41 self.id_wid = id_wid
42 UnbufferedPipeline.__init__(self, self) # pipeline is its own stage
43 self.a1o = self.ospec()
44
45 def ispec(self):
46 return FPSCData(self.width, self.id_wid)
47
48 def ospec(self):
49 return FPAddStage1Data(self.width, self.id_wid) # AddStage1 ospec
50
51 def setup(self, m, i):
52 """ links module to inputs and outputs
53 """
54
55 # chain AddAlignSingle, AddStage0 and AddStage1
56 mod = FPAddAlignSingleMod(self.width, self.id_wid)
57 a0mod = FPAddStage0Mod(self.width, self.id_wid)
58 a1mod = FPAddStage1Mod(self.width, self.id_wid)
59
60 chain = StageChain([mod, a0mod, a1mod])
61 chain.setup(m, i)
62
63 self.o = a1mod.o
64
65 def process(self, i):
66 return self.o
67
68 def action(self, m):
69 m.d.sync += self.a1o.eq(self.process(None))
70 m.next = "normalise_1"
71
72
73 class FPAddStage0Data:
74
75 def __init__(self, width, id_wid):
76 self.z = FPNumBase(width, False)
77 self.out_do_z = Signal(reset_less=True)
78 self.oz = Signal(width, reset_less=True)
79 self.tot = Signal(self.z.m_width + 4, reset_less=True)
80 self.mid = Signal(id_wid, reset_less=True)
81
82 def eq(self, i):
83 return [self.z.eq(i.z), self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
84 self.tot.eq(i.tot), self.mid.eq(i.mid)]
85
86
87 class FPAddStage0Mod:
88
89 def __init__(self, width, id_wid):
90 self.width = width
91 self.id_wid = id_wid
92 self.i = self.ispec()
93 self.o = self.ospec()
94
95 def ispec(self):
96 return FPSCData(self.width, self.id_wid)
97
98 def ospec(self):
99 return FPAddStage0Data(self.width, self.id_wid)
100
101 def process(self, i):
102 return self.o
103
104 def setup(self, m, i):
105 """ links module to inputs and outputs
106 """
107 m.submodules.add0 = self
108 m.d.comb += self.i.eq(i)
109
110 def elaborate(self, platform):
111 m = Module()
112 m.submodules.add0_in_a = self.i.a
113 m.submodules.add0_in_b = self.i.b
114 m.submodules.add0_out_z = self.o.z
115
116 # store intermediate tests (and zero-extended mantissas)
117 seq = Signal(reset_less=True)
118 mge = Signal(reset_less=True)
119 am0 = Signal(len(self.i.a.m)+1, reset_less=True)
120 bm0 = Signal(len(self.i.b.m)+1, reset_less=True)
121 m.d.comb += [seq.eq(self.i.a.s == self.i.b.s),
122 mge.eq(self.i.a.m >= self.i.b.m),
123 am0.eq(Cat(self.i.a.m, 0)),
124 bm0.eq(Cat(self.i.b.m, 0))
125 ]
126 # same-sign (both negative or both positive) add mantissas
127 with m.If(~self.i.out_do_z):
128 m.d.comb += self.o.z.e.eq(self.i.a.e)
129 with m.If(seq):
130 m.d.comb += [
131 self.o.tot.eq(am0 + bm0),
132 self.o.z.s.eq(self.i.a.s)
133 ]
134 # a mantissa greater than b, use a
135 with m.Elif(mge):
136 m.d.comb += [
137 self.o.tot.eq(am0 - bm0),
138 self.o.z.s.eq(self.i.a.s)
139 ]
140 # b mantissa greater than a, use b
141 with m.Else():
142 m.d.comb += [
143 self.o.tot.eq(bm0 - am0),
144 self.o.z.s.eq(self.i.b.s)
145 ]
146
147 m.d.comb += self.o.oz.eq(self.i.oz)
148 m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
149 m.d.comb += self.o.mid.eq(self.i.mid)
150 return m
151
152
153 class FPAddStage0(FPState):
154 """ First stage of add. covers same-sign (add) and subtract
155 special-casing when mantissas are greater or equal, to
156 give greatest accuracy.
157 """
158
159 def __init__(self, width, id_wid):
160 FPState.__init__(self, "add_0")
161 self.mod = FPAddStage0Mod(width)
162 self.o = self.mod.ospec()
163
164 def setup(self, m, i):
165 """ links module to inputs and outputs
166 """
167 self.mod.setup(m, i)
168
169 # NOTE: these could be done as combinatorial (merge add0+add1)
170 m.d.sync += self.o.eq(self.mod.o)
171
172 def action(self, m):
173 m.next = "add_1"
174
175
176 class FPAddStage1Mod(FPState):
177 """ Second stage of add: preparation for normalisation.
178 detects when tot sum is too big (tot[27] is kinda a carry bit)
179 """
180
181 def __init__(self, width, id_wid):
182 self.width = width
183 self.id_wid = id_wid
184 self.i = self.ispec()
185 self.o = self.ospec()
186
187 def ispec(self):
188 return FPAddStage0Data(self.width, self.id_wid)
189
190 def ospec(self):
191 return FPAddStage1Data(self.width, self.id_wid)
192
193 def process(self, i):
194 return self.o
195
196 def setup(self, m, i):
197 """ links module to inputs and outputs
198 """
199 m.submodules.add1 = self
200 m.submodules.add1_out_overflow = self.o.of
201
202 m.d.comb += self.i.eq(i)
203
204 def elaborate(self, platform):
205 m = Module()
206 m.d.comb += self.o.z.eq(self.i.z)
207 # tot[-1] (MSB) gets set when the sum overflows. shift result down
208 with m.If(~self.i.out_do_z):
209 with m.If(self.i.tot[-1]):
210 m.d.comb += [
211 self.o.z.m.eq(self.i.tot[4:]),
212 self.o.of.m0.eq(self.i.tot[4]),
213 self.o.of.guard.eq(self.i.tot[3]),
214 self.o.of.round_bit.eq(self.i.tot[2]),
215 self.o.of.sticky.eq(self.i.tot[1] | self.i.tot[0]),
216 self.o.z.e.eq(self.i.z.e + 1)
217 ]
218 # tot[-1] (MSB) zero case
219 with m.Else():
220 m.d.comb += [
221 self.o.z.m.eq(self.i.tot[3:]),
222 self.o.of.m0.eq(self.i.tot[3]),
223 self.o.of.guard.eq(self.i.tot[2]),
224 self.o.of.round_bit.eq(self.i.tot[1]),
225 self.o.of.sticky.eq(self.i.tot[0])
226 ]
227
228 m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
229 m.d.comb += self.o.oz.eq(self.i.oz)
230 m.d.comb += self.o.mid.eq(self.i.mid)
231
232 return m
233
234
235 class FPAddStage1(FPState):
236
237 def __init__(self, width, id_wid):
238 FPState.__init__(self, "add_1")
239 self.mod = FPAddStage1Mod(width)
240 self.out_z = FPNumBase(width, False)
241 self.out_of = Overflow()
242 self.norm_stb = Signal()
243
244 def setup(self, m, i):
245 """ links module to inputs and outputs
246 """
247 self.mod.setup(m, i)
248
249 m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state
250
251 m.d.sync += self.out_of.eq(self.mod.out_of)
252 m.d.sync += self.out_z.eq(self.mod.out_z)
253 m.d.sync += self.norm_stb.eq(1)
254
255 def action(self, m):
256 m.next = "normalise_1"
257
258
259
260
261 class FPOpData:
262 def __init__(self, width, id_wid):
263 self.z = FPOp(width)
264 self.mid = Signal(id_wid, reset_less=True)
265
266 def eq(self, i):
267 return [self.z.eq(i.z), self.mid.eq(i.mid)]
268
269 def ports(self):
270 return [self.z, self.mid]
271
272
273 class FPADDBaseMod:
274
275 def __init__(self, width, id_wid=None, single_cycle=False, compact=True):
276 """ IEEE754 FP Add
277
278 * width: bit-width of IEEE754. supported: 16, 32, 64
279 * id_wid: an identifier that is sync-connected to the input
280 * single_cycle: True indicates each stage to complete in 1 clock
281 * compact: True indicates a reduced number of stages
282 """
283 self.width = width
284 self.id_wid = id_wid
285 self.single_cycle = single_cycle
286 self.compact = compact
287
288 self.in_t = Trigger()
289 self.i = self.ispec()
290 self.o = self.ospec()
291
292 self.states = []
293
294 def ispec(self):
295 return FPADDBaseData(self.width, self.id_wid)
296
297 def ospec(self):
298 return FPOpData(self.width, self.id_wid)
299
300 def add_state(self, state):
301 self.states.append(state)
302 return state
303
304 def get_fragment(self, platform=None):
305 """ creates the HDL code-fragment for FPAdd
306 """
307 m = Module()
308 m.submodules.out_z = self.o.z
309 m.submodules.in_t = self.in_t
310 if self.compact:
311 self.get_compact_fragment(m, platform)
312 else:
313 self.get_longer_fragment(m, platform)
314
315 with m.FSM() as fsm:
316
317 for state in self.states:
318 with m.State(state.state_from):
319 state.action(m)
320
321 return m
322
323 def get_longer_fragment(self, m, platform=None):
324
325 get = self.add_state(FPGet2Op("get_ops", "special_cases",
326 self.width))
327 get.setup(m, self.i)
328 a = get.out_op1
329 b = get.out_op2
330 get.trigger_setup(m, self.in_t.stb, self.in_t.ack)
331
332 sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid))
333 sc.setup(m, a, b, self.in_mid)
334
335 dn = self.add_state(FPAddDeNorm(self.width, self.id_wid))
336 dn.setup(m, a, b, sc.in_mid)
337
338 if self.single_cycle:
339 alm = self.add_state(FPAddAlignSingle(self.width, self.id_wid))
340 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
341 else:
342 alm = self.add_state(FPAddAlignMulti(self.width, self.id_wid))
343 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
344
345 add0 = self.add_state(FPAddStage0(self.width, self.id_wid))
346 add0.setup(m, alm.out_a, alm.out_b, alm.in_mid)
347
348 add1 = self.add_state(FPAddStage1(self.width, self.id_wid))
349 add1.setup(m, add0.out_tot, add0.out_z, add0.in_mid)
350
351 if self.single_cycle:
352 n1 = self.add_state(FPNorm1Single(self.width, self.id_wid))
353 n1.setup(m, add1.out_z, add1.out_of, add0.in_mid)
354 else:
355 n1 = self.add_state(FPNorm1Multi(self.width, self.id_wid))
356 n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb, add0.in_mid)
357
358 rn = self.add_state(FPRound(self.width, self.id_wid))
359 rn.setup(m, n1.out_z, n1.out_roundz, n1.in_mid)
360
361 cor = self.add_state(FPCorrections(self.width, self.id_wid))
362 cor.setup(m, rn.out_z, rn.in_mid)
363
364 pa = self.add_state(FPPack(self.width, self.id_wid))
365 pa.setup(m, cor.out_z, rn.in_mid)
366
367 ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z,
368 pa.in_mid, self.out_mid))
369
370 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
371 pa.in_mid, self.out_mid))
372
373 def get_compact_fragment(self, m, platform=None):
374
375
376 get = FPGet2Op("get_ops", "special_cases", self.width, self.id_wid)
377 sc = FPAddSpecialCasesDeNorm(self.width, self.id_wid)
378 alm = FPAddAlignSingleAdd(self.width, self.id_wid)
379 n1 = FPNormToPack(self.width, self.id_wid)
380
381 get.trigger_setup(m, self.in_t.stb, self.in_t.ack)
382
383 chainlist = [get, sc, alm, n1]
384 chain = StageChain(chainlist, specallocate=True)
385 chain.setup(m, self.i)
386
387 for mod in chainlist:
388 sc = self.add_state(mod)
389
390 ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z.z, self.o,
391 n1.out_z.mid, self.o.mid))
392
393 #pz = self.add_state(FPPutZ("put_z", sc.out_z.z, self.o,
394 # sc.o.mid, self.o.mid))
395
396
397 class FPADDBase(FPState):
398
399 def __init__(self, width, id_wid=None, single_cycle=False):
400 """ IEEE754 FP Add
401
402 * width: bit-width of IEEE754. supported: 16, 32, 64
403 * id_wid: an identifier that is sync-connected to the input
404 * single_cycle: True indicates each stage to complete in 1 clock
405 """
406 FPState.__init__(self, "fpadd")
407 self.width = width
408 self.single_cycle = single_cycle
409 self.mod = FPADDBaseMod(width, id_wid, single_cycle)
410 self.o = self.ospec()
411
412 self.in_t = Trigger()
413 self.i = self.ispec()
414
415 self.z_done = Signal(reset_less=True) # connects to out_z Strobe
416 self.in_accept = Signal(reset_less=True)
417 self.add_stb = Signal(reset_less=True)
418 self.add_ack = Signal(reset=0, reset_less=True)
419
420 def ispec(self):
421 return self.mod.ispec()
422
423 def ospec(self):
424 return self.mod.ospec()
425
426 def setup(self, m, i, add_stb, in_mid):
427 m.d.comb += [self.i.eq(i),
428 self.mod.i.eq(self.i),
429 self.z_done.eq(self.mod.o.z.trigger),
430 #self.add_stb.eq(add_stb),
431 self.mod.in_t.stb.eq(self.in_t.stb),
432 self.in_t.ack.eq(self.mod.in_t.ack),
433 self.o.mid.eq(self.mod.o.mid),
434 self.o.z.v.eq(self.mod.o.z.v),
435 self.o.z.stb.eq(self.mod.o.z.stb),
436 self.mod.o.z.ack.eq(self.o.z.ack),
437 ]
438
439 m.d.sync += self.add_stb.eq(add_stb)
440 m.d.sync += self.add_ack.eq(0) # sets to zero when not in active state
441 m.d.sync += self.o.z.ack.eq(0) # likewise
442 #m.d.sync += self.in_t.stb.eq(0)
443
444 m.submodules.fpadd = self.mod
445
446 def action(self, m):
447
448 # in_accept is set on incoming strobe HIGH and ack LOW.
449 m.d.comb += self.in_accept.eq((~self.add_ack) & (self.add_stb))
450
451 #with m.If(self.in_t.ack):
452 # m.d.sync += self.in_t.stb.eq(0)
453 with m.If(~self.z_done):
454 # not done: test for accepting an incoming operand pair
455 with m.If(self.in_accept):
456 m.d.sync += [
457 self.add_ack.eq(1), # acknowledge receipt...
458 self.in_t.stb.eq(1), # initiate add
459 ]
460 with m.Else():
461 m.d.sync += [self.add_ack.eq(0),
462 self.in_t.stb.eq(0),
463 self.o.z.ack.eq(1),
464 ]
465 with m.Else():
466 # done: acknowledge, and write out id and value
467 m.d.sync += [self.add_ack.eq(1),
468 self.in_t.stb.eq(0)
469 ]
470 m.next = "put_z"
471
472 return
473
474 if self.in_mid is not None:
475 m.d.sync += self.out_mid.eq(self.mod.out_mid)
476
477 m.d.sync += [
478 self.out_z.v.eq(self.mod.out_z.v)
479 ]
480 # move to output state on detecting z ack
481 with m.If(self.out_z.trigger):
482 m.d.sync += self.out_z.stb.eq(0)
483 m.next = "put_z"
484 with m.Else():
485 m.d.sync += self.out_z.stb.eq(1)
486
487
488 class FPADDBasePipe(ControlBase):
489 def __init__(self, width, id_wid):
490 ControlBase.__init__(self)
491 self.pipe1 = FPAddSpecialCasesDeNorm(width, id_wid)
492 self.pipe2 = FPAddAlignSingleAdd(width, id_wid)
493 self.pipe3 = FPNormToPack(width, id_wid)
494
495 self._eqs = self.connect([self.pipe1, self.pipe2, self.pipe3])
496
497 def elaborate(self, platform):
498 m = Module()
499 m.submodules.scnorm = self.pipe1
500 m.submodules.addalign = self.pipe2
501 m.submodules.normpack = self.pipe3
502 m.d.comb += self._eqs
503 return m
504
505
506 class FPADDInMuxPipe(PriorityCombMuxInPipe):
507 def __init__(self, width, id_wid, num_rows):
508 self.num_rows = num_rows
509 def iospec(): return FPADDBaseData(width, id_wid)
510 stage = PassThroughStage(iospec)
511 PriorityCombMuxInPipe.__init__(self, stage, p_len=self.num_rows)
512
513
514 class FPADDMuxOutPipe(CombMuxOutPipe):
515 def __init__(self, width, id_wid, num_rows):
516 self.num_rows = num_rows
517 def iospec(): return FPPackData(width, id_wid)
518 stage = PassThroughStage(iospec)
519 CombMuxOutPipe.__init__(self, stage, n_len=self.num_rows)
520
521
522 class FPADDMuxInOut:
523 """ Reservation-Station version of FPADD pipeline.
524
525 * fan-in on inputs (an array of FPADDBaseData: a,b,mid)
526 * 3-stage adder pipeline
527 * fan-out on outputs (an array of FPPackData: z,mid)
528
529 Fan-in and Fan-out are combinatorial.
530 """
531 def __init__(self, width, id_wid, num_rows):
532 self.num_rows = num_rows
533 self.inpipe = FPADDInMuxPipe(width, id_wid, num_rows) # fan-in
534 self.fpadd = FPADDBasePipe(width, id_wid) # add stage
535 self.outpipe = FPADDMuxOutPipe(width, id_wid, num_rows) # fan-out
536
537 self.p = self.inpipe.p # kinda annoying,
538 self.n = self.outpipe.n # use pipe in/out as this class in/out
539 self._ports = self.inpipe.ports() + self.outpipe.ports()
540
541 def elaborate(self, platform):
542 m = Module()
543 m.submodules.inpipe = self.inpipe
544 m.submodules.fpadd = self.fpadd
545 m.submodules.outpipe = self.outpipe
546
547 m.d.comb += self.inpipe.n.connect_to_next(self.fpadd.p)
548 m.d.comb += self.fpadd.connect_to_next(self.outpipe)
549
550 return m
551
552 def ports(self):
553 return self._ports
554
555
556 class FPADD(FPID):
557 """ FPADD: stages as follows:
558
559 FPGetOp (a)
560 |
561 FPGetOp (b)
562 |
563 FPAddBase---> FPAddBaseMod
564 | |
565 PutZ GetOps->Specials->Align->Add1/2->Norm->Round/Pack->PutZ
566
567 FPAddBase is tricky: it is both a stage and *has* stages.
568 Connection to FPAddBaseMod therefore requires an in stb/ack
569 and an out stb/ack. Just as with Add1-Norm1 interaction, FPGetOp
570 needs to be the thing that raises the incoming stb.
571 """
572
573 def __init__(self, width, id_wid=None, single_cycle=False, rs_sz=2):
574 """ IEEE754 FP Add
575
576 * width: bit-width of IEEE754. supported: 16, 32, 64
577 * id_wid: an identifier that is sync-connected to the input
578 * single_cycle: True indicates each stage to complete in 1 clock
579 """
580 self.width = width
581 self.id_wid = id_wid
582 self.single_cycle = single_cycle
583
584 #self.out_z = FPOp(width)
585 self.ids = FPID(id_wid)
586
587 rs = []
588 for i in range(rs_sz):
589 in_a = FPOp(width)
590 in_b = FPOp(width)
591 in_a.name = "in_a_%d" % i
592 in_b.name = "in_b_%d" % i
593 rs.append((in_a, in_b))
594 self.rs = Array(rs)
595
596 res = []
597 for i in range(rs_sz):
598 out_z = FPOp(width)
599 out_z.name = "out_z_%d" % i
600 res.append(out_z)
601 self.res = Array(res)
602
603 self.states = []
604
605 def add_state(self, state):
606 self.states.append(state)
607 return state
608
609 def get_fragment(self, platform=None):
610 """ creates the HDL code-fragment for FPAdd
611 """
612 m = Module()
613 m.submodules += self.rs
614
615 in_a = self.rs[0][0]
616 in_b = self.rs[0][1]
617
618 geta = self.add_state(FPGetOp("get_a", "get_b",
619 in_a, self.width))
620 geta.setup(m, in_a)
621 a = geta.out_op
622
623 getb = self.add_state(FPGetOp("get_b", "fpadd",
624 in_b, self.width))
625 getb.setup(m, in_b)
626 b = getb.out_op
627
628 ab = FPADDBase(self.width, self.id_wid, self.single_cycle)
629 ab = self.add_state(ab)
630 abd = ab.ispec() # create an input spec object for FPADDBase
631 m.d.sync += [abd.a.eq(a), abd.b.eq(b), abd.mid.eq(self.ids.in_mid)]
632 ab.setup(m, abd, getb.out_decode, self.ids.in_mid)
633 o = ab.o
634
635 pz = self.add_state(FPPutZIdx("put_z", o.z, self.res,
636 o.mid, "get_a"))
637
638 with m.FSM() as fsm:
639
640 for state in self.states:
641 with m.State(state.state_from):
642 state.action(m)
643
644 return m
645
646
647 if __name__ == "__main__":
648 if True:
649 alu = FPADD(width=32, id_wid=5, single_cycle=True)
650 main(alu, ports=alu.rs[0][0].ports() + \
651 alu.rs[0][1].ports() + \
652 alu.res[0].ports() + \
653 [alu.ids.in_mid, alu.ids.out_mid])
654 else:
655 alu = FPADDBase(width=32, id_wid=5, single_cycle=True)
656 main(alu, ports=[alu.in_a, alu.in_b] + \
657 alu.in_t.ports() + \
658 alu.out_z.ports() + \
659 [alu.in_mid, alu.out_mid])
660
661
662 # works... but don't use, just do "python fname.py convert -t v"
663 #print (verilog.convert(alu, ports=[
664 # ports=alu.in_a.ports() + \
665 # alu.in_b.ports() + \
666 # alu.out_z.ports())