1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
9 from nmigen
import (Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
, Array
,
11 from nmigen
.utils
import bits_for
12 from operator
import or_
13 from functools
import reduce
15 from nmutil
.singlepipe
import PrevControl
, NextControl
16 from nmutil
.pipeline
import ObjectProxy
22 from nmigen
.hdl
.smtlib2
import RoundingModeEnum
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
31 class FPRoundingMode(enum
.Enum
):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
39 """Round to Nearest Even
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
47 ROUND_NEAREST_TIES_TO_EVEN
= RNE
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
58 ROUND_TOWARDS_ZERO
= RTZ
61 """Round towards +Infinity
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
69 ROUND_TOWARDS_POSITIVE
= RTP
72 """Round towards -Infinity
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
80 ROUND_TOWARDS_NEGATIVE
= RTN
83 """Round to Nearest Away
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
91 ROUND_NEAREST_TIES_TO_AWAY
= RNA
94 """Round to Odd, unsigned zeros are Positive
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
= RTOP
116 """Round to Odd, unsigned zeros are Negative
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
= RTON
139 l
= [None] * len(FPRoundingMode
)
140 for rm
in FPRoundingMode
:
144 def overflow_rounds_to_inf(self
, sign
):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
148 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
149 if self
is FPRoundingMode
.RNE
:
151 elif self
is FPRoundingMode
.RTZ
:
153 elif self
is FPRoundingMode
.RTP
:
155 elif self
is FPRoundingMode
.RTN
:
157 elif self
is FPRoundingMode
.RNA
:
159 elif self
is FPRoundingMode
.RTOP
:
162 assert self
is FPRoundingMode
.RTON
165 def underflow_rounds_to_zero(self
, sign
):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
169 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
170 if self
is FPRoundingMode
.RNE
:
172 elif self
is FPRoundingMode
.RTZ
:
174 elif self
is FPRoundingMode
.RTP
:
176 elif self
is FPRoundingMode
.RTN
:
178 elif self
is FPRoundingMode
.RNA
:
180 elif self
is FPRoundingMode
.RTOP
:
183 assert self
is FPRoundingMode
.RTON
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
190 if self
is FPRoundingMode
.RNE
:
192 elif self
is FPRoundingMode
.RTZ
:
194 elif self
is FPRoundingMode
.RTP
:
196 elif self
is FPRoundingMode
.RTN
:
198 elif self
is FPRoundingMode
.RNA
:
200 elif self
is FPRoundingMode
.RTOP
:
203 assert self
is FPRoundingMode
.RTON
207 def to_smtlib2(self
, default
=_raise_err
):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
212 if self
is FPRoundingMode
.RNE
:
213 return RoundingModeEnum
.RNE
214 elif self
is FPRoundingMode
.RTZ
:
215 return RoundingModeEnum
.RTZ
216 elif self
is FPRoundingMode
.RTP
:
217 return RoundingModeEnum
.RTP
218 elif self
is FPRoundingMode
.RTN
:
219 return RoundingModeEnum
.RTN
220 elif self
is FPRoundingMode
.RNA
:
221 return RoundingModeEnum
.RNA
223 assert self
in (FPRoundingMode
.RTOP
, FPRoundingMode
.RTON
)
224 if default
is _raise_err
:
226 "no corresponding smtlib2 rounding mode", self
)
233 """ Class describing binary floating-point formats based on IEEE 754.
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
249 """ Create ``FPFormat`` instance. """
250 self
.e_width
= e_width
251 self
.m_width
= m_width
252 self
.has_int_bit
= has_int_bit
253 self
.has_sign
= has_sign
255 def __eq__(self
, other
):
256 """ Check for equality. """
257 if not isinstance(other
, FPFormat
):
258 return NotImplemented
259 return (self
.e_width
== other
.e_width
260 and self
.m_width
== other
.m_width
261 and self
.has_int_bit
== other
.has_int_bit
262 and self
.has_sign
== other
.has_sign
)
266 """ Get standard IEEE 754-2008 format.
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
272 return FPFormat(5, 10)
274 return FPFormat(8, 23)
276 return FPFormat(11, 52)
278 return FPFormat(15, 112)
279 if width
> 128 and width
% 32 == 0:
280 if width
> 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width
= round(4 * math
.log2(width
)) - 13
283 return FPFormat(e_width
, width
- 1 - e_width
)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
290 if self
== self
.standard(self
.width
):
291 return f
"FPFormat.standard({self.width})"
294 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
295 if self
.has_int_bit
is not False:
296 retval
+= f
", {self.has_int_bit}"
297 if self
.has_sign
is not True:
298 retval
+= f
", {self.has_sign}"
301 def get_sign_field(self
, x
):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
305 return x
>> (self
.e_width
+ self
.m_width
)
307 def get_exponent_field(self
, x
):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
310 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
313 def get_exponent(self
, x
):
314 """ returns the exponent of its input number, x
316 return self
.get_exponent_field(x
) - self
.exponent_bias
318 def get_mantissa_field(self
, x
):
319 """ returns the mantissa of its input number, x
321 return x
& self
.mantissa_mask
323 def get_mantissa_value(self
, x
):
324 """ returns the mantissa of its input number, x, but with the
325 implicit bit, if any, made explicit.
328 return self
.get_mantissa_field(x
)
329 exponent_field
= self
.get_exponent_field(x
)
330 mantissa_field
= self
.get_mantissa_field(x
)
331 implicit_bit
= exponent_field
== self
.exponent_denormal_zero
332 return (implicit_bit
<< self
.fraction_width
) | mantissa_field
334 def is_zero(self
, x
):
335 """ returns true if x is +/- zero
337 return (self
.get_exponent(x
) == self
.e_sub
) & \
338 (self
.get_mantissa_field(x
) == 0)
340 def is_subnormal(self
, x
):
341 """ returns true if x is subnormal (exp at minimum)
343 return (self
.get_exponent(x
) == self
.e_sub
) & \
344 (self
.get_mantissa_field(x
) != 0)
347 """ returns true if x is infinite
349 return (self
.get_exponent(x
) == self
.e_max
) & \
350 (self
.get_mantissa_field(x
) == 0)
353 """ returns true if x is a nan (quiet or signalling)
355 return (self
.get_exponent(x
) == self
.e_max
) & \
356 (self
.get_mantissa_field(x
) != 0)
358 def is_quiet_nan(self
, x
):
359 """ returns true if x is a quiet nan
361 highbit
= 1 << (self
.m_width
- 1)
362 return (self
.get_exponent(x
) == self
.e_max
) & \
363 (self
.get_mantissa_field(x
) != 0) & \
364 (self
.get_mantissa_field(x
) & highbit
!= 0)
366 def to_quiet_nan(self
, x
):
367 """ converts `x` to a quiet NaN """
368 highbit
= 1 << (self
.m_width
- 1)
369 return x | highbit | self
.exponent_mask
371 def quiet_nan(self
, sign
=0):
372 """ return the default quiet NaN with sign `sign` """
373 return self
.to_quiet_nan(self
.zero(sign
))
375 def zero(self
, sign
=0):
376 """ return zero with sign `sign` """
377 return (sign
!= 0) << (self
.e_width
+ self
.m_width
)
379 def inf(self
, sign
=0):
380 """ return infinity with sign `sign` """
381 return self
.zero(sign
) | self
.exponent_mask
383 def is_nan_signaling(self
, x
):
384 """ returns true if x is a signalling nan
386 highbit
= 1 << (self
.m_width
- 1)
387 return (self
.get_exponent(x
) == self
.e_max
) & \
388 (self
.get_mantissa_field(x
) != 0) & \
389 (self
.get_mantissa_field(x
) & highbit
) == 0
393 """ Get the total number of bits in the FP format. """
394 return self
.has_sign
+ self
.e_width
+ self
.m_width
397 def mantissa_mask(self
):
398 """ Get a mantissa mask based on the mantissa width """
399 return (1 << self
.m_width
) - 1
402 def exponent_mask(self
):
403 """ Get an exponent mask """
404 return self
.exponent_inf_nan
<< self
.m_width
407 def exponent_inf_nan(self
):
408 """ Get the value of the exponent field designating infinity/NaN. """
409 return (1 << self
.e_width
) - 1
413 """ get the maximum exponent (minus bias)
415 return self
.exponent_inf_nan
- self
.exponent_bias
419 return self
.exponent_denormal_zero
- self
.exponent_bias
421 def exponent_denormal_zero(self
):
422 """ Get the value of the exponent field designating denormal/zero. """
426 def exponent_min_normal(self
):
427 """ Get the minimum value of the exponent field for normal numbers. """
431 def exponent_max_normal(self
):
432 """ Get the maximum value of the exponent field for normal numbers. """
433 return self
.exponent_inf_nan
- 1
436 def exponent_bias(self
):
437 """ Get the exponent bias. """
438 return (1 << (self
.e_width
- 1)) - 1
441 def fraction_width(self
):
442 """ Get the number of mantissa bits that are fraction bits. """
443 return self
.m_width
- self
.has_int_bit
446 class TestFPFormat(unittest
.TestCase
):
447 """ very quick test for FPFormat
450 def test_fpformat_fp64(self
):
451 f64
= FPFormat
.standard(64)
452 from sfpy
import Float64
453 x
= Float64(1.0).bits
456 self
.assertEqual(f64
.get_exponent(x
), 0)
457 x
= Float64(2.0).bits
459 self
.assertEqual(f64
.get_exponent(x
), 1)
461 x
= Float64(1.5).bits
462 m
= f64
.get_mantissa_field(x
)
463 print (hex(x
), hex(m
))
464 self
.assertEqual(m
, 0x8000000000000)
466 s
= f64
.get_sign_field(x
)
467 print (hex(x
), hex(s
))
468 self
.assertEqual(s
, 0)
470 x
= Float64(-1.5).bits
471 s
= f64
.get_sign_field(x
)
472 print (hex(x
), hex(s
))
473 self
.assertEqual(s
, 1)
475 def test_fpformat_fp32(self
):
476 f32
= FPFormat
.standard(32)
477 from sfpy
import Float32
478 x
= Float32(1.0).bits
481 self
.assertEqual(f32
.get_exponent(x
), 0)
482 x
= Float32(2.0).bits
484 self
.assertEqual(f32
.get_exponent(x
), 1)
486 x
= Float32(1.5).bits
487 m
= f32
.get_mantissa_field(x
)
488 print (hex(x
), hex(m
))
489 self
.assertEqual(m
, 0x400000)
492 x
= Float32(-1.0).sqrt()
495 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
496 f32
.get_mantissa_field(x
), i
)
497 self
.assertEqual(i
, True)
500 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
503 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
504 f32
.get_mantissa_field(x
), i
)
505 self
.assertEqual(i
, True)
510 i
= f32
.is_subnormal(x
)
511 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
512 f32
.get_mantissa_field(x
), i
)
513 self
.assertEqual(i
, True)
517 i
= f32
.is_subnormal(x
)
518 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
519 f32
.get_mantissa_field(x
), i
)
520 self
.assertEqual(i
, False)
524 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
525 f32
.get_mantissa_field(x
), i
)
526 self
.assertEqual(i
, True)
529 class MultiShiftR(Elaboratable
):
531 def __init__(self
, width
):
533 self
.smax
= bits_for(width
- 1)
534 self
.i
= Signal(width
, reset_less
=True)
535 self
.s
= Signal(self
.smax
, reset_less
=True)
536 self
.o
= Signal(width
, reset_less
=True)
538 def elaborate(self
, platform
):
540 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
545 """ Generates variable-length single-cycle shifter from a series
546 of conditional tests on each bit of the left/right shift operand.
547 Each bit tested produces output shifted by that number of bits,
548 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
549 shifts by 2 bits, each partial result cascading to the next Mux.
551 Could be adapted to do arithmetic shift by taking copies of the
552 MSB instead of zeros.
555 def __init__(self
, width
):
557 self
.smax
= bits_for(width
- 1)
559 def lshift(self
, op
, s
):
563 def rshift(self
, op
, s
):
568 class FPNumBaseRecord
:
569 """ Floating-point Base Number Class.
571 This class is designed to be passed around in other data structures
572 (between pipelines and between stages). Its "friend" is FPNumBase,
573 which is a *module*. The reason for the discernment is because
574 nmigen modules that are not added to submodules results in the
575 irritating "Elaboration" warning. Despite not *needing* FPNumBase
576 in many cases to be added as a submodule (because it is just data)
577 this was not possible to solve without splitting out the data from
581 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
584 # assert false, "missing name"
588 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
589 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
590 e_max
= 1 << (e_width
-3)
591 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
594 # mantissa extra bits (top,guard,round)
596 m_width
+= self
.m_extra
600 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
601 e_width
+= self
.e_extra
604 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
605 self
.m_width
= m_width
606 self
.e_width
= e_width
607 self
.e_start
= self
.rmw
608 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
610 self
.v
= Signal(width
, reset_less
=True,
611 name
=name
+"v") # Latched copy of value
612 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
613 self
.e
= Signal((e_width
, True),
614 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
615 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
620 def drop_in(self
, fp
):
626 fp
.width
= self
.width
627 fp
.e_width
= self
.e_width
628 fp
.e_max
= self
.e_max
629 fp
.m_width
= self
.m_width
630 fp
.e_start
= self
.e_start
631 fp
.e_end
= self
.e_end
632 fp
.m_extra
= self
.m_extra
634 m_width
= self
.m_width
636 e_width
= self
.e_width
638 self
.mzero
= Const(0, (m_width
, False))
639 m_msb
= 1 << (self
.m_width
-2)
640 self
.msb1
= Const(m_msb
, (m_width
, False))
641 self
.m1s
= Const(-1, (m_width
, False))
642 self
.P128
= Const(e_max
, (e_width
, True))
643 self
.P127
= Const(e_max
-1, (e_width
, True))
644 self
.N127
= Const(-(e_max
-1), (e_width
, True))
645 self
.N126
= Const(-(e_max
-2), (e_width
, True))
647 def create(self
, s
, e
, m
):
648 """ creates a value from sign / exponent / mantissa
650 bias is added here, to the exponent.
652 NOTE: order is important, because e_start/e_end can be
653 a bit too long (overwriting s).
656 self
.v
[0:self
.e_start
].eq(m
), # mantissa
657 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
658 self
.v
[-1].eq(s
), # sign
662 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
665 return (s
, self
.fp
.P128
, 0)
668 return (s
, self
.fp
.N127
, 0)
671 return self
.create(*self
._nan
(s
))
673 def quieted_nan(self
, other
):
674 assert isinstance(other
, FPNumBaseRecord
)
675 assert self
.width
== other
.width
676 return self
.create(other
.s
, self
.fp
.P128
,
677 other
.v
[0:self
.e_start
] |
(1 << (self
.e_start
- 1)))
680 return self
.create(*self
._inf
(s
))
682 def max_normal(self
, s
):
683 return self
.create(s
, self
.fp
.P127
, ~
0)
685 def min_denormal(self
, s
):
686 return self
.create(s
, self
.fp
.N127
, 1)
689 return self
.create(*self
._zero
(s
))
691 def create2(self
, s
, e
, m
):
692 """ creates a value from sign / exponent / mantissa
694 bias is added here, to the exponent
696 e
= e
+ self
.P127
# exp (add on bias)
697 return Cat(m
[0:self
.e_start
],
698 e
[0:self
.e_end
-self
.e_start
],
702 return self
.create2(s
, self
.P128
, self
.msb1
)
705 return self
.create2(s
, self
.P128
, self
.mzero
)
708 return self
.create2(s
, self
.N127
, self
.mzero
)
716 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
719 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
720 """ Floating-point Base Number Class
723 def __init__(self
, fp
):
728 self
.is_nan
= Signal(reset_less
=True)
729 self
.is_zero
= Signal(reset_less
=True)
730 self
.is_inf
= Signal(reset_less
=True)
731 self
.is_overflowed
= Signal(reset_less
=True)
732 self
.is_denormalised
= Signal(reset_less
=True)
733 self
.exp_128
= Signal(reset_less
=True)
734 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
735 self
.exp_lt_n126
= Signal(reset_less
=True)
736 self
.exp_zero
= Signal(reset_less
=True)
737 self
.exp_gt_n126
= Signal(reset_less
=True)
738 self
.exp_gt127
= Signal(reset_less
=True)
739 self
.exp_n127
= Signal(reset_less
=True)
740 self
.exp_n126
= Signal(reset_less
=True)
741 self
.m_zero
= Signal(reset_less
=True)
742 self
.m_msbzero
= Signal(reset_less
=True)
744 def elaborate(self
, platform
):
746 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
747 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
748 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
749 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
750 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
751 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
752 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
753 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
754 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
755 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
756 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
757 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
758 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
759 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
760 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
765 return (self
.exp_128
) & (~self
.m_zero
)
768 return (self
.exp_128
) & (self
.m_zero
)
771 return (self
.exp_n127
) & (self
.m_zero
)
773 def _is_overflowed(self
):
774 return self
.exp_gt127
776 def _is_denormalised(self
):
777 # XXX NOT to be used for "official" quiet NaN tests!
778 # particularly when the MSB has been extended
779 return (self
.exp_n126
) & (self
.m_msbzero
)
782 class FPNumOut(FPNumBase
):
783 """ Floating-point Number Class
785 Contains signals for an incoming copy of the value, decoded into
786 sign / exponent / mantissa.
787 Also contains encoding functions, creation and recognition of
788 zero, NaN and inf (all signed)
790 Four extra bits are included in the mantissa: the top bit
791 (m[-1]) is effectively a carry-overflow. The other three are
792 guard (m[2]), round (m[1]), and sticky (m[0])
795 def __init__(self
, fp
):
796 FPNumBase
.__init
__(self
, fp
)
798 def elaborate(self
, platform
):
799 m
= FPNumBase
.elaborate(self
, platform
)
804 class MultiShiftRMerge(Elaboratable
):
805 """ shifts down (right) and merges lower bits into m[0].
806 m[0] is the "sticky" bit, basically
809 def __init__(self
, width
, s_max
=None):
811 s_max
= bits_for(width
- 1)
812 self
.smax
= Shape
.cast(s_max
)
813 self
.m
= Signal(width
, reset_less
=True)
814 self
.inp
= Signal(width
, reset_less
=True)
815 self
.diff
= Signal(s_max
, reset_less
=True)
818 def elaborate(self
, platform
):
821 rs
= Signal(self
.width
, reset_less
=True)
822 m_mask
= Signal(self
.width
, reset_less
=True)
823 smask
= Signal(self
.width
, reset_less
=True)
824 stickybit
= Signal(reset_less
=True)
825 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
826 maxslen
= Signal(self
.smax
.width
, reset_less
=True)
827 maxsleni
= Signal(self
.smax
.width
, reset_less
=True)
829 sm
= MultiShift(self
.width
-1)
830 m0s
= Const(0, self
.width
-1)
831 mw
= Const(self
.width
-1, len(self
.diff
))
832 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
833 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
837 # shift mantissa by maxslen, mask by inverse
838 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
839 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
840 smask
.eq(self
.inp
[1:] & m_mask
),
841 # sticky bit combines all mask (and mantissa low bit)
842 stickybit
.eq(smask
.bool() | self
.inp
[0]),
843 # mantissa result contains m[0] already.
844 self
.m
.eq(Cat(stickybit
, rs
))
849 class FPNumShift(FPNumBase
, Elaboratable
):
850 """ Floating-point Number Class for shifting
853 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
854 FPNumBase
.__init
__(self
, width
, m_extra
)
855 self
.latch_in
= Signal()
860 def elaborate(self
, platform
):
861 m
= FPNumBase
.elaborate(self
, platform
)
863 m
.d
.comb
+= self
.s
.eq(op
.s
)
864 m
.d
.comb
+= self
.e
.eq(op
.e
)
865 m
.d
.comb
+= self
.m
.eq(op
.m
)
867 with self
.mainm
.State("align"):
868 with m
.If(self
.e
< self
.inv
.e
):
869 m
.d
.sync
+= self
.shift_down()
873 def shift_down(self
, inp
):
874 """ shifts a mantissa down by one. exponent is increased to compensate
876 accuracy is lost as a result in the mantissa however there are 3
877 guard bits (the latter of which is the "sticky" bit)
879 return [self
.e
.eq(inp
.e
+ 1),
880 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
883 def shift_down_multi(self
, diff
):
884 """ shifts a mantissa down. exponent is increased to compensate
886 accuracy is lost as a result in the mantissa however there are 3
887 guard bits (the latter of which is the "sticky" bit)
889 this code works by variable-shifting the mantissa by up to
890 its maximum bit-length: no point doing more (it'll still be
893 the sticky bit is computed by shifting a batch of 1s by
894 the same amount, which will introduce zeros. it's then
895 inverted and used as a mask to get the LSBs of the mantissa.
896 those are then |'d into the sticky bit.
898 sm
= MultiShift(self
.width
)
899 mw
= Const(self
.m_width
-1, len(diff
))
900 maxslen
= Mux(diff
> mw
, mw
, diff
)
901 rs
= sm
.rshift(self
.m
[1:], maxslen
)
902 maxsleni
= mw
- maxslen
903 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
905 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
906 return [self
.e
.eq(self
.e
+ diff
),
907 self
.m
.eq(Cat(stickybits
, rs
))
910 def shift_up_multi(self
, diff
):
911 """ shifts a mantissa up. exponent is decreased to compensate
913 sm
= MultiShift(self
.width
)
914 mw
= Const(self
.m_width
, len(diff
))
915 maxslen
= Mux(diff
> mw
, mw
, diff
)
917 return [self
.e
.eq(self
.e
- diff
),
918 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
922 class FPNumDecode(FPNumBase
):
923 """ Floating-point Number Class
925 Contains signals for an incoming copy of the value, decoded into
926 sign / exponent / mantissa.
927 Also contains encoding functions, creation and recognition of
928 zero, NaN and inf (all signed)
930 Four extra bits are included in the mantissa: the top bit
931 (m[-1]) is effectively a carry-overflow. The other three are
932 guard (m[2]), round (m[1]), and sticky (m[0])
935 def __init__(self
, op
, fp
):
936 FPNumBase
.__init
__(self
, fp
)
939 def elaborate(self
, platform
):
940 m
= FPNumBase
.elaborate(self
, platform
)
942 m
.d
.comb
+= self
.decode(self
.v
)
947 """ decodes a latched value into sign / exponent / mantissa
949 bias is subtracted here, from the exponent. exponent
950 is extended to 10 bits so that subtract 127 is done on
953 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
954 #print ("decode", self.e_end)
955 return [self
.m
.eq(Cat(*args
)), # mantissa
956 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
957 self
.s
.eq(v
[-1]), # sign
961 class FPNumIn(FPNumBase
):
962 """ Floating-point Number Class
964 Contains signals for an incoming copy of the value, decoded into
965 sign / exponent / mantissa.
966 Also contains encoding functions, creation and recognition of
967 zero, NaN and inf (all signed)
969 Four extra bits are included in the mantissa: the top bit
970 (m[-1]) is effectively a carry-overflow. The other three are
971 guard (m[2]), round (m[1]), and sticky (m[0])
974 def __init__(self
, op
, fp
):
975 FPNumBase
.__init
__(self
, fp
)
976 self
.latch_in
= Signal()
979 def decode2(self
, m
):
980 """ decodes a latched value into sign / exponent / mantissa
982 bias is subtracted here, from the exponent. exponent
983 is extended to 10 bits so that subtract 127 is done on
987 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
988 #print ("decode", self.e_end)
989 res
= ObjectProxy(m
, pipemode
=False)
990 res
.m
= Cat(*args
) # mantissa
991 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
996 """ decodes a latched value into sign / exponent / mantissa
998 bias is subtracted here, from the exponent. exponent
999 is extended to 10 bits so that subtract 127 is done on
1002 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1003 #print ("decode", self.e_end)
1004 return [self
.m
.eq(Cat(*args
)), # mantissa
1005 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
1006 self
.s
.eq(v
[-1]), # sign
1009 def shift_down(self
, inp
):
1010 """ shifts a mantissa down by one. exponent is increased to compensate
1012 accuracy is lost as a result in the mantissa however there are 3
1013 guard bits (the latter of which is the "sticky" bit)
1015 return [self
.e
.eq(inp
.e
+ 1),
1016 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
1019 def shift_down_multi(self
, diff
, inp
=None):
1020 """ shifts a mantissa down. exponent is increased to compensate
1022 accuracy is lost as a result in the mantissa however there are 3
1023 guard bits (the latter of which is the "sticky" bit)
1025 this code works by variable-shifting the mantissa by up to
1026 its maximum bit-length: no point doing more (it'll still be
1029 the sticky bit is computed by shifting a batch of 1s by
1030 the same amount, which will introduce zeros. it's then
1031 inverted and used as a mask to get the LSBs of the mantissa.
1032 those are then |'d into the sticky bit.
1036 sm
= MultiShift(self
.width
)
1037 mw
= Const(self
.m_width
-1, len(diff
))
1038 maxslen
= Mux(diff
> mw
, mw
, diff
)
1039 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
1040 maxsleni
= mw
- maxslen
1041 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
1043 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1044 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
1045 return [self
.e
.eq(inp
.e
+ diff
),
1046 self
.m
.eq(Cat(stickybit
, rs
))
1049 def shift_up_multi(self
, diff
):
1050 """ shifts a mantissa up. exponent is decreased to compensate
1052 sm
= MultiShift(self
.width
)
1053 mw
= Const(self
.m_width
, len(diff
))
1054 maxslen
= Mux(diff
> mw
, mw
, diff
)
1056 return [self
.e
.eq(self
.e
- diff
),
1057 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
1061 class Trigger(Elaboratable
):
1064 self
.stb
= Signal(reset
=0)
1066 self
.trigger
= Signal(reset_less
=True)
1068 def elaborate(self
, platform
):
1070 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
1074 return [self
.stb
.eq(inp
.stb
),
1075 self
.ack
.eq(inp
.ack
)
1079 return [self
.stb
, self
.ack
]
1082 class FPOpIn(PrevControl
):
1083 def __init__(self
, width
):
1084 PrevControl
.__init
__(self
)
1091 def chain_inv(self
, in_op
, extra
=None):
1093 if extra
is not None:
1095 return [self
.v
.eq(in_op
.v
), # receive value
1096 self
.stb
.eq(stb
), # receive STB
1097 in_op
.ack
.eq(~self
.ack
), # send ACK
1100 def chain_from(self
, in_op
, extra
=None):
1102 if extra
is not None:
1104 return [self
.v
.eq(in_op
.v
), # receive value
1105 self
.stb
.eq(stb
), # receive STB
1106 in_op
.ack
.eq(self
.ack
), # send ACK
1110 class FPOpOut(NextControl
):
1111 def __init__(self
, width
):
1112 NextControl
.__init
__(self
)
1119 def chain_inv(self
, in_op
, extra
=None):
1121 if extra
is not None:
1123 return [self
.v
.eq(in_op
.v
), # receive value
1124 self
.stb
.eq(stb
), # receive STB
1125 in_op
.ack
.eq(~self
.ack
), # send ACK
1128 def chain_from(self
, in_op
, extra
=None):
1130 if extra
is not None:
1132 return [self
.v
.eq(in_op
.v
), # receive value
1133 self
.stb
.eq(stb
), # receive STB
1134 in_op
.ack
.eq(self
.ack
), # send ACK
1139 # TODO: change FFLAGS to be FPSCR's status flags
1140 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
1141 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
1142 FFLAGS_OF
= Const(1<<2, 5) # overflow
1143 FFLAGS_UF
= Const(1<<1, 5) # underflow
1144 FFLAGS_NX
= Const(1<<0, 5) # inexact
1145 def __init__(self
, name
=None):
1148 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
1149 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
1150 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
1151 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
1152 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
1154 self
.sign
= Signal(reset_less
=True, name
=name
+"sign")
1155 """sign bit -- 1 means negative, 0 means positive"""
1157 self
.rm
= Signal(FPRoundingMode
, name
=name
+"rm",
1158 reset
=FPRoundingMode
.DEFAULT
)
1161 #self.roundz = Signal(reset_less=True)
1165 yield self
.round_bit
1173 return [self
.guard
.eq(inp
.guard
),
1174 self
.round_bit
.eq(inp
.round_bit
),
1175 self
.sticky
.eq(inp
.sticky
),
1177 self
.fpflags
.eq(inp
.fpflags
),
1178 self
.sign
.eq(inp
.sign
),
1182 def roundz_rne(self
):
1183 """true if the mantissa should be rounded up for `rm == RNE`
1185 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1187 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
1190 def roundz_rna(self
):
1191 """true if the mantissa should be rounded up for `rm == RNA`
1193 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1198 def roundz_rtn(self
):
1199 """true if the mantissa should be rounded up for `rm == RTN`
1201 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1203 return self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1206 def roundz_rto(self
):
1207 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1209 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1210 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1212 return ~self
.m0
& (self
.guard | self
.round_bit | self
.sticky
)
1215 def roundz_rtp(self
):
1216 """true if the mantissa should be rounded up for `rm == RTP`
1218 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1220 return ~self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1223 def roundz_rtz(self
):
1224 """true if the mantissa should be rounded up for `rm == RTZ`
1226 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1232 """true if the mantissa should be rounded up for the current rounding
1236 FPRoundingMode
.RNA
: self
.roundz_rna
,
1237 FPRoundingMode
.RNE
: self
.roundz_rne
,
1238 FPRoundingMode
.RTN
: self
.roundz_rtn
,
1239 FPRoundingMode
.RTOP
: self
.roundz_rto
,
1240 FPRoundingMode
.RTON
: self
.roundz_rto
,
1241 FPRoundingMode
.RTP
: self
.roundz_rtp
,
1242 FPRoundingMode
.RTZ
: self
.roundz_rtz
,
1244 return FPRoundingMode
.make_array(lambda rm
: d
[rm
])[self
.rm
]
1247 class OverflowMod(Elaboratable
, Overflow
):
1248 def __init__(self
, name
=None):
1249 Overflow
.__init
__(self
, name
)
1252 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
1255 yield from Overflow
.__iter
__(self
)
1256 yield self
.roundz_out
1259 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
1261 def elaborate(self
, platform
):
1263 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
1268 """ IEEE754 Floating Point Base Class
1270 contains common functions for FP manipulation, such as
1271 extracting and packing operands, normalisation, denormalisation,
1275 def get_op(self
, m
, op
, v
, next_state
):
1276 """ this function moves to the next state and copies the operand
1277 when both stb and ack are 1.
1278 acknowledgement is sent by setting ack to ZERO.
1282 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
1284 # op is latched in from FPNumIn class on same ack/stb
1285 m
.d
.comb
+= ack
.eq(0)
1287 m
.d
.comb
+= ack
.eq(1)
1290 def denormalise(self
, m
, a
):
1291 """ denormalises a number. this is probably the wrong name for
1292 this function. for normalised numbers (exponent != minimum)
1293 one *extra* bit (the implicit 1) is added *back in*.
1294 for denormalised numbers, the mantissa is left alone
1295 and the exponent increased by 1.
1297 both cases *effectively multiply the number stored by 2*,
1298 which has to be taken into account when extracting the result.
1300 with m
.If(a
.exp_n127
):
1301 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
1303 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
1305 def op_normalise(self
, m
, op
, next_state
):
1306 """ operand normalisation
1307 NOTE: just like "align", this one keeps going round every clock
1308 until the result's exponent is within acceptable "range"
1310 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
1312 op
.e
.eq(op
.e
- 1), # DECREASE exponent
1313 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
1318 def normalise_1(self
, m
, z
, of
, next_state
):
1319 """ first stage normalisation
1321 NOTE: just like "align", this one keeps going round every clock
1322 until the result's exponent is within acceptable "range"
1323 NOTE: the weirdness of reassigning guard and round is due to
1324 the extra mantissa bits coming from tot[0..2]
1326 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
1328 z
.e
.eq(z
.e
- 1), # DECREASE exponent
1329 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1330 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1331 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1332 of
.round_bit
.eq(0), # reset round bit
1338 def normalise_2(self
, m
, z
, of
, next_state
):
1339 """ second stage normalisation
1341 NOTE: just like "align", this one keeps going round every clock
1342 until the result's exponent is within acceptable "range"
1343 NOTE: the weirdness of reassigning guard and round is due to
1344 the extra mantissa bits coming from tot[0..2]
1346 with m
.If(z
.e
< z
.fp
.N126
):
1348 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1349 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1350 of
.guard
.eq(z
.m
[0]),
1352 of
.round_bit
.eq(of
.guard
),
1353 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1358 def roundz(self
, m
, z
, roundz
):
1359 """ performs rounding on the output. TODO: different kinds of rounding
1362 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1363 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1364 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1366 def corrections(self
, m
, z
, next_state
):
1367 """ denormalisation and sign-bug corrections
1370 # denormalised, correct exponent to zero
1371 with m
.If(z
.is_denormalised
):
1372 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1374 def pack(self
, m
, z
, next_state
):
1375 """ packs the result into the output (detects overflow->Inf)
1378 # if overflow occurs, return inf
1379 with m
.If(z
.is_overflowed
):
1380 m
.d
.sync
+= z
.inf(z
.s
)
1382 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1384 def put_z(self
, m
, z
, out_z
, next_state
):
1385 """ put_z: stores the result in the output. raises stb and waits
1386 for ack to be set to 1 before moving to the next state.
1387 resets stb back to zero when that occurs, as acknowledgement.
1392 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1393 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1396 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1399 class FPState(FPBase
):
1400 def __init__(self
, state_from
):
1401 self
.state_from
= state_from
1403 def set_inputs(self
, inputs
):
1404 self
.inputs
= inputs
1405 for k
, v
in inputs
.items():
1408 def set_outputs(self
, outputs
):
1409 self
.outputs
= outputs
1410 for k
, v
in outputs
.items():
1415 def __init__(self
, id_wid
):
1416 self
.id_wid
= id_wid
1418 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1419 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1424 def idsync(self
, m
):
1425 if self
.id_wid
is not None:
1426 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1429 if __name__
== '__main__':