split out add stage 0 to separate module
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Mux, Array, Const
6 from nmigen.lib.coding import PriorityEncoder
7 from nmigen.cli import main, verilog
8 from math import log
9
10 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
11 from fpbase import MultiShiftRMerge, Trigger
12 from singlepipe import (ControlBase, StageChain, UnbufferedPipeline,
13 PassThroughStage)
14 from multipipe import CombMuxOutPipe
15 from multipipe import PriorityCombMuxInPipe
16
17 from fpbase import FPState, FPID
18 from fpcommon.getop import (FPGetOpMod, FPGetOp, FPNumBase2Ops, FPADDBaseData,
19 FPGet2OpMod, FPGet2Op)
20 from fpcommon.denorm import (FPSCData, FPAddDeNormMod, FPAddDeNorm)
21 from fpcommon.postcalc import FPAddStage1Data
22 from fpcommon.postnormalise import (FPNorm1Data, FPNorm1ModSingle,
23 FPNorm1ModMulti, FPNorm1Single, FPNorm1Multi)
24 from fpcommon.roundz import (FPRoundData, FPRoundMod, FPRound)
25 from fpcommon.corrections import (FPCorrectionsMod, FPCorrections)
26 from fpcommon.pack import (FPPackData, FPPackMod, FPPack)
27 from fpcommon.normtopack import FPNormToPack
28 from fpcommon.putz import (FPPutZ, FPPutZIdx)
29
30 from fpadd.specialcases import (FPAddSpecialCasesMod, FPAddSpecialCases,
31 FPAddSpecialCasesDeNorm)
32 from fpadd.align import (FPAddAlignMulti, FPAddAlignMultiMod, FPNumIn2Ops,
33 FPAddAlignSingleMod, FPAddAlignSingle)
34 from fpadd.add0 import (FPAddStage0Data, FPAddStage0Mod, FPAddStage0)
35
36
37 class FPAddAlignSingleAdd(FPState, UnbufferedPipeline):
38
39 def __init__(self, width, id_wid):
40 FPState.__init__(self, "align")
41 self.width = width
42 self.id_wid = id_wid
43 UnbufferedPipeline.__init__(self, self) # pipeline is its own stage
44 self.a1o = self.ospec()
45
46 def ispec(self):
47 return FPSCData(self.width, self.id_wid)
48
49 def ospec(self):
50 return FPAddStage1Data(self.width, self.id_wid) # AddStage1 ospec
51
52 def setup(self, m, i):
53 """ links module to inputs and outputs
54 """
55
56 # chain AddAlignSingle, AddStage0 and AddStage1
57 mod = FPAddAlignSingleMod(self.width, self.id_wid)
58 a0mod = FPAddStage0Mod(self.width, self.id_wid)
59 a1mod = FPAddStage1Mod(self.width, self.id_wid)
60
61 chain = StageChain([mod, a0mod, a1mod])
62 chain.setup(m, i)
63
64 self.o = a1mod.o
65
66 def process(self, i):
67 return self.o
68
69 def action(self, m):
70 m.d.sync += self.a1o.eq(self.process(None))
71 m.next = "normalise_1"
72
73
74
75
76 class FPAddStage1Mod(FPState):
77 """ Second stage of add: preparation for normalisation.
78 detects when tot sum is too big (tot[27] is kinda a carry bit)
79 """
80
81 def __init__(self, width, id_wid):
82 self.width = width
83 self.id_wid = id_wid
84 self.i = self.ispec()
85 self.o = self.ospec()
86
87 def ispec(self):
88 return FPAddStage0Data(self.width, self.id_wid)
89
90 def ospec(self):
91 return FPAddStage1Data(self.width, self.id_wid)
92
93 def process(self, i):
94 return self.o
95
96 def setup(self, m, i):
97 """ links module to inputs and outputs
98 """
99 m.submodules.add1 = self
100 m.submodules.add1_out_overflow = self.o.of
101
102 m.d.comb += self.i.eq(i)
103
104 def elaborate(self, platform):
105 m = Module()
106 m.d.comb += self.o.z.eq(self.i.z)
107 # tot[-1] (MSB) gets set when the sum overflows. shift result down
108 with m.If(~self.i.out_do_z):
109 with m.If(self.i.tot[-1]):
110 m.d.comb += [
111 self.o.z.m.eq(self.i.tot[4:]),
112 self.o.of.m0.eq(self.i.tot[4]),
113 self.o.of.guard.eq(self.i.tot[3]),
114 self.o.of.round_bit.eq(self.i.tot[2]),
115 self.o.of.sticky.eq(self.i.tot[1] | self.i.tot[0]),
116 self.o.z.e.eq(self.i.z.e + 1)
117 ]
118 # tot[-1] (MSB) zero case
119 with m.Else():
120 m.d.comb += [
121 self.o.z.m.eq(self.i.tot[3:]),
122 self.o.of.m0.eq(self.i.tot[3]),
123 self.o.of.guard.eq(self.i.tot[2]),
124 self.o.of.round_bit.eq(self.i.tot[1]),
125 self.o.of.sticky.eq(self.i.tot[0])
126 ]
127
128 m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
129 m.d.comb += self.o.oz.eq(self.i.oz)
130 m.d.comb += self.o.mid.eq(self.i.mid)
131
132 return m
133
134
135 class FPAddStage1(FPState):
136
137 def __init__(self, width, id_wid):
138 FPState.__init__(self, "add_1")
139 self.mod = FPAddStage1Mod(width)
140 self.out_z = FPNumBase(width, False)
141 self.out_of = Overflow()
142 self.norm_stb = Signal()
143
144 def setup(self, m, i):
145 """ links module to inputs and outputs
146 """
147 self.mod.setup(m, i)
148
149 m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state
150
151 m.d.sync += self.out_of.eq(self.mod.out_of)
152 m.d.sync += self.out_z.eq(self.mod.out_z)
153 m.d.sync += self.norm_stb.eq(1)
154
155 def action(self, m):
156 m.next = "normalise_1"
157
158
159
160
161 class FPOpData:
162 def __init__(self, width, id_wid):
163 self.z = FPOp(width)
164 self.mid = Signal(id_wid, reset_less=True)
165
166 def eq(self, i):
167 return [self.z.eq(i.z), self.mid.eq(i.mid)]
168
169 def ports(self):
170 return [self.z, self.mid]
171
172
173 class FPADDBaseMod:
174
175 def __init__(self, width, id_wid=None, single_cycle=False, compact=True):
176 """ IEEE754 FP Add
177
178 * width: bit-width of IEEE754. supported: 16, 32, 64
179 * id_wid: an identifier that is sync-connected to the input
180 * single_cycle: True indicates each stage to complete in 1 clock
181 * compact: True indicates a reduced number of stages
182 """
183 self.width = width
184 self.id_wid = id_wid
185 self.single_cycle = single_cycle
186 self.compact = compact
187
188 self.in_t = Trigger()
189 self.i = self.ispec()
190 self.o = self.ospec()
191
192 self.states = []
193
194 def ispec(self):
195 return FPADDBaseData(self.width, self.id_wid)
196
197 def ospec(self):
198 return FPOpData(self.width, self.id_wid)
199
200 def add_state(self, state):
201 self.states.append(state)
202 return state
203
204 def get_fragment(self, platform=None):
205 """ creates the HDL code-fragment for FPAdd
206 """
207 m = Module()
208 m.submodules.out_z = self.o.z
209 m.submodules.in_t = self.in_t
210 if self.compact:
211 self.get_compact_fragment(m, platform)
212 else:
213 self.get_longer_fragment(m, platform)
214
215 with m.FSM() as fsm:
216
217 for state in self.states:
218 with m.State(state.state_from):
219 state.action(m)
220
221 return m
222
223 def get_longer_fragment(self, m, platform=None):
224
225 get = self.add_state(FPGet2Op("get_ops", "special_cases",
226 self.width))
227 get.setup(m, self.i)
228 a = get.out_op1
229 b = get.out_op2
230 get.trigger_setup(m, self.in_t.stb, self.in_t.ack)
231
232 sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid))
233 sc.setup(m, a, b, self.in_mid)
234
235 dn = self.add_state(FPAddDeNorm(self.width, self.id_wid))
236 dn.setup(m, a, b, sc.in_mid)
237
238 if self.single_cycle:
239 alm = self.add_state(FPAddAlignSingle(self.width, self.id_wid))
240 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
241 else:
242 alm = self.add_state(FPAddAlignMulti(self.width, self.id_wid))
243 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
244
245 add0 = self.add_state(FPAddStage0(self.width, self.id_wid))
246 add0.setup(m, alm.out_a, alm.out_b, alm.in_mid)
247
248 add1 = self.add_state(FPAddStage1(self.width, self.id_wid))
249 add1.setup(m, add0.out_tot, add0.out_z, add0.in_mid)
250
251 if self.single_cycle:
252 n1 = self.add_state(FPNorm1Single(self.width, self.id_wid))
253 n1.setup(m, add1.out_z, add1.out_of, add0.in_mid)
254 else:
255 n1 = self.add_state(FPNorm1Multi(self.width, self.id_wid))
256 n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb, add0.in_mid)
257
258 rn = self.add_state(FPRound(self.width, self.id_wid))
259 rn.setup(m, n1.out_z, n1.out_roundz, n1.in_mid)
260
261 cor = self.add_state(FPCorrections(self.width, self.id_wid))
262 cor.setup(m, rn.out_z, rn.in_mid)
263
264 pa = self.add_state(FPPack(self.width, self.id_wid))
265 pa.setup(m, cor.out_z, rn.in_mid)
266
267 ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z,
268 pa.in_mid, self.out_mid))
269
270 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
271 pa.in_mid, self.out_mid))
272
273 def get_compact_fragment(self, m, platform=None):
274
275 get = FPGet2Op("get_ops", "special_cases", self.width, self.id_wid)
276 sc = FPAddSpecialCasesDeNorm(self.width, self.id_wid)
277 alm = FPAddAlignSingleAdd(self.width, self.id_wid)
278 n1 = FPNormToPack(self.width, self.id_wid)
279
280 get.trigger_setup(m, self.in_t.stb, self.in_t.ack)
281
282 chainlist = [get, sc, alm, n1]
283 chain = StageChain(chainlist, specallocate=True)
284 chain.setup(m, self.i)
285
286 for mod in chainlist:
287 sc = self.add_state(mod)
288
289 ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z.z, self.o,
290 n1.out_z.mid, self.o.mid))
291
292 #pz = self.add_state(FPPutZ("put_z", sc.out_z.z, self.o,
293 # sc.o.mid, self.o.mid))
294
295
296 class FPADDBase(FPState):
297
298 def __init__(self, width, id_wid=None, single_cycle=False):
299 """ IEEE754 FP Add
300
301 * width: bit-width of IEEE754. supported: 16, 32, 64
302 * id_wid: an identifier that is sync-connected to the input
303 * single_cycle: True indicates each stage to complete in 1 clock
304 """
305 FPState.__init__(self, "fpadd")
306 self.width = width
307 self.single_cycle = single_cycle
308 self.mod = FPADDBaseMod(width, id_wid, single_cycle)
309 self.o = self.ospec()
310
311 self.in_t = Trigger()
312 self.i = self.ispec()
313
314 self.z_done = Signal(reset_less=True) # connects to out_z Strobe
315 self.in_accept = Signal(reset_less=True)
316 self.add_stb = Signal(reset_less=True)
317 self.add_ack = Signal(reset=0, reset_less=True)
318
319 def ispec(self):
320 return self.mod.ispec()
321
322 def ospec(self):
323 return self.mod.ospec()
324
325 def setup(self, m, i, add_stb, in_mid):
326 m.d.comb += [self.i.eq(i),
327 self.mod.i.eq(self.i),
328 self.z_done.eq(self.mod.o.z.trigger),
329 #self.add_stb.eq(add_stb),
330 self.mod.in_t.stb.eq(self.in_t.stb),
331 self.in_t.ack.eq(self.mod.in_t.ack),
332 self.o.mid.eq(self.mod.o.mid),
333 self.o.z.v.eq(self.mod.o.z.v),
334 self.o.z.stb.eq(self.mod.o.z.stb),
335 self.mod.o.z.ack.eq(self.o.z.ack),
336 ]
337
338 m.d.sync += self.add_stb.eq(add_stb)
339 m.d.sync += self.add_ack.eq(0) # sets to zero when not in active state
340 m.d.sync += self.o.z.ack.eq(0) # likewise
341 #m.d.sync += self.in_t.stb.eq(0)
342
343 m.submodules.fpadd = self.mod
344
345 def action(self, m):
346
347 # in_accept is set on incoming strobe HIGH and ack LOW.
348 m.d.comb += self.in_accept.eq((~self.add_ack) & (self.add_stb))
349
350 #with m.If(self.in_t.ack):
351 # m.d.sync += self.in_t.stb.eq(0)
352 with m.If(~self.z_done):
353 # not done: test for accepting an incoming operand pair
354 with m.If(self.in_accept):
355 m.d.sync += [
356 self.add_ack.eq(1), # acknowledge receipt...
357 self.in_t.stb.eq(1), # initiate add
358 ]
359 with m.Else():
360 m.d.sync += [self.add_ack.eq(0),
361 self.in_t.stb.eq(0),
362 self.o.z.ack.eq(1),
363 ]
364 with m.Else():
365 # done: acknowledge, and write out id and value
366 m.d.sync += [self.add_ack.eq(1),
367 self.in_t.stb.eq(0)
368 ]
369 m.next = "put_z"
370
371 return
372
373 if self.in_mid is not None:
374 m.d.sync += self.out_mid.eq(self.mod.out_mid)
375
376 m.d.sync += [
377 self.out_z.v.eq(self.mod.out_z.v)
378 ]
379 # move to output state on detecting z ack
380 with m.If(self.out_z.trigger):
381 m.d.sync += self.out_z.stb.eq(0)
382 m.next = "put_z"
383 with m.Else():
384 m.d.sync += self.out_z.stb.eq(1)
385
386
387 class FPADDBasePipe(ControlBase):
388 def __init__(self, width, id_wid):
389 ControlBase.__init__(self)
390 self.pipe1 = FPAddSpecialCasesDeNorm(width, id_wid)
391 self.pipe2 = FPAddAlignSingleAdd(width, id_wid)
392 self.pipe3 = FPNormToPack(width, id_wid)
393
394 self._eqs = self.connect([self.pipe1, self.pipe2, self.pipe3])
395
396 def elaborate(self, platform):
397 m = Module()
398 m.submodules.scnorm = self.pipe1
399 m.submodules.addalign = self.pipe2
400 m.submodules.normpack = self.pipe3
401 m.d.comb += self._eqs
402 return m
403
404
405 class FPADDInMuxPipe(PriorityCombMuxInPipe):
406 def __init__(self, width, id_wid, num_rows):
407 self.num_rows = num_rows
408 def iospec(): return FPADDBaseData(width, id_wid)
409 stage = PassThroughStage(iospec)
410 PriorityCombMuxInPipe.__init__(self, stage, p_len=self.num_rows)
411
412
413 class FPADDMuxOutPipe(CombMuxOutPipe):
414 def __init__(self, width, id_wid, num_rows):
415 self.num_rows = num_rows
416 def iospec(): return FPPackData(width, id_wid)
417 stage = PassThroughStage(iospec)
418 CombMuxOutPipe.__init__(self, stage, n_len=self.num_rows)
419
420
421 class FPADDMuxInOut:
422 """ Reservation-Station version of FPADD pipeline.
423
424 * fan-in on inputs (an array of FPADDBaseData: a,b,mid)
425 * 3-stage adder pipeline
426 * fan-out on outputs (an array of FPPackData: z,mid)
427
428 Fan-in and Fan-out are combinatorial.
429 """
430 def __init__(self, width, id_wid, num_rows):
431 self.num_rows = num_rows
432 self.inpipe = FPADDInMuxPipe(width, id_wid, num_rows) # fan-in
433 self.fpadd = FPADDBasePipe(width, id_wid) # add stage
434 self.outpipe = FPADDMuxOutPipe(width, id_wid, num_rows) # fan-out
435
436 self.p = self.inpipe.p # kinda annoying,
437 self.n = self.outpipe.n # use pipe in/out as this class in/out
438 self._ports = self.inpipe.ports() + self.outpipe.ports()
439
440 def elaborate(self, platform):
441 m = Module()
442 m.submodules.inpipe = self.inpipe
443 m.submodules.fpadd = self.fpadd
444 m.submodules.outpipe = self.outpipe
445
446 m.d.comb += self.inpipe.n.connect_to_next(self.fpadd.p)
447 m.d.comb += self.fpadd.connect_to_next(self.outpipe)
448
449 return m
450
451 def ports(self):
452 return self._ports
453
454
455 class FPADD(FPID):
456 """ FPADD: stages as follows:
457
458 FPGetOp (a)
459 |
460 FPGetOp (b)
461 |
462 FPAddBase---> FPAddBaseMod
463 | |
464 PutZ GetOps->Specials->Align->Add1/2->Norm->Round/Pack->PutZ
465
466 FPAddBase is tricky: it is both a stage and *has* stages.
467 Connection to FPAddBaseMod therefore requires an in stb/ack
468 and an out stb/ack. Just as with Add1-Norm1 interaction, FPGetOp
469 needs to be the thing that raises the incoming stb.
470 """
471
472 def __init__(self, width, id_wid=None, single_cycle=False, rs_sz=2):
473 """ IEEE754 FP Add
474
475 * width: bit-width of IEEE754. supported: 16, 32, 64
476 * id_wid: an identifier that is sync-connected to the input
477 * single_cycle: True indicates each stage to complete in 1 clock
478 """
479 self.width = width
480 self.id_wid = id_wid
481 self.single_cycle = single_cycle
482
483 #self.out_z = FPOp(width)
484 self.ids = FPID(id_wid)
485
486 rs = []
487 for i in range(rs_sz):
488 in_a = FPOp(width)
489 in_b = FPOp(width)
490 in_a.name = "in_a_%d" % i
491 in_b.name = "in_b_%d" % i
492 rs.append((in_a, in_b))
493 self.rs = Array(rs)
494
495 res = []
496 for i in range(rs_sz):
497 out_z = FPOp(width)
498 out_z.name = "out_z_%d" % i
499 res.append(out_z)
500 self.res = Array(res)
501
502 self.states = []
503
504 def add_state(self, state):
505 self.states.append(state)
506 return state
507
508 def get_fragment(self, platform=None):
509 """ creates the HDL code-fragment for FPAdd
510 """
511 m = Module()
512 m.submodules += self.rs
513
514 in_a = self.rs[0][0]
515 in_b = self.rs[0][1]
516
517 geta = self.add_state(FPGetOp("get_a", "get_b",
518 in_a, self.width))
519 geta.setup(m, in_a)
520 a = geta.out_op
521
522 getb = self.add_state(FPGetOp("get_b", "fpadd",
523 in_b, self.width))
524 getb.setup(m, in_b)
525 b = getb.out_op
526
527 ab = FPADDBase(self.width, self.id_wid, self.single_cycle)
528 ab = self.add_state(ab)
529 abd = ab.ispec() # create an input spec object for FPADDBase
530 m.d.sync += [abd.a.eq(a), abd.b.eq(b), abd.mid.eq(self.ids.in_mid)]
531 ab.setup(m, abd, getb.out_decode, self.ids.in_mid)
532 o = ab.o
533
534 pz = self.add_state(FPPutZIdx("put_z", o.z, self.res,
535 o.mid, "get_a"))
536
537 with m.FSM() as fsm:
538
539 for state in self.states:
540 with m.State(state.state_from):
541 state.action(m)
542
543 return m
544
545
546 if __name__ == "__main__":
547 if True:
548 alu = FPADD(width=32, id_wid=5, single_cycle=True)
549 main(alu, ports=alu.rs[0][0].ports() + \
550 alu.rs[0][1].ports() + \
551 alu.res[0].ports() + \
552 [alu.ids.in_mid, alu.ids.out_mid])
553 else:
554 alu = FPADDBase(width=32, id_wid=5, single_cycle=True)
555 main(alu, ports=[alu.in_a, alu.in_b] + \
556 alu.in_t.ports() + \
557 alu.out_z.ports() + \
558 [alu.in_mid, alu.out_mid])
559
560
561 # works... but don't use, just do "python fname.py convert -t v"
562 #print (verilog.convert(alu, ports=[
563 # ports=alu.in_a.ports() + \
564 # alu.in_b.ports() + \
565 # alu.out_z.ports())