remove verilog
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.v = Signal(width) # Latched copy of value
26 self.m = Signal(m_width) # Mantissa
27 self.e = Signal((10, True)) # Exponent: 10 bits, signed
28 self.s = Signal() # Sign bit
29
30 def decode(self):
31 """ decodes a latched value into sign / exponent / mantissa
32
33 bias is subtracted here, from the exponent.
34 """
35 v = self.v
36 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
37 self.e.eq(Cat(v[23:31]) - 127), # exponent (take off bias)
38 self.s.eq(Cat(v[31])), # sign
39 ]
40
41 def create(self, s, e, m):
42 """ creates a value from sign / exponent / mantissa
43
44 bias is added here, to the exponent
45 """
46 return [
47 self.v[31].eq(s), # sign
48 self.v[23:31].eq(e + 127), # exp (add on bias)
49 self.v[0:23].eq(m) # mantissa
50 ]
51
52 def shift_down(self):
53 """ shifts a mantissa down by one. exponent is increased to compensate
54
55 accuracy is lost as a result in the mantissa however there are 3
56 guard bits (the latter of which is the "sticky" bit)
57 """
58 return self.create(self.s,
59 self.e + 1,
60 Cat(self.m[0] | self.m[1], self.m[1:-5], 0))
61
62 def nan(self, s):
63 return self.create(s, 0x80, 1<<22)
64
65 def inf(self, s):
66 return self.create(s, 0x80, 0)
67
68 def zero(self, s):
69 return self.create(s, -127, 0)
70
71 def is_nan(self):
72 return (self.e == 128) & (self.m != 0)
73
74 def is_inf(self):
75 return (self.e == 128) & (self.m == 0)
76
77 def is_zero(self):
78 return (self.e == -127) & (self.m == 0)
79
80 def is_overflowed(self):
81 return (self.e < 127)
82
83 def is_denormalised(self):
84 return (self.e == -126) & (self.m[23] == 0)
85
86
87 class FPADD:
88 def __init__(self, width):
89 self.width = width
90
91 self.in_a = Signal(width)
92 self.in_a_stb = Signal()
93 self.in_a_ack = Signal()
94
95 self.in_b = Signal(width)
96 self.in_b_stb = Signal()
97 self.in_b_ack = Signal()
98
99 self.out_z = Signal(width)
100 self.out_z_stb = Signal()
101 self.out_z_ack = Signal()
102
103 def get_fragment(self, platform):
104 m = Module()
105
106 # Latches
107 a = FPNum(self.width)
108 b = FPNum(self.width)
109 z = FPNum(self.width, 24)
110
111 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
112
113 guard = Signal() # tot[2]
114 round_bit = Signal() # tot[1]
115 sticky = Signal() # tot[0]
116
117 with m.FSM() as fsm:
118
119 # ******
120 # gets operand a
121
122 with m.State("get_a"):
123 with m.If((self.in_a_ack) & (self.in_a_stb)):
124 m.next = "get_b"
125 m.d.sync += [
126 a.v.eq(self.in_a),
127 self.in_a_ack.eq(0)
128 ]
129 with m.Else():
130 m.d.sync += self.in_a_ack.eq(1)
131
132 # ******
133 # gets operand b
134
135 with m.State("get_b"):
136 with m.If((self.in_b_ack) & (self.in_b_stb)):
137 m.next = "get_a"
138 m.d.sync += [
139 b.v.eq(self.in_b),
140 self.in_b_ack.eq(0)
141 ]
142 with m.Else():
143 m.d.sync += self.in_b_ack.eq(1)
144
145 # ******
146 # unpacks operands into sign, mantissa and exponent
147
148 with m.State("unpack"):
149 m.next = "special_cases"
150 m.d.sync += a.decode()
151 m.d.sync += b.decode()
152
153 # ******
154 # special cases: NaNs, infs, zeros, denormalised
155
156 with m.State("special_cases"):
157
158 # if a is NaN or b is NaN return NaN
159 with m.If(a.is_nan() | b.is_nan()):
160 m.next = "put_z"
161 m.d.sync += z.nan(1)
162
163 # if a is inf return inf (or NaN)
164 with m.Elif(a.is_inf()):
165 m.next = "put_z"
166 m.d.sync += z.inf(a.s)
167 # if a is inf and signs don't match return NaN
168 with m.If((b.e == 128) & (a.s != b.s)):
169 m.d.sync += z.nan(b.s)
170
171 # if b is inf return inf
172 with m.Elif(b.is_inf()):
173 m.next = "put_z"
174 m.d.sync += z.inf(b.s)
175
176 # if a is zero and b zero return signed-a/b
177 with m.Elif(a.is_zero() & b.is_zero()):
178 m.next = "put_z"
179 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:26])
180
181 # if a is zero return b
182 with m.Elif(a.is_zero()):
183 m.next = "put_z"
184 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:26])
185
186 # if b is zero return a
187 with m.Elif(b.is_zero()):
188 m.next = "put_z"
189 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:26])
190
191 # Denormalised Number checks
192 with m.Else():
193 m.next = "align"
194 # denormalise a check
195 with m.If(a.e == -127):
196 m.d.sync += a.e.eq(-126) # limit a exponent
197 with m.Else():
198 m.d.sync += a.m[26].eq(1) # set highest mantissa bit
199 # denormalise b check
200 with m.If(b.e == -127):
201 m.d.sync += b.e.eq(-126) # limit b exponent
202 with m.Else():
203 m.d.sync += b.m[26].eq(1) # set highest mantissa bit
204
205 # ******
206 # align. NOTE: this does *not* do single-cycle multi-shifting,
207 # it *STAYS* in the align state until the exponents match
208
209 with m.State("align"):
210 # exponent of a greater than b: increment b exp, shift b mant
211 with m.If(a.e > b.e):
212 m.d.sync += b.shift_down()
213 # exponent of b greater than a: increment a exp, shift a mant
214 with m.Elif(a.e < b.e):
215 m.d.sync += a.shift_down()
216 # exponents equal: move to next stage.
217 with m.Else():
218 m.next = "add_0"
219
220 # ******
221 # First stage of add. covers same-sign (add) and subtract
222 # special-casing when mantissas are greater or equal, to
223 # give greatest accuracy.
224
225 with m.State("add_0"):
226 m.next = "add_1"
227 m.d.sync += z.e.eq(a.e)
228 # same-sign (both negative or both positive) add mantissas
229 with m.If(a.s == b.s):
230 m.d.sync += [
231 tot.eq(a.m + b.m),
232 z.s.eq(a.s)
233 ]
234 # a mantissa greater than b, use a
235 with m.Elif(a.m >= b.m):
236 m.d.sync += [
237 tot.eq(a.m - b.m),
238 z.s.eq(a.s)
239 ]
240 # b mantissa greater than a, use b
241 with m.Else():
242 m.d.sync += [
243 tot.eq(b.m - a.m),
244 z.s.eq(b.s)
245 ]
246
247 # ******
248 # Second stage of add: preparation for normalisation.
249 # detects when tot sum is too big (tot[27] is kinda a carry bit)
250
251 with m.State("add_1"):
252 m.next = "normalise_1"
253 # tot[27] gets set when the sum overflows. shift result down
254 with m.If(tot[27]):
255 m.d.sync += [
256 z.m.eq(tot[4:28]),
257 guard.eq(tot[3]),
258 round_bit.eq(tot[2]),
259 sticky.eq(tot[1] | tot[0]),
260 z.e.eq(z.e + 1)
261 ]
262 # tot[27] zero case
263 with m.Else():
264 m.d.sync += [
265 z.m.eq(tot[3:27]),
266 guard.eq(tot[2]),
267 round_bit.eq(tot[1]),
268 sticky.eq(tot[0])
269 ]
270
271 # ******
272 # First stage of normalisation.
273 # NOTE: just like "align", this one keeps going round every clock
274 # until the result's exponent is within acceptable "range"
275 # NOTE: the weirdness of reassigning guard and round is due to
276 # the extra mantissa bits coming from tot[0..2]
277
278 with m.State("normalise_1"):
279 with m.If((z.m[23] == 0) & (z.e > -126)):
280 m.d.sync +=[
281 z.e.eq(z.e - 1), # DECREASE exponent
282 z.m.eq(z.m << 1), # shift mantissa UP
283 z.m[0].eq(guard), # steal guard bit (was tot[2])
284 guard.eq(round_bit), # steal round_bit (was tot[1])
285 ]
286 with m.Else():
287 m.next = "normalize_2"
288
289 # ******
290 # Second stage of normalisation.
291 # NOTE: just like "align", this one keeps going round every clock
292 # until the result's exponent is within acceptable "range"
293 # NOTE: the weirdness of reassigning guard and round is due to
294 # the extra mantissa bits coming from tot[0..2]
295
296 with m.State("normalise_2"):
297 with m.If(z.e < -126):
298 m.d.sync +=[
299 z.e.eq(z.e + 1), # INCREASE exponent
300 z.m.eq(z.m >> 1), # shift mantissa DOWN
301 guard.eq(z.m[0]),
302 round_bit.eq(guard),
303 sticky.eq(sticky | round_bit)
304 ]
305 with m.Else():
306 m.next = "round"
307
308 # ******
309 # rounding stage
310
311 with m.State("round"):
312 m.next = "correction"
313 with m.If(guard & (round_bit | sticky | z.m[0])):
314 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
315 with m.If(z.m == 0xffffff): # all 1s
316 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
317
318 # ******
319 # correction stage
320
321 with m.State("corrections"):
322 m.next = "pack"
323 # denormalised, correct exponent to zero
324 with m.If(z.is_denormalised()):
325 m.d.sync += z.m.eq(-127)
326 # FIX SIGN BUG: -a + a = +0.
327 with m.If((z.e == -126) & (z.m[0:23] == 0)):
328 m.d.sync += z.s.eq(0)
329
330 # ******
331 # pack stage
332
333 with m.State("pack"):
334 m.next = "put_z"
335 # if overflow occurs, return inf
336 with m.If(z.is_overflowed()):
337 m.d.sync += z.inf(0)
338 with m.Else():
339 m.d.sync += z.create(z.s, z.e, z.m)
340
341 # ******
342 # put_z stage
343
344 with m.State("put_z"):
345 m.next = "get_a"
346 m.d.sync += [
347 self.out_z_stb.eq(1),
348 self.out_z.eq(z.v)
349 ]
350 with m.If(self.out_z_stb & self.out_z_ack):
351 m.d.sync += self.out_z_stb.eq(0)
352
353 return m
354
355
356 if __name__ == "__main__":
357 alu = FPADD(width=32)
358 main(alu, ports=[
359 alu.in_a, alu.in_a_stb, alu.in_a_ack,
360 alu.in_b, alu.in_b_stb, alu.in_b_ack,
361 alu.out_z, alu.out_z_stb, alu.out_z_ack,
362 ])
363
364
365 """
366 # doesnt work for some reason
367 print(verilog.convert(alu, ports=[in_a, in_a_stb, in_a_ack,
368 in_b, in_b_stb, in_b_ack,
369 out_z, out_z_stb, out_z_ack]))
370 """