remove m.If/Elif in fpdiv sqrt, replace with Mux
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
index 20f670288b3efa53f5014df71da73466ccd0e910..ee2ff33eb277b900886c91e20463df90b25e41d3 100644 (file)
@@ -9,13 +9,13 @@ Relevant bugreports:
 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
 """
 
 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
 """
 
-from nmigen import Module, Signal
+from nmigen import Module, Signal, Cat, Mux
 from nmigen.cli import main, verilog
 from math import log
 
 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
 from nmigen.cli import main, verilog
 from math import log
 
 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
-from ieee754.fpcommon.getop import FPADDBaseData
+from ieee754.fpcommon.basedata import FPBaseData
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 from ieee754.fpmul.align import FPAlignModSingle
 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
 from ieee754.fpmul.align import FPAlignModSingle
 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
@@ -31,7 +31,7 @@ class FPDIVSpecialCasesMod(PipeModBase):
         super().__init__(pspec, "specialcases")
 
     def ispec(self):
         super().__init__(pspec, "specialcases")
 
     def ispec(self):
-        return FPADDBaseData(self.pspec)
+        return FPBaseData(self.pspec)
 
     def ospec(self):
         return FPSCData(self.pspec, False)
 
     def ospec(self):
         return FPSCData(self.pspec, False)
@@ -41,8 +41,9 @@ class FPDIVSpecialCasesMod(PipeModBase):
         comb = m.d.comb
 
         # decode: XXX really should move to separate stage
         comb = m.d.comb
 
         # decode: XXX really should move to separate stage
