add FPFormat.get_exponent_value to get an unbiased exponent corrected for subnormals
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 """IEEE754 Floating Point Library
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
5
6 """
7
8
9 from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
10 Value, Shape, signed, unsigned)
11 from nmigen.utils import bits_for
12 from operator import or_
13 from functools import reduce
14
15 from nmutil.singlepipe import PrevControl, NextControl
16 from nmutil.pipeline import ObjectProxy
17 import unittest
18 import math
19 import enum
20
21 try:
22 from nmigen.hdl.smtlib2 import RoundingModeEnum
23 _HAVE_SMTLIB2 = True
24 except ImportError:
25 _HAVE_SMTLIB2 = False
26
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
28 _raise_err = object()
29
30
31 class FPRoundingMode(enum.Enum):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
34
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
38 RNE = 0b00
39 """Round to Nearest Even
40
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
45 """
46
47 ROUND_NEAREST_TIES_TO_EVEN = RNE
48 DEFAULT = RNE
49
50 RTZ = 0b01
51 """Round towards Zero
52
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
56 """
57
58 ROUND_TOWARDS_ZERO = RTZ
59
60 RTP = 0b10
61 """Round towards +Infinity
62
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
66 +Infinity.
67 """
68
69 ROUND_TOWARDS_POSITIVE = RTP
70
71 RTN = 0b11
72 """Round towards -Infinity
73
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
77 -Infinity.
78 """
79
80 ROUND_TOWARDS_NEGATIVE = RTN
81
82 RNA = 0b100
83 """Round to Nearest Away
84
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
89 """
90
91 ROUND_NEAREST_TIES_TO_AWAY = RNA
92
93 RTOP = 0b101
94 """Round to Odd, unsigned zeros are Positive
95
96 Not in smtlib2.
97
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
101
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
104
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
107
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
110 (e.g. f80).
111 """
112
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE = RTOP
114
115 RTON = 0b110
116 """Round to Odd, unsigned zeros are Negative
117
118 Not in smtlib2.
119
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
123
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
126
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
129
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
132 (e.g. f80).
133 """
134
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE = RTON
136
137 @staticmethod
138 def make_array(f):
139 l = [None] * len(FPRoundingMode)
140 for rm in FPRoundingMode:
141 l[rm.value] = f(rm)
142 return Array(l)
143
144 def overflow_rounds_to_inf(self, sign):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
147 """
148 not_sign = ~sign if isinstance(sign, Value) else not sign
149 if self is FPRoundingMode.RNE:
150 return True
151 elif self is FPRoundingMode.RTZ:
152 return False
153 elif self is FPRoundingMode.RTP:
154 return not_sign
155 elif self is FPRoundingMode.RTN:
156 return sign
157 elif self is FPRoundingMode.RNA:
158 return True
159 elif self is FPRoundingMode.RTOP:
160 return False
161 else:
162 assert self is FPRoundingMode.RTON
163 return False
164
165 def underflow_rounds_to_zero(self, sign):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
168 """
169 not_sign = ~sign if isinstance(sign, Value) else not sign
170 if self is FPRoundingMode.RNE:
171 return True
172 elif self is FPRoundingMode.RTZ:
173 return True
174 elif self is FPRoundingMode.RTP:
175 return sign
176 elif self is FPRoundingMode.RTN:
177 return not_sign
178 elif self is FPRoundingMode.RNA:
179 return True
180 elif self is FPRoundingMode.RTOP:
181 return False
182 else:
183 assert self is FPRoundingMode.RTON
184 return False
185
186 def zero_sign(self):
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
189 """
190 if self is FPRoundingMode.RNE:
191 return False
192 elif self is FPRoundingMode.RTZ:
193 return False
194 elif self is FPRoundingMode.RTP:
195 return False
196 elif self is FPRoundingMode.RTN:
197 return True
198 elif self is FPRoundingMode.RNA:
199 return False
200 elif self is FPRoundingMode.RTOP:
201 return False
202 else:
203 assert self is FPRoundingMode.RTON
204 return True
205
206 if _HAVE_SMTLIB2:
207 def to_smtlib2(self, default=_raise_err):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
211 """
212 if self is FPRoundingMode.RNE:
213 return RoundingModeEnum.RNE
214 elif self is FPRoundingMode.RTZ:
215 return RoundingModeEnum.RTZ
216 elif self is FPRoundingMode.RTP:
217 return RoundingModeEnum.RTP
218 elif self is FPRoundingMode.RTN:
219 return RoundingModeEnum.RTN
220 elif self is FPRoundingMode.RNA:
221 return RoundingModeEnum.RNA
222 else:
223 assert self in (FPRoundingMode.RTOP, FPRoundingMode.RTON)
224 if default is _raise_err:
225 raise ValueError(
226 "no corresponding smtlib2 rounding mode", self)
227 return default
228
229
230
231
232 class FPFormat:
233 """ Class describing binary floating-point formats based on IEEE 754.
234
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
237 field.
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
242 """
243
244 def __init__(self,
245 e_width,
246 m_width,
247 has_int_bit=False,
248 has_sign=True):
249 """ Create ``FPFormat`` instance. """
250 self.e_width = e_width
251 self.m_width = m_width
252 self.has_int_bit = has_int_bit
253 self.has_sign = has_sign
254
255 def __eq__(self, other):
256 """ Check for equality. """
257 if not isinstance(other, FPFormat):
258 return NotImplemented
259 return (self.e_width == other.e_width
260 and self.m_width == other.m_width
261 and self.has_int_bit == other.has_int_bit
262 and self.has_sign == other.has_sign)
263
264 @staticmethod
265 def standard(width):
266 """ Get standard IEEE 754-2008 format.
267
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
270 """
271 if width == 16:
272 return FPFormat(5, 10)
273 if width == 32:
274 return FPFormat(8, 23)
275 if width == 64:
276 return FPFormat(11, 52)
277 if width == 128:
278 return FPFormat(15, 112)
279 if width > 128 and width % 32 == 0:
280 if width > 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width = round(4 * math.log2(width)) - 13
283 return FPFormat(e_width, width - 1 - e_width)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
286
287 def __repr__(self):
288 """ Get repr. """
289 try:
290 if self == self.standard(self.width):
291 return f"FPFormat.standard({self.width})"
292 except ValueError:
293 pass
294 retval = f"FPFormat({self.e_width}, {self.m_width}"
295 if self.has_int_bit is not False:
296 retval += f", {self.has_int_bit}"
297 if self.has_sign is not True:
298 retval += f", {self.has_sign}"
299 return retval + ")"
300
301 def get_sign_field(self, x):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
304 """
305 return x >> (self.e_width + self.m_width)
306
307 def get_exponent_field(self, x):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
309 """
310 x = ((x >> self.m_width) & self.exponent_inf_nan)
311 return x
312
313 def get_exponent(self, x):
314 """ returns the exponent of its input number, x
315 """
316 x = self.get_exponent_field(x)
317 if isinstance(x, Value) and not x.shape().signed:
318 # convert x to signed without changing its value,
319 # since exponents can be negative
320 x |= Const(0, signed(1))
321 return x - self.exponent_bias
322
323 def get_exponent_value(self, x):
324 """ returns the exponent of its input number, x, adjusted for the
325 mathematically correct subnormal exponent.
326 """
327 x = self.get_exponent_field(x)
328 if isinstance(x, Value) and not x.shape().signed:
329 # convert x to signed without changing its value,
330 # since exponents can be negative
331 x |= Const(0, signed(1))
332 return x + (x == self.exponent_denormal_zero) - self.exponent_bias
333
334 def get_mantissa_field(self, x):
335 """ returns the mantissa of its input number, x
336 """
337 return x & self.mantissa_mask
338
339 def get_mantissa_value(self, x):
340 """ returns the mantissa of its input number, x, but with the
341 implicit bit, if any, made explicit.
342 """
343 if self.has_int_bit:
344 return self.get_mantissa_field(x)
345 exponent_field = self.get_exponent_field(x)
346 mantissa_field = self.get_mantissa_field(x)
347 implicit_bit = exponent_field != self.exponent_denormal_zero
348 return (implicit_bit << self.fraction_width) | mantissa_field
349
350 def is_zero(self, x):
351 """ returns true if x is +/- zero
352 """
353 return (self.get_exponent(x) == self.e_sub) & \
354 (self.get_mantissa_field(x) == 0)
355
356 def is_subnormal(self, x):
357 """ returns true if x is subnormal (exp at minimum)
358 """
359 return (self.get_exponent(x) == self.e_sub) & \
360 (self.get_mantissa_field(x) != 0)
361
362 def is_inf(self, x):
363 """ returns true if x is infinite
364 """
365 return (self.get_exponent(x) == self.e_max) & \
366 (self.get_mantissa_field(x) == 0)
367
368 def is_nan(self, x):
369 """ returns true if x is a nan (quiet or signalling)
370 """
371 return (self.get_exponent(x) == self.e_max) & \
372 (self.get_mantissa_field(x) != 0)
373
374 def is_quiet_nan(self, x):
375 """ returns true if x is a quiet nan
376 """
377 highbit = 1 << (self.m_width - 1)
378 return (self.get_exponent(x) == self.e_max) & \
379 (self.get_mantissa_field(x) != 0) & \
380 (self.get_mantissa_field(x) & highbit != 0)
381
382 def to_quiet_nan(self, x):
383 """ converts `x` to a quiet NaN """
384 highbit = 1 << (self.m_width - 1)
385 return x | highbit | self.exponent_mask
386
387 def quiet_nan(self, sign=0):
388 """ return the default quiet NaN with sign `sign` """
389 return self.to_quiet_nan(self.zero(sign))
390
391 def zero(self, sign=0):
392 """ return zero with sign `sign` """
393 return (sign != 0) << (self.e_width + self.m_width)
394
395 def inf(self, sign=0):
396 """ return infinity with sign `sign` """
397 return self.zero(sign) | self.exponent_mask
398
399 def is_nan_signaling(self, x):
400 """ returns true if x is a signalling nan
401 """
402 highbit = 1 << (self.m_width - 1)
403 return (self.get_exponent(x) == self.e_max) & \
404 (self.get_mantissa_field(x) != 0) & \
405 (self.get_mantissa_field(x) & highbit) == 0
406
407 @property
408 def width(self):
409 """ Get the total number of bits in the FP format. """
410 return self.has_sign + self.e_width + self.m_width
411
412 @property
413 def mantissa_mask(self):
414 """ Get a mantissa mask based on the mantissa width """
415 return (1 << self.m_width) - 1
416
417 @property
418 def exponent_mask(self):
419 """ Get an exponent mask """
420 return self.exponent_inf_nan << self.m_width
421
422 @property
423 def exponent_inf_nan(self):
424 """ Get the value of the exponent field designating infinity/NaN. """
425 return (1 << self.e_width) - 1
426
427 @property
428 def e_max(self):
429 """ get the maximum exponent (minus bias)
430 """
431 return self.exponent_inf_nan - self.exponent_bias
432
433 @property
434 def e_sub(self):
435 return self.exponent_denormal_zero - self.exponent_bias
436 @property
437 def exponent_denormal_zero(self):
438 """ Get the value of the exponent field designating denormal/zero. """
439 return 0
440
441 @property
442 def exponent_min_normal(self):
443 """ Get the minimum value of the exponent field for normal numbers. """
444 return 1
445
446 @property
447 def exponent_max_normal(self):
448 """ Get the maximum value of the exponent field for normal numbers. """
449 return self.exponent_inf_nan - 1
450
451 @property
452 def exponent_bias(self):
453 """ Get the exponent bias. """
454 return (1 << (self.e_width - 1)) - 1
455
456 @property
457 def fraction_width(self):
458 """ Get the number of mantissa bits that are fraction bits. """
459 return self.m_width - self.has_int_bit
460
461
462 class TestFPFormat(unittest.TestCase):
463 """ very quick test for FPFormat
464 """
465
466 def test_fpformat_fp64(self):
467 f64 = FPFormat.standard(64)
468 from sfpy import Float64
469 x = Float64(1.0).bits
470 print (hex(x))
471
472 self.assertEqual(f64.get_exponent(x), 0)
473 x = Float64(2.0).bits
474 print (hex(x))
475 self.assertEqual(f64.get_exponent(x), 1)
476
477 x = Float64(1.5).bits
478 m = f64.get_mantissa_field(x)
479 print (hex(x), hex(m))
480 self.assertEqual(m, 0x8000000000000)
481
482 s = f64.get_sign_field(x)
483 print (hex(x), hex(s))
484 self.assertEqual(s, 0)
485
486 x = Float64(-1.5).bits
487 s = f64.get_sign_field(x)
488 print (hex(x), hex(s))
489 self.assertEqual(s, 1)
490
491 def test_fpformat_fp32(self):
492 f32 = FPFormat.standard(32)
493 from sfpy import Float32
494 x = Float32(1.0).bits
495 print (hex(x))
496
497 self.assertEqual(f32.get_exponent(x), 0)
498 x = Float32(2.0).bits
499 print (hex(x))
500 self.assertEqual(f32.get_exponent(x), 1)
501
502 x = Float32(1.5).bits
503 m = f32.get_mantissa_field(x)
504 print (hex(x), hex(m))
505 self.assertEqual(m, 0x400000)
506
507 # NaN test
508 x = Float32(-1.0).sqrt()
509 x = x.bits
510 i = f32.is_nan(x)
511 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
512 f32.get_mantissa_field(x), i)
513 self.assertEqual(i, True)
514
515 # Inf test
516 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
517 x = x.bits
518 i = f32.is_inf(x)
519 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
520 f32.get_mantissa_field(x), i)
521 self.assertEqual(i, True)
522
523 # subnormal
524 x = Float32(1e-41)
525 x = x.bits
526 i = f32.is_subnormal(x)
527 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
528 f32.get_mantissa_field(x), i)
529 self.assertEqual(i, True)
530
531 x = Float32(0.0)
532 x = x.bits
533 i = f32.is_subnormal(x)
534 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
535 f32.get_mantissa_field(x), i)
536 self.assertEqual(i, False)
537
538 # zero
539 i = f32.is_zero(x)
540 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
541 f32.get_mantissa_field(x), i)
542 self.assertEqual(i, True)
543
544
545 class MultiShiftR(Elaboratable):
546
547 def __init__(self, width):
548 self.width = width
549 self.smax = bits_for(width - 1)
550 self.i = Signal(width, reset_less=True)
551 self.s = Signal(self.smax, reset_less=True)
552 self.o = Signal(width, reset_less=True)
553
554 def elaborate(self, platform):
555 m = Module()
556 m.d.comb += self.o.eq(self.i >> self.s)
557 return m
558
559
560 class MultiShift:
561 """ Generates variable-length single-cycle shifter from a series
562 of conditional tests on each bit of the left/right shift operand.
563 Each bit tested produces output shifted by that number of bits,
564 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
565 shifts by 2 bits, each partial result cascading to the next Mux.
566
567 Could be adapted to do arithmetic shift by taking copies of the
568 MSB instead of zeros.
569 """
570
571 def __init__(self, width):
572 self.width = width
573 self.smax = bits_for(width - 1)
574
575 def lshift(self, op, s):
576 res = op << s
577 return res[:len(op)]
578
579 def rshift(self, op, s):
580 res = op >> s
581 return res[:len(op)]
582
583
584 class FPNumBaseRecord:
585 """ Floating-point Base Number Class.
586
587 This class is designed to be passed around in other data structures
588 (between pipelines and between stages). Its "friend" is FPNumBase,
589 which is a *module*. The reason for the discernment is because
590 nmigen modules that are not added to submodules results in the
591 irritating "Elaboration" warning. Despite not *needing* FPNumBase
592 in many cases to be added as a submodule (because it is just data)
593 this was not possible to solve without splitting out the data from
594 the module.
595 """
596
597 def __init__(self, width, m_extra=True, e_extra=False, name=None):
598 if name is None:
599 name = ""
600 # assert false, "missing name"
601 else:
602 name += "_"
603 self.width = width
604 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
605 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
606 e_max = 1 << (e_width-3)
607 self.rmw = m_width - 1 # real mantissa width (not including extras)
608 self.e_max = e_max
609 if m_extra:
610 # mantissa extra bits (top,guard,round)
611 self.m_extra = 3
612 m_width += self.m_extra
613 else:
614 self.m_extra = 0
615 if e_extra:
616 self.e_extra = 6 # enough to cover FP64 when converting to FP16
617 e_width += self.e_extra
618 else:
619 self.e_extra = 0
620 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
621 self.m_width = m_width
622 self.e_width = e_width
623 self.e_start = self.rmw
624 self.e_end = self.rmw + self.e_width - 2 # for decoding
625
626 self.v = Signal(width, reset_less=True,
627 name=name+"v") # Latched copy of value
628 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
629 self.e = Signal(signed(e_width),
630 reset_less=True, name=name+"e") # exp+2 bits, signed
631 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
632
633 self.fp = self
634 self.drop_in(self)
635
636 def drop_in(self, fp):
637 fp.s = self.s
638 fp.e = self.e
639 fp.m = self.m
640 fp.v = self.v
641 fp.rmw = self.rmw
642 fp.width = self.width
643 fp.e_width = self.e_width
644 fp.e_max = self.e_max
645 fp.m_width = self.m_width
646 fp.e_start = self.e_start
647 fp.e_end = self.e_end
648 fp.m_extra = self.m_extra
649
650 m_width = self.m_width
651 e_max = self.e_max
652 e_width = self.e_width
653
654 self.mzero = Const(0, unsigned(m_width))
655 m_msb = 1 << (self.m_width-2)
656 self.msb1 = Const(m_msb, unsigned(m_width))
657 self.m1s = Const(-1, unsigned(m_width))
658 self.P128 = Const(e_max, signed(e_width))
659 self.P127 = Const(e_max-1, signed(e_width))
660 self.N127 = Const(-(e_max-1), signed(e_width))
661 self.N126 = Const(-(e_max-2), signed(e_width))
662
663 def create(self, s, e, m):
664 """ creates a value from sign / exponent / mantissa
665
666 bias is added here, to the exponent.
667
668 NOTE: order is important, because e_start/e_end can be
669 a bit too long (overwriting s).
670 """
671 return [
672 self.v[0:self.e_start].eq(m), # mantissa
673 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
674 self.v[-1].eq(s), # sign
675 ]
676
677 def _nan(self, s):
678 return (s, self.fp.P128, 1 << (self.e_start-1))
679
680 def _inf(self, s):
681 return (s, self.fp.P128, 0)
682
683 def _zero(self, s):
684 return (s, self.fp.N127, 0)
685
686 def nan(self, s):
687 return self.create(*self._nan(s))
688
689 def quieted_nan(self, other):
690 assert isinstance(other, FPNumBaseRecord)
691 assert self.width == other.width
692 return self.create(other.s, self.fp.P128,
693 other.v[0:self.e_start] | (1 << (self.e_start - 1)))
694
695 def inf(self, s):
696 return self.create(*self._inf(s))
697
698 def max_normal(self, s):
699 return self.create(s, self.fp.P127, ~0)
700
701 def min_denormal(self, s):
702 return self.create(s, self.fp.N127, 1)
703
704 def zero(self, s):
705 return self.create(*self._zero(s))
706
707 def create2(self, s, e, m):
708 """ creates a value from sign / exponent / mantissa
709
710 bias is added here, to the exponent
711 """
712 e = e + self.P127 # exp (add on bias)
713 return Cat(m[0:self.e_start],
714 e[0:self.e_end-self.e_start],
715 s)
716
717 def nan2(self, s):
718 return self.create2(s, self.P128, self.msb1)
719
720 def inf2(self, s):
721 return self.create2(s, self.P128, self.mzero)
722
723 def zero2(self, s):
724 return self.create2(s, self.N127, self.mzero)
725
726 def __iter__(self):
727 yield self.s
728 yield self.e
729 yield self.m
730
731 def eq(self, inp):
732 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
733
734
735 class FPNumBase(FPNumBaseRecord, Elaboratable):
736 """ Floating-point Base Number Class
737 """
738
739 def __init__(self, fp):
740 fp.drop_in(self)
741 self.fp = fp
742 e_width = fp.e_width
743
744 self.is_nan = Signal(reset_less=True)
745 self.is_zero = Signal(reset_less=True)
746 self.is_inf = Signal(reset_less=True)
747 self.is_overflowed = Signal(reset_less=True)
748 self.is_denormalised = Signal(reset_less=True)
749 self.exp_128 = Signal(reset_less=True)
750 self.exp_sub_n126 = Signal(signed(e_width), reset_less=True)
751 self.exp_lt_n126 = Signal(reset_less=True)
752 self.exp_zero = Signal(reset_less=True)
753 self.exp_gt_n126 = Signal(reset_less=True)
754 self.exp_gt127 = Signal(reset_less=True)
755 self.exp_n127 = Signal(reset_less=True)
756 self.exp_n126 = Signal(reset_less=True)
757 self.m_zero = Signal(reset_less=True)
758 self.m_msbzero = Signal(reset_less=True)
759
760 def elaborate(self, platform):
761 m = Module()
762 m.d.comb += self.is_nan.eq(self._is_nan())
763 m.d.comb += self.is_zero.eq(self._is_zero())
764 m.d.comb += self.is_inf.eq(self._is_inf())
765 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
766 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
767 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
768 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
769 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
770 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
771 m.d.comb += self.exp_zero.eq(self.e == 0)
772 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
773 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
774 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
775 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
776 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
777
778 return m
779
780 def _is_nan(self):
781 return (self.exp_128) & (~self.m_zero)
782
783 def _is_inf(self):
784 return (self.exp_128) & (self.m_zero)
785
786 def _is_zero(self):
787 return (self.exp_n127) & (self.m_zero)
788
789 def _is_overflowed(self):
790 return self.exp_gt127
791
792 def _is_denormalised(self):
793 # XXX NOT to be used for "official" quiet NaN tests!
794 # particularly when the MSB has been extended
795 return (self.exp_n126) & (self.m_msbzero)
796
797
798 class FPNumOut(FPNumBase):
799 """ Floating-point Number Class
800
801 Contains signals for an incoming copy of the value, decoded into
802 sign / exponent / mantissa.
803 Also contains encoding functions, creation and recognition of
804 zero, NaN and inf (all signed)
805
806 Four extra bits are included in the mantissa: the top bit
807 (m[-1]) is effectively a carry-overflow. The other three are
808 guard (m[2]), round (m[1]), and sticky (m[0])
809 """
810
811 def __init__(self, fp):
812 FPNumBase.__init__(self, fp)
813
814 def elaborate(self, platform):
815 m = FPNumBase.elaborate(self, platform)
816
817 return m
818
819
820 class MultiShiftRMerge(Elaboratable):
821 """ shifts down (right) and merges lower bits into m[0].
822 m[0] is the "sticky" bit, basically
823 """
824
825 def __init__(self, width, s_max=None):
826 if s_max is None:
827 s_max = bits_for(width - 1)
828 self.smax = Shape.cast(s_max)
829 self.m = Signal(width, reset_less=True)
830 self.inp = Signal(width, reset_less=True)
831 self.diff = Signal(s_max, reset_less=True)
832 self.width = width
833
834 def elaborate(self, platform):
835 m = Module()
836
837 rs = Signal(self.width, reset_less=True)
838 m_mask = Signal(self.width, reset_less=True)
839 smask = Signal(self.width, reset_less=True)
840 stickybit = Signal(reset_less=True)
841 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
842 maxslen = Signal(self.smax.width, reset_less=True)
843 maxsleni = Signal(self.smax.width, reset_less=True)
844
845 sm = MultiShift(self.width-1)
846 m0s = Const(0, self.width-1)
847 mw = Const(self.width-1, len(self.diff))
848 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
849 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
850 ]
851
852 m.d.comb += [
853 # shift mantissa by maxslen, mask by inverse
854 rs.eq(sm.rshift(self.inp[1:], maxslen)),
855 m_mask.eq(sm.rshift(~m0s, maxsleni)),
856 smask.eq(self.inp[1:] & m_mask),
857 # sticky bit combines all mask (and mantissa low bit)
858 stickybit.eq(smask.bool() | self.inp[0]),
859 # mantissa result contains m[0] already.
860 self.m.eq(Cat(stickybit, rs))
861 ]
862 return m
863
864
865 class FPNumShift(FPNumBase, Elaboratable):
866 """ Floating-point Number Class for shifting
867 """
868
869 def __init__(self, mainm, op, inv, width, m_extra=True):
870 FPNumBase.__init__(self, width, m_extra)
871 self.latch_in = Signal()
872 self.mainm = mainm
873 self.inv = inv
874 self.op = op
875
876 def elaborate(self, platform):
877 m = FPNumBase.elaborate(self, platform)
878
879 m.d.comb += self.s.eq(op.s)
880 m.d.comb += self.e.eq(op.e)
881 m.d.comb += self.m.eq(op.m)
882
883 with self.mainm.State("align"):
884 with m.If(self.e < self.inv.e):
885 m.d.sync += self.shift_down()
886
887 return m
888
889 def shift_down(self, inp):
890 """ shifts a mantissa down by one. exponent is increased to compensate
891
892 accuracy is lost as a result in the mantissa however there are 3
893 guard bits (the latter of which is the "sticky" bit)
894 """
895 return [self.e.eq(inp.e + 1),
896 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
897 ]
898
899 def shift_down_multi(self, diff):
900 """ shifts a mantissa down. exponent is increased to compensate
901
902 accuracy is lost as a result in the mantissa however there are 3
903 guard bits (the latter of which is the "sticky" bit)
904
905 this code works by variable-shifting the mantissa by up to
906 its maximum bit-length: no point doing more (it'll still be
907 zero).
908
909 the sticky bit is computed by shifting a batch of 1s by
910 the same amount, which will introduce zeros. it's then
911 inverted and used as a mask to get the LSBs of the mantissa.
912 those are then |'d into the sticky bit.
913 """
914 sm = MultiShift(self.width)
915 mw = Const(self.m_width-1, len(diff))
916 maxslen = Mux(diff > mw, mw, diff)
917 rs = sm.rshift(self.m[1:], maxslen)
918 maxsleni = mw - maxslen
919 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
920
921 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
922 return [self.e.eq(self.e + diff),
923 self.m.eq(Cat(stickybits, rs))
924 ]
925
926 def shift_up_multi(self, diff):
927 """ shifts a mantissa up. exponent is decreased to compensate
928 """
929 sm = MultiShift(self.width)
930 mw = Const(self.m_width, len(diff))
931 maxslen = Mux(diff > mw, mw, diff)
932
933 return [self.e.eq(self.e - diff),
934 self.m.eq(sm.lshift(self.m, maxslen))
935 ]
936
937
938 class FPNumDecode(FPNumBase):
939 """ Floating-point Number Class
940
941 Contains signals for an incoming copy of the value, decoded into
942 sign / exponent / mantissa.
943 Also contains encoding functions, creation and recognition of
944 zero, NaN and inf (all signed)
945
946 Four extra bits are included in the mantissa: the top bit
947 (m[-1]) is effectively a carry-overflow. The other three are
948 guard (m[2]), round (m[1]), and sticky (m[0])
949 """
950
951 def __init__(self, op, fp):
952 FPNumBase.__init__(self, fp)
953 self.op = op
954
955 def elaborate(self, platform):
956 m = FPNumBase.elaborate(self, platform)
957
958 m.d.comb += self.decode(self.v)
959
960 return m
961
962 def decode(self, v):
963 """ decodes a latched value into sign / exponent / mantissa
964
965 bias is subtracted here, from the exponent. exponent
966 is extended to 10 bits so that subtract 127 is done on
967 a 10-bit number
968 """
969 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
970 #print ("decode", self.e_end)
971 return [self.m.eq(Cat(*args)), # mantissa
972 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
973 self.s.eq(v[-1]), # sign
974 ]
975
976
977 class FPNumIn(FPNumBase):
978 """ Floating-point Number Class
979
980 Contains signals for an incoming copy of the value, decoded into
981 sign / exponent / mantissa.
982 Also contains encoding functions, creation and recognition of
983 zero, NaN and inf (all signed)
984
985 Four extra bits are included in the mantissa: the top bit
986 (m[-1]) is effectively a carry-overflow. The other three are
987 guard (m[2]), round (m[1]), and sticky (m[0])
988 """
989
990 def __init__(self, op, fp):
991 FPNumBase.__init__(self, fp)
992 self.latch_in = Signal()
993 self.op = op
994
995 def decode2(self, m):
996 """ decodes a latched value into sign / exponent / mantissa
997
998 bias is subtracted here, from the exponent. exponent
999 is extended to 10 bits so that subtract 127 is done on
1000 a 10-bit number
1001 """
1002 v = self.v
1003 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
1004 #print ("decode", self.e_end)
1005 res = ObjectProxy(m, pipemode=False)
1006 res.m = Cat(*args) # mantissa
1007 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
1008 res.s = v[-1] # sign
1009 return res
1010
1011 def decode(self, v):
1012 """ decodes a latched value into sign / exponent / mantissa
1013
1014 bias is subtracted here, from the exponent. exponent
1015 is extended to 10 bits so that subtract 127 is done on
1016 a 10-bit number
1017 """
1018 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
1019 #print ("decode", self.e_end)
1020 return [self.m.eq(Cat(*args)), # mantissa
1021 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
1022 self.s.eq(v[-1]), # sign
1023 ]
1024
1025 def shift_down(self, inp):
1026 """ shifts a mantissa down by one. exponent is increased to compensate
1027
1028 accuracy is lost as a result in the mantissa however there are 3
1029 guard bits (the latter of which is the "sticky" bit)
1030 """
1031 return [self.e.eq(inp.e + 1),
1032 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
1033 ]
1034
1035 def shift_down_multi(self, diff, inp=None):
1036 """ shifts a mantissa down. exponent is increased to compensate
1037
1038 accuracy is lost as a result in the mantissa however there are 3
1039 guard bits (the latter of which is the "sticky" bit)
1040
1041 this code works by variable-shifting the mantissa by up to
1042 its maximum bit-length: no point doing more (it'll still be
1043 zero).
1044
1045 the sticky bit is computed by shifting a batch of 1s by
1046 the same amount, which will introduce zeros. it's then
1047 inverted and used as a mask to get the LSBs of the mantissa.
1048 those are then |'d into the sticky bit.
1049 """
1050 if inp is None:
1051 inp = self
1052 sm = MultiShift(self.width)
1053 mw = Const(self.m_width-1, len(diff))
1054 maxslen = Mux(diff > mw, mw, diff)
1055 rs = sm.rshift(inp.m[1:], maxslen)
1056 maxsleni = mw - maxslen
1057 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
1058
1059 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1060 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
1061 return [self.e.eq(inp.e + diff),
1062 self.m.eq(Cat(stickybit, rs))
1063 ]
1064
1065 def shift_up_multi(self, diff):
1066 """ shifts a mantissa up. exponent is decreased to compensate
1067 """
1068 sm = MultiShift(self.width)
1069 mw = Const(self.m_width, len(diff))
1070 maxslen = Mux(diff > mw, mw, diff)
1071
1072 return [self.e.eq(self.e - diff),
1073 self.m.eq(sm.lshift(self.m, maxslen))
1074 ]
1075
1076
1077 class Trigger(Elaboratable):
1078 def __init__(self):
1079
1080 self.stb = Signal(reset=0)
1081 self.ack = Signal()
1082 self.trigger = Signal(reset_less=True)
1083
1084 def elaborate(self, platform):
1085 m = Module()
1086 m.d.comb += self.trigger.eq(self.stb & self.ack)
1087 return m
1088
1089 def eq(self, inp):
1090 return [self.stb.eq(inp.stb),
1091 self.ack.eq(inp.ack)
1092 ]
1093
1094 def ports(self):
1095 return [self.stb, self.ack]
1096
1097
1098 class FPOpIn(PrevControl):
1099 def __init__(self, width):
1100 PrevControl.__init__(self)
1101 self.width = width
1102
1103 @property
1104 def v(self):
1105 return self.data_i
1106
1107 def chain_inv(self, in_op, extra=None):
1108 stb = in_op.stb
1109 if extra is not None:
1110 stb = stb & extra
1111 return [self.v.eq(in_op.v), # receive value
1112 self.stb.eq(stb), # receive STB
1113 in_op.ack.eq(~self.ack), # send ACK
1114 ]
1115
1116 def chain_from(self, in_op, extra=None):
1117 stb = in_op.stb
1118 if extra is not None:
1119 stb = stb & extra
1120 return [self.v.eq(in_op.v), # receive value
1121 self.stb.eq(stb), # receive STB
1122 in_op.ack.eq(self.ack), # send ACK
1123 ]
1124
1125
1126 class FPOpOut(NextControl):
1127 def __init__(self, width):
1128 NextControl.__init__(self)
1129 self.width = width
1130
1131 @property
1132 def v(self):
1133 return self.data_o
1134
1135 def chain_inv(self, in_op, extra=None):
1136 stb = in_op.stb
1137 if extra is not None:
1138 stb = stb & extra
1139 return [self.v.eq(in_op.v), # receive value
1140 self.stb.eq(stb), # receive STB
1141 in_op.ack.eq(~self.ack), # send ACK
1142 ]
1143
1144 def chain_from(self, in_op, extra=None):
1145 stb = in_op.stb
1146 if extra is not None:
1147 stb = stb & extra
1148 return [self.v.eq(in_op.v), # receive value
1149 self.stb.eq(stb), # receive STB
1150 in_op.ack.eq(self.ack), # send ACK
1151 ]
1152
1153
1154 class Overflow:
1155 # TODO: change FFLAGS to be FPSCR's status flags
1156 FFLAGS_NV = Const(1<<4, 5) # invalid operation
1157 FFLAGS_DZ = Const(1<<3, 5) # divide by zero
1158 FFLAGS_OF = Const(1<<2, 5) # overflow
1159 FFLAGS_UF = Const(1<<1, 5) # underflow
1160 FFLAGS_NX = Const(1<<0, 5) # inexact
1161 def __init__(self, name=None):
1162 if name is None:
1163 name = ""
1164 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
1165 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
1166 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
1167 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
1168 self.fpflags = Signal(5, reset_less=True, name=name+"fflags")
1169
1170 self.sign = Signal(reset_less=True, name=name+"sign")
1171 """sign bit -- 1 means negative, 0 means positive"""
1172
1173 self.rm = Signal(FPRoundingMode, name=name+"rm",
1174 reset=FPRoundingMode.DEFAULT)
1175 """rounding mode"""
1176
1177 #self.roundz = Signal(reset_less=True)
1178
1179 def __iter__(self):
1180 yield self.guard
1181 yield self.round_bit
1182 yield self.sticky
1183 yield self.m0
1184 yield self.fpflags
1185 yield self.sign
1186 yield self.rm
1187
1188 def eq(self, inp):
1189 return [self.guard.eq(inp.guard),
1190 self.round_bit.eq(inp.round_bit),
1191 self.sticky.eq(inp.sticky),
1192 self.m0.eq(inp.m0),
1193 self.fpflags.eq(inp.fpflags),
1194 self.sign.eq(inp.sign),
1195 self.rm.eq(inp.rm)]
1196
1197 @property
1198 def roundz_rne(self):
1199 """true if the mantissa should be rounded up for `rm == RNE`
1200
1201 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1202 """
1203 return self.guard & (self.round_bit | self.sticky | self.m0)
1204
1205 @property
1206 def roundz_rna(self):
1207 """true if the mantissa should be rounded up for `rm == RNA`
1208
1209 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1210 """
1211 return self.guard
1212
1213 @property
1214 def roundz_rtn(self):
1215 """true if the mantissa should be rounded up for `rm == RTN`
1216
1217 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1218 """
1219 return self.sign & (self.guard | self.round_bit | self.sticky)
1220
1221 @property
1222 def roundz_rto(self):
1223 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1224
1225 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1226 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1227 """
1228 return ~self.m0 & (self.guard | self.round_bit | self.sticky)
1229
1230 @property
1231 def roundz_rtp(self):
1232 """true if the mantissa should be rounded up for `rm == RTP`
1233
1234 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1235 """
1236 return ~self.sign & (self.guard | self.round_bit | self.sticky)
1237
1238 @property
1239 def roundz_rtz(self):
1240 """true if the mantissa should be rounded up for `rm == RTZ`
1241
1242 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1243 """
1244 return False
1245
1246 @property
1247 def roundz(self):
1248 """true if the mantissa should be rounded up for the current rounding
1249 mode `self.rm`
1250 """
1251 d = {
1252 FPRoundingMode.RNA: self.roundz_rna,
1253 FPRoundingMode.RNE: self.roundz_rne,
1254 FPRoundingMode.RTN: self.roundz_rtn,
1255 FPRoundingMode.RTOP: self.roundz_rto,
1256 FPRoundingMode.RTON: self.roundz_rto,
1257 FPRoundingMode.RTP: self.roundz_rtp,
1258 FPRoundingMode.RTZ: self.roundz_rtz,
1259 }
1260 return FPRoundingMode.make_array(lambda rm: d[rm])[self.rm]
1261
1262
1263 class OverflowMod(Elaboratable, Overflow):
1264 def __init__(self, name=None):
1265 Overflow.__init__(self, name)
1266 if name is None:
1267 name = ""
1268 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
1269
1270 def __iter__(self):
1271 yield from Overflow.__iter__(self)
1272 yield self.roundz_out
1273
1274 def eq(self, inp):
1275 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
1276
1277 def elaborate(self, platform):
1278 m = Module()
1279 m.d.comb += self.roundz_out.eq(self.roundz) # roundz is a property
1280 return m
1281
1282
1283 class FPBase:
1284 """ IEEE754 Floating Point Base Class
1285
1286 contains common functions for FP manipulation, such as
1287 extracting and packing operands, normalisation, denormalisation,
1288 rounding etc.
1289 """
1290
1291 def get_op(self, m, op, v, next_state):
1292 """ this function moves to the next state and copies the operand
1293 when both stb and ack are 1.
1294 acknowledgement is sent by setting ack to ZERO.
1295 """
1296 res = v.decode2(m)
1297 ack = Signal()
1298 with m.If((op.ready_o) & (op.valid_i_test)):
1299 m.next = next_state
1300 # op is latched in from FPNumIn class on same ack/stb
1301 m.d.comb += ack.eq(0)
1302 with m.Else():
1303 m.d.comb += ack.eq(1)
1304 return [res, ack]
1305
1306 def denormalise(self, m, a):
1307 """ denormalises a number. this is probably the wrong name for
1308 this function. for normalised numbers (exponent != minimum)
1309 one *extra* bit (the implicit 1) is added *back in*.
1310 for denormalised numbers, the mantissa is left alone
1311 and the exponent increased by 1.
1312
1313 both cases *effectively multiply the number stored by 2*,
1314 which has to be taken into account when extracting the result.
1315 """
1316 with m.If(a.exp_n127):
1317 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
1318 with m.Else():
1319 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
1320
1321 def op_normalise(self, m, op, next_state):
1322 """ operand normalisation
1323 NOTE: just like "align", this one keeps going round every clock
1324 until the result's exponent is within acceptable "range"
1325 """
1326 with m.If((op.m[-1] == 0)): # check last bit of mantissa
1327 m.d.sync += [
1328 op.e.eq(op.e - 1), # DECREASE exponent
1329 op.m.eq(op.m << 1), # shift mantissa UP
1330 ]
1331 with m.Else():
1332 m.next = next_state
1333
1334 def normalise_1(self, m, z, of, next_state):
1335 """ first stage normalisation
1336
1337 NOTE: just like "align", this one keeps going round every clock
1338 until the result's exponent is within acceptable "range"
1339 NOTE: the weirdness of reassigning guard and round is due to
1340 the extra mantissa bits coming from tot[0..2]
1341 """
1342 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
1343 m.d.sync += [
1344 z.e.eq(z.e - 1), # DECREASE exponent
1345 z.m.eq(z.m << 1), # shift mantissa UP
1346 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
1347 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
1348 of.round_bit.eq(0), # reset round bit
1349 of.m0.eq(of.guard),
1350 ]
1351 with m.Else():
1352 m.next = next_state
1353
1354 def normalise_2(self, m, z, of, next_state):
1355 """ second stage normalisation
1356
1357 NOTE: just like "align", this one keeps going round every clock
1358 until the result's exponent is within acceptable "range"
1359 NOTE: the weirdness of reassigning guard and round is due to
1360 the extra mantissa bits coming from tot[0..2]
1361 """
1362 with m.If(z.e < z.fp.N126):
1363 m.d.sync += [
1364 z.e.eq(z.e + 1), # INCREASE exponent
1365 z.m.eq(z.m >> 1), # shift mantissa DOWN
1366 of.guard.eq(z.m[0]),
1367 of.m0.eq(z.m[1]),
1368 of.round_bit.eq(of.guard),
1369 of.sticky.eq(of.sticky | of.round_bit)
1370 ]
1371 with m.Else():
1372 m.next = next_state
1373
1374 def roundz(self, m, z, roundz):
1375 """ performs rounding on the output. TODO: different kinds of rounding
1376 """
1377 with m.If(roundz):
1378 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1379 with m.If(z.m == z.fp.m1s): # all 1s
1380 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1381
1382 def corrections(self, m, z, next_state):
1383 """ denormalisation and sign-bug corrections
1384 """
1385 m.next = next_state
1386 # denormalised, correct exponent to zero
1387 with m.If(z.is_denormalised):
1388 m.d.sync += z.e.eq(z.fp.N127)
1389
1390 def pack(self, m, z, next_state):
1391 """ packs the result into the output (detects overflow->Inf)
1392 """
1393 m.next = next_state
1394 # if overflow occurs, return inf
1395 with m.If(z.is_overflowed):
1396 m.d.sync += z.inf(z.s)
1397 with m.Else():
1398 m.d.sync += z.create(z.s, z.e, z.m)
1399
1400 def put_z(self, m, z, out_z, next_state):
1401 """ put_z: stores the result in the output. raises stb and waits
1402 for ack to be set to 1 before moving to the next state.
1403 resets stb back to zero when that occurs, as acknowledgement.
1404 """
1405 m.d.sync += [
1406 out_z.v.eq(z.v)
1407 ]
1408 with m.If(out_z.valid_o & out_z.ready_i_test):
1409 m.d.sync += out_z.valid_o.eq(0)
1410 m.next = next_state
1411 with m.Else():
1412 m.d.sync += out_z.valid_o.eq(1)
1413
1414
1415 class FPState(FPBase):
1416 def __init__(self, state_from):
1417 self.state_from = state_from
1418
1419 def set_inputs(self, inputs):
1420 self.inputs = inputs
1421 for k, v in inputs.items():
1422 setattr(self, k, v)
1423
1424 def set_outputs(self, outputs):
1425 self.outputs = outputs
1426 for k, v in outputs.items():
1427 setattr(self, k, v)
1428
1429
1430 class FPID:
1431 def __init__(self, id_wid):
1432 self.id_wid = id_wid
1433 if self.id_wid:
1434 self.in_mid = Signal(id_wid, reset_less=True)
1435 self.out_mid = Signal(id_wid, reset_less=True)
1436 else:
1437 self.in_mid = None
1438 self.out_mid = None
1439
1440 def idsync(self, m):
1441 if self.id_wid is not None:
1442 m.d.sync += self.out_mid.eq(self.in_mid)
1443
1444
1445 if __name__ == '__main__':
1446 unittest.main()