1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
9 from nmigen
import (Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
, Array
,
10 Value
, Shape
, signed
, unsigned
)
11 from nmigen
.utils
import bits_for
12 from operator
import or_
13 from functools
import reduce
15 from nmutil
.singlepipe
import PrevControl
, NextControl
16 from nmutil
.pipeline
import ObjectProxy
22 from nmigen
.hdl
.smtlib2
import RoundingModeEnum
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
31 class FPRoundingMode(enum
.Enum
):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
39 """Round to Nearest Even
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
47 ROUND_NEAREST_TIES_TO_EVEN
= RNE
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
58 ROUND_TOWARDS_ZERO
= RTZ
61 """Round towards +Infinity
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
69 ROUND_TOWARDS_POSITIVE
= RTP
72 """Round towards -Infinity
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
80 ROUND_TOWARDS_NEGATIVE
= RTN
83 """Round to Nearest Away
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
91 ROUND_NEAREST_TIES_TO_AWAY
= RNA
94 """Round to Odd, unsigned zeros are Positive
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
= RTOP
116 """Round to Odd, unsigned zeros are Negative
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
= RTON
139 l
= [None] * len(FPRoundingMode
)
140 for rm
in FPRoundingMode
:
144 def overflow_rounds_to_inf(self
, sign
):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
148 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
149 if self
is FPRoundingMode
.RNE
:
151 elif self
is FPRoundingMode
.RTZ
:
153 elif self
is FPRoundingMode
.RTP
:
155 elif self
is FPRoundingMode
.RTN
:
157 elif self
is FPRoundingMode
.RNA
:
159 elif self
is FPRoundingMode
.RTOP
:
162 assert self
is FPRoundingMode
.RTON
165 def underflow_rounds_to_zero(self
, sign
):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
169 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
170 if self
is FPRoundingMode
.RNE
:
172 elif self
is FPRoundingMode
.RTZ
:
174 elif self
is FPRoundingMode
.RTP
:
176 elif self
is FPRoundingMode
.RTN
:
178 elif self
is FPRoundingMode
.RNA
:
180 elif self
is FPRoundingMode
.RTOP
:
183 assert self
is FPRoundingMode
.RTON
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
190 if self
is FPRoundingMode
.RNE
:
192 elif self
is FPRoundingMode
.RTZ
:
194 elif self
is FPRoundingMode
.RTP
:
196 elif self
is FPRoundingMode
.RTN
:
198 elif self
is FPRoundingMode
.RNA
:
200 elif self
is FPRoundingMode
.RTOP
:
203 assert self
is FPRoundingMode
.RTON
207 def to_smtlib2(self
, default
=_raise_err
):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
212 if self
is FPRoundingMode
.RNE
:
213 return RoundingModeEnum
.RNE
214 elif self
is FPRoundingMode
.RTZ
:
215 return RoundingModeEnum
.RTZ
216 elif self
is FPRoundingMode
.RTP
:
217 return RoundingModeEnum
.RTP
218 elif self
is FPRoundingMode
.RTN
:
219 return RoundingModeEnum
.RTN
220 elif self
is FPRoundingMode
.RNA
:
221 return RoundingModeEnum
.RNA
223 assert self
in (FPRoundingMode
.RTOP
, FPRoundingMode
.RTON
)
224 if default
is _raise_err
:
226 "no corresponding smtlib2 rounding mode", self
)
233 """ Class describing binary floating-point formats based on IEEE 754.
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
249 """ Create ``FPFormat`` instance. """
250 self
.e_width
= e_width
251 self
.m_width
= m_width
252 self
.has_int_bit
= has_int_bit
253 self
.has_sign
= has_sign
255 def __eq__(self
, other
):
256 """ Check for equality. """
257 if not isinstance(other
, FPFormat
):
258 return NotImplemented
259 return (self
.e_width
== other
.e_width
260 and self
.m_width
== other
.m_width
261 and self
.has_int_bit
== other
.has_int_bit
262 and self
.has_sign
== other
.has_sign
)
266 """ Get standard IEEE 754-2008 format.
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
272 return FPFormat(5, 10)
274 return FPFormat(8, 23)
276 return FPFormat(11, 52)
278 return FPFormat(15, 112)
279 if width
> 128 and width
% 32 == 0:
280 if width
> 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width
= round(4 * math
.log2(width
)) - 13
283 return FPFormat(e_width
, width
- 1 - e_width
)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
290 if self
== self
.standard(self
.width
):
291 return f
"FPFormat.standard({self.width})"
294 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
295 if self
.has_int_bit
is not False:
296 retval
+= f
", {self.has_int_bit}"
297 if self
.has_sign
is not True:
298 retval
+= f
", {self.has_sign}"
301 def get_sign_field(self
, x
):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
305 return x
>> (self
.e_width
+ self
.m_width
)
307 def get_exponent_field(self
, x
):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
310 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
313 def get_exponent(self
, x
):
314 """ returns the exponent of its input number, x
316 x
= self
.get_exponent_field(x
)
317 if isinstance(x
, Value
) and not x
.shape().signed
:
318 # convert x to signed without changing its value,
319 # since exponents can be negative
320 x |
= Const(0, signed(1))
321 return x
- self
.exponent_bias
323 def get_exponent_value(self
, x
):
324 """ returns the exponent of its input number, x, adjusted for the
325 mathematically correct subnormal exponent.
327 x
= self
.get_exponent_field(x
)
328 if isinstance(x
, Value
) and not x
.shape().signed
:
329 # convert x to signed without changing its value,
330 # since exponents can be negative
331 x |
= Const(0, signed(1))
332 return x
+ (x
== self
.exponent_denormal_zero
) - self
.exponent_bias
334 def get_mantissa_field(self
, x
):
335 """ returns the mantissa of its input number, x
337 return x
& self
.mantissa_mask
339 def get_mantissa_value(self
, x
):
340 """ returns the mantissa of its input number, x, but with the
341 implicit bit, if any, made explicit.
344 return self
.get_mantissa_field(x
)
345 exponent_field
= self
.get_exponent_field(x
)
346 mantissa_field
= self
.get_mantissa_field(x
)
347 implicit_bit
= exponent_field
!= self
.exponent_denormal_zero
348 return (implicit_bit
<< self
.fraction_width
) | mantissa_field
350 def is_zero(self
, x
):
351 """ returns true if x is +/- zero
353 return (self
.get_exponent(x
) == self
.e_sub
) & \
354 (self
.get_mantissa_field(x
) == 0)
356 def is_subnormal(self
, x
):
357 """ returns true if x is subnormal (exp at minimum)
359 return (self
.get_exponent(x
) == self
.e_sub
) & \
360 (self
.get_mantissa_field(x
) != 0)
363 """ returns true if x is infinite
365 return (self
.get_exponent(x
) == self
.e_max
) & \
366 (self
.get_mantissa_field(x
) == 0)
369 """ returns true if x is a nan (quiet or signalling)
371 return (self
.get_exponent(x
) == self
.e_max
) & \
372 (self
.get_mantissa_field(x
) != 0)
374 def is_quiet_nan(self
, x
):
375 """ returns true if x is a quiet nan
377 highbit
= 1 << (self
.m_width
- 1)
378 return (self
.get_exponent(x
) == self
.e_max
) & \
379 (self
.get_mantissa_field(x
) != 0) & \
380 (self
.get_mantissa_field(x
) & highbit
!= 0)
382 def to_quiet_nan(self
, x
):
383 """ converts `x` to a quiet NaN """
384 highbit
= 1 << (self
.m_width
- 1)
385 return x | highbit | self
.exponent_mask
387 def quiet_nan(self
, sign
=0):
388 """ return the default quiet NaN with sign `sign` """
389 return self
.to_quiet_nan(self
.zero(sign
))
391 def zero(self
, sign
=0):
392 """ return zero with sign `sign` """
393 return (sign
!= 0) << (self
.e_width
+ self
.m_width
)
395 def inf(self
, sign
=0):
396 """ return infinity with sign `sign` """
397 return self
.zero(sign
) | self
.exponent_mask
399 def is_nan_signaling(self
, x
):
400 """ returns true if x is a signalling nan
402 highbit
= 1 << (self
.m_width
- 1)
403 return (self
.get_exponent(x
) == self
.e_max
) & \
404 (self
.get_mantissa_field(x
) != 0) & \
405 (self
.get_mantissa_field(x
) & highbit
) == 0
409 """ Get the total number of bits in the FP format. """
410 return self
.has_sign
+ self
.e_width
+ self
.m_width
413 def mantissa_mask(self
):
414 """ Get a mantissa mask based on the mantissa width """
415 return (1 << self
.m_width
) - 1
418 def exponent_mask(self
):
419 """ Get an exponent mask """
420 return self
.exponent_inf_nan
<< self
.m_width
423 def exponent_inf_nan(self
):
424 """ Get the value of the exponent field designating infinity/NaN. """
425 return (1 << self
.e_width
) - 1
429 """ get the maximum exponent (minus bias)
431 return self
.exponent_inf_nan
- self
.exponent_bias
435 return self
.exponent_denormal_zero
- self
.exponent_bias
437 def exponent_denormal_zero(self
):
438 """ Get the value of the exponent field designating denormal/zero. """
442 def exponent_min_normal(self
):
443 """ Get the minimum value of the exponent field for normal numbers. """
447 def exponent_max_normal(self
):
448 """ Get the maximum value of the exponent field for normal numbers. """
449 return self
.exponent_inf_nan
- 1
452 def exponent_bias(self
):
453 """ Get the exponent bias. """
454 return (1 << (self
.e_width
- 1)) - 1
457 def fraction_width(self
):
458 """ Get the number of mantissa bits that are fraction bits. """
459 return self
.m_width
- self
.has_int_bit
462 class TestFPFormat(unittest
.TestCase
):
463 """ very quick test for FPFormat
466 def test_fpformat_fp64(self
):
467 f64
= FPFormat
.standard(64)
468 from sfpy
import Float64
469 x
= Float64(1.0).bits
472 self
.assertEqual(f64
.get_exponent(x
), 0)
473 x
= Float64(2.0).bits
475 self
.assertEqual(f64
.get_exponent(x
), 1)
477 x
= Float64(1.5).bits
478 m
= f64
.get_mantissa_field(x
)
479 print (hex(x
), hex(m
))
480 self
.assertEqual(m
, 0x8000000000000)
482 s
= f64
.get_sign_field(x
)
483 print (hex(x
), hex(s
))
484 self
.assertEqual(s
, 0)
486 x
= Float64(-1.5).bits
487 s
= f64
.get_sign_field(x
)
488 print (hex(x
), hex(s
))
489 self
.assertEqual(s
, 1)
491 def test_fpformat_fp32(self
):
492 f32
= FPFormat
.standard(32)
493 from sfpy
import Float32
494 x
= Float32(1.0).bits
497 self
.assertEqual(f32
.get_exponent(x
), 0)
498 x
= Float32(2.0).bits
500 self
.assertEqual(f32
.get_exponent(x
), 1)
502 x
= Float32(1.5).bits
503 m
= f32
.get_mantissa_field(x
)
504 print (hex(x
), hex(m
))
505 self
.assertEqual(m
, 0x400000)
508 x
= Float32(-1.0).sqrt()
511 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
512 f32
.get_mantissa_field(x
), i
)
513 self
.assertEqual(i
, True)
516 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
519 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
520 f32
.get_mantissa_field(x
), i
)
521 self
.assertEqual(i
, True)
526 i
= f32
.is_subnormal(x
)
527 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
528 f32
.get_mantissa_field(x
), i
)
529 self
.assertEqual(i
, True)
533 i
= f32
.is_subnormal(x
)
534 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
535 f32
.get_mantissa_field(x
), i
)
536 self
.assertEqual(i
, False)
540 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
541 f32
.get_mantissa_field(x
), i
)
542 self
.assertEqual(i
, True)
545 class MultiShiftR(Elaboratable
):
547 def __init__(self
, width
):
549 self
.smax
= bits_for(width
- 1)
550 self
.i
= Signal(width
, reset_less
=True)
551 self
.s
= Signal(self
.smax
, reset_less
=True)
552 self
.o
= Signal(width
, reset_less
=True)
554 def elaborate(self
, platform
):
556 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
561 """ Generates variable-length single-cycle shifter from a series
562 of conditional tests on each bit of the left/right shift operand.
563 Each bit tested produces output shifted by that number of bits,
564 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
565 shifts by 2 bits, each partial result cascading to the next Mux.
567 Could be adapted to do arithmetic shift by taking copies of the
568 MSB instead of zeros.
571 def __init__(self
, width
):
573 self
.smax
= bits_for(width
- 1)
575 def lshift(self
, op
, s
):
579 def rshift(self
, op
, s
):
584 class FPNumBaseRecord
:
585 """ Floating-point Base Number Class.
587 This class is designed to be passed around in other data structures
588 (between pipelines and between stages). Its "friend" is FPNumBase,
589 which is a *module*. The reason for the discernment is because
590 nmigen modules that are not added to submodules results in the
591 irritating "Elaboration" warning. Despite not *needing* FPNumBase
592 in many cases to be added as a submodule (because it is just data)
593 this was not possible to solve without splitting out the data from
597 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
600 # assert false, "missing name"
604 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
605 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
606 e_max
= 1 << (e_width
-3)
607 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
610 # mantissa extra bits (top,guard,round)
612 m_width
+= self
.m_extra
616 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
617 e_width
+= self
.e_extra
620 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
621 self
.m_width
= m_width
622 self
.e_width
= e_width
623 self
.e_start
= self
.rmw
624 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
626 self
.v
= Signal(width
, reset_less
=True,
627 name
=name
+"v") # Latched copy of value
628 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
629 self
.e
= Signal(signed(e_width
),
630 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
631 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
636 def drop_in(self
, fp
):
642 fp
.width
= self
.width
643 fp
.e_width
= self
.e_width
644 fp
.e_max
= self
.e_max
645 fp
.m_width
= self
.m_width
646 fp
.e_start
= self
.e_start
647 fp
.e_end
= self
.e_end
648 fp
.m_extra
= self
.m_extra
650 m_width
= self
.m_width
652 e_width
= self
.e_width
654 self
.mzero
= Const(0, unsigned(m_width
))
655 m_msb
= 1 << (self
.m_width
-2)
656 self
.msb1
= Const(m_msb
, unsigned(m_width
))
657 self
.m1s
= Const(-1, unsigned(m_width
))
658 self
.P128
= Const(e_max
, signed(e_width
))
659 self
.P127
= Const(e_max
-1, signed(e_width
))
660 self
.N127
= Const(-(e_max
-1), signed(e_width
))
661 self
.N126
= Const(-(e_max
-2), signed(e_width
))
663 def create(self
, s
, e
, m
):
664 """ creates a value from sign / exponent / mantissa
666 bias is added here, to the exponent.
668 NOTE: order is important, because e_start/e_end can be
669 a bit too long (overwriting s).
672 self
.v
[0:self
.e_start
].eq(m
), # mantissa
673 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
674 self
.v
[-1].eq(s
), # sign
678 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
681 return (s
, self
.fp
.P128
, 0)
684 return (s
, self
.fp
.N127
, 0)
687 return self
.create(*self
._nan
(s
))
689 def quieted_nan(self
, other
):
690 assert isinstance(other
, FPNumBaseRecord
)
691 assert self
.width
== other
.width
692 return self
.create(other
.s
, self
.fp
.P128
,
693 other
.v
[0:self
.e_start
] |
(1 << (self
.e_start
- 1)))
696 return self
.create(*self
._inf
(s
))
698 def max_normal(self
, s
):
699 return self
.create(s
, self
.fp
.P127
, ~
0)
701 def min_denormal(self
, s
):
702 return self
.create(s
, self
.fp
.N127
, 1)
705 return self
.create(*self
._zero
(s
))
707 def create2(self
, s
, e
, m
):
708 """ creates a value from sign / exponent / mantissa
710 bias is added here, to the exponent
712 e
= e
+ self
.P127
# exp (add on bias)
713 return Cat(m
[0:self
.e_start
],
714 e
[0:self
.e_end
-self
.e_start
],
718 return self
.create2(s
, self
.P128
, self
.msb1
)
721 return self
.create2(s
, self
.P128
, self
.mzero
)
724 return self
.create2(s
, self
.N127
, self
.mzero
)
732 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
735 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
736 """ Floating-point Base Number Class
739 def __init__(self
, fp
):
744 self
.is_nan
= Signal(reset_less
=True)
745 self
.is_zero
= Signal(reset_less
=True)
746 self
.is_inf
= Signal(reset_less
=True)
747 self
.is_overflowed
= Signal(reset_less
=True)
748 self
.is_denormalised
= Signal(reset_less
=True)
749 self
.exp_128
= Signal(reset_less
=True)
750 self
.exp_sub_n126
= Signal(signed(e_width
), reset_less
=True)
751 self
.exp_lt_n126
= Signal(reset_less
=True)
752 self
.exp_zero
= Signal(reset_less
=True)
753 self
.exp_gt_n126
= Signal(reset_less
=True)
754 self
.exp_gt127
= Signal(reset_less
=True)
755 self
.exp_n127
= Signal(reset_less
=True)
756 self
.exp_n126
= Signal(reset_less
=True)
757 self
.m_zero
= Signal(reset_less
=True)
758 self
.m_msbzero
= Signal(reset_less
=True)
760 def elaborate(self
, platform
):
762 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
763 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
764 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
765 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
766 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
767 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
768 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
769 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
770 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
771 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
772 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
773 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
774 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
775 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
776 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
781 return (self
.exp_128
) & (~self
.m_zero
)
784 return (self
.exp_128
) & (self
.m_zero
)
787 return (self
.exp_n127
) & (self
.m_zero
)
789 def _is_overflowed(self
):
790 return self
.exp_gt127
792 def _is_denormalised(self
):
793 # XXX NOT to be used for "official" quiet NaN tests!
794 # particularly when the MSB has been extended
795 return (self
.exp_n126
) & (self
.m_msbzero
)
798 class FPNumOut(FPNumBase
):
799 """ Floating-point Number Class
801 Contains signals for an incoming copy of the value, decoded into
802 sign / exponent / mantissa.
803 Also contains encoding functions, creation and recognition of
804 zero, NaN and inf (all signed)
806 Four extra bits are included in the mantissa: the top bit
807 (m[-1]) is effectively a carry-overflow. The other three are
808 guard (m[2]), round (m[1]), and sticky (m[0])
811 def __init__(self
, fp
):
812 FPNumBase
.__init
__(self
, fp
)
814 def elaborate(self
, platform
):
815 m
= FPNumBase
.elaborate(self
, platform
)
820 class MultiShiftRMerge(Elaboratable
):
821 """ shifts down (right) and merges lower bits into m[0].
822 m[0] is the "sticky" bit, basically
825 def __init__(self
, width
, s_max
=None):
827 s_max
= bits_for(width
- 1)
828 self
.smax
= Shape
.cast(s_max
)
829 self
.m
= Signal(width
, reset_less
=True)
830 self
.inp
= Signal(width
, reset_less
=True)
831 self
.diff
= Signal(s_max
, reset_less
=True)
834 def elaborate(self
, platform
):
837 rs
= Signal(self
.width
, reset_less
=True)
838 m_mask
= Signal(self
.width
, reset_less
=True)
839 smask
= Signal(self
.width
, reset_less
=True)
840 stickybit
= Signal(reset_less
=True)
841 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
842 maxslen
= Signal(self
.smax
.width
, reset_less
=True)
843 maxsleni
= Signal(self
.smax
.width
, reset_less
=True)
845 sm
= MultiShift(self
.width
-1)
846 m0s
= Const(0, self
.width
-1)
847 mw
= Const(self
.width
-1, len(self
.diff
))
848 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
849 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
853 # shift mantissa by maxslen, mask by inverse
854 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
855 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
856 smask
.eq(self
.inp
[1:] & m_mask
),
857 # sticky bit combines all mask (and mantissa low bit)
858 stickybit
.eq(smask
.bool() | self
.inp
[0]),
859 # mantissa result contains m[0] already.
860 self
.m
.eq(Cat(stickybit
, rs
))
865 class FPNumShift(FPNumBase
, Elaboratable
):
866 """ Floating-point Number Class for shifting
869 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
870 FPNumBase
.__init
__(self
, width
, m_extra
)
871 self
.latch_in
= Signal()
876 def elaborate(self
, platform
):
877 m
= FPNumBase
.elaborate(self
, platform
)
879 m
.d
.comb
+= self
.s
.eq(op
.s
)
880 m
.d
.comb
+= self
.e
.eq(op
.e
)
881 m
.d
.comb
+= self
.m
.eq(op
.m
)
883 with self
.mainm
.State("align"):
884 with m
.If(self
.e
< self
.inv
.e
):
885 m
.d
.sync
+= self
.shift_down()
889 def shift_down(self
, inp
):
890 """ shifts a mantissa down by one. exponent is increased to compensate
892 accuracy is lost as a result in the mantissa however there are 3
893 guard bits (the latter of which is the "sticky" bit)
895 return [self
.e
.eq(inp
.e
+ 1),
896 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
899 def shift_down_multi(self
, diff
):
900 """ shifts a mantissa down. exponent is increased to compensate
902 accuracy is lost as a result in the mantissa however there are 3
903 guard bits (the latter of which is the "sticky" bit)
905 this code works by variable-shifting the mantissa by up to
906 its maximum bit-length: no point doing more (it'll still be
909 the sticky bit is computed by shifting a batch of 1s by
910 the same amount, which will introduce zeros. it's then
911 inverted and used as a mask to get the LSBs of the mantissa.
912 those are then |'d into the sticky bit.
914 sm
= MultiShift(self
.width
)
915 mw
= Const(self
.m_width
-1, len(diff
))
916 maxslen
= Mux(diff
> mw
, mw
, diff
)
917 rs
= sm
.rshift(self
.m
[1:], maxslen
)
918 maxsleni
= mw
- maxslen
919 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
921 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
922 return [self
.e
.eq(self
.e
+ diff
),
923 self
.m
.eq(Cat(stickybits
, rs
))
926 def shift_up_multi(self
, diff
):
927 """ shifts a mantissa up. exponent is decreased to compensate
929 sm
= MultiShift(self
.width
)
930 mw
= Const(self
.m_width
, len(diff
))
931 maxslen
= Mux(diff
> mw
, mw
, diff
)
933 return [self
.e
.eq(self
.e
- diff
),
934 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
938 class FPNumDecode(FPNumBase
):
939 """ Floating-point Number Class
941 Contains signals for an incoming copy of the value, decoded into
942 sign / exponent / mantissa.
943 Also contains encoding functions, creation and recognition of
944 zero, NaN and inf (all signed)
946 Four extra bits are included in the mantissa: the top bit
947 (m[-1]) is effectively a carry-overflow. The other three are
948 guard (m[2]), round (m[1]), and sticky (m[0])
951 def __init__(self
, op
, fp
):
952 FPNumBase
.__init
__(self
, fp
)
955 def elaborate(self
, platform
):
956 m
= FPNumBase
.elaborate(self
, platform
)
958 m
.d
.comb
+= self
.decode(self
.v
)
963 """ decodes a latched value into sign / exponent / mantissa
965 bias is subtracted here, from the exponent. exponent
966 is extended to 10 bits so that subtract 127 is done on
969 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
970 #print ("decode", self.e_end)
971 return [self
.m
.eq(Cat(*args
)), # mantissa
972 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
973 self
.s
.eq(v
[-1]), # sign
977 class FPNumIn(FPNumBase
):
978 """ Floating-point Number Class
980 Contains signals for an incoming copy of the value, decoded into
981 sign / exponent / mantissa.
982 Also contains encoding functions, creation and recognition of
983 zero, NaN and inf (all signed)
985 Four extra bits are included in the mantissa: the top bit
986 (m[-1]) is effectively a carry-overflow. The other three are
987 guard (m[2]), round (m[1]), and sticky (m[0])
990 def __init__(self
, op
, fp
):
991 FPNumBase
.__init
__(self
, fp
)
992 self
.latch_in
= Signal()
995 def decode2(self
, m
):
996 """ decodes a latched value into sign / exponent / mantissa
998 bias is subtracted here, from the exponent. exponent
999 is extended to 10 bits so that subtract 127 is done on
1003 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1004 #print ("decode", self.e_end)
1005 res
= ObjectProxy(m
, pipemode
=False)
1006 res
.m
= Cat(*args
) # mantissa
1007 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
1008 res
.s
= v
[-1] # sign
1011 def decode(self
, v
):
1012 """ decodes a latched value into sign / exponent / mantissa
1014 bias is subtracted here, from the exponent. exponent
1015 is extended to 10 bits so that subtract 127 is done on
1018 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1019 #print ("decode", self.e_end)
1020 return [self
.m
.eq(Cat(*args
)), # mantissa
1021 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
1022 self
.s
.eq(v
[-1]), # sign
1025 def shift_down(self
, inp
):
1026 """ shifts a mantissa down by one. exponent is increased to compensate
1028 accuracy is lost as a result in the mantissa however there are 3
1029 guard bits (the latter of which is the "sticky" bit)
1031 return [self
.e
.eq(inp
.e
+ 1),
1032 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
1035 def shift_down_multi(self
, diff
, inp
=None):
1036 """ shifts a mantissa down. exponent is increased to compensate
1038 accuracy is lost as a result in the mantissa however there are 3
1039 guard bits (the latter of which is the "sticky" bit)
1041 this code works by variable-shifting the mantissa by up to
1042 its maximum bit-length: no point doing more (it'll still be
1045 the sticky bit is computed by shifting a batch of 1s by
1046 the same amount, which will introduce zeros. it's then
1047 inverted and used as a mask to get the LSBs of the mantissa.
1048 those are then |'d into the sticky bit.
1052 sm
= MultiShift(self
.width
)
1053 mw
= Const(self
.m_width
-1, len(diff
))
1054 maxslen
= Mux(diff
> mw
, mw
, diff
)
1055 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
1056 maxsleni
= mw
- maxslen
1057 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
1059 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1060 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
1061 return [self
.e
.eq(inp
.e
+ diff
),
1062 self
.m
.eq(Cat(stickybit
, rs
))
1065 def shift_up_multi(self
, diff
):
1066 """ shifts a mantissa up. exponent is decreased to compensate
1068 sm
= MultiShift(self
.width
)
1069 mw
= Const(self
.m_width
, len(diff
))
1070 maxslen
= Mux(diff
> mw
, mw
, diff
)
1072 return [self
.e
.eq(self
.e
- diff
),
1073 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
1077 class Trigger(Elaboratable
):
1080 self
.stb
= Signal(reset
=0)
1082 self
.trigger
= Signal(reset_less
=True)
1084 def elaborate(self
, platform
):
1086 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
1090 return [self
.stb
.eq(inp
.stb
),
1091 self
.ack
.eq(inp
.ack
)
1095 return [self
.stb
, self
.ack
]
1098 class FPOpIn(PrevControl
):
1099 def __init__(self
, width
):
1100 PrevControl
.__init
__(self
)
1107 def chain_inv(self
, in_op
, extra
=None):
1109 if extra
is not None:
1111 return [self
.v
.eq(in_op
.v
), # receive value
1112 self
.stb
.eq(stb
), # receive STB
1113 in_op
.ack
.eq(~self
.ack
), # send ACK
1116 def chain_from(self
, in_op
, extra
=None):
1118 if extra
is not None:
1120 return [self
.v
.eq(in_op
.v
), # receive value
1121 self
.stb
.eq(stb
), # receive STB
1122 in_op
.ack
.eq(self
.ack
), # send ACK
1126 class FPOpOut(NextControl
):
1127 def __init__(self
, width
):
1128 NextControl
.__init
__(self
)
1135 def chain_inv(self
, in_op
, extra
=None):
1137 if extra
is not None:
1139 return [self
.v
.eq(in_op
.v
), # receive value
1140 self
.stb
.eq(stb
), # receive STB
1141 in_op
.ack
.eq(~self
.ack
), # send ACK
1144 def chain_from(self
, in_op
, extra
=None):
1146 if extra
is not None:
1148 return [self
.v
.eq(in_op
.v
), # receive value
1149 self
.stb
.eq(stb
), # receive STB
1150 in_op
.ack
.eq(self
.ack
), # send ACK
1155 # TODO: change FFLAGS to be FPSCR's status flags
1156 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
1157 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
1158 FFLAGS_OF
= Const(1<<2, 5) # overflow
1159 FFLAGS_UF
= Const(1<<1, 5) # underflow
1160 FFLAGS_NX
= Const(1<<0, 5) # inexact
1161 def __init__(self
, name
=None):
1164 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
1165 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
1166 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
1167 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
1168 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
1170 self
.sign
= Signal(reset_less
=True, name
=name
+"sign")
1171 """sign bit -- 1 means negative, 0 means positive"""
1173 self
.rm
= Signal(FPRoundingMode
, name
=name
+"rm",
1174 reset
=FPRoundingMode
.DEFAULT
)
1177 #self.roundz = Signal(reset_less=True)
1181 yield self
.round_bit
1189 return [self
.guard
.eq(inp
.guard
),
1190 self
.round_bit
.eq(inp
.round_bit
),
1191 self
.sticky
.eq(inp
.sticky
),
1193 self
.fpflags
.eq(inp
.fpflags
),
1194 self
.sign
.eq(inp
.sign
),
1198 def roundz_rne(self
):
1199 """true if the mantissa should be rounded up for `rm == RNE`
1201 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1203 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
1206 def roundz_rna(self
):
1207 """true if the mantissa should be rounded up for `rm == RNA`
1209 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1214 def roundz_rtn(self
):
1215 """true if the mantissa should be rounded up for `rm == RTN`
1217 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1219 return self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1222 def roundz_rto(self
):
1223 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1225 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1226 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1228 return ~self
.m0
& (self
.guard | self
.round_bit | self
.sticky
)
1231 def roundz_rtp(self
):
1232 """true if the mantissa should be rounded up for `rm == RTP`
1234 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1236 return ~self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1239 def roundz_rtz(self
):
1240 """true if the mantissa should be rounded up for `rm == RTZ`
1242 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1248 """true if the mantissa should be rounded up for the current rounding
1252 FPRoundingMode
.RNA
: self
.roundz_rna
,
1253 FPRoundingMode
.RNE
: self
.roundz_rne
,
1254 FPRoundingMode
.RTN
: self
.roundz_rtn
,
1255 FPRoundingMode
.RTOP
: self
.roundz_rto
,
1256 FPRoundingMode
.RTON
: self
.roundz_rto
,
1257 FPRoundingMode
.RTP
: self
.roundz_rtp
,
1258 FPRoundingMode
.RTZ
: self
.roundz_rtz
,
1260 return FPRoundingMode
.make_array(lambda rm
: d
[rm
])[self
.rm
]
1263 class OverflowMod(Elaboratable
, Overflow
):
1264 def __init__(self
, name
=None):
1265 Overflow
.__init
__(self
, name
)
1268 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
1271 yield from Overflow
.__iter
__(self
)
1272 yield self
.roundz_out
1275 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
1277 def elaborate(self
, platform
):
1279 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
1284 """ IEEE754 Floating Point Base Class
1286 contains common functions for FP manipulation, such as
1287 extracting and packing operands, normalisation, denormalisation,
1291 def get_op(self
, m
, op
, v
, next_state
):
1292 """ this function moves to the next state and copies the operand
1293 when both stb and ack are 1.
1294 acknowledgement is sent by setting ack to ZERO.
1298 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
1300 # op is latched in from FPNumIn class on same ack/stb
1301 m
.d
.comb
+= ack
.eq(0)
1303 m
.d
.comb
+= ack
.eq(1)
1306 def denormalise(self
, m
, a
):
1307 """ denormalises a number. this is probably the wrong name for
1308 this function. for normalised numbers (exponent != minimum)
1309 one *extra* bit (the implicit 1) is added *back in*.
1310 for denormalised numbers, the mantissa is left alone
1311 and the exponent increased by 1.
1313 both cases *effectively multiply the number stored by 2*,
1314 which has to be taken into account when extracting the result.
1316 with m
.If(a
.exp_n127
):
1317 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
1319 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
1321 def op_normalise(self
, m
, op
, next_state
):
1322 """ operand normalisation
1323 NOTE: just like "align", this one keeps going round every clock
1324 until the result's exponent is within acceptable "range"
1326 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
1328 op
.e
.eq(op
.e
- 1), # DECREASE exponent
1329 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
1334 def normalise_1(self
, m
, z
, of
, next_state
):
1335 """ first stage normalisation
1337 NOTE: just like "align", this one keeps going round every clock
1338 until the result's exponent is within acceptable "range"
1339 NOTE: the weirdness of reassigning guard and round is due to
1340 the extra mantissa bits coming from tot[0..2]
1342 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
1344 z
.e
.eq(z
.e
- 1), # DECREASE exponent
1345 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1346 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1347 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1348 of
.round_bit
.eq(0), # reset round bit
1354 def normalise_2(self
, m
, z
, of
, next_state
):
1355 """ second stage normalisation
1357 NOTE: just like "align", this one keeps going round every clock
1358 until the result's exponent is within acceptable "range"
1359 NOTE: the weirdness of reassigning guard and round is due to
1360 the extra mantissa bits coming from tot[0..2]
1362 with m
.If(z
.e
< z
.fp
.N126
):
1364 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1365 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1366 of
.guard
.eq(z
.m
[0]),
1368 of
.round_bit
.eq(of
.guard
),
1369 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1374 def roundz(self
, m
, z
, roundz
):
1375 """ performs rounding on the output. TODO: different kinds of rounding
1378 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1379 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1380 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1382 def corrections(self
, m
, z
, next_state
):
1383 """ denormalisation and sign-bug corrections
1386 # denormalised, correct exponent to zero
1387 with m
.If(z
.is_denormalised
):
1388 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1390 def pack(self
, m
, z
, next_state
):
1391 """ packs the result into the output (detects overflow->Inf)
1394 # if overflow occurs, return inf
1395 with m
.If(z
.is_overflowed
):
1396 m
.d
.sync
+= z
.inf(z
.s
)
1398 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1400 def put_z(self
, m
, z
, out_z
, next_state
):
1401 """ put_z: stores the result in the output. raises stb and waits
1402 for ack to be set to 1 before moving to the next state.
1403 resets stb back to zero when that occurs, as acknowledgement.
1408 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1409 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1412 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1415 class FPState(FPBase
):
1416 def __init__(self
, state_from
):
1417 self
.state_from
= state_from
1419 def set_inputs(self
, inputs
):
1420 self
.inputs
= inputs
1421 for k
, v
in inputs
.items():
1424 def set_outputs(self
, outputs
):
1425 self
.outputs
= outputs
1426 for k
, v
in outputs
.items():
1431 def __init__(self
, id_wid
):
1432 self
.id_wid
= id_wid
1434 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1435 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1440 def idsync(self
, m
):
1441 if self
.id_wid
is not None:
1442 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1445 if __name__
== '__main__':