5bf8e8c46573806ad32201d5d5d5f42c756e3652
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_gt127 = Signal(reset_less=True)
100 self.exp_n127 = Signal(reset_less=True)
101 self.exp_n126 = Signal(reset_less=True)
102 self.m_zero = Signal(reset_less=True)
103 self.m_msbzero = Signal(reset_less=True)
104
105 def elaborate(self, platform):
106 m = Module()
107 m.d.comb += self.is_nan.eq(self._is_nan())
108 m.d.comb += self.is_zero.eq(self._is_zero())
109 m.d.comb += self.is_inf.eq(self._is_inf())
110 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
111 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
112 m.d.comb += self.exp_128.eq(self.e == self.P128)
113 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
114 m.d.comb += self.exp_n127.eq(self.e == self.N127)
115 m.d.comb += self.exp_n126.eq(self.e == self.N126)
116 m.d.comb += self.m_zero.eq(self.m == self.mzero)
117 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
118
119 return m
120
121 def _is_nan(self):
122 return (self.exp_128) & (~self.m_zero)
123
124 def _is_inf(self):
125 return (self.exp_128) & (self.m_zero)
126
127 def _is_zero(self):
128 return (self.exp_n127) & (self.m_zero)
129
130 def _is_overflowed(self):
131 return self.exp_gt127
132
133 def _is_denormalised(self):
134 return (self.exp_n126) & (self.m_msbzero)
135
136
137 class FPNumOut(FPNumBase):
138 """ Floating-point Number Class
139
140 Contains signals for an incoming copy of the value, decoded into
141 sign / exponent / mantissa.
142 Also contains encoding functions, creation and recognition of
143 zero, NaN and inf (all signed)
144
145 Four extra bits are included in the mantissa: the top bit
146 (m[-1]) is effectively a carry-overflow. The other three are
147 guard (m[2]), round (m[1]), and sticky (m[0])
148 """
149 def __init__(self, width, m_extra=True):
150 FPNumBase.__init__(self, width, m_extra)
151
152 def elaborate(self, platform):
153 m = FPNumBase.elaborate(self, platform)
154
155 return m
156
157 def create(self, s, e, m):
158 """ creates a value from sign / exponent / mantissa
159
160 bias is added here, to the exponent
161 """
162 return [
163 self.v[-1].eq(s), # sign
164 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
165 self.v[0:self.e_start].eq(m) # mantissa
166 ]
167
168 def nan(self, s):
169 return self.create(s, self.P128, 1<<(self.e_start-1))
170
171 def inf(self, s):
172 return self.create(s, self.P128, 0)
173
174 def zero(self, s):
175 return self.create(s, self.N127, 0)
176
177
178 class FPNumShift(FPNumBase):
179 """ Floating-point Number Class for shifting
180 """
181 def __init__(self, mainm, op, inv, width, m_extra=True):
182 FPNumBase.__init__(self, width, m_extra)
183 self.latch_in = Signal()
184 self.mainm = mainm
185 self.inv = inv
186 self.op = op
187
188 def elaborate(self, platform):
189 m = FPNumBase.elaborate(self, platform)
190
191 m.d.comb += self.s.eq(op.s)
192 m.d.comb += self.e.eq(op.e)
193 m.d.comb += self.m.eq(op.m)
194
195 with self.mainm.State("align"):
196 with m.If(self.e < self.inv.e):
197 m.d.sync += self.shift_down()
198
199 return m
200
201 def shift_down(self):
202 """ shifts a mantissa down by one. exponent is increased to compensate
203
204 accuracy is lost as a result in the mantissa however there are 3
205 guard bits (the latter of which is the "sticky" bit)
206 """
207 return [self.e.eq(self.e + 1),
208 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
209 ]
210
211 def shift_down_multi(self, diff):
212 """ shifts a mantissa down. exponent is increased to compensate
213
214 accuracy is lost as a result in the mantissa however there are 3
215 guard bits (the latter of which is the "sticky" bit)
216
217 this code works by variable-shifting the mantissa by up to
218 its maximum bit-length: no point doing more (it'll still be
219 zero).
220
221 the sticky bit is computed by shifting a batch of 1s by
222 the same amount, which will introduce zeros. it's then
223 inverted and used as a mask to get the LSBs of the mantissa.
224 those are then |'d into the sticky bit.
225 """
226 sm = MultiShift(self.width)
227 mw = Const(self.m_width-1, len(diff))
228 maxslen = Mux(diff > mw, mw, diff)
229 rs = sm.rshift(self.m[1:], maxslen)
230 maxsleni = mw - maxslen
231 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
232
233 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
234 return [self.e.eq(self.e + diff),
235 self.m.eq(Cat(stickybits, rs))
236 ]
237
238 def shift_up_multi(self, diff):
239 """ shifts a mantissa up. exponent is decreased to compensate
240 """
241 sm = MultiShift(self.width)
242 mw = Const(self.m_width, len(diff))
243 maxslen = Mux(diff > mw, mw, diff)
244
245 return [self.e.eq(self.e - diff),
246 self.m.eq(sm.lshift(self.m, maxslen))
247 ]
248
249 class FPNumIn(FPNumBase):
250 """ Floating-point Number Class
251
252 Contains signals for an incoming copy of the value, decoded into
253 sign / exponent / mantissa.
254 Also contains encoding functions, creation and recognition of
255 zero, NaN and inf (all signed)
256
257 Four extra bits are included in the mantissa: the top bit
258 (m[-1]) is effectively a carry-overflow. The other three are
259 guard (m[2]), round (m[1]), and sticky (m[0])
260 """
261 def __init__(self, op, width, m_extra=True):
262 FPNumBase.__init__(self, width, m_extra)
263 self.latch_in = Signal()
264 self.op = op
265
266 def elaborate(self, platform):
267 m = FPNumBase.elaborate(self, platform)
268
269 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
270 #with m.If(self.latch_in):
271 # m.d.sync += self.decode(self.v)
272
273 return m
274
275 def decode(self, v):
276 """ decodes a latched value into sign / exponent / mantissa
277
278 bias is subtracted here, from the exponent. exponent
279 is extended to 10 bits so that subtract 127 is done on
280 a 10-bit number
281 """
282 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
283 #print ("decode", self.e_end)
284 return [self.m.eq(Cat(*args)), # mantissa
285 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
286 self.s.eq(v[-1]), # sign
287 ]
288
289 def shift_down(self):
290 """ shifts a mantissa down by one. exponent is increased to compensate
291
292 accuracy is lost as a result in the mantissa however there are 3
293 guard bits (the latter of which is the "sticky" bit)
294 """
295 return [self.e.eq(self.e + 1),
296 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
297 ]
298
299 def shift_down_multi(self, diff):
300 """ shifts a mantissa down. exponent is increased to compensate
301
302 accuracy is lost as a result in the mantissa however there are 3
303 guard bits (the latter of which is the "sticky" bit)
304
305 this code works by variable-shifting the mantissa by up to
306 its maximum bit-length: no point doing more (it'll still be
307 zero).
308
309 the sticky bit is computed by shifting a batch of 1s by
310 the same amount, which will introduce zeros. it's then
311 inverted and used as a mask to get the LSBs of the mantissa.
312 those are then |'d into the sticky bit.
313 """
314 sm = MultiShift(self.width)
315 mw = Const(self.m_width-1, len(diff))
316 maxslen = Mux(diff > mw, mw, diff)
317 rs = sm.rshift(self.m[1:], maxslen)
318 maxsleni = mw - maxslen
319 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
320
321 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
322 return [self.e.eq(self.e + diff),
323 self.m.eq(Cat(stickybits, rs))
324 ]
325
326 def shift_up_multi(self, diff):
327 """ shifts a mantissa up. exponent is decreased to compensate
328 """
329 sm = MultiShift(self.width)
330 mw = Const(self.m_width, len(diff))
331 maxslen = Mux(diff > mw, mw, diff)
332
333 return [self.e.eq(self.e - diff),
334 self.m.eq(sm.lshift(self.m, maxslen))
335 ]
336
337 class FPOp:
338 def __init__(self, width):
339 self.width = width
340
341 self.v = Signal(width)
342 self.stb = Signal(reset=0)
343 self.ack = Signal()
344
345 def chain_from(self, in_op, extra=None):
346 stb = in_op.stb
347 if extra is not None:
348 stb = stb & extra
349 return [self.v.eq(in_op.v), # receive value
350 self.stb.eq(stb), # receive STB
351 in_op.ack.eq(self.ack), # send ACK
352 ]
353
354 def ports(self):
355 return [self.v, self.stb, self.ack]
356
357
358 class Overflow:
359 def __init__(self):
360 self.guard = Signal(reset_less=True) # tot[2]
361 self.round_bit = Signal(reset_less=True) # tot[1]
362 self.sticky = Signal(reset_less=True) # tot[0]
363 self.m0 = Signal(reset_less=True) # mantissa zero bit
364
365 self.roundz = Signal(reset_less=True)
366
367 def elaborate(self, platform):
368 m = Module()
369 m.d.comb += self.roundz.eq(self.guard & \
370 (self.round_bit | self.sticky | self.m0))
371 return m
372
373
374 class FPBase:
375 """ IEEE754 Floating Point Base Class
376
377 contains common functions for FP manipulation, such as
378 extracting and packing operands, normalisation, denormalisation,
379 rounding etc.
380 """
381
382 def get_op(self, m, op, v, next_state):
383 """ this function moves to the next state and copies the operand
384 when both stb and ack are 1.
385 acknowledgement is sent by setting ack to ZERO.
386 """
387 with m.If((op.ack) & (op.stb)):
388 m.next = next_state
389 m.d.sync += [
390 # op is latched in from FPNumIn class on same ack/stb
391 v.decode(op.v),
392 op.ack.eq(0)
393 ]
394 with m.Else():
395 m.d.sync += op.ack.eq(1)
396
397 def denormalise(self, m, a):
398 """ denormalises a number. this is probably the wrong name for
399 this function. for normalised numbers (exponent != minimum)
400 one *extra* bit (the implicit 1) is added *back in*.
401 for denormalised numbers, the mantissa is left alone
402 and the exponent increased by 1.
403
404 both cases *effectively multiply the number stored by 2*,
405 which has to be taken into account when extracting the result.
406 """
407 with m.If(a.e == a.N127):
408 m.d.sync += a.e.eq(a.N126) # limit a exponent
409 with m.Else():
410 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
411
412 def op_normalise(self, m, op, next_state):
413 """ operand normalisation
414 NOTE: just like "align", this one keeps going round every clock
415 until the result's exponent is within acceptable "range"
416 """
417 with m.If((op.m[-1] == 0)): # check last bit of mantissa
418 m.d.sync +=[
419 op.e.eq(op.e - 1), # DECREASE exponent
420 op.m.eq(op.m << 1), # shift mantissa UP
421 ]
422 with m.Else():
423 m.next = next_state
424
425 def normalise_1(self, m, z, of, next_state):
426 """ first stage normalisation
427
428 NOTE: just like "align", this one keeps going round every clock
429 until the result's exponent is within acceptable "range"
430 NOTE: the weirdness of reassigning guard and round is due to
431 the extra mantissa bits coming from tot[0..2]
432 """
433 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
434 m.d.sync += [
435 z.e.eq(z.e - 1), # DECREASE exponent
436 z.m.eq(z.m << 1), # shift mantissa UP
437 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
438 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
439 of.round_bit.eq(0), # reset round bit
440 of.m0.eq(of.guard),
441 ]
442 with m.Else():
443 m.next = next_state
444
445 def normalise_2(self, m, z, of, next_state):
446 """ second stage normalisation
447
448 NOTE: just like "align", this one keeps going round every clock
449 until the result's exponent is within acceptable "range"
450 NOTE: the weirdness of reassigning guard and round is due to
451 the extra mantissa bits coming from tot[0..2]
452 """
453 with m.If(z.e < z.N126):
454 m.d.sync +=[
455 z.e.eq(z.e + 1), # INCREASE exponent
456 z.m.eq(z.m >> 1), # shift mantissa DOWN
457 of.guard.eq(z.m[0]),
458 of.m0.eq(z.m[1]),
459 of.round_bit.eq(of.guard),
460 of.sticky.eq(of.sticky | of.round_bit)
461 ]
462 with m.Else():
463 m.next = next_state
464
465 def roundz(self, m, z, of, next_state):
466 """ performs rounding on the output. TODO: different kinds of rounding
467 """
468 m.next = next_state
469 with m.If(of.roundz):
470 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
471 with m.If(z.m == z.m1s): # all 1s
472 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
473
474 def corrections(self, m, z, next_state):
475 """ denormalisation and sign-bug corrections
476 """
477 m.next = next_state
478 # denormalised, correct exponent to zero
479 with m.If(z.is_denormalised):
480 m.d.sync += z.e.eq(z.N127)
481
482 def pack(self, m, z, next_state):
483 """ packs the result into the output (detects overflow->Inf)
484 """
485 m.next = next_state
486 # if overflow occurs, return inf
487 with m.If(z.is_overflowed):
488 m.d.sync += z.inf(z.s)
489 with m.Else():
490 m.d.sync += z.create(z.s, z.e, z.m)
491
492 def put_z(self, m, z, out_z, next_state):
493 """ put_z: stores the result in the output. raises stb and waits
494 for ack to be set to 1 before moving to the next state.
495 resets stb back to zero when that occurs, as acknowledgement.
496 """
497 m.d.sync += [
498 out_z.v.eq(z.v)
499 ]
500 with m.If(out_z.stb & out_z.ack):
501 m.d.sync += out_z.stb.eq(0)
502 m.next = next_state
503 with m.Else():
504 m.d.sync += out_z.stb.eq(1)
505
506