create module for FPNum
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNum:
60 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
61
62 Contains signals for an incoming copy of the value, decoded into
63 sign / exponent / mantissa.
64 Also contains encoding functions, creation and recognition of
65 zero, NaN and inf (all signed)
66
67 Four extra bits are included in the mantissa: the top bit
68 (m[-1]) is effectively a carry-overflow. The other three are
69 guard (m[2]), round (m[1]), and sticky (m[0])
70 """
71 def __init__(self, width, m_extra=True):
72 self.width = width
73 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
74 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
75 e_max = 1<<(e_width-3)
76 self.rmw = m_width # real mantissa width (not including extras)
77 self.e_max = e_max
78 if m_extra:
79 # mantissa extra bits (top,guard,round)
80 self.m_extra = 3
81 m_width += self.m_extra
82 else:
83 self.m_extra = 0
84 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
85 self.m_width = m_width
86 self.e_width = e_width
87 self.e_start = self.rmw - 1
88 self.e_end = self.rmw + self.e_width - 3 # for decoding
89
90 self.v = Signal(width, reset_less=True) # Latched copy of value
91 self.m = Signal(m_width, reset_less=True) # Mantissa
92 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
93 self.s = Signal(reset_less=True) # Sign bit
94
95 self.mzero = Const(0, (m_width, False))
96 self.m1s = Const(-1, (m_width, False))
97 self.P128 = Const(e_max, (e_width, True))
98 self.P127 = Const(e_max-1, (e_width, True))
99 self.N127 = Const(-(e_max-1), (e_width, True))
100 self.N126 = Const(-(e_max-2), (e_width, True))
101
102 self.is_nan = Signal(reset_less=True)
103 self.is_zero = Signal(reset_less=True)
104 self.is_inf = Signal(reset_less=True)
105 self.is_overflowed = Signal(reset_less=True)
106 self.is_denormalised = Signal(reset_less=True)
107 self.exp_128 = Signal(reset_less=True)
108
109 def elaborate(self, platform):
110 m = Module()
111 m.d.comb += self.is_nan.eq(self._is_nan())
112 m.d.comb += self.is_zero.eq(self._is_zero())
113 m.d.comb += self.is_inf.eq(self._is_inf())
114 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
115 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
116 m.d.comb += self.exp_128.eq(self.e == self.P128)
117
118 return m
119
120 def decode(self, v):
121 """ decodes a latched value into sign / exponent / mantissa
122
123 bias is subtracted here, from the exponent. exponent
124 is extended to 10 bits so that subtract 127 is done on
125 a 10-bit number
126 """
127 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
128 #print ("decode", self.e_end)
129 return [self.m.eq(Cat(*args)), # mantissa
130 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
131 self.s.eq(v[-1]), # sign
132 ]
133
134 def create(self, s, e, m):
135 """ creates a value from sign / exponent / mantissa
136
137 bias is added here, to the exponent
138 """
139 return [
140 self.v[-1].eq(s), # sign
141 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
142 self.v[0:self.e_start].eq(m) # mantissa
143 ]
144
145 def shift_down(self):
146 """ shifts a mantissa down by one. exponent is increased to compensate
147
148 accuracy is lost as a result in the mantissa however there are 3
149 guard bits (the latter of which is the "sticky" bit)
150 """
151 return [self.e.eq(self.e + 1),
152 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
153 ]
154
155 def shift_down_multi(self, diff):
156 """ shifts a mantissa down. exponent is increased to compensate
157
158 accuracy is lost as a result in the mantissa however there are 3
159 guard bits (the latter of which is the "sticky" bit)
160
161 this code works by variable-shifting the mantissa by up to
162 its maximum bit-length: no point doing more (it'll still be
163 zero).
164
165 the sticky bit is computed by shifting a batch of 1s by
166 the same amount, which will introduce zeros. it's then
167 inverted and used as a mask to get the LSBs of the mantissa.
168 those are then |'d into the sticky bit.
169 """
170 sm = MultiShift(self.width)
171 mw = Const(self.m_width-1, len(diff))
172 maxslen = Mux(diff > mw, mw, diff)
173 rs = sm.rshift(self.m[1:], maxslen)
174 maxsleni = mw - maxslen
175 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
176
177 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
178 return [self.e.eq(self.e + diff),
179 self.m.eq(Cat(stickybits, rs))
180 ]
181
182 def shift_up_multi(self, diff):
183 """ shifts a mantissa up. exponent is decreased to compensate
184 """
185 sm = MultiShift(self.width)
186 mw = Const(self.m_width, len(diff))
187 maxslen = Mux(diff > mw, mw, diff)
188
189 return [self.e.eq(self.e - diff),
190 self.m.eq(sm.lshift(self.m, maxslen))
191 ]
192
193 def nan(self, s):
194 return self.create(s, self.P128, 1<<(self.e_start-1))
195
196 def inf(self, s):
197 return self.create(s, self.P128, 0)
198
199 def zero(self, s):
200 return self.create(s, self.N127, 0)
201
202 def _is_nan(self):
203 return (self.e == self.P128) & (self.m != 0)
204
205 def _is_inf(self):
206 return (self.e == self.P128) & (self.m == 0)
207
208 def _is_zero(self):
209 return (self.e == self.N127) & (self.m == self.mzero)
210
211 def _is_overflowed(self):
212 return (self.e > self.P127)
213
214 def _is_denormalised(self):
215 return (self.e == self.N126) & (self.m[self.e_start] == 0)
216
217
218 class FPOp:
219 def __init__(self, width):
220 self.width = width
221
222 self.v = Signal(width)
223 self.stb = Signal()
224 self.ack = Signal()
225
226 def ports(self):
227 return [self.v, self.stb, self.ack]
228
229
230 class Overflow:
231 def __init__(self):
232 self.guard = Signal(reset_less=True) # tot[2]
233 self.round_bit = Signal(reset_less=True) # tot[1]
234 self.sticky = Signal(reset_less=True) # tot[0]
235
236
237 class FPBase:
238 """ IEEE754 Floating Point Base Class
239
240 contains common functions for FP manipulation, such as
241 extracting and packing operands, normalisation, denormalisation,
242 rounding etc.
243 """
244
245 def get_op(self, m, op, v, next_state):
246 """ this function moves to the next state and copies the operand
247 when both stb and ack are 1.
248 acknowledgement is sent by setting ack to ZERO.
249 """
250 with m.If((op.ack) & (op.stb)):
251 m.next = next_state
252 m.d.sync += [
253 v.decode(op.v),
254 op.ack.eq(0)
255 ]
256 with m.Else():
257 m.d.sync += op.ack.eq(1)
258
259 def denormalise(self, m, a):
260 """ denormalises a number. this is probably the wrong name for
261 this function. for normalised numbers (exponent != minimum)
262 one *extra* bit (the implicit 1) is added *back in*.
263 for denormalised numbers, the mantissa is left alone
264 and the exponent increased by 1.
265
266 both cases *effectively multiply the number stored by 2*,
267 which has to be taken into account when extracting the result.
268 """
269 with m.If(a.e == a.N127):
270 m.d.sync += a.e.eq(a.N126) # limit a exponent
271 with m.Else():
272 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
273
274 def op_normalise(self, m, op, next_state):
275 """ operand normalisation
276 NOTE: just like "align", this one keeps going round every clock
277 until the result's exponent is within acceptable "range"
278 """
279 with m.If((op.m[-1] == 0)): # check last bit of mantissa
280 m.d.sync +=[
281 op.e.eq(op.e - 1), # DECREASE exponent
282 op.m.eq(op.m << 1), # shift mantissa UP
283 ]
284 with m.Else():
285 m.next = next_state
286
287 def normalise_1(self, m, z, of, next_state):
288 """ first stage normalisation
289
290 NOTE: just like "align", this one keeps going round every clock
291 until the result's exponent is within acceptable "range"
292 NOTE: the weirdness of reassigning guard and round is due to
293 the extra mantissa bits coming from tot[0..2]
294 """
295 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
296 m.d.sync += [
297 z.e.eq(z.e - 1), # DECREASE exponent
298 z.m.eq(z.m << 1), # shift mantissa UP
299 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
300 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
301 of.round_bit.eq(0), # reset round bit
302 ]
303 with m.Else():
304 m.next = next_state
305
306 def normalise_2(self, m, z, of, next_state):
307 """ second stage normalisation
308
309 NOTE: just like "align", this one keeps going round every clock
310 until the result's exponent is within acceptable "range"
311 NOTE: the weirdness of reassigning guard and round is due to
312 the extra mantissa bits coming from tot[0..2]
313 """
314 with m.If(z.e < z.N126):
315 m.d.sync +=[
316 z.e.eq(z.e + 1), # INCREASE exponent
317 z.m.eq(z.m >> 1), # shift mantissa DOWN
318 of.guard.eq(z.m[0]),
319 of.round_bit.eq(of.guard),
320 of.sticky.eq(of.sticky | of.round_bit)
321 ]
322 with m.Else():
323 m.next = next_state
324
325 def roundz(self, m, z, of, next_state):
326 """ performs rounding on the output. TODO: different kinds of rounding
327 """
328 m.next = next_state
329 roundz = Signal(reset_less=True)
330 m.d.comb += roundz.eq(of.guard & (of.round_bit | of.sticky | z.m[0]))
331 with m.If(roundz):
332 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
333 with m.If(z.m == z.m1s): # all 1s
334 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
335
336 def corrections(self, m, z, next_state):
337 """ denormalisation and sign-bug corrections
338 """
339 m.next = next_state
340 # denormalised, correct exponent to zero
341 with m.If(z.is_denormalised):
342 m.d.sync += z.e.eq(z.N127)
343
344 def pack(self, m, z, next_state):
345 """ packs the result into the output (detects overflow->Inf)
346 """
347 m.next = next_state
348 # if overflow occurs, return inf
349 with m.If(z.is_overflowed):
350 m.d.sync += z.inf(z.s)
351 with m.Else():
352 m.d.sync += z.create(z.s, z.e, z.m)
353
354 def put_z(self, m, z, out_z, next_state):
355 """ put_z: stores the result in the output. raises stb and waits
356 for ack to be set to 1 before moving to the next state.
357 resets stb back to zero when that occurs, as acknowledgement.
358 """
359 m.d.sync += [
360 out_z.v.eq(z.v)
361 ]
362 with m.If(out_z.stb & out_z.ack):
363 m.d.sync += out_z.stb.eq(0)
364 m.next = next_state
365 with m.Else():
366 m.d.sync += out_z.stb.eq(1)
367
368