1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
9 from nmigen
import (Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
, Array
,
10 Value
, Shape
, signed
, unsigned
)
11 from nmigen
.utils
import bits_for
12 from operator
import or_
13 from functools
import reduce
15 from nmutil
.singlepipe
import PrevControl
, NextControl
16 from nmutil
.pipeline
import ObjectProxy
22 from nmigen
.hdl
.smtlib2
import RoundingModeEnum
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
31 class FPRoundingMode(enum
.Enum
):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
39 """Round to Nearest Even
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
47 ROUND_NEAREST_TIES_TO_EVEN
= RNE
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
58 ROUND_TOWARDS_ZERO
= RTZ
61 """Round towards +Infinity
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
69 ROUND_TOWARDS_POSITIVE
= RTP
72 """Round towards -Infinity
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
80 ROUND_TOWARDS_NEGATIVE
= RTN
83 """Round to Nearest Away
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
91 ROUND_NEAREST_TIES_TO_AWAY
= RNA
94 """Round to Odd, unsigned zeros are Positive
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
= RTOP
116 """Round to Odd, unsigned zeros are Negative
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
= RTON
139 l
= [None] * len(FPRoundingMode
)
140 for rm
in FPRoundingMode
:
144 def overflow_rounds_to_inf(self
, sign
):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
148 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
149 if self
is FPRoundingMode
.RNE
:
151 elif self
is FPRoundingMode
.RTZ
:
153 elif self
is FPRoundingMode
.RTP
:
155 elif self
is FPRoundingMode
.RTN
:
157 elif self
is FPRoundingMode
.RNA
:
159 elif self
is FPRoundingMode
.RTOP
:
162 assert self
is FPRoundingMode
.RTON
165 def underflow_rounds_to_zero(self
, sign
):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
169 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
170 if self
is FPRoundingMode
.RNE
:
172 elif self
is FPRoundingMode
.RTZ
:
174 elif self
is FPRoundingMode
.RTP
:
176 elif self
is FPRoundingMode
.RTN
:
178 elif self
is FPRoundingMode
.RNA
:
180 elif self
is FPRoundingMode
.RTOP
:
183 assert self
is FPRoundingMode
.RTON
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
190 if self
is FPRoundingMode
.RNE
:
192 elif self
is FPRoundingMode
.RTZ
:
194 elif self
is FPRoundingMode
.RTP
:
196 elif self
is FPRoundingMode
.RTN
:
198 elif self
is FPRoundingMode
.RNA
:
200 elif self
is FPRoundingMode
.RTOP
:
203 assert self
is FPRoundingMode
.RTON
207 def to_smtlib2(self
, default
=_raise_err
):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
212 if self
is FPRoundingMode
.RNE
:
213 return RoundingModeEnum
.RNE
214 elif self
is FPRoundingMode
.RTZ
:
215 return RoundingModeEnum
.RTZ
216 elif self
is FPRoundingMode
.RTP
:
217 return RoundingModeEnum
.RTP
218 elif self
is FPRoundingMode
.RTN
:
219 return RoundingModeEnum
.RTN
220 elif self
is FPRoundingMode
.RNA
:
221 return RoundingModeEnum
.RNA
223 assert self
in (FPRoundingMode
.RTOP
, FPRoundingMode
.RTON
)
224 if default
is _raise_err
:
226 "no corresponding smtlib2 rounding mode", self
)
233 """ Class describing binary floating-point formats based on IEEE 754.
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
249 """ Create ``FPFormat`` instance. """
250 self
.e_width
= e_width
251 self
.m_width
= m_width
252 self
.has_int_bit
= has_int_bit
253 self
.has_sign
= has_sign
255 def __eq__(self
, other
):
256 """ Check for equality. """
257 if not isinstance(other
, FPFormat
):
258 return NotImplemented
259 return (self
.e_width
== other
.e_width
260 and self
.m_width
== other
.m_width
261 and self
.has_int_bit
== other
.has_int_bit
262 and self
.has_sign
== other
.has_sign
)
266 """ Get standard IEEE 754-2008 format.
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
272 return FPFormat(5, 10)
274 return FPFormat(8, 23)
276 return FPFormat(11, 52)
278 return FPFormat(15, 112)
279 if width
> 128 and width
% 32 == 0:
280 if width
> 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width
= round(4 * math
.log2(width
)) - 13
283 return FPFormat(e_width
, width
- 1 - e_width
)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
290 if self
== self
.standard(self
.width
):
291 return f
"FPFormat.standard({self.width})"
294 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
295 if self
.has_int_bit
is not False:
296 retval
+= f
", {self.has_int_bit}"
297 if self
.has_sign
is not True:
298 retval
+= f
", {self.has_sign}"
301 def get_sign_field(self
, x
):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
305 return x
>> (self
.e_width
+ self
.m_width
)
307 def get_exponent_field(self
, x
):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
310 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
313 def get_exponent(self
, x
):
314 """ returns the exponent of its input number, x
316 x
= self
.get_exponent_field(x
)
317 if isinstance(x
, Value
) and not x
.shape().signed
:
318 # convert x to signed without changing its value,
319 # since exponents can be negative
320 x |
= Const(0, signed(1))
321 return x
- self
.exponent_bias
323 def get_mantissa_field(self
, x
):
324 """ returns the mantissa of its input number, x
326 return x
& self
.mantissa_mask
328 def get_mantissa_value(self
, x
):
329 """ returns the mantissa of its input number, x, but with the
330 implicit bit, if any, made explicit.
333 return self
.get_mantissa_field(x
)
334 exponent_field
= self
.get_exponent_field(x
)
335 mantissa_field
= self
.get_mantissa_field(x
)
336 implicit_bit
= exponent_field
!= self
.exponent_denormal_zero
337 return (implicit_bit
<< self
.fraction_width
) | mantissa_field
339 def is_zero(self
, x
):
340 """ returns true if x is +/- zero
342 return (self
.get_exponent(x
) == self
.e_sub
) & \
343 (self
.get_mantissa_field(x
) == 0)
345 def is_subnormal(self
, x
):
346 """ returns true if x is subnormal (exp at minimum)
348 return (self
.get_exponent(x
) == self
.e_sub
) & \
349 (self
.get_mantissa_field(x
) != 0)
352 """ returns true if x is infinite
354 return (self
.get_exponent(x
) == self
.e_max
) & \
355 (self
.get_mantissa_field(x
) == 0)
358 """ returns true if x is a nan (quiet or signalling)
360 return (self
.get_exponent(x
) == self
.e_max
) & \
361 (self
.get_mantissa_field(x
) != 0)
363 def is_quiet_nan(self
, x
):
364 """ returns true if x is a quiet nan
366 highbit
= 1 << (self
.m_width
- 1)
367 return (self
.get_exponent(x
) == self
.e_max
) & \
368 (self
.get_mantissa_field(x
) != 0) & \
369 (self
.get_mantissa_field(x
) & highbit
!= 0)
371 def to_quiet_nan(self
, x
):
372 """ converts `x` to a quiet NaN """
373 highbit
= 1 << (self
.m_width
- 1)
374 return x | highbit | self
.exponent_mask
376 def quiet_nan(self
, sign
=0):
377 """ return the default quiet NaN with sign `sign` """
378 return self
.to_quiet_nan(self
.zero(sign
))
380 def zero(self
, sign
=0):
381 """ return zero with sign `sign` """
382 return (sign
!= 0) << (self
.e_width
+ self
.m_width
)
384 def inf(self
, sign
=0):
385 """ return infinity with sign `sign` """
386 return self
.zero(sign
) | self
.exponent_mask
388 def is_nan_signaling(self
, x
):
389 """ returns true if x is a signalling nan
391 highbit
= 1 << (self
.m_width
- 1)
392 return (self
.get_exponent(x
) == self
.e_max
) & \
393 (self
.get_mantissa_field(x
) != 0) & \
394 (self
.get_mantissa_field(x
) & highbit
) == 0
398 """ Get the total number of bits in the FP format. """
399 return self
.has_sign
+ self
.e_width
+ self
.m_width
402 def mantissa_mask(self
):
403 """ Get a mantissa mask based on the mantissa width """
404 return (1 << self
.m_width
) - 1
407 def exponent_mask(self
):
408 """ Get an exponent mask """
409 return self
.exponent_inf_nan
<< self
.m_width
412 def exponent_inf_nan(self
):
413 """ Get the value of the exponent field designating infinity/NaN. """
414 return (1 << self
.e_width
) - 1
418 """ get the maximum exponent (minus bias)
420 return self
.exponent_inf_nan
- self
.exponent_bias
424 return self
.exponent_denormal_zero
- self
.exponent_bias
426 def exponent_denormal_zero(self
):
427 """ Get the value of the exponent field designating denormal/zero. """
431 def exponent_min_normal(self
):
432 """ Get the minimum value of the exponent field for normal numbers. """
436 def exponent_max_normal(self
):
437 """ Get the maximum value of the exponent field for normal numbers. """
438 return self
.exponent_inf_nan
- 1
441 def exponent_bias(self
):
442 """ Get the exponent bias. """
443 return (1 << (self
.e_width
- 1)) - 1
446 def fraction_width(self
):
447 """ Get the number of mantissa bits that are fraction bits. """
448 return self
.m_width
- self
.has_int_bit
451 class TestFPFormat(unittest
.TestCase
):
452 """ very quick test for FPFormat
455 def test_fpformat_fp64(self
):
456 f64
= FPFormat
.standard(64)
457 from sfpy
import Float64
458 x
= Float64(1.0).bits
461 self
.assertEqual(f64
.get_exponent(x
), 0)
462 x
= Float64(2.0).bits
464 self
.assertEqual(f64
.get_exponent(x
), 1)
466 x
= Float64(1.5).bits
467 m
= f64
.get_mantissa_field(x
)
468 print (hex(x
), hex(m
))
469 self
.assertEqual(m
, 0x8000000000000)
471 s
= f64
.get_sign_field(x
)
472 print (hex(x
), hex(s
))
473 self
.assertEqual(s
, 0)
475 x
= Float64(-1.5).bits
476 s
= f64
.get_sign_field(x
)
477 print (hex(x
), hex(s
))
478 self
.assertEqual(s
, 1)
480 def test_fpformat_fp32(self
):
481 f32
= FPFormat
.standard(32)
482 from sfpy
import Float32
483 x
= Float32(1.0).bits
486 self
.assertEqual(f32
.get_exponent(x
), 0)
487 x
= Float32(2.0).bits
489 self
.assertEqual(f32
.get_exponent(x
), 1)
491 x
= Float32(1.5).bits
492 m
= f32
.get_mantissa_field(x
)
493 print (hex(x
), hex(m
))
494 self
.assertEqual(m
, 0x400000)
497 x
= Float32(-1.0).sqrt()
500 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
501 f32
.get_mantissa_field(x
), i
)
502 self
.assertEqual(i
, True)
505 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
508 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
509 f32
.get_mantissa_field(x
), i
)
510 self
.assertEqual(i
, True)
515 i
= f32
.is_subnormal(x
)
516 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
517 f32
.get_mantissa_field(x
), i
)
518 self
.assertEqual(i
, True)
522 i
= f32
.is_subnormal(x
)
523 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
524 f32
.get_mantissa_field(x
), i
)
525 self
.assertEqual(i
, False)
529 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
530 f32
.get_mantissa_field(x
), i
)
531 self
.assertEqual(i
, True)
534 class MultiShiftR(Elaboratable
):
536 def __init__(self
, width
):
538 self
.smax
= bits_for(width
- 1)
539 self
.i
= Signal(width
, reset_less
=True)
540 self
.s
= Signal(self
.smax
, reset_less
=True)
541 self
.o
= Signal(width
, reset_less
=True)
543 def elaborate(self
, platform
):
545 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
550 """ Generates variable-length single-cycle shifter from a series
551 of conditional tests on each bit of the left/right shift operand.
552 Each bit tested produces output shifted by that number of bits,
553 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
554 shifts by 2 bits, each partial result cascading to the next Mux.
556 Could be adapted to do arithmetic shift by taking copies of the
557 MSB instead of zeros.
560 def __init__(self
, width
):
562 self
.smax
= bits_for(width
- 1)
564 def lshift(self
, op
, s
):
568 def rshift(self
, op
, s
):
573 class FPNumBaseRecord
:
574 """ Floating-point Base Number Class.
576 This class is designed to be passed around in other data structures
577 (between pipelines and between stages). Its "friend" is FPNumBase,
578 which is a *module*. The reason for the discernment is because
579 nmigen modules that are not added to submodules results in the
580 irritating "Elaboration" warning. Despite not *needing* FPNumBase
581 in many cases to be added as a submodule (because it is just data)
582 this was not possible to solve without splitting out the data from
586 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
589 # assert false, "missing name"
593 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
594 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
595 e_max
= 1 << (e_width
-3)
596 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
599 # mantissa extra bits (top,guard,round)
601 m_width
+= self
.m_extra
605 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
606 e_width
+= self
.e_extra
609 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
610 self
.m_width
= m_width
611 self
.e_width
= e_width
612 self
.e_start
= self
.rmw
613 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
615 self
.v
= Signal(width
, reset_less
=True,
616 name
=name
+"v") # Latched copy of value
617 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
618 self
.e
= Signal(signed(e_width
),
619 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
620 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
625 def drop_in(self
, fp
):
631 fp
.width
= self
.width
632 fp
.e_width
= self
.e_width
633 fp
.e_max
= self
.e_max
634 fp
.m_width
= self
.m_width
635 fp
.e_start
= self
.e_start
636 fp
.e_end
= self
.e_end
637 fp
.m_extra
= self
.m_extra
639 m_width
= self
.m_width
641 e_width
= self
.e_width
643 self
.mzero
= Const(0, unsigned(m_width
))
644 m_msb
= 1 << (self
.m_width
-2)
645 self
.msb1
= Const(m_msb
, unsigned(m_width
))
646 self
.m1s
= Const(-1, unsigned(m_width
))
647 self
.P128
= Const(e_max
, signed(e_width
))
648 self
.P127
= Const(e_max
-1, signed(e_width
))
649 self
.N127
= Const(-(e_max
-1), signed(e_width
))
650 self
.N126
= Const(-(e_max
-2), signed(e_width
))
652 def create(self
, s
, e
, m
):
653 """ creates a value from sign / exponent / mantissa
655 bias is added here, to the exponent.
657 NOTE: order is important, because e_start/e_end can be
658 a bit too long (overwriting s).
661 self
.v
[0:self
.e_start
].eq(m
), # mantissa
662 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
663 self
.v
[-1].eq(s
), # sign
667 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
670 return (s
, self
.fp
.P128
, 0)
673 return (s
, self
.fp
.N127
, 0)
676 return self
.create(*self
._nan
(s
))
678 def quieted_nan(self
, other
):
679 assert isinstance(other
, FPNumBaseRecord
)
680 assert self
.width
== other
.width
681 return self
.create(other
.s
, self
.fp
.P128
,
682 other
.v
[0:self
.e_start
] |
(1 << (self
.e_start
- 1)))
685 return self
.create(*self
._inf
(s
))
687 def max_normal(self
, s
):
688 return self
.create(s
, self
.fp
.P127
, ~
0)
690 def min_denormal(self
, s
):
691 return self
.create(s
, self
.fp
.N127
, 1)
694 return self
.create(*self
._zero
(s
))
696 def create2(self
, s
, e
, m
):
697 """ creates a value from sign / exponent / mantissa
699 bias is added here, to the exponent
701 e
= e
+ self
.P127
# exp (add on bias)
702 return Cat(m
[0:self
.e_start
],
703 e
[0:self
.e_end
-self
.e_start
],
707 return self
.create2(s
, self
.P128
, self
.msb1
)
710 return self
.create2(s
, self
.P128
, self
.mzero
)
713 return self
.create2(s
, self
.N127
, self
.mzero
)
721 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
724 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
725 """ Floating-point Base Number Class
728 def __init__(self
, fp
):
733 self
.is_nan
= Signal(reset_less
=True)
734 self
.is_zero
= Signal(reset_less
=True)
735 self
.is_inf
= Signal(reset_less
=True)
736 self
.is_overflowed
= Signal(reset_less
=True)
737 self
.is_denormalised
= Signal(reset_less
=True)
738 self
.exp_128
= Signal(reset_less
=True)
739 self
.exp_sub_n126
= Signal(signed(e_width
), reset_less
=True)
740 self
.exp_lt_n126
= Signal(reset_less
=True)
741 self
.exp_zero
= Signal(reset_less
=True)
742 self
.exp_gt_n126
= Signal(reset_less
=True)
743 self
.exp_gt127
= Signal(reset_less
=True)
744 self
.exp_n127
= Signal(reset_less
=True)
745 self
.exp_n126
= Signal(reset_less
=True)
746 self
.m_zero
= Signal(reset_less
=True)
747 self
.m_msbzero
= Signal(reset_less
=True)
749 def elaborate(self
, platform
):
751 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
752 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
753 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
754 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
755 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
756 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
757 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
758 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
759 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
760 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
761 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
762 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
763 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
764 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
765 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
770 return (self
.exp_128
) & (~self
.m_zero
)
773 return (self
.exp_128
) & (self
.m_zero
)
776 return (self
.exp_n127
) & (self
.m_zero
)
778 def _is_overflowed(self
):
779 return self
.exp_gt127
781 def _is_denormalised(self
):
782 # XXX NOT to be used for "official" quiet NaN tests!
783 # particularly when the MSB has been extended
784 return (self
.exp_n126
) & (self
.m_msbzero
)
787 class FPNumOut(FPNumBase
):
788 """ Floating-point Number Class
790 Contains signals for an incoming copy of the value, decoded into
791 sign / exponent / mantissa.
792 Also contains encoding functions, creation and recognition of
793 zero, NaN and inf (all signed)
795 Four extra bits are included in the mantissa: the top bit
796 (m[-1]) is effectively a carry-overflow. The other three are
797 guard (m[2]), round (m[1]), and sticky (m[0])
800 def __init__(self
, fp
):
801 FPNumBase
.__init
__(self
, fp
)
803 def elaborate(self
, platform
):
804 m
= FPNumBase
.elaborate(self
, platform
)
809 class MultiShiftRMerge(Elaboratable
):
810 """ shifts down (right) and merges lower bits into m[0].
811 m[0] is the "sticky" bit, basically
814 def __init__(self
, width
, s_max
=None):
816 s_max
= bits_for(width
- 1)
817 self
.smax
= Shape
.cast(s_max
)
818 self
.m
= Signal(width
, reset_less
=True)
819 self
.inp
= Signal(width
, reset_less
=True)
820 self
.diff
= Signal(s_max
, reset_less
=True)
823 def elaborate(self
, platform
):
826 rs
= Signal(self
.width
, reset_less
=True)
827 m_mask
= Signal(self
.width
, reset_less
=True)
828 smask
= Signal(self
.width
, reset_less
=True)
829 stickybit
= Signal(reset_less
=True)
830 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
831 maxslen
= Signal(self
.smax
.width
, reset_less
=True)
832 maxsleni
= Signal(self
.smax
.width
, reset_less
=True)
834 sm
= MultiShift(self
.width
-1)
835 m0s
= Const(0, self
.width
-1)
836 mw
= Const(self
.width
-1, len(self
.diff
))
837 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
838 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
842 # shift mantissa by maxslen, mask by inverse
843 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
844 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
845 smask
.eq(self
.inp
[1:] & m_mask
),
846 # sticky bit combines all mask (and mantissa low bit)
847 stickybit
.eq(smask
.bool() | self
.inp
[0]),
848 # mantissa result contains m[0] already.
849 self
.m
.eq(Cat(stickybit
, rs
))
854 class FPNumShift(FPNumBase
, Elaboratable
):
855 """ Floating-point Number Class for shifting
858 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
859 FPNumBase
.__init
__(self
, width
, m_extra
)
860 self
.latch_in
= Signal()
865 def elaborate(self
, platform
):
866 m
= FPNumBase
.elaborate(self
, platform
)
868 m
.d
.comb
+= self
.s
.eq(op
.s
)
869 m
.d
.comb
+= self
.e
.eq(op
.e
)
870 m
.d
.comb
+= self
.m
.eq(op
.m
)
872 with self
.mainm
.State("align"):
873 with m
.If(self
.e
< self
.inv
.e
):
874 m
.d
.sync
+= self
.shift_down()
878 def shift_down(self
, inp
):
879 """ shifts a mantissa down by one. exponent is increased to compensate
881 accuracy is lost as a result in the mantissa however there are 3
882 guard bits (the latter of which is the "sticky" bit)
884 return [self
.e
.eq(inp
.e
+ 1),
885 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
888 def shift_down_multi(self
, diff
):
889 """ shifts a mantissa down. exponent is increased to compensate
891 accuracy is lost as a result in the mantissa however there are 3
892 guard bits (the latter of which is the "sticky" bit)
894 this code works by variable-shifting the mantissa by up to
895 its maximum bit-length: no point doing more (it'll still be
898 the sticky bit is computed by shifting a batch of 1s by
899 the same amount, which will introduce zeros. it's then
900 inverted and used as a mask to get the LSBs of the mantissa.
901 those are then |'d into the sticky bit.
903 sm
= MultiShift(self
.width
)
904 mw
= Const(self
.m_width
-1, len(diff
))
905 maxslen
= Mux(diff
> mw
, mw
, diff
)
906 rs
= sm
.rshift(self
.m
[1:], maxslen
)
907 maxsleni
= mw
- maxslen
908 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
910 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
911 return [self
.e
.eq(self
.e
+ diff
),
912 self
.m
.eq(Cat(stickybits
, rs
))
915 def shift_up_multi(self
, diff
):
916 """ shifts a mantissa up. exponent is decreased to compensate
918 sm
= MultiShift(self
.width
)
919 mw
= Const(self
.m_width
, len(diff
))
920 maxslen
= Mux(diff
> mw
, mw
, diff
)
922 return [self
.e
.eq(self
.e
- diff
),
923 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
927 class FPNumDecode(FPNumBase
):
928 """ Floating-point Number Class
930 Contains signals for an incoming copy of the value, decoded into
931 sign / exponent / mantissa.
932 Also contains encoding functions, creation and recognition of
933 zero, NaN and inf (all signed)
935 Four extra bits are included in the mantissa: the top bit
936 (m[-1]) is effectively a carry-overflow. The other three are
937 guard (m[2]), round (m[1]), and sticky (m[0])
940 def __init__(self
, op
, fp
):
941 FPNumBase
.__init
__(self
, fp
)
944 def elaborate(self
, platform
):
945 m
= FPNumBase
.elaborate(self
, platform
)
947 m
.d
.comb
+= self
.decode(self
.v
)
952 """ decodes a latched value into sign / exponent / mantissa
954 bias is subtracted here, from the exponent. exponent
955 is extended to 10 bits so that subtract 127 is done on
958 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
959 #print ("decode", self.e_end)
960 return [self
.m
.eq(Cat(*args
)), # mantissa
961 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
962 self
.s
.eq(v
[-1]), # sign
966 class FPNumIn(FPNumBase
):
967 """ Floating-point Number Class
969 Contains signals for an incoming copy of the value, decoded into
970 sign / exponent / mantissa.
971 Also contains encoding functions, creation and recognition of
972 zero, NaN and inf (all signed)
974 Four extra bits are included in the mantissa: the top bit
975 (m[-1]) is effectively a carry-overflow. The other three are
976 guard (m[2]), round (m[1]), and sticky (m[0])
979 def __init__(self
, op
, fp
):
980 FPNumBase
.__init
__(self
, fp
)
981 self
.latch_in
= Signal()
984 def decode2(self
, m
):
985 """ decodes a latched value into sign / exponent / mantissa
987 bias is subtracted here, from the exponent. exponent
988 is extended to 10 bits so that subtract 127 is done on
992 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
993 #print ("decode", self.e_end)
994 res
= ObjectProxy(m
, pipemode
=False)
995 res
.m
= Cat(*args
) # mantissa
996 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
1000 def decode(self
, v
):
1001 """ decodes a latched value into sign / exponent / mantissa
1003 bias is subtracted here, from the exponent. exponent
1004 is extended to 10 bits so that subtract 127 is done on
1007 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1008 #print ("decode", self.e_end)
1009 return [self
.m
.eq(Cat(*args
)), # mantissa
1010 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
1011 self
.s
.eq(v
[-1]), # sign
1014 def shift_down(self
, inp
):
1015 """ shifts a mantissa down by one. exponent is increased to compensate
1017 accuracy is lost as a result in the mantissa however there are 3
1018 guard bits (the latter of which is the "sticky" bit)
1020 return [self
.e
.eq(inp
.e
+ 1),
1021 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
1024 def shift_down_multi(self
, diff
, inp
=None):
1025 """ shifts a mantissa down. exponent is increased to compensate
1027 accuracy is lost as a result in the mantissa however there are 3
1028 guard bits (the latter of which is the "sticky" bit)
1030 this code works by variable-shifting the mantissa by up to
1031 its maximum bit-length: no point doing more (it'll still be
1034 the sticky bit is computed by shifting a batch of 1s by
1035 the same amount, which will introduce zeros. it's then
1036 inverted and used as a mask to get the LSBs of the mantissa.
1037 those are then |'d into the sticky bit.
1041 sm
= MultiShift(self
.width
)
1042 mw
= Const(self
.m_width
-1, len(diff
))
1043 maxslen
= Mux(diff
> mw
, mw
, diff
)
1044 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
1045 maxsleni
= mw
- maxslen
1046 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
1048 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1049 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
1050 return [self
.e
.eq(inp
.e
+ diff
),
1051 self
.m
.eq(Cat(stickybit
, rs
))
1054 def shift_up_multi(self
, diff
):
1055 """ shifts a mantissa up. exponent is decreased to compensate
1057 sm
= MultiShift(self
.width
)
1058 mw
= Const(self
.m_width
, len(diff
))
1059 maxslen
= Mux(diff
> mw
, mw
, diff
)
1061 return [self
.e
.eq(self
.e
- diff
),
1062 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
1066 class Trigger(Elaboratable
):
1069 self
.stb
= Signal(reset
=0)
1071 self
.trigger
= Signal(reset_less
=True)
1073 def elaborate(self
, platform
):
1075 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
1079 return [self
.stb
.eq(inp
.stb
),
1080 self
.ack
.eq(inp
.ack
)
1084 return [self
.stb
, self
.ack
]
1087 class FPOpIn(PrevControl
):
1088 def __init__(self
, width
):
1089 PrevControl
.__init
__(self
)
1096 def chain_inv(self
, in_op
, extra
=None):
1098 if extra
is not None:
1100 return [self
.v
.eq(in_op
.v
), # receive value
1101 self
.stb
.eq(stb
), # receive STB
1102 in_op
.ack
.eq(~self
.ack
), # send ACK
1105 def chain_from(self
, in_op
, extra
=None):
1107 if extra
is not None:
1109 return [self
.v
.eq(in_op
.v
), # receive value
1110 self
.stb
.eq(stb
), # receive STB
1111 in_op
.ack
.eq(self
.ack
), # send ACK
1115 class FPOpOut(NextControl
):
1116 def __init__(self
, width
):
1117 NextControl
.__init
__(self
)
1124 def chain_inv(self
, in_op
, extra
=None):
1126 if extra
is not None:
1128 return [self
.v
.eq(in_op
.v
), # receive value
1129 self
.stb
.eq(stb
), # receive STB
1130 in_op
.ack
.eq(~self
.ack
), # send ACK
1133 def chain_from(self
, in_op
, extra
=None):
1135 if extra
is not None:
1137 return [self
.v
.eq(in_op
.v
), # receive value
1138 self
.stb
.eq(stb
), # receive STB
1139 in_op
.ack
.eq(self
.ack
), # send ACK
1144 # TODO: change FFLAGS to be FPSCR's status flags
1145 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
1146 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
1147 FFLAGS_OF
= Const(1<<2, 5) # overflow
1148 FFLAGS_UF
= Const(1<<1, 5) # underflow
1149 FFLAGS_NX
= Const(1<<0, 5) # inexact
1150 def __init__(self
, name
=None):
1153 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
1154 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
1155 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
1156 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
1157 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
1159 self
.sign
= Signal(reset_less
=True, name
=name
+"sign")
1160 """sign bit -- 1 means negative, 0 means positive"""
1162 self
.rm
= Signal(FPRoundingMode
, name
=name
+"rm",
1163 reset
=FPRoundingMode
.DEFAULT
)
1166 #self.roundz = Signal(reset_less=True)
1170 yield self
.round_bit
1178 return [self
.guard
.eq(inp
.guard
),
1179 self
.round_bit
.eq(inp
.round_bit
),
1180 self
.sticky
.eq(inp
.sticky
),
1182 self
.fpflags
.eq(inp
.fpflags
),
1183 self
.sign
.eq(inp
.sign
),
1187 def roundz_rne(self
):
1188 """true if the mantissa should be rounded up for `rm == RNE`
1190 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1192 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
1195 def roundz_rna(self
):
1196 """true if the mantissa should be rounded up for `rm == RNA`
1198 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1203 def roundz_rtn(self
):
1204 """true if the mantissa should be rounded up for `rm == RTN`
1206 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1208 return self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1211 def roundz_rto(self
):
1212 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1214 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1215 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1217 return ~self
.m0
& (self
.guard | self
.round_bit | self
.sticky
)
1220 def roundz_rtp(self
):
1221 """true if the mantissa should be rounded up for `rm == RTP`
1223 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1225 return ~self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1228 def roundz_rtz(self
):
1229 """true if the mantissa should be rounded up for `rm == RTZ`
1231 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1237 """true if the mantissa should be rounded up for the current rounding
1241 FPRoundingMode
.RNA
: self
.roundz_rna
,
1242 FPRoundingMode
.RNE
: self
.roundz_rne
,
1243 FPRoundingMode
.RTN
: self
.roundz_rtn
,
1244 FPRoundingMode
.RTOP
: self
.roundz_rto
,
1245 FPRoundingMode
.RTON
: self
.roundz_rto
,
1246 FPRoundingMode
.RTP
: self
.roundz_rtp
,
1247 FPRoundingMode
.RTZ
: self
.roundz_rtz
,
1249 return FPRoundingMode
.make_array(lambda rm
: d
[rm
])[self
.rm
]
1252 class OverflowMod(Elaboratable
, Overflow
):
1253 def __init__(self
, name
=None):
1254 Overflow
.__init
__(self
, name
)
1257 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
1260 yield from Overflow
.__iter
__(self
)
1261 yield self
.roundz_out
1264 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
1266 def elaborate(self
, platform
):
1268 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
1273 """ IEEE754 Floating Point Base Class
1275 contains common functions for FP manipulation, such as
1276 extracting and packing operands, normalisation, denormalisation,
1280 def get_op(self
, m
, op
, v
, next_state
):
1281 """ this function moves to the next state and copies the operand
1282 when both stb and ack are 1.
1283 acknowledgement is sent by setting ack to ZERO.
1287 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
1289 # op is latched in from FPNumIn class on same ack/stb
1290 m
.d
.comb
+= ack
.eq(0)
1292 m
.d
.comb
+= ack
.eq(1)
1295 def denormalise(self
, m
, a
):
1296 """ denormalises a number. this is probably the wrong name for
1297 this function. for normalised numbers (exponent != minimum)
1298 one *extra* bit (the implicit 1) is added *back in*.
1299 for denormalised numbers, the mantissa is left alone
1300 and the exponent increased by 1.
1302 both cases *effectively multiply the number stored by 2*,
1303 which has to be taken into account when extracting the result.
1305 with m
.If(a
.exp_n127
):
1306 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
1308 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
1310 def op_normalise(self
, m
, op
, next_state
):
1311 """ operand normalisation
1312 NOTE: just like "align", this one keeps going round every clock
1313 until the result's exponent is within acceptable "range"
1315 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
1317 op
.e
.eq(op
.e
- 1), # DECREASE exponent
1318 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
1323 def normalise_1(self
, m
, z
, of
, next_state
):
1324 """ first stage normalisation
1326 NOTE: just like "align", this one keeps going round every clock
1327 until the result's exponent is within acceptable "range"
1328 NOTE: the weirdness of reassigning guard and round is due to
1329 the extra mantissa bits coming from tot[0..2]
1331 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
1333 z
.e
.eq(z
.e
- 1), # DECREASE exponent
1334 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1335 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1336 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1337 of
.round_bit
.eq(0), # reset round bit
1343 def normalise_2(self
, m
, z
, of
, next_state
):
1344 """ second stage normalisation
1346 NOTE: just like "align", this one keeps going round every clock
1347 until the result's exponent is within acceptable "range"
1348 NOTE: the weirdness of reassigning guard and round is due to
1349 the extra mantissa bits coming from tot[0..2]
1351 with m
.If(z
.e
< z
.fp
.N126
):
1353 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1354 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1355 of
.guard
.eq(z
.m
[0]),
1357 of
.round_bit
.eq(of
.guard
),
1358 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1363 def roundz(self
, m
, z
, roundz
):
1364 """ performs rounding on the output. TODO: different kinds of rounding
1367 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1368 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1369 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1371 def corrections(self
, m
, z
, next_state
):
1372 """ denormalisation and sign-bug corrections
1375 # denormalised, correct exponent to zero
1376 with m
.If(z
.is_denormalised
):
1377 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1379 def pack(self
, m
, z
, next_state
):
1380 """ packs the result into the output (detects overflow->Inf)
1383 # if overflow occurs, return inf
1384 with m
.If(z
.is_overflowed
):
1385 m
.d
.sync
+= z
.inf(z
.s
)
1387 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1389 def put_z(self
, m
, z
, out_z
, next_state
):
1390 """ put_z: stores the result in the output. raises stb and waits
1391 for ack to be set to 1 before moving to the next state.
1392 resets stb back to zero when that occurs, as acknowledgement.
1397 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1398 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1401 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1404 class FPState(FPBase
):
1405 def __init__(self
, state_from
):
1406 self
.state_from
= state_from
1408 def set_inputs(self
, inputs
):
1409 self
.inputs
= inputs
1410 for k
, v
in inputs
.items():
1413 def set_outputs(self
, outputs
):
1414 self
.outputs
= outputs
1415 for k
, v
in outputs
.items():
1420 def __init__(self
, id_wid
):
1421 self
.id_wid
= id_wid
1423 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1424 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1429 def idsync(self
, m
):
1430 if self
.id_wid
is not None:
1431 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1434 if __name__
== '__main__':