add op_normalise function
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.m_width = m_width
26 self.v = Signal(width) # Latched copy of value
27 self.m = Signal(m_width) # Mantissa
28 self.e = Signal((10, True)) # Exponent: 10 bits, signed
29 self.s = Signal() # Sign bit
30
31 self.mzero = Const(0, (m_width, False))
32 self.m1s = Const(-1, (m_width, False))
33 self.P128 = Const(128, (10, True))
34 self.P127 = Const(127, (10, True))
35 self.N127 = Const(-127, (10, True))
36 self.N126 = Const(-126, (10, True))
37
38 def decode(self, v):
39 """ decodes a latched value into sign / exponent / mantissa
40
41 bias is subtracted here, from the exponent. exponent
42 is extended to 10 bits so that subtract 127 is done on
43 a 10-bit number
44 """
45 args = [0] * (self.m_width-24) + [v[0:23]] # pad with extra zeros
46 return [self.m.eq(Cat(*args)), # mantissa
47 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
48 self.s.eq(v[31]), # sign
49 ]
50
51 def create(self, s, e, m):
52 """ creates a value from sign / exponent / mantissa
53
54 bias is added here, to the exponent
55 """
56 return [
57 self.v[31].eq(s), # sign
58 self.v[23:31].eq(e + self.P127), # exp (add on bias)
59 self.v[0:23].eq(m) # mantissa
60 ]
61
62 def shift_down(self):
63 """ shifts a mantissa down by one. exponent is increased to compensate
64
65 accuracy is lost as a result in the mantissa however there are 3
66 guard bits (the latter of which is the "sticky" bit)
67 """
68 return [self.e.eq(self.e + 1),
69 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
70 ]
71
72 def nan(self, s):
73 return self.create(s, self.P128, 1<<22)
74
75 def inf(self, s):
76 return self.create(s, self.P128, 0)
77
78 def zero(self, s):
79 return self.create(s, self.N127, 0)
80
81 def is_nan(self):
82 return (self.e == self.P128) & (self.m != 0)
83
84 def is_inf(self):
85 return (self.e == self.P128) & (self.m == 0)
86
87 def is_zero(self):
88 return (self.e == self.N127) & (self.m == self.mzero)
89
90 def is_overflowed(self):
91 return (self.e > self.P127)
92
93 def is_denormalised(self):
94 return (self.e == self.N126) & (self.m[23] == 0)
95
96
97 class FPOp:
98 def __init__(self, width):
99 self.width = width
100
101 self.v = Signal(width)
102 self.stb = Signal()
103 self.ack = Signal()
104
105 def ports(self):
106 return [self.v, self.stb, self.ack]
107
108
109 class Overflow:
110 def __init__(self):
111 self.guard = Signal() # tot[2]
112 self.round_bit = Signal() # tot[1]
113 self.sticky = Signal() # tot[0]
114
115
116 class FPBase:
117 """ IEEE754 Floating Point Base Class
118
119 contains common functions for FP manipulation, such as
120 extracting and packing operands, normalisation, denormalisation,
121 rounding etc.
122 """
123
124 def get_op(self, m, op, v, next_state):
125 """ this function moves to the next state and copies the operand
126 when both stb and ack are 1.
127 acknowledgement is sent by setting ack to ZERO.
128 """
129 with m.If((op.ack) & (op.stb)):
130 m.next = next_state
131 m.d.sync += [
132 v.decode(op.v),
133 op.ack.eq(0)
134 ]
135 with m.Else():
136 m.d.sync += op.ack.eq(1)
137
138 def denormalise(self, m, a):
139 """ denormalises a number
140 """
141 with m.If(a.e == a.N127):
142 m.d.sync += a.e.eq(-126) # limit a exponent
143 with m.Else():
144 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
145
146 def op_normalise(self, m, op, of, next_state):
147 """ operand normalisation
148 NOTE: just like "align", this one keeps going round every clock
149 until the result's exponent is within acceptable "range"
150 """
151 with m.If((op.m[-1] == 0)): # check last bit of mantissa
152 m.d.sync +=[
153 op.e.eq(op.e - 1), # DECREASE exponent
154 op.m.eq(op.m << 1), # shift mantissa UP
155 ]
156 with m.Else():
157 m.next = next_state
158
159 def normalise_1(self, m, z, of, next_state):
160 """ first stage normalisation
161
162 NOTE: just like "align", this one keeps going round every clock
163 until the result's exponent is within acceptable "range"
164 NOTE: the weirdness of reassigning guard and round is due to
165 the extra mantissa bits coming from tot[0..2]
166 """
167 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
168 m.d.sync +=[
169 z.e.eq(z.e - 1), # DECREASE exponent
170 z.m.eq(z.m << 1), # shift mantissa UP
171 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
172 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
173 of.round_bit.eq(0), # reset round bit
174 ]
175 with m.Else():
176 m.next = next_state
177
178 def normalise_2(self, m, z, of, next_state):
179 """ second stage normalisation
180
181 NOTE: just like "align", this one keeps going round every clock
182 until the result's exponent is within acceptable "range"
183 NOTE: the weirdness of reassigning guard and round is due to
184 the extra mantissa bits coming from tot[0..2]
185 """
186 with m.If(z.e < z.N126):
187 m.d.sync +=[
188 z.e.eq(z.e + 1), # INCREASE exponent
189 z.m.eq(z.m >> 1), # shift mantissa DOWN
190 of.guard.eq(z.m[0]),
191 of.round_bit.eq(of.guard),
192 of.sticky.eq(of.sticky | of.round_bit)
193 ]
194 with m.Else():
195 m.next = next_state
196
197 def roundz(self, m, z, of, next_state):
198 """ performs rounding on the output. TODO: different kinds of rounding
199 """
200 m.next = next_state
201 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
202 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
203 with m.If(z.m == z.m1s): # all 1s
204 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
205
206 def corrections(self, m, z, next_state):
207 """ denormalisation and sign-bug corrections
208 """
209 m.next = next_state
210 # denormalised, correct exponent to zero
211 with m.If(z.is_denormalised()):
212 m.d.sync += z.m.eq(-127)
213 # FIX SIGN BUG: -a + a = +0.
214 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
215 m.d.sync += z.s.eq(0)
216
217 def pack(self, m, z, next_state):
218 """ packs the result into the output (detects overflow->Inf)
219 """
220 m.next = next_state
221 # if overflow occurs, return inf
222 with m.If(z.is_overflowed()):
223 m.d.sync += z.inf(0)
224 with m.Else():
225 m.d.sync += z.create(z.s, z.e, z.m)
226
227 def put_z(self, m, z, out_z, next_state):
228 """ put_z: stores the result in the output. raises stb and waits
229 for ack to be set to 1 before moving to the next state.
230 resets stb back to zero when that occurs, as acknowledgement.
231 """
232 m.d.sync += [
233 out_z.stb.eq(1),
234 out_z.v.eq(z.v)
235 ]
236 with m.If(out_z.stb & out_z.ack):
237 m.d.sync += out_z.stb.eq(0)
238 m.next = next_state
239
240
241 class FPADD(FPBase):
242
243 def __init__(self, width):
244 FPBase.__init__(self)
245 self.width = width
246
247 self.in_a = FPOp(width)
248 self.in_b = FPOp(width)
249 self.out_z = FPOp(width)
250
251 def get_fragment(self, platform=None):
252 """ creates the HDL code-fragment for FPAdd
253 """
254 m = Module()
255
256 # Latches
257 a = FPNum(self.width)
258 b = FPNum(self.width)
259 z = FPNum(self.width, 24)
260
261 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
262
263 of = Overflow()
264
265 with m.FSM() as fsm:
266
267 # ******
268 # gets operand a
269
270 with m.State("get_a"):
271 self.get_op(m, self.in_a, a, "get_b")
272
273 # ******
274 # gets operand b
275
276 with m.State("get_b"):
277 self.get_op(m, self.in_b, b, "special_cases")
278
279 # ******
280 # special cases: NaNs, infs, zeros, denormalised
281 # NOTE: some of these are unique to add. see "Special Operations"
282 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
283
284 with m.State("special_cases"):
285
286 # if a is NaN or b is NaN return NaN
287 with m.If(a.is_nan() | b.is_nan()):
288 m.next = "put_z"
289 m.d.sync += z.nan(1)
290
291 # if a is inf return inf (or NaN)
292 with m.Elif(a.is_inf()):
293 m.next = "put_z"
294 m.d.sync += z.inf(a.s)
295 # if a is inf and signs don't match return NaN
296 with m.If((b.e == b.P128) & (a.s != b.s)):
297 m.d.sync += z.nan(b.s)
298
299 # if b is inf return inf
300 with m.Elif(b.is_inf()):
301 m.next = "put_z"
302 m.d.sync += z.inf(b.s)
303
304 # if a is zero and b zero return signed-a/b
305 with m.Elif(a.is_zero() & b.is_zero()):
306 m.next = "put_z"
307 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
308
309 # if a is zero return b
310 with m.Elif(a.is_zero()):
311 m.next = "put_z"
312 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
313
314 # if b is zero return a
315 with m.Elif(b.is_zero()):
316 m.next = "put_z"
317 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
318
319 # Denormalised Number checks
320 with m.Else():
321 m.next = "align"
322 self.denormalise(m, a)
323 self.denormalise(m, b)
324
325 # ******
326 # align. NOTE: this does *not* do single-cycle multi-shifting,
327 # it *STAYS* in the align state until the exponents match
328
329 with m.State("align"):
330 # exponent of a greater than b: increment b exp, shift b mant
331 with m.If(a.e > b.e):
332 m.d.sync += b.shift_down()
333 # exponent of b greater than a: increment a exp, shift a mant
334 with m.Elif(a.e < b.e):
335 m.d.sync += a.shift_down()
336 # exponents equal: move to next stage.
337 with m.Else():
338 m.next = "add_0"
339
340 # ******
341 # First stage of add. covers same-sign (add) and subtract
342 # special-casing when mantissas are greater or equal, to
343 # give greatest accuracy.
344
345 with m.State("add_0"):
346 m.next = "add_1"
347 m.d.sync += z.e.eq(a.e)
348 # same-sign (both negative or both positive) add mantissas
349 with m.If(a.s == b.s):
350 m.d.sync += [
351 tot.eq(a.m + b.m),
352 z.s.eq(a.s)
353 ]
354 # a mantissa greater than b, use a
355 with m.Elif(a.m >= b.m):
356 m.d.sync += [
357 tot.eq(a.m - b.m),
358 z.s.eq(a.s)
359 ]
360 # b mantissa greater than a, use b
361 with m.Else():
362 m.d.sync += [
363 tot.eq(b.m - a.m),
364 z.s.eq(b.s)
365 ]
366
367 # ******
368 # Second stage of add: preparation for normalisation.
369 # detects when tot sum is too big (tot[27] is kinda a carry bit)
370
371 with m.State("add_1"):
372 m.next = "normalise_1"
373 # tot[27] gets set when the sum overflows. shift result down
374 with m.If(tot[27]):
375 m.d.sync += [
376 z.m.eq(tot[4:28]),
377 of.guard.eq(tot[3]),
378 of.round_bit.eq(tot[2]),
379 of.sticky.eq(tot[1] | tot[0]),
380 z.e.eq(z.e + 1)
381 ]
382 # tot[27] zero case
383 with m.Else():
384 m.d.sync += [
385 z.m.eq(tot[3:27]),
386 of.guard.eq(tot[2]),
387 of.round_bit.eq(tot[1]),
388 of.sticky.eq(tot[0])
389 ]
390
391 # ******
392 # First stage of normalisation.
393
394 with m.State("normalise_1"):
395 self.normalise_1(m, z, of, "normalise_2")
396
397 # ******
398 # Second stage of normalisation.
399
400 with m.State("normalise_2"):
401 self.normalise_2(m, z, of, "round")
402
403 # ******
404 # rounding stage
405
406 with m.State("round"):
407 self.roundz(m, z, of, "corrections")
408
409 # ******
410 # correction stage
411
412 with m.State("corrections"):
413 self.corrections(m, z, "pack")
414
415 # ******
416 # pack stage
417
418 with m.State("pack"):
419 self.pack(m, z, "put_z")
420
421 # ******
422 # put_z stage
423
424 with m.State("put_z"):
425 self.put_z(m, z, self.out_z, "get_a")
426
427 return m
428
429
430 if __name__ == "__main__":
431 alu = FPADD(width=32)
432 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
433
434
435 # works... but don't use, just do "python fname.py convert -t v"
436 #print (verilog.convert(alu, ports=[
437 # ports=alu.in_a.ports() + \
438 # alu.in_b.ports() + \
439 # alu.out_z.ports())