add capability to pass through operands and muxid to output
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Mux, Array, Const
6 from nmigen.lib.coding import PriorityEncoder
7 from nmigen.cli import main, verilog
8 from math import log
9
10 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
11 from fpbase import MultiShiftRMerge, Trigger
12 #from fpbase import FPNumShiftMultiRight
13
14
15 class FPState(FPBase):
16 def __init__(self, state_from):
17 self.state_from = state_from
18
19 def set_inputs(self, inputs):
20 self.inputs = inputs
21 for k,v in inputs.items():
22 setattr(self, k, v)
23
24 def set_outputs(self, outputs):
25 self.outputs = outputs
26 for k,v in outputs.items():
27 setattr(self, k, v)
28
29
30 class FPGetSyncOpsMod:
31 def __init__(self, width, num_ops=2):
32 self.width = width
33 self.num_ops = num_ops
34 inops = []
35 outops = []
36 for i in range(num_ops):
37 inops.append(Signal(width, reset_less=True))
38 outops.append(Signal(width, reset_less=True))
39 self.in_op = inops
40 self.out_op = outops
41 self.stb = Signal(num_ops)
42 self.ack = Signal()
43 self.ready = Signal(reset_less=True)
44 self.out_decode = Signal(reset_less=True)
45
46 def elaborate(self, platform):
47 m = Module()
48 m.d.comb += self.ready.eq(self.stb == Const(-1, (self.num_ops, False)))
49 m.d.comb += self.out_decode.eq(self.ack & self.ready)
50 with m.If(self.out_decode):
51 for i in range(self.num_ops):
52 m.d.comb += [
53 self.out_op[i].eq(self.in_op[i]),
54 ]
55 return m
56
57 def ports(self):
58 return self.in_op + self.out_op + [self.stb, self.ack]
59
60
61 class InputGroup(Trigger):
62 def __init__(self, width, num_ops=2, num_rows=4):
63 Trigger.__init__(self)
64 self.width = width
65 self.num_ops = num_ops
66 self.num_rows = num_rows
67 self.mmax = int(log(self.num_rows) / log(2))
68 self.rs = []
69 self.mid = Signal(self.mmax, reset_less=True) # multiplex id
70 for i in range(num_rows):
71 self.rs.append(FPGetSyncOpsMod(width, num_ops))
72
73 outops = []
74 for i in range(num_ops):
75 outops.append(Signal(width, reset_less=True))
76 self.out_op = outops
77
78 def elaborate(self, platform):
79 m = Trigger.elaborate(self, platform)
80 pe = PriorityEncoder(self.num_rows)
81 m.submodules.selector = pe
82
83 # connect priority encoder
84 in_ready = []
85 for i in range(self.num_rows):
86 in_ready.append(self.rs[i].ready)
87 m.d.comb += pe.i.eq(Cat(*in_ready))
88 m.d.comb += self.stb.eq(pe.n) # strobe-out valid when encoder is active
89
90 with m.If(pe.n):
91 m.d.sync += self.mid.eq(pe.o)
92 for i in range(self.num_rows):
93 with m.If(pe.o == Const(i, (self.mmax, False))):
94 for j in range(self.num_ops):
95 m.d.sync += self.out_op[j].eq(self.rs[i].out_op[j])
96 return m
97
98 def ports(self):
99 res = []
100 for i in range(self.num_rows):
101 inop = self.rs[i]
102 res += inop.in_op + [inop.stb]
103 return self.out_op + res #+ [self.ack + self.stb]
104
105 class FPGetOpMod:
106 def __init__(self, width):
107 self.in_op = FPOp(width)
108 self.out_op = Signal(width)
109 self.out_decode = Signal(reset_less=True)
110
111 def elaborate(self, platform):
112 m = Module()
113 m.d.comb += self.out_decode.eq((self.in_op.ack) & (self.in_op.stb))
114 m.submodules.get_op_in = self.in_op
115 #m.submodules.get_op_out = self.out_op
116 with m.If(self.out_decode):
117 m.d.comb += [
118 self.out_op.eq(self.in_op.v),
119 ]
120 return m
121
122
123 class FPGetOp(FPState):
124 """ gets operand
125 """
126
127 def __init__(self, in_state, out_state, in_op, width):
128 FPState.__init__(self, in_state)
129 self.out_state = out_state
130 self.mod = FPGetOpMod(width)
131 self.in_op = in_op
132 self.out_op = Signal(width)
133 self.out_decode = Signal(reset_less=True)
134
135 def setup(self, m, in_op):
136 """ links module to inputs and outputs
137 """
138 setattr(m.submodules, self.state_from, self.mod)
139 m.d.comb += self.mod.in_op.copy(in_op)
140 #m.d.comb += self.out_op.eq(self.mod.out_op)
141 m.d.comb += self.out_decode.eq(self.mod.out_decode)
142
143 def action(self, m):
144 with m.If(self.out_decode):
145 m.next = self.out_state
146 m.d.sync += [
147 self.in_op.ack.eq(0),
148 self.out_op.eq(self.mod.out_op)
149 ]
150 with m.Else():
151 m.d.sync += self.in_op.ack.eq(1)
152
153
154 class FPGet2OpMod(Trigger):
155 def __init__(self, width):
156 Trigger.__init__(self)
157 self.in_op1 = Signal(width, reset_less=True)
158 self.in_op2 = Signal(width, reset_less=True)
159 self.out_op1 = FPNumIn(None, width)
160 self.out_op2 = FPNumIn(None, width)
161
162 def elaborate(self, platform):
163 m = Trigger.elaborate(self, platform)
164 #m.submodules.get_op_in = self.in_op
165 m.submodules.get_op1_out = self.out_op1
166 m.submodules.get_op2_out = self.out_op2
167 with m.If(self.trigger):
168 m.d.comb += [
169 self.out_op1.decode(self.in_op1),
170 self.out_op2.decode(self.in_op2),
171 ]
172 return m
173
174
175 class FPGet2Op(FPState):
176 """ gets operands
177 """
178
179 def __init__(self, in_state, out_state, in_op1, in_op2, width):
180 FPState.__init__(self, in_state)
181 self.out_state = out_state
182 self.mod = FPGet2OpMod(width)
183 self.in_op1 = in_op1
184 self.in_op2 = in_op2
185 self.out_op1 = FPNumIn(None, width)
186 self.out_op2 = FPNumIn(None, width)
187 self.in_stb = Signal(reset_less=True)
188 self.out_ack = Signal(reset_less=True)
189 self.out_decode = Signal(reset_less=True)
190
191 def setup(self, m, in_op1, in_op2, in_stb, in_ack):
192 """ links module to inputs and outputs
193 """
194 m.submodules.get_ops = self.mod
195 m.d.comb += self.mod.in_op1.eq(in_op1)
196 m.d.comb += self.mod.in_op2.eq(in_op2)
197 m.d.comb += self.mod.stb.eq(in_stb)
198 m.d.comb += self.out_ack.eq(self.mod.ack)
199 m.d.comb += self.out_decode.eq(self.mod.trigger)
200 m.d.comb += in_ack.eq(self.mod.ack)
201
202 def action(self, m):
203 with m.If(self.out_decode):
204 m.next = self.out_state
205 m.d.sync += [
206 self.mod.ack.eq(0),
207 #self.out_op1.v.eq(self.mod.out_op1.v),
208 #self.out_op2.v.eq(self.mod.out_op2.v),
209 self.out_op1.copy(self.mod.out_op1),
210 self.out_op2.copy(self.mod.out_op2)
211 ]
212 with m.Else():
213 m.d.sync += self.mod.ack.eq(1)
214
215
216 class FPAddSpecialCasesMod:
217 """ special cases: NaNs, infs, zeros, denormalised
218 NOTE: some of these are unique to add. see "Special Operations"
219 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
220 """
221
222 def __init__(self, width):
223 self.in_a = FPNumBase(width)
224 self.in_b = FPNumBase(width)
225 self.out_z = FPNumOut(width, False)
226 self.out_do_z = Signal(reset_less=True)
227
228 def setup(self, m, in_a, in_b, out_do_z):
229 """ links module to inputs and outputs
230 """
231 m.submodules.specialcases = self
232 m.d.comb += self.in_a.copy(in_a)
233 m.d.comb += self.in_b.copy(in_b)
234 m.d.comb += out_do_z.eq(self.out_do_z)
235
236 def elaborate(self, platform):
237 m = Module()
238
239 m.submodules.sc_in_a = self.in_a
240 m.submodules.sc_in_b = self.in_b
241 m.submodules.sc_out_z = self.out_z
242
243 s_nomatch = Signal()
244 m.d.comb += s_nomatch.eq(self.in_a.s != self.in_b.s)
245
246 m_match = Signal()
247 m.d.comb += m_match.eq(self.in_a.m == self.in_b.m)
248
249 # if a is NaN or b is NaN return NaN
250 with m.If(self.in_a.is_nan | self.in_b.is_nan):
251 m.d.comb += self.out_do_z.eq(1)
252 m.d.comb += self.out_z.nan(0)
253
254 # XXX WEIRDNESS for FP16 non-canonical NaN handling
255 # under review
256
257 ## if a is zero and b is NaN return -b
258 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
259 # m.d.comb += self.out_do_z.eq(1)
260 # m.d.comb += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
261
262 ## if b is zero and a is NaN return -a
263 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
264 # m.d.comb += self.out_do_z.eq(1)
265 # m.d.comb += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
266
267 ## if a is -zero and b is NaN return -b
268 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
269 # m.d.comb += self.out_do_z.eq(1)
270 # m.d.comb += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
271
272 ## if b is -zero and a is NaN return -a
273 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
274 # m.d.comb += self.out_do_z.eq(1)
275 # m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
276
277 # if a is inf return inf (or NaN)
278 with m.Elif(self.in_a.is_inf):
279 m.d.comb += self.out_do_z.eq(1)
280 m.d.comb += self.out_z.inf(self.in_a.s)
281 # if a is inf and signs don't match return NaN
282 with m.If(self.in_b.exp_128 & s_nomatch):
283 m.d.comb += self.out_z.nan(0)
284
285 # if b is inf return inf
286 with m.Elif(self.in_b.is_inf):
287 m.d.comb += self.out_do_z.eq(1)
288 m.d.comb += self.out_z.inf(self.in_b.s)
289
290 # if a is zero and b zero return signed-a/b
291 with m.Elif(self.in_a.is_zero & self.in_b.is_zero):
292 m.d.comb += self.out_do_z.eq(1)
293 m.d.comb += self.out_z.create(self.in_a.s & self.in_b.s,
294 self.in_b.e,
295 self.in_b.m[3:-1])
296
297 # if a is zero return b
298 with m.Elif(self.in_a.is_zero):
299 m.d.comb += self.out_do_z.eq(1)
300 m.d.comb += self.out_z.create(self.in_b.s, self.in_b.e,
301 self.in_b.m[3:-1])
302
303 # if b is zero return a
304 with m.Elif(self.in_b.is_zero):
305 m.d.comb += self.out_do_z.eq(1)
306 m.d.comb += self.out_z.create(self.in_a.s, self.in_a.e,
307 self.in_a.m[3:-1])
308
309 # if a equal to -b return zero (+ve zero)
310 with m.Elif(s_nomatch & m_match & (self.in_a.e == self.in_b.e)):
311 m.d.comb += self.out_do_z.eq(1)
312 m.d.comb += self.out_z.zero(0)
313
314 # Denormalised Number checks
315 with m.Else():
316 m.d.comb += self.out_do_z.eq(0)
317
318 return m
319
320
321 class FPID:
322 def __init__(self, id_wid):
323 self.id_wid = id_wid
324 if self.id_wid:
325 self.in_mid = Signal(id_wid, reset_less=True)
326 self.out_mid = Signal(id_wid, reset_less=True)
327 else:
328 self.in_mid = None
329 self.out_mid = None
330
331 def idsync(self, m):
332 if self.id_wid is not None:
333 m.d.sync += self.out_mid.eq(self.in_mid)
334
335
336 class FPAddSpecialCases(FPState, FPID):
337 """ special cases: NaNs, infs, zeros, denormalised
338 NOTE: some of these are unique to add. see "Special Operations"
339 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
340 """
341
342 def __init__(self, width, id_wid):
343 FPState.__init__(self, "special_cases")
344 FPID.__init__(self, id_wid)
345 self.mod = FPAddSpecialCasesMod(width)
346 self.out_z = FPNumOut(width, False)
347 self.out_do_z = Signal(reset_less=True)
348
349 def setup(self, m, in_a, in_b, in_mid):
350 """ links module to inputs and outputs
351 """
352 self.mod.setup(m, in_a, in_b, self.out_do_z)
353 if self.in_mid is not None:
354 m.d.comb += self.in_mid.eq(in_mid)
355
356 def action(self, m):
357 self.idsync(m)
358 with m.If(self.out_do_z):
359 m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
360 m.next = "put_z"
361 with m.Else():
362 m.next = "denormalise"
363
364
365 class FPAddSpecialCasesDeNorm(FPState, FPID):
366 """ special cases: NaNs, infs, zeros, denormalised
367 NOTE: some of these are unique to add. see "Special Operations"
368 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
369 """
370
371 def __init__(self, width, id_wid):
372 FPState.__init__(self, "special_cases")
373 FPID.__init__(self, id_wid)
374 self.smod = FPAddSpecialCasesMod(width)
375 self.out_z = FPNumOut(width, False)
376 self.out_do_z = Signal(reset_less=True)
377
378 self.dmod = FPAddDeNormMod(width)
379 self.out_a = FPNumBase(width)
380 self.out_b = FPNumBase(width)
381
382 def setup(self, m, in_a, in_b, in_mid):
383 """ links module to inputs and outputs
384 """
385 self.smod.setup(m, in_a, in_b, self.out_do_z)
386 self.dmod.setup(m, in_a, in_b)
387 if self.in_mid is not None:
388 m.d.comb += self.in_mid.eq(in_mid)
389
390 def action(self, m):
391 self.idsync(m)
392 with m.If(self.out_do_z):
393 m.d.sync += self.out_z.v.eq(self.smod.out_z.v) # only take output
394 m.next = "put_z"
395 with m.Else():
396 m.next = "align"
397 m.d.sync += self.out_a.copy(self.dmod.out_a)
398 m.d.sync += self.out_b.copy(self.dmod.out_b)
399
400
401 class FPAddDeNormMod(FPState):
402
403 def __init__(self, width):
404 self.in_a = FPNumBase(width)
405 self.in_b = FPNumBase(width)
406 self.out_a = FPNumBase(width)
407 self.out_b = FPNumBase(width)
408
409 def setup(self, m, in_a, in_b):
410 """ links module to inputs and outputs
411 """
412 m.submodules.denormalise = self
413 m.d.comb += self.in_a.copy(in_a)
414 m.d.comb += self.in_b.copy(in_b)
415
416 def elaborate(self, platform):
417 m = Module()
418 m.submodules.denorm_in_a = self.in_a
419 m.submodules.denorm_in_b = self.in_b
420 m.submodules.denorm_out_a = self.out_a
421 m.submodules.denorm_out_b = self.out_b
422 # hmmm, don't like repeating identical code
423 m.d.comb += self.out_a.copy(self.in_a)
424 with m.If(self.in_a.exp_n127):
425 m.d.comb += self.out_a.e.eq(self.in_a.N126) # limit a exponent
426 with m.Else():
427 m.d.comb += self.out_a.m[-1].eq(1) # set top mantissa bit
428
429 m.d.comb += self.out_b.copy(self.in_b)
430 with m.If(self.in_b.exp_n127):
431 m.d.comb += self.out_b.e.eq(self.in_b.N126) # limit a exponent
432 with m.Else():
433 m.d.comb += self.out_b.m[-1].eq(1) # set top mantissa bit
434
435 return m
436
437
438 class FPAddDeNorm(FPState, FPID):
439
440 def __init__(self, width, id_wid):
441 FPState.__init__(self, "denormalise")
442 FPID.__init__(self, id_wid)
443 self.mod = FPAddDeNormMod(width)
444 self.out_a = FPNumBase(width)
445 self.out_b = FPNumBase(width)
446
447 def setup(self, m, in_a, in_b, in_mid):
448 """ links module to inputs and outputs
449 """
450 self.mod.setup(m, in_a, in_b)
451 if self.in_mid is not None:
452 m.d.comb += self.in_mid.eq(in_mid)
453
454 def action(self, m):
455 self.idsync(m)
456 # Denormalised Number checks
457 m.next = "align"
458 m.d.sync += self.out_a.copy(self.mod.out_a)
459 m.d.sync += self.out_b.copy(self.mod.out_b)
460
461
462 class FPAddAlignMultiMod(FPState):
463
464 def __init__(self, width):
465 self.in_a = FPNumBase(width)
466 self.in_b = FPNumBase(width)
467 self.out_a = FPNumIn(None, width)
468 self.out_b = FPNumIn(None, width)
469 self.exp_eq = Signal(reset_less=True)
470
471 def elaborate(self, platform):
472 # This one however (single-cycle) will do the shift
473 # in one go.
474
475 m = Module()
476
477 m.submodules.align_in_a = self.in_a
478 m.submodules.align_in_b = self.in_b
479 m.submodules.align_out_a = self.out_a
480 m.submodules.align_out_b = self.out_b
481
482 # NOTE: this does *not* do single-cycle multi-shifting,
483 # it *STAYS* in the align state until exponents match
484
485 # exponent of a greater than b: shift b down
486 m.d.comb += self.exp_eq.eq(0)
487 m.d.comb += self.out_a.copy(self.in_a)
488 m.d.comb += self.out_b.copy(self.in_b)
489 agtb = Signal(reset_less=True)
490 altb = Signal(reset_less=True)
491 m.d.comb += agtb.eq(self.in_a.e > self.in_b.e)
492 m.d.comb += altb.eq(self.in_a.e < self.in_b.e)
493 with m.If(agtb):
494 m.d.comb += self.out_b.shift_down(self.in_b)
495 # exponent of b greater than a: shift a down
496 with m.Elif(altb):
497 m.d.comb += self.out_a.shift_down(self.in_a)
498 # exponents equal: move to next stage.
499 with m.Else():
500 m.d.comb += self.exp_eq.eq(1)
501 return m
502
503
504 class FPAddAlignMulti(FPState, FPID):
505
506 def __init__(self, width, id_wid):
507 FPID.__init__(self, id_wid)
508 FPState.__init__(self, "align")
509 self.mod = FPAddAlignMultiMod(width)
510 self.out_a = FPNumIn(None, width)
511 self.out_b = FPNumIn(None, width)
512 self.exp_eq = Signal(reset_less=True)
513
514 def setup(self, m, in_a, in_b, in_mid):
515 """ links module to inputs and outputs
516 """
517 m.submodules.align = self.mod
518 m.d.comb += self.mod.in_a.copy(in_a)
519 m.d.comb += self.mod.in_b.copy(in_b)
520 #m.d.comb += self.out_a.copy(self.mod.out_a)
521 #m.d.comb += self.out_b.copy(self.mod.out_b)
522 m.d.comb += self.exp_eq.eq(self.mod.exp_eq)
523 if self.in_mid is not None:
524 m.d.comb += self.in_mid.eq(in_mid)
525
526 def action(self, m):
527 self.idsync(m)
528 m.d.sync += self.out_a.copy(self.mod.out_a)
529 m.d.sync += self.out_b.copy(self.mod.out_b)
530 with m.If(self.exp_eq):
531 m.next = "add_0"
532
533
534 class FPAddAlignSingleMod:
535
536 def __init__(self, width):
537 self.width = width
538 self.in_a = FPNumBase(width)
539 self.in_b = FPNumBase(width)
540 self.out_a = FPNumIn(None, width)
541 self.out_b = FPNumIn(None, width)
542
543 def setup(self, m, in_a, in_b):
544 """ links module to inputs and outputs
545 """
546 m.submodules.align = self
547 m.d.comb += self.in_a.copy(in_a)
548 m.d.comb += self.in_b.copy(in_b)
549
550 def elaborate(self, platform):
551 """ Aligns A against B or B against A, depending on which has the
552 greater exponent. This is done in a *single* cycle using
553 variable-width bit-shift
554
555 the shifter used here is quite expensive in terms of gates.
556 Mux A or B in (and out) into temporaries, as only one of them
557 needs to be aligned against the other
558 """
559 m = Module()
560
561 m.submodules.align_in_a = self.in_a
562 m.submodules.align_in_b = self.in_b
563 m.submodules.align_out_a = self.out_a
564 m.submodules.align_out_b = self.out_b
565
566 # temporary (muxed) input and output to be shifted
567 t_inp = FPNumBase(self.width)
568 t_out = FPNumIn(None, self.width)
569 espec = (len(self.in_a.e), True)
570 msr = MultiShiftRMerge(self.in_a.m_width, espec)
571 m.submodules.align_t_in = t_inp
572 m.submodules.align_t_out = t_out
573 m.submodules.multishift_r = msr
574
575 ediff = Signal(espec, reset_less=True)
576 ediffr = Signal(espec, reset_less=True)
577 tdiff = Signal(espec, reset_less=True)
578 elz = Signal(reset_less=True)
579 egz = Signal(reset_less=True)
580
581 # connect multi-shifter to t_inp/out mantissa (and tdiff)
582 m.d.comb += msr.inp.eq(t_inp.m)
583 m.d.comb += msr.diff.eq(tdiff)
584 m.d.comb += t_out.m.eq(msr.m)
585 m.d.comb += t_out.e.eq(t_inp.e + tdiff)
586 m.d.comb += t_out.s.eq(t_inp.s)
587
588 m.d.comb += ediff.eq(self.in_a.e - self.in_b.e)
589 m.d.comb += ediffr.eq(self.in_b.e - self.in_a.e)
590 m.d.comb += elz.eq(self.in_a.e < self.in_b.e)
591 m.d.comb += egz.eq(self.in_a.e > self.in_b.e)
592
593 # default: A-exp == B-exp, A and B untouched (fall through)
594 m.d.comb += self.out_a.copy(self.in_a)
595 m.d.comb += self.out_b.copy(self.in_b)
596 # only one shifter (muxed)
597 #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
598 # exponent of a greater than b: shift b down
599 with m.If(egz):
600 m.d.comb += [t_inp.copy(self.in_b),
601 tdiff.eq(ediff),
602 self.out_b.copy(t_out),
603 self.out_b.s.eq(self.in_b.s), # whoops forgot sign
604 ]
605 # exponent of b greater than a: shift a down
606 with m.Elif(elz):
607 m.d.comb += [t_inp.copy(self.in_a),
608 tdiff.eq(ediffr),
609 self.out_a.copy(t_out),
610 self.out_a.s.eq(self.in_a.s), # whoops forgot sign
611 ]
612 return m
613
614
615 class FPAddAlignSingle(FPState, FPID):
616
617 def __init__(self, width, id_wid):
618 FPState.__init__(self, "align")
619 FPID.__init__(self, id_wid)
620 self.mod = FPAddAlignSingleMod(width)
621 self.out_a = FPNumIn(None, width)
622 self.out_b = FPNumIn(None, width)
623
624 def setup(self, m, in_a, in_b, in_mid):
625 """ links module to inputs and outputs
626 """
627 self.mod.setup(m, in_a, in_b)
628 if self.in_mid is not None:
629 m.d.comb += self.in_mid.eq(in_mid)
630
631 def action(self, m):
632 self.idsync(m)
633 # NOTE: could be done as comb
634 m.d.sync += self.out_a.copy(self.mod.out_a)
635 m.d.sync += self.out_b.copy(self.mod.out_b)
636 m.next = "add_0"
637
638
639 class FPAddAlignSingleAdd(FPState, FPID):
640
641 def __init__(self, width, id_wid):
642 FPState.__init__(self, "align")
643 FPID.__init__(self, id_wid)
644 self.mod = FPAddAlignSingleMod(width)
645 self.out_a = FPNumIn(None, width)
646 self.out_b = FPNumIn(None, width)
647
648 self.a0mod = FPAddStage0Mod(width)
649 self.a0_out_z = FPNumBase(width, False)
650 self.out_tot = Signal(self.a0_out_z.m_width + 4, reset_less=True)
651 self.a0_out_z = FPNumBase(width, False)
652
653 self.a1mod = FPAddStage1Mod(width)
654 self.out_z = FPNumBase(width, False)
655 self.out_of = Overflow()
656
657 def setup(self, m, in_a, in_b, in_mid):
658 """ links module to inputs and outputs
659 """
660 self.mod.setup(m, in_a, in_b)
661 m.d.comb += self.out_a.copy(self.mod.out_a)
662 m.d.comb += self.out_b.copy(self.mod.out_b)
663
664 self.a0mod.setup(m, self.out_a, self.out_b)
665 m.d.comb += self.a0_out_z.copy(self.a0mod.out_z)
666 m.d.comb += self.out_tot.eq(self.a0mod.out_tot)
667
668 self.a1mod.setup(m, self.out_tot, self.a0_out_z)
669
670 if self.in_mid is not None:
671 m.d.comb += self.in_mid.eq(in_mid)
672
673 def action(self, m):
674 self.idsync(m)
675 m.d.sync += self.out_of.copy(self.a1mod.out_of)
676 m.d.sync += self.out_z.copy(self.a1mod.out_z)
677 m.next = "normalise_1"
678
679
680 class FPAddStage0Mod:
681
682 def __init__(self, width):
683 self.in_a = FPNumBase(width)
684 self.in_b = FPNumBase(width)
685 self.in_z = FPNumBase(width, False)
686 self.out_z = FPNumBase(width, False)
687 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
688
689 def setup(self, m, in_a, in_b):
690 """ links module to inputs and outputs
691 """
692 m.submodules.add0 = self
693 m.d.comb += self.in_a.copy(in_a)
694 m.d.comb += self.in_b.copy(in_b)
695
696 def elaborate(self, platform):
697 m = Module()
698 m.submodules.add0_in_a = self.in_a
699 m.submodules.add0_in_b = self.in_b
700 m.submodules.add0_out_z = self.out_z
701
702 m.d.comb += self.out_z.e.eq(self.in_a.e)
703
704 # store intermediate tests (and zero-extended mantissas)
705 seq = Signal(reset_less=True)
706 mge = Signal(reset_less=True)
707 am0 = Signal(len(self.in_a.m)+1, reset_less=True)
708 bm0 = Signal(len(self.in_b.m)+1, reset_less=True)
709 m.d.comb += [seq.eq(self.in_a.s == self.in_b.s),
710 mge.eq(self.in_a.m >= self.in_b.m),
711 am0.eq(Cat(self.in_a.m, 0)),
712 bm0.eq(Cat(self.in_b.m, 0))
713 ]
714 # same-sign (both negative or both positive) add mantissas
715 with m.If(seq):
716 m.d.comb += [
717 self.out_tot.eq(am0 + bm0),
718 self.out_z.s.eq(self.in_a.s)
719 ]
720 # a mantissa greater than b, use a
721 with m.Elif(mge):
722 m.d.comb += [
723 self.out_tot.eq(am0 - bm0),
724 self.out_z.s.eq(self.in_a.s)
725 ]
726 # b mantissa greater than a, use b
727 with m.Else():
728 m.d.comb += [
729 self.out_tot.eq(bm0 - am0),
730 self.out_z.s.eq(self.in_b.s)
731 ]
732 return m
733
734
735 class FPAddStage0(FPState, FPID):
736 """ First stage of add. covers same-sign (add) and subtract
737 special-casing when mantissas are greater or equal, to
738 give greatest accuracy.
739 """
740
741 def __init__(self, width, id_wid):
742 FPState.__init__(self, "add_0")
743 FPID.__init__(self, id_wid)
744 self.mod = FPAddStage0Mod(width)
745 self.out_z = FPNumBase(width, False)
746 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
747
748 def setup(self, m, in_a, in_b, in_mid):
749 """ links module to inputs and outputs
750 """
751 self.mod.setup(m, in_a, in_b)
752 if self.in_mid is not None:
753 m.d.comb += self.in_mid.eq(in_mid)
754
755 def action(self, m):
756 self.idsync(m)
757 # NOTE: these could be done as combinatorial (merge add0+add1)
758 m.d.sync += self.out_z.copy(self.mod.out_z)
759 m.d.sync += self.out_tot.eq(self.mod.out_tot)
760 m.next = "add_1"
761
762
763 class FPAddStage1Mod(FPState):
764 """ Second stage of add: preparation for normalisation.
765 detects when tot sum is too big (tot[27] is kinda a carry bit)
766 """
767
768 def __init__(self, width):
769 self.out_norm = Signal(reset_less=True)
770 self.in_z = FPNumBase(width, False)
771 self.in_tot = Signal(self.in_z.m_width + 4, reset_less=True)
772 self.out_z = FPNumBase(width, False)
773 self.out_of = Overflow()
774
775 def setup(self, m, in_tot, in_z):
776 """ links module to inputs and outputs
777 """
778 m.submodules.add1 = self
779 m.submodules.add1_out_overflow = self.out_of
780
781 m.d.comb += self.in_z.copy(in_z)
782 m.d.comb += self.in_tot.eq(in_tot)
783
784 def elaborate(self, platform):
785 m = Module()
786 #m.submodules.norm1_in_overflow = self.in_of
787 #m.submodules.norm1_out_overflow = self.out_of
788 #m.submodules.norm1_in_z = self.in_z
789 #m.submodules.norm1_out_z = self.out_z
790 m.d.comb += self.out_z.copy(self.in_z)
791 # tot[27] gets set when the sum overflows. shift result down
792 with m.If(self.in_tot[-1]):
793 m.d.comb += [
794 self.out_z.m.eq(self.in_tot[4:]),
795 self.out_of.m0.eq(self.in_tot[4]),
796 self.out_of.guard.eq(self.in_tot[3]),
797 self.out_of.round_bit.eq(self.in_tot[2]),
798 self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
799 self.out_z.e.eq(self.in_z.e + 1)
800 ]
801 # tot[27] zero case
802 with m.Else():
803 m.d.comb += [
804 self.out_z.m.eq(self.in_tot[3:]),
805 self.out_of.m0.eq(self.in_tot[3]),
806 self.out_of.guard.eq(self.in_tot[2]),
807 self.out_of.round_bit.eq(self.in_tot[1]),
808 self.out_of.sticky.eq(self.in_tot[0])
809 ]
810 return m
811
812
813 class FPAddStage1(FPState, FPID):
814
815 def __init__(self, width, id_wid):
816 FPState.__init__(self, "add_1")
817 FPID.__init__(self, id_wid)
818 self.mod = FPAddStage1Mod(width)
819 self.out_z = FPNumBase(width, False)
820 self.out_of = Overflow()
821 self.norm_stb = Signal()
822
823 def setup(self, m, in_tot, in_z, in_mid):
824 """ links module to inputs and outputs
825 """
826 self.mod.setup(m, in_tot, in_z)
827
828 m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state
829
830 if self.in_mid is not None:
831 m.d.comb += self.in_mid.eq(in_mid)
832
833 def action(self, m):
834 self.idsync(m)
835 m.d.sync += self.out_of.copy(self.mod.out_of)
836 m.d.sync += self.out_z.copy(self.mod.out_z)
837 m.d.sync += self.norm_stb.eq(1)
838 m.next = "normalise_1"
839
840
841 class FPNorm1ModSingle:
842
843 def __init__(self, width):
844 self.width = width
845 self.out_norm = Signal(reset_less=True)
846 self.in_z = FPNumBase(width, False)
847 self.in_of = Overflow()
848 self.out_z = FPNumBase(width, False)
849 self.out_of = Overflow()
850
851 def setup(self, m, in_z, in_of, out_z):
852 """ links module to inputs and outputs
853 """
854 m.submodules.normalise_1 = self
855
856 m.d.comb += self.in_z.copy(in_z)
857 m.d.comb += self.in_of.copy(in_of)
858
859 m.d.comb += out_z.copy(self.out_z)
860
861 def elaborate(self, platform):
862 m = Module()
863
864 mwid = self.out_z.m_width+2
865 pe = PriorityEncoder(mwid)
866 m.submodules.norm_pe = pe
867
868 m.submodules.norm1_out_z = self.out_z
869 m.submodules.norm1_out_overflow = self.out_of
870 m.submodules.norm1_in_z = self.in_z
871 m.submodules.norm1_in_overflow = self.in_of
872
873 in_z = FPNumBase(self.width, False)
874 in_of = Overflow()
875 m.submodules.norm1_insel_z = in_z
876 m.submodules.norm1_insel_overflow = in_of
877
878 espec = (len(in_z.e), True)
879 ediff_n126 = Signal(espec, reset_less=True)
880 msr = MultiShiftRMerge(mwid, espec)
881 m.submodules.multishift_r = msr
882
883 m.d.comb += in_z.copy(self.in_z)
884 m.d.comb += in_of.copy(self.in_of)
885 # initialise out from in (overridden below)
886 m.d.comb += self.out_z.copy(in_z)
887 m.d.comb += self.out_of.copy(in_of)
888 # normalisation increase/decrease conditions
889 decrease = Signal(reset_less=True)
890 increase = Signal(reset_less=True)
891 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
892 m.d.comb += increase.eq(in_z.exp_lt_n126)
893 # decrease exponent
894 with m.If(decrease):
895 # *sigh* not entirely obvious: count leading zeros (clz)
896 # with a PriorityEncoder: to find from the MSB
897 # we reverse the order of the bits.
898 temp_m = Signal(mwid, reset_less=True)
899 temp_s = Signal(mwid+1, reset_less=True)
900 clz = Signal((len(in_z.e), True), reset_less=True)
901 # make sure that the amount to decrease by does NOT
902 # go below the minimum non-INF/NaN exponent
903 limclz = Mux(in_z.exp_sub_n126 > pe.o, pe.o,
904 in_z.exp_sub_n126)
905 m.d.comb += [
906 # cat round and guard bits back into the mantissa
907 temp_m.eq(Cat(in_of.round_bit, in_of.guard, in_z.m)),
908 pe.i.eq(temp_m[::-1]), # inverted
909 clz.eq(limclz), # count zeros from MSB down
910 temp_s.eq(temp_m << clz), # shift mantissa UP
911 self.out_z.e.eq(in_z.e - clz), # DECREASE exponent
912 self.out_z.m.eq(temp_s[2:]), # exclude bits 0&1
913 self.out_of.m0.eq(temp_s[2]), # copy of mantissa[0]
914 # overflow in bits 0..1: got shifted too (leave sticky)
915 self.out_of.guard.eq(temp_s[1]), # guard
916 self.out_of.round_bit.eq(temp_s[0]), # round
917 ]
918 # increase exponent
919 with m.Elif(increase):
920 temp_m = Signal(mwid+1, reset_less=True)
921 m.d.comb += [
922 temp_m.eq(Cat(in_of.sticky, in_of.round_bit, in_of.guard,
923 in_z.m)),
924 ediff_n126.eq(in_z.N126 - in_z.e),
925 # connect multi-shifter to inp/out mantissa (and ediff)
926 msr.inp.eq(temp_m),
927 msr.diff.eq(ediff_n126),
928 self.out_z.m.eq(msr.m[3:]),
929 self.out_of.m0.eq(temp_s[3]), # copy of mantissa[0]
930 # overflow in bits 0..1: got shifted too (leave sticky)
931 self.out_of.guard.eq(temp_s[2]), # guard
932 self.out_of.round_bit.eq(temp_s[1]), # round
933 self.out_of.sticky.eq(temp_s[0]), # sticky
934 self.out_z.e.eq(in_z.e + ediff_n126),
935 ]
936
937 return m
938
939
940 class FPNorm1ModMulti:
941
942 def __init__(self, width, single_cycle=True):
943 self.width = width
944 self.in_select = Signal(reset_less=True)
945 self.out_norm = Signal(reset_less=True)
946 self.in_z = FPNumBase(width, False)
947 self.in_of = Overflow()
948 self.temp_z = FPNumBase(width, False)
949 self.temp_of = Overflow()
950 self.out_z = FPNumBase(width, False)
951 self.out_of = Overflow()
952
953 def elaborate(self, platform):
954 m = Module()
955
956 m.submodules.norm1_out_z = self.out_z
957 m.submodules.norm1_out_overflow = self.out_of
958 m.submodules.norm1_temp_z = self.temp_z
959 m.submodules.norm1_temp_of = self.temp_of
960 m.submodules.norm1_in_z = self.in_z
961 m.submodules.norm1_in_overflow = self.in_of
962
963 in_z = FPNumBase(self.width, False)
964 in_of = Overflow()
965 m.submodules.norm1_insel_z = in_z
966 m.submodules.norm1_insel_overflow = in_of
967
968 # select which of temp or in z/of to use
969 with m.If(self.in_select):
970 m.d.comb += in_z.copy(self.in_z)
971 m.d.comb += in_of.copy(self.in_of)
972 with m.Else():
973 m.d.comb += in_z.copy(self.temp_z)
974 m.d.comb += in_of.copy(self.temp_of)
975 # initialise out from in (overridden below)
976 m.d.comb += self.out_z.copy(in_z)
977 m.d.comb += self.out_of.copy(in_of)
978 # normalisation increase/decrease conditions
979 decrease = Signal(reset_less=True)
980 increase = Signal(reset_less=True)
981 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
982 m.d.comb += increase.eq(in_z.exp_lt_n126)
983 m.d.comb += self.out_norm.eq(decrease | increase) # loop-end
984 # decrease exponent
985 with m.If(decrease):
986 m.d.comb += [
987 self.out_z.e.eq(in_z.e - 1), # DECREASE exponent
988 self.out_z.m.eq(in_z.m << 1), # shift mantissa UP
989 self.out_z.m[0].eq(in_of.guard), # steal guard (was tot[2])
990 self.out_of.guard.eq(in_of.round_bit), # round (was tot[1])
991 self.out_of.round_bit.eq(0), # reset round bit
992 self.out_of.m0.eq(in_of.guard),
993 ]
994 # increase exponent
995 with m.Elif(increase):
996 m.d.comb += [
997 self.out_z.e.eq(in_z.e + 1), # INCREASE exponent
998 self.out_z.m.eq(in_z.m >> 1), # shift mantissa DOWN
999 self.out_of.guard.eq(in_z.m[0]),
1000 self.out_of.m0.eq(in_z.m[1]),
1001 self.out_of.round_bit.eq(in_of.guard),
1002 self.out_of.sticky.eq(in_of.sticky | in_of.round_bit)
1003 ]
1004
1005 return m
1006
1007
1008 class FPNorm1Single(FPState, FPID):
1009
1010 def __init__(self, width, id_wid, single_cycle=True):
1011 FPID.__init__(self, id_wid)
1012 FPState.__init__(self, "normalise_1")
1013 self.mod = FPNorm1ModSingle(width)
1014 self.out_norm = Signal(reset_less=True)
1015 self.out_z = FPNumBase(width)
1016 self.out_roundz = Signal(reset_less=True)
1017
1018 def setup(self, m, in_z, in_of, in_mid):
1019 """ links module to inputs and outputs
1020 """
1021 self.mod.setup(m, in_z, in_of, self.out_z)
1022
1023 if self.in_mid is not None:
1024 m.d.comb += self.in_mid.eq(in_mid)
1025
1026 def action(self, m):
1027 self.idsync(m)
1028 m.d.sync += self.out_roundz.eq(self.mod.out_of.roundz)
1029 m.next = "round"
1030
1031
1032 class FPNorm1Multi(FPState, FPID):
1033
1034 def __init__(self, width, id_wid):
1035 FPID.__init__(self, id_wid)
1036 FPState.__init__(self, "normalise_1")
1037 self.mod = FPNorm1ModMulti(width)
1038 self.stb = Signal(reset_less=True)
1039 self.ack = Signal(reset=0, reset_less=True)
1040 self.out_norm = Signal(reset_less=True)
1041 self.in_accept = Signal(reset_less=True)
1042 self.temp_z = FPNumBase(width)
1043 self.temp_of = Overflow()
1044 self.out_z = FPNumBase(width)
1045 self.out_roundz = Signal(reset_less=True)
1046
1047 def setup(self, m, in_z, in_of, norm_stb, in_mid):
1048 """ links module to inputs and outputs
1049 """
1050 self.mod.setup(m, in_z, in_of, norm_stb,
1051 self.in_accept, self.temp_z, self.temp_of,
1052 self.out_z, self.out_norm)
1053
1054 m.d.comb += self.stb.eq(norm_stb)
1055 m.d.sync += self.ack.eq(0) # sets to zero when not in normalise_1 state
1056
1057 if self.in_mid is not None:
1058 m.d.comb += self.in_mid.eq(in_mid)
1059
1060 def action(self, m):
1061 self.idsync(m)
1062 m.d.comb += self.in_accept.eq((~self.ack) & (self.stb))
1063 m.d.sync += self.temp_of.copy(self.mod.out_of)
1064 m.d.sync += self.temp_z.copy(self.out_z)
1065 with m.If(self.out_norm):
1066 with m.If(self.in_accept):
1067 m.d.sync += [
1068 self.ack.eq(1),
1069 ]
1070 with m.Else():
1071 m.d.sync += self.ack.eq(0)
1072 with m.Else():
1073 # normalisation not required (or done).
1074 m.next = "round"
1075 m.d.sync += self.ack.eq(1)
1076 m.d.sync += self.out_roundz.eq(self.mod.out_of.roundz)
1077
1078
1079 class FPNormToPack(FPState, FPID):
1080
1081 def __init__(self, width, id_wid):
1082 FPID.__init__(self, id_wid)
1083 FPState.__init__(self, "normalise_1")
1084 self.width = width
1085
1086 def setup(self, m, in_z, in_of, in_mid):
1087 """ links module to inputs and outputs
1088 """
1089
1090 # Normalisation (chained to input in_z+in_of)
1091 nmod = FPNorm1ModSingle(self.width)
1092 n_out_z = FPNumBase(self.width)
1093 n_out_roundz = Signal(reset_less=True)
1094 nmod.setup(m, in_z, in_of, n_out_z)
1095
1096 # Rounding (chained to normalisation)
1097 rmod = FPRoundMod(self.width)
1098 r_out_z = FPNumBase(self.width)
1099 rmod.setup(m, n_out_z, n_out_roundz)
1100 m.d.comb += n_out_roundz.eq(nmod.out_of.roundz)
1101 m.d.comb += r_out_z.copy(rmod.out_z)
1102
1103 # Corrections (chained to rounding)
1104 cmod = FPCorrectionsMod(self.width)
1105 c_out_z = FPNumBase(self.width)
1106 cmod.setup(m, r_out_z)
1107 m.d.comb += c_out_z.copy(cmod.out_z)
1108
1109 # Pack (chained to corrections)
1110 self.pmod = FPPackMod(self.width)
1111 self.out_z = FPNumBase(self.width)
1112 self.pmod.setup(m, c_out_z)
1113
1114 # Multiplex ID
1115 if self.in_mid is not None:
1116 m.d.comb += self.in_mid.eq(in_mid)
1117
1118 def action(self, m):
1119 self.idsync(m) # copies incoming ID to outgoing
1120 m.d.sync += self.out_z.v.eq(self.pmod.out_z.v) # outputs packed result
1121 m.next = "pack_put_z"
1122
1123
1124 class FPRoundMod:
1125
1126 def __init__(self, width):
1127 self.in_roundz = Signal(reset_less=True)
1128 self.in_z = FPNumBase(width, False)
1129 self.out_z = FPNumBase(width, False)
1130
1131 def setup(self, m, in_z, roundz):
1132 m.submodules.roundz = self
1133
1134 m.d.comb += self.in_z.copy(in_z)
1135 m.d.comb += self.in_roundz.eq(roundz)
1136
1137 def elaborate(self, platform):
1138 m = Module()
1139 m.d.comb += self.out_z.copy(self.in_z)
1140 with m.If(self.in_roundz):
1141 m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
1142 with m.If(self.in_z.m == self.in_z.m1s): # all 1s
1143 m.d.comb += self.out_z.e.eq(self.in_z.e + 1) # exponent up
1144 return m
1145
1146
1147 class FPRound(FPState, FPID):
1148
1149 def __init__(self, width, id_wid):
1150 FPState.__init__(self, "round")
1151 FPID.__init__(self, id_wid)
1152 self.mod = FPRoundMod(width)
1153 self.out_z = FPNumBase(width)
1154
1155 def setup(self, m, in_z, roundz, in_mid):
1156 """ links module to inputs and outputs
1157 """
1158 self.mod.setup(m, in_z, roundz)
1159
1160 if self.in_mid is not None:
1161 m.d.comb += self.in_mid.eq(in_mid)
1162
1163 def action(self, m):
1164 self.idsync(m)
1165 m.d.sync += self.out_z.copy(self.mod.out_z)
1166 m.next = "corrections"
1167
1168
1169 class FPCorrectionsMod:
1170
1171 def __init__(self, width):
1172 self.in_z = FPNumOut(width, False)
1173 self.out_z = FPNumOut(width, False)
1174
1175 def setup(self, m, in_z):
1176 """ links module to inputs and outputs
1177 """
1178 m.submodules.corrections = self
1179 m.d.comb += self.in_z.copy(in_z)
1180
1181 def elaborate(self, platform):
1182 m = Module()
1183 m.submodules.corr_in_z = self.in_z
1184 m.submodules.corr_out_z = self.out_z
1185 m.d.comb += self.out_z.copy(self.in_z)
1186 with m.If(self.in_z.is_denormalised):
1187 m.d.comb += self.out_z.e.eq(self.in_z.N127)
1188 return m
1189
1190
1191 class FPCorrections(FPState, FPID):
1192
1193 def __init__(self, width, id_wid):
1194 FPState.__init__(self, "corrections")
1195 FPID.__init__(self, id_wid)
1196 self.mod = FPCorrectionsMod(width)
1197 self.out_z = FPNumBase(width)
1198
1199 def setup(self, m, in_z, in_mid):
1200 """ links module to inputs and outputs
1201 """
1202 self.mod.setup(m, in_z)
1203 if self.in_mid is not None:
1204 m.d.comb += self.in_mid.eq(in_mid)
1205
1206 def action(self, m):
1207 self.idsync(m)
1208 m.d.sync += self.out_z.copy(self.mod.out_z)
1209 m.next = "pack"
1210
1211
1212 class FPPackMod:
1213
1214 def __init__(self, width):
1215 self.in_z = FPNumOut(width, False)
1216 self.out_z = FPNumOut(width, False)
1217
1218 def setup(self, m, in_z):
1219 """ links module to inputs and outputs
1220 """
1221 m.submodules.pack = self
1222 m.d.comb += self.in_z.copy(in_z)
1223
1224 def elaborate(self, platform):
1225 m = Module()
1226 m.submodules.pack_in_z = self.in_z
1227 with m.If(self.in_z.is_overflowed):
1228 m.d.comb += self.out_z.inf(self.in_z.s)
1229 with m.Else():
1230 m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
1231 return m
1232
1233
1234 class FPPack(FPState, FPID):
1235
1236 def __init__(self, width, id_wid):
1237 FPState.__init__(self, "pack")
1238 FPID.__init__(self, id_wid)
1239 self.mod = FPPackMod(width)
1240 self.out_z = FPNumOut(width, False)
1241
1242 def setup(self, m, in_z, in_mid):
1243 """ links module to inputs and outputs
1244 """
1245 self.mod.setup(m, in_z)
1246 if self.in_mid is not None:
1247 m.d.comb += self.in_mid.eq(in_mid)
1248
1249 def action(self, m):
1250 self.idsync(m)
1251 m.d.sync += self.out_z.v.eq(self.mod.out_z.v)
1252 m.next = "pack_put_z"
1253
1254
1255 class FPPutZ(FPState):
1256
1257 def __init__(self, state, in_z, out_z, in_mid, out_mid, to_state=None):
1258 FPState.__init__(self, state)
1259 if to_state is None:
1260 to_state = "get_ops"
1261 self.to_state = to_state
1262 self.in_z = in_z
1263 self.out_z = out_z
1264 self.in_mid = in_mid
1265 self.out_mid = out_mid
1266
1267 def action(self, m):
1268 if self.in_mid is not None:
1269 m.d.sync += self.out_mid.eq(self.in_mid)
1270 m.d.sync += [
1271 self.out_z.v.eq(self.in_z.v)
1272 ]
1273 with m.If(self.out_z.stb & self.out_z.ack):
1274 m.d.sync += self.out_z.stb.eq(0)
1275 m.next = self.to_state
1276 with m.Else():
1277 m.d.sync += self.out_z.stb.eq(1)
1278
1279
1280 class FPPutZIdx(FPState):
1281
1282 def __init__(self, state, in_z, out_zs, in_mid, to_state=None):
1283 FPState.__init__(self, state)
1284 if to_state is None:
1285 to_state = "get_ops"
1286 self.to_state = to_state
1287 self.in_z = in_z
1288 self.out_zs = out_zs
1289 self.in_mid = in_mid
1290
1291 def action(self, m):
1292 outz_stb = Signal(reset_less=True)
1293 outz_ack = Signal(reset_less=True)
1294 m.d.comb += [outz_stb.eq(self.out_zs[self.in_mid].stb),
1295 outz_ack.eq(self.out_zs[self.in_mid].ack),
1296 ]
1297 m.d.sync += [
1298 self.out_zs[self.in_mid].v.eq(self.in_z.v)
1299 ]
1300 with m.If(outz_stb & outz_ack):
1301 m.d.sync += self.out_zs[self.in_mid].stb.eq(0)
1302 m.next = self.to_state
1303 with m.Else():
1304 m.d.sync += self.out_zs[self.in_mid].stb.eq(1)
1305
1306
1307 class FPADDBaseMod(FPID):
1308
1309 def __init__(self, width, id_wid=None, single_cycle=False, compact=True):
1310 """ IEEE754 FP Add
1311
1312 * width: bit-width of IEEE754. supported: 16, 32, 64
1313 * id_wid: an identifier that is sync-connected to the input
1314 * single_cycle: True indicates each stage to complete in 1 clock
1315 * compact: True indicates a reduced number of stages
1316 """
1317 FPID.__init__(self, id_wid)
1318 self.width = width
1319 self.single_cycle = single_cycle
1320 self.compact = compact
1321
1322 self.in_t = Trigger()
1323 self.in_a = Signal(width)
1324 self.in_b = Signal(width)
1325 self.out_z = FPOp(width)
1326
1327 self.states = []
1328
1329 def add_state(self, state):
1330 self.states.append(state)
1331 return state
1332
1333 def get_fragment(self, platform=None):
1334 """ creates the HDL code-fragment for FPAdd
1335 """
1336 m = Module()
1337 m.submodules.out_z = self.out_z
1338 m.submodules.in_t = self.in_t
1339 if self.compact:
1340 self.get_compact_fragment(m, platform)
1341 else:
1342 self.get_longer_fragment(m, platform)
1343
1344 with m.FSM() as fsm:
1345
1346 for state in self.states:
1347 with m.State(state.state_from):
1348 state.action(m)
1349
1350 return m
1351
1352 def get_longer_fragment(self, m, platform=None):
1353
1354 get = self.add_state(FPGet2Op("get_ops", "special_cases",
1355 self.in_a, self.in_b, self.width))
1356 get.setup(m, self.in_a, self.in_b, self.in_t.stb, self.in_t.ack)
1357 a = get.out_op1
1358 b = get.out_op2
1359
1360 sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid))
1361 sc.setup(m, a, b, self.in_mid)
1362
1363 dn = self.add_state(FPAddDeNorm(self.width, self.id_wid))
1364 dn.setup(m, a, b, sc.in_mid)
1365
1366 if self.single_cycle:
1367 alm = self.add_state(FPAddAlignSingle(self.width, self.id_wid))
1368 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
1369 else:
1370 alm = self.add_state(FPAddAlignMulti(self.width, self.id_wid))
1371 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
1372
1373 add0 = self.add_state(FPAddStage0(self.width, self.id_wid))
1374 add0.setup(m, alm.out_a, alm.out_b, alm.in_mid)
1375
1376 add1 = self.add_state(FPAddStage1(self.width, self.id_wid))
1377 add1.setup(m, add0.out_tot, add0.out_z, add0.in_mid)
1378
1379 if self.single_cycle:
1380 n1 = self.add_state(FPNorm1Single(self.width, self.id_wid))
1381 n1.setup(m, add1.out_z, add1.out_of, add0.in_mid)
1382 else:
1383 n1 = self.add_state(FPNorm1Multi(self.width, self.id_wid))
1384 n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb, add0.in_mid)
1385
1386 rn = self.add_state(FPRound(self.width, self.id_wid))
1387 rn.setup(m, n1.out_z, n1.out_roundz, n1.in_mid)
1388
1389 cor = self.add_state(FPCorrections(self.width, self.id_wid))
1390 cor.setup(m, rn.out_z, rn.in_mid)
1391
1392 pa = self.add_state(FPPack(self.width, self.id_wid))
1393 pa.setup(m, cor.out_z, rn.in_mid)
1394
1395 ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z,
1396 pa.in_mid, self.out_mid))
1397
1398 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
1399 pa.in_mid, self.out_mid))
1400
1401 def get_compact_fragment(self, m, platform=None):
1402
1403 get = self.add_state(FPGet2Op("get_ops", "special_cases",
1404 self.in_a, self.in_b, self.width))
1405 get.setup(m, self.in_a, self.in_b, self.in_t.stb, self.in_t.ack)
1406 a = get.out_op1
1407 b = get.out_op2
1408
1409 sc = self.add_state(FPAddSpecialCasesDeNorm(self.width, self.id_wid))
1410 sc.setup(m, a, b, self.in_mid)
1411
1412 alm = self.add_state(FPAddAlignSingleAdd(self.width, self.id_wid))
1413 alm.setup(m, sc.out_a, sc.out_b, sc.in_mid)
1414
1415 n1 = self.add_state(FPNormToPack(self.width, self.id_wid))
1416 n1.setup(m, alm.out_z, alm.out_of, alm.in_mid)
1417
1418 ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z, self.out_z,
1419 n1.in_mid, self.out_mid))
1420
1421 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
1422 sc.in_mid, self.out_mid))
1423
1424
1425 class FPADDBase(FPState, FPID):
1426
1427 def __init__(self, width, id_wid=None, single_cycle=False):
1428 """ IEEE754 FP Add
1429
1430 * width: bit-width of IEEE754. supported: 16, 32, 64
1431 * id_wid: an identifier that is sync-connected to the input
1432 * single_cycle: True indicates each stage to complete in 1 clock
1433 """
1434 FPID.__init__(self, id_wid)
1435 FPState.__init__(self, "fpadd")
1436 self.width = width
1437 self.single_cycle = single_cycle
1438 self.mod = FPADDBaseMod(width, id_wid, single_cycle)
1439
1440 self.in_t = Trigger()
1441 self.in_a = Signal(width)
1442 self.in_b = Signal(width)
1443 #self.out_z = FPOp(width)
1444
1445 self.z_done = Signal(reset_less=True) # connects to out_z Strobe
1446 self.in_accept = Signal(reset_less=True)
1447 self.add_stb = Signal(reset_less=True)
1448 self.add_ack = Signal(reset=0, reset_less=True)
1449
1450 def setup(self, m, a, b, add_stb, in_mid, out_z, out_mid):
1451 self.out_z = out_z
1452 self.out_mid = out_mid
1453 m.d.comb += [self.in_a.eq(a),
1454 self.in_b.eq(b),
1455 self.mod.in_a.eq(self.in_a),
1456 self.mod.in_b.eq(self.in_b),
1457 self.in_mid.eq(in_mid),
1458 self.mod.in_mid.eq(self.in_mid),
1459 self.z_done.eq(self.mod.out_z.trigger),
1460 #self.add_stb.eq(add_stb),
1461 self.mod.in_t.stb.eq(self.in_t.stb),
1462 self.in_t.ack.eq(self.mod.in_t.ack),
1463 self.out_mid.eq(self.mod.out_mid),
1464 self.out_z.v.eq(self.mod.out_z.v),
1465 self.out_z.stb.eq(self.mod.out_z.stb),
1466 self.mod.out_z.ack.eq(self.out_z.ack),
1467 ]
1468
1469 m.d.sync += self.add_stb.eq(add_stb)
1470 m.d.sync += self.add_ack.eq(0) # sets to zero when not in active state
1471 m.d.sync += self.out_z.ack.eq(0) # likewise
1472 #m.d.sync += self.in_t.stb.eq(0)
1473
1474 m.submodules.fpadd = self.mod
1475
1476 def action(self, m):
1477
1478 # in_accept is set on incoming strobe HIGH and ack LOW.
1479 m.d.comb += self.in_accept.eq((~self.add_ack) & (self.add_stb))
1480
1481 #with m.If(self.in_t.ack):
1482 # m.d.sync += self.in_t.stb.eq(0)
1483 with m.If(~self.z_done):
1484 # not done: test for accepting an incoming operand pair
1485 with m.If(self.in_accept):
1486 m.d.sync += [
1487 self.add_ack.eq(1), # acknowledge receipt...
1488 self.in_t.stb.eq(1), # initiate add
1489 ]
1490 with m.Else():
1491 m.d.sync += [self.add_ack.eq(0),
1492 self.in_t.stb.eq(0),
1493 self.out_z.ack.eq(1),
1494 ]
1495 with m.Else():
1496 # done: acknowledge, and write out id and value
1497 m.d.sync += [self.add_ack.eq(1),
1498 self.in_t.stb.eq(0)
1499 ]
1500 m.next = "put_z"
1501
1502 return
1503
1504 if self.in_mid is not None:
1505 m.d.sync += self.out_mid.eq(self.mod.out_mid)
1506
1507 m.d.sync += [
1508 self.out_z.v.eq(self.mod.out_z.v)
1509 ]
1510 # move to output state on detecting z ack
1511 with m.If(self.out_z.trigger):
1512 m.d.sync += self.out_z.stb.eq(0)
1513 m.next = "put_z"
1514 with m.Else():
1515 m.d.sync += self.out_z.stb.eq(1)
1516
1517 class ResArray:
1518 def __init__(self, width, id_wid):
1519 self.width = width
1520 self.id_wid = id_wid
1521 res = []
1522 for i in range(rs_sz):
1523 out_z = FPOp(width)
1524 out_z.name = "out_z_%d" % i
1525 res.append(out_z)
1526 self.res = Array(res)
1527 self.in_z = FPOp(width)
1528 self.in_mid = Signal(self.id_wid, reset_less=True)
1529
1530 def setup(self, m, in_z, in_mid):
1531 m.d.comb += [self.in_z.copy(in_z),
1532 self.in_mid.eq(in_mid)]
1533
1534 def get_fragment(self, platform=None):
1535 """ creates the HDL code-fragment for FPAdd
1536 """
1537 m = Module()
1538 m.submodules.res_in_z = self.in_z
1539 m.submodules += self.res
1540
1541 return m
1542
1543 def ports(self):
1544 res = []
1545 for z in self.res:
1546 res += z.ports()
1547 return res
1548
1549
1550 class FPADD(FPID):
1551 """ FPADD: stages as follows:
1552
1553 FPGetOp (a)
1554 |
1555 FPGetOp (b)
1556 |
1557 FPAddBase---> FPAddBaseMod
1558 | |
1559 PutZ GetOps->Specials->Align->Add1/2->Norm->Round/Pack->PutZ
1560
1561 FPAddBase is tricky: it is both a stage and *has* stages.
1562 Connection to FPAddBaseMod therefore requires an in stb/ack
1563 and an out stb/ack. Just as with Add1-Norm1 interaction, FPGetOp
1564 needs to be the thing that raises the incoming stb.
1565 """
1566
1567 def __init__(self, width, id_wid=None, single_cycle=False, rs_sz=2):
1568 """ IEEE754 FP Add
1569
1570 * width: bit-width of IEEE754. supported: 16, 32, 64
1571 * id_wid: an identifier that is sync-connected to the input
1572 * single_cycle: True indicates each stage to complete in 1 clock
1573 """
1574 self.width = width
1575 self.id_wid = id_wid
1576 self.single_cycle = single_cycle
1577
1578 #self.out_z = FPOp(width)
1579 self.ids = FPID(id_wid)
1580
1581 rs = []
1582 for i in range(rs_sz):
1583 in_a = FPOp(width)
1584 in_b = FPOp(width)
1585 in_a.name = "in_a_%d" % i
1586 in_b.name = "in_b_%d" % i
1587 rs.append((in_a, in_b))
1588 self.rs = Array(rs)
1589
1590 res = []
1591 for i in range(rs_sz):
1592 out_z = FPOp(width)
1593 out_z.name = "out_z_%d" % i
1594 res.append(out_z)
1595 self.res = Array(res)
1596
1597 self.states = []
1598
1599 def add_state(self, state):
1600 self.states.append(state)
1601 return state
1602
1603 def get_fragment(self, platform=None):
1604 """ creates the HDL code-fragment for FPAdd
1605 """
1606 m = Module()
1607 m.submodules += self.rs
1608
1609 in_a = self.rs[0][0]
1610 in_b = self.rs[0][1]
1611
1612 out_z = FPOp(self.width)
1613 out_mid = Signal(self.id_wid, reset_less=True)
1614 m.submodules.out_z = out_z
1615
1616 geta = self.add_state(FPGetOp("get_a", "get_b",
1617 in_a, self.width))
1618 geta.setup(m, in_a)
1619 a = geta.out_op
1620
1621 getb = self.add_state(FPGetOp("get_b", "fpadd",
1622 in_b, self.width))
1623 getb.setup(m, in_b)
1624 b = getb.out_op
1625
1626 ab = FPADDBase(self.width, self.id_wid, self.single_cycle)
1627 ab = self.add_state(ab)
1628 ab.setup(m, a, b, getb.out_decode, self.ids.in_mid,
1629 out_z, out_mid)
1630
1631 pz = self.add_state(FPPutZIdx("put_z", ab.out_z, self.res,
1632 out_mid, "get_a"))
1633
1634 with m.FSM() as fsm:
1635
1636 for state in self.states:
1637 with m.State(state.state_from):
1638 state.action(m)
1639
1640 return m
1641
1642
1643 if __name__ == "__main__":
1644 if True:
1645 alu = FPADD(width=32, id_wid=5, single_cycle=True)
1646 main(alu, ports=alu.rs[0][0].ports() + \
1647 alu.rs[0][1].ports() + \
1648 alu.res[0].ports() + \
1649 [alu.ids.in_mid, alu.ids.out_mid])
1650 else:
1651 alu = FPADDBase(width=32, id_wid=5, single_cycle=True)
1652 main(alu, ports=[alu.in_a, alu.in_b] + \
1653 alu.in_t.ports() + \
1654 alu.out_z.ports() + \
1655 [alu.in_mid, alu.out_mid])
1656
1657
1658 # works... but don't use, just do "python fname.py convert -t v"
1659 #print (verilog.convert(alu, ports=[
1660 # ports=alu.in_a.ports() + \
1661 # alu.in_b.ports() + \
1662 # alu.out_z.ports())