store roundz test in comb variable
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNum:
60 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
61
62 Contains signals for an incoming copy of the value, decoded into
63 sign / exponent / mantissa.
64 Also contains encoding functions, creation and recognition of
65 zero, NaN and inf (all signed)
66
67 Four extra bits are included in the mantissa: the top bit
68 (m[-1]) is effectively a carry-overflow. The other three are
69 guard (m[2]), round (m[1]), and sticky (m[0])
70 """
71 def __init__(self, width, m_extra=True):
72 self.width = width
73 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
74 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
75 e_max = 1<<(e_width-3)
76 self.rmw = m_width # real mantissa width (not including extras)
77 self.e_max = e_max
78 if m_extra:
79 # mantissa extra bits (top,guard,round)
80 self.m_extra = 3
81 m_width += self.m_extra
82 else:
83 self.m_extra = 0
84 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
85 self.m_width = m_width
86 self.e_width = e_width
87 self.e_start = self.rmw - 1
88 self.e_end = self.rmw + self.e_width - 3 # for decoding
89
90 self.v = Signal(width) # Latched copy of value
91 self.m = Signal(m_width, reset_less=True) # Mantissa
92 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
93 self.s = Signal(reset_less=True) # Sign bit
94
95 self.mzero = Const(0, (m_width, False))
96 self.m1s = Const(-1, (m_width, False))
97 self.P128 = Const(e_max, (e_width, True))
98 self.P127 = Const(e_max-1, (e_width, True))
99 self.N127 = Const(-(e_max-1), (e_width, True))
100 self.N126 = Const(-(e_max-2), (e_width, True))
101
102 def decode(self, v):
103 """ decodes a latched value into sign / exponent / mantissa
104
105 bias is subtracted here, from the exponent. exponent
106 is extended to 10 bits so that subtract 127 is done on
107 a 10-bit number
108 """
109 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
110 #print ("decode", self.e_end)
111 return [self.m.eq(Cat(*args)), # mantissa
112 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
113 self.s.eq(v[-1]), # sign
114 ]
115
116 def create(self, s, e, m):
117 """ creates a value from sign / exponent / mantissa
118
119 bias is added here, to the exponent
120 """
121 return [
122 self.v[-1].eq(s), # sign
123 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
124 self.v[0:self.e_start].eq(m) # mantissa
125 ]
126
127 def shift_down(self):
128 """ shifts a mantissa down by one. exponent is increased to compensate
129
130 accuracy is lost as a result in the mantissa however there are 3
131 guard bits (the latter of which is the "sticky" bit)
132 """
133 return [self.e.eq(self.e + 1),
134 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
135 ]
136
137 def shift_down_multi(self, diff):
138 """ shifts a mantissa down. exponent is increased to compensate
139
140 accuracy is lost as a result in the mantissa however there are 3
141 guard bits (the latter of which is the "sticky" bit)
142
143 this code works by variable-shifting the mantissa by up to
144 its maximum bit-length: no point doing more (it'll still be
145 zero).
146
147 the sticky bit is computed by shifting a batch of 1s by
148 the same amount, which will introduce zeros. it's then
149 inverted and used as a mask to get the LSBs of the mantissa.
150 those are then |'d into the sticky bit.
151 """
152 sm = MultiShift(self.width)
153 mw = Const(self.m_width-1, len(diff))
154 maxslen = Mux(diff > mw, mw, diff)
155 rs = sm.rshift(self.m[1:], maxslen)
156 maxsleni = mw - maxslen
157 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
158
159 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
160 return [self.e.eq(self.e + diff),
161 self.m.eq(Cat(stickybits, rs))
162 ]
163
164 def shift_up_multi(self, diff):
165 """ shifts a mantissa up. exponent is decreased to compensate
166 """
167 sm = MultiShift(self.width)
168 mw = Const(self.m_width, len(diff))
169 maxslen = Mux(diff > mw, mw, diff)
170
171 return [self.e.eq(self.e - diff),
172 self.m.eq(sm.lshift(self.m, maxslen))
173 ]
174
175 def nan(self, s):
176 return self.create(s, self.P128, 1<<(self.e_start-1))
177
178 def inf(self, s):
179 return self.create(s, self.P128, 0)
180
181 def zero(self, s):
182 return self.create(s, self.N127, 0)
183
184 def is_nan(self):
185 return (self.e == self.P128) & (self.m != 0)
186
187 def is_inf(self):
188 return (self.e == self.P128) & (self.m == 0)
189
190 def is_zero(self):
191 return (self.e == self.N127) & (self.m == self.mzero)
192
193 def is_overflowed(self):
194 return (self.e > self.P127)
195
196 def is_denormalised(self):
197 return (self.e == self.N126) & (self.m[self.e_start] == 0)
198
199
200 class FPOp:
201 def __init__(self, width):
202 self.width = width
203
204 self.v = Signal(width, reset_less=True)
205 self.stb = Signal(reset_less=True)
206 self.ack = Signal(reset_less=True)
207
208 def ports(self):
209 return [self.v, self.stb, self.ack]
210
211
212 class Overflow:
213 def __init__(self):
214 self.guard = Signal(reset_less=True) # tot[2]
215 self.round_bit = Signal(reset_less=True) # tot[1]
216 self.sticky = Signal(reset_less=True) # tot[0]
217
218
219 class FPBase:
220 """ IEEE754 Floating Point Base Class
221
222 contains common functions for FP manipulation, such as
223 extracting and packing operands, normalisation, denormalisation,
224 rounding etc.
225 """
226
227 def get_op(self, m, op, v, next_state):
228 """ this function moves to the next state and copies the operand
229 when both stb and ack are 1.
230 acknowledgement is sent by setting ack to ZERO.
231 """
232 with m.If((op.ack) & (op.stb)):
233 m.next = next_state
234 m.d.sync += [
235 v.decode(op.v),
236 op.ack.eq(0)
237 ]
238 with m.Else():
239 m.d.sync += op.ack.eq(1)
240
241 def denormalise(self, m, a):
242 """ denormalises a number. this is probably the wrong name for
243 this function. for normalised numbers (exponent != minimum)
244 one *extra* bit (the implicit 1) is added *back in*.
245 for denormalised numbers, the mantissa is left alone
246 and the exponent increased by 1.
247
248 both cases *effectively multiply the number stored by 2*,
249 which has to be taken into account when extracting the result.
250 """
251 with m.If(a.e == a.N127):
252 m.d.sync += a.e.eq(a.N126) # limit a exponent
253 with m.Else():
254 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
255
256 def op_normalise(self, m, op, next_state):
257 """ operand normalisation
258 NOTE: just like "align", this one keeps going round every clock
259 until the result's exponent is within acceptable "range"
260 """
261 with m.If((op.m[-1] == 0)): # check last bit of mantissa
262 m.d.sync +=[
263 op.e.eq(op.e - 1), # DECREASE exponent
264 op.m.eq(op.m << 1), # shift mantissa UP
265 ]
266 with m.Else():
267 m.next = next_state
268
269 def normalise_1(self, m, z, of, next_state):
270 """ first stage normalisation
271
272 NOTE: just like "align", this one keeps going round every clock
273 until the result's exponent is within acceptable "range"
274 NOTE: the weirdness of reassigning guard and round is due to
275 the extra mantissa bits coming from tot[0..2]
276 """
277 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
278 m.d.sync += [
279 z.e.eq(z.e - 1), # DECREASE exponent
280 z.m.eq(z.m << 1), # shift mantissa UP
281 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
282 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
283 of.round_bit.eq(0), # reset round bit
284 ]
285 with m.Else():
286 m.next = next_state
287
288 def normalise_2(self, m, z, of, next_state):
289 """ second stage normalisation
290
291 NOTE: just like "align", this one keeps going round every clock
292 until the result's exponent is within acceptable "range"
293 NOTE: the weirdness of reassigning guard and round is due to
294 the extra mantissa bits coming from tot[0..2]
295 """
296 with m.If(z.e < z.N126):
297 m.d.sync +=[
298 z.e.eq(z.e + 1), # INCREASE exponent
299 z.m.eq(z.m >> 1), # shift mantissa DOWN
300 of.guard.eq(z.m[0]),
301 of.round_bit.eq(of.guard),
302 of.sticky.eq(of.sticky | of.round_bit)
303 ]
304 with m.Else():
305 m.next = next_state
306
307 def roundz(self, m, z, of, next_state):
308 """ performs rounding on the output. TODO: different kinds of rounding
309 """
310 m.next = next_state
311 roundz = Signal(reset_less=True)
312 m.d.comb += roundz.eq(of.guard & (of.round_bit | of.sticky | z.m[0]))
313 with m.If(roundz):
314 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
315 with m.If(z.m == z.m1s): # all 1s
316 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
317
318 def corrections(self, m, z, next_state):
319 """ denormalisation and sign-bug corrections
320 """
321 m.next = next_state
322 # denormalised, correct exponent to zero
323 with m.If(z.is_denormalised()):
324 m.d.sync += z.e.eq(z.N127)
325
326 def pack(self, m, z, next_state):
327 """ packs the result into the output (detects overflow->Inf)
328 """
329 m.next = next_state
330 # if overflow occurs, return inf
331 with m.If(z.is_overflowed()):
332 m.d.sync += z.inf(z.s)
333 with m.Else():
334 m.d.sync += z.create(z.s, z.e, z.m)
335
336 def put_z(self, m, z, out_z, next_state):
337 """ put_z: stores the result in the output. raises stb and waits
338 for ack to be set to 1 before moving to the next state.
339 resets stb back to zero when that occurs, as acknowledgement.
340 """
341 m.d.sync += [
342 out_z.v.eq(z.v)
343 ]
344 with m.If(out_z.stb & out_z.ack):
345 m.d.sync += out_z.stb.eq(0)
346 m.next = next_state
347 with m.Else():
348 m.d.sync += out_z.stb.eq(1)
349
350