remove m.If/Elif in fpdiv sqrt, replace with Mux
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
1 """ IEEE Floating Point Divider
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jacob Lifshay
5
6 Relevant bugreports:
7 * http://bugs.libre-riscv.org/show_bug.cgi?id=99
8 * http://bugs.libre-riscv.org/show_bug.cgi?id=43
9 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
10 """
11
12 from nmigen import Module, Signal, Cat, Mux
13 from nmigen.cli import main, verilog
14 from math import log
15
16 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
17 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
18 from ieee754.fpcommon.basedata import FPBaseData
19 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
20 from ieee754.fpmul.align import FPAlignModSingle
21 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
22
23
24 class FPDIVSpecialCasesMod(PipeModBase):
25 """ special cases: NaNs, infs, zeros, denormalised
26 see "Special Operations"
27 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
28 """
29
30 def __init__(self, pspec):
31 super().__init__(pspec, "specialcases")
32
33 def ispec(self):
34 return FPBaseData(self.pspec)
35
36 def ospec(self):
37 return FPSCData(self.pspec, False)
38
39 def elaborate(self, platform):
40 m = Module()
41 comb = m.d.comb
42
43 # decode: XXX really should move to separate stage
44 width = self.pspec.width
45 a1 = FPNumBaseRecord(width, False, name="a1")
46 b1 = FPNumBaseRecord(width, False, name="b1")
47 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
48 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
49 comb += [a1.v.eq(self.i.a),
50 b1.v.eq(self.i.b),
51 self.o.a.eq(a1),
52 self.o.b.eq(b1)
53 ]
54
55 # temporaries (used below)
56 sabx = Signal(reset_less=True) # sign a xor b (sabx, get it?)
57 t_abnan = Signal(reset_less=True)
58 t_abinf = Signal(reset_less=True)
59 t_a1inf = Signal(reset_less=True)
60 t_b1inf = Signal(reset_less=True)
61 t_a1nan = Signal(reset_less=True)
62 t_a1zero = Signal(reset_less=True)
63 t_b1zero = Signal(reset_less=True)
64 t_abz = Signal(reset_less=True)
65 t_special_div = Signal(reset_less=True)
66 t_special_sqrt = Signal(reset_less=True)
67 t_special_rsqrt = Signal(reset_less=True)
68
69 comb += sabx.eq(a1.s ^ b1.s)
70 comb += t_abnan.eq(a1.is_nan | b1.is_nan)
71 comb += t_abinf.eq(a1.is_inf & b1.is_inf)
72 comb += t_a1inf.eq(a1.is_inf)
73 comb += t_b1inf.eq(b1.is_inf)
74 comb += t_a1nan.eq(a1.is_nan)
75 comb += t_abz.eq(a1.is_zero & b1.is_zero)
76 comb += t_a1zero.eq(a1.is_zero)
77 comb += t_b1zero.eq(b1.is_zero)
78
79 # prepare inf/zero/nans
80 z_zero = FPNumBaseRecord(width, False, name="z_zero")
81 z_zeroa = FPNumBaseRecord(width, False, name="z_zeroa")
82 z_zeroab = FPNumBaseRecord(width, False, name="z_zeroab")
83 z_nan = FPNumBaseRecord(width, False, name="z_nan")
84 z_infa = FPNumBaseRecord(width, False, name="z_infa")
85 z_infb = FPNumBaseRecord(width, False, name="z_infb")
86 z_infab = FPNumBaseRecord(width, False, name="z_infab")
87 comb += z_zero.zero(0)
88 comb += z_zeroa.zero(a1.s)
89 comb += z_zeroab.zero(sabx)
90 comb += z_nan.nan(0)
91 comb += z_infa.inf(a1.s)
92 comb += z_infb.inf(b1.s)
93 comb += z_infab.inf(sabx)
94
95 comb += t_special_div.eq(Cat(t_b1zero, t_a1zero, t_b1inf, t_a1inf,
96 t_abinf, t_abnan).bool())
97 comb += t_special_sqrt.eq(Cat(t_a1zero, a1.s, t_a1inf,
98 t_a1nan).bool())
99
100 # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
101 with m.Switch(self.i.ctx.op):
102
103 ########## DIV ############
104 with m.Case(int(DP.UDivRem)):
105
106 # any special cases?
107 comb += self.o.out_do_z.eq(t_special_div)
108
109 # if a is NaN or b is NaN return NaN
110 # if a is inf and b is Inf return NaN
111 # if a is inf return inf
112 # if b is inf return zero
113 # if a is zero return zero (or NaN if b is zero)
114 # b is zero return NaN
115 # if b is zero return Inf
116
117 # sigh inverse order on the above, Mux-cascade
118 oz = 0
119 oz = Mux(t_b1zero, z_infab.v, oz)
120 oz = Mux(t_a1zero, Mux(t_b1zero, z_nan.v, z_zeroab.v), oz)
121 oz = Mux(t_b1inf, z_zeroab.v, oz)
122 oz = Mux(t_a1inf, z_infab.v, oz)
123 oz = Mux(t_abinf, z_nan.v, oz)
124 oz = Mux(t_abnan, z_nan.v, oz)
125
126 comb += self.o.oz.eq(oz)
127
128 ########## SQRT ############
129 with m.Case(int(DP.SqrtRem)):
130
131 # any special cases?
132 comb += self.o.out_do_z.eq(t_special_sqrt)
133
134 # if a is zero return zero
135 # -ve number is NaN
136 # if a is inf return inf
137 # if a is NaN return NaN
138
139 oz = 0
140 oz = Mux(t_a1nan, z_nan.v, oz)
141 oz = Mux(t_a1inf, z_infab.v, oz)
142 oz = Mux(a1.s, z_nan.v, oz)
143 oz = Mux(t_a1zero, z_zeroa.v, oz)
144
145 comb += self.o.oz.eq(oz)
146
147 ########## RSQRT ############
148 with m.Case(int(DP.RSqrtRem)):
149
150 # if a is NaN return canonical NaN
151 with m.If(a1.is_nan):
152 comb += self.o.z.nan(0)
153
154 # if a is +/- zero return +/- INF
155 with m.Elif(a1.is_zero):
156 # this includes the "weird" case 1/sqrt(-0) == -Inf
157 comb += self.o.z.inf(a1.s)
158
159 # -ve number is canonical NaN
160 with m.Elif(a1.s):
161 comb += self.o.z.nan(0)
162
163 # if a is inf return zero (-ve already excluded, above)
164 with m.Elif(a1.is_inf):
165 comb += self.o.z.zero(0)
166
167 # Denormalised Number checks next, so pass a/b data through
168 with m.Else():
169 comb += self.o.out_do_z.eq(0)
170
171 comb += self.o.oz.eq(self.o.z.v)
172
173 # pass through context
174 comb += self.o.ctx.eq(self.i.ctx)
175
176 return m
177
178
179 class FPDIVSpecialCasesDeNorm(PipeModBaseChain):
180 """ special cases: NaNs, infs, zeros, denormalised
181 """
182
183 def get_chain(self):
184 """ links module to inputs and outputs
185 """
186 smod = FPDIVSpecialCasesMod(self.pspec)
187 dmod = FPAddDeNormMod(self.pspec, False)
188 amod = FPAlignModSingle(self.pspec, False)
189
190 return [smod, dmod, amod]