unit test for multi-bit shift right with merge (sticky bit)
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
100 self.exp_lt_n126 = Signal(reset_less=True)
101 self.exp_gt_n126 = Signal(reset_less=True)
102 self.exp_gt127 = Signal(reset_less=True)
103 self.exp_n127 = Signal(reset_less=True)
104 self.exp_n126 = Signal(reset_less=True)
105 self.m_zero = Signal(reset_less=True)
106 self.m_msbzero = Signal(reset_less=True)
107
108 def elaborate(self, platform):
109 m = Module()
110 m.d.comb += self.is_nan.eq(self._is_nan())
111 m.d.comb += self.is_zero.eq(self._is_zero())
112 m.d.comb += self.is_inf.eq(self._is_inf())
113 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
114 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
115 m.d.comb += self.exp_128.eq(self.e == self.P128)
116 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
117 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
118 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
119 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
120 m.d.comb += self.exp_n127.eq(self.e == self.N127)
121 m.d.comb += self.exp_n126.eq(self.e == self.N126)
122 m.d.comb += self.m_zero.eq(self.m == self.mzero)
123 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
124
125 return m
126
127 def _is_nan(self):
128 return (self.exp_128) & (~self.m_zero)
129
130 def _is_inf(self):
131 return (self.exp_128) & (self.m_zero)
132
133 def _is_zero(self):
134 return (self.exp_n127) & (self.m_zero)
135
136 def _is_overflowed(self):
137 return self.exp_gt127
138
139 def _is_denormalised(self):
140 return (self.exp_n126) & (self.m_msbzero)
141
142 def copy(self, inp):
143 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
144
145
146 class FPNumOut(FPNumBase):
147 """ Floating-point Number Class
148
149 Contains signals for an incoming copy of the value, decoded into
150 sign / exponent / mantissa.
151 Also contains encoding functions, creation and recognition of
152 zero, NaN and inf (all signed)
153
154 Four extra bits are included in the mantissa: the top bit
155 (m[-1]) is effectively a carry-overflow. The other three are
156 guard (m[2]), round (m[1]), and sticky (m[0])
157 """
158 def __init__(self, width, m_extra=True):
159 FPNumBase.__init__(self, width, m_extra)
160
161 def elaborate(self, platform):
162 m = FPNumBase.elaborate(self, platform)
163
164 return m
165
166 def create(self, s, e, m):
167 """ creates a value from sign / exponent / mantissa
168
169 bias is added here, to the exponent
170 """
171 return [
172 self.v[-1].eq(s), # sign
173 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
174 self.v[0:self.e_start].eq(m) # mantissa
175 ]
176
177 def nan(self, s):
178 return self.create(s, self.P128, 1<<(self.e_start-1))
179
180 def inf(self, s):
181 return self.create(s, self.P128, 0)
182
183 def zero(self, s):
184 return self.create(s, self.N127, 0)
185
186
187 class MultiShiftRMerge:
188 """ shifts down (right) and merges lower bits into m[0].
189 m[0] is the "sticky" bit, basically
190 """
191 def __init__(self, width):
192 self.smax = int(log(width) / log(2))
193 self.m = Signal(width, reset_less=True)
194 self.inp = Signal(width, reset_less=True)
195 self.diff = Signal(self.smax, reset_less=True)
196 self.width = width
197
198 def elaborate(self, platform):
199 m = Module()
200
201 rs = Signal(self.width, reset_less=True)
202 m_mask = Signal(self.width, reset_less=True)
203 smask = Signal(self.width, reset_less=True)
204 stickybit = Signal(reset_less=True)
205
206 sm = MultiShift(self.width-1)
207 m0s = Const(0, self.width-1)
208 mw = Const(self.width-1, len(self.diff))
209 maxslen = Mux(self.diff > mw, mw, self.diff)
210 maxsleni = mw - maxslen
211 m.d.comb += [
212 # shift mantissa by maxslen, mask by inverse
213 rs.eq(sm.rshift(self.inp[1:], maxslen)),
214 m_mask.eq(sm.rshift(~m0s, maxsleni)),
215 smask.eq(self.inp[1:] & m_mask),
216 # sticky bit combines all mask (and mantissa low bit)
217 stickybit.eq(smask.bool() | self.inp[0]),
218 # mantissa result contains m[0] already.
219 self.m.eq(Cat(stickybit, rs))
220 ]
221 return m
222
223
224 class FPNumShift(FPNumBase):
225 """ Floating-point Number Class for shifting
226 """
227 def __init__(self, mainm, op, inv, width, m_extra=True):
228 FPNumBase.__init__(self, width, m_extra)
229 self.latch_in = Signal()
230 self.mainm = mainm
231 self.inv = inv
232 self.op = op
233
234 def elaborate(self, platform):
235 m = FPNumBase.elaborate(self, platform)
236
237 m.d.comb += self.s.eq(op.s)
238 m.d.comb += self.e.eq(op.e)
239 m.d.comb += self.m.eq(op.m)
240
241 with self.mainm.State("align"):
242 with m.If(self.e < self.inv.e):
243 m.d.sync += self.shift_down()
244
245 return m
246
247 def shift_down(self, inp):
248 """ shifts a mantissa down by one. exponent is increased to compensate
249
250 accuracy is lost as a result in the mantissa however there are 3
251 guard bits (the latter of which is the "sticky" bit)
252 """
253 return [self.e.eq(inp.e + 1),
254 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
255 ]
256
257 def shift_down_multi(self, diff):
258 """ shifts a mantissa down. exponent is increased to compensate
259
260 accuracy is lost as a result in the mantissa however there are 3
261 guard bits (the latter of which is the "sticky" bit)
262
263 this code works by variable-shifting the mantissa by up to
264 its maximum bit-length: no point doing more (it'll still be
265 zero).
266
267 the sticky bit is computed by shifting a batch of 1s by
268 the same amount, which will introduce zeros. it's then
269 inverted and used as a mask to get the LSBs of the mantissa.
270 those are then |'d into the sticky bit.
271 """
272 sm = MultiShift(self.width)
273 mw = Const(self.m_width-1, len(diff))
274 maxslen = Mux(diff > mw, mw, diff)
275 rs = sm.rshift(self.m[1:], maxslen)
276 maxsleni = mw - maxslen
277 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
278
279 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
280 return [self.e.eq(self.e + diff),
281 self.m.eq(Cat(stickybits, rs))
282 ]
283
284 def shift_up_multi(self, diff):
285 """ shifts a mantissa up. exponent is decreased to compensate
286 """
287 sm = MultiShift(self.width)
288 mw = Const(self.m_width, len(diff))
289 maxslen = Mux(diff > mw, mw, diff)
290
291 return [self.e.eq(self.e - diff),
292 self.m.eq(sm.lshift(self.m, maxslen))
293 ]
294
295 class FPNumIn(FPNumBase):
296 """ Floating-point Number Class
297
298 Contains signals for an incoming copy of the value, decoded into
299 sign / exponent / mantissa.
300 Also contains encoding functions, creation and recognition of
301 zero, NaN and inf (all signed)
302
303 Four extra bits are included in the mantissa: the top bit
304 (m[-1]) is effectively a carry-overflow. The other three are
305 guard (m[2]), round (m[1]), and sticky (m[0])
306 """
307 def __init__(self, op, width, m_extra=True):
308 FPNumBase.__init__(self, width, m_extra)
309 self.latch_in = Signal()
310 self.op = op
311
312 def elaborate(self, platform):
313 m = FPNumBase.elaborate(self, platform)
314
315 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
316 #with m.If(self.latch_in):
317 # m.d.sync += self.decode(self.v)
318
319 return m
320
321 def decode(self, v):
322 """ decodes a latched value into sign / exponent / mantissa
323
324 bias is subtracted here, from the exponent. exponent
325 is extended to 10 bits so that subtract 127 is done on
326 a 10-bit number
327 """
328 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
329 #print ("decode", self.e_end)
330 return [self.m.eq(Cat(*args)), # mantissa
331 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
332 self.s.eq(v[-1]), # sign
333 ]
334
335 def shift_down(self, inp):
336 """ shifts a mantissa down by one. exponent is increased to compensate
337
338 accuracy is lost as a result in the mantissa however there are 3
339 guard bits (the latter of which is the "sticky" bit)
340 """
341 return [self.e.eq(inp.e + 1),
342 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
343 ]
344
345 def shift_down_multi(self, diff, inp=None):
346 """ shifts a mantissa down. exponent is increased to compensate
347
348 accuracy is lost as a result in the mantissa however there are 3
349 guard bits (the latter of which is the "sticky" bit)
350
351 this code works by variable-shifting the mantissa by up to
352 its maximum bit-length: no point doing more (it'll still be
353 zero).
354
355 the sticky bit is computed by shifting a batch of 1s by
356 the same amount, which will introduce zeros. it's then
357 inverted and used as a mask to get the LSBs of the mantissa.
358 those are then |'d into the sticky bit.
359 """
360 if inp is None:
361 inp = self
362 sm = MultiShift(self.width)
363 mw = Const(self.m_width-1, len(diff))
364 maxslen = Mux(diff > mw, mw, diff)
365 rs = sm.rshift(inp.m[1:], maxslen)
366 maxsleni = mw - maxslen
367 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
368
369 #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
370 stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0]
371 return [self.e.eq(inp.e + diff),
372 self.m.eq(Cat(stickybits, rs))
373 ]
374
375 def shift_up_multi(self, diff):
376 """ shifts a mantissa up. exponent is decreased to compensate
377 """
378 sm = MultiShift(self.width)
379 mw = Const(self.m_width, len(diff))
380 maxslen = Mux(diff > mw, mw, diff)
381
382 return [self.e.eq(self.e - diff),
383 self.m.eq(sm.lshift(self.m, maxslen))
384 ]
385
386 class FPOp:
387 def __init__(self, width):
388 self.width = width
389
390 self.v = Signal(width)
391 self.stb = Signal(reset=0)
392 self.ack = Signal()
393 self.trigger = Signal(reset_less=True)
394
395 def elaborate(self, platform):
396 m = Module()
397 m.d.sync += self.trigger.eq(self.stb & self.ack)
398 return m
399
400 def chain_inv(self, in_op, extra=None):
401 stb = in_op.stb
402 if extra is not None:
403 stb = stb & extra
404 return [self.v.eq(in_op.v), # receive value
405 self.stb.eq(stb), # receive STB
406 in_op.ack.eq(~self.ack), # send ACK
407 ]
408
409 def chain_from(self, in_op, extra=None):
410 stb = in_op.stb
411 if extra is not None:
412 stb = stb & extra
413 return [self.v.eq(in_op.v), # receive value
414 self.stb.eq(stb), # receive STB
415 in_op.ack.eq(self.ack), # send ACK
416 ]
417
418 def copy(self, inp):
419 return [self.v.eq(inp.v),
420 self.stb.eq(inp.stb),
421 self.ack.eq(inp.ack)
422 ]
423
424 def ports(self):
425 return [self.v, self.stb, self.ack]
426
427
428 class Overflow:
429 def __init__(self):
430 self.guard = Signal(reset_less=True) # tot[2]
431 self.round_bit = Signal(reset_less=True) # tot[1]
432 self.sticky = Signal(reset_less=True) # tot[0]
433 self.m0 = Signal(reset_less=True) # mantissa zero bit
434
435 self.roundz = Signal(reset_less=True)
436
437 def copy(self, inp):
438 return [self.guard.eq(inp.guard),
439 self.round_bit.eq(inp.round_bit),
440 self.sticky.eq(inp.sticky),
441 self.m0.eq(inp.m0)]
442
443 def elaborate(self, platform):
444 m = Module()
445 m.d.comb += self.roundz.eq(self.guard & \
446 (self.round_bit | self.sticky | self.m0))
447 return m
448
449
450 class FPBase:
451 """ IEEE754 Floating Point Base Class
452
453 contains common functions for FP manipulation, such as
454 extracting and packing operands, normalisation, denormalisation,
455 rounding etc.
456 """
457
458 def get_op(self, m, op, v, next_state):
459 """ this function moves to the next state and copies the operand
460 when both stb and ack are 1.
461 acknowledgement is sent by setting ack to ZERO.
462 """
463 with m.If((op.ack) & (op.stb)):
464 m.next = next_state
465 m.d.sync += [
466 # op is latched in from FPNumIn class on same ack/stb
467 v.decode(op.v),
468 op.ack.eq(0)
469 ]
470 with m.Else():
471 m.d.sync += op.ack.eq(1)
472
473 def denormalise(self, m, a):
474 """ denormalises a number. this is probably the wrong name for
475 this function. for normalised numbers (exponent != minimum)
476 one *extra* bit (the implicit 1) is added *back in*.
477 for denormalised numbers, the mantissa is left alone
478 and the exponent increased by 1.
479
480 both cases *effectively multiply the number stored by 2*,
481 which has to be taken into account when extracting the result.
482 """
483 with m.If(a.exp_n127):
484 m.d.sync += a.e.eq(a.N126) # limit a exponent
485 with m.Else():
486 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
487
488 def op_normalise(self, m, op, next_state):
489 """ operand normalisation
490 NOTE: just like "align", this one keeps going round every clock
491 until the result's exponent is within acceptable "range"
492 """
493 with m.If((op.m[-1] == 0)): # check last bit of mantissa
494 m.d.sync +=[
495 op.e.eq(op.e - 1), # DECREASE exponent
496 op.m.eq(op.m << 1), # shift mantissa UP
497 ]
498 with m.Else():
499 m.next = next_state
500
501 def normalise_1(self, m, z, of, next_state):
502 """ first stage normalisation
503
504 NOTE: just like "align", this one keeps going round every clock
505 until the result's exponent is within acceptable "range"
506 NOTE: the weirdness of reassigning guard and round is due to
507 the extra mantissa bits coming from tot[0..2]
508 """
509 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
510 m.d.sync += [
511 z.e.eq(z.e - 1), # DECREASE exponent
512 z.m.eq(z.m << 1), # shift mantissa UP
513 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
514 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
515 of.round_bit.eq(0), # reset round bit
516 of.m0.eq(of.guard),
517 ]
518 with m.Else():
519 m.next = next_state
520
521 def normalise_2(self, m, z, of, next_state):
522 """ second stage normalisation
523
524 NOTE: just like "align", this one keeps going round every clock
525 until the result's exponent is within acceptable "range"
526 NOTE: the weirdness of reassigning guard and round is due to
527 the extra mantissa bits coming from tot[0..2]
528 """
529 with m.If(z.e < z.N126):
530 m.d.sync +=[
531 z.e.eq(z.e + 1), # INCREASE exponent
532 z.m.eq(z.m >> 1), # shift mantissa DOWN
533 of.guard.eq(z.m[0]),
534 of.m0.eq(z.m[1]),
535 of.round_bit.eq(of.guard),
536 of.sticky.eq(of.sticky | of.round_bit)
537 ]
538 with m.Else():
539 m.next = next_state
540
541 def roundz(self, m, z, out_z, roundz):
542 """ performs rounding on the output. TODO: different kinds of rounding
543 """
544 m.d.comb += out_z.copy(z) # copies input to output first
545 with m.If(roundz):
546 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
547 with m.If(z.m == z.m1s): # all 1s
548 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
549
550 def corrections(self, m, z, next_state):
551 """ denormalisation and sign-bug corrections
552 """
553 m.next = next_state
554 # denormalised, correct exponent to zero
555 with m.If(z.is_denormalised):
556 m.d.sync += z.e.eq(z.N127)
557
558 def pack(self, m, z, next_state):
559 """ packs the result into the output (detects overflow->Inf)
560 """
561 m.next = next_state
562 # if overflow occurs, return inf
563 with m.If(z.is_overflowed):
564 m.d.sync += z.inf(z.s)
565 with m.Else():
566 m.d.sync += z.create(z.s, z.e, z.m)
567
568 def put_z(self, m, z, out_z, next_state):
569 """ put_z: stores the result in the output. raises stb and waits
570 for ack to be set to 1 before moving to the next state.
571 resets stb back to zero when that occurs, as acknowledgement.
572 """
573 m.d.sync += [
574 out_z.v.eq(z.v)
575 ]
576 with m.If(out_z.stb & out_z.ack):
577 m.d.sync += out_z.stb.eq(0)
578 m.next = next_state
579 with m.Else():
580 m.d.sync += out_z.stb.eq(1)
581
582