create separate modules for fpnum in and out
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99
100 def elaborate(self, platform):
101 m = Module()
102 m.d.comb += self.is_nan.eq(self._is_nan())
103 m.d.comb += self.is_zero.eq(self._is_zero())
104 m.d.comb += self.is_inf.eq(self._is_inf())
105 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
106 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
107 m.d.comb += self.exp_128.eq(self.e == self.P128)
108
109 return m
110
111 def _is_nan(self):
112 return (self.e == self.P128) & (self.m != 0)
113
114 def _is_inf(self):
115 return (self.e == self.P128) & (self.m == 0)
116
117 def _is_zero(self):
118 return (self.e == self.N127) & (self.m == self.mzero)
119
120 def _is_overflowed(self):
121 return (self.e > self.P127)
122
123 def _is_denormalised(self):
124 return (self.e == self.N126) & (self.m[self.e_start] == 0)
125
126
127 class FPNumOut(FPNumBase):
128 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
129
130 Contains signals for an incoming copy of the value, decoded into
131 sign / exponent / mantissa.
132 Also contains encoding functions, creation and recognition of
133 zero, NaN and inf (all signed)
134
135 Four extra bits are included in the mantissa: the top bit
136 (m[-1]) is effectively a carry-overflow. The other three are
137 guard (m[2]), round (m[1]), and sticky (m[0])
138 """
139 def __init__(self, width, m_extra=True):
140 FPNumBase.__init__(self, width, m_extra)
141
142 def elaborate(self, platform):
143 m = FPNumBase.elaborate(self, platform)
144
145 return m
146
147 def create(self, s, e, m):
148 """ creates a value from sign / exponent / mantissa
149
150 bias is added here, to the exponent
151 """
152 return [
153 self.v[-1].eq(s), # sign
154 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
155 self.v[0:self.e_start].eq(m) # mantissa
156 ]
157
158 def nan(self, s):
159 return self.create(s, self.P128, 1<<(self.e_start-1))
160
161 def inf(self, s):
162 return self.create(s, self.P128, 0)
163
164 def zero(self, s):
165 return self.create(s, self.N127, 0)
166
167
168 class FPNumIn(FPNumBase):
169 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
170
171 Contains signals for an incoming copy of the value, decoded into
172 sign / exponent / mantissa.
173 Also contains encoding functions, creation and recognition of
174 zero, NaN and inf (all signed)
175
176 Four extra bits are included in the mantissa: the top bit
177 (m[-1]) is effectively a carry-overflow. The other three are
178 guard (m[2]), round (m[1]), and sticky (m[0])
179 """
180 def __init__(self, width, m_extra=True):
181 FPNumBase.__init__(self, width, m_extra)
182
183 def elaborate(self, platform):
184 m = FPNumBase.elaborate(self, platform)
185
186 return m
187
188 def decode(self, v):
189 """ decodes a latched value into sign / exponent / mantissa
190
191 bias is subtracted here, from the exponent. exponent
192 is extended to 10 bits so that subtract 127 is done on
193 a 10-bit number
194 """
195 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
196 #print ("decode", self.e_end)
197 return [self.m.eq(Cat(*args)), # mantissa
198 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
199 self.s.eq(v[-1]), # sign
200 ]
201
202 def shift_down(self):
203 """ shifts a mantissa down by one. exponent is increased to compensate
204
205 accuracy is lost as a result in the mantissa however there are 3
206 guard bits (the latter of which is the "sticky" bit)
207 """
208 return [self.e.eq(self.e + 1),
209 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
210 ]
211
212 def shift_down_multi(self, diff):
213 """ shifts a mantissa down. exponent is increased to compensate
214
215 accuracy is lost as a result in the mantissa however there are 3
216 guard bits (the latter of which is the "sticky" bit)
217
218 this code works by variable-shifting the mantissa by up to
219 its maximum bit-length: no point doing more (it'll still be
220 zero).
221
222 the sticky bit is computed by shifting a batch of 1s by
223 the same amount, which will introduce zeros. it's then
224 inverted and used as a mask to get the LSBs of the mantissa.
225 those are then |'d into the sticky bit.
226 """
227 sm = MultiShift(self.width)
228 mw = Const(self.m_width-1, len(diff))
229 maxslen = Mux(diff > mw, mw, diff)
230 rs = sm.rshift(self.m[1:], maxslen)
231 maxsleni = mw - maxslen
232 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
233
234 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
235 return [self.e.eq(self.e + diff),
236 self.m.eq(Cat(stickybits, rs))
237 ]
238
239 def shift_up_multi(self, diff):
240 """ shifts a mantissa up. exponent is decreased to compensate
241 """
242 sm = MultiShift(self.width)
243 mw = Const(self.m_width, len(diff))
244 maxslen = Mux(diff > mw, mw, diff)
245
246 return [self.e.eq(self.e - diff),
247 self.m.eq(sm.lshift(self.m, maxslen))
248 ]
249
250 class FPOp:
251 def __init__(self, width):
252 self.width = width
253
254 self.v = Signal(width)
255 self.stb = Signal()
256 self.ack = Signal()
257
258 def ports(self):
259 return [self.v, self.stb, self.ack]
260
261
262 class Overflow:
263 def __init__(self):
264 self.guard = Signal(reset_less=True) # tot[2]
265 self.round_bit = Signal(reset_less=True) # tot[1]
266 self.sticky = Signal(reset_less=True) # tot[0]
267 self.m0 = Signal(reset_less=True) # mantissa zero bit
268
269 self.roundz = Signal(reset_less=True)
270
271 def elaborate(self, platform):
272 m = Module()
273 m.d.comb += self.roundz.eq(self.guard & \
274 (self.round_bit | self.sticky | self.m0))
275 return m
276
277
278 class FPBase:
279 """ IEEE754 Floating Point Base Class
280
281 contains common functions for FP manipulation, such as
282 extracting and packing operands, normalisation, denormalisation,
283 rounding etc.
284 """
285
286 def get_op(self, m, op, v, next_state):
287 """ this function moves to the next state and copies the operand
288 when both stb and ack are 1.
289 acknowledgement is sent by setting ack to ZERO.
290 """
291 with m.If((op.ack) & (op.stb)):
292 m.next = next_state
293 m.d.sync += [
294 v.decode(op.v),
295 op.ack.eq(0)
296 ]
297 with m.Else():
298 m.d.sync += op.ack.eq(1)
299
300 def denormalise(self, m, a):
301 """ denormalises a number. this is probably the wrong name for
302 this function. for normalised numbers (exponent != minimum)
303 one *extra* bit (the implicit 1) is added *back in*.
304 for denormalised numbers, the mantissa is left alone
305 and the exponent increased by 1.
306
307 both cases *effectively multiply the number stored by 2*,
308 which has to be taken into account when extracting the result.
309 """
310 with m.If(a.e == a.N127):
311 m.d.sync += a.e.eq(a.N126) # limit a exponent
312 with m.Else():
313 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
314
315 def op_normalise(self, m, op, next_state):
316 """ operand normalisation
317 NOTE: just like "align", this one keeps going round every clock
318 until the result's exponent is within acceptable "range"
319 """
320 with m.If((op.m[-1] == 0)): # check last bit of mantissa
321 m.d.sync +=[
322 op.e.eq(op.e - 1), # DECREASE exponent
323 op.m.eq(op.m << 1), # shift mantissa UP
324 ]
325 with m.Else():
326 m.next = next_state
327
328 def normalise_1(self, m, z, of, next_state):
329 """ first stage normalisation
330
331 NOTE: just like "align", this one keeps going round every clock
332 until the result's exponent is within acceptable "range"
333 NOTE: the weirdness of reassigning guard and round is due to
334 the extra mantissa bits coming from tot[0..2]
335 """
336 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
337 m.d.sync += [
338 z.e.eq(z.e - 1), # DECREASE exponent
339 z.m.eq(z.m << 1), # shift mantissa UP
340 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
341 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
342 of.round_bit.eq(0), # reset round bit
343 of.m0.eq(of.guard),
344 ]
345 with m.Else():
346 m.next = next_state
347
348 def normalise_2(self, m, z, of, next_state):
349 """ second stage normalisation
350
351 NOTE: just like "align", this one keeps going round every clock
352 until the result's exponent is within acceptable "range"
353 NOTE: the weirdness of reassigning guard and round is due to
354 the extra mantissa bits coming from tot[0..2]
355 """
356 with m.If(z.e < z.N126):
357 m.d.sync +=[
358 z.e.eq(z.e + 1), # INCREASE exponent
359 z.m.eq(z.m >> 1), # shift mantissa DOWN
360 of.guard.eq(z.m[0]),
361 of.m0.eq(z.m[1]),
362 of.round_bit.eq(of.guard),
363 of.sticky.eq(of.sticky | of.round_bit)
364 ]
365 with m.Else():
366 m.next = next_state
367
368 def roundz(self, m, z, of, next_state):
369 """ performs rounding on the output. TODO: different kinds of rounding
370 """
371 m.next = next_state
372 with m.If(of.roundz):
373 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
374 with m.If(z.m == z.m1s): # all 1s
375 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
376
377 def corrections(self, m, z, next_state):
378 """ denormalisation and sign-bug corrections
379 """
380 m.next = next_state
381 # denormalised, correct exponent to zero
382 with m.If(z.is_denormalised):
383 m.d.sync += z.e.eq(z.N127)
384
385 def pack(self, m, z, next_state):
386 """ packs the result into the output (detects overflow->Inf)
387 """
388 m.next = next_state
389 # if overflow occurs, return inf
390 with m.If(z.is_overflowed):
391 m.d.sync += z.inf(z.s)
392 with m.Else():
393 m.d.sync += z.create(z.s, z.e, z.m)
394
395 def put_z(self, m, z, out_z, next_state):
396 """ put_z: stores the result in the output. raises stb and waits
397 for ack to be set to 1 before moving to the next state.
398 resets stb back to zero when that occurs, as acknowledgement.
399 """
400 m.d.sync += [
401 out_z.v.eq(z.v)
402 ]
403 with m.If(out_z.stb & out_z.ack):
404 m.d.sync += out_z.stb.eq(0)
405 m.next = next_state
406 with m.Else():
407 m.d.sync += out_z.stb.eq(1)
408
409