pad with zeros if needed in decode
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.m_width = m_width
26 self.v = Signal(width) # Latched copy of value
27 self.m = Signal(m_width) # Mantissa
28 self.e = Signal((10, True)) # Exponent: 10 bits, signed
29 self.s = Signal() # Sign bit
30
31 self.mzero = Const(0, (m_width, False))
32 self.m1s = Const(-1, (m_width, False))
33 self.P128 = Const(128, (10, True))
34 self.P127 = Const(127, (10, True))
35 self.N127 = Const(-127, (10, True))
36 self.N126 = Const(-126, (10, True))
37
38 def decode(self, v):
39 """ decodes a latched value into sign / exponent / mantissa
40
41 bias is subtracted here, from the exponent. exponent
42 is extended to 10 bits so that subtract 127 is done on
43 a 10-bit number
44 """
45 args = [0] * (self.m_width-24) + [v[0:23]] # pad with extra zeros
46 return [self.m.eq(Cat(*args)), # mantissa
47 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
48 self.s.eq(v[31]), # sign
49 ]
50
51 def create(self, s, e, m):
52 """ creates a value from sign / exponent / mantissa
53
54 bias is added here, to the exponent
55 """
56 return [
57 self.v[31].eq(s), # sign
58 self.v[23:31].eq(e + self.P127), # exp (add on bias)
59 self.v[0:23].eq(m) # mantissa
60 ]
61
62 def shift_down(self):
63 """ shifts a mantissa down by one. exponent is increased to compensate
64
65 accuracy is lost as a result in the mantissa however there are 3
66 guard bits (the latter of which is the "sticky" bit)
67 """
68 return [self.e.eq(self.e + 1),
69 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
70 ]
71
72 def nan(self, s):
73 return self.create(s, self.P128, 1<<22)
74
75 def inf(self, s):
76 return self.create(s, self.P128, 0)
77
78 def zero(self, s):
79 return self.create(s, self.N127, 0)
80
81 def is_nan(self):
82 return (self.e == self.P128) & (self.m != 0)
83
84 def is_inf(self):
85 return (self.e == self.P128) & (self.m == 0)
86
87 def is_zero(self):
88 return (self.e == self.N127) & (self.m == self.mzero)
89
90 def is_overflowed(self):
91 return (self.e > self.P127)
92
93 def is_denormalised(self):
94 return (self.e == self.N126) & (self.m[23] == 0)
95
96
97 class FPOp:
98 def __init__(self, width):
99 self.width = width
100
101 self.v = Signal(width)
102 self.stb = Signal()
103 self.ack = Signal()
104
105 def ports(self):
106 return [self.v, self.stb, self.ack]
107
108
109 class Overflow:
110 def __init__(self):
111 self.guard = Signal() # tot[2]
112 self.round_bit = Signal() # tot[1]
113 self.sticky = Signal() # tot[0]
114
115
116 class FPBase:
117 """ IEEE754 Floating Point Base Class
118
119 contains common functions for FP manipulation, such as
120 extracting and packing operands, normalisation, denormalisation,
121 rounding etc.
122 """
123
124 def get_op(self, m, op, v, next_state):
125 """ this function moves to the next state and copies the operand
126 when both stb and ack are 1.
127 acknowledgement is sent by setting ack to ZERO.
128 """
129 with m.If((op.ack) & (op.stb)):
130 m.next = next_state
131 m.d.sync += [
132 v.decode(op.v),
133 op.ack.eq(0)
134 ]
135 with m.Else():
136 m.d.sync += op.ack.eq(1)
137
138 def denormalise(self, m, a):
139 """ denormalises a number
140 """
141 with m.If(a.e == a.N127):
142 m.d.sync += a.e.eq(-126) # limit a exponent
143 with m.Else():
144 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
145
146 def normalise_1(self, m, z, of, next_state):
147 """ first stage normalisation
148
149 NOTE: just like "align", this one keeps going round every clock
150 until the result's exponent is within acceptable "range"
151 NOTE: the weirdness of reassigning guard and round is due to
152 the extra mantissa bits coming from tot[0..2]
153 """
154 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
155 m.d.sync +=[
156 z.e.eq(z.e - 1), # DECREASE exponent
157 z.m.eq(z.m << 1), # shift mantissa UP
158 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
159 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
160 of.round_bit.eq(0), # reset round bit
161 ]
162 with m.Else():
163 m.next = next_state
164
165 def normalise_2(self, m, z, of, next_state):
166 """ second stage normalisation
167
168 NOTE: just like "align", this one keeps going round every clock
169 until the result's exponent is within acceptable "range"
170 NOTE: the weirdness of reassigning guard and round is due to
171 the extra mantissa bits coming from tot[0..2]
172 """
173 with m.If(z.e < z.N126):
174 m.d.sync +=[
175 z.e.eq(z.e + 1), # INCREASE exponent
176 z.m.eq(z.m >> 1), # shift mantissa DOWN
177 of.guard.eq(z.m[0]),
178 of.round_bit.eq(of.guard),
179 of.sticky.eq(of.sticky | of.round_bit)
180 ]
181 with m.Else():
182 m.next = next_state
183
184 def roundz(self, m, z, of, next_state):
185 """ performs rounding on the output. TODO: different kinds of rounding
186 """
187 m.next = next_state
188 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
189 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
190 with m.If(z.m == z.m1s): # all 1s
191 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
192
193 def corrections(self, m, z, next_state):
194 """ denormalisation and sign-bug corrections
195 """
196 m.next = next_state
197 # denormalised, correct exponent to zero
198 with m.If(z.is_denormalised()):
199 m.d.sync += z.m.eq(-127)
200 # FIX SIGN BUG: -a + a = +0.
201 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
202 m.d.sync += z.s.eq(0)
203
204 def pack(self, m, z, next_state):
205 """ packs the result into the output (detects overflow->Inf)
206 """
207 m.next = next_state
208 # if overflow occurs, return inf
209 with m.If(z.is_overflowed()):
210 m.d.sync += z.inf(0)
211 with m.Else():
212 m.d.sync += z.create(z.s, z.e, z.m)
213
214 def put_z(self, m, z, out_z, next_state):
215 """ put_z: stores the result in the output. raises stb and waits
216 for ack to be set to 1 before moving to the next state.
217 resets stb back to zero when that occurs, as acknowledgement.
218 """
219 m.d.sync += [
220 out_z.stb.eq(1),
221 out_z.v.eq(z.v)
222 ]
223 with m.If(out_z.stb & out_z.ack):
224 m.d.sync += out_z.stb.eq(0)
225 m.next = next_state
226
227
228 class FPADD(FPBase):
229
230 def __init__(self, width):
231 FPBase.__init__(self)
232 self.width = width
233
234 self.in_a = FPOp(width)
235 self.in_b = FPOp(width)
236 self.out_z = FPOp(width)
237
238 def get_fragment(self, platform=None):
239 """ creates the HDL code-fragment for FPAdd
240 """
241 m = Module()
242
243 # Latches
244 a = FPNum(self.width)
245 b = FPNum(self.width)
246 z = FPNum(self.width, 24)
247
248 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
249
250 of = Overflow()
251
252 with m.FSM() as fsm:
253
254 # ******
255 # gets operand a
256
257 with m.State("get_a"):
258 self.get_op(m, self.in_a, a, "get_b")
259
260 # ******
261 # gets operand b
262
263 with m.State("get_b"):
264 self.get_op(m, self.in_b, b, "special_cases")
265
266 # ******
267 # special cases: NaNs, infs, zeros, denormalised
268 # NOTE: some of these are unique to add. see "Special Operations"
269 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
270
271 with m.State("special_cases"):
272
273 # if a is NaN or b is NaN return NaN
274 with m.If(a.is_nan() | b.is_nan()):
275 m.next = "put_z"
276 m.d.sync += z.nan(1)
277
278 # if a is inf return inf (or NaN)
279 with m.Elif(a.is_inf()):
280 m.next = "put_z"
281 m.d.sync += z.inf(a.s)
282 # if a is inf and signs don't match return NaN
283 with m.If((b.e == b.P128) & (a.s != b.s)):
284 m.d.sync += z.nan(b.s)
285
286 # if b is inf return inf
287 with m.Elif(b.is_inf()):
288 m.next = "put_z"
289 m.d.sync += z.inf(b.s)
290
291 # if a is zero and b zero return signed-a/b
292 with m.Elif(a.is_zero() & b.is_zero()):
293 m.next = "put_z"
294 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
295
296 # if a is zero return b
297 with m.Elif(a.is_zero()):
298 m.next = "put_z"
299 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
300
301 # if b is zero return a
302 with m.Elif(b.is_zero()):
303 m.next = "put_z"
304 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
305
306 # Denormalised Number checks
307 with m.Else():
308 m.next = "align"
309 self.denormalise(m, a)
310 self.denormalise(m, b)
311
312 # ******
313 # align. NOTE: this does *not* do single-cycle multi-shifting,
314 # it *STAYS* in the align state until the exponents match
315
316 with m.State("align"):
317 # exponent of a greater than b: increment b exp, shift b mant
318 with m.If(a.e > b.e):
319 m.d.sync += b.shift_down()
320 # exponent of b greater than a: increment a exp, shift a mant
321 with m.Elif(a.e < b.e):
322 m.d.sync += a.shift_down()
323 # exponents equal: move to next stage.
324 with m.Else():
325 m.next = "add_0"
326
327 # ******
328 # First stage of add. covers same-sign (add) and subtract
329 # special-casing when mantissas are greater or equal, to
330 # give greatest accuracy.
331
332 with m.State("add_0"):
333 m.next = "add_1"
334 m.d.sync += z.e.eq(a.e)
335 # same-sign (both negative or both positive) add mantissas
336 with m.If(a.s == b.s):
337 m.d.sync += [
338 tot.eq(a.m + b.m),
339 z.s.eq(a.s)
340 ]
341 # a mantissa greater than b, use a
342 with m.Elif(a.m >= b.m):
343 m.d.sync += [
344 tot.eq(a.m - b.m),
345 z.s.eq(a.s)
346 ]
347 # b mantissa greater than a, use b
348 with m.Else():
349 m.d.sync += [
350 tot.eq(b.m - a.m),
351 z.s.eq(b.s)
352 ]
353
354 # ******
355 # Second stage of add: preparation for normalisation.
356 # detects when tot sum is too big (tot[27] is kinda a carry bit)
357
358 with m.State("add_1"):
359 m.next = "normalise_1"
360 # tot[27] gets set when the sum overflows. shift result down
361 with m.If(tot[27]):
362 m.d.sync += [
363 z.m.eq(tot[4:28]),
364 of.guard.eq(tot[3]),
365 of.round_bit.eq(tot[2]),
366 of.sticky.eq(tot[1] | tot[0]),
367 z.e.eq(z.e + 1)
368 ]
369 # tot[27] zero case
370 with m.Else():
371 m.d.sync += [
372 z.m.eq(tot[3:27]),
373 of.guard.eq(tot[2]),
374 of.round_bit.eq(tot[1]),
375 of.sticky.eq(tot[0])
376 ]
377
378 # ******
379 # First stage of normalisation.
380
381 with m.State("normalise_1"):
382 self.normalise_1(m, z, of, "normalise_2")
383
384 # ******
385 # Second stage of normalisation.
386
387 with m.State("normalise_2"):
388 self.normalise_2(m, z, of, "round")
389
390 # ******
391 # rounding stage
392
393 with m.State("round"):
394 self.roundz(m, z, of, "corrections")
395
396 # ******
397 # correction stage
398
399 with m.State("corrections"):
400 self.corrections(m, z, "pack")
401
402 # ******
403 # pack stage
404
405 with m.State("pack"):
406 self.pack(m, z, "put_z")
407
408 # ******
409 # put_z stage
410
411 with m.State("put_z"):
412 self.put_z(m, z, self.out_z, "get_a")
413
414 return m
415
416
417 if __name__ == "__main__":
418 alu = FPADD(width=32)
419 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
420
421
422 # works... but don't use, just do "python fname.py convert -t v"
423 #print (verilog.convert(alu, ports=[
424 # ports=alu.in_a.ports() + \
425 # alu.in_b.ports() + \
426 # alu.out_z.ports())