fix shift class syntax errors (untested)
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_lt_n126 = Signal(reset_less=True)
100 self.exp_gt_n126 = Signal(reset_less=True)
101 self.exp_gt127 = Signal(reset_less=True)
102 self.exp_n127 = Signal(reset_less=True)
103 self.exp_n126 = Signal(reset_less=True)
104 self.m_zero = Signal(reset_less=True)
105 self.m_msbzero = Signal(reset_less=True)
106
107 def elaborate(self, platform):
108 m = Module()
109 m.d.comb += self.is_nan.eq(self._is_nan())
110 m.d.comb += self.is_zero.eq(self._is_zero())
111 m.d.comb += self.is_inf.eq(self._is_inf())
112 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
113 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
114 m.d.comb += self.exp_128.eq(self.e == self.P128)
115 m.d.comb += self.exp_gt_n126.eq(self.e > self.N126)
116 m.d.comb += self.exp_lt_n126.eq(self.e < self.N126)
117 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
118 m.d.comb += self.exp_n127.eq(self.e == self.N127)
119 m.d.comb += self.exp_n126.eq(self.e == self.N126)
120 m.d.comb += self.m_zero.eq(self.m == self.mzero)
121 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
122
123 return m
124
125 def _is_nan(self):
126 return (self.exp_128) & (~self.m_zero)
127
128 def _is_inf(self):
129 return (self.exp_128) & (self.m_zero)
130
131 def _is_zero(self):
132 return (self.exp_n127) & (self.m_zero)
133
134 def _is_overflowed(self):
135 return self.exp_gt127
136
137 def _is_denormalised(self):
138 return (self.exp_n126) & (self.m_msbzero)
139
140 def copy(self, inp):
141 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
142
143
144 class FPNumOut(FPNumBase):
145 """ Floating-point Number Class
146
147 Contains signals for an incoming copy of the value, decoded into
148 sign / exponent / mantissa.
149 Also contains encoding functions, creation and recognition of
150 zero, NaN and inf (all signed)
151
152 Four extra bits are included in the mantissa: the top bit
153 (m[-1]) is effectively a carry-overflow. The other three are
154 guard (m[2]), round (m[1]), and sticky (m[0])
155 """
156 def __init__(self, width, m_extra=True):
157 FPNumBase.__init__(self, width, m_extra)
158
159 def elaborate(self, platform):
160 m = FPNumBase.elaborate(self, platform)
161
162 return m
163
164 def create(self, s, e, m):
165 """ creates a value from sign / exponent / mantissa
166
167 bias is added here, to the exponent
168 """
169 return [
170 self.v[-1].eq(s), # sign
171 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
172 self.v[0:self.e_start].eq(m) # mantissa
173 ]
174
175 def nan(self, s):
176 return self.create(s, self.P128, 1<<(self.e_start-1))
177
178 def inf(self, s):
179 return self.create(s, self.P128, 0)
180
181 def zero(self, s):
182 return self.create(s, self.N127, 0)
183
184
185 class FPNumShiftMultiRight(FPNumBase):
186 """ shifts a mantissa down. exponent is increased to compensate
187
188 accuracy is lost as a result in the mantissa however there are 3
189 guard bits (the latter of which is the "sticky" bit)
190
191 this code works by variable-shifting the mantissa by up to
192 its maximum bit-length: no point doing more (it'll still be
193 zero).
194
195 the sticky bit is computed by shifting a batch of 1s by
196 the same amount, which will introduce zeros. it's then
197 inverted and used as a mask to get the LSBs of the mantissa.
198 those are then |'d into the sticky bit.
199 """
200 def __init__(self, inp, diff, width):
201 self.m = Signal(width, reset_less=True)
202 self.inp = inp
203 self.diff = diff
204 self.width = width
205
206 def elaborate(self, platform):
207 m = Module()
208 #m.submodules.inp = self.inp
209
210 rs = Signal(self.width, reset_less=True)
211 m_mask = Signal(self.width, reset_less=True)
212 smask = Signal(self.width, reset_less=True)
213 stickybit = Signal(reset_less=True)
214
215 sm = MultiShift(self.width-1)
216 mw = Const(self.width-1, len(self.diff))
217 maxslen = Mux(self.diff > mw, mw, self.diff)
218 maxsleni = mw - maxslen
219 m.d.comb += [
220 # shift mantissa by maxslen, mask by inverse
221 rs.eq(sm.rshift(self.inp.m[1:], maxslen)),
222 m_mask.eq(sm.rshift(self.inp.m1s[1:], maxsleni)),
223 smask.eq(self.inp.m[1:] & m_mask),
224 # sticky bit combines all mask (and mantissa low bit)
225 stickybit.eq(smask.bool() | self.inp.m[0]),
226 #self.s.eq(self.inp.s),
227 #self.e.eq(self.inp.e + diff),
228 # mantissa result contains m[0] already.
229 self.m.eq(Cat(stickybit, rs))
230 ]
231 return m
232
233
234 class FPNumShift(FPNumBase):
235 """ Floating-point Number Class for shifting
236 """
237 def __init__(self, mainm, op, inv, width, m_extra=True):
238 FPNumBase.__init__(self, width, m_extra)
239 self.latch_in = Signal()
240 self.mainm = mainm
241 self.inv = inv
242 self.op = op
243
244 def elaborate(self, platform):
245 m = FPNumBase.elaborate(self, platform)
246
247 m.d.comb += self.s.eq(op.s)
248 m.d.comb += self.e.eq(op.e)
249 m.d.comb += self.m.eq(op.m)
250
251 with self.mainm.State("align"):
252 with m.If(self.e < self.inv.e):
253 m.d.sync += self.shift_down()
254
255 return m
256
257 def shift_down(self, inp):
258 """ shifts a mantissa down by one. exponent is increased to compensate
259
260 accuracy is lost as a result in the mantissa however there are 3
261 guard bits (the latter of which is the "sticky" bit)
262 """
263 return [self.e.eq(inp.e + 1),
264 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
265 ]
266
267 def shift_down_multi(self, diff):
268 """ shifts a mantissa down. exponent is increased to compensate
269
270 accuracy is lost as a result in the mantissa however there are 3
271 guard bits (the latter of which is the "sticky" bit)
272
273 this code works by variable-shifting the mantissa by up to
274 its maximum bit-length: no point doing more (it'll still be
275 zero).
276
277 the sticky bit is computed by shifting a batch of 1s by
278 the same amount, which will introduce zeros. it's then
279 inverted and used as a mask to get the LSBs of the mantissa.
280 those are then |'d into the sticky bit.
281 """
282 sm = MultiShift(self.width)
283 mw = Const(self.m_width-1, len(diff))
284 maxslen = Mux(diff > mw, mw, diff)
285 rs = sm.rshift(self.m[1:], maxslen)
286 maxsleni = mw - maxslen
287 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
288
289 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
290 return [self.e.eq(self.e + diff),
291 self.m.eq(Cat(stickybits, rs))
292 ]
293
294 def shift_up_multi(self, diff):
295 """ shifts a mantissa up. exponent is decreased to compensate
296 """
297 sm = MultiShift(self.width)
298 mw = Const(self.m_width, len(diff))
299 maxslen = Mux(diff > mw, mw, diff)
300
301 return [self.e.eq(self.e - diff),
302 self.m.eq(sm.lshift(self.m, maxslen))
303 ]
304
305 class FPNumIn(FPNumBase):
306 """ Floating-point Number Class
307
308 Contains signals for an incoming copy of the value, decoded into
309 sign / exponent / mantissa.
310 Also contains encoding functions, creation and recognition of
311 zero, NaN and inf (all signed)
312
313 Four extra bits are included in the mantissa: the top bit
314 (m[-1]) is effectively a carry-overflow. The other three are
315 guard (m[2]), round (m[1]), and sticky (m[0])
316 """
317 def __init__(self, op, width, m_extra=True):
318 FPNumBase.__init__(self, width, m_extra)
319 self.latch_in = Signal()
320 self.op = op
321
322 def elaborate(self, platform):
323 m = FPNumBase.elaborate(self, platform)
324
325 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
326 #with m.If(self.latch_in):
327 # m.d.sync += self.decode(self.v)
328
329 return m
330
331 def decode(self, v):
332 """ decodes a latched value into sign / exponent / mantissa
333
334 bias is subtracted here, from the exponent. exponent
335 is extended to 10 bits so that subtract 127 is done on
336 a 10-bit number
337 """
338 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
339 #print ("decode", self.e_end)
340 return [self.m.eq(Cat(*args)), # mantissa
341 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
342 self.s.eq(v[-1]), # sign
343 ]
344
345 def shift_down(self, inp):
346 """ shifts a mantissa down by one. exponent is increased to compensate
347
348 accuracy is lost as a result in the mantissa however there are 3
349 guard bits (the latter of which is the "sticky" bit)
350 """
351 return [self.e.eq(inp.e + 1),
352 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
353 ]
354
355 def shift_down_multi(self, diff, inp=None):
356 """ shifts a mantissa down. exponent is increased to compensate
357
358 accuracy is lost as a result in the mantissa however there are 3
359 guard bits (the latter of which is the "sticky" bit)
360
361 this code works by variable-shifting the mantissa by up to
362 its maximum bit-length: no point doing more (it'll still be
363 zero).
364
365 the sticky bit is computed by shifting a batch of 1s by
366 the same amount, which will introduce zeros. it's then
367 inverted and used as a mask to get the LSBs of the mantissa.
368 those are then |'d into the sticky bit.
369 """
370 if inp is None:
371 inp = self
372 sm = MultiShift(self.width)
373 mw = Const(self.m_width-1, len(diff))
374 maxslen = Mux(diff > mw, mw, diff)
375 rs = sm.rshift(inp.m[1:], maxslen)
376 maxsleni = mw - maxslen
377 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
378
379 #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
380 stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0]
381 return [self.e.eq(inp.e + diff),
382 self.m.eq(Cat(stickybits, rs))
383 ]
384
385 def shift_up_multi(self, diff):
386 """ shifts a mantissa up. exponent is decreased to compensate
387 """
388 sm = MultiShift(self.width)
389 mw = Const(self.m_width, len(diff))
390 maxslen = Mux(diff > mw, mw, diff)
391
392 return [self.e.eq(self.e - diff),
393 self.m.eq(sm.lshift(self.m, maxslen))
394 ]
395
396 class FPOp:
397 def __init__(self, width):
398 self.width = width
399
400 self.v = Signal(width)
401 self.stb = Signal(reset=0)
402 self.ack = Signal()
403 self.trigger = Signal(reset_less=True)
404
405 def elaborate(self, platform):
406 m = Module()
407 m.d.sync += self.trigger.eq(self.stb & self.ack)
408 return m
409
410 def chain_inv(self, in_op, extra=None):
411 stb = in_op.stb
412 if extra is not None:
413 stb = stb & extra
414 return [self.v.eq(in_op.v), # receive value
415 self.stb.eq(stb), # receive STB
416 in_op.ack.eq(~self.ack), # send ACK
417 ]
418
419 def chain_from(self, in_op, extra=None):
420 stb = in_op.stb
421 if extra is not None:
422 stb = stb & extra
423 return [self.v.eq(in_op.v), # receive value
424 self.stb.eq(stb), # receive STB
425 in_op.ack.eq(self.ack), # send ACK
426 ]
427
428 def copy(self, inp):
429 return [self.v.eq(inp.v),
430 self.stb.eq(inp.stb),
431 self.ack.eq(inp.ack)
432 ]
433
434 def ports(self):
435 return [self.v, self.stb, self.ack]
436
437
438 class Overflow:
439 def __init__(self):
440 self.guard = Signal(reset_less=True) # tot[2]
441 self.round_bit = Signal(reset_less=True) # tot[1]
442 self.sticky = Signal(reset_less=True) # tot[0]
443 self.m0 = Signal(reset_less=True) # mantissa zero bit
444
445 self.roundz = Signal(reset_less=True)
446
447 def copy(self, inp):
448 return [self.guard.eq(inp.guard),
449 self.round_bit.eq(inp.round_bit),
450 self.sticky.eq(inp.sticky),
451 self.m0.eq(inp.m0)]
452
453 def elaborate(self, platform):
454 m = Module()
455 m.d.comb += self.roundz.eq(self.guard & \
456 (self.round_bit | self.sticky | self.m0))
457 return m
458
459
460 class FPBase:
461 """ IEEE754 Floating Point Base Class
462
463 contains common functions for FP manipulation, such as
464 extracting and packing operands, normalisation, denormalisation,
465 rounding etc.
466 """
467
468 def get_op(self, m, op, v, next_state):
469 """ this function moves to the next state and copies the operand
470 when both stb and ack are 1.
471 acknowledgement is sent by setting ack to ZERO.
472 """
473 with m.If((op.ack) & (op.stb)):
474 m.next = next_state
475 m.d.sync += [
476 # op is latched in from FPNumIn class on same ack/stb
477 v.decode(op.v),
478 op.ack.eq(0)
479 ]
480 with m.Else():
481 m.d.sync += op.ack.eq(1)
482
483 def denormalise(self, m, a):
484 """ denormalises a number. this is probably the wrong name for
485 this function. for normalised numbers (exponent != minimum)
486 one *extra* bit (the implicit 1) is added *back in*.
487 for denormalised numbers, the mantissa is left alone
488 and the exponent increased by 1.
489
490 both cases *effectively multiply the number stored by 2*,
491 which has to be taken into account when extracting the result.
492 """
493 with m.If(a.exp_n127):
494 m.d.sync += a.e.eq(a.N126) # limit a exponent
495 with m.Else():
496 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
497
498 def op_normalise(self, m, op, next_state):
499 """ operand normalisation
500 NOTE: just like "align", this one keeps going round every clock
501 until the result's exponent is within acceptable "range"
502 """
503 with m.If((op.m[-1] == 0)): # check last bit of mantissa
504 m.d.sync +=[
505 op.e.eq(op.e - 1), # DECREASE exponent
506 op.m.eq(op.m << 1), # shift mantissa UP
507 ]
508 with m.Else():
509 m.next = next_state
510
511 def normalise_1(self, m, z, of, next_state):
512 """ first stage normalisation
513
514 NOTE: just like "align", this one keeps going round every clock
515 until the result's exponent is within acceptable "range"
516 NOTE: the weirdness of reassigning guard and round is due to
517 the extra mantissa bits coming from tot[0..2]
518 """
519 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
520 m.d.sync += [
521 z.e.eq(z.e - 1), # DECREASE exponent
522 z.m.eq(z.m << 1), # shift mantissa UP
523 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
524 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
525 of.round_bit.eq(0), # reset round bit
526 of.m0.eq(of.guard),
527 ]
528 with m.Else():
529 m.next = next_state
530
531 def normalise_2(self, m, z, of, next_state):
532 """ second stage normalisation
533
534 NOTE: just like "align", this one keeps going round every clock
535 until the result's exponent is within acceptable "range"
536 NOTE: the weirdness of reassigning guard and round is due to
537 the extra mantissa bits coming from tot[0..2]
538 """
539 with m.If(z.e < z.N126):
540 m.d.sync +=[
541 z.e.eq(z.e + 1), # INCREASE exponent
542 z.m.eq(z.m >> 1), # shift mantissa DOWN
543 of.guard.eq(z.m[0]),
544 of.m0.eq(z.m[1]),
545 of.round_bit.eq(of.guard),
546 of.sticky.eq(of.sticky | of.round_bit)
547 ]
548 with m.Else():
549 m.next = next_state
550
551 def roundz(self, m, z, out_z, roundz):
552 """ performs rounding on the output. TODO: different kinds of rounding
553 """
554 m.d.comb += out_z.copy(z) # copies input to output first
555 with m.If(roundz):
556 m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
557 with m.If(z.m == z.m1s): # all 1s
558 m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
559
560 def corrections(self, m, z, next_state):
561 """ denormalisation and sign-bug corrections
562 """
563 m.next = next_state
564 # denormalised, correct exponent to zero
565 with m.If(z.is_denormalised):
566 m.d.sync += z.e.eq(z.N127)
567
568 def pack(self, m, z, next_state):
569 """ packs the result into the output (detects overflow->Inf)
570 """
571 m.next = next_state
572 # if overflow occurs, return inf
573 with m.If(z.is_overflowed):
574 m.d.sync += z.inf(z.s)
575 with m.Else():
576 m.d.sync += z.create(z.s, z.e, z.m)
577
578 def put_z(self, m, z, out_z, next_state):
579 """ put_z: stores the result in the output. raises stb and waits
580 for ack to be set to 1 before moving to the next state.
581 resets stb back to zero when that occurs, as acknowledgement.
582 """
583 m.d.sync += [
584 out_z.v.eq(z.v)
585 ]
586 with m.If(out_z.stb & out_z.ack):
587 m.d.sync += out_z.stb.eq(0)
588 m.next = next_state
589 with m.Else():
590 m.d.sync += out_z.stb.eq(1)
591
592