quick debug session on FP div stub pipeline
[ieee754fpu.git] / src / ieee754 / fpdiv / nmigen_div_experiment.py
1 # IEEE Floating Point Divider (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Const, Cat, Elaboratable
6 from nmigen.cli import main, verilog
7
8 from ieee754.fpcommon.fpbase import (FPNumIn, FPNumOut, FPOpIn,
9 FPOpOut, Overflow, FPBase, FPState,
10 FPNumBaseRecord)
11 from nmutil.nmoperator import eq
12
13
14 class Div:
15 def __init__(self, width):
16 self.width = width
17 self.quot = Signal(width) # quotient
18 self.dor = Signal(width) # divisor
19 self.dend = Signal(width) # dividend
20 self.rem = Signal(width) # remainder
21 self.count = Signal(7) # loop count
22
23 self.czero = Const(0, width)
24
25 def reset(self, m):
26 m.d.sync += [
27 self.quot.eq(self.czero),
28 self.rem.eq(self.czero),
29 self.count.eq(Const(0, 7))
30 ]
31
32
33 class FPDIV(FPBase, Elaboratable):
34
35 def __init__(self, width):
36 FPBase.__init__(self)
37 self.width = width
38
39 self.in_a = FPOpIn(width)
40 self.in_b = FPOpIn(width)
41 self.out_z = FPOpOut(width)
42 self.in_a.data_i = Signal(width)
43 self.in_b.data_i = Signal(width)
44 self.out_z.data_o = Signal(width)
45
46 self.states = []
47
48 def add_state(self, state):
49 self.states.append(state)
50 return state
51
52 def elaborate(self, platform=None):
53 """ creates the HDL code-fragment for FPDiv
54 """
55 m = Module()
56
57 # Latches
58 a = FPNumBaseRecord(self.width, False)
59 b = FPNumBaseRecord(self.width, False)
60 z = FPNumBaseRecord(self.width, False)
61 a = FPNumIn(None, a)
62 b = FPNumIn(None, b)
63 z = FPNumOut(z)
64
65 div = Div(a.m_width*2 + 3) # double the mantissa width plus g/r/sticky
66
67 of = Overflow()
68 m.submodules.in_a = a
69 m.submodules.in_b = b
70 m.submodules.z = z
71 #m.submodules.of = of
72
73 print ("a.v", a.v, self.in_a.v)
74 m.d.comb += a.v.eq(self.in_a.v)
75 m.d.comb += b.v.eq(self.in_b.v)
76
77 with m.FSM() as fsm:
78
79 # ******
80 # gets operand a
81
82 with m.State("get_a"):
83 res = self.get_op(m, self.in_a, a, "get_b")
84 m.d.sync += eq([a, self.in_a.ready_o], res)
85
86 # ******
87 # gets operand b
88
89 with m.State("get_b"):
90 res = self.get_op(m, self.in_b, b, "special_cases")
91 m.d.sync += eq([b, self.in_b.ready_o], res)
92
93 # ******
94 # special cases: NaNs, infs, zeros, denormalised
95 # NOTE: some of these are unique to div. see "Special Operations"
96 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
97
98 with m.State("special_cases"):
99
100 # if a is NaN or b is NaN return NaN
101 with m.If(a.is_nan | b.is_nan):
102 m.next = "put_z"
103 m.d.sync += z.nan(1)
104
105 # if a is Inf and b is Inf return NaN
106 with m.Elif(a.is_inf & b.is_inf):
107 m.next = "put_z"
108 m.d.sync += z.nan(1)
109
110 # if a is inf return inf (or NaN if b is zero)
111 with m.Elif(a.is_inf):
112 m.next = "put_z"
113 m.d.sync += z.inf(a.s ^ b.s)
114
115 # if b is inf return zero
116 with m.Elif(b.is_inf):
117 m.next = "put_z"
118 m.d.sync += z.zero(a.s ^ b.s)
119
120 # if a is zero return zero (or NaN if b is zero)
121 with m.Elif(a.is_zero):
122 m.next = "put_z"
123 # if b is zero return NaN
124 with m.If(b.is_zero):
125 m.d.sync += z.nan(1)
126 with m.Else():
127 m.d.sync += z.zero(a.s ^ b.s)
128
129 # if b is zero return Inf
130 with m.Elif(b.is_zero):
131 m.next = "put_z"
132 m.d.sync += z.inf(a.s ^ b.s)
133
134 # Denormalised Number checks
135 with m.Else():
136 m.next = "normalise_a"
137 self.denormalise(m, a)
138 self.denormalise(m, b)
139
140 # ******
141 # normalise_a
142
143 with m.State("normalise_a"):
144 self.op_normalise(m, a, "normalise_b")
145
146 # ******
147 # normalise_b
148
149 with m.State("normalise_b"):
150 self.op_normalise(m, b, "divide_0")
151
152 # ******
153 # First stage of divide. initialise state
154
155 with m.State("divide_0"):
156 m.next = "divide_1"
157 m.d.sync += [
158 z.s.eq(a.s ^ b.s), # sign
159 z.e.eq(a.e - b.e), # exponent
160 div.dend.eq(a.m<<(a.m_width+3)), # 3 bits for g/r/sticky
161 div.dor.eq(b.m),
162 ]
163 div.reset(m)
164
165 # ******
166 # Second stage of divide.
167
168 with m.State("divide_1"):
169 m.next = "divide_2"
170 m.d.sync += [
171 div.quot.eq(div.quot << 1),
172 div.rem.eq(Cat(div.dend[-1], div.rem[0:])),
173 div.dend.eq(div.dend << 1),
174 ]
175
176 # ******
177 # Third stage of divide.
178 # This stage ends by jumping out to divide_3
179 # However it defaults to jumping to divide_1 (which comes back here)
180
181 with m.State("divide_2"):
182 with m.If(div.rem >= div.dor):
183 m.d.sync += [
184 div.quot[0].eq(1),
185 div.rem.eq(div.rem - div.dor),
186 ]
187 with m.If(div.count == div.width-2):
188 m.next = "divide_3"
189 with m.Else():
190 m.next = "divide_1"
191 m.d.sync += [
192 div.count.eq(div.count + 1),
193 ]
194
195 # ******
196 # Fourth stage of divide.
197
198 with m.State("divide_3"):
199 m.next = "normalise_1"
200 m.d.sync += [
201 z.m.eq(div.quot[3:]),
202 of.guard.eq(div.quot[2]),
203 of.round_bit.eq(div.quot[1]),
204 of.sticky.eq(div.quot[0] | (div.rem != 0))
205 ]
206
207 # ******
208 # First stage of normalisation.
209
210 with m.State("normalise_1"):
211 self.normalise_1(m, z, of, "normalise_2")
212
213 # ******
214 # Second stage of normalisation.
215
216 with m.State("normalise_2"):
217 self.normalise_2(m, z, of, "round")
218
219 # ******
220 # rounding stage
221
222 with m.State("round"):
223 self.roundz(m, z, of.roundz)
224 m.next = "corrections"
225
226 # ******
227 # correction stage
228
229 with m.State("corrections"):
230 self.corrections(m, z, "pack")
231
232 # ******
233 # pack stage
234
235 with m.State("pack"):
236 self.pack(m, z, "put_z")
237
238 # ******
239 # put_z stage
240
241 with m.State("put_z"):
242 self.put_z(m, z, self.out_z, "get_a")
243
244 return m
245
246
247 if __name__ == "__main__":
248 alu = FPDIV(width=32)
249 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
250
251
252 # works... but don't use, just do "python fname.py convert -t v"
253 #print (verilog.convert(alu, ports=[
254 # ports=alu.in_a.ports() + \
255 # alu.in_b.ports() + \
256 # alu.out_z.ports())