split denormalisation to separate state
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99
100 def elaborate(self, platform):
101 m = Module()
102 m.d.comb += self.is_nan.eq(self._is_nan())
103 m.d.comb += self.is_zero.eq(self._is_zero())
104 m.d.comb += self.is_inf.eq(self._is_inf())
105 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
106 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
107 m.d.comb += self.exp_128.eq(self.e == self.P128)
108
109 return m
110
111 def _is_nan(self):
112 return (self.e == self.P128) & (self.m != 0)
113
114 def _is_inf(self):
115 return (self.e == self.P128) & (self.m == 0)
116
117 def _is_zero(self):
118 return (self.e == self.N127) & (self.m == self.mzero)
119
120 def _is_overflowed(self):
121 return (self.e > self.P127)
122
123 def _is_denormalised(self):
124 return (self.e == self.N126) & (self.m[self.e_start] == 0)
125
126
127 class FPNumOut(FPNumBase):
128 """ Floating-point Number Class
129
130 Contains signals for an incoming copy of the value, decoded into
131 sign / exponent / mantissa.
132 Also contains encoding functions, creation and recognition of
133 zero, NaN and inf (all signed)
134
135 Four extra bits are included in the mantissa: the top bit
136 (m[-1]) is effectively a carry-overflow. The other three are
137 guard (m[2]), round (m[1]), and sticky (m[0])
138 """
139 def __init__(self, width, m_extra=True):
140 FPNumBase.__init__(self, width, m_extra)
141
142 def elaborate(self, platform):
143 m = FPNumBase.elaborate(self, platform)
144
145 return m
146
147 def create(self, s, e, m):
148 """ creates a value from sign / exponent / mantissa
149
150 bias is added here, to the exponent
151 """
152 return [
153 self.v[-1].eq(s), # sign
154 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
155 self.v[0:self.e_start].eq(m) # mantissa
156 ]
157
158 def nan(self, s):
159 return self.create(s, self.P128, 1<<(self.e_start-1))
160
161 def inf(self, s):
162 return self.create(s, self.P128, 0)
163
164 def zero(self, s):
165 return self.create(s, self.N127, 0)
166
167
168 class FPNumShift(FPNumBase):
169 """ Floating-point Number Class for shifting
170 """
171 def __init__(self, mainm, op, inv, width, m_extra=True):
172 FPNumBase.__init__(self, width, m_extra)
173 self.latch_in = Signal()
174 self.mainm = mainm
175 self.inv = inv
176 self.op = op
177
178 def elaborate(self, platform):
179 m = FPNumBase.elaborate(self, platform)
180
181 m.d.comb += self.s.eq(op.s)
182 m.d.comb += self.e.eq(op.e)
183 m.d.comb += self.m.eq(op.m)
184
185 with self.mainm.State("align"):
186 with m.If(self.e < self.inv.e):
187 m.d.sync += self.shift_down()
188
189 return m
190
191 def shift_down(self):
192 """ shifts a mantissa down by one. exponent is increased to compensate
193
194 accuracy is lost as a result in the mantissa however there are 3
195 guard bits (the latter of which is the "sticky" bit)
196 """
197 return [self.e.eq(self.e + 1),
198 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
199 ]
200
201 def shift_down_multi(self, diff):
202 """ shifts a mantissa down. exponent is increased to compensate
203
204 accuracy is lost as a result in the mantissa however there are 3
205 guard bits (the latter of which is the "sticky" bit)
206
207 this code works by variable-shifting the mantissa by up to
208 its maximum bit-length: no point doing more (it'll still be
209 zero).
210
211 the sticky bit is computed by shifting a batch of 1s by
212 the same amount, which will introduce zeros. it's then
213 inverted and used as a mask to get the LSBs of the mantissa.
214 those are then |'d into the sticky bit.
215 """
216 sm = MultiShift(self.width)
217 mw = Const(self.m_width-1, len(diff))
218 maxslen = Mux(diff > mw, mw, diff)
219 rs = sm.rshift(self.m[1:], maxslen)
220 maxsleni = mw - maxslen
221 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
222
223 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
224 return [self.e.eq(self.e + diff),
225 self.m.eq(Cat(stickybits, rs))
226 ]
227
228 def shift_up_multi(self, diff):
229 """ shifts a mantissa up. exponent is decreased to compensate
230 """
231 sm = MultiShift(self.width)
232 mw = Const(self.m_width, len(diff))
233 maxslen = Mux(diff > mw, mw, diff)
234
235 return [self.e.eq(self.e - diff),
236 self.m.eq(sm.lshift(self.m, maxslen))
237 ]
238
239 class FPNumIn(FPNumBase):
240 """ Floating-point Number Class
241
242 Contains signals for an incoming copy of the value, decoded into
243 sign / exponent / mantissa.
244 Also contains encoding functions, creation and recognition of
245 zero, NaN and inf (all signed)
246
247 Four extra bits are included in the mantissa: the top bit
248 (m[-1]) is effectively a carry-overflow. The other three are
249 guard (m[2]), round (m[1]), and sticky (m[0])
250 """
251 def __init__(self, op, width, m_extra=True):
252 FPNumBase.__init__(self, width, m_extra)
253 self.latch_in = Signal()
254 self.op = op
255
256 def elaborate(self, platform):
257 m = FPNumBase.elaborate(self, platform)
258
259 m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
260 with m.If(self.latch_in):
261 m.d.sync += self.decode(self.v)
262
263 return m
264
265 def decode(self, v):
266 """ decodes a latched value into sign / exponent / mantissa
267
268 bias is subtracted here, from the exponent. exponent
269 is extended to 10 bits so that subtract 127 is done on
270 a 10-bit number
271 """
272 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
273 #print ("decode", self.e_end)
274 return [self.m.eq(Cat(*args)), # mantissa
275 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
276 self.s.eq(v[-1]), # sign
277 ]
278
279 def shift_down(self):
280 """ shifts a mantissa down by one. exponent is increased to compensate
281
282 accuracy is lost as a result in the mantissa however there are 3
283 guard bits (the latter of which is the "sticky" bit)
284 """
285 return [self.e.eq(self.e + 1),
286 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
287 ]
288
289 def shift_down_multi(self, diff):
290 """ shifts a mantissa down. exponent is increased to compensate
291
292 accuracy is lost as a result in the mantissa however there are 3
293 guard bits (the latter of which is the "sticky" bit)
294
295 this code works by variable-shifting the mantissa by up to
296 its maximum bit-length: no point doing more (it'll still be
297 zero).
298
299 the sticky bit is computed by shifting a batch of 1s by
300 the same amount, which will introduce zeros. it's then
301 inverted and used as a mask to get the LSBs of the mantissa.
302 those are then |'d into the sticky bit.
303 """
304 sm = MultiShift(self.width)
305 mw = Const(self.m_width-1, len(diff))
306 maxslen = Mux(diff > mw, mw, diff)
307 rs = sm.rshift(self.m[1:], maxslen)
308 maxsleni = mw - maxslen
309 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
310
311 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
312 return [self.e.eq(self.e + diff),
313 self.m.eq(Cat(stickybits, rs))
314 ]
315
316 def shift_up_multi(self, diff):
317 """ shifts a mantissa up. exponent is decreased to compensate
318 """
319 sm = MultiShift(self.width)
320 mw = Const(self.m_width, len(diff))
321 maxslen = Mux(diff > mw, mw, diff)
322
323 return [self.e.eq(self.e - diff),
324 self.m.eq(sm.lshift(self.m, maxslen))
325 ]
326
327 class FPOp:
328 def __init__(self, width):
329 self.width = width
330
331 self.v = Signal(width)
332 self.stb = Signal()
333 self.ack = Signal()
334
335 def ports(self):
336 return [self.v, self.stb, self.ack]
337
338
339 class Overflow:
340 def __init__(self):
341 self.guard = Signal(reset_less=True) # tot[2]
342 self.round_bit = Signal(reset_less=True) # tot[1]
343 self.sticky = Signal(reset_less=True) # tot[0]
344 self.m0 = Signal(reset_less=True) # mantissa zero bit
345
346 self.roundz = Signal(reset_less=True)
347
348 def elaborate(self, platform):
349 m = Module()
350 m.d.comb += self.roundz.eq(self.guard & \
351 (self.round_bit | self.sticky | self.m0))
352 return m
353
354
355 class FPBase:
356 """ IEEE754 Floating Point Base Class
357
358 contains common functions for FP manipulation, such as
359 extracting and packing operands, normalisation, denormalisation,
360 rounding etc.
361 """
362
363 def get_op(self, m, op, v, next_state):
364 """ this function moves to the next state and copies the operand
365 when both stb and ack are 1.
366 acknowledgement is sent by setting ack to ZERO.
367 """
368 with m.If((op.ack) & (op.stb)):
369 m.next = next_state
370 m.d.sync += [
371 # op is latched in from FPNumIn class on same ack/stb
372 op.ack.eq(0)
373 ]
374 with m.Else():
375 m.d.sync += op.ack.eq(1)
376
377 def denormalise(self, m, a):
378 """ denormalises a number. this is probably the wrong name for
379 this function. for normalised numbers (exponent != minimum)
380 one *extra* bit (the implicit 1) is added *back in*.
381 for denormalised numbers, the mantissa is left alone
382 and the exponent increased by 1.
383
384 both cases *effectively multiply the number stored by 2*,
385 which has to be taken into account when extracting the result.
386 """
387 with m.If(a.e == a.N127):
388 m.d.sync += a.e.eq(a.N126) # limit a exponent
389 with m.Else():
390 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
391
392 def op_normalise(self, m, op, next_state):
393 """ operand normalisation
394 NOTE: just like "align", this one keeps going round every clock
395 until the result's exponent is within acceptable "range"
396 """
397 with m.If((op.m[-1] == 0)): # check last bit of mantissa
398 m.d.sync +=[
399 op.e.eq(op.e - 1), # DECREASE exponent
400 op.m.eq(op.m << 1), # shift mantissa UP
401 ]
402 with m.Else():
403 m.next = next_state
404
405 def normalise_1(self, m, z, of, next_state):
406 """ first stage normalisation
407
408 NOTE: just like "align", this one keeps going round every clock
409 until the result's exponent is within acceptable "range"
410 NOTE: the weirdness of reassigning guard and round is due to
411 the extra mantissa bits coming from tot[0..2]
412 """
413 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
414 m.d.sync += [
415 z.e.eq(z.e - 1), # DECREASE exponent
416 z.m.eq(z.m << 1), # shift mantissa UP
417 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
418 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
419 of.round_bit.eq(0), # reset round bit
420 of.m0.eq(of.guard),
421 ]
422 with m.Else():
423 m.next = next_state
424
425 def normalise_2(self, m, z, of, next_state):
426 """ second stage normalisation
427
428 NOTE: just like "align", this one keeps going round every clock
429 until the result's exponent is within acceptable "range"
430 NOTE: the weirdness of reassigning guard and round is due to
431 the extra mantissa bits coming from tot[0..2]
432 """
433 with m.If(z.e < z.N126):
434 m.d.sync +=[
435 z.e.eq(z.e + 1), # INCREASE exponent
436 z.m.eq(z.m >> 1), # shift mantissa DOWN
437 of.guard.eq(z.m[0]),
438 of.m0.eq(z.m[1]),
439 of.round_bit.eq(of.guard),
440 of.sticky.eq(of.sticky | of.round_bit)
441 ]
442 with m.Else():
443 m.next = next_state
444
445 def roundz(self, m, z, of, next_state):
446 """ performs rounding on the output. TODO: different kinds of rounding
447 """
448 m.next = next_state
449 with m.If(of.roundz):
450 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
451 with m.If(z.m == z.m1s): # all 1s
452 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
453
454 def corrections(self, m, z, next_state):
455 """ denormalisation and sign-bug corrections
456 """
457 m.next = next_state
458 # denormalised, correct exponent to zero
459 with m.If(z.is_denormalised):
460 m.d.sync += z.e.eq(z.N127)
461
462 def pack(self, m, z, next_state):
463 """ packs the result into the output (detects overflow->Inf)
464 """
465 m.next = next_state
466 # if overflow occurs, return inf
467 with m.If(z.is_overflowed):
468 m.d.sync += z.inf(z.s)
469 with m.Else():
470 m.d.sync += z.create(z.s, z.e, z.m)
471
472 def put_z(self, m, z, out_z, next_state):
473 """ put_z: stores the result in the output. raises stb and waits
474 for ack to be set to 1 before moving to the next state.
475 resets stb back to zero when that occurs, as acknowledgement.
476 """
477 m.d.sync += [
478 out_z.v.eq(z.v)
479 ]
480 with m.If(out_z.stb & out_z.ack):
481 m.d.sync += out_z.stb.eq(0)
482 m.next = next_state
483 with m.Else():
484 m.d.sync += out_z.stb.eq(1)
485
486