create FPNumBase2Ops class and add ispec function to SpecialCasesMod
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Mux, Array, Const
6 from nmigen.lib.coding import PriorityEncoder
7 from nmigen.cli import main, verilog
8 from math import log
9
10 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
11 from fpbase import MultiShiftRMerge, Trigger
12 #from fpbase import FPNumShiftMultiRight
13
14
15 class FPState(FPBase):
16 def __init__(self, state_from):
17 self.state_from = state_from
18
19 def set_inputs(self, inputs):
20 self.inputs = inputs
21 for k,v in inputs.items():
22 setattr(self, k, v)
23
24 def set_outputs(self, outputs):
25 self.outputs = outputs
26 for k,v in outputs.items():
27 setattr(self, k, v)
28
29
30 class FPGetSyncOpsMod:
31 def __init__(self, width, num_ops=2):
32 self.width = width
33 self.num_ops = num_ops
34 inops = []
35 outops = []
36 for i in range(num_ops):
37 inops.append(Signal(width, reset_less=True))
38 outops.append(Signal(width, reset_less=True))
39 self.in_op = inops
40 self.out_op = outops
41 self.stb = Signal(num_ops)
42 self.ack = Signal()
43 self.ready = Signal(reset_less=True)
44 self.out_decode = Signal(reset_less=True)
45
46 def elaborate(self, platform):
47 m = Module()
48 m.d.comb += self.ready.eq(self.stb == Const(-1, (self.num_ops, False)))
49 m.d.comb += self.out_decode.eq(self.ack & self.ready)
50 with m.If(self.out_decode):
51 for i in range(self.num_ops):
52 m.d.comb += [
53 self.out_op[i].eq(self.in_op[i]),
54 ]
55 return m
56
57 def ports(self):
58 return self.in_op + self.out_op + [self.stb, self.ack]
59
60
61 class FPOps(Trigger):
62 def __init__(self, width, num_ops):
63 Trigger.__init__(self)
64 self.width = width
65 self.num_ops = num_ops
66
67 res = []
68 for i in range(num_ops):
69 res.append(Signal(width))
70 self.v = Array(res)
71
72 def ports(self):
73 res = []
74 for i in range(self.num_ops):
75 res.append(self.v[i])
76 res.append(self.ack)
77 res.append(self.stb)
78 return res
79
80
81 class InputGroup:
82 def __init__(self, width, num_ops=2, num_rows=4):
83 self.width = width
84 self.num_ops = num_ops
85 self.num_rows = num_rows
86 self.mmax = int(log(self.num_rows) / log(2))
87 self.rs = []
88 self.mid = Signal(self.mmax, reset_less=True) # multiplex id
89 for i in range(num_rows):
90 self.rs.append(FPGetSyncOpsMod(width, num_ops))
91 self.rs = Array(self.rs)
92
93 self.out_op = FPOps(width, num_ops)
94
95 def elaborate(self, platform):
96 m = Module()
97
98 pe = PriorityEncoder(self.num_rows)
99 m.submodules.selector = pe
100 m.submodules.out_op = self.out_op
101 m.submodules += self.rs
102
103 # connect priority encoder
104 in_ready = []
105 for i in range(self.num_rows):
106 in_ready.append(self.rs[i].ready)
107 m.d.comb += pe.i.eq(Cat(*in_ready))
108
109 active = Signal(reset_less=True)
110 out_en = Signal(reset_less=True)
111 m.d.comb += active.eq(~pe.n) # encoder active
112 m.d.comb += out_en.eq(active & self.out_op.trigger)
113
114 # encoder active: ack relevant input, record MID, pass output
115 with m.If(out_en):
116 rs = self.rs[pe.o]
117 m.d.sync += self.mid.eq(pe.o)
118 m.d.sync += rs.ack.eq(0)
119 m.d.sync += self.out_op.stb.eq(0)
120 for j in range(self.num_ops):
121 m.d.sync += self.out_op.v[j].eq(rs.out_op[j])
122 with m.Else():
123 m.d.sync += self.out_op.stb.eq(1)
124 # acks all default to zero
125 for i in range(self.num_rows):
126 m.d.sync += self.rs[i].ack.eq(1)
127
128 return m
129
130 def ports(self):
131 res = []
132 for i in range(self.num_rows):
133 inop = self.rs[i]
134 res += inop.in_op + [inop.stb]
135 return self.out_op.ports() + res + [self.mid]
136
137
138 class FPGetOpMod:
139 def __init__(self, width):
140 self.in_op = FPOp(width)
141 self.out_op = Signal(width)
142 self.out_decode = Signal(reset_less=True)
143
144 def elaborate(self, platform):
145 m = Module()
146 m.d.comb += self.out_decode.eq((self.in_op.ack) & (self.in_op.stb))
147 m.submodules.get_op_in = self.in_op
148 #m.submodules.get_op_out = self.out_op
149 with m.If(self.out_decode):
150 m.d.comb += [
151 self.out_op.eq(self.in_op.v),
152 ]
153 return m
154
155
156 class FPGetOp(FPState):
157 """ gets operand
158 """
159
160 def __init__(self, in_state, out_state, in_op, width):
161 FPState.__init__(self, in_state)
162 self.out_state = out_state
163 self.mod = FPGetOpMod(width)
164 self.in_op = in_op
165 self.out_op = Signal(width)
166 self.out_decode = Signal(reset_less=True)
167
168 def setup(self, m, in_op):
169 """ links module to inputs and outputs
170 """
171 setattr(m.submodules, self.state_from, self.mod)
172 m.d.comb += self.mod.in_op.eq(in_op)
173 #m.d.comb += self.out_op.eq(self.mod.out_op)
174 m.d.comb += self.out_decode.eq(self.mod.out_decode)
175
176 def action(self, m):
177 with m.If(self.out_decode):
178 m.next = self.out_state
179 m.d.sync += [
180 self.in_op.ack.eq(0),
181 self.out_op.eq(self.mod.out_op)
182 ]
183 with m.Else():
184 m.d.sync += self.in_op.ack.eq(1)
185
186
187 class FPGet2OpMod(Trigger):
188 def __init__(self, width):
189 Trigger.__init__(self)
190 self.in_op1 = Signal(width, reset_less=True)
191 self.in_op2 = Signal(width, reset_less=True)
192 self.out_op1 = FPNumIn(None, width)
193 self.out_op2 = FPNumIn(None, width)
194
195 def elaborate(self, platform):
196 m = Trigger.elaborate(self, platform)
197 #m.submodules.get_op_in = self.in_op
198 m.submodules.get_op1_out = self.out_op1
199 m.submodules.get_op2_out = self.out_op2
200 with m.If(self.trigger):
201 m.d.comb += [
202 self.out_op1.decode(self.in_op1),
203 self.out_op2.decode(self.in_op2),
204 ]
205 return m
206
207
208 class FPGet2Op(FPState):
209 """ gets operands
210 """
211
212 def __init__(self, in_state, out_state, in_op1, in_op2, width):
213 FPState.__init__(self, in_state)
214 self.out_state = out_state
215 self.mod = FPGet2OpMod(width)
216 self.in_op1 = in_op1
217 self.in_op2 = in_op2
218 self.out_op1 = FPNumIn(None, width)
219 self.out_op2 = FPNumIn(None, width)
220 self.in_stb = Signal(reset_less=True)
221 self.out_ack = Signal(reset_less=True)
222 self.out_decode = Signal(reset_less=True)
223
224 def setup(self, m, in_op1, in_op2, in_stb, in_ack):
225 """ links module to inputs and outputs
226 """
227 m.submodules.get_ops = self.mod
228 m.d.comb += self.mod.in_op1.eq(in_op1)
229 m.d.comb += self.mod.in_op2.eq(in_op2)
230 m.d.comb += self.mod.stb.eq(in_stb)
231 m.d.comb += self.out_ack.eq(self.mod.ack)
232 m.d.comb += self.out_decode.eq(self.mod.trigger)
233 m.d.comb += in_ack.eq(self.mod.ack)
234
235 def action(self, m):
236 with m.If(self.out_decode):
237 m.next = self.out_state
238 m.d.sync += [
239 self.mod.ack.eq(0),
240 #self.out_op1.v.eq(self.mod.out_op1.v),
241 #self.out_op2.v.eq(self.mod.out_op2.v),
242 self.out_op1.eq(self.mod.out_op1),
243 self.out_op2.eq(self.mod.out_op2)
244 ]
245 with m.Else():
246 m.d.sync += self.mod.ack.eq(1)
247
248 class FPNumBase2Ops:
249
250 def __init__(self, width):
251 self.a = FPNumBase(width)
252 self.b = FPNumBase(width)
253
254 def eq(self, i):
255 return [self.a.eq(i.a), self.a.eq(i.b)]
256
257
258 class FPAddSpecialCasesMod:
259 """ special cases: NaNs, infs, zeros, denormalised
260 NOTE: some of these are unique to add. see "Special Operations"
261 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
262 """
263
264 def __init__(self, width):
265 self.width = width
266 self.i = self.ispec()
267 self.out_z = FPNumOut(width, False)
268 self.out_do_z = Signal(reset_less=True)
269
270 def ispec(self):
271 return FPNumBase2Ops(self.width)
272
273 def setup(self, m, in_a, in_b, out_do_z):
274 """ links module to inputs and outputs
275 """
276 m.submodules.specialcases = self
277 m.d.comb += self.i.a.eq(in_a)
278 m.d.comb += self.i.b.eq(in_b)
279 m.d.comb += out_do_z.eq(self.out_do_z)
280
281 def elaborate(self, platform):
282 m = Module()
283
284 m.submodules.sc_in_a = self.i.a
285 m.submodules.sc_in_b = self.i.b
286 m.submodules.sc_out_z = self.out_z
287
288 s_nomatch = Signal()
289 m.d.comb += s_nomatch.eq(self.i.a.s != self.i.b.s)
290
291 m_match = Signal()
292 m.d.comb += m_match.eq(self.i.a.m == self.i.b.m)
293
294 # if a is NaN or b is NaN return NaN
295 with m.If(self.i.a.is_nan | self.i.b.is_nan):
296 m.d.comb += self.out_do_z.eq(1)
297 m.d.comb += self.out_z.nan(0)
298
299 # XXX WEIRDNESS for FP16 non-canonical NaN handling
300 # under review
301
302 ## if a is zero and b is NaN return -b
303 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
304 # m.d.comb += self.out_do_z.eq(1)
305 # m.d.comb += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
306
307 ## if b is zero and a is NaN return -a
308 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
309 # m.d.comb += self.out_do_z.eq(1)
310 # m.d.comb += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
311
312 ## if a is -zero and b is NaN return -b
313 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
314 # m.d.comb += self.out_do_z.eq(1)
315 # m.d.comb += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
316
317 ## if b is -zero and a is NaN return -a
318 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
319 # m.d.comb += self.out_do_z.eq(1)
320 # m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
321
322 # if a is inf return inf (or NaN)
323 with m.Elif(self.i.a.is_inf):
324 m.d.comb += self.out_do_z.eq(1)
325 m.d.comb += self.out_z.inf(self.i.a.s)
326 # if a is inf and signs don't match return NaN
327 with m.If(self.i.b.exp_128 & s_nomatch):
328 m.d.comb += self.out_z.nan(0)
329
330 # if b is inf return inf
331 with m.Elif(self.i.b.is_inf):
332 m.d.comb += self.out_do_z.eq(1)
333 m.d.comb += self.out_z.inf(self.i.b.s)
334
335 # if a is zero and b zero return signed-a/b
336 with m.Elif(self.i.a.is_zero & self.i.b.is_zero):
337 m.d.comb += self.out_do_z.eq(1)
338 m.d.comb += self.out_z.create(self.i.a.s & self.i.b.s,
339 self.i.b.e,
340 self.i.b.m[3:-1])
341
342 # if a is zero return b
343 with m.Elif(self.i.a.is_zero):
344 m.d.comb += self.out_do_z.eq(1)
345 m.d.comb += self.out_z.create(self.i.b.s, self.i.b.e,
346 self.i.b.m[3:-1])
347
348 # if b is zero return a
349 with m.Elif(self.i.b.is_zero):
350 m.d.comb += self.out_do_z.eq(1)
351 m.d.comb += self.out_z.create(self.i.a.s, self.i.a.e,
352 self.i.a.m[3:-1])
353
354 # if a equal to -b return zero (+ve zero)
355 with m.Elif(s_nomatch & m_match & (self.i.a.e == self.i.b.e)):
356 m.d.comb += self.out_do_z.eq(1)
357 m.d.comb += self.out_z.zero(0)
358
359 # Denormalised Number checks
360 with m.Else():
361 m.d.comb += self.out_do_z.eq(0)
362
363 return m
364
365
366 class FPID:
367 def __init__(self, id_wid):
368 self.id_wid = id_wid
369 if self.id_wid:
370 self.in_mid = Signal(id_wid, reset_less=True)
371 self.out_mid = Signal(id_wid, reset_less=True)
372 else:
373 self.in_mid = None
374 self.out_mid = None
375
376 def idsync(self, m):
377 if self.id_wid is not None:
378 m.d.sync += self.out_mid.eq(self.in_mid)
379
380
381 class FPAddSpecialCases(FPState, FPID):
382 """ special cases: NaNs, infs, zeros, denormalised
383 NOTE: some of these are unique to add. see "Special Operations"
384 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
385 """
386
387 def __init__(self, width, id_wid):
388 FPState.__init__(self, "special_cases")
389 FPID.__init__(self, id_wid)
390 self.mod = FPAddSpecialCasesMod(width)
391 self.out_z = FPNumOut(width, False)
392 self.out_do_z = Signal(reset_less=True)
393
394 def setup(self, m, in_a, in_b, in_mid):
395 """ links module to inputs and outputs
396 """
397 self.mod.setup(m, in_a, in_b, self.out_do_z)
398 if self.in_mid is not None:
399 m.d.comb += self.in_mid.eq(in_mid)
400
401 def action(self, m):
402 self.idsync(m)
403 with m.If(self.out_do_z):
404 m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
405 m.next = "put_z"
406 with m.Else():
407 m.next = "denormalise"
408
409
410 class FPAddSpecialCasesDeNorm(FPState, FPID):
411 """ special cases: NaNs, infs, zeros, denormalised
412 NOTE: some of these are unique to add. see "Special Operations"
413 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
414 """
415
416 def __init__(self, width, id_wid):
417 FPState.__init__(self, "special_cases")
418 FPID.__init__(self, id_wid)
419 self.smod = FPAddSpecialCasesMod(width)
420 self.out_z = FPNumOut(width, False)
421 self.out_do_z = Signal(reset_less=True)
422
423 self.dmod = FPAddDeNormMod(width)
424 self.out_a = FPNumBase(width)
425 self.out_b = FPNumBase(width)
426
427 def setup(self, m, in_a, in_b, in_mid):
428 """ links module to inputs and outputs
429 """
430 self.smod.setup(m, in_a, in_b, self.out_do_z)
431 self.dmod.setup(m, in_a, in_b)
432 if self.in_mid is not None:
433 m.d.comb += self.in_mid.eq(in_mid)
434
435 def action(self, m):
436 self.idsync(m)
437 with m.If(self.out_do_z):
438 m.d.sync += self.out_z.v.eq(self.smod.out_z.v) # only take output
439 m.next = "put_z"
440 with m.Else():
441 m.next = "align"
442 m.d.sync += self.out_a.eq(self.dmod.out_a)
443 m.d.sync += self.out_b.eq(self.dmod.out_b)
444
445
446 class FPAddDeNormMod(FPState):
447
448 def __init__(self, width):
449 self.in_a = FPNumBase(width)
450 self.in_b = FPNumBase(width)
451 self.out_a = FPNumBase(width)
452 self.out_b = FPNumBase(width)
453
454 def setup(self, m, in_a, in_b):
455 """ links module to inputs and outputs
456 """
457 m.submodules.denormalise = self
458 m.d.comb += self.in_a.eq(in_a)
459 m.d.comb += self.in_b.eq(in_b)
460
461 def elaborate(self, platform):
462 m = Module()
463 m.submodules.denorm_in_a = self.in_a
464 m.submodules.denorm_in_b = self.in_b
465 m.submodules.denorm_out_a = self.out_a
466 m.submodules.denorm_out_b = self.out_b
467 # hmmm, don't like repeating identical code
468 m.d.comb += self.out_a.eq(self.in_a)
469 with m.If(self.in_a.exp_n127):
470 m.d.comb += self.out_a.e.eq(self.in_a.N126) # limit a exponent
471 with m.Else():
472 m.d.comb += self.out_a.m[-1].eq(1) # set top mantissa bit
473
474 m.d.comb += self.out_b.eq(self.in_b)
475 with m.If(self.in_b.exp_n127):
476 m.d.comb += self.out_b.e.eq(self.in_b.N126) # limit a exponent
477 with m.Else():
478 m.d.comb += self.out_b.m[-1].eq(1) # set top mantissa bit
479
480 return m
481
482
483 class FPAddDeNorm(FPState, FPID):
484
485 def __init__(self, width, id_wid):
486 FPState.__init__(self, "denormalise")
487 FPID.__init__(self, id_wid)
488 self.mod = FPAddDeNormMod(width)
489 self.out_a = FPNumBase(width)
490 self.out_b = FPNumBase(width)
491
492 def setup(self, m, in_a, in_b, in_mid):
493 """ links module to inputs and outputs
494 """
495 self.mod.setup(m, in_a, in_b)
496 if self.in_mid is not None:
497 m.d.comb += self.in_mid.eq(in_mid)
498
499 def action(self, m):
500 self.idsync(m)
501 # Denormalised Number checks
502 m.next = "align"
503 m.d.sync += self.out_a.eq(self.mod.out_a)
504 m.d.sync += self.out_b.eq(self.mod.out_b)
505
506
507 class FPAddAlignMultiMod(FPState):
508
509 def __init__(self, width):
510 self.in_a = FPNumBase(width)
511 self.in_b = FPNumBase(width)
512 self.out_a = FPNumIn(None, width)
513 self.out_b = FPNumIn(None, width)
514 self.exp_eq = Signal(reset_less=True)
515
516 def elaborate(self, platform):
517 # This one however (single-cycle) will do the shift
518 # in one go.
519
520 m = Module()
521
522 m.submodules.align_in_a = self.in_a
523 m.submodules.align_in_b = self.in_b
524 m.submodules.align_out_a = self.out_a
525 m.submodules.align_out_b = self.out_b
526
527 # NOTE: this does *not* do single-cycle multi-shifting,
528 # it *STAYS* in the align state until exponents match
529
530 # exponent of a greater than b: shift b down
531 m.d.comb += self.exp_eq.eq(0)
532 m.d.comb += self.out_a.eq(self.in_a)
533 m.d.comb += self.out_b.eq(self.in_b)
534 agtb = Signal(reset_less=True)
535 altb = Signal(reset_less=True)
536 m.d.comb += agtb.eq(self.in_a.e > self.in_b.e)
537 m.d.comb += altb.eq(self.in_a.e < self.in_b.e)
538 with m.If(agtb):
539 m.d.comb += self.out_b.shift_down(self.in_b)
540 # exponent of b greater than a: shift a down
541 with m.Elif(altb):
542 m.d.comb += self.out_a.shift_down(self.in_a)
543 # exponents equal: move to next stage.
544 with m.Else():
545 m.d.comb += self.exp_eq.eq(1)
546 return m
547
548
549 class FPAddAlignMulti(FPState, FPID):
550
551 def __init__(self, width, id_wid):
552 FPID.__init__(self, id_wid)
553 FPState.__init__(self, "align")
554 self.mod = FPAddAlignMultiMod(width)
555 self.out_a = FPNumIn(None, width)
556 self.out_b = FPNumIn(None, width)
557 self.exp_eq = Signal(reset_less=True)
558
559 def setup(self, m, in_a, in_b, in_mid):
560 """ links module to inputs and outputs
561 """
562 m.submodules.align = self.mod
563 m.d.comb += self.mod.in_a.eq(in_a)
564 m.d.comb += self.mod.in_b.eq(in_b)
565 #m.d.comb += self.out_a.eq(self.mod.out_a)
566 #m.d.comb += self.out_b.eq(self.mod.out_b)
567 m.d.comb += self.exp_eq.eq(self.mod.exp_eq)
568 if self.in_mid is not None:
569 m.d.comb += self.in_mid.eq(in_mid)
570
571 def action(self, m):
572 self.idsync(m)
573 m.d.sync += self.out_a.eq(self.mod.out_a)
574 m.d.sync += self.out_b.eq(self.mod.out_b)
575 with m.If(self.exp_eq):
576 m.next = "add_0"
577
578
579 class FPAddAlignSingleMod:
580
581 def __init__(self, width):
582 self.width = width
583 self.in_a = FPNumBase(width)
584 self.in_b = FPNumBase(width)
585 self.out_a = FPNumIn(None, width)
586 self.out_b = FPNumIn(None, width)
587
588 def setup(self, m, in_a, in_b):
589 """ links module to inputs and outputs
590 """
591 m.submodules.align = self
592 m.d.comb += self.in_a.eq(in_a)
593 m.d.comb += self.in_b.eq(in_b)
594
595 def elaborate(self, platform):
596 """ Aligns A against B or B against A, depending on which has the
597 greater exponent. This is done in a *single* cycle using
598 variable-width bit-shift
599
600 the shifter used here is quite expensive in terms of gates.
601 Mux A or B in (and out) into temporaries, as only one of them
602 needs to be aligned against the other
603 """
604 m = Module()
605
606 m.submodules.align_in_a = self.in_a
607 m.submodules.align_in_b = self.in_b
608 m.submodules.align_out_a = self.out_a
609 m.submodules.align_out_b = self.out_b
610
611 # temporary (muxed) input and output to be shifted
612 t_inp = FPNumBase(self.width)
613 t_out = FPNumIn(None, self.width)
614 espec = (len(self.in_a.e), True)
615 msr = MultiShiftRMerge(self.in_a.m_width, espec)
616 m.submodules.align_t_in = t_inp
617 m.submodules.align_t_out = t_out
618 m.submodules.multishift_r = msr
619
620 ediff = Signal(espec, reset_less=True)
621 ediffr = Signal(espec, reset_less=True)
622 tdiff = Signal(espec, reset_less=True)
623 elz = Signal(reset_less=True)
624 egz = Signal(reset_less=True)
625
626 # connect multi-shifter to t_inp/out mantissa (and tdiff)
627 m.d.comb += msr.inp.eq(t_inp.m)
628 m.d.comb += msr.diff.eq(tdiff)
629 m.d.comb += t_out.m.eq(msr.m)
630 m.d.comb += t_out.e.eq(t_inp.e + tdiff)
631 m.d.comb += t_out.s.eq(t_inp.s)
632
633 m.d.comb += ediff.eq(self.in_a.e - self.in_b.e)
634 m.d.comb += ediffr.eq(self.in_b.e - self.in_a.e)
635 m.d.comb += elz.eq(self.in_a.e < self.in_b.e)
636 m.d.comb += egz.eq(self.in_a.e > self.in_b.e)
637
638 # default: A-exp == B-exp, A and B untouched (fall through)
639 m.d.comb += self.out_a.eq(self.in_a)
640 m.d.comb += self.out_b.eq(self.in_b)
641 # only one shifter (muxed)
642 #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
643 # exponent of a greater than b: shift b down
644 with m.If(egz):
645 m.d.comb += [t_inp.eq(self.in_b),
646 tdiff.eq(ediff),
647 self.out_b.eq(t_out),
648 self.out_b.s.eq(self.in_b.s), # whoops forgot sign
649 ]
650 # exponent of b greater than a: shift a down
651 with m.Elif(elz):
652 m.d.comb += [t_inp.eq(self.in_a),
653 tdiff.eq(ediffr),
654 self.out_a.eq(t_out),
655 self.out_a.s.eq(self.in_a.s), # whoops forgot sign
656 ]
657 return m
658
659
660 class FPAddAlignSingle(FPState, FPID):
661
662 def __init__(self, width, id_wid):
663 FPState.__init__(self, "align")
664 FPID.__init__(self, id_wid)
665 self.mod = FPAddAlignSingleMod(width)
666 self.out_a = FPNumIn(None, width)
667 self.out_b = FPNumIn(None, width)
668
669 def setup(self, m, in_a, in_b, in_mid):
670 """ links module to inputs and outputs
671 """
672 self.mod.setup(m, in_a, in_b)
673 if self.in_mid is not None:
674 m.d.comb += self.in_mid.eq(in_mid)
675
676 def action(self, m):
677 self.idsync(m)
678 # NOTE: could be done as comb
679 m.d.sync += self.out_a.eq(self.mod.out_a)
680 m.d.sync += self.out_b.eq(self.mod.out_b)
681 m.next = "add_0"
682
683
684 class FPAddAlignSingleAdd(FPState, FPID):
685
686 def __init__(self, width, id_wid):
687 FPState.__init__(self, "align")
688 FPID.__init__(self, id_wid)
689 self.mod = FPAddAlignSingleMod(width)
690 self.out_a = FPNumIn(None, width)
691 self.out_b = FPNumIn(None, width)
692
693 self.a0mod = FPAddStage0Mod(width)
694 self.a0_out_z = FPNumBase(width, False)
695 self.out_tot = Signal(self.a0_out_z.m_width + 4, reset_less=True)
696 self.a0_out_z = FPNumBase(width, False)
697
698 self.a1mod = FPAddStage1Mod(width)
699 self.out_z = FPNumBase(width, False)
700 self.out_of = Overflow()
701
702 def setup(self, m, in_a, in_b, in_mid):
703 """ links module to inputs and outputs
704 """
705 self.mod.setup(m, in_a, in_b)
706 m.d.comb += self.out_a.eq(self.mod.out_a)
707 m.d.comb += self.out_b.eq(self.mod.out_b)
708
709 self.a0mod.setup(m, self.out_a, self.out_b)
710 m.d.comb += self.a0_out_z.eq(self.a0mod.out_z)
711 m.d.comb += self.out_tot.eq(self.a0mod.out_tot)
712
713 self.a1mod.setup(m, self.out_tot, self.a0_out_z)
714
715 if self.in_mid is not None:
716 m.d.comb += self.in_mid.eq(in_mid)
717
718 def action(self, m):
719 self.idsync(m)
720 m.d.sync += self.out_of.eq(self.a1mod.out_of)
721 m.d.sync += self.out_z.eq(self.a1mod.out_z)
722 m.next = "normalise_1"
723
724
725 class FPAddStage0Mod:
726
727 def __init__(self, width):
728 self.in_a = FPNumBase(width)
729 self.in_b = FPNumBase(width)
730 self.in_z = FPNumBase(width, False)
731 self.out_z = FPNumBase(width, False)
732 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
733
734 def setup(self, m, in_a, in_b):
735 """ links module to inputs and outputs
736 """
737 m.submodules.add0 = self
738 m.d.comb += self.in_a.eq(in_a)
739 m.d.comb += self.in_b.eq(in_b)
740
741 def elaborate(self, platform):
742 m = Module()
743 m.submodules.add0_in_a = self.in_a
744 m.submodules.add0_in_b = self.in_b
745 m.submodules.add0_out_z = self.out_z
746
747 m.d.comb += self.out_z.e.eq(self.in_a.e)
748
749 # store intermediate tests (and zero-extended mantissas)
750 seq = Signal(reset_less=True)
751 mge = Signal(reset_less=True)
752 am0 = Signal(len(self.in_a.m)+1, reset_less=True)
753 bm0 = Signal(len(self.in_b.m)+1, reset_less=True)
754 m.d.comb += [seq.eq(self.in_a.s == self.in_b.s),
755 mge.eq(self.in_a.m >= self.in_b.m),
756 am0.eq(Cat(self.in_a.m, 0)),
757 bm0.eq(Cat(self.in_b.m, 0))
758 ]
759 # same-sign (both negative or both positive) add mantissas
760 with m.If(seq):
761 m.d.comb += [
762 self.out_tot.eq(am0 + bm0),
763 self.out_z.s.eq(self.in_a.s)
764 ]
765 # a mantissa greater than b, use a
766 with m.Elif(mge):
767 m.d.comb += [
768 self.out_tot.eq(am0 - bm0),
769 self.out_z.s.eq(self.in_a.s)
770 ]
771 # b mantissa greater than a, use b
772 with m.Else():
773 m.d.comb += [
774 self.out_tot.eq(bm0 - am0),
775 self.out_z.s.eq(self.in_b.s)
776 ]
777 return m
778
779
780 class FPAddStage0(FPState, FPID):
781 """ First stage of add. covers same-sign (add) and subtract
782 special-casing when mantissas are greater or equal, to
783 give greatest accuracy.
784 """
785
786 def __init__(self, width, id_wid):
787 FPState.__init__(self, "add_0")
788 FPID.__init__(self, id_wid)
789 self.mod = FPAddStage0Mod(width)
790 self.out_z = FPNumBase(width, False)
791 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
792
793 def setup(self, m, in_a, in_b, in_mid):
794 """ links module to inputs and outputs
795 """
796 self.mod.setup(m, in_a, in_b)
797 if self.in_mid is not None:
798 m.d.comb += self.in_mid.eq(in_mid)
799
800 def action(self, m):
801 self.idsync(m)
802 # NOTE: these could be done as combinatorial (merge add0+add1)
803 m.d.sync += self.out_z.eq(self.mod.out_z)
804 m.d.sync += self.out_tot.eq(self.mod.out_tot)
805 m.next = "add_1"
806
807
808 class FPAddStage1Mod(FPState):
809 """ Second stage of add: preparation for normalisation.
810 detects when tot sum is too big (tot[27] is kinda a carry bit)
811 """
812
813 def __init__(self, width):
814 self.out_norm = Signal(reset_less=True)
815 self.in_z = FPNumBase(width, False)
816 self.in_tot = Signal(self.in_z.m_width + 4, reset_less=True)
817 self.out_z = FPNumBase(width, False)
818 self.out_of = Overflow()
819
820 def setup(self, m, in_tot, in_z):
821 """ links module to inputs and outputs
822 """
823 m.submodules.add1 = self
824 m.submodules.add1_out_overflow = self.out_of
825
826 m.d.comb += self.in_z.eq(in_z)
827 m.d.comb += self.in_tot.eq(in_tot)
828
829 def elaborate(self, platform):
830 m = Module()
831 #m.submodules.norm1_in_overflow = self.in_of
832 #m.submodules.norm1_out_overflow = self.out_of
833 #m.submodules.norm1_in_z = self.in_z
834 #m.submodules.norm1_out_z = self.out_z
835 m.d.comb += self.out_z.eq(self.in_z)
836 # tot[-1] (MSB) gets set when the sum overflows. shift result down
837 with m.If(self.in_tot[-1]):
838 m.d.comb += [
839 self.out_z.m.eq(self.in_tot[4:]),
840 self.out_of.m0.eq(self.in_tot[4]),
841 self.out_of.guard.eq(self.in_tot[3]),
842 self.out_of.round_bit.eq(self.in_tot[2]),
843 self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
844 self.out_z.e.eq(self.in_z.e + 1)
845 ]
846 # tot[-1] (MSB) zero case
847 with m.Else():
848 m.d.comb += [
849 self.out_z.m.eq(self.in_tot[3:]),
850 self.out_of.m0.eq(self.in_tot[3]),
851 self.out_of.guard.eq(self.in_tot[2]),
852 self.out_of.round_bit.eq(self.in_tot[1]),
853 self.out_of.sticky.eq(self.in_tot[0])
854 ]
855 return m
856
857
858 class FPAddStage1(FPState, FPID):
859
860 def __init__(self, width, id_wid):
861 FPState.__init__(self, "add_1")
862 FPID.__init__(self, id_wid)
863 self.mod = FPAddStage1Mod(width)
864 self.out_z = FPNumBase(width, False)
865 self.out_of = Overflow()
866 self.norm_stb = Signal()
867
868 def setup(self, m, in_tot, in_z, in_mid):
869 """ links module to inputs and outputs
870 """
871 self.mod.setup(m, in_tot, in_z)
872
873 m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state
874
875 if self.in_mid is not None:
876 m.d.comb += self.in_mid.eq(in_mid)
877
878 def action(self, m):
879 self.idsync(m)
880 m.d.sync += self.out_of.eq(self.mod.out_of)
881 m.d.sync += self.out_z.eq(self.mod.out_z)
882 m.d.sync += self.norm_stb.eq(1)
883 m.next = "normalise_1"
884
885
886 class FPNormaliseModSingle:
887
888 def __init__(self, width):
889 self.width = width
890 self.in_z = FPNumBase(width, False)
891 self.out_z = FPNumBase(width, False)
892
893 def setup(self, m, in_z, out_z, modname):
894 """ links module to inputs and outputs
895 """
896 m.submodules.normalise = self
897 m.d.comb += self.in_z.eq(in_z)
898 m.d.comb += out_z.eq(self.out_z)
899
900 def elaborate(self, platform):
901 m = Module()
902
903 mwid = self.out_z.m_width+2
904 pe = PriorityEncoder(mwid)
905 m.submodules.norm_pe = pe
906
907 m.submodules.norm1_out_z = self.out_z
908 m.submodules.norm1_in_z = self.in_z
909
910 in_z = FPNumBase(self.width, False)
911 in_of = Overflow()
912 m.submodules.norm1_insel_z = in_z
913 m.submodules.norm1_insel_overflow = in_of
914
915 espec = (len(in_z.e), True)
916 ediff_n126 = Signal(espec, reset_less=True)
917 msr = MultiShiftRMerge(mwid, espec)
918 m.submodules.multishift_r = msr
919
920 m.d.comb += in_z.eq(self.in_z)
921 m.d.comb += in_of.eq(self.in_of)
922 # initialise out from in (overridden below)
923 m.d.comb += self.out_z.eq(in_z)
924 m.d.comb += self.out_of.eq(in_of)
925 # normalisation increase/decrease conditions
926 decrease = Signal(reset_less=True)
927 m.d.comb += decrease.eq(in_z.m_msbzero)
928 # decrease exponent
929 with m.If(decrease):
930 # *sigh* not entirely obvious: count leading zeros (clz)
931 # with a PriorityEncoder: to find from the MSB
932 # we reverse the order of the bits.
933 temp_m = Signal(mwid, reset_less=True)
934 temp_s = Signal(mwid+1, reset_less=True)
935 clz = Signal((len(in_z.e), True), reset_less=True)
936 m.d.comb += [
937 # cat round and guard bits back into the mantissa
938 temp_m.eq(Cat(in_of.round_bit, in_of.guard, in_z.m)),
939 pe.i.eq(temp_m[::-1]), # inverted
940 clz.eq(pe.o), # count zeros from MSB down
941 temp_s.eq(temp_m << clz), # shift mantissa UP
942 self.out_z.e.eq(in_z.e - clz), # DECREASE exponent
943 self.out_z.m.eq(temp_s[2:]), # exclude bits 0&1
944 ]
945
946 return m
947
948
949 class FPNorm1ModSingle:
950
951 def __init__(self, width):
952 self.width = width
953 self.out_norm = Signal(reset_less=True)
954 self.in_z = FPNumBase(width, False)
955 self.in_of = Overflow()
956 self.out_z = FPNumBase(width, False)
957 self.out_of = Overflow()
958
959 def setup(self, m, in_z, in_of, out_z):
960 """ links module to inputs and outputs
961 """
962 m.submodules.normalise_1 = self
963
964 m.d.comb += self.in_z.eq(in_z)
965 m.d.comb += self.in_of.eq(in_of)
966
967 m.d.comb += out_z.eq(self.out_z)
968
969 def elaborate(self, platform):
970 m = Module()
971
972 mwid = self.out_z.m_width+2
973 pe = PriorityEncoder(mwid)
974 m.submodules.norm_pe = pe
975
976 m.submodules.norm1_out_z = self.out_z
977 m.submodules.norm1_out_overflow = self.out_of
978 m.submodules.norm1_in_z = self.in_z
979 m.submodules.norm1_in_overflow = self.in_of
980
981 in_z = FPNumBase(self.width, False)
982 in_of = Overflow()
983 m.submodules.norm1_insel_z = in_z
984 m.submodules.norm1_insel_overflow = in_of
985
986 espec = (len(in_z.e), True)
987 ediff_n126 = Signal(espec, reset_less=True)
988 msr = MultiShiftRMerge(mwid, espec)
989 m.submodules.multishift_r = msr
990
991 m.d.comb += in_z.eq(self.in_z)
992 m.d.comb += in_of.eq(self.in_of)
993 # initialise out from in (overridden below)
994 m.d.comb += self.out_z.eq(in_z)
995 m.d.comb += self.out_of.eq(in_of)
996 # normalisation increase/decrease conditions
997 decrease = Signal(reset_less=True)
998 increase = Signal(reset_less=True)
999 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
1000 m.d.comb += increase.eq(in_z.exp_lt_n126)
1001 # decrease exponent
1002 with m.If(decrease):
1003 # *sigh* not entirely obvious: count leading zeros (clz)
1004 # with a PriorityEncoder: to find from the MSB
1005 # we reverse the order of the bits.
1006 temp_m = Signal(mwid, reset_less=True)
1007 temp_s = Signal(mwid+1, reset_less=True)
1008 clz = Signal((len(in_z.e), True), reset_less=True)
1009 # make sure that the amount to decrease by does NOT
1010 # go below the minimum non-INF/NaN exponent
1011 limclz = Mux(in_z.exp_sub_n126 > pe.o, pe.o,
1012 in_z.exp_sub_n126)
1013 m.d.comb += [
1014 # cat round and guard bits back into the mantissa
1015 temp_m.eq(Cat(in_of.round_bit, in_of.guard, in_z.m)),
1016 pe.i.eq(temp_m[::-1]), # inverted
1017 clz.eq(limclz), # count zeros from MSB down
1018 temp_s.eq(temp_m << clz), # shift mantissa UP
1019 self.out_z.e.eq(in_z.e - clz), # DECREASE exponent
1020 self.out_z.m.eq(temp_s[2:]), # exclude bits 0&1
1021 self.out_of.m0.eq(temp_s[2]), # copy of mantissa[0]
1022 # overflow in bits 0..1: got shifted too (leave sticky)
1023 self.out_of.guard.eq(temp_s[1]), # guard
1024 self.out_of.round_bit.eq(temp_s[0]), # round
1025 ]
1026 # increase exponent
1027 with m.Elif(increase):
1028 temp_m = Signal(mwid+1, reset_less=True)
1029 m.d.comb += [
1030 temp_m.eq(Cat(in_of.sticky, in_of.round_bit, in_of.guard,
1031 in_z.m)),
1032 ediff_n126.eq(in_z.N126 - in_z.e),
1033 # connect multi-shifter to inp/out mantissa (and ediff)
1034 msr.inp.eq(temp_m),
1035 msr.diff.eq(ediff_n126),
1036 self.out_z.m.eq(msr.m[3:]),
1037 self.out_of.m0.eq(temp_s[3]), # copy of mantissa[0]
1038 # overflow in bits 0..1: got shifted too (leave sticky)
1039 self.out_of.guard.eq(temp_s[2]), # guard
1040 self.out_of.round_bit.eq(temp_s[1]), # round
1041 self.out_of.sticky.eq(temp_s[0]), # sticky
1042 self.out_z.e.eq(in_z.e + ediff_n126),
1043 ]
1044
1045 return m
1046
1047
1048 class FPNorm1ModMulti:
1049
1050 def __init__(self, width, single_cycle=True):
1051 self.width = width
1052 self.in_select = Signal(reset_less=True)
1053 self.out_norm = Signal(reset_less=True)
1054 self.in_z = FPNumBase(width, False)
1055 self.in_of = Overflow()
1056 self.temp_z = FPNumBase(width, False)
1057 self.temp_of = Overflow()
1058 self.out_z = FPNumBase(width, False)
1059 self.out_of = Overflow()
1060
1061 def elaborate(self, platform):
1062 m = Module()
1063
1064 m.submodules.norm1_out_z = self.out_z
1065 m.submodules.norm1_out_overflow = self.out_of
1066 m.submodules.norm1_temp_z = self.temp_z
1067 m.submodules.norm1_temp_of = self.temp_of
1068 m.submodules.norm1_in_z = self.in_z
1069 m.submodules.norm1_in_overflow = self.in_of
1070
1071 in_z = FPNumBase(self.width, False)
1072 in_of = Overflow()
1073 m.submodules.norm1_insel_z = in_z
1074 m.submodules.norm1_insel_overflow = in_of
1075
1076 # select which of temp or in z/of to use
1077 with m.If(self.in_select):
1078 m.d.comb += in_z.eq(self.in_z)
1079 m.d.comb += in_of.eq(self.in_of)
1080 with m.Else():
1081 m.d.comb += in_z.eq(self.temp_z)
1082 m.d.comb += in_of.eq(self.temp_of)
1083 # initialise out from in (overridden below)
1084 m.d.comb += self.out_z.eq(in_z)
1085 m.d.comb += self.out_of.eq(in_of)
1086 # normalisation increase/decrease conditions
1087 decrease = Signal(reset_less=True)
1088 increase = Signal(reset_less=True)
1089 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
1090 m.d.comb += increase.eq(in_z.exp_lt_n126)
1091 m.d.comb += self.out_norm.eq(decrease | increase) # loop-end
1092 # decrease exponent
1093 with m.If(decrease):
1094 m.d.comb += [
1095 self.out_z.e.eq(in_z.e - 1), # DECREASE exponent
1096 self.out_z.m.eq(in_z.m << 1), # shift mantissa UP
1097 self.out_z.m[0].eq(in_of.guard), # steal guard (was tot[2])
1098 self.out_of.guard.eq(in_of.round_bit), # round (was tot[1])
1099 self.out_of.round_bit.eq(0), # reset round bit
1100 self.out_of.m0.eq(in_of.guard),
1101 ]
1102 # increase exponent
1103 with m.Elif(increase):
1104 m.d.comb += [
1105 self.out_z.e.eq(in_z.e + 1), # INCREASE exponent
1106 self.out_z.m.eq(in_z.m >> 1), # shift mantissa DOWN
1107 self.out_of.guard.eq(in_z.m[0]),
1108 self.out_of.m0.eq(in_z.m[1]),
1109 self.out_of.round_bit.eq(in_of.guard),
1110 self.out_of.sticky.eq(in_of.sticky | in_of.round_bit)
1111 ]
1112
1113 return m
1114
1115
1116 class FPNorm1Single(FPState, FPID):
1117
1118 def __init__(self, width, id_wid, single_cycle=True):
1119 FPID.__init__(self, id_wid)
1120 FPState.__init__(self, "normalise_1")
1121 self.mod = FPNorm1ModSingle(width)
1122 self.out_norm = Signal(reset_less=True)
1123 self.out_z = FPNumBase(width)
1124 self.out_roundz = Signal(reset_less=True)
1125
1126 def setup(self, m, in_z, in_of, in_mid):
1127 """ links module to inputs and outputs
1128 """
1129 self.mod.setup(m, in_z, in_of, self.out_z)
1130
1131 if self.in_mid is not None:
1132 m.d.comb += self.in_mid.eq(in_mid)
1133
1134 def action(self, m):
1135 self.idsync(m)
1136 m.d.sync += self.out_roundz.eq(self.mod.out_of.roundz)
1137 m.next = "round"
1138
1139
1140 class FPNorm1Multi(FPState, FPID):
1141
1142 def __init__(self, width, id_wid):
1143 FPID.__init__(self, id_wid)
1144 FPState.__init__(self, "normalise_1")
1145 self.mod = FPNorm1ModMulti(width)
1146 self.stb = Signal(reset_less=True)
1147 self.ack = Signal(reset=0, reset_less=True)
1148 self.out_norm = Signal(reset_less=True)
1149 self.in_accept = Signal(reset_less=True)
1150 self.temp_z = FPNumBase(width)
1151 self.temp_of = Overflow()
1152 self.out_z = FPNumBase(width)
1153 self.out_roundz = Signal(reset_less=True)
1154
1155 def setup(self, m, in_z, in_of, norm_stb, in_mid):
1156 """ links module to inputs and outputs
1157 """
1158 self.mod.setup(m, in_z, in_of, norm_stb,
1159 self.in_accept, self.temp_z, self.temp_of,
1160 self.out_z, self.out_norm)
1161
1162 m.d.comb += self.stb.eq(norm_stb)
1163 m.d.sync += self.ack.eq(0) # sets to zero when not in normalise_1 state
1164
1165 if self.in_mid is not None:
1166 m.d.comb += self.in_mid.eq(in_mid)
1167
1168 def action(self, m):
1169 self.idsync(m)
1170 m.d.comb += self.in_accept.eq((~self.ack) & (self.stb))
1171 m.d.sync += self.temp_of.eq(self.mod.out_of)
1172 m.d.sync += self.temp_z.eq(self.out_z)
1173 with m.If(self.out_norm):
1174 with m.If(self.in_accept):
1175 m.d.sync += [
1176 self.ack.eq(1),
1177 ]
1178 with m.Else():
1179 m.d.sync += self.ack.eq(0)
1180 with m.Else():
1181 # normalisation not required (or done).
1182 m.next = "round"
1183 m.d.sync += self.ack.eq(1)
1184 m.d.sync += self.out_roundz.eq(self.mod.out_of.roundz)
1185
1186
1187 class FPNormToPack(FPState, FPID):
1188
1189 def __init__(self, width, id_wid):
1190 FPID.__init__(self, id_wid)
1191 FPState.__init__(self, "normalise_1")
1192 self.width = width
1193
1194 def setup(self, m, in_z, in_of, in_mid):
1195 """ links module to inputs and outputs
1196 """
1197
1198 # Normalisation (chained to input in_z+in_of)
1199 nmod = FPNorm1ModSingle(self.width)
1200 n_out_z = FPNumBase(self.width)
1201 n_out_roundz = Signal(reset_less=True)
1202 nmod.setup(m, in_z, in_of, n_out_z)
1203
1204 # Rounding (chained to normalisation)
1205 rmod = FPRoundMod(self.width)
1206 r_out_z = FPNumBase(self.width)
1207 rmod.setup(m, n_out_z, n_out_roundz)
1208 m.d.comb += n_out_roundz.eq(nmod.out_of.roundz)
1209 m.d.comb += r_out_z.eq(rmod.out_z)
1210
1211 # Corrections (chained to rounding)
1212 cmod = FPCorrectionsMod(self.width)
1213 c_out_z = FPNumBase(self.width)
1214 cmod.setup(m, r_out_z)
1215 m.d.comb += c_out_z.eq(cmod.out_z)
1216
1217 # Pack (chained to corrections)
1218 self.pmod = FPPackMod(self.width)
1219 self.out_z = FPNumBase(self.width)
1220 self.pmod.setup(m, c_out_z)
1221
1222 # Multiplex ID
1223 if self.in_mid is not None:
1224 m.d.comb += self.in_mid.eq(in_mid)
1225
1226 def action(self, m):
1227 self.idsync(m) # copies incoming ID to outgoing
1228 m.d.sync += self.out_z.v.eq(self.pmod.out_z.v) # outputs packed result
1229 m.next = "pack_put_z"
1230
1231
1232 class FPRoundMod:
1233
1234 def __init__(self, width):
1235 self.in_roundz = Signal(reset_less=True)
1236 self.in_z = FPNumBase(width, False)
1237 self.out_z = FPNumBase(width, False)
1238
1239 def setup(self, m, in_z, roundz):
1240 m.submodules.roundz = self
1241
1242 m.d.comb += self.in_z.eq(in_z)
1243 m.d.comb += self.in_roundz.eq(roundz)
1244
1245 def elaborate(self, platform):
1246 m = Module()
1247 m.d.comb += self.out_z.eq(self.in_z)
1248 with m.If(self.in_roundz):
1249 m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
1250 with m.If(self.in_z.m == self.in_z.m1s): # all 1s
1251 m.d.comb += self.out_z.e.eq(self.in_z.e + 1) # exponent up
1252 return m
1253
1254
1255 class FPRound(FPState, FPID):
1256
1257 def __init__(self, width, id_wid):
1258 FPState.__init__(self, "round")
1259 FPID.__init__(self, id_wid)
1260 self.mod = FPRoundMod(width)
1261 self.out_z = FPNumBase(width)
1262
1263 def setup(self, m, in_z, roundz, in_mid):
1264 """ links module to inputs and outputs
1265 """
1266 self.mod.setup(m, in_z, roundz)
1267
1268 if self.in_mid is not None:
1269 m.d.comb += self.in_mid.eq(in_mid)
1270
1271 def action(self, m):
1272 self.idsync(m)
1273 m.d.sync += self.out_z.eq(self.mod.out_z)
1274 m.next = "corrections"
1275
1276
1277 class FPCorrectionsMod:
1278
1279 def __init__(self, width):
1280 self.in_z = FPNumOut(width, False)
1281 self.out_z = FPNumOut(width, False)
1282
1283 def setup(self, m, in_z):
1284 """ links module to inputs and outputs
1285 """
1286 m.submodules.corrections = self
1287 m.d.comb += self.in_z.eq(in_z)
1288
1289 def elaborate(self, platform):
1290 m = Module()
1291 m.submodules.corr_in_z = self.in_z
1292 m.submodules.corr_out_z = self.out_z
1293 m.d.comb += self.out_z.eq(self.in_z)
1294 with m.If(self.in_z.is_denormalised):
1295 m.d.comb += self.out_z.e.eq(self.in_z.N127)
1296 return m
1297
1298
1299 class FPCorrections(FPState, FPID):
1300
1301 def __init__(self, width, id_wid):
1302 FPState.__init__(self, "corrections")
1303 FPID.__init__(self, id_wid)
1304 self.mod = FPCorrectionsMod(width)
1305 self.out_z = FPNumBase(width)
1306
1307 def setup(self, m, in_z, in_mid):
1308 """ links module to inputs and outputs
1309 """
1310 self.mod.setup(m, in_z)
1311 if self.in_mid is not None:
1312 m.d.comb += self.in_mid.eq(in_mid)
1313
1314 def action(self, m):
1315 self.idsync(m)
1316 m.d.sync += self.out_z.eq(self.mod.out_z)
1317 m.next = "pack"
1318
1319
1320 class FPPackMod:
1321
1322 def __init__(self, width):
1323 self.in_z = FPNumOut(width, False)
1324 self.out_z = FPNumOut(width, False)
1325
1326 def setup(self, m, in_z):
1327 """ links module to inputs and outputs
1328 """
1329 m.submodules.pack = self
1330 m.d.comb += self.in_z.eq(in_z)
1331
1332 def elaborate(self, platform):
1333 m = Module()
1334 m.submodules.pack_in_z = self.in_z
1335 with m.If(self.in_z.is_overflowed):
1336 m.d.comb += self.out_z.inf(self.in_z.s)
1337 with m.Else():
1338 m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
1339 return m
1340
1341
1342 class FPPack(FPState, FPID):
1343
1344 def __init__(self, width, id_wid):
1345 FPState.__init__(self, "pack")
1346 FPID.__init__(self, id_wid)
1347 self.mod = FPPackMod(width)
1348 self.out_z = FPNumOut(width, False)
1349
1350 def setup(self, m, in_z, in_mid):
1351 """ links module to inputs and outputs
1352 """
1353 self.mod.setup(m, in_z)
1354 if self.in_mid is not None:
1355 m.d.comb += self.in_mid.eq(in_mid)
1356
1357 def action(self, m):
1358 self.idsync(m)
1359 m.d.sync += self.out_z.v.eq(self.mod.out_z.v)
1360 m.next = "pack_put_z"
1361
1362
1363 class FPPutZ(FPState):
1364
1365 def __init__(self, state, in_z, out_z, in_mid, out_mid, to_state=None):
1366 FPState.__init__(self, state)
1367 if to_state is None:
1368 to_state = "get_ops"
1369 self.to_state = to_state
1370 self.in_z = in_z
1371 self.out_z = out_z
1372 self.in_mid = in_mid
1373 self.out_mid = out_mid
1374
1375 def action(self, m):
1376 if self.in_mid is not None:
1377 m.d.sync += self.out_mid.eq(self.in_mid)
1378 m.d.sync += [
1379 self.out_z.v.eq(self.in_z.v)
1380 ]
1381 with m.If(self.out_z.stb & self.out_z.ack):
1382 m.d.sync += self.out_z.stb.eq(0)
1383 m.next = self.to_state
1384 with m.Else():
1385 m.d.sync += self.out_z.stb.eq(1)
1386
1387
1388 class FPPutZIdx(FPState):
1389
1390 def __init__(self, state, in_z, out_zs, in_mid, to_state=None):
1391 FPState.__init__(self, state)
1392 if to_state is None:
1393 to_state = "get_ops"
1394 self.to_state = to_state
1395 self.in_z = in_z
1396 self.out_zs = out_zs
1397 self.in_mid = in_mid
1398
1399 def action(self, m):
1400 outz_stb = Signal(reset_less=True)
1401 outz_ack = Signal(reset_less=True)
1402 m.d.comb += [outz_stb.eq(self.out_zs[self.in_mid].stb),
1403 outz_ack.eq(self.out_zs[self.in_mid].ack),
1404 ]
1405 m.d.sync += [
1406 self.out_zs[self.in_mid].v.eq(self.in_z.v)
1407 ]
1408 with m.If(outz_stb & outz_ack):
1409 m.d.sync += self.out_zs[self.in_mid].stb.eq(0)
1410 m.next = self.to_state
1411 with m.Else():
1412 m.d.sync += self.out_zs[self.in_mid].stb.eq(1)
1413
1414
1415 class FPADDBaseMod(FPID):
1416
1417 def __init__(self, width, id_wid=None, single_cycle=False, compact=True):
1418 """ IEEE754 FP Add
1419
1420 * width: bit-width of IEEE754. supported: 16, 32, 64
1421 * id_wid: an identifier that is sync-connected to the input
1422 * single_cycle: True indicates each stage to complete in 1 clock
1423 * compact: True indicates a reduced number of stages
1424 """
1425 FPID.__init__(self, id_wid)
1426 self.width = width
1427 self.single_cycle = single_cycle
1428 self.compact = compact
1429
1430 self.in_t = Trigger()
1431 self.in_a = Signal(width)
1432 self.in_b = Signal(width)
1433 self.out_z = FPOp(width)
1434
1435 self.states = []
1436
1437 def add_state(self, state):
1438 self.states.append(state)
1439 return state
1440
1441 def get_fragment(self, platform=None):
1442 """ creates the HDL code-fragment for FPAdd
1443 """
1444 m = Module()
1445 m.submodules.out_z = self.out_z
1446 m.submodules.in_t = self.in_t
1447 if self.compact:
1448 self.get_compact_fragment(m, platform)
1449 else:
1450 self.get_longer_fragment(m, platform)
1451
1452 with m.FSM() as fsm:
1453
1454 for state in self.states:
1455 with m.State(state.state_from):
1456 state.action(m)
1457
1458 return m
1459
1460 def get_longer_fragment(self, m, platform=None):
1461
1462 get = self.add_state(FPGet2Op("get_ops", "special_cases",
1463 self.in_a, self.in_b, self.width))
1464 get.setup(m, self.in_a, self.in_b, self.in_t.stb, self.in_t.ack)
1465 a = get.out_op1
1466 b = get.out_op2
1467
1468 sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid))
1469 sc.setup(m, a, b, self.in_mid)
1470
1471 dn = self.add_state(FPAddDeNorm(self.width, self.id_wid))
1472 dn.setup(m, a, b, sc.in_mid)
1473
1474 if self.single_cycle:
1475 alm = self.add_state(FPAddAlignSingle(self.width, self.id_wid))
1476 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
1477 else:
1478 alm = self.add_state(FPAddAlignMulti(self.width, self.id_wid))
1479 alm.setup(m, dn.out_a, dn.out_b, dn.in_mid)
1480
1481 add0 = self.add_state(FPAddStage0(self.width, self.id_wid))
1482 add0.setup(m, alm.out_a, alm.out_b, alm.in_mid)
1483
1484 add1 = self.add_state(FPAddStage1(self.width, self.id_wid))
1485 add1.setup(m, add0.out_tot, add0.out_z, add0.in_mid)
1486
1487 if self.single_cycle:
1488 n1 = self.add_state(FPNorm1Single(self.width, self.id_wid))
1489 n1.setup(m, add1.out_z, add1.out_of, add0.in_mid)
1490 else:
1491 n1 = self.add_state(FPNorm1Multi(self.width, self.id_wid))
1492 n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb, add0.in_mid)
1493
1494 rn = self.add_state(FPRound(self.width, self.id_wid))
1495 rn.setup(m, n1.out_z, n1.out_roundz, n1.in_mid)
1496
1497 cor = self.add_state(FPCorrections(self.width, self.id_wid))
1498 cor.setup(m, rn.out_z, rn.in_mid)
1499
1500 pa = self.add_state(FPPack(self.width, self.id_wid))
1501 pa.setup(m, cor.out_z, rn.in_mid)
1502
1503 ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z,
1504 pa.in_mid, self.out_mid))
1505
1506 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
1507 pa.in_mid, self.out_mid))
1508
1509 def get_compact_fragment(self, m, platform=None):
1510
1511 get = self.add_state(FPGet2Op("get_ops", "special_cases",
1512 self.in_a, self.in_b, self.width))
1513 get.setup(m, self.in_a, self.in_b, self.in_t.stb, self.in_t.ack)
1514 a = get.out_op1
1515 b = get.out_op2
1516
1517 sc = self.add_state(FPAddSpecialCasesDeNorm(self.width, self.id_wid))
1518 sc.setup(m, a, b, self.in_mid)
1519
1520 alm = self.add_state(FPAddAlignSingleAdd(self.width, self.id_wid))
1521 alm.setup(m, sc.out_a, sc.out_b, sc.in_mid)
1522
1523 n1 = self.add_state(FPNormToPack(self.width, self.id_wid))
1524 n1.setup(m, alm.out_z, alm.out_of, alm.in_mid)
1525
1526 ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z, self.out_z,
1527 n1.in_mid, self.out_mid))
1528
1529 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z,
1530 sc.in_mid, self.out_mid))
1531
1532
1533 class FPADDBase(FPState, FPID):
1534
1535 def __init__(self, width, id_wid=None, single_cycle=False):
1536 """ IEEE754 FP Add
1537
1538 * width: bit-width of IEEE754. supported: 16, 32, 64
1539 * id_wid: an identifier that is sync-connected to the input
1540 * single_cycle: True indicates each stage to complete in 1 clock
1541 """
1542 FPID.__init__(self, id_wid)
1543 FPState.__init__(self, "fpadd")
1544 self.width = width
1545 self.single_cycle = single_cycle
1546 self.mod = FPADDBaseMod(width, id_wid, single_cycle)
1547
1548 self.in_t = Trigger()
1549 self.in_a = Signal(width)
1550 self.in_b = Signal(width)
1551 #self.out_z = FPOp(width)
1552
1553 self.z_done = Signal(reset_less=True) # connects to out_z Strobe
1554 self.in_accept = Signal(reset_less=True)
1555 self.add_stb = Signal(reset_less=True)
1556 self.add_ack = Signal(reset=0, reset_less=True)
1557
1558 def setup(self, m, a, b, add_stb, in_mid, out_z, out_mid):
1559 self.out_z = out_z
1560 self.out_mid = out_mid
1561 m.d.comb += [self.in_a.eq(a),
1562 self.in_b.eq(b),
1563 self.mod.in_a.eq(self.in_a),
1564 self.mod.in_b.eq(self.in_b),
1565 self.in_mid.eq(in_mid),
1566 self.mod.in_mid.eq(self.in_mid),
1567 self.z_done.eq(self.mod.out_z.trigger),
1568 #self.add_stb.eq(add_stb),
1569 self.mod.in_t.stb.eq(self.in_t.stb),
1570 self.in_t.ack.eq(self.mod.in_t.ack),
1571 self.out_mid.eq(self.mod.out_mid),
1572 self.out_z.v.eq(self.mod.out_z.v),
1573 self.out_z.stb.eq(self.mod.out_z.stb),
1574 self.mod.out_z.ack.eq(self.out_z.ack),
1575 ]
1576
1577 m.d.sync += self.add_stb.eq(add_stb)
1578 m.d.sync += self.add_ack.eq(0) # sets to zero when not in active state
1579 m.d.sync += self.out_z.ack.eq(0) # likewise
1580 #m.d.sync += self.in_t.stb.eq(0)
1581
1582 m.submodules.fpadd = self.mod
1583
1584 def action(self, m):
1585
1586 # in_accept is set on incoming strobe HIGH and ack LOW.
1587 m.d.comb += self.in_accept.eq((~self.add_ack) & (self.add_stb))
1588
1589 #with m.If(self.in_t.ack):
1590 # m.d.sync += self.in_t.stb.eq(0)
1591 with m.If(~self.z_done):
1592 # not done: test for accepting an incoming operand pair
1593 with m.If(self.in_accept):
1594 m.d.sync += [
1595 self.add_ack.eq(1), # acknowledge receipt...
1596 self.in_t.stb.eq(1), # initiate add
1597 ]
1598 with m.Else():
1599 m.d.sync += [self.add_ack.eq(0),
1600 self.in_t.stb.eq(0),
1601 self.out_z.ack.eq(1),
1602 ]
1603 with m.Else():
1604 # done: acknowledge, and write out id and value
1605 m.d.sync += [self.add_ack.eq(1),
1606 self.in_t.stb.eq(0)
1607 ]
1608 m.next = "put_z"
1609
1610 return
1611
1612 if self.in_mid is not None:
1613 m.d.sync += self.out_mid.eq(self.mod.out_mid)
1614
1615 m.d.sync += [
1616 self.out_z.v.eq(self.mod.out_z.v)
1617 ]
1618 # move to output state on detecting z ack
1619 with m.If(self.out_z.trigger):
1620 m.d.sync += self.out_z.stb.eq(0)
1621 m.next = "put_z"
1622 with m.Else():
1623 m.d.sync += self.out_z.stb.eq(1)
1624
1625 class ResArray:
1626 def __init__(self, width, id_wid):
1627 self.width = width
1628 self.id_wid = id_wid
1629 res = []
1630 for i in range(rs_sz):
1631 out_z = FPOp(width)
1632 out_z.name = "out_z_%d" % i
1633 res.append(out_z)
1634 self.res = Array(res)
1635 self.in_z = FPOp(width)
1636 self.in_mid = Signal(self.id_wid, reset_less=True)
1637
1638 def setup(self, m, in_z, in_mid):
1639 m.d.comb += [self.in_z.eq(in_z),
1640 self.in_mid.eq(in_mid)]
1641
1642 def get_fragment(self, platform=None):
1643 """ creates the HDL code-fragment for FPAdd
1644 """
1645 m = Module()
1646 m.submodules.res_in_z = self.in_z
1647 m.submodules += self.res
1648
1649 return m
1650
1651 def ports(self):
1652 res = []
1653 for z in self.res:
1654 res += z.ports()
1655 return res
1656
1657
1658 class FPADD(FPID):
1659 """ FPADD: stages as follows:
1660
1661 FPGetOp (a)
1662 |
1663 FPGetOp (b)
1664 |
1665 FPAddBase---> FPAddBaseMod
1666 | |
1667 PutZ GetOps->Specials->Align->Add1/2->Norm->Round/Pack->PutZ
1668
1669 FPAddBase is tricky: it is both a stage and *has* stages.
1670 Connection to FPAddBaseMod therefore requires an in stb/ack
1671 and an out stb/ack. Just as with Add1-Norm1 interaction, FPGetOp
1672 needs to be the thing that raises the incoming stb.
1673 """
1674
1675 def __init__(self, width, id_wid=None, single_cycle=False, rs_sz=2):
1676 """ IEEE754 FP Add
1677
1678 * width: bit-width of IEEE754. supported: 16, 32, 64
1679 * id_wid: an identifier that is sync-connected to the input
1680 * single_cycle: True indicates each stage to complete in 1 clock
1681 """
1682 self.width = width
1683 self.id_wid = id_wid
1684 self.single_cycle = single_cycle
1685
1686 #self.out_z = FPOp(width)
1687 self.ids = FPID(id_wid)
1688
1689 rs = []
1690 for i in range(rs_sz):
1691 in_a = FPOp(width)
1692 in_b = FPOp(width)
1693 in_a.name = "in_a_%d" % i
1694 in_b.name = "in_b_%d" % i
1695 rs.append((in_a, in_b))
1696 self.rs = Array(rs)
1697
1698 res = []
1699 for i in range(rs_sz):
1700 out_z = FPOp(width)
1701 out_z.name = "out_z_%d" % i
1702 res.append(out_z)
1703 self.res = Array(res)
1704
1705 self.states = []
1706
1707 def add_state(self, state):
1708 self.states.append(state)
1709 return state
1710
1711 def get_fragment(self, platform=None):
1712 """ creates the HDL code-fragment for FPAdd
1713 """
1714 m = Module()
1715 m.submodules += self.rs
1716
1717 in_a = self.rs[0][0]
1718 in_b = self.rs[0][1]
1719
1720 out_z = FPOp(self.width)
1721 out_mid = Signal(self.id_wid, reset_less=True)
1722 m.submodules.out_z = out_z
1723
1724 geta = self.add_state(FPGetOp("get_a", "get_b",
1725 in_a, self.width))
1726 geta.setup(m, in_a)
1727 a = geta.out_op
1728
1729 getb = self.add_state(FPGetOp("get_b", "fpadd",
1730 in_b, self.width))
1731 getb.setup(m, in_b)
1732 b = getb.out_op
1733
1734 ab = FPADDBase(self.width, self.id_wid, self.single_cycle)
1735 ab = self.add_state(ab)
1736 ab.setup(m, a, b, getb.out_decode, self.ids.in_mid,
1737 out_z, out_mid)
1738
1739 pz = self.add_state(FPPutZIdx("put_z", ab.out_z, self.res,
1740 out_mid, "get_a"))
1741
1742 with m.FSM() as fsm:
1743
1744 for state in self.states:
1745 with m.State(state.state_from):
1746 state.action(m)
1747
1748 return m
1749
1750
1751 if __name__ == "__main__":
1752 if True:
1753 alu = FPADD(width=32, id_wid=5, single_cycle=True)
1754 main(alu, ports=alu.rs[0][0].ports() + \
1755 alu.rs[0][1].ports() + \
1756 alu.res[0].ports() + \
1757 [alu.ids.in_mid, alu.ids.out_mid])
1758 else:
1759 alu = FPADDBase(width=32, id_wid=5, single_cycle=True)
1760 main(alu, ports=[alu.in_a, alu.in_b] + \
1761 alu.in_t.ports() + \
1762 alu.out_z.ports() + \
1763 [alu.in_mid, alu.out_mid])
1764
1765
1766 # works... but don't use, just do "python fname.py convert -t v"
1767 #print (verilog.convert(alu, ports=[
1768 # ports=alu.in_a.ports() + \
1769 # alu.in_b.ports() + \
1770 # alu.out_z.ports())