get rid of unpack phase by making it part of the get_op
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.v = Signal(width) # Latched copy of value
26 self.m = Signal(m_width) # Mantissa
27 self.e = Signal((10, True)) # Exponent: 10 bits, signed
28 self.s = Signal() # Sign bit
29
30 self.mzero = Const(0, (m_width, False))
31 self.m1s = Const(-1, (m_width, False))
32 self.P128 = Const(128, (10, True))
33 self.P127 = Const(127, (10, True))
34 self.N127 = Const(-127, (10, True))
35 self.N126 = Const(-126, (10, True))
36
37 def decode(self, v):
38 """ decodes a latched value into sign / exponent / mantissa
39
40 bias is subtracted here, from the exponent. exponent
41 is extended to 10 bits so that subtract 127 is done on
42 a 10-bit number
43 """
44 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
45 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
46 self.s.eq(v[31]), # sign
47 ]
48
49 def create(self, s, e, m):
50 """ creates a value from sign / exponent / mantissa
51
52 bias is added here, to the exponent
53 """
54 return [
55 self.v[31].eq(s), # sign
56 self.v[23:31].eq(e + self.P127), # exp (add on bias)
57 self.v[0:23].eq(m) # mantissa
58 ]
59
60 def shift_down(self):
61 """ shifts a mantissa down by one. exponent is increased to compensate
62
63 accuracy is lost as a result in the mantissa however there are 3
64 guard bits (the latter of which is the "sticky" bit)
65 """
66 return [self.e.eq(self.e + 1),
67 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
68 ]
69
70 def nan(self, s):
71 return self.create(s, self.P128, 1<<22)
72
73 def inf(self, s):
74 return self.create(s, self.P128, 0)
75
76 def zero(self, s):
77 return self.create(s, self.N127, 0)
78
79 def is_nan(self):
80 return (self.e == self.P128) & (self.m != 0)
81
82 def is_inf(self):
83 return (self.e == self.P128) & (self.m == 0)
84
85 def is_zero(self):
86 return (self.e == self.N127) & (self.m == self.mzero)
87
88 def is_overflowed(self):
89 return (self.e > self.P127)
90
91 def is_denormalised(self):
92 return (self.e == self.N126) & (self.m[23] == 0)
93
94 class FPOp:
95 def __init__(self, width):
96 self.width = width
97
98 self.v = Signal(width)
99 self.stb = Signal()
100 self.ack = Signal()
101
102 def ports(self):
103 return [self.v, self.stb, self.ack]
104
105
106 class Overflow:
107 def __init__(self):
108 self.guard = Signal() # tot[2]
109 self.round_bit = Signal() # tot[1]
110 self.sticky = Signal() # tot[0]
111
112
113 class FPADD:
114 def __init__(self, width):
115 self.width = width
116
117 self.in_a = FPOp(width)
118 self.in_b = FPOp(width)
119 self.out_z = FPOp(width)
120
121 def get_op(self, m, op, v, next_state):
122 """ this function moves to the next state and copies the operand
123 when both stb and ack are 1.
124 acknowledgement is sent by setting ack to ZERO.
125 """
126
127 with m.If((op.ack) & (op.stb)):
128 m.next = next_state
129 m.d.sync += [
130 v.decode(op.v),
131 op.ack.eq(0)
132 ]
133 with m.Else():
134 m.d.sync += op.ack.eq(1)
135
136 def normalise_1(self, m, z, of, next_state):
137 """ first stage normalisation
138
139 NOTE: just like "align", this one keeps going round every clock
140 until the result's exponent is within acceptable "range"
141 NOTE: the weirdness of reassigning guard and round is due to
142 the extra mantissa bits coming from tot[0..2]
143 """
144 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
145 m.d.sync +=[
146 z.e.eq(z.e - 1), # DECREASE exponent
147 z.m.eq(z.m << 1), # shift mantissa UP
148 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
149 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
150 of.round_bit.eq(0), # reset round bit
151 ]
152 with m.Else():
153 m.next = next_state
154
155 def normalise_2(self, m, z, of, next_state):
156 """ second stage normalisation
157
158 NOTE: just like "align", this one keeps going round every clock
159 until the result's exponent is within acceptable "range"
160 NOTE: the weirdness of reassigning guard and round is due to
161 the extra mantissa bits coming from tot[0..2]
162 """
163 with m.If(z.e < z.N126):
164 m.d.sync +=[
165 z.e.eq(z.e + 1), # INCREASE exponent
166 z.m.eq(z.m >> 1), # shift mantissa DOWN
167 of.guard.eq(z.m[0]),
168 of.round_bit.eq(of.guard),
169 of.sticky.eq(of.sticky | of.round_bit)
170 ]
171 with m.Else():
172 m.next = next_state
173
174 def roundz(self, m, z, of, next_state):
175 m.next = next_state
176 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
177 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
178 with m.If(z.m == z.m1s): # all 1s
179 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
180
181 def corrections(self, m, z, next_state):
182 m.next = next_state
183 # denormalised, correct exponent to zero
184 with m.If(z.is_denormalised()):
185 m.d.sync += z.m.eq(-127)
186 # FIX SIGN BUG: -a + a = +0.
187 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
188 m.d.sync += z.s.eq(0)
189
190 def pack(self, m, z, next_state):
191 m.next = next_state
192 # if overflow occurs, return inf
193 with m.If(z.is_overflowed()):
194 m.d.sync += z.inf(0)
195 with m.Else():
196 m.d.sync += z.create(z.s, z.e, z.m)
197
198 def put_z(self, m, z, out_z, next_state):
199 """ put_z: stores the result in the output. raises stb and waits
200 for ack to be set to 1 before moving to the next state.
201 resets stb back to zero when that occurs, as acknowledgement.
202 """
203 m.d.sync += [
204 out_z.stb.eq(1),
205 out_z.v.eq(z.v)
206 ]
207 with m.If(out_z.stb & out_z.ack):
208 m.d.sync += out_z.stb.eq(0)
209 m.next = next_state
210
211 def get_fragment(self, platform=None):
212 m = Module()
213
214 # Latches
215 a = FPNum(self.width)
216 b = FPNum(self.width)
217 z = FPNum(self.width, 24)
218
219 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
220
221 of = Overflow()
222
223 with m.FSM() as fsm:
224
225 # ******
226 # gets operand a
227
228 with m.State("get_a"):
229 self.get_op(m, self.in_a, a, "get_b")
230
231 # ******
232 # gets operand b
233
234 with m.State("get_b"):
235 self.get_op(m, self.in_b, b, "special_cases")
236
237 # ******
238 # special cases: NaNs, infs, zeros, denormalised
239
240 with m.State("special_cases"):
241
242 # if a is NaN or b is NaN return NaN
243 with m.If(a.is_nan() | b.is_nan()):
244 m.next = "put_z"
245 m.d.sync += z.nan(1)
246
247 # if a is inf return inf (or NaN)
248 with m.Elif(a.is_inf()):
249 m.next = "put_z"
250 m.d.sync += z.inf(a.s)
251 # if a is inf and signs don't match return NaN
252 with m.If((b.e == b.P128) & (a.s != b.s)):
253 m.d.sync += z.nan(b.s)
254
255 # if b is inf return inf
256 with m.Elif(b.is_inf()):
257 m.next = "put_z"
258 m.d.sync += z.inf(b.s)
259
260 # if a is zero and b zero return signed-a/b
261 with m.Elif(a.is_zero() & b.is_zero()):
262 m.next = "put_z"
263 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
264
265 # if a is zero return b
266 with m.Elif(a.is_zero()):
267 m.next = "put_z"
268 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
269
270 # if b is zero return a
271 with m.Elif(b.is_zero()):
272 m.next = "put_z"
273 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
274
275 # Denormalised Number checks
276 with m.Else():
277 m.next = "align"
278 # denormalise a check
279 with m.If(a.e == a.N127):
280 m.d.sync += a.e.eq(-126) # limit a exponent
281 with m.Else():
282 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
283 # denormalise b check
284 with m.If(b.e == a.N127):
285 m.d.sync += b.e.eq(-126) # limit b exponent
286 with m.Else():
287 m.d.sync += b.m[-1].eq(1) # set top mantissa bit
288
289 # ******
290 # align. NOTE: this does *not* do single-cycle multi-shifting,
291 # it *STAYS* in the align state until the exponents match
292
293 with m.State("align"):
294 # exponent of a greater than b: increment b exp, shift b mant
295 with m.If(a.e > b.e):
296 m.d.sync += b.shift_down()
297 # exponent of b greater than a: increment a exp, shift a mant
298 with m.Elif(a.e < b.e):
299 m.d.sync += a.shift_down()
300 # exponents equal: move to next stage.
301 with m.Else():
302 m.next = "add_0"
303
304 # ******
305 # First stage of add. covers same-sign (add) and subtract
306 # special-casing when mantissas are greater or equal, to
307 # give greatest accuracy.
308
309 with m.State("add_0"):
310 m.next = "add_1"
311 m.d.sync += z.e.eq(a.e)
312 # same-sign (both negative or both positive) add mantissas
313 with m.If(a.s == b.s):
314 m.d.sync += [
315 tot.eq(a.m + b.m),
316 z.s.eq(a.s)
317 ]
318 # a mantissa greater than b, use a
319 with m.Elif(a.m >= b.m):
320 m.d.sync += [
321 tot.eq(a.m - b.m),
322 z.s.eq(a.s)
323 ]
324 # b mantissa greater than a, use b
325 with m.Else():
326 m.d.sync += [
327 tot.eq(b.m - a.m),
328 z.s.eq(b.s)
329 ]
330
331 # ******
332 # Second stage of add: preparation for normalisation.
333 # detects when tot sum is too big (tot[27] is kinda a carry bit)
334
335 with m.State("add_1"):
336 m.next = "normalise_1"
337 # tot[27] gets set when the sum overflows. shift result down
338 with m.If(tot[27]):
339 m.d.sync += [
340 z.m.eq(tot[4:28]),
341 of.guard.eq(tot[3]),
342 of.round_bit.eq(tot[2]),
343 of.sticky.eq(tot[1] | tot[0]),
344 z.e.eq(z.e + 1)
345 ]
346 # tot[27] zero case
347 with m.Else():
348 m.d.sync += [
349 z.m.eq(tot[3:27]),
350 of.guard.eq(tot[2]),
351 of.round_bit.eq(tot[1]),
352 of.sticky.eq(tot[0])
353 ]
354
355 # ******
356 # First stage of normalisation.
357
358 with m.State("normalise_1"):
359 self.normalise_1(m, z, of, "normalise_2")
360
361 # ******
362 # Second stage of normalisation.
363
364 with m.State("normalise_2"):
365 self.normalise_2(m, z, of, "round")
366
367 # ******
368 # rounding stage
369
370 with m.State("round"):
371 self.roundz(m, z, of, "corrections")
372
373 # ******
374 # correction stage
375
376 with m.State("corrections"):
377 self.corrections(m, z, "pack")
378
379 # ******
380 # pack stage
381
382 with m.State("pack"):
383 self.pack(m, z, "put_z")
384
385 # ******
386 # put_z stage
387
388 with m.State("put_z"):
389 self.put_z(m, z, self.out_z, "get_a")
390
391 return m
392
393
394 if __name__ == "__main__":
395 alu = FPADD(width=32)
396 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
397
398
399 # works... but don't use, just do "python fname.py convert -t v"
400 #print (verilog.convert(alu, ports=[
401 # ports=alu.in_a.ports() + \
402 # alu.in_b.ports() + \
403 # alu.out_z.ports())