add single-cycle version of alignment process in fadd
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShift:
11 """ Generates variable-length single-cycle shifter from a series
12 of conditional tests on each bit of the left/right shift operand.
13 Each bit tested produces output shifted by that number of bits,
14 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
15 shifts by 2 bits, each partial result cascading to the next Mux.
16
17 Could be adapted to do arithmetic shift by taking copies of the
18 MSB instead of zeros.
19 """
20
21 def __init__(self, width):
22 self.width = width
23 self.smax = int(log(width) / log(2))
24
25 def lshift(self, op, s):
26 res = op
27 for i in range(self.smax):
28 zeros = [0] * (1<<i)
29 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
30 return res
31
32 def rshift(self, op, s):
33 res = op
34 for i in range(self.smax):
35 zeros = [0] * (1<<i)
36 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
37 return res
38
39
40 class FPNum:
41 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
42
43 Contains signals for an incoming copy of the value, decoded into
44 sign / exponent / mantissa.
45 Also contains encoding functions, creation and recognition of
46 zero, NaN and inf (all signed)
47
48 Four extra bits are included in the mantissa: the top bit
49 (m[-1]) is effectively a carry-overflow. The other three are
50 guard (m[2]), round (m[1]), and sticky (m[0])
51 """
52 def __init__(self, width, m_extra=True):
53 self.width = width
54 m_width = {32: 24, 64: 53}[width]
55 e_width = {32: 10, 64: 13}[width]
56 e_max = 1<<(e_width-3)
57 self.rmw = m_width # real mantissa width (not including extras)
58 self.e_max = e_max
59 if m_extra:
60 # mantissa extra bits (top,guard,round)
61 self.m_extra = 3
62 m_width += self.m_extra
63 else:
64 self.m_extra = 0
65 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
66 self.m_width = m_width
67 self.e_width = e_width
68 self.e_start = self.rmw - 1
69 self.e_end = self.rmw + self.e_width - 3 # for decoding
70
71 self.v = Signal(width) # Latched copy of value
72 self.m = Signal(m_width) # Mantissa
73 self.e = Signal((e_width, True)) # Exponent: 10 bits, signed
74 self.s = Signal() # Sign bit
75
76 self.mzero = Const(0, (m_width, False))
77 self.m1s = Const(-1, (m_width, False))
78 self.P128 = Const(e_max, (e_width, True))
79 self.P127 = Const(e_max-1, (e_width, True))
80 self.N127 = Const(-(e_max-1), (e_width, True))
81 self.N126 = Const(-(e_max-2), (e_width, True))
82
83 def decode(self, v):
84 """ decodes a latched value into sign / exponent / mantissa
85
86 bias is subtracted here, from the exponent. exponent
87 is extended to 10 bits so that subtract 127 is done on
88 a 10-bit number
89 """
90 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
91 #print ("decode", self.e_end)
92 return [self.m.eq(Cat(*args)), # mantissa
93 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
94 self.s.eq(v[-1]), # sign
95 ]
96
97 def create(self, s, e, m):
98 """ creates a value from sign / exponent / mantissa
99
100 bias is added here, to the exponent
101 """
102 return [
103 self.v[-1].eq(s), # sign
104 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
105 self.v[0:self.e_start].eq(m) # mantissa
106 ]
107
108 def shift_down(self):
109 """ shifts a mantissa down by one. exponent is increased to compensate
110
111 accuracy is lost as a result in the mantissa however there are 3
112 guard bits (the latter of which is the "sticky" bit)
113 """
114 return [self.e.eq(self.e + 1),
115 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
116 ]
117
118 def shift_down_multi(self, diff):
119 """ shifts a mantissa down. exponent is increased to compensate
120
121 accuracy is lost as a result in the mantissa however there are 3
122 guard bits (the latter of which is the "sticky" bit)
123
124 this code works by variable-shifting the mantissa by up to
125 its maximum bit-length: no point doing more (it'll still be
126 zero).
127
128 the sticky bit is computed by shifting a batch of 1s by
129 the same amount, which will introduce zeros. it's then
130 inverted and used as a mask to get the LSBs of the mantissa.
131 those are then |'d into the sticky bit.
132 """
133 sm = MultiShift(self.width)
134 mw = Const(self.m_width-1, len(diff))
135 maxslen = Mux(diff > mw, mw, diff)
136 rs = sm.rshift(self.m[1:], maxslen)
137 maxsleni = mw - maxslen
138 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
139
140 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
141 return [self.e.eq(self.e + diff),
142 self.m.eq(Cat(stickybits, rs))
143 ]
144
145 def nan(self, s):
146 return self.create(s, self.P128, 1<<(self.e_start-1))
147
148 def inf(self, s):
149 return self.create(s, self.P128, 0)
150
151 def zero(self, s):
152 return self.create(s, self.N127, 0)
153
154 def is_nan(self):
155 return (self.e == self.P128) & (self.m != 0)
156
157 def is_inf(self):
158 return (self.e == self.P128) & (self.m == 0)
159
160 def is_zero(self):
161 return (self.e == self.N127) & (self.m == self.mzero)
162
163 def is_overflowed(self):
164 return (self.e > self.P127)
165
166 def is_denormalised(self):
167 return (self.e == self.N126) & (self.m[self.e_start] == 0)
168
169
170 class FPOp:
171 def __init__(self, width):
172 self.width = width
173
174 self.v = Signal(width)
175 self.stb = Signal()
176 self.ack = Signal()
177
178 def ports(self):
179 return [self.v, self.stb, self.ack]
180
181
182 class Overflow:
183 def __init__(self):
184 self.guard = Signal() # tot[2]
185 self.round_bit = Signal() # tot[1]
186 self.sticky = Signal() # tot[0]
187
188
189 class FPBase:
190 """ IEEE754 Floating Point Base Class
191
192 contains common functions for FP manipulation, such as
193 extracting and packing operands, normalisation, denormalisation,
194 rounding etc.
195 """
196
197 def get_op(self, m, op, v, next_state):
198 """ this function moves to the next state and copies the operand
199 when both stb and ack are 1.
200 acknowledgement is sent by setting ack to ZERO.
201 """
202 with m.If((op.ack) & (op.stb)):
203 m.next = next_state
204 m.d.sync += [
205 v.decode(op.v),
206 op.ack.eq(0)
207 ]
208 with m.Else():
209 m.d.sync += op.ack.eq(1)
210
211 def denormalise(self, m, a):
212 """ denormalises a number. this is probably the wrong name for
213 this function. for normalised numbers (exponent != minimum)
214 one *extra* bit (the implicit 1) is added *back in*.
215 for denormalised numbers, the mantissa is left alone
216 and the exponent increased by 1.
217
218 both cases *effectively multiply the number stored by 2*,
219 which has to be taken into account when extracting the result.
220 """
221 with m.If(a.e == a.N127):
222 m.d.sync += a.e.eq(a.N126) # limit a exponent
223 with m.Else():
224 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
225
226 def op_normalise(self, m, op, next_state):
227 """ operand normalisation
228 NOTE: just like "align", this one keeps going round every clock
229 until the result's exponent is within acceptable "range"
230 """
231 with m.If((op.m[-1] == 0)): # check last bit of mantissa
232 m.d.sync +=[
233 op.e.eq(op.e - 1), # DECREASE exponent
234 op.m.eq(op.m << 1), # shift mantissa UP
235 ]
236 with m.Else():
237 m.next = next_state
238
239 def normalise_1(self, m, z, of, next_state):
240 """ first stage normalisation
241
242 NOTE: just like "align", this one keeps going round every clock
243 until the result's exponent is within acceptable "range"
244 NOTE: the weirdness of reassigning guard and round is due to
245 the extra mantissa bits coming from tot[0..2]
246 """
247 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
248 m.d.sync +=[
249 z.e.eq(z.e - 1), # DECREASE exponent
250 z.m.eq(z.m << 1), # shift mantissa UP
251 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
252 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
253 of.round_bit.eq(0), # reset round bit
254 ]
255 with m.Else():
256 m.next = next_state
257
258 def normalise_2(self, m, z, of, next_state):
259 """ second stage normalisation
260
261 NOTE: just like "align", this one keeps going round every clock
262 until the result's exponent is within acceptable "range"
263 NOTE: the weirdness of reassigning guard and round is due to
264 the extra mantissa bits coming from tot[0..2]
265 """
266 with m.If(z.e < z.N126):
267 m.d.sync +=[
268 z.e.eq(z.e + 1), # INCREASE exponent
269 z.m.eq(z.m >> 1), # shift mantissa DOWN
270 of.guard.eq(z.m[0]),
271 of.round_bit.eq(of.guard),
272 of.sticky.eq(of.sticky | of.round_bit)
273 ]
274 with m.Else():
275 m.next = next_state
276
277 def roundz(self, m, z, of, next_state):
278 """ performs rounding on the output. TODO: different kinds of rounding
279 """
280 m.next = next_state
281 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
282 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
283 with m.If(z.m == z.m1s): # all 1s
284 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
285
286 def corrections(self, m, z, next_state):
287 """ denormalisation and sign-bug corrections
288 """
289 m.next = next_state
290 # denormalised, correct exponent to zero
291 with m.If(z.is_denormalised()):
292 m.d.sync += z.m.eq(z.N127)
293 # FIX SIGN BUG: -a + a = +0.
294 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
295 m.d.sync += z.s.eq(0)
296
297 def pack(self, m, z, next_state):
298 """ packs the result into the output (detects overflow->Inf)
299 """
300 m.next = next_state
301 # if overflow occurs, return inf
302 with m.If(z.is_overflowed()):
303 m.d.sync += z.inf(0)
304 with m.Else():
305 m.d.sync += z.create(z.s, z.e, z.m)
306
307 def put_z(self, m, z, out_z, next_state):
308 """ put_z: stores the result in the output. raises stb and waits
309 for ack to be set to 1 before moving to the next state.
310 resets stb back to zero when that occurs, as acknowledgement.
311 """
312 m.d.sync += [
313 out_z.stb.eq(1),
314 out_z.v.eq(z.v)
315 ]
316 with m.If(out_z.stb & out_z.ack):
317 m.d.sync += out_z.stb.eq(0)
318 m.next = next_state
319
320