-        a1 = FPNumBaseRecord(self.pspec.width, False, name="a1")
-        b1 = FPNumBaseRecord(self.pspec.width, False, name="b1")
+        width = self.pspec.width
+        a1 = FPNumBaseRecord(width, False, name="a1")
+        b1 = FPNumBaseRecord(width, False, name="b1")
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
         comb += [a1.v.eq(self.i.a),
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
         comb += [a1.v.eq(self.i.a),
@@ -53,75 +54,98 @@ class FPDIVSpecialCasesMod(PipeModBase):
 
         # temporaries (used below)
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
 
         # temporaries (used below)
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
-        abnan = Signal(reset_less=True)
-        abinf = Signal(reset_less=True)
+        t_abnan = Signal(reset_less=True)
+        t_abinf = Signal(reset_less=True)
+        t_a1inf = Signal(reset_less=True)
+        t_b1inf = Signal(reset_less=True)
+        t_a1nan = Signal(reset_less=True)
+        t_a1zero = Signal(reset_less=True)
+        t_b1zero = Signal(reset_less=True)
+        t_abz = Signal(reset_less=True)
+        t_special_div = Signal(reset_less=True)
+        t_special_sqrt = Signal(reset_less=True)
+        t_special_rsqrt = Signal(reset_less=True)
 
         comb += sabx.eq(a1.s ^ b1.s)
 
         comb += sabx.eq(a1.s ^ b1.s)
-        comb += abnan.eq(a1.is_nan | b1.is_nan)
-        comb += abinf.eq(a1.is_inf & b1.is_inf)
-
-        # default (overridden if needed)
-        comb += self.o.out_do_z.eq(1)
+        comb += t_abnan.eq(a1.is_nan | b1.is_nan)
+        comb += t_abinf.eq(a1.is_inf & b1.is_inf)
+        comb += t_a1inf.eq(a1.is_inf)
+        comb += t_b1inf.eq(b1.is_inf)
+        comb += t_a1nan.eq(a1.is_nan)
+        comb += t_abz.eq(a1.is_zero & b1.is_zero)
+        comb += t_a1zero.eq(a1.is_zero)
+        comb += t_b1zero.eq(b1.is_zero)
+
+        # prepare inf/zero/nans
+        z_zero = FPNumBaseRecord(width, False, name="z_zero")
+        z_zeroa = FPNumBaseRecord(width, False, name="z_zeroa")
+        z_zeroab = FPNumBaseRecord(width, False, name="z_zeroab")
+        z_nan = FPNumBaseRecord(width, False, name="z_nan")
+        z_infa = FPNumBaseRecord(width, False, name="z_infa")
+        z_infb = FPNumBaseRecord(width, False, name="z_infb")
+        z_infab = FPNumBaseRecord(width, False, name="z_infab")
+        comb += z_zero.zero(0)
+        comb += z_zeroa.zero(a1.s)
+        comb += z_zeroab.zero(sabx)
+        comb += z_nan.nan(0)
+        comb += z_infa.inf(a1.s)
+        comb += z_infb.inf(b1.s)
+        comb += z_infab.inf(sabx)
+
+        comb += t_special_div.eq(Cat(t_b1zero, t_a1zero, t_b1inf, t_a1inf,
+                                     t_abinf, t_abnan).bool())
+        comb += t_special_sqrt.eq(Cat(t_a1zero, a1.s, t_a1inf,
+                                     t_a1nan).bool())
 
         # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
         with m.Switch(self.i.ctx.op):
 
 
         # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
         with m.Switch(self.i.ctx.op):
 
-            with m.Case(int(DP.UDivRem)): # DIV
+            ########## DIV ############
+            with m.Case(int(DP.UDivRem)):
 
 
-                # if a is NaN or b is NaN return NaN
-                with m.If(abnan):
-                    comb += self.o.z.nan(0)
+                # any special cases?
+                comb += self.o.out_do_z.eq(t_special_div)
 
 
+                # if a is NaN or b is NaN return NaN
                 # if a is inf and b is Inf return NaN
                 # if a is inf and b is Inf return NaN
-                with m.Elif(abinf):
-                    comb += self.o.z.nan(0)
-
                 # if a is inf return inf
                 # if a is inf return inf
-                with m.Elif(a1.is_inf):
-                    comb += self.o.z.inf(sabx)
-
                 # if b is inf return zero
                 # if b is inf return zero
-                with m.Elif(b1.is_inf):
-                    comb += self.o.z.zero(sabx)
-
                 # if a is zero return zero (or NaN if b is zero)
                 # if a is zero return zero (or NaN if b is zero)
-                with m.Elif(a1.is_zero):
-                    comb += self.o.z.zero(sabx)
                     # b is zero return NaN
                     # b is zero return NaN
-                    with m.If(b1.is_zero):
-                        comb += self.o.z.nan(0)
-
                 # if b is zero return Inf
                 # if b is zero return Inf
-                with m.Elif(b1.is_zero):
-                    comb += self.o.z.inf(sabx)
 
 
-                # Denormalised Number checks next, so pass a/b data through
-                with m.Else():
-                    comb += self.o.out_do_z.eq(0)
+                # sigh inverse order on the above, Mux-cascade
+                oz = 0
+                oz = Mux(t_b1zero, z_infab.v, oz)
+                oz = Mux(t_a1zero, Mux(t_b1zero, z_nan.v, z_zeroab.v), oz)
+                oz = Mux(t_b1inf, z_zeroab.v, oz)
+                oz = Mux(t_a1inf, z_infab.v, oz)
+                oz = Mux(t_abinf, z_nan.v, oz)
+                oz = Mux(t_abnan, z_nan.v, oz)
 
 
-            with m.Case(int(DP.SqrtRem)): # SQRT
+                comb += self.o.oz.eq(oz)
 
 
-                # if a is zero return zero
-                with m.If(a1.is_zero):
-                    comb += self.o.z.zero(a1.s)
+            ########## SQRT ############
+            with m.Case(int(DP.SqrtRem)):
 
 
-                # -ve number is NaN
-                with m.Elif(a1.s):
-                    comb += self.o.z.nan(0)
+                # any special cases?
+                comb += self.o.out_do_z.eq(t_special_sqrt)
 
 
+                # if a is zero return zero
+                # -ve number is NaN
                 # if a is inf return inf
                 # if a is inf return inf
-                with m.Elif(a1.is_inf):
-                    comb += self.o.z.inf(sabx)
-
                 # if a is NaN return NaN
                 # if a is NaN return NaN
-                with m.Elif(a1.is_nan):
-                    comb += self.o.z.nan(0)
 
 
-                # Denormalised Number checks next, so pass a/b data through
-                with m.Else():
-                    comb += self.o.out_do_z.eq(0)
+                oz = 0
+                oz = Mux(t_a1nan, z_nan.v, oz)
+                oz = Mux(t_a1inf, z_infab.v, oz)
+                oz = Mux(a1.s, z_nan.v, oz)
+                oz = Mux(t_a1zero, z_zeroa.v, oz)
+
+                comb += self.o.oz.eq(oz)
 
 
-            with m.Case(int(DP.RSqrtRem)): # RSQRT
+            ########## RSQRT ############
+            with m.Case(int(DP.RSqrtRem)):
 
                 # if a is NaN return canonical NaN
                 with m.If(a1.is_nan):
 
                 # if a is NaN return canonical NaN
                 with m.If(a1.is_nan):
@@ -144,7 +168,9 @@ class FPDIVSpecialCasesMod(PipeModBase):
                 with m.Else():
                     comb += self.o.out_do_z.eq(0)
 
                 with m.Else():
                     comb += self.o.out_do_z.eq(0)
 
-        comb += self.o.oz.eq(self.o.z.v)
+                comb += self.o.oz.eq(self.o.z.v)
+
+        # pass through context
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m