remove m.If/Elif from fpdiv specialcases
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
1 """ IEEE Floating Point Divider
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jacob Lifshay
5
6 Relevant bugreports:
7 * http://bugs.libre-riscv.org/show_bug.cgi?id=99
8 * http://bugs.libre-riscv.org/show_bug.cgi?id=43
9 * http://bugs.libre-riscv.org/show_bug.cgi?id=44
10 """
11
12 from nmigen import Module, Signal, Cat, Mux
13 from nmigen.cli import main, verilog
14 from math import log
15
16 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
17 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
18 from ieee754.fpcommon.basedata import FPBaseData
19 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
20 from ieee754.fpmul.align import FPAlignModSingle
21 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation as DP
22
23
24 class FPDIVSpecialCasesMod(PipeModBase):
25 """ special cases: NaNs, infs, zeros, denormalised
26 see "Special Operations"
27 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
28 """
29
30 def __init__(self, pspec):
31 super().__init__(pspec, "specialcases")
32
33 def ispec(self):
34 return FPBaseData(self.pspec)
35
36 def ospec(self):
37 return FPSCData(self.pspec, False)
38
39 def elaborate(self, platform):
40 m = Module()
41 comb = m.d.comb
42
43 # decode: XXX really should move to separate stage
44 width = self.pspec.width
45 a1 = FPNumBaseRecord(width, False, name="a1")
46 b1 = FPNumBaseRecord(width, False, name="b1")
47 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
48 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
49 comb += [a1.v.eq(self.i.a),
50 b1.v.eq(self.i.b),
51 self.o.a.eq(a1),
52 self.o.b.eq(b1)
53 ]
54
55 # temporaries (used below)
56 sabx = Signal(reset_less=True) # sign a xor b (sabx, get it?)
57 t_abnan = Signal(reset_less=True)
58 t_abinf = Signal(reset_less=True)
59 t_a1inf = Signal(reset_less=True)
60 t_b1inf = Signal(reset_less=True)
61 t_a1zero = Signal(reset_less=True)
62 t_b1zero = Signal(reset_less=True)
63 t_abz = Signal(reset_less=True)
64 t_special_div = Signal(reset_less=True)
65 t_special_sqrt = Signal(reset_less=True)
66 t_special_rsqrt = Signal(reset_less=True)
67
68 comb += sabx.eq(a1.s ^ b1.s)
69 comb += t_abnan.eq(a1.is_nan | b1.is_nan)
70 comb += t_abinf.eq(a1.is_inf & b1.is_inf)
71 comb += t_a1inf.eq(a1.is_inf)
72 comb += t_b1inf.eq(b1.is_inf)
73 comb += t_abz.eq(a1.is_zero & b1.is_zero)
74 comb += t_a1zero.eq(a1.is_zero)
75 comb += t_b1zero.eq(b1.is_zero)
76
77 # prepare inf/zero/nans
78 z_zero = FPNumBaseRecord(width, False, name="z_zero")
79 z_zeroab = FPNumBaseRecord(width, False, name="z_zeroab")
80 z_nan = FPNumBaseRecord(width, False, name="z_nan")
81 z_infa = FPNumBaseRecord(width, False, name="z_infa")
82 z_infb = FPNumBaseRecord(width, False, name="z_infb")
83 z_infab = FPNumBaseRecord(width, False, name="z_infab")
84 comb += z_zero.zero(0)
85 comb += z_zeroab.zero(sabx)
86 comb += z_nan.nan(0)
87 comb += z_infa.inf(a1.s)
88 comb += z_infb.inf(b1.s)
89 comb += z_infab.inf(sabx)
90
91 comb += t_special_div.eq(Cat(t_b1zero, t_a1zero, t_b1inf, t_a1inf,
92 t_abinf, t_abnan).bool())
93
94 # select one of 3 different sets of specialcases (DIV, SQRT, RSQRT)
95 with m.Switch(self.i.ctx.op):
96
97 ########## DIV ############
98 with m.Case(int(DP.UDivRem)):
99
100 # any special cases?
101 comb += self.o.out_do_z.eq(t_special_div)
102
103 # if a is NaN or b is NaN return NaN
104 # if a is inf and b is Inf return NaN
105 # if a is inf return inf
106 # if b is inf return zero
107 # if a is zero return zero (or NaN if b is zero)
108 # b is zero return NaN
109 # if b is zero return Inf
110
111 # sigh inverse order on the above, Mux-cascade
112 oz = 0
113 oz = Mux(t_b1zero, z_infab.v, oz)
114 oz = Mux(t_a1zero, Mux(t_b1zero, z_nan.v, z_zeroab.v), oz)
115 oz = Mux(t_b1inf, z_zeroab.v, oz)
116 oz = Mux(t_a1inf, z_infab.v, oz)
117 oz = Mux(t_abinf, z_nan.v, oz)
118 oz = Mux(t_abnan, z_nan.v, oz)
119
120 comb += self.o.oz.eq(oz)
121
122 ########## SQRT ############
123 with m.Case(int(DP.SqrtRem)):
124
125 # if a is zero return zero
126 with m.If(a1.is_zero):
127 comb += self.o.z.zero(a1.s)
128
129 # -ve number is NaN
130 with m.Elif(a1.s):
131 comb += self.o.z.nan(0)
132
133 # if a is inf return inf
134 with m.Elif(a1.is_inf):
135 comb += self.o.z.inf(sabx)
136
137 # if a is NaN return NaN
138 with m.Elif(a1.is_nan):
139 comb += self.o.z.nan(0)
140
141 # Denormalised Number checks next, so pass a/b data through
142 with m.Else():
143 comb += self.o.out_do_z.eq(0)
144
145 comb += self.o.oz.eq(self.o.z.v)
146
147 ########## RSQRT ############
148 with m.Case(int(DP.RSqrtRem)):
149
150 # if a is NaN return canonical NaN
151 with m.If(a1.is_nan):
152 comb += self.o.z.nan(0)
153
154 # if a is +/- zero return +/- INF
155 with m.Elif(a1.is_zero):
156 # this includes the "weird" case 1/sqrt(-0) == -Inf
157 comb += self.o.z.inf(a1.s)
158
159 # -ve number is canonical NaN
160 with m.Elif(a1.s):
161 comb += self.o.z.nan(0)
162
163 # if a is inf return zero (-ve already excluded, above)
164 with m.Elif(a1.is_inf):
165 comb += self.o.z.zero(0)
166
167 # Denormalised Number checks next, so pass a/b data through
168 with m.Else():
169 comb += self.o.out_do_z.eq(0)
170
171 comb += self.o.oz.eq(self.o.z.v)
172
173 # pass through context
174 comb += self.o.ctx.eq(self.i.ctx)
175
176 return m
177
178
179 class FPDIVSpecialCasesDeNorm(PipeModBaseChain):
180 """ special cases: NaNs, infs, zeros, denormalised
181 """
182
183 def get_chain(self):
184 """ links module to inputs and outputs
185 """
186 smod = FPDIVSpecialCasesMod(self.pspec)
187 dmod = FPAddDeNormMod(self.pspec, False)
188 amod = FPAlignModSingle(self.pspec, False)
189
190 return [smod, dmod, amod]