remove m.If/Elif in fpdiv sqrt, replace with Mux
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
index 5ec3194902e43da4acef377748a08bcf25143c69..ee2ff33eb277b900886c91e20463df90b25e41d3 100644 (file)
-# IEEE Floating Point Multiplier 
+""" IEEE Floating Point Divider
 
-from nmigen import Module, Signal, Cat, Const, Elaboratable
+Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+Copyright (C) 2019 Jacob Lifshay
+
+Relevant bugreports:
+* http://bugs.libre-riscv.org/show_bug.cgi?id=99
+* http://bugs.libre-riscv.org/show_bug.cgi?id=43
+* http://bugs.libre-riscv.org/show_bug.cgi?id=44
+"""
+
+from nmigen import Module, Signal, Cat, Mux
 from nmigen.cli import main, verilog
 from math import log
 
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
-from nmutil.singlepipe import SimpleHandshake, StageChain
-
-from ieee754.fpcommon.fpbase import FPState, FPID
-from ieee754.fpcommon.getop import FPADDBaseData
+from ieee754.fpcommon.basedata import FPBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
+from ieee754.fpmul.align import FPAlignModSingle
+from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
 
 
-class FPDIVSpecialCasesMod(Elaboratable):
+class FPDIVSpecialCasesMod(PipeModBase):
     """ special cases: NaNs, infs, zeros, denormalised
         see "Special Operations"
         https://steve.hollasch.net/cgindex/coding/ieeefloat.html
     """
 
-    def __init__(self, width, id_wid):
-        self.width = width
-        self.id_wid = id_wid
-        self.i = self.ispec()
-        self.o = self.ospec()
+    def __init__(self, pspec):
+        super().__init__(pspec, "specialcases")
 
     def ispec(self):
-        return FPADDBaseData(self.width, self.id_wid)
+        return FPBaseData(self.pspec)
 
     def ospec(self):
-        return FPSCData(self.width, self.id_wid, False)
-
-    def setup(self, m, i):
-        """ links module to inputs and outputs
-        """
-        m.submodules.specialcases = self
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
+        return FPSCData(self.pspec, False)
 
     def elaborate(self, platform):
         m = Module()
-
-        #m.submodules.sc_out_z = self.o.z
+        comb = m.d.comb
 
         # decode: XXX really should move to separate stage
-        a1 = FPNumBaseRecord(self.width, False)
-        b1 = FPNumBaseRecord(self.width, False)
+        width = self.pspec.width
+        a1 = FPNumBaseRecord(width, False, name="a1")
+        b1 = FPNumBaseRecord(width, False, name="b1")
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
-        m.d.comb += [a1.v.eq(self.i.a),
+        comb += [a1.v.eq(self.i.a),
                      b1.v.eq(self.i.b),
                      self.o.a.eq(a1),
                      self.o.b.eq(b1)
-                    ]
+                     ]
 
