remove m.If/Elif from fpdiv specialcases
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
index 92eec06222c6bc4232b01e8fbfa46f27162d06f0..b8698c0836763c9cf3a7fbff375a174493a630f8 100644 (file)
@@ -9,7 +9,7 @@ Relevant bugreports:
 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
 """
 
-from nmigen import Module, Signal
+from nmigen import Module, Signal, Cat, Mux
 from nmigen.cli import main, verilog
 from math import log
 
@@ -41,8 +41,9 @@ class FPDIVSpecialCasesMod(PipeModBase):
         comb = m.d.comb
 
         # decode: XXX really should move to separate stage
-        a1 = FPNumBaseRecord(self.pspec.width, False, name="a1")
-        b1 = FPNumBaseRecord(self.pspec.width, False, name="b1")
+        width = self.pspec.width
+        a1 = FPNumBaseRecord(width, False, name="a1")
+        b1 = FPNumBaseRecord(width, False, name="b1")
         m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
         m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
         comb += [a1.v.eq(self.i.a),
@@ -53,15 +54,42 @@ class FPDIVSpecialCasesMod(PipeModBase):
 
         # temporaries (used below)
         sabx = Signal(reset_less=True)   # sign a xor b (sabx, get it?)
-        abnan = Signal(reset_less=True)
-        abinf = Signal(reset_less=True)
+        t_abnan = Signal(reset_less=True)
+        t_abinf = Signal(reset_less=True)
+        t_a1inf = Signal(reset_less=True)
+        t_b1inf = Signal(reset_less=True)
+        t_a1zero = Signal(reset_less=True)
+        t_b1zero = Signal(reset_less=True)
+        t_abz = Signal(reset_less=True)
+        t_special_div = Signal(reset_less=True)
+        t_special_sqrt = Signal(reset_less=True)
+        t_special_rsqrt = Signal(reset_less=True)
 
         comb += sabx.eq(a1.s ^ b1.s)
-        comb += abnan.eq(a1.is_nan | b1.is_nan)
-        comb += abinf.eq(a1.is_inf & b1.is_inf)
-
-        # default (overridden if needed)
-        comb += self.o.out_do_z.eq(1)
+        comb += t_abnan.eq(a1.is_nan | b1.is_nan)
+        comb += t_abinf.eq(a1.is_inf & b1.is_inf)
+        comb += t_a1inf.eq(a1.is_inf)
+        comb += t_b1inf.eq(b1.is_inf)
+        comb += t_abz.eq(a1.is_zero & b1.is_zero)
+        comb += t_a1zero.eq(a1.is_zero)
+        comb += t_b1zero.eq(b1.is_zero)
+
+        # prepare inf/zero/nans
+        z_zero = FPNumBaseRecord(width, False, name="z_zero")
+        z_zeroab = FPNumBaseRecord(width, False, name="z_zeroab")
+        z_nan = FPNumBaseRecord(width, False, name="z_nan")
+        z_infa = FPNumBaseRecord(width, False, name="z_infa")
+        z_infb = FPNumBaseRecord(width, False, name="z_infb")
+        z_infab = FPNumBaseRecord(width, False, name="z_infab")
+        comb += z_zero.zero(0)
+        comb += z_zeroab.zero(sabx)
+        comb += z_nan.nan(0)
+        comb += z_infa.inf(a1.s)
+        comb += z_infb.inf(b1.s)
+        comb += z_infab.inf(sabx)
+
+        comb += t_special_div.eq(Cat(t_b1zero, t_a1zero, t_b1inf, t_a1inf,
+                                     t_abinf, t_abnan).bool())
 
         # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
         with m.Switch(self.i.ctx.op):
@@ -69,36 +97,27 @@ class FPDIVSpecialCasesMod(PipeModBase):
             ########## DIV ############
             with m.Case(int(DP.UDivRem)):
 
-                # if a is NaN or b is NaN return NaN
-                with m.If(abnan):
-                    comb += self.o.z.nan(0)
+                # any special cases?
+                comb += self.o.out_do_z.eq(t_special_div)
 
+                # if a is NaN or b is NaN return NaN
                 # if a is inf and b is Inf return NaN
-                with m.Elif(abinf):
-                    comb += self.o.z.nan(0)
-
                 # if a is inf return inf
-                with m.Elif(a1.is_inf):
-                    comb += self.o.z.inf(sabx)
-
                 # if b is inf return zero
-                with m.Elif(b1.is_inf):
-                    comb += self.o.z.zero(sabx)
-
                 # if a is zero return zero (or NaN if b is zero)
-                with m.Elif(a1.is_zero):
-                    comb += self.o.z.zero(sabx)
                     # b is zero return NaN
-                    with m.If(b1.is_zero):
-                        comb += self.o.z.nan(0)
-
                 # if b is zero return Inf
-                with m.Elif(b1.is_zero):
-                    comb += self.o.z.inf(sabx)
 
-                # Denormalised Number checks next, so pass a/b data through
-                with m.Else():
-                    comb += self.o.out_do_z.eq(0)
+                # sigh inverse order on the above, Mux-cascade
+                oz = 0
+                oz = Mux(t_b1zero, z_infab.v, oz)
+                oz = Mux(t_a1zero, Mux(t_b1zero, z_nan.v, z_zeroab.v), oz)
+                oz = Mux(t_b1inf, z_zeroab.v, oz)
+                oz = Mux(t_a1inf, z_infab.v, oz)
+                oz = Mux(t_abinf, z_nan.v, oz)
+                oz = Mux(t_abnan, z_nan.v, oz)
+
+                comb += self.o.oz.eq(oz)
 
             ########## SQRT ############
             with m.Case(int(DP.SqrtRem)):
@@ -123,6 +142,8 @@ class FPDIVSpecialCasesMod(PipeModBase):
                 with m.Else():
                     comb += self.o.out_do_z.eq(0)
 
+                comb += self.o.oz.eq(self.o.z.v)
+
             ########## RSQRT ############
             with m.Case(int(DP.RSqrtRem)):
 
@@ -147,8 +168,9 @@ class FPDIVSpecialCasesMod(PipeModBase):
                 with m.Else():
                     comb += self.o.out_do_z.eq(0)
 
+                comb += self.o.oz.eq(self.o.z.v)
+
         # pass through context
-        comb += self.o.oz.eq(self.o.z.v)
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m