bd20364992273638287b6aa9ad16e09eb79d4951
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
100 self.exp_lt_n126 = Signal(reset_less=True)
101 self.exp_gt_n126 = Signal(reset_less=True)
102 self.exp_gt127 = Signal(reset_less=True)
103 self.exp_n127 = Signal(reset_less=True)
104 self.exp_n126 = Signal(reset_less=True)
105 self.m_zero = Signal(reset_less=True)
106 self.m_msbzero = Signal(reset_less=True)
107
108 def elaborate(self, platform):
109 m = Module()
110 m.d.comb += self.is_nan.eq(self._is_nan())
111 m.d.comb += self.is_zero.eq(self._is_zero())
112 m.d.comb += self.is_inf.eq(self._is_inf())
113 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
114 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
115 m.d.comb += self.exp_128.eq(self.e == self.P128)
116 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
117 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
118 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
119 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
120 m.d.comb += self.exp_n127.eq(self.e == self.N127)
121 m.d.comb += self.exp_n126.eq(self.e == self.N126)
122 m.d.comb += self.m_zero.eq(self.m == self.mzero)
123 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
124
125 return m
126
127 def _is_nan(self):
128 return (self.exp_128) & (~self.m_zero)
129
130 def _is_inf(self):
131 return (self.exp_128) & (self.m_zero)
132
133 def _is_zero(self):
134 return (self.exp_n127) & (self.m_zero)
135
136 def _is_overflowed(self):
137 return self.exp_gt127
138
139 def _is_denormalised(self):
140 return (self.exp_n126) & (self.m_msbzero)
141
142 def copy(self, inp):
143 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
144
145
146 class FPNumOut(FPNumBase):
147 """ Floating-point Number Class
148
149 Contains signals for an incoming copy of the value, decoded into
150 sign / exponent / mantissa.
151 Also contains encoding functions, creation and recognition of
152 zero, NaN and inf (all signed)
153
154 Four extra bits are included in the mantissa: the top bit
155 (m[-1]) is effectively a carry-overflow. The other three are
156 guard (m[2]), round (m[1]), and sticky (m[0])
157 """
158 def __init__(self, width, m_extra=True):
159 FPNumBase.__init__(self, width, m_extra)
160
161 def elaborate(self, platform):
162 m = FPNumBase.elaborate(self, platform)
163
164 return m
165
166 def create(self, s, e, m):
167 """ creates a value from sign / exponent / mantissa
168
169 bias is added here, to the exponent
170 """
171 return [
172 self.v[-1].eq(s), # sign
173 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
174 self.v[0:self.e_start].eq(m) # mantissa
175 ]
176
177 def nan(self, s):
178 return self.create(s, self.P128, 1<<(self.e_start-1))
179
180 def inf(self, s):
181 return self.create(s, self.P128, 0)
182
183 def zero(self, s):
184 return self.create(s, self.N127, 0)
185
186
187 class MultiShiftRMerge:
188 """ shifts down (right) and merges lower bits into m[0].
189 m[0] is the "sticky" bit, basically
190 """
191 def __init__(self, width, s_max=None):
192 if s_max is None:
193 s_max = int(log(width) / log(2))
194 self.smax = s_max
195 self.m = Signal(width, reset_less=True)
196 self.inp = Signal(width, reset_less=True)
197 self.diff = Signal(s_max, reset_less=True)
198 self.width = width
199
200 def elaborate(self, platform):
201 m = Module()
202
203 rs = Signal(self.width, reset_less=True)
204 m_mask = Signal(self.width, reset_less=True)
205 smask = Signal(self.width, reset_less=True)
206 stickybit = Signal(reset_less=True)
207 maxslen = Signal(self.smax, reset_less=True)
208 maxsleni = Signal(self.smax, reset_less=True)
209
210 sm = MultiShift(self.width-1)
211 m0s = Const(0, self.width-1)
212 mw = Const(self.width-1, len(self.diff))
213 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
214 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
215 ]
216
217 m.d.comb += [
218 # shift mantissa by maxslen, mask by inverse
219 rs.eq(sm.rshift(self.inp[1:], maxslen)),
220 m_mask.eq(sm.rshift(~m0s, maxsleni)),
221 smask.eq(self.inp[1:] & m_mask),
222 # sticky bit combines all mask (and mantissa low bit)
223 stickybit.eq(smask.bool() | self.inp[0]),
224 # mantissa result contains m[0] already.
225 self.m.eq(Cat(stickybit, rs))
226 ]
227 return m
228
229
230 class FPNumShift(FPNumBase):
231 """ Floating-point Number Class for shifting
232 """
233 def __init__(self, mainm, op, inv, width, m_extra=True):
234 FPNumBase.__init__(self, width, m_extra)
235 self.latch_in = Signal()
236 self.mainm = mainm
237 self.inv = inv
238 self.op = op
239
240 def elaborate(self, platform):
241 m = FPNumBase.elaborate(self, platform)
242
243 m.d.comb += self.s.eq(op.s)
244 m.d.comb += self.e.eq(op.e)
245 m.d.comb += self.m.eq(op.m)
246
247 with self.mainm.State("align"):
248 with m.If(self.e < self.inv.e):
249 m.d.sync += self.shift_down()
250
251 return m
252
253 def shift_down(self, inp):
254 """ shifts a mantissa down by one. exponent is increased to compensate
255
256 accuracy is lost as a result in the mantissa however there are 3
257 guard bits (the latter of which is the "sticky" bit)
258 """
259 return [self.e.eq(inp.e + 1),
260 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
261 ]
262
263 def shift_down_multi(self, diff):
264 """ shifts a mantissa down. exponent is increased to compensate
265
266 accuracy is lost as a result in the mantissa however there are 3
267 guard bits (the latter of which is the "sticky" bit)
268
269 this code works by variable-shifting the mantissa by up to
270 its maximum bit-length: no point doing more (it'll still be
271 zero).
272
273 the sticky bit is computed by shifting a batch of 1s by
274 the same amount, which will introduce zeros. it's then
275 inverted and used as a mask to get the LSBs of the mantissa.
276 those are then |'d into the sticky bit.
277 """
278 sm = MultiShift(self.width)
279 mw = Const(self.m_width-1, len(diff))
280 maxslen = Mux(diff > mw, mw, diff)
281 rs = sm.rshift(self.m[1:], maxslen)
282 maxsleni = mw - maxslen
283 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
284
285 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
286 return [self.e.eq(self.e + diff),
287 self.m.eq(Cat(stickybits, rs))
288 ]
289
290 def shift_up_multi(self, diff):
291 """ shifts a mantissa up. exponent is decreased to compensate
292 """
293 sm = MultiShift(self.width)
294 mw = Const(self.m_width, len(diff))
295 maxslen = Mux(diff > mw, mw, diff)
296
297 return [self.e.eq(self.e - diff),
298 self.m.eq(sm.lshift(self.m, maxslen))
299 ]
300
301 class FPNumIn(FPNumBase):
302 """ Floating-point Number Class
303
304 Contains signals for an incoming copy of the value, decoded into
305 sign / exponent / mantissa.
306 Also contains encoding functions, creation and recognition of
307 zero, NaN and inf (all signed)
308
309 Four extra bits are included in the mantissa: the top bit
310 (m[-1]) is effectively a carry-overflow. The other three are
311 guard (m[2]), round (m[1]), and sticky (m[0])
312 """
313 def __init__(self, op, width, m_extra=True):
314 FPNumBase.__init__(self, width, m_extra)
315 self.latch_in = Signal()
316 self.op = op
317
318 def elaborate(self, platform):
319 m = FPNumBase.elaborate(self, platform)
320
321 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
322 #with m.If(self.latch_in):
323 # m.d.sync += self.decode(self.v)
324
325 return m
326
327 def decode(self, v):
328 """ decodes a latched value into sign / exponent / mantissa
329
330 bias is subtracted here, from the exponent. exponent
331 is extended to 10 bits so that subtract 127 is done on
332 a 10-bit number
333 """
334 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
335 #print ("decode", self.e_end)
336 return [self.m.eq(Cat(*args)), # mantissa
337 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
338 self.s.eq(v[-1]), # sign
339 ]
340
341 def shift_down(self, inp):
342 """ shifts a mantissa down by one. exponent is increased to compensate
343
344 accuracy is lost as a result in the mantissa however there are 3
345 guard bits (the latter of which is the "sticky" bit)
346 """
347 return [self.e.eq(inp.e + 1),
348 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
349 ]
350
351 def shift_down_multi(self, diff, inp=None):
352 """ shifts a mantissa down. exponent is increased to compensate
353
354 accuracy is lost as a result in the mantissa however there are 3
355 guard bits (the latter of which is the "sticky" bit)
356
357 this code works by variable-shifting the mantissa by up to
358 its maximum bit-length: no point doing more (it'll still be
359 zero).
360
361 the sticky bit is computed by shifting a batch of 1s by
362 the same amount, which will introduce zeros. it's then
363 inverted and used as a mask to get the LSBs of the mantissa.
364 those are then |'d into the sticky bit.
365 """
366 if inp is None:
367 inp = self
368 sm = MultiShift(self.width)
369 mw = Const(self.m_width-1, len(diff))
370 maxslen = Mux(diff > mw, mw, diff)
371 rs = sm.rshift(inp.m[1:], maxslen)
372 maxsleni = mw - maxslen
373 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
374
375 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
376 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
377 return [self.e.eq(inp.e + diff),
378 self.m.eq(Cat(stickybit, rs))
379 ]
380
381 def shift_up_multi(self, diff):
382 """ shifts a mantissa up. exponent is decreased to compensate
383 """
384 sm = MultiShift(self.width)
385 mw = Const(self.m_width, len(diff))
386 maxslen = Mux(diff > mw, mw, diff)
387
388 return [self.e.eq(self.e - diff),
389 self.m.eq(sm.lshift(self.m, maxslen))
390 ]
391
392 class Trigger:
393 def __init__(self):
394
395 self.stb = Signal(reset=0)
396 self.ack = Signal()
397 self.trigger = Signal(reset_less=True)
398
399 def elaborate(self, platform):
400 m = Module()
401 m.d.comb += self.trigger.eq(self.stb & self.ack)
402 return m
403
404 def copy(self, inp):
405 return [self.stb.eq(inp.stb),
406 self.ack.eq(inp.ack)
407 ]
408
409 def ports(self):
410 return [self.stb, self.ack]
411
412
413 class FPOp(Trigger):
414 def __init__(self, width):
415 Trigger.__init__(self)
416 self.width = width
417
418 self.v = Signal(width)
419
420 def chain_inv(self, in_op, extra=None):
421 stb = in_op.stb
422 if extra is not None:
423 stb = stb & extra
424 return [self.v.eq(in_op.v), # receive value
425 self.stb.eq(stb), # receive STB
426 in_op.ack.eq(~self.ack), # send ACK
427 ]
428
429 def chain_from(self, in_op, extra=None):
430 stb = in_op.stb
431 if extra is not None:
432 stb = stb & extra
433 return [self.v.eq(in_op.v), # receive value
434 self.stb.eq(stb), # receive STB
435 in_op.ack.eq(self.ack), # send ACK
436 ]
437
438 def copy(self, inp):
439 return [self.v.eq(inp.v),
440 self.stb.eq(inp.stb),
441 self.ack.eq(inp.ack)
442 ]
443
444 def ports(self):
445 return [self.v, self.stb, self.ack]
446
447
448 class Overflow:
449 def __init__(self):
450 self.guard = Signal(reset_less=True) # tot[2]
451 self.round_bit = Signal(reset_less=True) # tot[1]
452 self.sticky = Signal(reset_less=True) # tot[0]
453 self.m0 = Signal(reset_less=True) # mantissa zero bit
454
455 self.roundz = Signal(reset_less=True)
456
457 def copy(self, inp):
458 return [self.guard.eq(inp.guard),
459 self.round_bit.eq(inp.round_bit),
460 self.sticky.eq(inp.sticky),
461 self.m0.eq(inp.m0)]
462
463 def elaborate(self, platform):
464 m = Module()
465 m.d.comb += self.roundz.eq(self.guard & \
466 (self.round_bit | self.sticky | self.m0))
467 return m
468
469
470 class FPBase:
471 """ IEEE754 Floating Point Base Class
472
473 contains common functions for FP manipulation, such as
474 extracting and packing operands, normalisation, denormalisation,
475 rounding etc.
476 """
477
478 def get_op(self, m, op, v, next_state):
479 """ this function moves to the next state and copies the operand
480 when both stb and ack are 1.
481 acknowledgement is sent by setting ack to ZERO.
482 """
483 with m.If((op.ack) & (op.stb)):
484 m.next = next_state
485 m.d.sync += [
486 # op is latched in from FPNumIn class on same ack/stb
487 v.decode(op.v),
488 op.ack.eq(0)
489 ]
490 with m.Else():
491 m.d.sync += op.ack.eq(1)
492
493 def denormalise(self, m, a):
494 """ denormalises a number. this is probably the wrong name for
495 this function. for normalised numbers (exponent != minimum)
496 one *extra* bit (the implicit 1) is added *back in*.
497 for denormalised numbers, the mantissa is left alone
498 and the exponent increased by 1.
499
500 both cases *effectively multiply the number stored by 2*,
501 which has to be taken into account when extracting the result.
502 """
503 with m.If(a.exp_n127):
504 m.d.sync += a.e.eq(a.N126) # limit a exponent
505 with m.Else():
506 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
507
508 def op_normalise(self, m, op, next_state):
509 """ operand normalisation
510 NOTE: just like "align", this one keeps going round every clock
511 until the result's exponent is within acceptable "range"
512 """
513 with m.If((op.m[-1] == 0)): # check last bit of mantissa
514 m.d.sync +=[
515 op.e.eq(op.e - 1), # DECREASE exponent
516 op.m.eq(op.m << 1), # shift mantissa UP
517 ]
518 with m.Else():
519 m.next = next_state
520
521 def normalise_1(self, m, z, of, next_state):
522 """ first stage normalisation
523
524 NOTE: just like "align", this one keeps going round every clock
525 until the result's exponent is within acceptable "range"
526 NOTE: the weirdness of reassigning guard and round is due to
527 the extra mantissa bits coming from tot[0..2]
528 """
529 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
530 m.d.sync += [
531 z.e.eq(z.e - 1), # DECREASE exponent
532 z.m.eq(z.m << 1), # shift mantissa UP
533 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
534 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
535 of.round_bit.eq(0), # reset round bit
536 of.m0.eq(of.guard),
537 ]
538 with m.Else():
539 m.next = next_state
540
541 def normalise_2(self, m, z, of, next_state):
542 """ second stage normalisation
543
544 NOTE: just like "align", this one keeps going round every clock
545 until the result's exponent is within acceptable "range"
546 NOTE: the weirdness of reassigning guard and round is due to
547 the extra mantissa bits coming from tot[0..2]
548 """
549 with m.If(z.e < z.N126):
550 m.d.sync +=[
551 z.e.eq(z.e + 1), # INCREASE exponent
552 z.m.eq(z.m >> 1), # shift mantissa DOWN
553 of.guard.eq(z.m[0]),
554 of.m0.eq(z.m[1]),
555 of.round_bit.eq(of.guard),
556 of.sticky.eq(of.sticky | of.round_bit)
557 ]
558 with m.Else():
559 m.next = next_state
560
561 def roundz(self, m, z, out_z, roundz):
562 """ performs rounding on the output. TODO: different kinds of rounding
563 """
564 #m.d.comb += out_z.copy(z) # copies input to output first
565 with m.If(roundz):
566 m.d.sync += out_z.m.eq(z.m + 1) # mantissa rounds up
567 with m.If(z.m == z.m1s): # all 1s
568 m.d.sync += out_z.e.eq(z.e + 1) # exponent rounds up
569
570 def corrections(self, m, z, next_state):
571 """ denormalisation and sign-bug corrections
572 """
573 m.next = next_state
574 # denormalised, correct exponent to zero
575 with m.If(z.is_denormalised):
576 m.d.sync += z.e.eq(z.N127)
577
578 def pack(self, m, z, next_state):
579 """ packs the result into the output (detects overflow->Inf)
580 """
581 m.next = next_state
582 # if overflow occurs, return inf
583 with m.If(z.is_overflowed):
584 m.d.sync += z.inf(z.s)
585 with m.Else():
586 m.d.sync += z.create(z.s, z.e, z.m)
587
588 def put_z(self, m, z, out_z, next_state):
589 """ put_z: stores the result in the output. raises stb and waits
590 for ack to be set to 1 before moving to the next state.
591 resets stb back to zero when that occurs, as acknowledgement.
592 """
593 m.d.sync += [
594 out_z.v.eq(z.v)
595 ]
596 with m.If(out_z.stb & out_z.ack):
597 m.d.sync += out_z.stb.eq(0)
598 m.next = next_state
599 with m.Else():
600 m.d.sync += out_z.stb.eq(1)
601
602