add rounding modes to fpbase.Overflow
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 29 Jun 2022 04:39:22 +0000 (21:39 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 29 Jun 2022 04:39:22 +0000 (21:39 -0700)
src/ieee754/fpcommon/fpbase.py

index 6e40b021945236fb914c5e4a7630a0f28b0f582c..01f3e07fd702b873e62f57b0a460183a531d0004 100644 (file)
@@ -6,7 +6,7 @@ Copyright (C) 2019 Jake Lifshay
 """
 
 
-from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array
 from math import log
 from operator import or_
 from functools import reduce
@@ -15,6 +15,120 @@ from nmutil.singlepipe import PrevControl, NextControl
 from nmutil.pipeline import ObjectProxy
 import unittest
 import math
+import enum
+
+try:
+    from nmigen.hdl.smtlib2 import RoundingModeEnum
+    _HAVE_SMTLIB2 = True
+except ImportError:
+    _HAVE_SMTLIB2 = False
+
+# value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
+_raise_err = object()
+
+
+class FPRoundingMode(enum.Enum):
+    # matches the FPSCR.RN field values, but includes some extra
+    # values (>= 0b100) used in miscellaneous instructions.
+
+    # naming matches smtlib2 names, doc strings are the OpenPower ISA
+    # specification's names (v3.1 section 7.3.2.6 --
+    # matches values in section 4.3.6).
+    RNE = 0b00
+    """Round to Nearest Even
+
+    Rounds to the nearest representable floating-point number, ties are
+    rounded to the number with the even mantissa. Treats +-Infinity as if
+    it were a normalized floating-point number when deciding which number
+    is closer when rounding. See IEEE754 spec. for details.
+    """
+
+    ROUND_NEAREST_TIES_TO_EVEN = RNE
+    DEFAULT = RNE
+
+    RTZ = 0b01
+    """Round towards Zero
+
+    If the result is exactly representable as a floating-point number, return
+    that, otherwise return the nearest representable floating-point value
+    with magnitude smaller than the exact answer.
+    """
+
+    ROUND_TOWARDS_ZERO = RTZ
+
+    RTP = 0b10
+    """Round towards +Infinity
+
+    If the result is exactly representable as a floating-point number, return
+    that, otherwise return the nearest representable floating-point value
+    that is numerically greater than the exact answer. This can round up to
+    +Infinity.
+    """
+
+    ROUND_TOWARDS_POSITIVE = RTP
+
+    RTN = 0b11
+    """Round towards -Infinity
+
+    If the result is exactly representable as a floating-point number, return
+    that, otherwise return the nearest representable floating-point value
+    that is numerically less than the exact answer. This can round down to
+    -Infinity.
+    """
+
+    ROUND_TOWARDS_NEGATIVE = RTN
+
+    RNA = 0b100
+    """Round to Nearest Away
+
+    Rounds to the nearest representable floating-point number, ties are
+    rounded to the number with the maximum magnitude. Treats +-Infinity as if
+    it were a normalized floating-point number when deciding which number
+    is closer when rounding. See IEEE754 spec. for details.
+    """
+
+    ROUND_NEAREST_TIES_TO_AWAY = RNA
+
+    RTO = 0b101
+    """Round to Odd
+
+    Not in smtlib2.
+
+    If the result is exactly representable as a floating-point number, return
+    that, otherwise return the nearest representable floating-point value
+    that has an odd mantissa.
+
+    This is useful to avoid double-rounding errors when doing arithmetic in a
+    larger type (e.g. f128) but where the answer should be a smaller type
+    (e.g. f80).
+    """
+
+    ROUND_TO_ODD = RTO
+
+    if _HAVE_SMTLIB2:
+        def to_smtlib2(self, default=_raise_err):
+            """return the corresponding smtlib2 rounding mode for `self`. If
+            there is no corresponding smtlib2 rounding mode, then return
+            `default` if specified, else raise `ValueError`.
+            """
+            if self is FPRoundingMode.RNE:
+                return RoundingModeEnum.RNE
+            elif self is FPRoundingMode.RTZ:
+                return RoundingModeEnum.RTZ
+            elif self is FPRoundingMode.RTP:
+                return RoundingModeEnum.RTP
+            elif self is FPRoundingMode.RTN:
+                return RoundingModeEnum.RTN
+            elif self is FPRoundingMode.RNA:
+                return RoundingModeEnum.RNA
+            else:
+                assert self is FPRoundingMode.RTO
+                if default is _raise_err:
+                    raise ValueError(
+                        "no corresponding smtlib2 rounding mode", self)
+                return default
+
+
 
 
 class FPFormat:
@@ -885,6 +999,7 @@ class FPOpOut(NextControl):
 
 
 class Overflow:
+    # TODO: change FFLAGS to be FPSCR's status flags
     FFLAGS_NV = Const(1<<4, 5) # invalid operation
     FFLAGS_DZ = Const(1<<3, 5) # divide by zero
     FFLAGS_OF = Const(1<<2, 5) # overflow
@@ -899,6 +1014,12 @@ class Overflow:
         self.m0 = Signal(reset_less=True, name=name+"m0")  # mantissa bit 0
         self.fpflags = Signal(5, reset_less=True, name=name+"fflags")
 
+        self.sign = Signal(reset_less=True, name=name+"sign")
+        """sign bit -- 1 means negative, 0 means positive"""
+
+        self.rm = Signal(name=name+"rm", reset=FPRoundingMode.DEFAULT)
+        """rounding mode"""
+
         #self.roundz = Signal(reset_less=True)
 
     def __iter__(self):
@@ -907,18 +1028,86 @@ class Overflow:
         yield self.sticky
         yield self.m0
         yield self.fpflags
+        yield self.sign
+        yield self.rm
 
     def eq(self, inp):
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
                 self.sticky.eq(inp.sticky),
                 self.m0.eq(inp.m0),
-                self.fpflags.eq(inp.fpflags)]
+                self.fpflags.eq(inp.fpflags),
+                self.sign.eq(inp.sign),
+                self.rm.eq(inp.rm)]
 
     @property
-    def roundz(self):
+    def roundz_rne(self):
+        """true if the mantissa should be rounded up for `rm == RNE`
+
+        assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
+        """
         return self.guard & (self.round_bit | self.sticky | self.m0)
 
+    @property
+    def roundz_rna(self):
+        """true if the mantissa should be rounded up for `rm == RNA`
+
+        assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
+        """
+        return self.guard
+
+    @property
+    def roundz_rtn(self):
+        """true if the mantissa should be rounded up for `rm == RTN`
+
+        assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
+        """
+        FPRoundingMode.ROUND_TOWARDS_NEGATIVE
+        return self.sign & (self.guard | self.round_bit | self.sticky)
+
+    @property
+    def roundz_rto(self):
+        """true if the mantissa should be rounded up for `rm == RTO`
+
+        assumes the rounding mode is `ROUND_TO_ODD`
+        """
+        FPRoundingMode.ROUND_TO_ODD
+        return ~self.m0 & (self.guard | self.round_bit | self.sticky)
+
+    @property
+    def roundz_rtp(self):
+        """true if the mantissa should be rounded up for `rm == RTP`
+
+        assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
+        """
+        FPRoundingMode.ROUND_TOWARDS_POSITIVE
+        return ~self.sign & (self.guard | self.round_bit | self.sticky)
+
+    @property
+    def roundz_rtz(self):
+        """true if the mantissa should be rounded up for `rm == RTZ`
+
+        assumes the rounding mode is `ROUND_TOWARDS_ZERO`
+        """
+        FPRoundingMode.ROUND_TOWARDS_ZERO
+        return self.guard & (self.round_bit | self.sticky | self.m0)
+
+    @property
+    def roundz(self):
+        """true if the mantissa should be rounded up for the current rounding
+        mode `self.rm`
+        """
+        l = [None] * len(FPRoundingMode)
+        l[FPRoundingMode.RNA.value] = self.roundz_rna
+        l[FPRoundingMode.RNE.value] = self.roundz_rne
+        l[FPRoundingMode.RTN.value] = self.roundz_rtn
+        l[FPRoundingMode.RTO.value] = self.roundz_rto
+        l[FPRoundingMode.RTP.value] = self.roundz_rtp
+        l[FPRoundingMode.RTZ.value] = self.roundz_rtz
+        for (i, v) in enumerate(l):
+            assert v is not None, f"missing rounding mode {bin(i)}"
+        return Array(l)[self.rm]
+
 
 class OverflowMod(Elaboratable, Overflow):
     def __init__(self, name=None):