put exponent > 126 logic in FPNumBase, use it in norm module
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_gt_n126 = Signal(reset_less=True)
100 self.exp_gt127 = Signal(reset_less=True)
101 self.exp_n127 = Signal(reset_less=True)
102 self.exp_n126 = Signal(reset_less=True)
103 self.m_zero = Signal(reset_less=True)
104 self.m_msbzero = Signal(reset_less=True)
105
106 def elaborate(self, platform):
107 m = Module()
108 m.d.comb += self.is_nan.eq(self._is_nan())
109 m.d.comb += self.is_zero.eq(self._is_zero())
110 m.d.comb += self.is_inf.eq(self._is_inf())
111 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
112 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
113 m.d.comb += self.exp_128.eq(self.e == self.P128)
114 m.d.comb += self.exp_gt_n126.eq(self.e > self.N126)
115 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
116 m.d.comb += self.exp_n127.eq(self.e == self.N127)
117 m.d.comb += self.exp_n126.eq(self.e == self.N126)
118 m.d.comb += self.m_zero.eq(self.m == self.mzero)
119 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
120
121 return m
122
123 def _is_nan(self):
124 return (self.exp_128) & (~self.m_zero)
125
126 def _is_inf(self):
127 return (self.exp_128) & (self.m_zero)
128
129 def _is_zero(self):
130 return (self.exp_n127) & (self.m_zero)
131
132 def _is_overflowed(self):
133 return self.exp_gt127
134
135 def _is_denormalised(self):
136 return (self.exp_n126) & (self.m_msbzero)
137
138 def copy(self, inp):
139 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
140
141
142 class FPNumOut(FPNumBase):
143 """ Floating-point Number Class
144
145 Contains signals for an incoming copy of the value, decoded into
146 sign / exponent / mantissa.
147 Also contains encoding functions, creation and recognition of
148 zero, NaN and inf (all signed)
149
150 Four extra bits are included in the mantissa: the top bit
151 (m[-1]) is effectively a carry-overflow. The other three are
152 guard (m[2]), round (m[1]), and sticky (m[0])
153 """
154 def __init__(self, width, m_extra=True):
155 FPNumBase.__init__(self, width, m_extra)
156
157 def elaborate(self, platform):
158 m = FPNumBase.elaborate(self, platform)
159
160 return m
161
162 def create(self, s, e, m):
163 """ creates a value from sign / exponent / mantissa
164
165 bias is added here, to the exponent
166 """
167 return [
168 self.v[-1].eq(s), # sign
169 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
170 self.v[0:self.e_start].eq(m) # mantissa
171 ]
172
173 def nan(self, s):
174 return self.create(s, self.P128, 1<<(self.e_start-1))
175
176 def inf(self, s):
177 return self.create(s, self.P128, 0)
178
179 def zero(self, s):
180 return self.create(s, self.N127, 0)
181
182
183 class FPNumShift(FPNumBase):
184 """ Floating-point Number Class for shifting
185 """
186 def __init__(self, mainm, op, inv, width, m_extra=True):
187 FPNumBase.__init__(self, width, m_extra)
188 self.latch_in = Signal()
189 self.mainm = mainm
190 self.inv = inv
191 self.op = op
192
193 def elaborate(self, platform):
194 m = FPNumBase.elaborate(self, platform)
195
196 m.d.comb += self.s.eq(op.s)
197 m.d.comb += self.e.eq(op.e)
198 m.d.comb += self.m.eq(op.m)
199
200 with self.mainm.State("align"):
201 with m.If(self.e < self.inv.e):
202 m.d.sync += self.shift_down()
203
204 return m
205
206 def shift_down(self):
207 """ shifts a mantissa down by one. exponent is increased to compensate
208
209 accuracy is lost as a result in the mantissa however there are 3
210 guard bits (the latter of which is the "sticky" bit)
211 """
212 return [self.e.eq(self.e + 1),
213 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
214 ]
215
216 def shift_down_multi(self, diff):
217 """ shifts a mantissa down. exponent is increased to compensate
218
219 accuracy is lost as a result in the mantissa however there are 3
220 guard bits (the latter of which is the "sticky" bit)
221
222 this code works by variable-shifting the mantissa by up to
223 its maximum bit-length: no point doing more (it'll still be
224 zero).
225
226 the sticky bit is computed by shifting a batch of 1s by
227 the same amount, which will introduce zeros. it's then
228 inverted and used as a mask to get the LSBs of the mantissa.
229 those are then |'d into the sticky bit.
230 """
231 sm = MultiShift(self.width)
232 mw = Const(self.m_width-1, len(diff))
233 maxslen = Mux(diff > mw, mw, diff)
234 rs = sm.rshift(self.m[1:], maxslen)
235 maxsleni = mw - maxslen
236 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
237
238 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
239 return [self.e.eq(self.e + diff),
240 self.m.eq(Cat(stickybits, rs))
241 ]
242
243 def shift_up_multi(self, diff):
244 """ shifts a mantissa up. exponent is decreased to compensate
245 """
246 sm = MultiShift(self.width)
247 mw = Const(self.m_width, len(diff))
248 maxslen = Mux(diff > mw, mw, diff)
249
250 return [self.e.eq(self.e - diff),
251 self.m.eq(sm.lshift(self.m, maxslen))
252 ]
253
254 class FPNumIn(FPNumBase):
255 """ Floating-point Number Class
256
257 Contains signals for an incoming copy of the value, decoded into
258 sign / exponent / mantissa.
259 Also contains encoding functions, creation and recognition of
260 zero, NaN and inf (all signed)
261
262 Four extra bits are included in the mantissa: the top bit
263 (m[-1]) is effectively a carry-overflow. The other three are
264 guard (m[2]), round (m[1]), and sticky (m[0])
265 """
266 def __init__(self, op, width, m_extra=True):
267 FPNumBase.__init__(self, width, m_extra)
268 self.latch_in = Signal()
269 self.op = op
270
271 def elaborate(self, platform):
272 m = FPNumBase.elaborate(self, platform)
273
274 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
275 #with m.If(self.latch_in):
276 # m.d.sync += self.decode(self.v)
277
278 return m
279
280 def decode(self, v):
281 """ decodes a latched value into sign / exponent / mantissa
282
283 bias is subtracted here, from the exponent. exponent
284 is extended to 10 bits so that subtract 127 is done on
285 a 10-bit number
286 """
287 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
288 #print ("decode", self.e_end)
289 return [self.m.eq(Cat(*args)), # mantissa
290 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
291 self.s.eq(v[-1]), # sign
292 ]
293
294 def shift_down(self):
295 """ shifts a mantissa down by one. exponent is increased to compensate
296
297 accuracy is lost as a result in the mantissa however there are 3
298 guard bits (the latter of which is the "sticky" bit)
299 """
300 return [self.e.eq(self.e + 1),
301 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
302 ]
303
304 def shift_down_multi(self, diff):
305 """ shifts a mantissa down. exponent is increased to compensate
306
307 accuracy is lost as a result in the mantissa however there are 3
308 guard bits (the latter of which is the "sticky" bit)
309
310 this code works by variable-shifting the mantissa by up to
311 its maximum bit-length: no point doing more (it'll still be
312 zero).
313
314 the sticky bit is computed by shifting a batch of 1s by
315 the same amount, which will introduce zeros. it's then
316 inverted and used as a mask to get the LSBs of the mantissa.
317 those are then |'d into the sticky bit.
318 """
319 sm = MultiShift(self.width)
320 mw = Const(self.m_width-1, len(diff))
321 maxslen = Mux(diff > mw, mw, diff)
322 rs = sm.rshift(self.m[1:], maxslen)
323 maxsleni = mw - maxslen
324 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
325
326 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
327 return [self.e.eq(self.e + diff),
328 self.m.eq(Cat(stickybits, rs))
329 ]
330
331 def shift_up_multi(self, diff):
332 """ shifts a mantissa up. exponent is decreased to compensate
333 """
334 sm = MultiShift(self.width)
335 mw = Const(self.m_width, len(diff))
336 maxslen = Mux(diff > mw, mw, diff)
337
338 return [self.e.eq(self.e - diff),
339 self.m.eq(sm.lshift(self.m, maxslen))
340 ]
341
342 class FPOp:
343 def __init__(self, width):
344 self.width = width
345
346 self.v = Signal(width)
347 self.stb = Signal(reset=0)
348 self.ack = Signal()
349
350 def chain_inv(self, in_op, extra=None):
351 stb = in_op.stb
352 if extra is not None:
353 stb = stb & extra
354 return [self.v.eq(in_op.v), # receive value
355 self.stb.eq(stb), # receive STB
356 in_op.ack.eq(~self.ack), # send ACK
357 ]
358
359 def chain_from(self, in_op, extra=None):
360 stb = in_op.stb
361 if extra is not None:
362 stb = stb & extra
363 return [self.v.eq(in_op.v), # receive value
364 self.stb.eq(stb), # receive STB
365 in_op.ack.eq(self.ack), # send ACK
366 ]
367
368 def ports(self):
369 return [self.v, self.stb, self.ack]
370
371
372 class Overflow:
373 def __init__(self):
374 self.guard = Signal(reset_less=True) # tot[2]
375 self.round_bit = Signal(reset_less=True) # tot[1]
376 self.sticky = Signal(reset_less=True) # tot[0]
377 self.m0 = Signal(reset_less=True) # mantissa zero bit
378
379 self.roundz = Signal(reset_less=True)
380
381 def copy(self, inp):
382 return [self.guard.eq(inp.guard),
383 self.round_bit.eq(inp.round_bit),
384 self.sticky.eq(inp.sticky),
385 self.m0.eq(inp.m0)]
386
387 def elaborate(self, platform):
388 m = Module()
389 m.d.comb += self.roundz.eq(self.guard & \
390 (self.round_bit | self.sticky | self.m0))
391 return m
392
393
394 class FPBase:
395 """ IEEE754 Floating Point Base Class
396
397 contains common functions for FP manipulation, such as
398 extracting and packing operands, normalisation, denormalisation,
399 rounding etc.
400 """
401
402 def get_op(self, m, op, v, next_state):
403 """ this function moves to the next state and copies the operand
404 when both stb and ack are 1.
405 acknowledgement is sent by setting ack to ZERO.
406 """
407 with m.If((op.ack) & (op.stb)):
408 m.next = next_state
409 m.d.sync += [
410 # op is latched in from FPNumIn class on same ack/stb
411 v.decode(op.v),
412 op.ack.eq(0)
413 ]
414 with m.Else():
415 m.d.sync += op.ack.eq(1)
416
417 def denormalise(self, m, a):
418 """ denormalises a number. this is probably the wrong name for
419 this function. for normalised numbers (exponent != minimum)
420 one *extra* bit (the implicit 1) is added *back in*.
421 for denormalised numbers, the mantissa is left alone
422 and the exponent increased by 1.
423
424 both cases *effectively multiply the number stored by 2*,
425 which has to be taken into account when extracting the result.
426 """
427 with m.If(a.e == a.N127):
428 m.d.sync += a.e.eq(a.N126) # limit a exponent
429 with m.Else():
430 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
431
432 def op_normalise(self, m, op, next_state):
433 """ operand normalisation
434 NOTE: just like "align", this one keeps going round every clock
435 until the result's exponent is within acceptable "range"
436 """
437 with m.If((op.m[-1] == 0)): # check last bit of mantissa
438 m.d.sync +=[
439 op.e.eq(op.e - 1), # DECREASE exponent
440 op.m.eq(op.m << 1), # shift mantissa UP
441 ]
442 with m.Else():
443 m.next = next_state
444
445 def normalise_1(self, m, z, of, next_state):
446 """ first stage normalisation
447
448 NOTE: just like "align", this one keeps going round every clock
449 until the result's exponent is within acceptable "range"
450 NOTE: the weirdness of reassigning guard and round is due to
451 the extra mantissa bits coming from tot[0..2]
452 """
453 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
454 m.d.sync += [
455 z.e.eq(z.e - 1), # DECREASE exponent
456 z.m.eq(z.m << 1), # shift mantissa UP
457 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
458 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
459 of.round_bit.eq(0), # reset round bit
460 of.m0.eq(of.guard),
461 ]
462 with m.Else():
463 m.next = next_state
464
465 def normalise_2(self, m, z, of, next_state):
466 """ second stage normalisation
467
468 NOTE: just like "align", this one keeps going round every clock
469 until the result's exponent is within acceptable "range"
470 NOTE: the weirdness of reassigning guard and round is due to
471 the extra mantissa bits coming from tot[0..2]
472 """
473 with m.If(z.e < z.N126):
474 m.d.sync +=[
475 z.e.eq(z.e + 1), # INCREASE exponent
476 z.m.eq(z.m >> 1), # shift mantissa DOWN
477 of.guard.eq(z.m[0]),
478 of.m0.eq(z.m[1]),
479 of.round_bit.eq(of.guard),
480 of.sticky.eq(of.sticky | of.round_bit)
481 ]
482 with m.Else():
483 m.next = next_state
484
485 def roundz(self, m, z, out_z, roundz):
486 """ performs rounding on the output. TODO: different kinds of rounding
487 """
488 m.d.comb += out_z.copy(z) # copies input to output first
489 with m.If(roundz):
490 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
491 with m.If(z.m == z.m1s): # all 1s
492 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
493
494 def corrections(self, m, z, next_state):
495 """ denormalisation and sign-bug corrections
496 """
497 m.next = next_state
498 # denormalised, correct exponent to zero
499 with m.If(z.is_denormalised):
500 m.d.sync += z.e.eq(z.N127)
501
502 def pack(self, m, z, next_state):
503 """ packs the result into the output (detects overflow->Inf)
504 """
505 m.next = next_state
506 # if overflow occurs, return inf
507 with m.If(z.is_overflowed):
508 m.d.sync += z.inf(z.s)
509 with m.Else():
510 m.d.sync += z.create(z.s, z.e, z.m)
511
512 def put_z(self, m, z, out_z, next_state):
513 """ put_z: stores the result in the output. raises stb and waits
514 for ack to be set to 1 before moving to the next state.
515 resets stb back to zero when that occurs, as acknowledgement.
516 """
517 m.d.sync += [
518 out_z.v.eq(z.v)
519 ]
520 with m.If(out_z.stb & out_z.ack):
521 m.d.sync += out_z.stb.eq(0)
522 m.next = next_state
523 with m.Else():
524 m.d.sync += out_z.stb.eq(1)
525
526