make module out of overflow class
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main, verilog
7
8 from fpbase import FPNum, FPOp, Overflow, FPBase
9
10
11 class FPADD(FPBase):
12
13 def __init__(self, width, single_cycle=False):
14 FPBase.__init__(self)
15 self.width = width
16 self.single_cycle = single_cycle
17
18 self.in_a = FPOp(width)
19 self.in_b = FPOp(width)
20 self.out_z = FPOp(width)
21
22 def get_fragment(self, platform=None):
23 """ creates the HDL code-fragment for FPAdd
24 """
25 m = Module()
26
27 # Latches
28 a = FPNum(self.width)
29 b = FPNum(self.width)
30 z = FPNum(self.width, False)
31
32 m.submodules.fpnum_a = a
33 m.submodules.fpnum_b = b
34 m.submodules.fpnum_z = z
35
36 w = z.m_width + 4
37 tot = Signal(w, reset_less=True) # sticky/round/guard, {mantissa} result, 1 overflow
38
39 of = Overflow()
40
41 m.submodules.overflow = of
42
43 with m.FSM() as fsm:
44
45 # ******
46 # gets operand a
47
48 with m.State("get_a"):
49 self.get_op(m, self.in_a, a, "get_b")
50
51 # ******
52 # gets operand b
53
54 with m.State("get_b"):
55 self.get_op(m, self.in_b, b, "special_cases")
56
57 # ******
58 # special cases: NaNs, infs, zeros, denormalised
59 # NOTE: some of these are unique to add. see "Special Operations"
60 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
61
62 with m.State("special_cases"):
63
64 s_nomatch = Signal()
65 m.d.comb += s_nomatch.eq(a.s != b.s)
66
67 m_match = Signal()
68 m.d.comb += m_match.eq(a.m == b.m)
69
70 # if a is NaN or b is NaN return NaN
71 with m.If(a.is_nan | b.is_nan):
72 m.next = "put_z"
73 m.d.sync += z.nan(1)
74
75 # XXX WEIRDNESS for FP16 non-canonical NaN handling
76 # under review
77
78 ## if a is zero and b is NaN return -b
79 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
80 # m.next = "put_z"
81 # m.d.sync += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
82
83 ## if b is zero and a is NaN return -a
84 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
85 # m.next = "put_z"
86 # m.d.sync += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
87
88 ## if a is -zero and b is NaN return -b
89 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
90 # m.next = "put_z"
91 # m.d.sync += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
92
93 ## if b is -zero and a is NaN return -a
94 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
95 # m.next = "put_z"
96 # m.d.sync += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
97
98 # if a is inf return inf (or NaN)
99 with m.Elif(a.is_inf):
100 m.next = "put_z"
101 m.d.sync += z.inf(a.s)
102 # if a is inf and signs don't match return NaN
103 with m.If(b.exp_128 & s_nomatch):
104 m.d.sync += z.nan(1)
105
106 # if b is inf return inf
107 with m.Elif(b.is_inf):
108 m.next = "put_z"
109 m.d.sync += z.inf(b.s)
110
111 # if a is zero and b zero return signed-a/b
112 with m.Elif(a.is_zero & b.is_zero):
113 m.next = "put_z"
114 m.d.sync += z.create(a.s & b.s, b.e, b.m[3:-1])
115
116 # if a is zero return b
117 with m.Elif(a.is_zero):
118 m.next = "put_z"
119 m.d.sync += z.create(b.s, b.e, b.m[3:-1])
120
121 # if b is zero return a
122 with m.Elif(b.is_zero):
123 m.next = "put_z"
124 m.d.sync += z.create(a.s, a.e, a.m[3:-1])
125
126 # if a equal to -b return zero (+ve zero)
127 with m.Elif(s_nomatch & m_match & (a.e == b.e)):
128 m.next = "put_z"
129 m.d.sync += z.zero(0)
130
131 # Denormalised Number checks
132 with m.Else():
133 m.next = "align"
134 self.denormalise(m, a)
135 self.denormalise(m, b)
136
137 # ******
138 # align.
139
140 with m.State("align"):
141 if not self.single_cycle:
142 # NOTE: this does *not* do single-cycle multi-shifting,
143 # it *STAYS* in the align state until exponents match
144
145 # exponent of a greater than b: shift b down
146 with m.If(a.e > b.e):
147 m.d.sync += b.shift_down()
148 # exponent of b greater than a: shift a down
149 with m.Elif(a.e < b.e):
150 m.d.sync += a.shift_down()
151 # exponents equal: move to next stage.
152 with m.Else():
153 m.next = "add_0"
154 else:
155 # This one however (single-cycle) will do the shift
156 # in one go.
157
158 # XXX TODO: the shifter used here is quite expensive
159 # having only one would be better
160
161 ediff = Signal((len(a.e), True), reset_less=True)
162 ediffr = Signal((len(a.e), True), reset_less=True)
163 m.d.comb += ediff.eq(a.e - b.e)
164 m.d.comb += ediffr.eq(b.e - a.e)
165 with m.If(ediff > 0):
166 m.d.sync += b.shift_down_multi(ediff)
167 # exponent of b greater than a: shift a down
168 with m.Elif(ediff < 0):
169 m.d.sync += a.shift_down_multi(ediffr)
170
171 m.next = "add_0"
172
173 # ******
174 # First stage of add. covers same-sign (add) and subtract
175 # special-casing when mantissas are greater or equal, to
176 # give greatest accuracy.
177
178 with m.State("add_0"):
179 m.next = "add_1"
180 m.d.sync += z.e.eq(a.e)
181 # same-sign (both negative or both positive) add mantissas
182 with m.If(a.s == b.s):
183 m.d.sync += [
184 tot.eq(Cat(a.m, 0) + Cat(b.m, 0)),
185 z.s.eq(a.s)
186 ]
187 # a mantissa greater than b, use a
188 with m.Elif(a.m >= b.m):
189 m.d.sync += [
190 tot.eq(Cat(a.m, 0) - Cat(b.m, 0)),
191 z.s.eq(a.s)
192 ]
193 # b mantissa greater than a, use b
194 with m.Else():
195 m.d.sync += [
196 tot.eq(Cat(b.m, 0) - Cat(a.m, 0)),
197 z.s.eq(b.s)
198 ]
199
200 # ******
201 # Second stage of add: preparation for normalisation.
202 # detects when tot sum is too big (tot[27] is kinda a carry bit)
203
204 with m.State("add_1"):
205 m.next = "normalise_1"
206 # tot[27] gets set when the sum overflows. shift result down
207 with m.If(tot[-1]):
208 m.d.sync += [
209 z.m.eq(tot[4:]),
210 of.m0.eq(tot[4]),
211 of.guard.eq(tot[3]),
212 of.round_bit.eq(tot[2]),
213 of.sticky.eq(tot[1] | tot[0]),
214 z.e.eq(z.e + 1)
215 ]
216 # tot[27] zero case
217 with m.Else():
218 m.d.sync += [
219 z.m.eq(tot[3:]),
220 of.m0.eq(tot[3]),
221 of.guard.eq(tot[2]),
222 of.round_bit.eq(tot[1]),
223 of.sticky.eq(tot[0])
224 ]
225
226 # ******
227 # First stage of normalisation.
228
229 with m.State("normalise_1"):
230 self.normalise_1(m, z, of, "normalise_2")
231
232 # ******
233 # Second stage of normalisation.
234
235 with m.State("normalise_2"):
236 self.normalise_2(m, z, of, "round")
237
238 # ******
239 # rounding stage
240
241 with m.State("round"):
242 self.roundz(m, z, of, "corrections")
243
244 # ******
245 # correction stage
246
247 with m.State("corrections"):
248 self.corrections(m, z, "pack")
249
250 # ******
251 # pack stage
252
253 with m.State("pack"):
254 self.pack(m, z, "put_z")
255
256 # ******
257 # put_z stage
258
259 with m.State("put_z"):
260 self.put_z(m, z, self.out_z, "get_a")
261
262 return m
263
264
265 if __name__ == "__main__":
266 alu = FPADD(width=32)
267 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
268
269
270 # works... but don't use, just do "python fname.py convert -t v"
271 #print (verilog.convert(alu, ports=[
272 # ports=alu.in_a.ports() + \
273 # alu.in_b.ports() + \
274 # alu.out_z.ports())