1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
10 from singlepipe
import PrevControl
, NextControl
11 from pipeline
import ObjectProxy
16 def __init__(self
, width
):
18 self
.smax
= int(log(width
) / log(2))
19 self
.i
= Signal(width
, reset_less
=True)
20 self
.s
= Signal(self
.smax
, reset_less
=True)
21 self
.o
= Signal(width
, reset_less
=True)
23 def elaborate(self
, platform
):
25 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
36 Could be adapted to do arithmetic shift by taking copies of the
40 def __init__(self
, width
):
42 self
.smax
= int(log(width
) / log(2))
44 def lshift(self
, op
, s
):
48 for i
in range(self
.smax
):
50 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
53 def rshift(self
, op
, s
):
57 for i
in range(self
.smax
):
59 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
64 """ Floating-point Base Number Class
66 def __init__(self
, width
, m_extra
=True):
68 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
69 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
70 e_max
= 1<<(e_width
-3)
71 self
.rmw
= m_width
# real mantissa width (not including extras)
74 # mantissa extra bits (top,guard,round)
76 m_width
+= self
.m_extra
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self
.m_width
= m_width
81 self
.e_width
= e_width
82 self
.e_start
= self
.rmw
- 1
83 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
85 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
86 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
87 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
88 self
.s
= Signal(reset_less
=True) # Sign bit
90 self
.mzero
= Const(0, (m_width
, False))
91 m_msb
= 1<<(self
.m_width
-2)
92 self
.msb1
= Const(m_msb
, (m_width
, False))
93 self
.m1s
= Const(-1, (m_width
, False))
94 self
.P128
= Const(e_max
, (e_width
, True))
95 self
.P127
= Const(e_max
-1, (e_width
, True))
96 self
.N127
= Const(-(e_max
-1), (e_width
, True))
97 self
.N126
= Const(-(e_max
-2), (e_width
, True))
99 self
.is_nan
= Signal(reset_less
=True)
100 self
.is_zero
= Signal(reset_less
=True)
101 self
.is_inf
= Signal(reset_less
=True)
102 self
.is_overflowed
= Signal(reset_less
=True)
103 self
.is_denormalised
= Signal(reset_less
=True)
104 self
.exp_128
= Signal(reset_less
=True)
105 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
106 self
.exp_lt_n126
= Signal(reset_less
=True)
107 self
.exp_gt_n126
= Signal(reset_less
=True)
108 self
.exp_gt127
= Signal(reset_less
=True)
109 self
.exp_n127
= Signal(reset_less
=True)
110 self
.exp_n126
= Signal(reset_less
=True)
111 self
.m_zero
= Signal(reset_less
=True)
112 self
.m_msbzero
= Signal(reset_less
=True)
114 def elaborate(self
, platform
):
116 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
117 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
118 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
119 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
120 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
121 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
122 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.N126
)
123 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
124 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
125 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
126 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
127 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
128 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
129 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
134 return (self
.exp_128
) & (~self
.m_zero
)
137 return (self
.exp_128
) & (self
.m_zero
)
140 return (self
.exp_n127
) & (self
.m_zero
)
142 def _is_overflowed(self
):
143 return self
.exp_gt127
145 def _is_denormalised(self
):
146 return (self
.exp_n126
) & (self
.m_msbzero
)
149 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
152 class FPNumOut(FPNumBase
):
153 """ Floating-point Number Class
155 Contains signals for an incoming copy of the value, decoded into
156 sign / exponent / mantissa.
157 Also contains encoding functions, creation and recognition of
158 zero, NaN and inf (all signed)
160 Four extra bits are included in the mantissa: the top bit
161 (m[-1]) is effectively a carry-overflow. The other three are
162 guard (m[2]), round (m[1]), and sticky (m[0])
164 def __init__(self
, width
, m_extra
=True):
165 FPNumBase
.__init
__(self
, width
, m_extra
)
167 def elaborate(self
, platform
):
168 m
= FPNumBase
.elaborate(self
, platform
)
172 def create(self
, s
, e
, m
):
173 """ creates a value from sign / exponent / mantissa
175 bias is added here, to the exponent
178 self
.v
[-1].eq(s
), # sign
179 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
180 self
.v
[0:self
.e_start
].eq(m
) # mantissa
184 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
187 return self
.create(s
, self
.P128
, 0)
190 return self
.create(s
, self
.N127
, 0)
192 def create2(self
, s
, e
, m
):
193 """ creates a value from sign / exponent / mantissa
195 bias is added here, to the exponent
197 e
= e
+ self
.P127
# exp (add on bias)
198 return Cat(m
[0:self
.e_start
],
199 e
[0:self
.e_end
-self
.e_start
],
203 return self
.create2(s
, self
.P128
, self
.msb1
)
206 return self
.create2(s
, self
.P128
, self
.mzero
)
209 return self
.create2(s
, self
.N127
, self
.mzero
)
212 class MultiShiftRMerge
:
213 """ shifts down (right) and merges lower bits into m[0].
214 m[0] is the "sticky" bit, basically
216 def __init__(self
, width
, s_max
=None):
218 s_max
= int(log(width
) / log(2))
220 self
.m
= Signal(width
, reset_less
=True)
221 self
.inp
= Signal(width
, reset_less
=True)
222 self
.diff
= Signal(s_max
, reset_less
=True)
225 def elaborate(self
, platform
):
228 rs
= Signal(self
.width
, reset_less
=True)
229 m_mask
= Signal(self
.width
, reset_less
=True)
230 smask
= Signal(self
.width
, reset_less
=True)
231 stickybit
= Signal(reset_less
=True)
232 maxslen
= Signal(self
.smax
, reset_less
=True)
233 maxsleni
= Signal(self
.smax
, reset_less
=True)
235 sm
= MultiShift(self
.width
-1)
236 m0s
= Const(0, self
.width
-1)
237 mw
= Const(self
.width
-1, len(self
.diff
))
238 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
239 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
243 # shift mantissa by maxslen, mask by inverse
244 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
245 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
246 smask
.eq(self
.inp
[1:] & m_mask
),
247 # sticky bit combines all mask (and mantissa low bit)
248 stickybit
.eq(smask
.bool() | self
.inp
[0]),
249 # mantissa result contains m[0] already.
250 self
.m
.eq(Cat(stickybit
, rs
))
255 class FPNumShift(FPNumBase
):
256 """ Floating-point Number Class for shifting
258 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
259 FPNumBase
.__init
__(self
, width
, m_extra
)
260 self
.latch_in
= Signal()
265 def elaborate(self
, platform
):
266 m
= FPNumBase
.elaborate(self
, platform
)
268 m
.d
.comb
+= self
.s
.eq(op
.s
)
269 m
.d
.comb
+= self
.e
.eq(op
.e
)
270 m
.d
.comb
+= self
.m
.eq(op
.m
)
272 with self
.mainm
.State("align"):
273 with m
.If(self
.e
< self
.inv
.e
):
274 m
.d
.sync
+= self
.shift_down()
278 def shift_down(self
, inp
):
279 """ shifts a mantissa down by one. exponent is increased to compensate
281 accuracy is lost as a result in the mantissa however there are 3
282 guard bits (the latter of which is the "sticky" bit)
284 return [self
.e
.eq(inp
.e
+ 1),
285 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
288 def shift_down_multi(self
, diff
):
289 """ shifts a mantissa down. exponent is increased to compensate
291 accuracy is lost as a result in the mantissa however there are 3
292 guard bits (the latter of which is the "sticky" bit)
294 this code works by variable-shifting the mantissa by up to
295 its maximum bit-length: no point doing more (it'll still be
298 the sticky bit is computed by shifting a batch of 1s by
299 the same amount, which will introduce zeros. it's then
300 inverted and used as a mask to get the LSBs of the mantissa.
301 those are then |'d into the sticky bit.
303 sm
= MultiShift(self
.width
)
304 mw
= Const(self
.m_width
-1, len(diff
))
305 maxslen
= Mux(diff
> mw
, mw
, diff
)
306 rs
= sm
.rshift(self
.m
[1:], maxslen
)
307 maxsleni
= mw
- maxslen
308 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
310 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
311 return [self
.e
.eq(self
.e
+ diff
),
312 self
.m
.eq(Cat(stickybits
, rs
))
315 def shift_up_multi(self
, diff
):
316 """ shifts a mantissa up. exponent is decreased to compensate
318 sm
= MultiShift(self
.width
)
319 mw
= Const(self
.m_width
, len(diff
))
320 maxslen
= Mux(diff
> mw
, mw
, diff
)
322 return [self
.e
.eq(self
.e
- diff
),
323 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
327 class FPNumDecode(FPNumBase
):
328 """ Floating-point Number Class
330 Contains signals for an incoming copy of the value, decoded into
331 sign / exponent / mantissa.
332 Also contains encoding functions, creation and recognition of
333 zero, NaN and inf (all signed)
335 Four extra bits are included in the mantissa: the top bit
336 (m[-1]) is effectively a carry-overflow. The other three are
337 guard (m[2]), round (m[1]), and sticky (m[0])
339 def __init__(self
, op
, width
, m_extra
=True):
340 FPNumBase
.__init
__(self
, width
, m_extra
)
343 def elaborate(self
, platform
):
344 m
= FPNumBase
.elaborate(self
, platform
)
346 m
.d
.comb
+= self
.decode(self
.v
)
351 """ decodes a latched value into sign / exponent / mantissa
353 bias is subtracted here, from the exponent. exponent
354 is extended to 10 bits so that subtract 127 is done on
357 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
358 #print ("decode", self.e_end)
359 return [self
.m
.eq(Cat(*args
)), # mantissa
360 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
361 self
.s
.eq(v
[-1]), # sign
364 class FPNumIn(FPNumBase
):
365 """ Floating-point Number Class
367 Contains signals for an incoming copy of the value, decoded into
368 sign / exponent / mantissa.
369 Also contains encoding functions, creation and recognition of
370 zero, NaN and inf (all signed)
372 Four extra bits are included in the mantissa: the top bit
373 (m[-1]) is effectively a carry-overflow. The other three are
374 guard (m[2]), round (m[1]), and sticky (m[0])
376 def __init__(self
, op
, width
, m_extra
=True):
377 FPNumBase
.__init
__(self
, width
, m_extra
)
378 self
.latch_in
= Signal()
381 def decode2(self
, m
):
382 """ decodes a latched value into sign / exponent / mantissa
384 bias is subtracted here, from the exponent. exponent
385 is extended to 10 bits so that subtract 127 is done on
389 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
390 #print ("decode", self.e_end)
391 res
= ObjectProxy(m
, pipemode
=False)
392 res
.m
= Cat(*args
) # mantissa
393 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.P127
# exp
398 """ decodes a latched value into sign / exponent / mantissa
400 bias is subtracted here, from the exponent. exponent
401 is extended to 10 bits so that subtract 127 is done on
404 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
405 #print ("decode", self.e_end)
406 return [self
.m
.eq(Cat(*args
)), # mantissa
407 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
408 self
.s
.eq(v
[-1]), # sign
411 def shift_down(self
, inp
):
412 """ shifts a mantissa down by one. exponent is increased to compensate
414 accuracy is lost as a result in the mantissa however there are 3
415 guard bits (the latter of which is the "sticky" bit)
417 return [self
.e
.eq(inp
.e
+ 1),
418 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
421 def shift_down_multi(self
, diff
, inp
=None):
422 """ shifts a mantissa down. exponent is increased to compensate
424 accuracy is lost as a result in the mantissa however there are 3
425 guard bits (the latter of which is the "sticky" bit)
427 this code works by variable-shifting the mantissa by up to
428 its maximum bit-length: no point doing more (it'll still be
431 the sticky bit is computed by shifting a batch of 1s by
432 the same amount, which will introduce zeros. it's then
433 inverted and used as a mask to get the LSBs of the mantissa.
434 those are then |'d into the sticky bit.
438 sm
= MultiShift(self
.width
)
439 mw
= Const(self
.m_width
-1, len(diff
))
440 maxslen
= Mux(diff
> mw
, mw
, diff
)
441 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
442 maxsleni
= mw
- maxslen
443 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
445 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
446 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
447 return [self
.e
.eq(inp
.e
+ diff
),
448 self
.m
.eq(Cat(stickybit
, rs
))
451 def shift_up_multi(self
, diff
):
452 """ shifts a mantissa up. exponent is decreased to compensate
454 sm
= MultiShift(self
.width
)
455 mw
= Const(self
.m_width
, len(diff
))
456 maxslen
= Mux(diff
> mw
, mw
, diff
)
458 return [self
.e
.eq(self
.e
- diff
),
459 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
465 self
.stb
= Signal(reset
=0)
467 self
.trigger
= Signal(reset_less
=True)
469 def elaborate(self
, platform
):
471 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
475 return [self
.stb
.eq(inp
.stb
),
480 return [self
.stb
, self
.ack
]
483 class FPOpIn(PrevControl
):
484 def __init__(self
, width
):
485 PrevControl
.__init
__(self
)
487 self
.v
= Signal(width
)
490 def chain_inv(self
, in_op
, extra
=None):
492 if extra
is not None:
494 return [self
.v
.eq(in_op
.v
), # receive value
495 self
.stb
.eq(stb
), # receive STB
496 in_op
.ack
.eq(~self
.ack
), # send ACK
499 def chain_from(self
, in_op
, extra
=None):
501 if extra
is not None:
503 return [self
.v
.eq(in_op
.v
), # receive value
504 self
.stb
.eq(stb
), # receive STB
505 in_op
.ack
.eq(self
.ack
), # send ACK
509 class FPOpOut(NextControl
):
510 def __init__(self
, width
):
511 NextControl
.__init
__(self
)
513 self
.v
= Signal(width
)
516 def chain_inv(self
, in_op
, extra
=None):
518 if extra
is not None:
520 return [self
.v
.eq(in_op
.v
), # receive value
521 self
.stb
.eq(stb
), # receive STB
522 in_op
.ack
.eq(~self
.ack
), # send ACK
525 def chain_from(self
, in_op
, extra
=None):
527 if extra
is not None:
529 return [self
.v
.eq(in_op
.v
), # receive value
530 self
.stb
.eq(stb
), # receive STB
531 in_op
.ack
.eq(self
.ack
), # send ACK
537 self
.guard
= Signal(reset_less
=True) # tot[2]
538 self
.round_bit
= Signal(reset_less
=True) # tot[1]
539 self
.sticky
= Signal(reset_less
=True) # tot[0]
540 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
542 self
.roundz
= Signal(reset_less
=True)
545 return [self
.guard
.eq(inp
.guard
),
546 self
.round_bit
.eq(inp
.round_bit
),
547 self
.sticky
.eq(inp
.sticky
),
550 def elaborate(self
, platform
):
552 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
553 (self
.round_bit | self
.sticky | self
.m0
))
558 """ IEEE754 Floating Point Base Class
560 contains common functions for FP manipulation, such as
561 extracting and packing operands, normalisation, denormalisation,
565 def get_op(self
, m
, op
, v
, next_state
):
566 """ this function moves to the next state and copies the operand
567 when both stb and ack are 1.
568 acknowledgement is sent by setting ack to ZERO.
572 with m
.If((op
.o_ready
) & (op
.i_valid_test
)):
574 # op is latched in from FPNumIn class on same ack/stb
575 m
.d
.comb
+= ack
.eq(0)
577 m
.d
.comb
+= ack
.eq(1)
580 def denormalise(self
, m
, a
):
581 """ denormalises a number. this is probably the wrong name for
582 this function. for normalised numbers (exponent != minimum)
583 one *extra* bit (the implicit 1) is added *back in*.
584 for denormalised numbers, the mantissa is left alone
585 and the exponent increased by 1.
587 both cases *effectively multiply the number stored by 2*,
588 which has to be taken into account when extracting the result.
590 with m
.If(a
.exp_n127
):
591 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
593 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
595 def op_normalise(self
, m
, op
, next_state
):
596 """ operand normalisation
597 NOTE: just like "align", this one keeps going round every clock
598 until the result's exponent is within acceptable "range"
600 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
602 op
.e
.eq(op
.e
- 1), # DECREASE exponent
603 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
608 def normalise_1(self
, m
, z
, of
, next_state
):
609 """ first stage normalisation
611 NOTE: just like "align", this one keeps going round every clock
612 until the result's exponent is within acceptable "range"
613 NOTE: the weirdness of reassigning guard and round is due to
614 the extra mantissa bits coming from tot[0..2]
616 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
618 z
.e
.eq(z
.e
- 1), # DECREASE exponent
619 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
620 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
621 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
622 of
.round_bit
.eq(0), # reset round bit
628 def normalise_2(self
, m
, z
, of
, next_state
):
629 """ second stage normalisation
631 NOTE: just like "align", this one keeps going round every clock
632 until the result's exponent is within acceptable "range"
633 NOTE: the weirdness of reassigning guard and round is due to
634 the extra mantissa bits coming from tot[0..2]
636 with m
.If(z
.e
< z
.N126
):
638 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
639 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
642 of
.round_bit
.eq(of
.guard
),
643 of
.sticky
.eq(of
.sticky | of
.round_bit
)
648 def roundz(self
, m
, z
, roundz
):
649 """ performs rounding on the output. TODO: different kinds of rounding
652 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
653 with m
.If(z
.m
== z
.m1s
): # all 1s
654 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
656 def corrections(self
, m
, z
, next_state
):
657 """ denormalisation and sign-bug corrections
660 # denormalised, correct exponent to zero
661 with m
.If(z
.is_denormalised
):
662 m
.d
.sync
+= z
.e
.eq(z
.N127
)
664 def pack(self
, m
, z
, next_state
):
665 """ packs the result into the output (detects overflow->Inf)
668 # if overflow occurs, return inf
669 with m
.If(z
.is_overflowed
):
670 m
.d
.sync
+= z
.inf(z
.s
)
672 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
674 def put_z(self
, m
, z
, out_z
, next_state
):
675 """ put_z: stores the result in the output. raises stb and waits
676 for ack to be set to 1 before moving to the next state.
677 resets stb back to zero when that occurs, as acknowledgement.
682 with m
.If(out_z
.o_valid
& out_z
.i_ready_test
):
683 m
.d
.sync
+= out_z
.o_valid
.eq(0)
686 m
.d
.sync
+= out_z
.o_valid
.eq(1)
689 class FPState(FPBase
):
690 def __init__(self
, state_from
):
691 self
.state_from
= state_from
693 def set_inputs(self
, inputs
):
695 for k
,v
in inputs
.items():
698 def set_outputs(self
, outputs
):
699 self
.outputs
= outputs
700 for k
,v
in outputs
.items():
705 def __init__(self
, id_wid
):
708 self
.in_mid
= Signal(id_wid
, reset_less
=True)
709 self
.out_mid
= Signal(id_wid
, reset_less
=True)
715 if self
.id_wid
is not None:
716 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)