small optimisation, move subtraction of -126 from exponent into FPNumBase module...
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
100 self.exp_lt_n126 = Signal(reset_less=True)
101 self.exp_gt_n126 = Signal(reset_less=True)
102 self.exp_gt127 = Signal(reset_less=True)
103 self.exp_n127 = Signal(reset_less=True)
104 self.exp_n126 = Signal(reset_less=True)
105 self.m_zero = Signal(reset_less=True)
106 self.m_msbzero = Signal(reset_less=True)
107
108 def elaborate(self, platform):
109 m = Module()
110 m.d.comb += self.is_nan.eq(self._is_nan())
111 m.d.comb += self.is_zero.eq(self._is_zero())
112 m.d.comb += self.is_inf.eq(self._is_inf())
113 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
114 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
115 m.d.comb += self.exp_128.eq(self.e == self.P128)
116 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
117 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
118 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
119 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
120 m.d.comb += self.exp_n127.eq(self.e == self.N127)
121 m.d.comb += self.exp_n126.eq(self.e == self.N126)
122 m.d.comb += self.m_zero.eq(self.m == self.mzero)
123 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
124
125 return m
126
127 def _is_nan(self):
128 return (self.exp_128) & (~self.m_zero)
129
130 def _is_inf(self):
131 return (self.exp_128) & (self.m_zero)
132
133 def _is_zero(self):
134 return (self.exp_n127) & (self.m_zero)
135
136 def _is_overflowed(self):
137 return self.exp_gt127
138
139 def _is_denormalised(self):
140 return (self.exp_n126) & (self.m_msbzero)
141
142 def copy(self, inp):
143 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
144
145
146 class FPNumOut(FPNumBase):
147 """ Floating-point Number Class
148
149 Contains signals for an incoming copy of the value, decoded into
150 sign / exponent / mantissa.
151 Also contains encoding functions, creation and recognition of
152 zero, NaN and inf (all signed)
153
154 Four extra bits are included in the mantissa: the top bit
155 (m[-1]) is effectively a carry-overflow. The other three are
156 guard (m[2]), round (m[1]), and sticky (m[0])
157 """
158 def __init__(self, width, m_extra=True):
159 FPNumBase.__init__(self, width, m_extra)
160
161 def elaborate(self, platform):
162 m = FPNumBase.elaborate(self, platform)
163
164 return m
165
166 def create(self, s, e, m):
167 """ creates a value from sign / exponent / mantissa
168
169 bias is added here, to the exponent
170 """
171 return [
172 self.v[-1].eq(s), # sign
173 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
174 self.v[0:self.e_start].eq(m) # mantissa
175 ]
176
177 def nan(self, s):
178 return self.create(s, self.P128, 1<<(self.e_start-1))
179
180 def inf(self, s):
181 return self.create(s, self.P128, 0)
182
183 def zero(self, s):
184 return self.create(s, self.N127, 0)
185
186
187 class FPNumShiftMultiRight(FPNumBase):
188 """ shifts a mantissa down. exponent is increased to compensate
189
190 accuracy is lost as a result in the mantissa however there are 3
191 guard bits (the latter of which is the "sticky" bit)
192
193 this code works by variable-shifting the mantissa by up to
194 its maximum bit-length: no point doing more (it'll still be
195 zero).
196
197 the sticky bit is computed by shifting a batch of 1s by
198 the same amount, which will introduce zeros. it's then
199 inverted and used as a mask to get the LSBs of the mantissa.
200 those are then |'d into the sticky bit.
201 """
202 def __init__(self, inp, diff, width):
203 self.m = Signal(width, reset_less=True)
204 self.inp = inp
205 self.diff = diff
206 self.width = width
207
208 def elaborate(self, platform):
209 m = Module()
210 #m.submodules.inp = self.inp
211
212 rs = Signal(self.width, reset_less=True)
213 m_mask = Signal(self.width, reset_less=True)
214 smask = Signal(self.width, reset_less=True)
215 stickybit = Signal(reset_less=True)
216
217 sm = MultiShift(self.width-1)
218 mw = Const(self.width-1, len(self.diff))
219 maxslen = Mux(self.diff > mw, mw, self.diff)
220 maxsleni = mw - maxslen
221 m.d.comb += [
222 # shift mantissa by maxslen, mask by inverse
223 rs.eq(sm.rshift(self.inp.m[1:], maxslen)),
224 m_mask.eq(sm.rshift(self.inp.m1s[1:], maxsleni)),
225 smask.eq(self.inp.m[1:] & m_mask),
226 # sticky bit combines all mask (and mantissa low bit)
227 stickybit.eq(smask.bool() | self.inp.m[0]),
228 #self.s.eq(self.inp.s),
229 #self.e.eq(self.inp.e + diff),
230 # mantissa result contains m[0] already.
231 self.m.eq(Cat(stickybit, rs))
232 ]
233 return m
234
235
236 class FPNumShift(FPNumBase):
237 """ Floating-point Number Class for shifting
238 """
239 def __init__(self, mainm, op, inv, width, m_extra=True):
240 FPNumBase.__init__(self, width, m_extra)
241 self.latch_in = Signal()
242 self.mainm = mainm
243 self.inv = inv
244 self.op = op
245
246 def elaborate(self, platform):
247 m = FPNumBase.elaborate(self, platform)
248
249 m.d.comb += self.s.eq(op.s)
250 m.d.comb += self.e.eq(op.e)
251 m.d.comb += self.m.eq(op.m)
252
253 with self.mainm.State("align"):
254 with m.If(self.e < self.inv.e):
255 m.d.sync += self.shift_down()
256
257 return m
258
259 def shift_down(self, inp):
260 """ shifts a mantissa down by one. exponent is increased to compensate
261
262 accuracy is lost as a result in the mantissa however there are 3
263 guard bits (the latter of which is the "sticky" bit)
264 """
265 return [self.e.eq(inp.e + 1),
266 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
267 ]
268
269 def shift_down_multi(self, diff):
270 """ shifts a mantissa down. exponent is increased to compensate
271
272 accuracy is lost as a result in the mantissa however there are 3
273 guard bits (the latter of which is the "sticky" bit)
274
275 this code works by variable-shifting the mantissa by up to
276 its maximum bit-length: no point doing more (it'll still be
277 zero).
278
279 the sticky bit is computed by shifting a batch of 1s by
280 the same amount, which will introduce zeros. it's then
281 inverted and used as a mask to get the LSBs of the mantissa.
282 those are then |'d into the sticky bit.
283 """
284 sm = MultiShift(self.width)
285 mw = Const(self.m_width-1, len(diff))
286 maxslen = Mux(diff > mw, mw, diff)
287 rs = sm.rshift(self.m[1:], maxslen)
288 maxsleni = mw - maxslen
289 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
290
291 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
292 return [self.e.eq(self.e + diff),
293 self.m.eq(Cat(stickybits, rs))
294 ]
295
296 def shift_up_multi(self, diff):
297 """ shifts a mantissa up. exponent is decreased to compensate
298 """
299 sm = MultiShift(self.width)
300 mw = Const(self.m_width, len(diff))
301 maxslen = Mux(diff > mw, mw, diff)
302
303 return [self.e.eq(self.e - diff),
304 self.m.eq(sm.lshift(self.m, maxslen))
305 ]
306
307 class FPNumIn(FPNumBase):
308 """ Floating-point Number Class
309
310 Contains signals for an incoming copy of the value, decoded into
311 sign / exponent / mantissa.
312 Also contains encoding functions, creation and recognition of
313 zero, NaN and inf (all signed)
314
315 Four extra bits are included in the mantissa: the top bit
316 (m[-1]) is effectively a carry-overflow. The other three are
317 guard (m[2]), round (m[1]), and sticky (m[0])
318 """
319 def __init__(self, op, width, m_extra=True):
320 FPNumBase.__init__(self, width, m_extra)
321 self.latch_in = Signal()
322 self.op = op
323
324 def elaborate(self, platform):
325 m = FPNumBase.elaborate(self, platform)
326
327 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
328 #with m.If(self.latch_in):
329 # m.d.sync += self.decode(self.v)
330
331 return m
332
333 def decode(self, v):
334 """ decodes a latched value into sign / exponent / mantissa
335
336 bias is subtracted here, from the exponent. exponent
337 is extended to 10 bits so that subtract 127 is done on
338 a 10-bit number
339 """
340 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
341 #print ("decode", self.e_end)
342 return [self.m.eq(Cat(*args)), # mantissa
343 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
344 self.s.eq(v[-1]), # sign
345 ]
346
347 def shift_down(self, inp):
348 """ shifts a mantissa down by one. exponent is increased to compensate
349
350 accuracy is lost as a result in the mantissa however there are 3
351 guard bits (the latter of which is the "sticky" bit)
352 """
353 return [self.e.eq(inp.e + 1),
354 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
355 ]
356
357 def shift_down_multi(self, diff, inp=None):
358 """ shifts a mantissa down. exponent is increased to compensate
359
360 accuracy is lost as a result in the mantissa however there are 3
361 guard bits (the latter of which is the "sticky" bit)
362
363 this code works by variable-shifting the mantissa by up to
364 its maximum bit-length: no point doing more (it'll still be
365 zero).
366
367 the sticky bit is computed by shifting a batch of 1s by
368 the same amount, which will introduce zeros. it's then
369 inverted and used as a mask to get the LSBs of the mantissa.
370 those are then |'d into the sticky bit.
371 """
372 if inp is None:
373 inp = self
374 sm = MultiShift(self.width)
375 mw = Const(self.m_width-1, len(diff))
376 maxslen = Mux(diff > mw, mw, diff)
377 rs = sm.rshift(inp.m[1:], maxslen)
378 maxsleni = mw - maxslen
379 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
380
381 #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
382 stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0]
383 return [self.e.eq(inp.e + diff),
384 self.m.eq(Cat(stickybits, rs))
385 ]
386
387 def shift_up_multi(self, diff):
388 """ shifts a mantissa up. exponent is decreased to compensate
389 """
390 sm = MultiShift(self.width)
391 mw = Const(self.m_width, len(diff))
392 maxslen = Mux(diff > mw, mw, diff)
393
394 return [self.e.eq(self.e - diff),
395 self.m.eq(sm.lshift(self.m, maxslen))
396 ]
397
398 class FPOp:
399 def __init__(self, width):
400 self.width = width
401
402 self.v = Signal(width)
403 self.stb = Signal(reset=0)
404 self.ack = Signal()
405 self.trigger = Signal(reset_less=True)
406
407 def elaborate(self, platform):
408 m = Module()
409 m.d.sync += self.trigger.eq(self.stb & self.ack)
410 return m
411
412 def chain_inv(self, in_op, extra=None):
413 stb = in_op.stb
414 if extra is not None:
415 stb = stb & extra
416 return [self.v.eq(in_op.v), # receive value
417 self.stb.eq(stb), # receive STB
418 in_op.ack.eq(~self.ack), # send ACK
419 ]
420
421 def chain_from(self, in_op, extra=None):
422 stb = in_op.stb
423 if extra is not None:
424 stb = stb & extra
425 return [self.v.eq(in_op.v), # receive value
426 self.stb.eq(stb), # receive STB
427 in_op.ack.eq(self.ack), # send ACK
428 ]
429
430 def copy(self, inp):
431 return [self.v.eq(inp.v),
432 self.stb.eq(inp.stb),
433 self.ack.eq(inp.ack)
434 ]
435
436 def ports(self):
437 return [self.v, self.stb, self.ack]
438
439
440 class Overflow:
441 def __init__(self):
442 self.guard = Signal(reset_less=True) # tot[2]
443 self.round_bit = Signal(reset_less=True) # tot[1]
444 self.sticky = Signal(reset_less=True) # tot[0]
445 self.m0 = Signal(reset_less=True) # mantissa zero bit
446
447 self.roundz = Signal(reset_less=True)
448
449 def copy(self, inp):
450 return [self.guard.eq(inp.guard),
451 self.round_bit.eq(inp.round_bit),
452 self.sticky.eq(inp.sticky),
453 self.m0.eq(inp.m0)]
454
455 def elaborate(self, platform):
456 m = Module()
457 m.d.comb += self.roundz.eq(self.guard & \
458 (self.round_bit | self.sticky | self.m0))
459 return m
460
461
462 class FPBase:
463 """ IEEE754 Floating Point Base Class
464
465 contains common functions for FP manipulation, such as
466 extracting and packing operands, normalisation, denormalisation,
467 rounding etc.
468 """
469
470 def get_op(self, m, op, v, next_state):
471 """ this function moves to the next state and copies the operand
472 when both stb and ack are 1.
473 acknowledgement is sent by setting ack to ZERO.
474 """
475 with m.If((op.ack) & (op.stb)):
476 m.next = next_state
477 m.d.sync += [
478 # op is latched in from FPNumIn class on same ack/stb
479 v.decode(op.v),
480 op.ack.eq(0)
481 ]
482 with m.Else():
483 m.d.sync += op.ack.eq(1)
484
485 def denormalise(self, m, a):
486 """ denormalises a number. this is probably the wrong name for
487 this function. for normalised numbers (exponent != minimum)
488 one *extra* bit (the implicit 1) is added *back in*.
489 for denormalised numbers, the mantissa is left alone
490 and the exponent increased by 1.
491
492 both cases *effectively multiply the number stored by 2*,
493 which has to be taken into account when extracting the result.
494 """
495 with m.If(a.exp_n127):
496 m.d.sync += a.e.eq(a.N126) # limit a exponent
497 with m.Else():
498 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
499
500 def op_normalise(self, m, op, next_state):
501 """ operand normalisation
502 NOTE: just like "align", this one keeps going round every clock
503 until the result's exponent is within acceptable "range"
504 """
505 with m.If((op.m[-1] == 0)): # check last bit of mantissa
506 m.d.sync +=[
507 op.e.eq(op.e - 1), # DECREASE exponent
508 op.m.eq(op.m << 1), # shift mantissa UP
509 ]
510 with m.Else():
511 m.next = next_state
512
513 def normalise_1(self, m, z, of, next_state):
514 """ first stage normalisation
515
516 NOTE: just like "align", this one keeps going round every clock
517 until the result's exponent is within acceptable "range"
518 NOTE: the weirdness of reassigning guard and round is due to
519 the extra mantissa bits coming from tot[0..2]
520 """
521 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
522 m.d.sync += [
523 z.e.eq(z.e - 1), # DECREASE exponent
524 z.m.eq(z.m << 1), # shift mantissa UP
525 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
526 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
527 of.round_bit.eq(0), # reset round bit
528 of.m0.eq(of.guard),
529 ]
530 with m.Else():
531 m.next = next_state
532
533 def normalise_2(self, m, z, of, next_state):
534 """ second stage normalisation
535
536 NOTE: just like "align", this one keeps going round every clock
537 until the result's exponent is within acceptable "range"
538 NOTE: the weirdness of reassigning guard and round is due to
539 the extra mantissa bits coming from tot[0..2]
540 """
541 with m.If(z.e < z.N126):
542 m.d.sync +=[
543 z.e.eq(z.e + 1), # INCREASE exponent
544 z.m.eq(z.m >> 1), # shift mantissa DOWN
545 of.guard.eq(z.m[0]),
546 of.m0.eq(z.m[1]),
547 of.round_bit.eq(of.guard),
548 of.sticky.eq(of.sticky | of.round_bit)
549 ]
550 with m.Else():
551 m.next = next_state
552
553 def roundz(self, m, z, out_z, roundz):
554 """ performs rounding on the output. TODO: different kinds of rounding
555 """
556 m.d.comb += out_z.copy(z) # copies input to output first
557 with m.If(roundz):
558 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
559 with m.If(z.m == z.m1s): # all 1s
560 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
561
562 def corrections(self, m, z, next_state):
563 """ denormalisation and sign-bug corrections
564 """
565 m.next = next_state
566 # denormalised, correct exponent to zero
567 with m.If(z.is_denormalised):
568 m.d.sync += z.e.eq(z.N127)
569
570 def pack(self, m, z, next_state):
571 """ packs the result into the output (detects overflow->Inf)
572 """
573 m.next = next_state
574 # if overflow occurs, return inf
575 with m.If(z.is_overflowed):
576 m.d.sync += z.inf(z.s)
577 with m.Else():
578 m.d.sync += z.create(z.s, z.e, z.m)
579
580 def put_z(self, m, z, out_z, next_state):
581 """ put_z: stores the result in the output. raises stb and waits
582 for ack to be set to 1 before moving to the next state.
583 resets stb back to zero when that occurs, as acknowledgement.
584 """
585 m.d.sync += [
586 out_z.v.eq(z.v)
587 ]
588 with m.If(out_z.stb & out_z.ack):
589 m.d.sync += out_z.stb.eq(0)
590 m.next = next_state
591 with m.Else():
592 m.d.sync += out_z.stb.eq(1)
593
594