1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
12 def __init__(self
, width
):
14 self
.smax
= int(log(width
) / log(2))
15 self
.i
= Signal(width
, reset_less
=True)
16 self
.s
= Signal(self
.smax
, reset_less
=True)
17 self
.o
= Signal(width
, reset_less
=True)
19 def elaborate(self
, platform
):
21 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
32 Could be adapted to do arithmetic shift by taking copies of the
36 def __init__(self
, width
):
38 self
.smax
= int(log(width
) / log(2))
40 def lshift(self
, op
, s
):
44 for i
in range(self
.smax
):
46 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
49 def rshift(self
, op
, s
):
53 for i
in range(self
.smax
):
55 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
60 """ Floating-point Base Number Class
62 def __init__(self
, width
, m_extra
=True):
64 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
65 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
66 e_max
= 1<<(e_width
-3)
67 self
.rmw
= m_width
# real mantissa width (not including extras)
70 # mantissa extra bits (top,guard,round)
72 m_width
+= self
.m_extra
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self
.m_width
= m_width
77 self
.e_width
= e_width
78 self
.e_start
= self
.rmw
- 1
79 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
81 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
82 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
83 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
84 self
.s
= Signal(reset_less
=True) # Sign bit
86 self
.mzero
= Const(0, (m_width
, False))
87 self
.m1s
= Const(-1, (m_width
, False))
88 self
.P128
= Const(e_max
, (e_width
, True))
89 self
.P127
= Const(e_max
-1, (e_width
, True))
90 self
.N127
= Const(-(e_max
-1), (e_width
, True))
91 self
.N126
= Const(-(e_max
-2), (e_width
, True))
93 self
.is_nan
= Signal(reset_less
=True)
94 self
.is_zero
= Signal(reset_less
=True)
95 self
.is_inf
= Signal(reset_less
=True)
96 self
.is_overflowed
= Signal(reset_less
=True)
97 self
.is_denormalised
= Signal(reset_less
=True)
98 self
.exp_128
= Signal(reset_less
=True)
99 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
100 self
.exp_lt_n126
= Signal(reset_less
=True)
101 self
.exp_gt_n126
= Signal(reset_less
=True)
102 self
.exp_gt127
= Signal(reset_less
=True)
103 self
.exp_n127
= Signal(reset_less
=True)
104 self
.exp_n126
= Signal(reset_less
=True)
105 self
.m_zero
= Signal(reset_less
=True)
106 self
.m_msbzero
= Signal(reset_less
=True)
108 def elaborate(self
, platform
):
110 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
111 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
112 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
113 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
114 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
115 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
116 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.N126
)
117 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
118 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
119 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
120 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
121 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
122 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
123 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
128 return (self
.exp_128
) & (~self
.m_zero
)
131 return (self
.exp_128
) & (self
.m_zero
)
134 return (self
.exp_n127
) & (self
.m_zero
)
136 def _is_overflowed(self
):
137 return self
.exp_gt127
139 def _is_denormalised(self
):
140 return (self
.exp_n126
) & (self
.m_msbzero
)
143 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
146 class FPNumOut(FPNumBase
):
147 """ Floating-point Number Class
149 Contains signals for an incoming copy of the value, decoded into
150 sign / exponent / mantissa.
151 Also contains encoding functions, creation and recognition of
152 zero, NaN and inf (all signed)
154 Four extra bits are included in the mantissa: the top bit
155 (m[-1]) is effectively a carry-overflow. The other three are
156 guard (m[2]), round (m[1]), and sticky (m[0])
158 def __init__(self
, width
, m_extra
=True):
159 FPNumBase
.__init
__(self
, width
, m_extra
)
161 def elaborate(self
, platform
):
162 m
= FPNumBase
.elaborate(self
, platform
)
166 def create(self
, s
, e
, m
):
167 """ creates a value from sign / exponent / mantissa
169 bias is added here, to the exponent
172 self
.v
[-1].eq(s
), # sign
173 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
174 self
.v
[0:self
.e_start
].eq(m
) # mantissa
178 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
181 return self
.create(s
, self
.P128
, 0)
184 return self
.create(s
, self
.N127
, 0)
187 class MultiShiftRMerge
:
188 """ shifts down (right) and merges lower bits into m[0].
189 m[0] is the "sticky" bit, basically
191 def __init__(self
, width
, s_max
=None):
193 s_max
= int(log(width
) / log(2))
195 self
.m
= Signal(width
, reset_less
=True)
196 self
.inp
= Signal(width
, reset_less
=True)
197 self
.diff
= Signal(s_max
, reset_less
=True)
200 def elaborate(self
, platform
):
203 rs
= Signal(self
.width
, reset_less
=True)
204 m_mask
= Signal(self
.width
, reset_less
=True)
205 smask
= Signal(self
.width
, reset_less
=True)
206 stickybit
= Signal(reset_less
=True)
207 maxslen
= Signal(self
.smax
, reset_less
=True)
208 maxsleni
= Signal(self
.smax
, reset_less
=True)
210 sm
= MultiShift(self
.width
-1)
211 m0s
= Const(0, self
.width
-1)
212 mw
= Const(self
.width
-1, len(self
.diff
))
213 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
214 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
218 # shift mantissa by maxslen, mask by inverse
219 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
220 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
221 smask
.eq(self
.inp
[1:] & m_mask
),
222 # sticky bit combines all mask (and mantissa low bit)
223 stickybit
.eq(smask
.bool() | self
.inp
[0]),
224 # mantissa result contains m[0] already.
225 self
.m
.eq(Cat(stickybit
, rs
))
230 class FPNumShift(FPNumBase
):
231 """ Floating-point Number Class for shifting
233 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
234 FPNumBase
.__init
__(self
, width
, m_extra
)
235 self
.latch_in
= Signal()
240 def elaborate(self
, platform
):
241 m
= FPNumBase
.elaborate(self
, platform
)
243 m
.d
.comb
+= self
.s
.eq(op
.s
)
244 m
.d
.comb
+= self
.e
.eq(op
.e
)
245 m
.d
.comb
+= self
.m
.eq(op
.m
)
247 with self
.mainm
.State("align"):
248 with m
.If(self
.e
< self
.inv
.e
):
249 m
.d
.sync
+= self
.shift_down()
253 def shift_down(self
, inp
):
254 """ shifts a mantissa down by one. exponent is increased to compensate
256 accuracy is lost as a result in the mantissa however there are 3
257 guard bits (the latter of which is the "sticky" bit)
259 return [self
.e
.eq(inp
.e
+ 1),
260 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
263 def shift_down_multi(self
, diff
):
264 """ shifts a mantissa down. exponent is increased to compensate
266 accuracy is lost as a result in the mantissa however there are 3
267 guard bits (the latter of which is the "sticky" bit)
269 this code works by variable-shifting the mantissa by up to
270 its maximum bit-length: no point doing more (it'll still be
273 the sticky bit is computed by shifting a batch of 1s by
274 the same amount, which will introduce zeros. it's then
275 inverted and used as a mask to get the LSBs of the mantissa.
276 those are then |'d into the sticky bit.
278 sm
= MultiShift(self
.width
)
279 mw
= Const(self
.m_width
-1, len(diff
))
280 maxslen
= Mux(diff
> mw
, mw
, diff
)
281 rs
= sm
.rshift(self
.m
[1:], maxslen
)
282 maxsleni
= mw
- maxslen
283 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
285 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
286 return [self
.e
.eq(self
.e
+ diff
),
287 self
.m
.eq(Cat(stickybits
, rs
))
290 def shift_up_multi(self
, diff
):
291 """ shifts a mantissa up. exponent is decreased to compensate
293 sm
= MultiShift(self
.width
)
294 mw
= Const(self
.m_width
, len(diff
))
295 maxslen
= Mux(diff
> mw
, mw
, diff
)
297 return [self
.e
.eq(self
.e
- diff
),
298 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
301 class FPNumIn(FPNumBase
):
302 """ Floating-point Number Class
304 Contains signals for an incoming copy of the value, decoded into
305 sign / exponent / mantissa.
306 Also contains encoding functions, creation and recognition of
307 zero, NaN and inf (all signed)
309 Four extra bits are included in the mantissa: the top bit
310 (m[-1]) is effectively a carry-overflow. The other three are
311 guard (m[2]), round (m[1]), and sticky (m[0])
313 def __init__(self
, op
, width
, m_extra
=True):
314 FPNumBase
.__init
__(self
, width
, m_extra
)
315 self
.latch_in
= Signal()
318 def elaborate(self
, platform
):
319 m
= FPNumBase
.elaborate(self
, platform
)
321 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
322 #with m.If(self.latch_in):
323 # m.d.sync += self.decode(self.v)
328 """ decodes a latched value into sign / exponent / mantissa
330 bias is subtracted here, from the exponent. exponent
331 is extended to 10 bits so that subtract 127 is done on
334 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
335 #print ("decode", self.e_end)
336 return [self
.m
.eq(Cat(*args
)), # mantissa
337 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
338 self
.s
.eq(v
[-1]), # sign
341 def shift_down(self
, inp
):
342 """ shifts a mantissa down by one. exponent is increased to compensate
344 accuracy is lost as a result in the mantissa however there are 3
345 guard bits (the latter of which is the "sticky" bit)
347 return [self
.e
.eq(inp
.e
+ 1),
348 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
351 def shift_down_multi(self
, diff
, inp
=None):
352 """ shifts a mantissa down. exponent is increased to compensate
354 accuracy is lost as a result in the mantissa however there are 3
355 guard bits (the latter of which is the "sticky" bit)
357 this code works by variable-shifting the mantissa by up to
358 its maximum bit-length: no point doing more (it'll still be
361 the sticky bit is computed by shifting a batch of 1s by
362 the same amount, which will introduce zeros. it's then
363 inverted and used as a mask to get the LSBs of the mantissa.
364 those are then |'d into the sticky bit.
368 sm
= MultiShift(self
.width
)
369 mw
= Const(self
.m_width
-1, len(diff
))
370 maxslen
= Mux(diff
> mw
, mw
, diff
)
371 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
372 maxsleni
= mw
- maxslen
373 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
375 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
376 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
377 return [self
.e
.eq(inp
.e
+ diff
),
378 self
.m
.eq(Cat(stickybit
, rs
))
381 def shift_up_multi(self
, diff
):
382 """ shifts a mantissa up. exponent is decreased to compensate
384 sm
= MultiShift(self
.width
)
385 mw
= Const(self
.m_width
, len(diff
))
386 maxslen
= Mux(diff
> mw
, mw
, diff
)
388 return [self
.e
.eq(self
.e
- diff
),
389 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
393 def __init__(self
, width
):
396 self
.v
= Signal(width
)
397 self
.stb
= Signal(reset
=0)
399 self
.trigger
= Signal(reset_less
=True)
401 def elaborate(self
, platform
):
403 m
.d
.sync
+= self
.trigger
.eq(self
.stb
& self
.ack
)
406 def chain_inv(self
, in_op
, extra
=None):
408 if extra
is not None:
410 return [self
.v
.eq(in_op
.v
), # receive value
411 self
.stb
.eq(stb
), # receive STB
412 in_op
.ack
.eq(~self
.ack
), # send ACK
415 def chain_from(self
, in_op
, extra
=None):
417 if extra
is not None:
419 return [self
.v
.eq(in_op
.v
), # receive value
420 self
.stb
.eq(stb
), # receive STB
421 in_op
.ack
.eq(self
.ack
), # send ACK
425 return [self
.v
.eq(inp
.v
),
426 self
.stb
.eq(inp
.stb
),
431 return [self
.v
, self
.stb
, self
.ack
]
436 self
.guard
= Signal(reset_less
=True) # tot[2]
437 self
.round_bit
= Signal(reset_less
=True) # tot[1]
438 self
.sticky
= Signal(reset_less
=True) # tot[0]
439 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
441 self
.roundz
= Signal(reset_less
=True)
444 return [self
.guard
.eq(inp
.guard
),
445 self
.round_bit
.eq(inp
.round_bit
),
446 self
.sticky
.eq(inp
.sticky
),
449 def elaborate(self
, platform
):
451 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
452 (self
.round_bit | self
.sticky | self
.m0
))
457 """ IEEE754 Floating Point Base Class
459 contains common functions for FP manipulation, such as
460 extracting and packing operands, normalisation, denormalisation,
464 def get_op(self
, m
, op
, v
, next_state
):
465 """ this function moves to the next state and copies the operand
466 when both stb and ack are 1.
467 acknowledgement is sent by setting ack to ZERO.
469 with m
.If((op
.ack
) & (op
.stb
)):
472 # op is latched in from FPNumIn class on same ack/stb
477 m
.d
.sync
+= op
.ack
.eq(1)
479 def denormalise(self
, m
, a
):
480 """ denormalises a number. this is probably the wrong name for
481 this function. for normalised numbers (exponent != minimum)
482 one *extra* bit (the implicit 1) is added *back in*.
483 for denormalised numbers, the mantissa is left alone
484 and the exponent increased by 1.
486 both cases *effectively multiply the number stored by 2*,
487 which has to be taken into account when extracting the result.
489 with m
.If(a
.exp_n127
):
490 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
492 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
494 def op_normalise(self
, m
, op
, next_state
):
495 """ operand normalisation
496 NOTE: just like "align", this one keeps going round every clock
497 until the result's exponent is within acceptable "range"
499 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
501 op
.e
.eq(op
.e
- 1), # DECREASE exponent
502 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
507 def normalise_1(self
, m
, z
, of
, next_state
):
508 """ first stage normalisation
510 NOTE: just like "align", this one keeps going round every clock
511 until the result's exponent is within acceptable "range"
512 NOTE: the weirdness of reassigning guard and round is due to
513 the extra mantissa bits coming from tot[0..2]
515 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
517 z
.e
.eq(z
.e
- 1), # DECREASE exponent
518 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
519 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
520 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
521 of
.round_bit
.eq(0), # reset round bit
527 def normalise_2(self
, m
, z
, of
, next_state
):
528 """ second stage normalisation
530 NOTE: just like "align", this one keeps going round every clock
531 until the result's exponent is within acceptable "range"
532 NOTE: the weirdness of reassigning guard and round is due to
533 the extra mantissa bits coming from tot[0..2]
535 with m
.If(z
.e
< z
.N126
):
537 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
538 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
541 of
.round_bit
.eq(of
.guard
),
542 of
.sticky
.eq(of
.sticky | of
.round_bit
)
547 def roundz(self
, m
, z
, out_z
, roundz
):
548 """ performs rounding on the output. TODO: different kinds of rounding
550 m
.d
.comb
+= out_z
.copy(z
) # copies input to output first
552 m
.d
.comb
+= out_z
.m
.eq(z
.m
+ 1) # mantissa rounds up
553 with m
.If(z
.m
== z
.m1s
): # all 1s
554 m
.d
.comb
+= out_z
.e
.eq(z
.e
+ 1) # exponent rounds up
556 def corrections(self
, m
, z
, next_state
):
557 """ denormalisation and sign-bug corrections
560 # denormalised, correct exponent to zero
561 with m
.If(z
.is_denormalised
):
562 m
.d
.sync
+= z
.e
.eq(z
.N127
)
564 def pack(self
, m
, z
, next_state
):
565 """ packs the result into the output (detects overflow->Inf)
568 # if overflow occurs, return inf
569 with m
.If(z
.is_overflowed
):
570 m
.d
.sync
+= z
.inf(z
.s
)
572 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
574 def put_z(self
, m
, z
, out_z
, next_state
):
575 """ put_z: stores the result in the output. raises stb and waits
576 for ack to be set to 1 before moving to the next state.
577 resets stb back to zero when that occurs, as acknowledgement.
582 with m
.If(out_z
.stb
& out_z
.ack
):
583 m
.d
.sync
+= out_z
.stb
.eq(0)
586 m
.d
.sync
+= out_z
.stb
.eq(1)