add in FPNumShiftMultiRight class
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_lt_n126 = Signal(reset_less=True)
100 self.exp_gt_n126 = Signal(reset_less=True)
101 self.exp_gt127 = Signal(reset_less=True)
102 self.exp_n127 = Signal(reset_less=True)
103 self.exp_n126 = Signal(reset_less=True)
104 self.m_zero = Signal(reset_less=True)
105 self.m_msbzero = Signal(reset_less=True)
106
107 def elaborate(self, platform):
108 m = Module()
109 m.d.comb += self.is_nan.eq(self._is_nan())
110 m.d.comb += self.is_zero.eq(self._is_zero())
111 m.d.comb += self.is_inf.eq(self._is_inf())
112 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
113 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
114 m.d.comb += self.exp_128.eq(self.e == self.P128)
115 m.d.comb += self.exp_gt_n126.eq(self.e > self.N126)
116 m.d.comb += self.exp_lt_n126.eq(self.e < self.N126)
117 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
118 m.d.comb += self.exp_n127.eq(self.e == self.N127)
119 m.d.comb += self.exp_n126.eq(self.e == self.N126)
120 m.d.comb += self.m_zero.eq(self.m == self.mzero)
121 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
122
123 return m
124
125 def _is_nan(self):
126 return (self.exp_128) & (~self.m_zero)
127
128 def _is_inf(self):
129 return (self.exp_128) & (self.m_zero)
130
131 def _is_zero(self):
132 return (self.exp_n127) & (self.m_zero)
133
134 def _is_overflowed(self):
135 return self.exp_gt127
136
137 def _is_denormalised(self):
138 return (self.exp_n126) & (self.m_msbzero)
139
140 def copy(self, inp):
141 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
142
143
144 class FPNumOut(FPNumBase):
145 """ Floating-point Number Class
146
147 Contains signals for an incoming copy of the value, decoded into
148 sign / exponent / mantissa.
149 Also contains encoding functions, creation and recognition of
150 zero, NaN and inf (all signed)
151
152 Four extra bits are included in the mantissa: the top bit
153 (m[-1]) is effectively a carry-overflow. The other three are
154 guard (m[2]), round (m[1]), and sticky (m[0])
155 """
156 def __init__(self, width, m_extra=True):
157 FPNumBase.__init__(self, width, m_extra)
158
159 def elaborate(self, platform):
160 m = FPNumBase.elaborate(self, platform)
161
162 return m
163
164 def create(self, s, e, m):
165 """ creates a value from sign / exponent / mantissa
166
167 bias is added here, to the exponent
168 """
169 return [
170 self.v[-1].eq(s), # sign
171 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
172 self.v[0:self.e_start].eq(m) # mantissa
173 ]
174
175 def nan(self, s):
176 return self.create(s, self.P128, 1<<(self.e_start-1))
177
178 def inf(self, s):
179 return self.create(s, self.P128, 0)
180
181 def zero(self, s):
182 return self.create(s, self.N127, 0)
183
184
185 class FPNumShiftMultiRight(FPNumBase):
186 """ shifts a mantissa down. exponent is increased to compensate
187
188 accuracy is lost as a result in the mantissa however there are 3
189 guard bits (the latter of which is the "sticky" bit)
190
191 this code works by variable-shifting the mantissa by up to
192 its maximum bit-length: no point doing more (it'll still be
193 zero).
194
195 the sticky bit is computed by shifting a batch of 1s by
196 the same amount, which will introduce zeros. it's then
197 inverted and used as a mask to get the LSBs of the mantissa.
198 those are then |'d into the sticky bit.
199 """
200 def __init__(self, inp, diff, width, m_extra=True):
201 FPNumBase.__init__(self, width, m_extra)
202 self.inp = inp
203 self.diff = diff
204
205 def elaborate(self, platform):
206 m = FPNumBase.elaborate(self, platform)
207 m.submodules.inp = self.inp
208
209 rs = Signal(self.width, reset_less=True)
210 m_mask = Signal(self.width, reset_less=True)
211 smask = Signal(self.width, reset_less=True)
212 stickybit = Signal(reset_less=True)
213
214 sm = MultiShift(self.m_width-1)
215 mw = Const(self.m_width-1, len(self.diff))
216 maxslen = Mux(diff > mw, mw, self.diff)
217 maxsleni = mw - maxslen
218 m.d.comb += [
219 # shift mantissa by maxslen, mask by inverse
220 rs.eq(sm.rshift(self.inp.m[1:], maxslen)),
221 m_mask.eq(sm.rshift(self.m1s[1:], maxsleni)),
222 smask.eq(inp.m[1:] & m_mask),
223 # sticky bit combines all mask (and mantissa low bit)
224 stickybit.eq(smask.bool() | inp.m[0]),
225 self.s.eq(self.inp.s),
226 self.e.eq(self.inp.e + diff),
227 # mantissa result contains m[0] already.
228 self.m.eq(Cat(stickybit, rs))
229 ]
230 return m
231
232
233 class FPNumShift(FPNumBase):
234 """ Floating-point Number Class for shifting
235 """
236 def __init__(self, mainm, op, inv, width, m_extra=True):
237 FPNumBase.__init__(self, width, m_extra)
238 self.latch_in = Signal()
239 self.mainm = mainm
240 self.inv = inv
241 self.op = op
242
243 def elaborate(self, platform):
244 m = FPNumBase.elaborate(self, platform)
245
246 m.d.comb += self.s.eq(op.s)
247 m.d.comb += self.e.eq(op.e)
248 m.d.comb += self.m.eq(op.m)
249
250 with self.mainm.State("align"):
251 with m.If(self.e < self.inv.e):
252 m.d.sync += self.shift_down()
253
254 return m
255
256 def shift_down(self, inp):
257 """ shifts a mantissa down by one. exponent is increased to compensate
258
259 accuracy is lost as a result in the mantissa however there are 3
260 guard bits (the latter of which is the "sticky" bit)
261 """
262 return [self.e.eq(inp.e + 1),
263 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
264 ]
265
266 def shift_down_multi(self, diff):
267 """ shifts a mantissa down. exponent is increased to compensate
268
269 accuracy is lost as a result in the mantissa however there are 3
270 guard bits (the latter of which is the "sticky" bit)
271
272 this code works by variable-shifting the mantissa by up to
273 its maximum bit-length: no point doing more (it'll still be
274 zero).
275
276 the sticky bit is computed by shifting a batch of 1s by
277 the same amount, which will introduce zeros. it's then
278 inverted and used as a mask to get the LSBs of the mantissa.
279 those are then |'d into the sticky bit.
280 """
281 sm = MultiShift(self.width)
282 mw = Const(self.m_width-1, len(diff))
283 maxslen = Mux(diff > mw, mw, diff)
284 rs = sm.rshift(self.m[1:], maxslen)
285 maxsleni = mw - maxslen
286 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
287
288 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
289 return [self.e.eq(self.e + diff),
290 self.m.eq(Cat(stickybits, rs))
291 ]
292
293 def shift_up_multi(self, diff):
294 """ shifts a mantissa up. exponent is decreased to compensate
295 """
296 sm = MultiShift(self.width)
297 mw = Const(self.m_width, len(diff))
298 maxslen = Mux(diff > mw, mw, diff)
299
300 return [self.e.eq(self.e - diff),
301 self.m.eq(sm.lshift(self.m, maxslen))
302 ]
303
304 class FPNumIn(FPNumBase):
305 """ Floating-point Number Class
306
307 Contains signals for an incoming copy of the value, decoded into
308 sign / exponent / mantissa.
309 Also contains encoding functions, creation and recognition of
310 zero, NaN and inf (all signed)
311
312 Four extra bits are included in the mantissa: the top bit
313 (m[-1]) is effectively a carry-overflow. The other three are
314 guard (m[2]), round (m[1]), and sticky (m[0])
315 """
316 def __init__(self, op, width, m_extra=True):
317 FPNumBase.__init__(self, width, m_extra)
318 self.latch_in = Signal()
319 self.op = op
320
321 def elaborate(self, platform):
322 m = FPNumBase.elaborate(self, platform)
323
324 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
325 #with m.If(self.latch_in):
326 # m.d.sync += self.decode(self.v)
327
328 return m
329
330 def decode(self, v):
331 """ decodes a latched value into sign / exponent / mantissa
332
333 bias is subtracted here, from the exponent. exponent
334 is extended to 10 bits so that subtract 127 is done on
335 a 10-bit number
336 """
337 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
338 #print ("decode", self.e_end)
339 return [self.m.eq(Cat(*args)), # mantissa
340 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
341 self.s.eq(v[-1]), # sign
342 ]
343
344 def shift_down(self, inp):
345 """ shifts a mantissa down by one. exponent is increased to compensate
346
347 accuracy is lost as a result in the mantissa however there are 3
348 guard bits (the latter of which is the "sticky" bit)
349 """
350 return [self.e.eq(inp.e + 1),
351 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
352 ]
353
354 def shift_down_multi(self, diff, inp=None):
355 """ shifts a mantissa down. exponent is increased to compensate
356
357 accuracy is lost as a result in the mantissa however there are 3
358 guard bits (the latter of which is the "sticky" bit)
359
360 this code works by variable-shifting the mantissa by up to
361 its maximum bit-length: no point doing more (it'll still be
362 zero).
363
364 the sticky bit is computed by shifting a batch of 1s by
365 the same amount, which will introduce zeros. it's then
366 inverted and used as a mask to get the LSBs of the mantissa.
367 those are then |'d into the sticky bit.
368 """
369 if inp is None:
370 inp = self
371 sm = MultiShift(self.width)
372 mw = Const(self.m_width-1, len(diff))
373 maxslen = Mux(diff > mw, mw, diff)
374 rs = sm.rshift(inp.m[1:], maxslen)
375 maxsleni = mw - maxslen
376 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
377
378 #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
379 stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0]
380 return [self.e.eq(inp.e + diff),
381 self.m.eq(Cat(stickybits, rs))
382 ]
383
384 def shift_up_multi(self, diff):
385 """ shifts a mantissa up. exponent is decreased to compensate
386 """
387 sm = MultiShift(self.width)
388 mw = Const(self.m_width, len(diff))
389 maxslen = Mux(diff > mw, mw, diff)
390
391 return [self.e.eq(self.e - diff),
392 self.m.eq(sm.lshift(self.m, maxslen))
393 ]
394
395 class FPOp:
396 def __init__(self, width):
397 self.width = width
398
399 self.v = Signal(width)
400 self.stb = Signal(reset=0)
401 self.ack = Signal()
402 self.trigger = Signal(reset_less=True)
403
404 def elaborate(self, platform):
405 m = Module()
406 m.d.sync += self.trigger.eq(self.stb & self.ack)
407 return m
408
409 def chain_inv(self, in_op, extra=None):
410 stb = in_op.stb
411 if extra is not None:
412 stb = stb & extra
413 return [self.v.eq(in_op.v), # receive value
414 self.stb.eq(stb), # receive STB
415 in_op.ack.eq(~self.ack), # send ACK
416 ]
417
418 def chain_from(self, in_op, extra=None):
419 stb = in_op.stb
420 if extra is not None:
421 stb = stb & extra
422 return [self.v.eq(in_op.v), # receive value
423 self.stb.eq(stb), # receive STB
424 in_op.ack.eq(self.ack), # send ACK
425 ]
426
427 def copy(self, inp):
428 return [self.v.eq(inp.v),
429 self.stb.eq(inp.stb),
430 self.ack.eq(inp.ack)
431 ]
432
433 def ports(self):
434 return [self.v, self.stb, self.ack]
435
436
437 class Overflow:
438 def __init__(self):
439 self.guard = Signal(reset_less=True) # tot[2]
440 self.round_bit = Signal(reset_less=True) # tot[1]
441 self.sticky = Signal(reset_less=True) # tot[0]
442 self.m0 = Signal(reset_less=True) # mantissa zero bit
443
444 self.roundz = Signal(reset_less=True)
445
446 def copy(self, inp):
447 return [self.guard.eq(inp.guard),
448 self.round_bit.eq(inp.round_bit),
449 self.sticky.eq(inp.sticky),
450 self.m0.eq(inp.m0)]
451
452 def elaborate(self, platform):
453 m = Module()
454 m.d.comb += self.roundz.eq(self.guard & \
455 (self.round_bit | self.sticky | self.m0))
456 return m
457
458
459 class FPBase:
460 """ IEEE754 Floating Point Base Class
461
462 contains common functions for FP manipulation, such as
463 extracting and packing operands, normalisation, denormalisation,
464 rounding etc.
465 """
466
467 def get_op(self, m, op, v, next_state):
468 """ this function moves to the next state and copies the operand
469 when both stb and ack are 1.
470 acknowledgement is sent by setting ack to ZERO.
471 """
472 with m.If((op.ack) & (op.stb)):
473 m.next = next_state
474 m.d.sync += [
475 # op is latched in from FPNumIn class on same ack/stb
476 v.decode(op.v),
477 op.ack.eq(0)
478 ]
479 with m.Else():
480 m.d.sync += op.ack.eq(1)
481
482 def denormalise(self, m, a):
483 """ denormalises a number. this is probably the wrong name for
484 this function. for normalised numbers (exponent != minimum)
485 one *extra* bit (the implicit 1) is added *back in*.
486 for denormalised numbers, the mantissa is left alone
487 and the exponent increased by 1.
488
489 both cases *effectively multiply the number stored by 2*,
490 which has to be taken into account when extracting the result.
491 """
492 with m.If(a.exp_n127):
493 m.d.sync += a.e.eq(a.N126) # limit a exponent
494 with m.Else():
495 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
496
497 def op_normalise(self, m, op, next_state):
498 """ operand normalisation
499 NOTE: just like "align", this one keeps going round every clock
500 until the result's exponent is within acceptable "range"
501 """
502 with m.If((op.m[-1] == 0)): # check last bit of mantissa
503 m.d.sync +=[
504 op.e.eq(op.e - 1), # DECREASE exponent
505 op.m.eq(op.m << 1), # shift mantissa UP
506 ]
507 with m.Else():
508 m.next = next_state
509
510 def normalise_1(self, m, z, of, next_state):
511 """ first stage normalisation
512
513 NOTE: just like "align", this one keeps going round every clock
514 until the result's exponent is within acceptable "range"
515 NOTE: the weirdness of reassigning guard and round is due to
516 the extra mantissa bits coming from tot[0..2]
517 """
518 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
519 m.d.sync += [
520 z.e.eq(z.e - 1), # DECREASE exponent
521 z.m.eq(z.m << 1), # shift mantissa UP
522 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
523 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
524 of.round_bit.eq(0), # reset round bit
525 of.m0.eq(of.guard),
526 ]
527 with m.Else():
528 m.next = next_state
529
530 def normalise_2(self, m, z, of, next_state):
531 """ second stage normalisation
532
533 NOTE: just like "align", this one keeps going round every clock
534 until the result's exponent is within acceptable "range"
535 NOTE: the weirdness of reassigning guard and round is due to
536 the extra mantissa bits coming from tot[0..2]
537 """
538 with m.If(z.e < z.N126):
539 m.d.sync +=[
540 z.e.eq(z.e + 1), # INCREASE exponent
541 z.m.eq(z.m >> 1), # shift mantissa DOWN
542 of.guard.eq(z.m[0]),
543 of.m0.eq(z.m[1]),
544 of.round_bit.eq(of.guard),
545 of.sticky.eq(of.sticky | of.round_bit)
546 ]
547 with m.Else():
548 m.next = next_state
549
550 def roundz(self, m, z, out_z, roundz):
551 """ performs rounding on the output. TODO: different kinds of rounding
552 """
553 m.d.comb += out_z.copy(z) # copies input to output first
554 with m.If(roundz):
555 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
556 with m.If(z.m == z.m1s): # all 1s
557 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
558
559 def corrections(self, m, z, next_state):
560 """ denormalisation and sign-bug corrections
561 """
562 m.next = next_state
563 # denormalised, correct exponent to zero
564 with m.If(z.is_denormalised):
565 m.d.sync += z.e.eq(z.N127)
566
567 def pack(self, m, z, next_state):
568 """ packs the result into the output (detects overflow->Inf)
569 """
570 m.next = next_state
571 # if overflow occurs, return inf
572 with m.If(z.is_overflowed):
573 m.d.sync += z.inf(z.s)
574 with m.Else():
575 m.d.sync += z.create(z.s, z.e, z.m)
576
577 def put_z(self, m, z, out_z, next_state):
578 """ put_z: stores the result in the output. raises stb and waits
579 for ack to be set to 1 before moving to the next state.
580 resets stb back to zero when that occurs, as acknowledgement.
581 """
582 m.d.sync += [
583 out_z.v.eq(z.v)
584 ]
585 with m.If(out_z.stb & out_z.ack):
586 m.d.sync += out_z.stb.eq(0)
587 m.next = next_state
588 with m.Else():
589 m.d.sync += out_z.stb.eq(1)
590
591