add FPFormat class to describe floating-point formats without all the nmigen signals
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12 import math
13
14
15 class FPFormat:
16 """ Class describing binary floating-point formats based on IEEE 754.
17
18 :attribute exponent_width: the number of bits in the exponent field.
19 :attribute mantissa_width: the number of bits stored in the mantissa
20 field.
21 :attribute has_int_bit: if the FP format has an explicit integer bit (like
22 the x87 80-bit format). The bit is considered part of the mantissa.
23 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
24 Image/Buffer formats are FP numbers without a sign bit.)
25 """
26
27 def __init__(self,
28 exponent_width,
29 mantissa_width,
30 has_int_bit=False,
31 has_sign=True):
32 """ Create ``FPFormat`` instance. """
33 self.exponent_width = exponent_width
34 self.mantissa_width = mantissa_width
35 self.has_int_bit = has_int_bit
36 self.has_sign = has_sign
37
38 def __eq__(self, other):
39 """ Check for equality. """
40 if not isinstance(other, FPFormat):
41 return NotImplemented
42 return (self.exponent_width == other.exponent_width
43 and self.mantissa_width == other.mantissa_width
44 and self.has_int_bit == other.has_int_bit
45 and self.has_sign == other.has_sign)
46
47 @staticmethod
48 def standard(width):
49 """ Get standard IEEE 754-2008 format.
50
51 :param width: bit-width of requested format.
52 :returns: the requested ``FPFormat`` instance.
53 """
54 if not instanceof(width, int):
55 raise TypeError()
56 if width == 16:
57 return FPFormat(5, 10)
58 if width == 32:
59 return FPFormat(8, 23)
60 if width == 64:
61 return FPFormat(11, 52)
62 if width == 128:
63 return FPFormat(15, 112)
64 if width > 128 and width % 32 == 0:
65 if width > 1000000: # arbitrary upper limit
66 raise ValueError("width too big")
67 exponent_width = round(4 * math.log2(width)) - 13
68 return FPFormat(exponent_width, width - 1 - exponent_width)
69 raise ValueError("width must be the bit-width of a valid IEEE"
70 " 754-2008 binary format")
71
72 def __repr__(self):
73 """ Get repr. """
74 if (self.width in (16, 32, 64, 128)
75 or (self.width > 128 and self.width % 32 == 0)):
76 if self == self.standard(self.width):
77 return f"FPFormat.standard({self.width})"
78 retval = f"FPFormat({self.exponent_width}, {self.mantissa_width}"
79 if self.has_int_bit is not False:
80 retval += f", {self.has_int_bit}"
81 if self.has_sign is not True:
82 retval += f", {self.has_sign}"
83 return retval + ")"
84
85 @property
86 def width(self):
87 """ Get the total number of bits in the FP format. """
88 return self.has_sign + self.exponent_width + self.mantissa_width
89
90 @property
91 def exponent_inf_nan(self):
92 """ Get the value of the exponent field designating infinity/NaN. """
93 return (1 << self.exponent_width) - 1
94
95 @property
96 def exponent_denormal_zero(self):
97 """ Get the value of the exponent field designating denormal/zero. """
98 return 0
99
100 @property
101 def exponent_min_normal(self):
102 """ Get the minimum value of the exponent field for normal numbers. """
103 return 1
104
105 @property
106 def exponent_max_normal(self):
107 """ Get the maximum value of the exponent field for normal numbers. """
108 return self.exponent_inf_nan - 1
109
110 @property
111 def exponent_bias(self):
112 """ Get the exponent bias. """
113 return (1 << (self.exponent_width - 1)) - 1
114
115 @property
116 def fraction_width(self):
117 """ Get the number of mantissa bits that are fraction bits. """
118 return self.mantissa_width - self.has_int_bit
119
120
121 class MultiShiftR:
122
123 def __init__(self, width):
124 self.width = width
125 self.smax = int(log(width) / log(2))
126 self.i = Signal(width, reset_less=True)
127 self.s = Signal(self.smax, reset_less=True)
128 self.o = Signal(width, reset_less=True)
129
130 def elaborate(self, platform):
131 m = Module()
132 m.d.comb += self.o.eq(self.i >> self.s)
133 return m
134
135
136 class MultiShift:
137 """ Generates variable-length single-cycle shifter from a series
138 of conditional tests on each bit of the left/right shift operand.
139 Each bit tested produces output shifted by that number of bits,
140 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
141 shifts by 2 bits, each partial result cascading to the next Mux.
142
143 Could be adapted to do arithmetic shift by taking copies of the
144 MSB instead of zeros.
145 """
146
147 def __init__(self, width):
148 self.width = width
149 self.smax = int(log(width) / log(2))
150
151 def lshift(self, op, s):
152 res = op << s
153 return res[:len(op)]
154
155 def rshift(self, op, s):
156 res = op >> s
157 return res[:len(op)]
158
159
160 class FPNumBaseRecord:
161 """ Floating-point Base Number Class
162 """
163
164 def __init__(self, width, m_extra=True, e_extra=False):
165 self.width = width
166 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
167 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
168 e_max = 1 << (e_width-3)
169 self.rmw = m_width - 1 # real mantissa width (not including extras)
170 self.e_max = e_max
171 if m_extra:
172 # mantissa extra bits (top,guard,round)
173 self.m_extra = 3
174 m_width += self.m_extra
175 else:
176 self.m_extra = 0
177 if e_extra:
178 self.e_extra = 6 # enough to cover FP64 when converting to FP16
179 e_width += self.e_extra
180 else:
181 self.e_extra = 0
182 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
183 self.m_width = m_width
184 self.e_width = e_width
185 self.e_start = self.rmw
186 self.e_end = self.rmw + self.e_width - 2 # for decoding
187
188 self.v = Signal(width, reset_less=True) # Latched copy of value
189 self.m = Signal(m_width, reset_less=True) # Mantissa
190 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
191 self.s = Signal(reset_less=True) # Sign bit
192
193 self.fp = self
194 self.drop_in(self)
195
196 def drop_in(self, fp):
197 fp.s = self.s
198 fp.e = self.e
199 fp.m = self.m
200 fp.v = self.v
201 fp.rmw = self.rmw
202 fp.width = self.width
203 fp.e_width = self.e_width
204 fp.e_max = self.e_max
205 fp.m_width = self.m_width
206 fp.e_start = self.e_start
207 fp.e_end = self.e_end
208 fp.m_extra = self.m_extra
209
210 m_width = self.m_width
211 e_max = self.e_max
212 e_width = self.e_width
213
214 self.mzero = Const(0, (m_width, False))
215 m_msb = 1 << (self.m_width-2)
216 self.msb1 = Const(m_msb, (m_width, False))
217 self.m1s = Const(-1, (m_width, False))
218 self.P128 = Const(e_max, (e_width, True))
219 self.P127 = Const(e_max-1, (e_width, True))
220 self.N127 = Const(-(e_max-1), (e_width, True))
221 self.N126 = Const(-(e_max-2), (e_width, True))
222
223 def create(self, s, e, m):
224 """ creates a value from sign / exponent / mantissa
225
226 bias is added here, to the exponent.
227
228 NOTE: order is important, because e_start/e_end can be
229 a bit too long (overwriting s).
230 """
231 return [
232 self.v[0:self.e_start].eq(m), # mantissa
233 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
234 self.v[-1].eq(s), # sign
235 ]
236
237 def _nan(self, s):
238 return (s, self.fp.P128, 1 << (self.e_start-1))
239
240 def _inf(self, s):
241 return (s, self.fp.P128, 0)
242
243 def _zero(self, s):
244 return (s, self.fp.N127, 0)
245
246 def nan(self, s):
247 return self.create(*self._nan(s))
248
249 def inf(self, s):
250 return self.create(*self._inf(s))
251
252 def zero(self, s):
253 return self.create(*self._zero(s))
254
255 def create2(self, s, e, m):
256 """ creates a value from sign / exponent / mantissa
257
258 bias is added here, to the exponent
259 """
260 e = e + self.P127 # exp (add on bias)
261 return Cat(m[0:self.e_start],
262 e[0:self.e_end-self.e_start],
263 s)
264
265 def nan2(self, s):
266 return self.create2(s, self.P128, self.msb1)
267
268 def inf2(self, s):
269 return self.create2(s, self.P128, self.mzero)
270
271 def zero2(self, s):
272 return self.create2(s, self.N127, self.mzero)
273
274 def __iter__(self):
275 yield self.s
276 yield self.e
277 yield self.m
278
279 def eq(self, inp):
280 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
281
282
283 class FPNumBase(FPNumBaseRecord, Elaboratable):
284 """ Floating-point Base Number Class
285 """
286
287 def __init__(self, fp):
288 fp.drop_in(self)
289 self.fp = fp
290 e_width = fp.e_width
291
292 self.is_nan = Signal(reset_less=True)
293 self.is_zero = Signal(reset_less=True)
294 self.is_inf = Signal(reset_less=True)
295 self.is_overflowed = Signal(reset_less=True)
296 self.is_denormalised = Signal(reset_less=True)
297 self.exp_128 = Signal(reset_less=True)
298 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
299 self.exp_lt_n126 = Signal(reset_less=True)
300 self.exp_gt_n126 = Signal(reset_less=True)
301 self.exp_gt127 = Signal(reset_less=True)
302 self.exp_n127 = Signal(reset_less=True)
303 self.exp_n126 = Signal(reset_less=True)
304 self.m_zero = Signal(reset_less=True)
305 self.m_msbzero = Signal(reset_less=True)
306
307 def elaborate(self, platform):
308 m = Module()
309 m.d.comb += self.is_nan.eq(self._is_nan())
310 m.d.comb += self.is_zero.eq(self._is_zero())
311 m.d.comb += self.is_inf.eq(self._is_inf())
312 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
313 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
314 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
315 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
316 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
317 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
318 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
319 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
320 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
321 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
322 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
323
324 return m
325
326 def _is_nan(self):
327 return (self.exp_128) & (~self.m_zero)
328
329 def _is_inf(self):
330 return (self.exp_128) & (self.m_zero)
331
332 def _is_zero(self):
333 return (self.exp_n127) & (self.m_zero)
334
335 def _is_overflowed(self):
336 return self.exp_gt127
337
338 def _is_denormalised(self):
339 return (self.exp_n126) & (self.m_msbzero)
340
341
342 class FPNumOut(FPNumBase):
343 """ Floating-point Number Class
344
345 Contains signals for an incoming copy of the value, decoded into
346 sign / exponent / mantissa.
347 Also contains encoding functions, creation and recognition of
348 zero, NaN and inf (all signed)
349
350 Four extra bits are included in the mantissa: the top bit
351 (m[-1]) is effectively a carry-overflow. The other three are
352 guard (m[2]), round (m[1]), and sticky (m[0])
353 """
354
355 def __init__(self, fp):
356 FPNumBase.__init__(self, fp)
357
358 def elaborate(self, platform):
359 m = FPNumBase.elaborate(self, platform)
360
361 return m
362
363
364 class MultiShiftRMerge(Elaboratable):
365 """ shifts down (right) and merges lower bits into m[0].
366 m[0] is the "sticky" bit, basically
367 """
368
369 def __init__(self, width, s_max=None):
370 if s_max is None:
371 s_max = int(log(width) / log(2))
372 self.smax = s_max
373 self.m = Signal(width, reset_less=True)
374 self.inp = Signal(width, reset_less=True)
375 self.diff = Signal(s_max, reset_less=True)
376 self.width = width
377
378 def elaborate(self, platform):
379 m = Module()
380
381 rs = Signal(self.width, reset_less=True)
382 m_mask = Signal(self.width, reset_less=True)
383 smask = Signal(self.width, reset_less=True)
384 stickybit = Signal(reset_less=True)
385 maxslen = Signal(self.smax, reset_less=True)
386 maxsleni = Signal(self.smax, reset_less=True)
387
388 sm = MultiShift(self.width-1)
389 m0s = Const(0, self.width-1)
390 mw = Const(self.width-1, len(self.diff))
391 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
392 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
393 ]
394
395 m.d.comb += [
396 # shift mantissa by maxslen, mask by inverse
397 rs.eq(sm.rshift(self.inp[1:], maxslen)),
398 m_mask.eq(sm.rshift(~m0s, maxsleni)),
399 smask.eq(self.inp[1:] & m_mask),
400 # sticky bit combines all mask (and mantissa low bit)
401 stickybit.eq(smask.bool() | self.inp[0]),
402 # mantissa result contains m[0] already.
403 self.m.eq(Cat(stickybit, rs))
404 ]
405 return m
406
407
408 class FPNumShift(FPNumBase, Elaboratable):
409 """ Floating-point Number Class for shifting
410 """
411
412 def __init__(self, mainm, op, inv, width, m_extra=True):
413 FPNumBase.__init__(self, width, m_extra)
414 self.latch_in = Signal()
415 self.mainm = mainm
416 self.inv = inv
417 self.op = op
418
419 def elaborate(self, platform):
420 m = FPNumBase.elaborate(self, platform)
421
422 m.d.comb += self.s.eq(op.s)
423 m.d.comb += self.e.eq(op.e)
424 m.d.comb += self.m.eq(op.m)
425
426 with self.mainm.State("align"):
427 with m.If(self.e < self.inv.e):
428 m.d.sync += self.shift_down()
429
430 return m
431
432 def shift_down(self, inp):
433 """ shifts a mantissa down by one. exponent is increased to compensate
434
435 accuracy is lost as a result in the mantissa however there are 3
436 guard bits (the latter of which is the "sticky" bit)
437 """
438 return [self.e.eq(inp.e + 1),
439 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
440 ]
441
442 def shift_down_multi(self, diff):
443 """ shifts a mantissa down. exponent is increased to compensate
444
445 accuracy is lost as a result in the mantissa however there are 3
446 guard bits (the latter of which is the "sticky" bit)
447
448 this code works by variable-shifting the mantissa by up to
449 its maximum bit-length: no point doing more (it'll still be
450 zero).
451
452 the sticky bit is computed by shifting a batch of 1s by
453 the same amount, which will introduce zeros. it's then
454 inverted and used as a mask to get the LSBs of the mantissa.
455 those are then |'d into the sticky bit.
456 """
457 sm = MultiShift(self.width)
458 mw = Const(self.m_width-1, len(diff))
459 maxslen = Mux(diff > mw, mw, diff)
460 rs = sm.rshift(self.m[1:], maxslen)
461 maxsleni = mw - maxslen
462 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
463
464 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
465 return [self.e.eq(self.e + diff),
466 self.m.eq(Cat(stickybits, rs))
467 ]
468
469 def shift_up_multi(self, diff):
470 """ shifts a mantissa up. exponent is decreased to compensate
471 """
472 sm = MultiShift(self.width)
473 mw = Const(self.m_width, len(diff))
474 maxslen = Mux(diff > mw, mw, diff)
475
476 return [self.e.eq(self.e - diff),
477 self.m.eq(sm.lshift(self.m, maxslen))
478 ]
479
480
481 class FPNumDecode(FPNumBase):
482 """ Floating-point Number Class
483
484 Contains signals for an incoming copy of the value, decoded into
485 sign / exponent / mantissa.
486 Also contains encoding functions, creation and recognition of
487 zero, NaN and inf (all signed)
488
489 Four extra bits are included in the mantissa: the top bit
490 (m[-1]) is effectively a carry-overflow. The other three are
491 guard (m[2]), round (m[1]), and sticky (m[0])
492 """
493
494 def __init__(self, op, fp):
495 FPNumBase.__init__(self, fp)
496 self.op = op
497
498 def elaborate(self, platform):
499 m = FPNumBase.elaborate(self, platform)
500
501 m.d.comb += self.decode(self.v)
502
503 return m
504
505 def decode(self, v):
506 """ decodes a latched value into sign / exponent / mantissa
507
508 bias is subtracted here, from the exponent. exponent
509 is extended to 10 bits so that subtract 127 is done on
510 a 10-bit number
511 """
512 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
513 #print ("decode", self.e_end)
514 return [self.m.eq(Cat(*args)), # mantissa
515 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
516 self.s.eq(v[-1]), # sign
517 ]
518
519
520 class FPNumIn(FPNumBase):
521 """ Floating-point Number Class
522
523 Contains signals for an incoming copy of the value, decoded into
524 sign / exponent / mantissa.
525 Also contains encoding functions, creation and recognition of
526 zero, NaN and inf (all signed)
527
528 Four extra bits are included in the mantissa: the top bit
529 (m[-1]) is effectively a carry-overflow. The other three are
530 guard (m[2]), round (m[1]), and sticky (m[0])
531 """
532
533 def __init__(self, op, fp):
534 FPNumBase.__init__(self, fp)
535 self.latch_in = Signal()
536 self.op = op
537
538 def decode2(self, m):
539 """ decodes a latched value into sign / exponent / mantissa
540
541 bias is subtracted here, from the exponent. exponent
542 is extended to 10 bits so that subtract 127 is done on
543 a 10-bit number
544 """
545 v = self.v
546 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
547 #print ("decode", self.e_end)
548 res = ObjectProxy(m, pipemode=False)
549 res.m = Cat(*args) # mantissa
550 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
551 res.s = v[-1] # sign
552 return res
553
554 def decode(self, v):
555 """ decodes a latched value into sign / exponent / mantissa
556
557 bias is subtracted here, from the exponent. exponent
558 is extended to 10 bits so that subtract 127 is done on
559 a 10-bit number
560 """
561 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
562 #print ("decode", self.e_end)
563 return [self.m.eq(Cat(*args)), # mantissa
564 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
565 self.s.eq(v[-1]), # sign
566 ]
567
568 def shift_down(self, inp):
569 """ shifts a mantissa down by one. exponent is increased to compensate
570
571 accuracy is lost as a result in the mantissa however there are 3
572 guard bits (the latter of which is the "sticky" bit)
573 """
574 return [self.e.eq(inp.e + 1),
575 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
576 ]
577
578 def shift_down_multi(self, diff, inp=None):
579 """ shifts a mantissa down. exponent is increased to compensate
580
581 accuracy is lost as a result in the mantissa however there are 3
582 guard bits (the latter of which is the "sticky" bit)
583
584 this code works by variable-shifting the mantissa by up to
585 its maximum bit-length: no point doing more (it'll still be
586 zero).
587
588 the sticky bit is computed by shifting a batch of 1s by
589 the same amount, which will introduce zeros. it's then
590 inverted and used as a mask to get the LSBs of the mantissa.
591 those are then |'d into the sticky bit.
592 """
593 if inp is None:
594 inp = self
595 sm = MultiShift(self.width)
596 mw = Const(self.m_width-1, len(diff))
597 maxslen = Mux(diff > mw, mw, diff)
598 rs = sm.rshift(inp.m[1:], maxslen)
599 maxsleni = mw - maxslen
600 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
601
602 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
603 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
604 return [self.e.eq(inp.e + diff),
605 self.m.eq(Cat(stickybit, rs))
606 ]
607
608 def shift_up_multi(self, diff):
609 """ shifts a mantissa up. exponent is decreased to compensate
610 """
611 sm = MultiShift(self.width)
612 mw = Const(self.m_width, len(diff))
613 maxslen = Mux(diff > mw, mw, diff)
614
615 return [self.e.eq(self.e - diff),
616 self.m.eq(sm.lshift(self.m, maxslen))
617 ]
618
619
620 class Trigger(Elaboratable):
621 def __init__(self):
622
623 self.stb = Signal(reset=0)
624 self.ack = Signal()
625 self.trigger = Signal(reset_less=True)
626
627 def elaborate(self, platform):
628 m = Module()
629 m.d.comb += self.trigger.eq(self.stb & self.ack)
630 return m
631
632 def eq(self, inp):
633 return [self.stb.eq(inp.stb),
634 self.ack.eq(inp.ack)
635 ]
636
637 def ports(self):
638 return [self.stb, self.ack]
639
640
641 class FPOpIn(PrevControl):
642 def __init__(self, width):
643 PrevControl.__init__(self)
644 self.width = width
645
646 @property
647 def v(self):
648 return self.data_i
649
650 def chain_inv(self, in_op, extra=None):
651 stb = in_op.stb
652 if extra is not None:
653 stb = stb & extra
654 return [self.v.eq(in_op.v), # receive value
655 self.stb.eq(stb), # receive STB
656 in_op.ack.eq(~self.ack), # send ACK
657 ]
658
659 def chain_from(self, in_op, extra=None):
660 stb = in_op.stb
661 if extra is not None:
662 stb = stb & extra
663 return [self.v.eq(in_op.v), # receive value
664 self.stb.eq(stb), # receive STB
665 in_op.ack.eq(self.ack), # send ACK
666 ]
667
668
669 class FPOpOut(NextControl):
670 def __init__(self, width):
671 NextControl.__init__(self)
672 self.width = width
673
674 @property
675 def v(self):
676 return self.data_o
677
678 def chain_inv(self, in_op, extra=None):
679 stb = in_op.stb
680 if extra is not None:
681 stb = stb & extra
682 return [self.v.eq(in_op.v), # receive value
683 self.stb.eq(stb), # receive STB
684 in_op.ack.eq(~self.ack), # send ACK
685 ]
686
687 def chain_from(self, in_op, extra=None):
688 stb = in_op.stb
689 if extra is not None:
690 stb = stb & extra
691 return [self.v.eq(in_op.v), # receive value
692 self.stb.eq(stb), # receive STB
693 in_op.ack.eq(self.ack), # send ACK
694 ]
695
696
697 class Overflow: # (Elaboratable):
698 def __init__(self, name=None):
699 if name is None:
700 name = ""
701 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
702 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
703 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
704 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
705
706 #self.roundz = Signal(reset_less=True)
707
708 def __iter__(self):
709 yield self.guard
710 yield self.round_bit
711 yield self.sticky
712 yield self.m0
713
714 def eq(self, inp):
715 return [self.guard.eq(inp.guard),
716 self.round_bit.eq(inp.round_bit),
717 self.sticky.eq(inp.sticky),
718 self.m0.eq(inp.m0)]
719
720 @property
721 def roundz(self):
722 return self.guard & (self.round_bit | self.sticky | self.m0)
723
724
725 class OverflowMod(Elaboratable, Overflow):
726 def __init__(self, name=None):
727 Overflow.__init__(self, name)
728 if name is None:
729 name = ""
730 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
731
732 def __iter__(self):
733 yield from Overflow.__iter__(self)
734 yield self.roundz_out
735
736 def eq(self, inp):
737 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
738
739 def elaborate(self, platform):
740 m = Module()
741 m.d.comb += self.roundz_out.eq(self.roundz)
742 return m
743
744
745 class FPBase:
746 """ IEEE754 Floating Point Base Class
747
748 contains common functions for FP manipulation, such as
749 extracting and packing operands, normalisation, denormalisation,
750 rounding etc.
751 """
752
753 def get_op(self, m, op, v, next_state):
754 """ this function moves to the next state and copies the operand
755 when both stb and ack are 1.
756 acknowledgement is sent by setting ack to ZERO.
757 """
758 res = v.decode2(m)
759 ack = Signal()
760 with m.If((op.ready_o) & (op.valid_i_test)):
761 m.next = next_state
762 # op is latched in from FPNumIn class on same ack/stb
763 m.d.comb += ack.eq(0)
764 with m.Else():
765 m.d.comb += ack.eq(1)
766 return [res, ack]
767
768 def denormalise(self, m, a):
769 """ denormalises a number. this is probably the wrong name for
770 this function. for normalised numbers (exponent != minimum)
771 one *extra* bit (the implicit 1) is added *back in*.
772 for denormalised numbers, the mantissa is left alone
773 and the exponent increased by 1.
774
775 both cases *effectively multiply the number stored by 2*,
776 which has to be taken into account when extracting the result.
777 """
778 with m.If(a.exp_n127):
779 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
780 with m.Else():
781 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
782
783 def op_normalise(self, m, op, next_state):
784 """ operand normalisation
785 NOTE: just like "align", this one keeps going round every clock
786 until the result's exponent is within acceptable "range"
787 """
788 with m.If((op.m[-1] == 0)): # check last bit of mantissa
789 m.d.sync += [
790 op.e.eq(op.e - 1), # DECREASE exponent
791 op.m.eq(op.m << 1), # shift mantissa UP
792 ]
793 with m.Else():
794 m.next = next_state
795
796 def normalise_1(self, m, z, of, next_state):
797 """ first stage normalisation
798
799 NOTE: just like "align", this one keeps going round every clock
800 until the result's exponent is within acceptable "range"
801 NOTE: the weirdness of reassigning guard and round is due to
802 the extra mantissa bits coming from tot[0..2]
803 """
804 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
805 m.d.sync += [
806 z.e.eq(z.e - 1), # DECREASE exponent
807 z.m.eq(z.m << 1), # shift mantissa UP
808 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
809 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
810 of.round_bit.eq(0), # reset round bit
811 of.m0.eq(of.guard),
812 ]
813 with m.Else():
814 m.next = next_state
815
816 def normalise_2(self, m, z, of, next_state):
817 """ second stage normalisation
818
819 NOTE: just like "align", this one keeps going round every clock
820 until the result's exponent is within acceptable "range"
821 NOTE: the weirdness of reassigning guard and round is due to
822 the extra mantissa bits coming from tot[0..2]
823 """
824 with m.If(z.e < z.fp.N126):
825 m.d.sync += [
826 z.e.eq(z.e + 1), # INCREASE exponent
827 z.m.eq(z.m >> 1), # shift mantissa DOWN
828 of.guard.eq(z.m[0]),
829 of.m0.eq(z.m[1]),
830 of.round_bit.eq(of.guard),
831 of.sticky.eq(of.sticky | of.round_bit)
832 ]
833 with m.Else():
834 m.next = next_state
835
836 def roundz(self, m, z, roundz):
837 """ performs rounding on the output. TODO: different kinds of rounding
838 """
839 with m.If(roundz):
840 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
841 with m.If(z.m == z.fp.m1s): # all 1s
842 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
843
844 def corrections(self, m, z, next_state):
845 """ denormalisation and sign-bug corrections
846 """
847 m.next = next_state
848 # denormalised, correct exponent to zero
849 with m.If(z.is_denormalised):
850 m.d.sync += z.e.eq(z.fp.N127)
851
852 def pack(self, m, z, next_state):
853 """ packs the result into the output (detects overflow->Inf)
854 """
855 m.next = next_state
856 # if overflow occurs, return inf
857 with m.If(z.is_overflowed):
858 m.d.sync += z.inf(z.s)
859 with m.Else():
860 m.d.sync += z.create(z.s, z.e, z.m)
861
862 def put_z(self, m, z, out_z, next_state):
863 """ put_z: stores the result in the output. raises stb and waits
864 for ack to be set to 1 before moving to the next state.
865 resets stb back to zero when that occurs, as acknowledgement.
866 """
867 m.d.sync += [
868 out_z.v.eq(z.v)
869 ]
870 with m.If(out_z.valid_o & out_z.ready_i_test):
871 m.d.sync += out_z.valid_o.eq(0)
872 m.next = next_state
873 with m.Else():
874 m.d.sync += out_z.valid_o.eq(1)
875
876
877 class FPState(FPBase):
878 def __init__(self, state_from):
879 self.state_from = state_from
880
881 def set_inputs(self, inputs):
882 self.inputs = inputs
883 for k, v in inputs.items():
884 setattr(self, k, v)
885
886 def set_outputs(self, outputs):
887 self.outputs = outputs
888 for k, v in outputs.items():
889 setattr(self, k, v)
890
891
892 class FPID:
893 def __init__(self, id_wid):
894 self.id_wid = id_wid
895 if self.id_wid:
896 self.in_mid = Signal(id_wid, reset_less=True)
897 self.out_mid = Signal(id_wid, reset_less=True)
898 else:
899 self.in_mid = None
900 self.out_mid = None
901
902 def idsync(self, m):
903 if self.id_wid is not None:
904 m.d.sync += self.out_mid.eq(self.in_mid)