+        # temporaries (used below)
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
-        m.d.comb += sabx.eq(a1.s ^ b1.s)
-
-        abnan = Signal(reset_less=True)
-        m.d.comb += abnan.eq(a1.is_nan | b1.is_nan)
-
-        abinf = Signal(reset_less=True)
-        m.d.comb += abinf.eq(a1.is_inf & b1.is_inf)
-
-        # if a is NaN or b is NaN return NaN
-        with m.If(abnan):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.nan(1)
-
-        # if a is inf and b is Inf return NaN
-        with m.Elif(abnan):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.nan(1)
-
-        # if a is inf return inf
-        with m.Elif(a1.is_inf):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.inf(sabx)
-
-        # if b is inf return zero
-        with m.Elif(b1.is_inf):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.zero(sabx)
-
-        # if a is zero return zero (or NaN if b is zero)
-        with m.Elif(a1.is_zero):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.zero(sabx)
-            # b is zero return NaN
-            with m.If(b1.is_zero):
-                m.d.comb += self.o.z.nan(1)
-
-        # if b is zero return Inf
-        with m.Elif(b1.is_zero):
-            m.d.comb += self.o.out_do_z.eq(1)
-            m.d.comb += self.o.z.inf(sabx)
-
-        # Denormalised Number checks next, so pass a/b data through
-        with m.Else():
-            m.d.comb += self.o.out_do_z.eq(0)
-
-        m.d.comb += self.o.oz.eq(self.o.z.v)
-        m.d.comb += self.o.mid.eq(self.i.mid)
+        t_abnan = Signal(reset_less=True)
+        t_abinf = Signal(reset_less=True)
+        t_a1inf = Signal(reset_less=True)
+        t_b1inf = Signal(reset_less=True)
+        t_a1nan = Signal(reset_less=True)
+        t_a1zero = Signal(reset_less=True)
+        t_b1zero = Signal(reset_less=True)
+        t_abz = Signal(reset_less=True)
+        t_special_div = Signal(reset_less=True)
+        t_special_sqrt = Signal(reset_less=True)
+        t_special_rsqrt = Signal(reset_less=True)
+
+        comb += sabx.eq(a1.s ^ b1.s)
+        comb += t_abnan.eq(a1.is_nan | b1.is_nan)
+        comb += t_abinf.eq(a1.is_inf & b1.is_inf)
+        comb += t_a1inf.eq(a1.is_inf)
+        comb += t_b1inf.eq(b1.is_inf)
+        comb += t_a1nan.eq(a1.is_nan)
+        comb += t_abz.eq(a1.is_zero & b1.is_zero)
+        comb += t_a1zero.eq(a1.is_zero)
+        comb += t_b1zero.eq(b1.is_zero)
+
+        # prepare inf/zero/nans
+        z_zero = FPNumBaseRecord(width, False, name="z_zero")
+        z_zeroa = FPNumBaseRecord(width, False, name="z_zeroa")
+        z_zeroab = FPNumBaseRecord(width, False, name="z_zeroab")
+        z_nan = FPNumBaseRecord(width, False, name="z_nan")
+        z_infa = FPNumBaseRecord(width, False, name="z_infa")
+        z_infb = FPNumBaseRecord(width, False, name="z_infb")
+        z_infab = FPNumBaseRecord(width, False, name="z_infab")
+        comb += z_zero.zero(0)
+        comb += z_zeroa.zero(a1.s)
+        comb += z_zeroab.zero(sabx)
+        comb += z_nan.nan(0)
+        comb += z_infa.inf(a1.s)
+        comb += z_infb.inf(b1.s)
+        comb += z_infab.inf(sabx)
+
+        comb += t_special_div.eq(Cat(t_b1zero, t_a1zero, t_b1inf, t_a1inf,
+                                     t_abinf, t_abnan).bool())
+        comb += t_special_sqrt.eq(Cat(t_a1zero, a1.s, t_a1inf,
+                                     t_a1nan).bool())
+
+        # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
+        with m.Switch(self.i.ctx.op):
+
+            ########## DIV ############
+            with m.Case(int(DP.UDivRem)):
+
+                # any special cases?
+                comb += self.o.out_do_z.eq(t_special_div)
+
+                # if a is NaN or b is NaN return NaN
+                # if a is inf and b is Inf return NaN
+                # if a is inf return inf
+                # if b is inf return zero
+                # if a is zero return zero (or NaN if b is zero)
+                    # b is zero return NaN
+                # if b is zero return Inf
+
+                # sigh inverse order on the above, Mux-cascade
+                oz = 0
+                oz = Mux(t_b1zero, z_infab.v, oz)
+                oz = Mux(t_a1zero, Mux(t_b1zero, z_nan.v, z_zeroab.v), oz)
+                oz = Mux(t_b1inf, z_zeroab.v, oz)
+                oz = Mux(t_a1inf, z_infab.v, oz)
+                oz = Mux(t_abinf, z_nan.v, oz)
+                oz = Mux(t_abnan, z_nan.v, oz)
+
+                comb += self.o.oz.eq(oz)
+
+            ########## SQRT ############
+            with m.Case(int(DP.SqrtRem)):
+
+                # any special cases?
+                comb += self.o.out_do_z.eq(t_special_sqrt)
+
+                # if a is zero return zero
+                # -ve number is NaN
+                # if a is inf return inf
+                # if a is NaN return NaN
+
+                oz = 0
+                oz = Mux(t_a1nan, z_nan.v, oz)
+                oz = Mux(t_a1inf, z_infab.v, oz)
+                oz = Mux(a1.s, z_nan.v, oz)
+                oz = Mux(t_a1zero, z_zeroa.v, oz)
+
+                comb += self.o.oz.eq(oz)
+
+            ########## RSQRT ############
+            with m.Case(int(DP.RSqrtRem)):
+
+                # if a is NaN return canonical NaN
+                with m.If(a1.is_nan):
+                    comb += self.o.z.nan(0)
+
+                # if a is +/- zero return +/- INF
+                with m.Elif(a1.is_zero):
+                    # this includes the "weird" case 1/sqrt(-0) == -Inf
+                    comb += self.o.z.inf(a1.s)
+
+                # -ve number is canonical NaN
+                with m.Elif(a1.s):
+                    comb += self.o.z.nan(0)
+
+                # if a is inf return zero (-ve already excluded, above)
+                with m.Elif(a1.is_inf):
+                    comb += self.o.z.zero(0)
+
+                # Denormalised Number checks next, so pass a/b data through
+                with m.Else():
+                    comb += self.o.out_do_z.eq(0)
+
+                comb += self.o.oz.eq(self.o.z.v)
+
+        # pass through context
+        comb += self.o.ctx.eq(self.i.ctx)
 
         return m
 
 
