correct FPRSQRT specialcases
[ieee754fpu.git] / src / ieee754 / fpdiv / specialcases.py
1 # IEEE Floating Point Multiplier
2
3 from nmigen import Module, Signal, Cat, Const, Elaboratable
4 from nmigen.cli import main, verilog
5 from math import log
6
7 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
8 from nmutil.singlepipe import SimpleHandshake, StageChain
9
10 from ieee754.fpcommon.fpbase import FPState, FPID
11 from ieee754.fpcommon.getop import FPADDBaseData
12 from ieee754.fpcommon.denorm import (FPSCData, FPAddDeNormMod)
13 from ieee754.fpmul.align import FPAlignModSingle
14
15
16 class FPDIVSpecialCasesMod(Elaboratable):
17 """ special cases: NaNs, infs, zeros, denormalised
18 see "Special Operations"
19 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
20 """
21
22 def __init__(self, pspec):
23 self.pspec = pspec
24 self.i = self.ispec()
25 self.o = self.ospec()
26
27 def ispec(self):
28 return FPADDBaseData(self.pspec)
29
30 def ospec(self):
31 return FPSCData(self.pspec, False)
32
33 def setup(self, m, i):
34 """ links module to inputs and outputs
35 """
36 m.submodules.specialcases = self
37 m.d.comb += self.i.eq(i)
38
39 def process(self, i):
40 return self.o
41
42 def elaborate(self, platform):
43 m = Module()
44
45 #m.submodules.sc_out_z = self.o.z
46
47 # decode: XXX really should move to separate stage
48 a1 = FPNumBaseRecord(self.pspec.width, False)
49 b1 = FPNumBaseRecord(self.pspec.width, False)
50 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
51 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
52 m.d.comb += [a1.v.eq(self.i.a),
53 b1.v.eq(self.i.b),
54 self.o.a.eq(a1),
55 self.o.b.eq(b1)
56 ]
57
58 sabx = Signal(reset_less=True) # sign a xor b (sabx, get it?)
59 m.d.comb += sabx.eq(a1.s ^ b1.s)
60
61 abnan = Signal(reset_less=True)
62 m.d.comb += abnan.eq(a1.is_nan | b1.is_nan)
63
64 abinf = Signal(reset_less=True)
65 m.d.comb += abinf.eq(a1.is_inf & b1.is_inf)
66
67 with m.If(self.i.ctx.op == 0): # DIV
68 # if a is NaN or b is NaN return NaN
69 with m.If(abnan):
70 m.d.comb += self.o.out_do_z.eq(1)
71 m.d.comb += self.o.z.nan(0)
72
73 # if a is inf and b is Inf return NaN
74 with m.Elif(abinf):
75 m.d.comb += self.o.out_do_z.eq(1)
76 m.d.comb += self.o.z.nan(0)
77
78 # if a is inf return inf
79 with m.Elif(a1.is_inf):
80 m.d.comb += self.o.out_do_z.eq(1)
81 m.d.comb += self.o.z.inf(sabx)
82
83 # if b is inf return zero
84 with m.Elif(b1.is_inf):
85 m.d.comb += self.o.out_do_z.eq(1)
86 m.d.comb += self.o.z.zero(sabx)
87
88 # if a is zero return zero (or NaN if b is zero)
89 with m.Elif(a1.is_zero):
90 m.d.comb += self.o.out_do_z.eq(1)
91 m.d.comb += self.o.z.zero(sabx)
92 # b is zero return NaN
93 with m.If(b1.is_zero):
94 m.d.comb += self.o.z.nan(0)
95
96 # if b is zero return Inf
97 with m.Elif(b1.is_zero):
98 m.d.comb += self.o.out_do_z.eq(1)
99 m.d.comb += self.o.z.inf(sabx)
100
101 # Denormalised Number checks next, so pass a/b data through
102 with m.Else():
103 m.d.comb += self.o.out_do_z.eq(0)
104
105 with m.If(self.i.ctx.op == 1): # SQRT
106
107 # if a is zero return zero
108 with m.If(a1.is_zero):
109 m.d.comb += self.o.out_do_z.eq(1)
110 m.d.comb += self.o.z.zero(a1.s)
111
112 # -ve number is NaN
113 with m.Elif(a1.s):
114 m.d.comb += self.o.out_do_z.eq(1)
115 m.d.comb += self.o.z.nan(0)
116
117 # if a is inf return inf
118 with m.Elif(a1.is_inf):
119 m.d.comb += self.o.out_do_z.eq(1)
120 m.d.comb += self.o.z.inf(sabx)
121
122 # if a is NaN return NaN
123 with m.Elif(a1.is_nan):
124 m.d.comb += self.o.out_do_z.eq(1)
125 m.d.comb += self.o.z.nan(0)
126
127 # Denormalised Number checks next, so pass a/b data through
128 with m.Else():
129 m.d.comb += self.o.out_do_z.eq(0)
130
131 with m.If(self.i.ctx.op == 2): # RSQRT
132
133 # if a is NaN return canonical NaN
134 with m.If(a1.is_nan):
135 m.d.comb += self.o.out_do_z.eq(1)
136 m.d.comb += self.o.z.nan(0)
137
138 # if a is +/- zero return +/- INF
139 with m.Elif(a1.is_zero):
140 m.d.comb += self.o.out_do_z.eq(1)
141 # this includes the "weird" case 1/sqrt(-0) == -Inf
142 m.d.comb += self.o.z.inf(a1.s)
143
144 # -ve number is canonical NaN
145 with m.Elif(a1.s):
146 m.d.comb += self.o.out_do_z.eq(1)
147 m.d.comb += self.o.z.nan(0)
148
149 # if a is inf return zero (-ve already excluded, above)
150 with m.Elif(a1.is_inf):
151 m.d.comb += self.o.out_do_z.eq(1)
152 m.d.comb += self.o.z.zero(0)
153
154 # Denormalised Number checks next, so pass a/b data through
155 with m.Else():
156 m.d.comb += self.o.out_do_z.eq(0)
157
158
159 m.d.comb += self.o.oz.eq(self.o.z.v)
160 m.d.comb += self.o.ctx.eq(self.i.ctx)
161
162 return m
163
164
165 class FPDIVSpecialCases(FPState):
166 """ special cases: NaNs, infs, zeros, denormalised
167 NOTE: some of these are unique to div. see "Special Operations"
168 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
169 """
170
171 def __init__(self, pspec):
172 FPState.__init__(self, "special_cases")
173 self.mod = FPDIVSpecialCasesMod(pspec)
174 self.out_z = self.mod.ospec()
175 self.out_do_z = Signal(reset_less=True)
176
177 def setup(self, m, i):
178 """ links module to inputs and outputs
179 """
180 self.mod.setup(m, i, self.out_do_z)
181 m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
182 m.d.sync += self.out_z.mid.eq(self.mod.o.mid) # (and mid)
183
184 def action(self, m):
185 self.idsync(m)
186 with m.If(self.out_do_z):
187 m.next = "put_z"
188 with m.Else():
189 m.next = "denormalise"
190
191
192 class FPDIVSpecialCasesDeNorm(FPState, SimpleHandshake):
193 """ special cases: NaNs, infs, zeros, denormalised
194 """
195
196 def __init__(self, pspec):
197 FPState.__init__(self, "special_cases")
198 self.pspec = pspec
199 SimpleHandshake.__init__(self, self) # pipe is its own stage
200 self.out = self.ospec()
201
202 def ispec(self):
203 return FPADDBaseData(self.pspec) # SpecialCases ispec
204
205 def ospec(self):
206 return FPSCData(self.pspec, False) # Align ospec
207
208 def setup(self, m, i):
209 """ links module to inputs and outputs
210 """
211 smod = FPDIVSpecialCasesMod(self.pspec)
212 dmod = FPAddDeNormMod(self.pspec, False)
213 amod = FPAlignModSingle(self.pspec, False)
214
215 chain = StageChain([smod, dmod, amod])
216 chain.setup(m, i)
217
218 # only needed for break-out (early-out)
219 # self.out_do_z = smod.o.out_do_z
220
221 self.o = amod.o
222
223 def process(self, i):
224 return self.o
225
226 def action(self, m):
227 # for break-out (early-out)
228 #with m.If(self.out_do_z):
229 # m.next = "put_z"
230 #with m.Else():
231 m.d.sync += self.out.eq(self.process(None))
232 m.next = "align"
233
234