create FPDecode module
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
100 self.exp_lt_n126 = Signal(reset_less=True)
101 self.exp_gt_n126 = Signal(reset_less=True)
102 self.exp_gt127 = Signal(reset_less=True)
103 self.exp_n127 = Signal(reset_less=True)
104 self.exp_n126 = Signal(reset_less=True)
105 self.m_zero = Signal(reset_less=True)
106 self.m_msbzero = Signal(reset_less=True)
107
108 def elaborate(self, platform):
109 m = Module()
110 m.d.comb += self.is_nan.eq(self._is_nan())
111 m.d.comb += self.is_zero.eq(self._is_zero())
112 m.d.comb += self.is_inf.eq(self._is_inf())
113 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
114 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
115 m.d.comb += self.exp_128.eq(self.e == self.P128)
116 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
117 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
118 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
119 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
120 m.d.comb += self.exp_n127.eq(self.e == self.N127)
121 m.d.comb += self.exp_n126.eq(self.e == self.N126)
122 m.d.comb += self.m_zero.eq(self.m == self.mzero)
123 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
124
125 return m
126
127 def _is_nan(self):
128 return (self.exp_128) & (~self.m_zero)
129
130 def _is_inf(self):
131 return (self.exp_128) & (self.m_zero)
132
133 def _is_zero(self):
134 return (self.exp_n127) & (self.m_zero)
135
136 def _is_overflowed(self):
137 return self.exp_gt127
138
139 def _is_denormalised(self):
140 return (self.exp_n126) & (self.m_msbzero)
141
142 def eq(self, inp):
143 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
144
145
146 class FPNumOut(FPNumBase):
147 """ Floating-point Number Class
148
149 Contains signals for an incoming copy of the value, decoded into
150 sign / exponent / mantissa.
151 Also contains encoding functions, creation and recognition of
152 zero, NaN and inf (all signed)
153
154 Four extra bits are included in the mantissa: the top bit
155 (m[-1]) is effectively a carry-overflow. The other three are
156 guard (m[2]), round (m[1]), and sticky (m[0])
157 """
158 def __init__(self, width, m_extra=True):
159 FPNumBase.__init__(self, width, m_extra)
160
161 def elaborate(self, platform):
162 m = FPNumBase.elaborate(self, platform)
163
164 return m
165
166 def create(self, s, e, m):
167 """ creates a value from sign / exponent / mantissa
168
169 bias is added here, to the exponent
170 """
171 return [
172 self.v[-1].eq(s), # sign
173 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
174 self.v[0:self.e_start].eq(m) # mantissa
175 ]
176
177 def nan(self, s):
178 return self.create(s, self.P128, 1<<(self.e_start-1))
179
180 def inf(self, s):
181 return self.create(s, self.P128, 0)
182
183 def zero(self, s):
184 return self.create(s, self.N127, 0)
185
186
187 class MultiShiftRMerge:
188 """ shifts down (right) and merges lower bits into m[0].
189 m[0] is the "sticky" bit, basically
190 """
191 def __init__(self, width, s_max=None):
192 if s_max is None:
193 s_max = int(log(width) / log(2))
194 self.smax = s_max
195 self.m = Signal(width, reset_less=True)
196 self.inp = Signal(width, reset_less=True)
197 self.diff = Signal(s_max, reset_less=True)
198 self.width = width
199
200 def elaborate(self, platform):
201 m = Module()
202
203 rs = Signal(self.width, reset_less=True)
204 m_mask = Signal(self.width, reset_less=True)
205 smask = Signal(self.width, reset_less=True)
206 stickybit = Signal(reset_less=True)
207 maxslen = Signal(self.smax, reset_less=True)
208 maxsleni = Signal(self.smax, reset_less=True)
209
210 sm = MultiShift(self.width-1)
211 m0s = Const(0, self.width-1)
212 mw = Const(self.width-1, len(self.diff))
213 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
214 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
215 ]
216
217 m.d.comb += [
218 # shift mantissa by maxslen, mask by inverse
219 rs.eq(sm.rshift(self.inp[1:], maxslen)),
220 m_mask.eq(sm.rshift(~m0s, maxsleni)),
221 smask.eq(self.inp[1:] & m_mask),
222 # sticky bit combines all mask (and mantissa low bit)
223 stickybit.eq(smask.bool() | self.inp[0]),
224 # mantissa result contains m[0] already.
225 self.m.eq(Cat(stickybit, rs))
226 ]
227 return m
228
229
230 class FPNumShift(FPNumBase):
231 """ Floating-point Number Class for shifting
232 """
233 def __init__(self, mainm, op, inv, width, m_extra=True):
234 FPNumBase.__init__(self, width, m_extra)
235 self.latch_in = Signal()
236 self.mainm = mainm
237 self.inv = inv
238 self.op = op
239
240 def elaborate(self, platform):
241 m = FPNumBase.elaborate(self, platform)
242
243 m.d.comb += self.s.eq(op.s)
244 m.d.comb += self.e.eq(op.e)
245 m.d.comb += self.m.eq(op.m)
246
247 with self.mainm.State("align"):
248 with m.If(self.e < self.inv.e):
249 m.d.sync += self.shift_down()
250
251 return m
252
253 def shift_down(self, inp):
254 """ shifts a mantissa down by one. exponent is increased to compensate
255
256 accuracy is lost as a result in the mantissa however there are 3
257 guard bits (the latter of which is the "sticky" bit)
258 """
259 return [self.e.eq(inp.e + 1),
260 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
261 ]
262
263 def shift_down_multi(self, diff):
264 """ shifts a mantissa down. exponent is increased to compensate
265
266 accuracy is lost as a result in the mantissa however there are 3
267 guard bits (the latter of which is the "sticky" bit)
268
269 this code works by variable-shifting the mantissa by up to
270 its maximum bit-length: no point doing more (it'll still be
271 zero).
272
273 the sticky bit is computed by shifting a batch of 1s by
274 the same amount, which will introduce zeros. it's then
275 inverted and used as a mask to get the LSBs of the mantissa.
276 those are then |'d into the sticky bit.
277 """
278 sm = MultiShift(self.width)
279 mw = Const(self.m_width-1, len(diff))
280 maxslen = Mux(diff > mw, mw, diff)
281 rs = sm.rshift(self.m[1:], maxslen)
282 maxsleni = mw - maxslen
283 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
284
285 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
286 return [self.e.eq(self.e + diff),
287 self.m.eq(Cat(stickybits, rs))
288 ]
289
290 def shift_up_multi(self, diff):
291 """ shifts a mantissa up. exponent is decreased to compensate
292 """
293 sm = MultiShift(self.width)
294 mw = Const(self.m_width, len(diff))
295 maxslen = Mux(diff > mw, mw, diff)
296
297 return [self.e.eq(self.e - diff),
298 self.m.eq(sm.lshift(self.m, maxslen))
299 ]
300
301
302 class FPNumDecode(FPNumBase):
303 """ Floating-point Number Class
304
305 Contains signals for an incoming copy of the value, decoded into
306 sign / exponent / mantissa.
307 Also contains encoding functions, creation and recognition of
308 zero, NaN and inf (all signed)
309
310 Four extra bits are included in the mantissa: the top bit
311 (m[-1]) is effectively a carry-overflow. The other three are
312 guard (m[2]), round (m[1]), and sticky (m[0])
313 """
314 def __init__(self, op, width, m_extra=True):
315 FPNumBase.__init__(self, width, m_extra)
316 self.op = op
317
318 def elaborate(self, platform):
319 m = FPNumBase.elaborate(self, platform)
320
321 m.d.comb += self.decode(self.v)
322
323 return m
324
325 def decode(self, v):
326 """ decodes a latched value into sign / exponent / mantissa
327
328 bias is subtracted here, from the exponent. exponent
329 is extended to 10 bits so that subtract 127 is done on
330 a 10-bit number
331 """
332 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
333 #print ("decode", self.e_end)
334 return [self.m.eq(Cat(*args)), # mantissa
335 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
336 self.s.eq(v[-1]), # sign
337 ]
338
339 class FPNumIn(FPNumBase):
340 """ Floating-point Number Class
341
342 Contains signals for an incoming copy of the value, decoded into
343 sign / exponent / mantissa.
344 Also contains encoding functions, creation and recognition of
345 zero, NaN and inf (all signed)
346
347 Four extra bits are included in the mantissa: the top bit
348 (m[-1]) is effectively a carry-overflow. The other three are
349 guard (m[2]), round (m[1]), and sticky (m[0])
350 """
351 def __init__(self, op, width, m_extra=True):
352 FPNumBase.__init__(self, width, m_extra)
353 self.latch_in = Signal()
354 self.op = op
355
356 def decode(self, v):
357 """ decodes a latched value into sign / exponent / mantissa
358
359 bias is subtracted here, from the exponent. exponent
360 is extended to 10 bits so that subtract 127 is done on
361 a 10-bit number
362 """
363 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
364 #print ("decode", self.e_end)
365 return [self.m.eq(Cat(*args)), # mantissa
366 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
367 self.s.eq(v[-1]), # sign
368 ]
369
370 def shift_down(self, inp):
371 """ shifts a mantissa down by one. exponent is increased to compensate
372
373 accuracy is lost as a result in the mantissa however there are 3
374 guard bits (the latter of which is the "sticky" bit)
375 """
376 return [self.e.eq(inp.e + 1),
377 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
378 ]
379
380 def shift_down_multi(self, diff, inp=None):
381 """ shifts a mantissa down. exponent is increased to compensate
382
383 accuracy is lost as a result in the mantissa however there are 3
384 guard bits (the latter of which is the "sticky" bit)
385
386 this code works by variable-shifting the mantissa by up to
387 its maximum bit-length: no point doing more (it'll still be
388 zero).
389
390 the sticky bit is computed by shifting a batch of 1s by
391 the same amount, which will introduce zeros. it's then
392 inverted and used as a mask to get the LSBs of the mantissa.
393 those are then |'d into the sticky bit.
394 """
395 if inp is None:
396 inp = self
397 sm = MultiShift(self.width)
398 mw = Const(self.m_width-1, len(diff))
399 maxslen = Mux(diff > mw, mw, diff)
400 rs = sm.rshift(inp.m[1:], maxslen)
401 maxsleni = mw - maxslen
402 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
403
404 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
405 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
406 return [self.e.eq(inp.e + diff),
407 self.m.eq(Cat(stickybit, rs))
408 ]
409
410 def shift_up_multi(self, diff):
411 """ shifts a mantissa up. exponent is decreased to compensate
412 """
413 sm = MultiShift(self.width)
414 mw = Const(self.m_width, len(diff))
415 maxslen = Mux(diff > mw, mw, diff)
416
417 return [self.e.eq(self.e - diff),
418 self.m.eq(sm.lshift(self.m, maxslen))
419 ]
420
421 class Trigger:
422 def __init__(self):
423
424 self.stb = Signal(reset=0)
425 self.ack = Signal()
426 self.trigger = Signal(reset_less=True)
427
428 def elaborate(self, platform):
429 m = Module()
430 m.d.comb += self.trigger.eq(self.stb & self.ack)
431 return m
432
433 def eq(self, inp):
434 return [self.stb.eq(inp.stb),
435 self.ack.eq(inp.ack)
436 ]
437
438 def ports(self):
439 return [self.stb, self.ack]
440
441
442 class FPOp(Trigger):
443 def __init__(self, width):
444 Trigger.__init__(self)
445 self.width = width
446
447 self.v = Signal(width)
448
449 def chain_inv(self, in_op, extra=None):
450 stb = in_op.stb
451 if extra is not None:
452 stb = stb & extra
453 return [self.v.eq(in_op.v), # receive value
454 self.stb.eq(stb), # receive STB
455 in_op.ack.eq(~self.ack), # send ACK
456 ]
457
458 def chain_from(self, in_op, extra=None):
459 stb = in_op.stb
460 if extra is not None:
461 stb = stb & extra
462 return [self.v.eq(in_op.v), # receive value
463 self.stb.eq(stb), # receive STB
464 in_op.ack.eq(self.ack), # send ACK
465 ]
466
467 def eq(self, inp):
468 return [self.v.eq(inp.v),
469 self.stb.eq(inp.stb),
470 self.ack.eq(inp.ack)
471 ]
472
473 def ports(self):
474 return [self.v, self.stb, self.ack]
475
476
477 class Overflow:
478 def __init__(self):
479 self.guard = Signal(reset_less=True) # tot[2]
480 self.round_bit = Signal(reset_less=True) # tot[1]
481 self.sticky = Signal(reset_less=True) # tot[0]
482 self.m0 = Signal(reset_less=True) # mantissa zero bit
483
484 self.roundz = Signal(reset_less=True)
485
486 def eq(self, inp):
487 return [self.guard.eq(inp.guard),
488 self.round_bit.eq(inp.round_bit),
489 self.sticky.eq(inp.sticky),
490 self.m0.eq(inp.m0)]
491
492 def elaborate(self, platform):
493 m = Module()
494 m.d.comb += self.roundz.eq(self.guard & \
495 (self.round_bit | self.sticky | self.m0))
496 return m
497
498
499 class FPBase:
500 """ IEEE754 Floating Point Base Class
501
502 contains common functions for FP manipulation, such as
503 extracting and packing operands, normalisation, denormalisation,
504 rounding etc.
505 """
506
507 def get_op(self, m, op, v, next_state):
508 """ this function moves to the next state and copies the operand
509 when both stb and ack are 1.
510 acknowledgement is sent by setting ack to ZERO.
511 """
512 with m.If((op.ack) & (op.stb)):
513 m.next = next_state
514 m.d.sync += [
515 # op is latched in from FPNumIn class on same ack/stb
516 v.decode(op.v),
517 op.ack.eq(0)
518 ]
519 with m.Else():
520 m.d.sync += op.ack.eq(1)
521
522 def denormalise(self, m, a):
523 """ denormalises a number. this is probably the wrong name for
524 this function. for normalised numbers (exponent != minimum)
525 one *extra* bit (the implicit 1) is added *back in*.
526 for denormalised numbers, the mantissa is left alone
527 and the exponent increased by 1.
528
529 both cases *effectively multiply the number stored by 2*,
530 which has to be taken into account when extracting the result.
531 """
532 with m.If(a.exp_n127):
533 m.d.sync += a.e.eq(a.N126) # limit a exponent
534 with m.Else():
535 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
536
537 def op_normalise(self, m, op, next_state):
538 """ operand normalisation
539 NOTE: just like "align", this one keeps going round every clock
540 until the result's exponent is within acceptable "range"
541 """
542 with m.If((op.m[-1] == 0)): # check last bit of mantissa
543 m.d.sync +=[
544 op.e.eq(op.e - 1), # DECREASE exponent
545 op.m.eq(op.m << 1), # shift mantissa UP
546 ]
547 with m.Else():
548 m.next = next_state
549
550 def normalise_1(self, m, z, of, next_state):
551 """ first stage normalisation
552
553 NOTE: just like "align", this one keeps going round every clock
554 until the result's exponent is within acceptable "range"
555 NOTE: the weirdness of reassigning guard and round is due to
556 the extra mantissa bits coming from tot[0..2]
557 """
558 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
559 m.d.sync += [
560 z.e.eq(z.e - 1), # DECREASE exponent
561 z.m.eq(z.m << 1), # shift mantissa UP
562 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
563 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
564 of.round_bit.eq(0), # reset round bit
565 of.m0.eq(of.guard),
566 ]
567 with m.Else():
568 m.next = next_state
569
570 def normalise_2(self, m, z, of, next_state):
571 """ second stage normalisation
572
573 NOTE: just like "align", this one keeps going round every clock
574 until the result's exponent is within acceptable "range"
575 NOTE: the weirdness of reassigning guard and round is due to
576 the extra mantissa bits coming from tot[0..2]
577 """
578 with m.If(z.e < z.N126):
579 m.d.sync +=[
580 z.e.eq(z.e + 1), # INCREASE exponent
581 z.m.eq(z.m >> 1), # shift mantissa DOWN
582 of.guard.eq(z.m[0]),
583 of.m0.eq(z.m[1]),
584 of.round_bit.eq(of.guard),
585 of.sticky.eq(of.sticky | of.round_bit)
586 ]
587 with m.Else():
588 m.next = next_state
589
590 def roundz(self, m, z, roundz):
591 """ performs rounding on the output. TODO: different kinds of rounding
592 """
593 with m.If(roundz):
594 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
595 with m.If(z.m == z.m1s): # all 1s
596 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
597
598 def corrections(self, m, z, next_state):
599 """ denormalisation and sign-bug corrections
600 """
601 m.next = next_state
602 # denormalised, correct exponent to zero
603 with m.If(z.is_denormalised):
604 m.d.sync += z.e.eq(z.N127)
605
606 def pack(self, m, z, next_state):
607 """ packs the result into the output (detects overflow->Inf)
608 """
609 m.next = next_state
610 # if overflow occurs, return inf
611 with m.If(z.is_overflowed):
612 m.d.sync += z.inf(z.s)
613 with m.Else():
614 m.d.sync += z.create(z.s, z.e, z.m)
615
616 def put_z(self, m, z, out_z, next_state):
617 """ put_z: stores the result in the output. raises stb and waits
618 for ack to be set to 1 before moving to the next state.
619 resets stb back to zero when that occurs, as acknowledgement.
620 """
621 m.d.sync += [
622 out_z.v.eq(z.v)
623 ]
624 with m.If(out_z.stb & out_z.ack):
625 m.d.sync += out_z.stb.eq(0)
626 m.next = next_state
627 with m.Else():
628 m.d.sync += out_z.stb.eq(1)
629
630
631 class FPState(FPBase):
632 def __init__(self, state_from):
633 self.state_from = state_from
634
635 def set_inputs(self, inputs):
636 self.inputs = inputs
637 for k,v in inputs.items():
638 setattr(self, k, v)
639
640 def set_outputs(self, outputs):
641 self.outputs = outputs
642 for k,v in outputs.items():
643 setattr(self, k, v)
644
645
646 class FPID:
647 def __init__(self, id_wid):
648 self.id_wid = id_wid
649 if self.id_wid:
650 self.in_mid = Signal(id_wid, reset_less=True)
651 self.out_mid = Signal(id_wid, reset_less=True)
652 else:
653 self.in_mid = None
654 self.out_mid = None
655
656 def idsync(self, m):
657 if self.id_wid is not None:
658 m.d.sync += self.out_mid.eq(self.in_mid)
659
660