1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
12 def __init__(self
, width
):
14 self
.smax
= int(log(width
) / log(2))
15 self
.i
= Signal(width
, reset_less
=True)
16 self
.s
= Signal(self
.smax
, reset_less
=True)
17 self
.o
= Signal(width
, reset_less
=True)
19 def elaborate(self
, platform
):
21 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
32 Could be adapted to do arithmetic shift by taking copies of the
36 def __init__(self
, width
):
38 self
.smax
= int(log(width
) / log(2))
40 def lshift(self
, op
, s
):
44 for i
in range(self
.smax
):
46 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
49 def rshift(self
, op
, s
):
53 for i
in range(self
.smax
):
55 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
60 """ Floating-point Base Number Class
62 def __init__(self
, width
, m_extra
=True):
64 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
65 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
66 e_max
= 1<<(e_width
-3)
67 self
.rmw
= m_width
# real mantissa width (not including extras)
70 # mantissa extra bits (top,guard,round)
72 m_width
+= self
.m_extra
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self
.m_width
= m_width
77 self
.e_width
= e_width
78 self
.e_start
= self
.rmw
- 1
79 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
81 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
82 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
83 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
84 self
.s
= Signal(reset_less
=True) # Sign bit
86 self
.mzero
= Const(0, (m_width
, False))
87 self
.m1s
= Const(-1, (m_width
, False))
88 self
.P128
= Const(e_max
, (e_width
, True))
89 self
.P127
= Const(e_max
-1, (e_width
, True))
90 self
.N127
= Const(-(e_max
-1), (e_width
, True))
91 self
.N126
= Const(-(e_max
-2), (e_width
, True))
93 self
.is_nan
= Signal(reset_less
=True)
94 self
.is_zero
= Signal(reset_less
=True)
95 self
.is_inf
= Signal(reset_less
=True)
96 self
.is_overflowed
= Signal(reset_less
=True)
97 self
.is_denormalised
= Signal(reset_less
=True)
98 self
.exp_128
= Signal(reset_less
=True)
99 self
.exp_lt_n126
= Signal(reset_less
=True)
100 self
.exp_gt_n126
= Signal(reset_less
=True)
101 self
.exp_gt127
= Signal(reset_less
=True)
102 self
.exp_n127
= Signal(reset_less
=True)
103 self
.exp_n126
= Signal(reset_less
=True)
104 self
.m_zero
= Signal(reset_less
=True)
105 self
.m_msbzero
= Signal(reset_less
=True)
107 def elaborate(self
, platform
):
109 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
110 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
111 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
112 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
113 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
114 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
115 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.e
> self
.N126
)
116 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.e
< self
.N126
)
117 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
118 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
119 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
120 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
121 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
126 return (self
.exp_128
) & (~self
.m_zero
)
129 return (self
.exp_128
) & (self
.m_zero
)
132 return (self
.exp_n127
) & (self
.m_zero
)
134 def _is_overflowed(self
):
135 return self
.exp_gt127
137 def _is_denormalised(self
):
138 return (self
.exp_n126
) & (self
.m_msbzero
)
141 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
144 class FPNumOut(FPNumBase
):
145 """ Floating-point Number Class
147 Contains signals for an incoming copy of the value, decoded into
148 sign / exponent / mantissa.
149 Also contains encoding functions, creation and recognition of
150 zero, NaN and inf (all signed)
152 Four extra bits are included in the mantissa: the top bit
153 (m[-1]) is effectively a carry-overflow. The other three are
154 guard (m[2]), round (m[1]), and sticky (m[0])
156 def __init__(self
, width
, m_extra
=True):
157 FPNumBase
.__init
__(self
, width
, m_extra
)
159 def elaborate(self
, platform
):
160 m
= FPNumBase
.elaborate(self
, platform
)
164 def create(self
, s
, e
, m
):
165 """ creates a value from sign / exponent / mantissa
167 bias is added here, to the exponent
170 self
.v
[-1].eq(s
), # sign
171 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
172 self
.v
[0:self
.e_start
].eq(m
) # mantissa
176 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
179 return self
.create(s
, self
.P128
, 0)
182 return self
.create(s
, self
.N127
, 0)
185 class FPNumShiftMultiRight(FPNumBase
):
186 """ shifts a mantissa down. exponent is increased to compensate
188 accuracy is lost as a result in the mantissa however there are 3
189 guard bits (the latter of which is the "sticky" bit)
191 this code works by variable-shifting the mantissa by up to
192 its maximum bit-length: no point doing more (it'll still be
195 the sticky bit is computed by shifting a batch of 1s by
196 the same amount, which will introduce zeros. it's then
197 inverted and used as a mask to get the LSBs of the mantissa.
198 those are then |'d into the sticky bit.
200 def __init__(self
, inp
, diff
, width
):
201 self
.m
= Signal(width
, reset_less
=True)
206 def elaborate(self
, platform
):
208 #m.submodules.inp = self.inp
210 rs
= Signal(self
.width
, reset_less
=True)
211 m_mask
= Signal(self
.width
, reset_less
=True)
212 smask
= Signal(self
.width
, reset_less
=True)
213 stickybit
= Signal(reset_less
=True)
215 sm
= MultiShift(self
.width
-1)
216 mw
= Const(self
.width
-1, len(self
.diff
))
217 maxslen
= Mux(self
.diff
> mw
, mw
, self
.diff
)
218 maxsleni
= mw
- maxslen
220 # shift mantissa by maxslen, mask by inverse
221 rs
.eq(sm
.rshift(self
.inp
.m
[1:], maxslen
)),
222 m_mask
.eq(sm
.rshift(self
.inp
.m1s
[1:], maxsleni
)),
223 smask
.eq(self
.inp
.m
[1:] & m_mask
),
224 # sticky bit combines all mask (and mantissa low bit)
225 stickybit
.eq(smask
.bool() | self
.inp
.m
[0]),
226 #self.s.eq(self.inp.s),
227 #self.e.eq(self.inp.e + diff),
228 # mantissa result contains m[0] already.
229 self
.m
.eq(Cat(stickybit
, rs
))
234 class FPNumShift(FPNumBase
):
235 """ Floating-point Number Class for shifting
237 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
238 FPNumBase
.__init
__(self
, width
, m_extra
)
239 self
.latch_in
= Signal()
244 def elaborate(self
, platform
):
245 m
= FPNumBase
.elaborate(self
, platform
)
247 m
.d
.comb
+= self
.s
.eq(op
.s
)
248 m
.d
.comb
+= self
.e
.eq(op
.e
)
249 m
.d
.comb
+= self
.m
.eq(op
.m
)
251 with self
.mainm
.State("align"):
252 with m
.If(self
.e
< self
.inv
.e
):
253 m
.d
.sync
+= self
.shift_down()
257 def shift_down(self
, inp
):
258 """ shifts a mantissa down by one. exponent is increased to compensate
260 accuracy is lost as a result in the mantissa however there are 3
261 guard bits (the latter of which is the "sticky" bit)
263 return [self
.e
.eq(inp
.e
+ 1),
264 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
267 def shift_down_multi(self
, diff
):
268 """ shifts a mantissa down. exponent is increased to compensate
270 accuracy is lost as a result in the mantissa however there are 3
271 guard bits (the latter of which is the "sticky" bit)
273 this code works by variable-shifting the mantissa by up to
274 its maximum bit-length: no point doing more (it'll still be
277 the sticky bit is computed by shifting a batch of 1s by
278 the same amount, which will introduce zeros. it's then
279 inverted and used as a mask to get the LSBs of the mantissa.
280 those are then |'d into the sticky bit.
282 sm
= MultiShift(self
.width
)
283 mw
= Const(self
.m_width
-1, len(diff
))
284 maxslen
= Mux(diff
> mw
, mw
, diff
)
285 rs
= sm
.rshift(self
.m
[1:], maxslen
)
286 maxsleni
= mw
- maxslen
287 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
289 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
290 return [self
.e
.eq(self
.e
+ diff
),
291 self
.m
.eq(Cat(stickybits
, rs
))
294 def shift_up_multi(self
, diff
):
295 """ shifts a mantissa up. exponent is decreased to compensate
297 sm
= MultiShift(self
.width
)
298 mw
= Const(self
.m_width
, len(diff
))
299 maxslen
= Mux(diff
> mw
, mw
, diff
)
301 return [self
.e
.eq(self
.e
- diff
),
302 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
305 class FPNumIn(FPNumBase
):
306 """ Floating-point Number Class
308 Contains signals for an incoming copy of the value, decoded into
309 sign / exponent / mantissa.
310 Also contains encoding functions, creation and recognition of
311 zero, NaN and inf (all signed)
313 Four extra bits are included in the mantissa: the top bit
314 (m[-1]) is effectively a carry-overflow. The other three are
315 guard (m[2]), round (m[1]), and sticky (m[0])
317 def __init__(self
, op
, width
, m_extra
=True):
318 FPNumBase
.__init
__(self
, width
, m_extra
)
319 self
.latch_in
= Signal()
322 def elaborate(self
, platform
):
323 m
= FPNumBase
.elaborate(self
, platform
)
325 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
326 #with m.If(self.latch_in):
327 # m.d.sync += self.decode(self.v)
332 """ decodes a latched value into sign / exponent / mantissa
334 bias is subtracted here, from the exponent. exponent
335 is extended to 10 bits so that subtract 127 is done on
338 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
339 #print ("decode", self.e_end)
340 return [self
.m
.eq(Cat(*args
)), # mantissa
341 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
342 self
.s
.eq(v
[-1]), # sign
345 def shift_down(self
, inp
):
346 """ shifts a mantissa down by one. exponent is increased to compensate
348 accuracy is lost as a result in the mantissa however there are 3
349 guard bits (the latter of which is the "sticky" bit)
351 return [self
.e
.eq(inp
.e
+ 1),
352 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
355 def shift_down_multi(self
, diff
, inp
=None):
356 """ shifts a mantissa down. exponent is increased to compensate
358 accuracy is lost as a result in the mantissa however there are 3
359 guard bits (the latter of which is the "sticky" bit)
361 this code works by variable-shifting the mantissa by up to
362 its maximum bit-length: no point doing more (it'll still be
365 the sticky bit is computed by shifting a batch of 1s by
366 the same amount, which will introduce zeros. it's then
367 inverted and used as a mask to get the LSBs of the mantissa.
368 those are then |'d into the sticky bit.
372 sm
= MultiShift(self
.width
)
373 mw
= Const(self
.m_width
-1, len(diff
))
374 maxslen
= Mux(diff
> mw
, mw
, diff
)
375 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
376 maxsleni
= mw
- maxslen
377 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
379 #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
380 stickybits
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
381 return [self
.e
.eq(inp
.e
+ diff
),
382 self
.m
.eq(Cat(stickybits
, rs
))
385 def shift_up_multi(self
, diff
):
386 """ shifts a mantissa up. exponent is decreased to compensate
388 sm
= MultiShift(self
.width
)
389 mw
= Const(self
.m_width
, len(diff
))
390 maxslen
= Mux(diff
> mw
, mw
, diff
)
392 return [self
.e
.eq(self
.e
- diff
),
393 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
397 def __init__(self
, width
):
400 self
.v
= Signal(width
)
401 self
.stb
= Signal(reset
=0)
403 self
.trigger
= Signal(reset_less
=True)
405 def elaborate(self
, platform
):
407 m
.d
.sync
+= self
.trigger
.eq(self
.stb
& self
.ack
)
410 def chain_inv(self
, in_op
, extra
=None):
412 if extra
is not None:
414 return [self
.v
.eq(in_op
.v
), # receive value
415 self
.stb
.eq(stb
), # receive STB
416 in_op
.ack
.eq(~self
.ack
), # send ACK
419 def chain_from(self
, in_op
, extra
=None):
421 if extra
is not None:
423 return [self
.v
.eq(in_op
.v
), # receive value
424 self
.stb
.eq(stb
), # receive STB
425 in_op
.ack
.eq(self
.ack
), # send ACK
429 return [self
.v
.eq(inp
.v
),
430 self
.stb
.eq(inp
.stb
),
435 return [self
.v
, self
.stb
, self
.ack
]
440 self
.guard
= Signal(reset_less
=True) # tot[2]
441 self
.round_bit
= Signal(reset_less
=True) # tot[1]
442 self
.sticky
= Signal(reset_less
=True) # tot[0]
443 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
445 self
.roundz
= Signal(reset_less
=True)
448 return [self
.guard
.eq(inp
.guard
),
449 self
.round_bit
.eq(inp
.round_bit
),
450 self
.sticky
.eq(inp
.sticky
),
453 def elaborate(self
, platform
):
455 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
456 (self
.round_bit | self
.sticky | self
.m0
))
461 """ IEEE754 Floating Point Base Class
463 contains common functions for FP manipulation, such as
464 extracting and packing operands, normalisation, denormalisation,
468 def get_op(self
, m
, op
, v
, next_state
):
469 """ this function moves to the next state and copies the operand
470 when both stb and ack are 1.
471 acknowledgement is sent by setting ack to ZERO.
473 with m
.If((op
.ack
) & (op
.stb
)):
476 # op is latched in from FPNumIn class on same ack/stb
481 m
.d
.sync
+= op
.ack
.eq(1)
483 def denormalise(self
, m
, a
):
484 """ denormalises a number. this is probably the wrong name for
485 this function. for normalised numbers (exponent != minimum)
486 one *extra* bit (the implicit 1) is added *back in*.
487 for denormalised numbers, the mantissa is left alone
488 and the exponent increased by 1.
490 both cases *effectively multiply the number stored by 2*,
491 which has to be taken into account when extracting the result.
493 with m
.If(a
.exp_n127
):
494 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
496 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
498 def op_normalise(self
, m
, op
, next_state
):
499 """ operand normalisation
500 NOTE: just like "align", this one keeps going round every clock
501 until the result's exponent is within acceptable "range"
503 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
505 op
.e
.eq(op
.e
- 1), # DECREASE exponent
506 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
511 def normalise_1(self
, m
, z
, of
, next_state
):
512 """ first stage normalisation
514 NOTE: just like "align", this one keeps going round every clock
515 until the result's exponent is within acceptable "range"
516 NOTE: the weirdness of reassigning guard and round is due to
517 the extra mantissa bits coming from tot[0..2]
519 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
521 z
.e
.eq(z
.e
- 1), # DECREASE exponent
522 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
523 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
524 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
525 of
.round_bit
.eq(0), # reset round bit
531 def normalise_2(self
, m
, z
, of
, next_state
):
532 """ second stage normalisation
534 NOTE: just like "align", this one keeps going round every clock
535 until the result's exponent is within acceptable "range"
536 NOTE: the weirdness of reassigning guard and round is due to
537 the extra mantissa bits coming from tot[0..2]
539 with m
.If(z
.e
< z
.N126
):
541 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
542 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
545 of
.round_bit
.eq(of
.guard
),
546 of
.sticky
.eq(of
.sticky | of
.round_bit
)
551 def roundz(self
, m
, z
, out_z
, roundz
):
552 """ performs rounding on the output. TODO: different kinds of rounding
554 m
.d
.comb
+= out_z
.copy(z
) # copies input to output first
556 m
.d
.comb
+= out_z
.m
.eq(z
.m
+ 1) # mantissa rounds up
557 with m
.If(z
.m
== z
.m1s
): # all 1s
558 m
.d
.comb
+= out_z
.e
.eq(z
.e
+ 1) # exponent rounds up
560 def corrections(self
, m
, z
, next_state
):
561 """ denormalisation and sign-bug corrections
564 # denormalised, correct exponent to zero
565 with m
.If(z
.is_denormalised
):
566 m
.d
.sync
+= z
.e
.eq(z
.N127
)
568 def pack(self
, m
, z
, next_state
):
569 """ packs the result into the output (detects overflow->Inf)
572 # if overflow occurs, return inf
573 with m
.If(z
.is_overflowed
):
574 m
.d
.sync
+= z
.inf(z
.s
)
576 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
578 def put_z(self
, m
, z
, out_z
, next_state
):
579 """ put_z: stores the result in the output. raises stb and waits
580 for ack to be set to 1 before moving to the next state.
581 resets stb back to zero when that occurs, as acknowledgement.
586 with m
.If(out_z
.stb
& out_z
.ack
):
587 m
.d
.sync
+= out_z
.stb
.eq(0)
590 m
.d
.sync
+= out_z
.stb
.eq(1)