add FPFormat.get_exponent_value to get an unbiased exponent corrected for subnormals
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
index c53538f403f9f31f6632d0a7de0b9a1688add4b0..84768edc158e2a15323f2baef6984ca2f6f874e6 100644 (file)
@@ -1,13 +1,14 @@
 """IEEE754 Floating Point Library
 
 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
-Copyright (C) 2019,2021 Jake Lifshay
+Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
 
 """
 
 
-from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array, Value
-from math import log
+from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
+                    Value, Shape, signed, unsigned)
+from nmigen.utils import bits_for
 from operator import or_
 from functools import reduce
 
@@ -312,13 +313,40 @@ class FPFormat:
     def get_exponent(self, x):
         """ returns the exponent of its input number, x
         """
-        return self.get_exponent_field(x) - self.exponent_bias
+        x = self.get_exponent_field(x)
+        if isinstance(x, Value) and not x.shape().signed:
+            # convert x to signed without changing its value,
+            # since exponents can be negative
+            x |= Const(0, signed(1))
+        return x - self.exponent_bias
+
+    def get_exponent_value(self, x):
+        """ returns the exponent of its input number, x, adjusted for the
+        mathematically correct subnormal exponent.
+        """
+        x = self.get_exponent_field(x)
+        if isinstance(x, Value) and not x.shape().signed:
+            # convert x to signed without changing its value,
+            # since exponents can be negative
+            x |= Const(0, signed(1))
+        return x + (x == self.exponent_denormal_zero) - self.exponent_bias
 
     def get_mantissa_field(self, x):
         """ returns the mantissa of its input number, x
         """
         return x & self.mantissa_mask
 
+    def get_mantissa_value(self, x):
+        """ returns the mantissa of its input number, x, but with the
+        implicit bit, if any, made explicit.
+        """
+        if self.has_int_bit:
+            return self.get_mantissa_field(x)
+        exponent_field = self.get_exponent_field(x)
+        mantissa_field = self.get_mantissa_field(x)
+        implicit_bit = exponent_field != self.exponent_denormal_zero
+        return (implicit_bit << self.fraction_width) | mantissa_field
+
     def is_zero(self, x):
         """ returns true if x is +/- zero
         """
@@ -351,6 +379,23 @@ class FPFormat:
             (self.get_mantissa_field(x) != 0) & \
             (self.get_mantissa_field(x) & highbit != 0)
 
+    def to_quiet_nan(self, x):
+        """ converts `x` to a quiet NaN """
+        highbit = 1 << (self.m_width - 1)
+        return x | highbit | self.exponent_mask
+
+    def quiet_nan(self, sign=0):
+        """ return the default quiet NaN with sign `sign` """
+        return self.to_quiet_nan(self.zero(sign))
+
+    def zero(self, sign=0):
+        """ return zero with sign `sign` """
+        return (sign != 0) << (self.e_width + self.m_width)
+
+    def inf(self, sign=0):
+        """ return infinity with sign `sign` """
+        return self.zero(sign) | self.exponent_mask
+
     def is_nan_signaling(self, x):
         """ returns true if x is a signalling nan
         """
@@ -369,6 +414,11 @@ class FPFormat:
         """ Get a mantissa mask based on the mantissa width """
         return (1 << self.m_width) - 1
 
+    @property
+    def exponent_mask(self):
+        """ Get an exponent mask """
+        return self.exponent_inf_nan << self.m_width
+
     @property
     def exponent_inf_nan(self):
         """ Get the value of the exponent field designating infinity/NaN. """
@@ -492,11 +542,11 @@ class TestFPFormat(unittest.TestCase):
         self.assertEqual(i, True)
 
 
-class MultiShiftR:
+class MultiShiftR(Elaboratable):
 
     def __init__(self, width):
         self.width = width
-        self.smax = int(log(width) / log(2))
+        self.smax = bits_for(width - 1)
         self.i = Signal(width, reset_less=True)
         self.s = Signal(self.smax, reset_less=True)
         self.o = Signal(width, reset_less=True)
@@ -520,7 +570,7 @@ class MultiShift:
 
     def __init__(self, width):
         self.width = width
-        self.smax = int(log(width) / log(2))
+        self.smax = bits_for(width - 1)
 
     def lshift(self, op, s):
         res = op << s
@@ -576,7 +626,7 @@ class FPNumBaseRecord:
         self.v = Signal(width, reset_less=True,
                         name=name+"v")  # Latched copy of value
         self.m = Signal(m_width, reset_less=True, name=name+"m")  # Mantissa
-        self.e = Signal((e_width, True),
+        self.e = Signal(signed(e_width),
                         reset_less=True, name=name+"e")  # exp+2 bits, signed
         self.s = Signal(reset_less=True, name=name+"s")  # Sign bit
 
@@ -601,14 +651,14 @@ class FPNumBaseRecord:
         e_max = self.e_max
         e_width = self.e_width
 
-        self.mzero = Const(0, (m_width, False))
+        self.mzero = Const(0, unsigned(m_width))
         m_msb = 1 << (self.m_width-2)
-        self.msb1 = Const(m_msb, (m_width, False))
-        self.m1s = Const(-1, (m_width, False))
-        self.P128 = Const(e_max, (e_width, True))
-        self.P127 = Const(e_max-1, (e_width, True))
-        self.N127 = Const(-(e_max-1), (e_width, True))
-        self.N126 = Const(-(e_max-2), (e_width, True))
+        self.msb1 = Const(m_msb, unsigned(m_width))
+        self.m1s = Const(-1, unsigned(m_width))
+        self.P128 = Const(e_max, signed(e_width))
+        self.P127 = Const(e_max-1, signed(e_width))
+        self.N127 = Const(-(e_max-1), signed(e_width))
+        self.N126 = Const(-(e_max-2), signed(e_width))
 
     def create(self, s, e, m):
         """ creates a value from sign / exponent / mantissa
@@ -697,7 +747,7 @@ class FPNumBase(FPNumBaseRecord, Elaboratable):
         self.is_overflowed = Signal(reset_less=True)
         self.is_denormalised = Signal(reset_less=True)
         self.exp_128 = Signal(reset_less=True)
-        self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
+        self.exp_sub_n126 = Signal(signed(e_width), reset_less=True)
         self.exp_lt_n126 = Signal(reset_less=True)
         self.exp_zero = Signal(reset_less=True)
         self.exp_gt_n126 = Signal(reset_less=True)
@@ -774,8 +824,8 @@ class MultiShiftRMerge(Elaboratable):
 
     def __init__(self, width, s_max=None):
         if s_max is None:
-            s_max = int(log(width) / log(2))
-        self.smax = s_max
+            s_max = bits_for(width - 1)
+        self.smax = Shape.cast(s_max)
         self.m = Signal(width, reset_less=True)
         self.inp = Signal(width, reset_less=True)
         self.diff = Signal(s_max, reset_less=True)
@@ -789,8 +839,8 @@ class MultiShiftRMerge(Elaboratable):
         smask = Signal(self.width, reset_less=True)
         stickybit = Signal(reset_less=True)
         # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
-        maxslen = Signal(self.smax[0], reset_less=True)
-        maxsleni = Signal(self.smax[0], reset_less=True)
+        maxslen = Signal(self.smax.width, reset_less=True)
+        maxsleni = Signal(self.smax.width, reset_less=True)
 
         sm = MultiShift(self.width-1)
         m0s = Const(0, self.width-1)