-class FPDIVSpecialCases(FPState):
+class FPDIVSpecialCasesDeNorm(PipeModBaseChain):
     """ special cases: NaNs, infs, zeros, denormalised
-        NOTE: some of these are unique to div.  see "Special Operations"
-        https://steve.hollasch.net/cgindex/coding/ieeefloat.html
     """
 
-    def __init__(self, width, id_wid):
-        FPState.__init__(self, "special_cases")
-        self.mod = FPDIVSpecialCasesMod(width)
-        self.out_z = self.mod.ospec()
-        self.out_do_z = Signal(reset_less=True)
-
-    def setup(self, m, i):
+    def get_chain(self):
         """ links module to inputs and outputs
         """
-        self.mod.setup(m, i, self.out_do_z)
-        m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
-        m.d.sync += self.out_z.mid.eq(self.mod.o.mid)  # (and mid)
-
-    def action(self, m):
-        self.idsync(m)
-        with m.If(self.out_do_z):
-            m.next = "put_z"
-        with m.Else():
-            m.next = "denormalise"
-
-
-class FPDIVSpecialCasesDeNorm(FPState, SimpleHandshake):
-    """ special cases: NaNs, infs, zeros, denormalised
-    """
-
-    def __init__(self, width, id_wid):
-        FPState.__init__(self, "special_cases")
-        self.width = width
-        self.id_wid = id_wid
-        SimpleHandshake.__init__(self, self) # pipe is its own stage
-        self.out = self.ospec()
-
-    def ispec(self):
-        return FPADDBaseData(self.width, self.id_wid) # SpecialCases ispec
-
-    def ospec(self):
-        return FPSCData(self.width, self.id_wid, False) # DeNorm ospec
-
-    def setup(self, m, i):
-        """ links module to inputs and outputs
-        """
-        smod = FPDIVSpecialCasesMod(self.width, self.id_wid)
-        dmod = FPAddDeNormMod(self.width, self.id_wid, False)
-
-        chain = StageChain([smod, dmod])
-        chain.setup(m, i)
-
-        # only needed for break-out (early-out)
-        # self.out_do_z = smod.o.out_do_z
-
-        self.o = dmod.o
-
-    def process(self, i):
-        return self.o
-
-    def action(self, m):
-        # for break-out (early-out)
-        #with m.If(self.out_do_z):
-        #    m.next = "put_z"
-        #with m.Else():
-            m.d.sync += self.out.eq(self.process(None))
-            m.next = "align"
-
+        smod = FPDIVSpecialCasesMod(self.pspec)
+        dmod = FPAddDeNormMod(self.pspec, False)
+        amod = FPAlignModSingle(self.pspec, False)
 
+        return [smod, dmod, amod]