comment functions
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.v = Signal(width) # Latched copy of value
26 self.m = Signal(m_width) # Mantissa
27 self.e = Signal((10, True)) # Exponent: 10 bits, signed
28 self.s = Signal() # Sign bit
29
30 self.mzero = Const(0, (m_width, False))
31 self.m1s = Const(-1, (m_width, False))
32 self.P128 = Const(128, (10, True))
33 self.P127 = Const(127, (10, True))
34 self.N127 = Const(-127, (10, True))
35 self.N126 = Const(-126, (10, True))
36
37 def decode(self):
38 """ decodes a latched value into sign / exponent / mantissa
39
40 bias is subtracted here, from the exponent. exponent
41 is extended to 10 bits so that subtract 127 is done on
42 a 10-bit number
43 """
44 v = self.v
45 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
46 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
47 self.s.eq(v[31]), # sign
48 ]
49
50 def create(self, s, e, m):
51 """ creates a value from sign / exponent / mantissa
52
53 bias is added here, to the exponent
54 """
55 return [
56 self.v[31].eq(s), # sign
57 self.v[23:31].eq(e + self.P127), # exp (add on bias)
58 self.v[0:23].eq(m) # mantissa
59 ]
60
61 def shift_down(self):
62 """ shifts a mantissa down by one. exponent is increased to compensate
63
64 accuracy is lost as a result in the mantissa however there are 3
65 guard bits (the latter of which is the "sticky" bit)
66 """
67 return [self.e.eq(self.e + 1),
68 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
69 ]
70
71 def nan(self, s):
72 return self.create(s, self.P128, 1<<22)
73
74 def inf(self, s):
75 return self.create(s, self.P128, 0)
76
77 def zero(self, s):
78 return self.create(s, self.N127, 0)
79
80 def is_nan(self):
81 return (self.e == self.P128) & (self.m != 0)
82
83 def is_inf(self):
84 return (self.e == self.P128) & (self.m == 0)
85
86 def is_zero(self):
87 return (self.e == self.N127) & (self.m == self.mzero)
88
89 def is_overflowed(self):
90 return (self.e > self.P127)
91
92 def is_denormalised(self):
93 return (self.e == self.N126) & (self.m[23] == 0)
94
95 class FPOp:
96 def __init__(self, width):
97 self.width = width
98
99 self.v = Signal(width)
100 self.stb = Signal()
101 self.ack = Signal()
102
103 def ports(self):
104 return [self.v, self.stb, self.ack]
105
106
107 class Overflow:
108 def __init__(self):
109 self.guard = Signal() # tot[2]
110 self.round_bit = Signal() # tot[1]
111 self.sticky = Signal() # tot[0]
112
113
114 class FPADD:
115 def __init__(self, width):
116 self.width = width
117
118 self.in_a = FPOp(width)
119 self.in_b = FPOp(width)
120 self.out_z = FPOp(width)
121
122 def get_op(self, m, op, v, next_state):
123 """ this function moves to the next state and copies the operand
124 when both stb and ack are 1.
125 acknowledgement is sent by setting ack to ZERO.
126 """
127
128 with m.If((op.ack) & (op.stb)):
129 m.next = next_state
130 m.d.sync += [
131 v.eq(op.v),
132 op.ack.eq(0)
133 ]
134 with m.Else():
135 m.d.sync += op.ack.eq(1)
136
137 def normalise_1(self, m, z, of, next_state):
138 """ first stage normalisation
139
140 NOTE: just like "align", this one keeps going round every clock
141 until the result's exponent is within acceptable "range"
142 NOTE: the weirdness of reassigning guard and round is due to
143 the extra mantissa bits coming from tot[0..2]
144 """
145 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
146 m.d.sync +=[
147 z.e.eq(z.e - 1), # DECREASE exponent
148 z.m.eq(z.m << 1), # shift mantissa UP
149 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
150 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
151 of.round_bit.eq(0), # reset round bit
152 ]
153 with m.Else():
154 m.next = next_state
155
156 def normalise_2(self, m, z, of, next_state):
157 """ second stage normalisation
158
159 NOTE: just like "align", this one keeps going round every clock
160 until the result's exponent is within acceptable "range"
161 NOTE: the weirdness of reassigning guard and round is due to
162 the extra mantissa bits coming from tot[0..2]
163 """
164 with m.If(z.e < z.N126):
165 m.d.sync +=[
166 z.e.eq(z.e + 1), # INCREASE exponent
167 z.m.eq(z.m >> 1), # shift mantissa DOWN
168 of.guard.eq(z.m[0]),
169 of.round_bit.eq(of.guard),
170 of.sticky.eq(of.sticky | of.round_bit)
171 ]
172 with m.Else():
173 m.next = next_state
174
175 def roundz(self, m, z, of, next_state):
176 m.next = next_state
177 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
178 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
179 with m.If(z.m == z.m1s): # all 1s
180 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
181
182 def corrections(self, m, z, next_state):
183 m.next = next_state
184 # denormalised, correct exponent to zero
185 with m.If(z.is_denormalised()):
186 m.d.sync += z.m.eq(-127)
187 # FIX SIGN BUG: -a + a = +0.
188 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
189 m.d.sync += z.s.eq(0)
190
191 def pack(self, m, z, next_state):
192 m.next = next_state
193 # if overflow occurs, return inf
194 with m.If(z.is_overflowed()):
195 m.d.sync += z.inf(0)
196 with m.Else():
197 m.d.sync += z.create(z.s, z.e, z.m)
198
199 def put_z(self, m, z, out_z, next_state):
200 """ put_z: stores the result in the output. raises stb and waits
201 for ack to be set to 1 before moving to the next state.
202 resets stb back to zero when that occurs, as acknowledgement.
203 """
204 m.d.sync += [
205 out_z.stb.eq(1),
206 out_z.v.eq(z.v)
207 ]
208 with m.If(out_z.stb & out_z.ack):
209 m.d.sync += out_z.stb.eq(0)
210 m.next = next_state
211
212 def get_fragment(self, platform=None):
213 m = Module()
214
215 # Latches
216 a = FPNum(self.width)
217 b = FPNum(self.width)
218 z = FPNum(self.width, 24)
219
220 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
221
222 of = Overflow()
223
224 with m.FSM() as fsm:
225
226 # ******
227 # gets operand a
228
229 with m.State("get_a"):
230 self.get_op(m, self.in_a, a.v, "get_b")
231
232 # ******
233 # gets operand b
234
235 with m.State("get_b"):
236 self.get_op(m, self.in_b, b.v, "unpack")
237
238 # ******
239 # unpacks operands into sign, mantissa and exponent
240
241 with m.State("unpack"):
242 m.next = "special_cases"
243 m.d.sync += a.decode()
244 m.d.sync += b.decode()
245
246 # ******
247 # special cases: NaNs, infs, zeros, denormalised
248
249 with m.State("special_cases"):
250
251 # if a is NaN or b is NaN return NaN
252 with m.If(a.is_nan() | b.is_nan()):
253 m.next = "put_z"
254 m.d.sync += z.nan(1)
255
256 # if a is inf return inf (or NaN)
257 with m.Elif(a.is_inf()):
258 m.next = "put_z"
259 m.d.sync += z.inf(a.s)
260 # if a is inf and signs don't match return NaN
261 with m.If((b.e == b.P128) & (a.s != b.s)):
262 m.d.sync += z.nan(b.s)
263
264 # if b is inf return inf
265 with m.Elif(b.is_inf()):
266 m.next = "put_z"
267 m.d.sync += z.inf(b.s)
268
269 # if a is zero and b zero return signed-a/b
270 with m.Elif(a.is_zero() & b.is_zero()):
271 m.next = "put_z"
272 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
273
274 # if a is zero return b
275 with m.Elif(a.is_zero()):
276 m.next = "put_z"
277 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
278
279 # if b is zero return a
280 with m.Elif(b.is_zero()):
281 m.next = "put_z"
282 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
283
284 # Denormalised Number checks
285 with m.Else():
286 m.next = "align"
287 # denormalise a check
288 with m.If(a.e == a.N127):
289 m.d.sync += a.e.eq(-126) # limit a exponent
290 with m.Else():
291 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
292 # denormalise b check
293 with m.If(b.e == a.N127):
294 m.d.sync += b.e.eq(-126) # limit b exponent
295 with m.Else():
296 m.d.sync += b.m[-1].eq(1) # set top mantissa bit
297
298 # ******
299 # align. NOTE: this does *not* do single-cycle multi-shifting,
300 # it *STAYS* in the align state until the exponents match
301
302 with m.State("align"):
303 # exponent of a greater than b: increment b exp, shift b mant
304 with m.If(a.e > b.e):
305 m.d.sync += b.shift_down()
306 # exponent of b greater than a: increment a exp, shift a mant
307 with m.Elif(a.e < b.e):
308 m.d.sync += a.shift_down()
309 # exponents equal: move to next stage.
310 with m.Else():
311 m.next = "add_0"
312
313 # ******
314 # First stage of add. covers same-sign (add) and subtract
315 # special-casing when mantissas are greater or equal, to
316 # give greatest accuracy.
317
318 with m.State("add_0"):
319 m.next = "add_1"
320 m.d.sync += z.e.eq(a.e)
321 # same-sign (both negative or both positive) add mantissas
322 with m.If(a.s == b.s):
323 m.d.sync += [
324 tot.eq(a.m + b.m),
325 z.s.eq(a.s)
326 ]
327 # a mantissa greater than b, use a
328 with m.Elif(a.m >= b.m):
329 m.d.sync += [
330 tot.eq(a.m - b.m),
331 z.s.eq(a.s)
332 ]
333 # b mantissa greater than a, use b
334 with m.Else():
335 m.d.sync += [
336 tot.eq(b.m - a.m),
337 z.s.eq(b.s)
338 ]
339
340 # ******
341 # Second stage of add: preparation for normalisation.
342 # detects when tot sum is too big (tot[27] is kinda a carry bit)
343
344 with m.State("add_1"):
345 m.next = "normalise_1"
346 # tot[27] gets set when the sum overflows. shift result down
347 with m.If(tot[27]):
348 m.d.sync += [
349 z.m.eq(tot[4:28]),
350 of.guard.eq(tot[3]),
351 of.round_bit.eq(tot[2]),
352 of.sticky.eq(tot[1] | tot[0]),
353 z.e.eq(z.e + 1)
354 ]
355 # tot[27] zero case
356 with m.Else():
357 m.d.sync += [
358 z.m.eq(tot[3:27]),
359 of.guard.eq(tot[2]),
360 of.round_bit.eq(tot[1]),
361 of.sticky.eq(tot[0])
362 ]
363
364 # ******
365 # First stage of normalisation.
366
367 with m.State("normalise_1"):
368 self.normalise_1(m, z, of, "normalise_2")
369
370 # ******
371 # Second stage of normalisation.
372
373 with m.State("normalise_2"):
374 self.normalise_2(m, z, of, "round")
375
376 # ******
377 # rounding stage
378
379 with m.State("round"):
380 self.roundz(m, z, of, "corrections")
381
382 # ******
383 # correction stage
384
385 with m.State("corrections"):
386 self.corrections(m, z, "pack")
387
388 # ******
389 # pack stage
390
391 with m.State("pack"):
392 self.pack(m, z, "put_z")
393
394 # ******
395 # put_z stage
396
397 with m.State("put_z"):
398 self.put_z(m, z, self.out_z, "get_a")
399
400 return m
401
402
403 if __name__ == "__main__":
404 alu = FPADD(width=32)
405 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
406
407
408 # works... but don't use, just do "python fname.py convert -t v"
409 #print (verilog.convert(alu, ports=[
410 # ports=alu.in_a.ports() + \
411 # alu.in_b.ports() + \
412 # alu.out_z.ports())