rounding done in module
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_gt127 = Signal(reset_less=True)
100 self.exp_n127 = Signal(reset_less=True)
101 self.exp_n126 = Signal(reset_less=True)
102 self.m_zero = Signal(reset_less=True)
103 self.m_msbzero = Signal(reset_less=True)
104
105 def elaborate(self, platform):
106 m = Module()
107 m.d.comb += self.is_nan.eq(self._is_nan())
108 m.d.comb += self.is_zero.eq(self._is_zero())
109 m.d.comb += self.is_inf.eq(self._is_inf())
110 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
111 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
112 m.d.comb += self.exp_128.eq(self.e == self.P128)
113 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
114 m.d.comb += self.exp_n127.eq(self.e == self.N127)
115 m.d.comb += self.exp_n126.eq(self.e == self.N126)
116 m.d.comb += self.m_zero.eq(self.m == self.mzero)
117 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
118
119 return m
120
121 def _is_nan(self):
122 return (self.exp_128) & (~self.m_zero)
123
124 def _is_inf(self):
125 return (self.exp_128) & (self.m_zero)
126
127 def _is_zero(self):
128 return (self.exp_n127) & (self.m_zero)
129
130 def _is_overflowed(self):
131 return self.exp_gt127
132
133 def _is_denormalised(self):
134 return (self.exp_n126) & (self.m_msbzero)
135
136 def copy(self, inp):
137 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
138
139
140 class FPNumOut(FPNumBase):
141 """ Floating-point Number Class
142
143 Contains signals for an incoming copy of the value, decoded into
144 sign / exponent / mantissa.
145 Also contains encoding functions, creation and recognition of
146 zero, NaN and inf (all signed)
147
148 Four extra bits are included in the mantissa: the top bit
149 (m[-1]) is effectively a carry-overflow. The other three are
150 guard (m[2]), round (m[1]), and sticky (m[0])
151 """
152 def __init__(self, width, m_extra=True):
153 FPNumBase.__init__(self, width, m_extra)
154
155 def elaborate(self, platform):
156 m = FPNumBase.elaborate(self, platform)
157
158 return m
159
160 def create(self, s, e, m):
161 """ creates a value from sign / exponent / mantissa
162
163 bias is added here, to the exponent
164 """
165 return [
166 self.v[-1].eq(s), # sign
167 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
168 self.v[0:self.e_start].eq(m) # mantissa
169 ]
170
171 def nan(self, s):
172 return self.create(s, self.P128, 1<<(self.e_start-1))
173
174 def inf(self, s):
175 return self.create(s, self.P128, 0)
176
177 def zero(self, s):
178 return self.create(s, self.N127, 0)
179
180
181 class FPNumShift(FPNumBase):
182 """ Floating-point Number Class for shifting
183 """
184 def __init__(self, mainm, op, inv, width, m_extra=True):
185 FPNumBase.__init__(self, width, m_extra)
186 self.latch_in = Signal()
187 self.mainm = mainm
188 self.inv = inv
189 self.op = op
190
191 def elaborate(self, platform):
192 m = FPNumBase.elaborate(self, platform)
193
194 m.d.comb += self.s.eq(op.s)
195 m.d.comb += self.e.eq(op.e)
196 m.d.comb += self.m.eq(op.m)
197
198 with self.mainm.State("align"):
199 with m.If(self.e < self.inv.e):
200 m.d.sync += self.shift_down()
201
202 return m
203
204 def shift_down(self):
205 """ shifts a mantissa down by one. exponent is increased to compensate
206
207 accuracy is lost as a result in the mantissa however there are 3
208 guard bits (the latter of which is the "sticky" bit)
209 """
210 return [self.e.eq(self.e + 1),
211 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
212 ]
213
214 def shift_down_multi(self, diff):
215 """ shifts a mantissa down. exponent is increased to compensate
216
217 accuracy is lost as a result in the mantissa however there are 3
218 guard bits (the latter of which is the "sticky" bit)
219
220 this code works by variable-shifting the mantissa by up to
221 its maximum bit-length: no point doing more (it'll still be
222 zero).
223
224 the sticky bit is computed by shifting a batch of 1s by
225 the same amount, which will introduce zeros. it's then
226 inverted and used as a mask to get the LSBs of the mantissa.
227 those are then |'d into the sticky bit.
228 """
229 sm = MultiShift(self.width)
230 mw = Const(self.m_width-1, len(diff))
231 maxslen = Mux(diff > mw, mw, diff)
232 rs = sm.rshift(self.m[1:], maxslen)
233 maxsleni = mw - maxslen
234 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
235
236 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
237 return [self.e.eq(self.e + diff),
238 self.m.eq(Cat(stickybits, rs))
239 ]
240
241 def shift_up_multi(self, diff):
242 """ shifts a mantissa up. exponent is decreased to compensate
243 """
244 sm = MultiShift(self.width)
245 mw = Const(self.m_width, len(diff))
246 maxslen = Mux(diff > mw, mw, diff)
247
248 return [self.e.eq(self.e - diff),
249 self.m.eq(sm.lshift(self.m, maxslen))
250 ]
251
252 class FPNumIn(FPNumBase):
253 """ Floating-point Number Class
254
255 Contains signals for an incoming copy of the value, decoded into
256 sign / exponent / mantissa.
257 Also contains encoding functions, creation and recognition of
258 zero, NaN and inf (all signed)
259
260 Four extra bits are included in the mantissa: the top bit
261 (m[-1]) is effectively a carry-overflow. The other three are
262 guard (m[2]), round (m[1]), and sticky (m[0])
263 """
264 def __init__(self, op, width, m_extra=True):
265 FPNumBase.__init__(self, width, m_extra)
266 self.latch_in = Signal()
267 self.op = op
268
269 def elaborate(self, platform):
270 m = FPNumBase.elaborate(self, platform)
271
272 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
273 #with m.If(self.latch_in):
274 # m.d.sync += self.decode(self.v)
275
276 return m
277
278 def decode(self, v):
279 """ decodes a latched value into sign / exponent / mantissa
280
281 bias is subtracted here, from the exponent. exponent
282 is extended to 10 bits so that subtract 127 is done on
283 a 10-bit number
284 """
285 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
286 #print ("decode", self.e_end)
287 return [self.m.eq(Cat(*args)), # mantissa
288 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
289 self.s.eq(v[-1]), # sign
290 ]
291
292 def shift_down(self):
293 """ shifts a mantissa down by one. exponent is increased to compensate
294
295 accuracy is lost as a result in the mantissa however there are 3
296 guard bits (the latter of which is the "sticky" bit)
297 """
298 return [self.e.eq(self.e + 1),
299 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
300 ]
301
302 def shift_down_multi(self, diff):
303 """ shifts a mantissa down. exponent is increased to compensate
304
305 accuracy is lost as a result in the mantissa however there are 3
306 guard bits (the latter of which is the "sticky" bit)
307
308 this code works by variable-shifting the mantissa by up to
309 its maximum bit-length: no point doing more (it'll still be
310 zero).
311
312 the sticky bit is computed by shifting a batch of 1s by
313 the same amount, which will introduce zeros. it's then
314 inverted and used as a mask to get the LSBs of the mantissa.
315 those are then |'d into the sticky bit.
316 """
317 sm = MultiShift(self.width)
318 mw = Const(self.m_width-1, len(diff))
319 maxslen = Mux(diff > mw, mw, diff)
320 rs = sm.rshift(self.m[1:], maxslen)
321 maxsleni = mw - maxslen
322 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
323
324 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
325 return [self.e.eq(self.e + diff),
326 self.m.eq(Cat(stickybits, rs))
327 ]
328
329 def shift_up_multi(self, diff):
330 """ shifts a mantissa up. exponent is decreased to compensate
331 """
332 sm = MultiShift(self.width)
333 mw = Const(self.m_width, len(diff))
334 maxslen = Mux(diff > mw, mw, diff)
335
336 return [self.e.eq(self.e - diff),
337 self.m.eq(sm.lshift(self.m, maxslen))
338 ]
339
340 class FPOp:
341 def __init__(self, width):
342 self.width = width
343
344 self.v = Signal(width)
345 self.stb = Signal(reset=0)
346 self.ack = Signal()
347
348 def chain_inv(self, in_op, extra=None):
349 stb = in_op.stb
350 if extra is not None:
351 stb = stb & extra
352 return [self.v.eq(in_op.v), # receive value
353 self.stb.eq(stb), # receive STB
354 in_op.ack.eq(~self.ack), # send ACK
355 ]
356
357 def chain_from(self, in_op, extra=None):
358 stb = in_op.stb
359 if extra is not None:
360 stb = stb & extra
361 return [self.v.eq(in_op.v), # receive value
362 self.stb.eq(stb), # receive STB
363 in_op.ack.eq(self.ack), # send ACK
364 ]
365
366 def ports(self):
367 return [self.v, self.stb, self.ack]
368
369
370 class Overflow:
371 def __init__(self):
372 self.guard = Signal(reset_less=True) # tot[2]
373 self.round_bit = Signal(reset_less=True) # tot[1]
374 self.sticky = Signal(reset_less=True) # tot[0]
375 self.m0 = Signal(reset_less=True) # mantissa zero bit
376
377 self.roundz = Signal(reset_less=True)
378
379 def elaborate(self, platform):
380 m = Module()
381 m.d.comb += self.roundz.eq(self.guard & \
382 (self.round_bit | self.sticky | self.m0))
383 return m
384
385
386 class FPBase:
387 """ IEEE754 Floating Point Base Class
388
389 contains common functions for FP manipulation, such as
390 extracting and packing operands, normalisation, denormalisation,
391 rounding etc.
392 """
393
394 def get_op(self, m, op, v, next_state):
395 """ this function moves to the next state and copies the operand
396 when both stb and ack are 1.
397 acknowledgement is sent by setting ack to ZERO.
398 """
399 with m.If((op.ack) & (op.stb)):
400 m.next = next_state
401 m.d.sync += [
402 # op is latched in from FPNumIn class on same ack/stb
403 v.decode(op.v),
404 op.ack.eq(0)
405 ]
406 with m.Else():
407 m.d.sync += op.ack.eq(1)
408
409 def denormalise(self, m, a):
410 """ denormalises a number. this is probably the wrong name for
411 this function. for normalised numbers (exponent != minimum)
412 one *extra* bit (the implicit 1) is added *back in*.
413 for denormalised numbers, the mantissa is left alone
414 and the exponent increased by 1.
415
416 both cases *effectively multiply the number stored by 2*,
417 which has to be taken into account when extracting the result.
418 """
419 with m.If(a.e == a.N127):
420 m.d.sync += a.e.eq(a.N126) # limit a exponent
421 with m.Else():
422 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
423
424 def op_normalise(self, m, op, next_state):
425 """ operand normalisation
426 NOTE: just like "align", this one keeps going round every clock
427 until the result's exponent is within acceptable "range"
428 """
429 with m.If((op.m[-1] == 0)): # check last bit of mantissa
430 m.d.sync +=[
431 op.e.eq(op.e - 1), # DECREASE exponent
432 op.m.eq(op.m << 1), # shift mantissa UP
433 ]
434 with m.Else():
435 m.next = next_state
436
437 def normalise_1(self, m, z, of, next_state):
438 """ first stage normalisation
439
440 NOTE: just like "align", this one keeps going round every clock
441 until the result's exponent is within acceptable "range"
442 NOTE: the weirdness of reassigning guard and round is due to
443 the extra mantissa bits coming from tot[0..2]
444 """
445 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
446 m.d.sync += [
447 z.e.eq(z.e - 1), # DECREASE exponent
448 z.m.eq(z.m << 1), # shift mantissa UP
449 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
450 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
451 of.round_bit.eq(0), # reset round bit
452 of.m0.eq(of.guard),
453 ]
454 with m.Else():
455 m.next = next_state
456
457 def normalise_2(self, m, z, of, next_state):
458 """ second stage normalisation
459
460 NOTE: just like "align", this one keeps going round every clock
461 until the result's exponent is within acceptable "range"
462 NOTE: the weirdness of reassigning guard and round is due to
463 the extra mantissa bits coming from tot[0..2]
464 """
465 with m.If(z.e < z.N126):
466 m.d.sync +=[
467 z.e.eq(z.e + 1), # INCREASE exponent
468 z.m.eq(z.m >> 1), # shift mantissa DOWN
469 of.guard.eq(z.m[0]),
470 of.m0.eq(z.m[1]),
471 of.round_bit.eq(of.guard),
472 of.sticky.eq(of.sticky | of.round_bit)
473 ]
474 with m.Else():
475 m.next = next_state
476
477 def roundz(self, m, z, out_z, roundz):
478 """ performs rounding on the output. TODO: different kinds of rounding
479 """
480 m.d.comb += out_z.copy(z) # copies input to output first
481 with m.If(roundz):
482 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
483 with m.If(z.m == z.m1s): # all 1s
484 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
485
486 def corrections(self, m, z, next_state):
487 """ denormalisation and sign-bug corrections
488 """
489 m.next = next_state
490 # denormalised, correct exponent to zero
491 with m.If(z.is_denormalised):
492 m.d.sync += z.e.eq(z.N127)
493
494 def pack(self, m, z, next_state):
495 """ packs the result into the output (detects overflow->Inf)
496 """
497 m.next = next_state
498 # if overflow occurs, return inf
499 with m.If(z.is_overflowed):
500 m.d.sync += z.inf(z.s)
501 with m.Else():
502 m.d.sync += z.create(z.s, z.e, z.m)
503
504 def put_z(self, m, z, out_z, next_state):
505 """ put_z: stores the result in the output. raises stb and waits
506 for ack to be set to 1 before moving to the next state.
507 resets stb back to zero when that occurs, as acknowledgement.
508 """
509 m.d.sync += [
510 out_z.v.eq(z.v)
511 ]
512 with m.If(out_z.stb & out_z.ack):
513 m.d.sync += out_z.stb.eq(0)
514 m.next = next_state
515 with m.Else():
516 m.d.sync += out_z.stb.eq(1)
517
518