create single and multi shift cycle, single doesnt work, multi does
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_lt_n126 = Signal(reset_less=True)
100 self.exp_gt_n126 = Signal(reset_less=True)
101 self.exp_gt127 = Signal(reset_less=True)
102 self.exp_n127 = Signal(reset_less=True)
103 self.exp_n126 = Signal(reset_less=True)
104 self.m_zero = Signal(reset_less=True)
105 self.m_msbzero = Signal(reset_less=True)
106
107 def elaborate(self, platform):
108 m = Module()
109 m.d.comb += self.is_nan.eq(self._is_nan())
110 m.d.comb += self.is_zero.eq(self._is_zero())
111 m.d.comb += self.is_inf.eq(self._is_inf())
112 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
113 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
114 m.d.comb += self.exp_128.eq(self.e == self.P128)
115 m.d.comb += self.exp_gt_n126.eq(self.e > self.N126)
116 m.d.comb += self.exp_lt_n126.eq(self.e < self.N126)
117 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
118 m.d.comb += self.exp_n127.eq(self.e == self.N127)
119 m.d.comb += self.exp_n126.eq(self.e == self.N126)
120 m.d.comb += self.m_zero.eq(self.m == self.mzero)
121 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
122
123 return m
124
125 def _is_nan(self):
126 return (self.exp_128) & (~self.m_zero)
127
128 def _is_inf(self):
129 return (self.exp_128) & (self.m_zero)
130
131 def _is_zero(self):
132 return (self.exp_n127) & (self.m_zero)
133
134 def _is_overflowed(self):
135 return self.exp_gt127
136
137 def _is_denormalised(self):
138 return (self.exp_n126) & (self.m_msbzero)
139
140 def copy(self, inp):
141 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
142
143
144 class FPNumOut(FPNumBase):
145 """ Floating-point Number Class
146
147 Contains signals for an incoming copy of the value, decoded into
148 sign / exponent / mantissa.
149 Also contains encoding functions, creation and recognition of
150 zero, NaN and inf (all signed)
151
152 Four extra bits are included in the mantissa: the top bit
153 (m[-1]) is effectively a carry-overflow. The other three are
154 guard (m[2]), round (m[1]), and sticky (m[0])
155 """
156 def __init__(self, width, m_extra=True):
157 FPNumBase.__init__(self, width, m_extra)
158
159 def elaborate(self, platform):
160 m = FPNumBase.elaborate(self, platform)
161
162 return m
163
164 def create(self, s, e, m):
165 """ creates a value from sign / exponent / mantissa
166
167 bias is added here, to the exponent
168 """
169 return [
170 self.v[-1].eq(s), # sign
171 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
172 self.v[0:self.e_start].eq(m) # mantissa
173 ]
174
175 def nan(self, s):
176 return self.create(s, self.P128, 1<<(self.e_start-1))
177
178 def inf(self, s):
179 return self.create(s, self.P128, 0)
180
181 def zero(self, s):
182 return self.create(s, self.N127, 0)
183
184
185 class FPNumShift(FPNumBase):
186 """ Floating-point Number Class for shifting
187 """
188 def __init__(self, mainm, op, inv, width, m_extra=True):
189 FPNumBase.__init__(self, width, m_extra)
190 self.latch_in = Signal()
191 self.mainm = mainm
192 self.inv = inv
193 self.op = op
194
195 def elaborate(self, platform):
196 m = FPNumBase.elaborate(self, platform)
197
198 m.d.comb += self.s.eq(op.s)
199 m.d.comb += self.e.eq(op.e)
200 m.d.comb += self.m.eq(op.m)
201
202 with self.mainm.State("align"):
203 with m.If(self.e < self.inv.e):
204 m.d.sync += self.shift_down()
205
206 return m
207
208 def shift_down(self, inp):
209 """ shifts a mantissa down by one. exponent is increased to compensate
210
211 accuracy is lost as a result in the mantissa however there are 3
212 guard bits (the latter of which is the "sticky" bit)
213 """
214 return [self.e.eq(inp.e + 1),
215 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
216 ]
217
218 def shift_down_multi(self, diff):
219 """ shifts a mantissa down. exponent is increased to compensate
220
221 accuracy is lost as a result in the mantissa however there are 3
222 guard bits (the latter of which is the "sticky" bit)
223
224 this code works by variable-shifting the mantissa by up to
225 its maximum bit-length: no point doing more (it'll still be
226 zero).
227
228 the sticky bit is computed by shifting a batch of 1s by
229 the same amount, which will introduce zeros. it's then
230 inverted and used as a mask to get the LSBs of the mantissa.
231 those are then |'d into the sticky bit.
232 """
233 sm = MultiShift(self.width)
234 mw = Const(self.m_width-1, len(diff))
235 maxslen = Mux(diff > mw, mw, diff)
236 rs = sm.rshift(self.m[1:], maxslen)
237 maxsleni = mw - maxslen
238 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
239
240 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
241 return [self.e.eq(self.e + diff),
242 self.m.eq(Cat(stickybits, rs))
243 ]
244
245 def shift_up_multi(self, diff):
246 """ shifts a mantissa up. exponent is decreased to compensate
247 """
248 sm = MultiShift(self.width)
249 mw = Const(self.m_width, len(diff))
250 maxslen = Mux(diff > mw, mw, diff)
251
252 return [self.e.eq(self.e - diff),
253 self.m.eq(sm.lshift(self.m, maxslen))
254 ]
255
256 class FPNumIn(FPNumBase):
257 """ Floating-point Number Class
258
259 Contains signals for an incoming copy of the value, decoded into
260 sign / exponent / mantissa.
261 Also contains encoding functions, creation and recognition of
262 zero, NaN and inf (all signed)
263
264 Four extra bits are included in the mantissa: the top bit
265 (m[-1]) is effectively a carry-overflow. The other three are
266 guard (m[2]), round (m[1]), and sticky (m[0])
267 """
268 def __init__(self, op, width, m_extra=True):
269 FPNumBase.__init__(self, width, m_extra)
270 self.latch_in = Signal()
271 self.op = op
272
273 def elaborate(self, platform):
274 m = FPNumBase.elaborate(self, platform)
275
276 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
277 #with m.If(self.latch_in):
278 # m.d.sync += self.decode(self.v)
279
280 return m
281
282 def decode(self, v):
283 """ decodes a latched value into sign / exponent / mantissa
284
285 bias is subtracted here, from the exponent. exponent
286 is extended to 10 bits so that subtract 127 is done on
287 a 10-bit number
288 """
289 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
290 #print ("decode", self.e_end)
291 return [self.m.eq(Cat(*args)), # mantissa
292 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
293 self.s.eq(v[-1]), # sign
294 ]
295
296 def shift_down(self, inp):
297 """ shifts a mantissa down by one. exponent is increased to compensate
298
299 accuracy is lost as a result in the mantissa however there are 3
300 guard bits (the latter of which is the "sticky" bit)
301 """
302 return [self.e.eq(inp.e + 1),
303 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
304 ]
305
306 def shift_down_multi(self, diff):
307 """ shifts a mantissa down. exponent is increased to compensate
308
309 accuracy is lost as a result in the mantissa however there are 3
310 guard bits (the latter of which is the "sticky" bit)
311
312 this code works by variable-shifting the mantissa by up to
313 its maximum bit-length: no point doing more (it'll still be
314 zero).
315
316 the sticky bit is computed by shifting a batch of 1s by
317 the same amount, which will introduce zeros. it's then
318 inverted and used as a mask to get the LSBs of the mantissa.
319 those are then |'d into the sticky bit.
320 """
321 sm = MultiShift(self.width)
322 mw = Const(self.m_width-1, len(diff))
323 maxslen = Mux(diff > mw, mw, diff)
324 rs = sm.rshift(self.m[1:], maxslen)
325 maxsleni = mw - maxslen
326 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
327
328 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
329 return [self.e.eq(self.e + diff),
330 self.m.eq(Cat(stickybits, rs))
331 ]
332
333 def shift_up_multi(self, diff):
334 """ shifts a mantissa up. exponent is decreased to compensate
335 """
336 sm = MultiShift(self.width)
337 mw = Const(self.m_width, len(diff))
338 maxslen = Mux(diff > mw, mw, diff)
339
340 return [self.e.eq(self.e - diff),
341 self.m.eq(sm.lshift(self.m, maxslen))
342 ]
343
344 class FPOp:
345 def __init__(self, width):
346 self.width = width
347
348 self.v = Signal(width)
349 self.stb = Signal(reset=0)
350 self.ack = Signal()
351
352 def chain_inv(self, in_op, extra=None):
353 stb = in_op.stb
354 if extra is not None:
355 stb = stb & extra
356 return [self.v.eq(in_op.v), # receive value
357 self.stb.eq(stb), # receive STB
358 in_op.ack.eq(~self.ack), # send ACK
359 ]
360
361 def chain_from(self, in_op, extra=None):
362 stb = in_op.stb
363 if extra is not None:
364 stb = stb & extra
365 return [self.v.eq(in_op.v), # receive value
366 self.stb.eq(stb), # receive STB
367 in_op.ack.eq(self.ack), # send ACK
368 ]
369
370 def ports(self):
371 return [self.v, self.stb, self.ack]
372
373
374 class Overflow:
375 def __init__(self):
376 self.guard = Signal(reset_less=True) # tot[2]
377 self.round_bit = Signal(reset_less=True) # tot[1]
378 self.sticky = Signal(reset_less=True) # tot[0]
379 self.m0 = Signal(reset_less=True) # mantissa zero bit
380
381 self.roundz = Signal(reset_less=True)
382
383 def copy(self, inp):
384 return [self.guard.eq(inp.guard),
385 self.round_bit.eq(inp.round_bit),
386 self.sticky.eq(inp.sticky),
387 self.m0.eq(inp.m0)]
388
389 def elaborate(self, platform):
390 m = Module()
391 m.d.comb += self.roundz.eq(self.guard & \
392 (self.round_bit | self.sticky | self.m0))
393 return m
394
395
396 class FPBase:
397 """ IEEE754 Floating Point Base Class
398
399 contains common functions for FP manipulation, such as
400 extracting and packing operands, normalisation, denormalisation,
401 rounding etc.
402 """
403
404 def get_op(self, m, op, v, next_state):
405 """ this function moves to the next state and copies the operand
406 when both stb and ack are 1.
407 acknowledgement is sent by setting ack to ZERO.
408 """
409 with m.If((op.ack) & (op.stb)):
410 m.next = next_state
411 m.d.sync += [
412 # op is latched in from FPNumIn class on same ack/stb
413 v.decode(op.v),
414 op.ack.eq(0)
415 ]
416 with m.Else():
417 m.d.sync += op.ack.eq(1)
418
419 def denormalise(self, m, a):
420 """ denormalises a number. this is probably the wrong name for
421 this function. for normalised numbers (exponent != minimum)
422 one *extra* bit (the implicit 1) is added *back in*.
423 for denormalised numbers, the mantissa is left alone
424 and the exponent increased by 1.
425
426 both cases *effectively multiply the number stored by 2*,
427 which has to be taken into account when extracting the result.
428 """
429 with m.If(a.e == a.N127):
430 m.d.sync += a.e.eq(a.N126) # limit a exponent
431 with m.Else():
432 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
433
434 def op_normalise(self, m, op, next_state):
435 """ operand normalisation
436 NOTE: just like "align", this one keeps going round every clock
437 until the result's exponent is within acceptable "range"
438 """
439 with m.If((op.m[-1] == 0)): # check last bit of mantissa
440 m.d.sync +=[
441 op.e.eq(op.e - 1), # DECREASE exponent
442 op.m.eq(op.m << 1), # shift mantissa UP
443 ]
444 with m.Else():
445 m.next = next_state
446
447 def normalise_1(self, m, z, of, next_state):
448 """ first stage normalisation
449
450 NOTE: just like "align", this one keeps going round every clock
451 until the result's exponent is within acceptable "range"
452 NOTE: the weirdness of reassigning guard and round is due to
453 the extra mantissa bits coming from tot[0..2]
454 """
455 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
456 m.d.sync += [
457 z.e.eq(z.e - 1), # DECREASE exponent
458 z.m.eq(z.m << 1), # shift mantissa UP
459 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
460 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
461 of.round_bit.eq(0), # reset round bit
462 of.m0.eq(of.guard),
463 ]
464 with m.Else():
465 m.next = next_state
466
467 def normalise_2(self, m, z, of, next_state):
468 """ second stage normalisation
469
470 NOTE: just like "align", this one keeps going round every clock
471 until the result's exponent is within acceptable "range"
472 NOTE: the weirdness of reassigning guard and round is due to
473 the extra mantissa bits coming from tot[0..2]
474 """
475 with m.If(z.e < z.N126):
476 m.d.sync +=[
477 z.e.eq(z.e + 1), # INCREASE exponent
478 z.m.eq(z.m >> 1), # shift mantissa DOWN
479 of.guard.eq(z.m[0]),
480 of.m0.eq(z.m[1]),
481 of.round_bit.eq(of.guard),
482 of.sticky.eq(of.sticky | of.round_bit)
483 ]
484 with m.Else():
485 m.next = next_state
486
487 def roundz(self, m, z, out_z, roundz):
488 """ performs rounding on the output. TODO: different kinds of rounding
489 """
490 m.d.comb += out_z.copy(z) # copies input to output first
491 with m.If(roundz):
492 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
493 with m.If(z.m == z.m1s): # all 1s
494 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
495
496 def corrections(self, m, z, next_state):
497 """ denormalisation and sign-bug corrections
498 """
499 m.next = next_state
500 # denormalised, correct exponent to zero
501 with m.If(z.is_denormalised):
502 m.d.sync += z.e.eq(z.N127)
503
504 def pack(self, m, z, next_state):
505 """ packs the result into the output (detects overflow->Inf)
506 """
507 m.next = next_state
508 # if overflow occurs, return inf
509 with m.If(z.is_overflowed):
510 m.d.sync += z.inf(z.s)
511 with m.Else():
512 m.d.sync += z.create(z.s, z.e, z.m)
513
514 def put_z(self, m, z, out_z, next_state):
515 """ put_z: stores the result in the output. raises stb and waits
516 for ack to be set to 1 before moving to the next state.
517 resets stb back to zero when that occurs, as acknowledgement.
518 """
519 m.d.sync += [
520 out_z.v.eq(z.v)
521 ]
522 with m.If(out_z.stb & out_z.ack):
523 m.d.sync += out_z.stb.eq(0)
524 m.next = next_state
525 with m.Else():
526 m.d.sync += out_z.stb.eq(1)
527
528