latch into FPNumIn within module
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 class MultiShiftR:
11
12 def __init__(self, width):
13 self.width = width
14 self.smax = int(log(width) / log(2))
15 self.i = Signal(width, reset_less=True)
16 self.s = Signal(self.smax, reset_less=True)
17 self.o = Signal(width, reset_less=True)
18
19 def elaborate(self, platform):
20 m = Module()
21 m.d.comb += self.o.eq(self.i >> self.s)
22 return m
23
24
25 class MultiShift:
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
31
32 Could be adapted to do arithmetic shift by taking copies of the
33 MSB instead of zeros.
34 """
35
36 def __init__(self, width):
37 self.width = width
38 self.smax = int(log(width) / log(2))
39
40 def lshift(self, op, s):
41 res = op << s
42 return res[:len(op)]
43 res = op
44 for i in range(self.smax):
45 zeros = [0] * (1<<i)
46 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
47 return res
48
49 def rshift(self, op, s):
50 res = op >> s
51 return res[:len(op)]
52 res = op
53 for i in range(self.smax):
54 zeros = [0] * (1<<i)
55 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
56 return res
57
58
59 class FPNumBase:
60 """ Floating-point Base Number Class
61 """
62 def __init__(self, width, m_extra=True):
63 self.width = width
64 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
65 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
66 e_max = 1<<(e_width-3)
67 self.rmw = m_width # real mantissa width (not including extras)
68 self.e_max = e_max
69 if m_extra:
70 # mantissa extra bits (top,guard,round)
71 self.m_extra = 3
72 m_width += self.m_extra
73 else:
74 self.m_extra = 0
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw - 1
79 self.e_end = self.rmw + self.e_width - 3 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.mzero = Const(0, (m_width, False))
87 self.m1s = Const(-1, (m_width, False))
88 self.P128 = Const(e_max, (e_width, True))
89 self.P127 = Const(e_max-1, (e_width, True))
90 self.N127 = Const(-(e_max-1), (e_width, True))
91 self.N126 = Const(-(e_max-2), (e_width, True))
92
93 self.is_nan = Signal(reset_less=True)
94 self.is_zero = Signal(reset_less=True)
95 self.is_inf = Signal(reset_less=True)
96 self.is_overflowed = Signal(reset_less=True)
97 self.is_denormalised = Signal(reset_less=True)
98 self.exp_128 = Signal(reset_less=True)
99
100 def elaborate(self, platform):
101 m = Module()
102 m.d.comb += self.is_nan.eq(self._is_nan())
103 m.d.comb += self.is_zero.eq(self._is_zero())
104 m.d.comb += self.is_inf.eq(self._is_inf())
105 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
106 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
107 m.d.comb += self.exp_128.eq(self.e == self.P128)
108
109 return m
110
111 def _is_nan(self):
112 return (self.e == self.P128) & (self.m != 0)
113
114 def _is_inf(self):
115 return (self.e == self.P128) & (self.m == 0)
116
117 def _is_zero(self):
118 return (self.e == self.N127) & (self.m == self.mzero)
119
120 def _is_overflowed(self):
121 return (self.e > self.P127)
122
123 def _is_denormalised(self):
124 return (self.e == self.N126) & (self.m[self.e_start] == 0)
125
126
127 class FPNumOut(FPNumBase):
128 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
129
130 Contains signals for an incoming copy of the value, decoded into
131 sign / exponent / mantissa.
132 Also contains encoding functions, creation and recognition of
133 zero, NaN and inf (all signed)
134
135 Four extra bits are included in the mantissa: the top bit
136 (m[-1]) is effectively a carry-overflow. The other three are
137 guard (m[2]), round (m[1]), and sticky (m[0])
138 """
139 def __init__(self, width, m_extra=True):
140 FPNumBase.__init__(self, width, m_extra)
141
142 def elaborate(self, platform):
143 m = FPNumBase.elaborate(self, platform)
144
145 return m
146
147 def create(self, s, e, m):
148 """ creates a value from sign / exponent / mantissa
149
150 bias is added here, to the exponent
151 """
152 return [
153 self.v[-1].eq(s), # sign
154 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
155 self.v[0:self.e_start].eq(m) # mantissa
156 ]
157
158 def nan(self, s):
159 return self.create(s, self.P128, 1<<(self.e_start-1))
160
161 def inf(self, s):
162 return self.create(s, self.P128, 0)
163
164 def zero(self, s):
165 return self.create(s, self.N127, 0)
166
167
168 class FPNumIn(FPNumBase):
169 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
170
171 Contains signals for an incoming copy of the value, decoded into
172 sign / exponent / mantissa.
173 Also contains encoding functions, creation and recognition of
174 zero, NaN and inf (all signed)
175
176 Four extra bits are included in the mantissa: the top bit
177 (m[-1]) is effectively a carry-overflow. The other three are
178 guard (m[2]), round (m[1]), and sticky (m[0])
179 """
180 def __init__(self, op, width, m_extra=True):
181 FPNumBase.__init__(self, width, m_extra)
182 self.latch_in = Signal()
183 self.op = op
184
185 def elaborate(self, platform):
186 m = FPNumBase.elaborate(self, platform)
187
188 m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
189 with m.If(self.latch_in):
190 m.d.sync += self.decode(self.v)
191
192 return m
193
194 def decode(self, v):
195 """ decodes a latched value into sign / exponent / mantissa
196
197 bias is subtracted here, from the exponent. exponent
198 is extended to 10 bits so that subtract 127 is done on
199 a 10-bit number
200 """
201 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
202 #print ("decode", self.e_end)
203 return [self.m.eq(Cat(*args)), # mantissa
204 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
205 self.s.eq(v[-1]), # sign
206 ]
207
208 def shift_down(self):
209 """ shifts a mantissa down by one. exponent is increased to compensate
210
211 accuracy is lost as a result in the mantissa however there are 3
212 guard bits (the latter of which is the "sticky" bit)
213 """
214 return [self.e.eq(self.e + 1),
215 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
216 ]
217
218 def shift_down_multi(self, diff):
219 """ shifts a mantissa down. exponent is increased to compensate
220
221 accuracy is lost as a result in the mantissa however there are 3
222 guard bits (the latter of which is the "sticky" bit)
223
224 this code works by variable-shifting the mantissa by up to
225 its maximum bit-length: no point doing more (it'll still be
226 zero).
227
228 the sticky bit is computed by shifting a batch of 1s by
229 the same amount, which will introduce zeros. it's then
230 inverted and used as a mask to get the LSBs of the mantissa.
231 those are then |'d into the sticky bit.
232 """
233 sm = MultiShift(self.width)
234 mw = Const(self.m_width-1, len(diff))
235 maxslen = Mux(diff > mw, mw, diff)
236 rs = sm.rshift(self.m[1:], maxslen)
237 maxsleni = mw - maxslen
238 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
239
240 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
241 return [self.e.eq(self.e + diff),
242 self.m.eq(Cat(stickybits, rs))
243 ]
244
245 def shift_up_multi(self, diff):
246 """ shifts a mantissa up. exponent is decreased to compensate
247 """
248 sm = MultiShift(self.width)
249 mw = Const(self.m_width, len(diff))
250 maxslen = Mux(diff > mw, mw, diff)
251
252 return [self.e.eq(self.e - diff),
253 self.m.eq(sm.lshift(self.m, maxslen))
254 ]
255
256 class FPOp:
257 def __init__(self, width):
258 self.width = width
259
260 self.v = Signal(width)
261 self.stb = Signal()
262 self.ack = Signal()
263
264 def ports(self):
265 return [self.v, self.stb, self.ack]
266
267
268 class Overflow:
269 def __init__(self):
270 self.guard = Signal(reset_less=True) # tot[2]
271 self.round_bit = Signal(reset_less=True) # tot[1]
272 self.sticky = Signal(reset_less=True) # tot[0]
273 self.m0 = Signal(reset_less=True) # mantissa zero bit
274
275 self.roundz = Signal(reset_less=True)
276
277 def elaborate(self, platform):
278 m = Module()
279 m.d.comb += self.roundz.eq(self.guard & \
280 (self.round_bit | self.sticky | self.m0))
281 return m
282
283
284 class FPBase:
285 """ IEEE754 Floating Point Base Class
286
287 contains common functions for FP manipulation, such as
288 extracting and packing operands, normalisation, denormalisation,
289 rounding etc.
290 """
291
292 def get_op(self, m, op, v, next_state):
293 """ this function moves to the next state and copies the operand
294 when both stb and ack are 1.
295 acknowledgement is sent by setting ack to ZERO.
296 """
297 with m.If((op.ack) & (op.stb)):
298 m.next = next_state
299 m.d.sync += [
300 # op is latched in from FPNumIn class on same ack/stb
301 op.ack.eq(0)
302 ]
303 with m.Else():
304 m.d.sync += op.ack.eq(1)
305
306 def denormalise(self, m, a):
307 """ denormalises a number. this is probably the wrong name for
308 this function. for normalised numbers (exponent != minimum)
309 one *extra* bit (the implicit 1) is added *back in*.
310 for denormalised numbers, the mantissa is left alone
311 and the exponent increased by 1.
312
313 both cases *effectively multiply the number stored by 2*,
314 which has to be taken into account when extracting the result.
315 """
316 with m.If(a.e == a.N127):
317 m.d.sync += a.e.eq(a.N126) # limit a exponent
318 with m.Else():
319 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
320
321 def op_normalise(self, m, op, next_state):
322 """ operand normalisation
323 NOTE: just like "align", this one keeps going round every clock
324 until the result's exponent is within acceptable "range"
325 """
326 with m.If((op.m[-1] == 0)): # check last bit of mantissa
327 m.d.sync +=[
328 op.e.eq(op.e - 1), # DECREASE exponent
329 op.m.eq(op.m << 1), # shift mantissa UP
330 ]
331 with m.Else():
332 m.next = next_state
333
334 def normalise_1(self, m, z, of, next_state):
335 """ first stage normalisation
336
337 NOTE: just like "align", this one keeps going round every clock
338 until the result's exponent is within acceptable "range"
339 NOTE: the weirdness of reassigning guard and round is due to
340 the extra mantissa bits coming from tot[0..2]
341 """
342 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
343 m.d.sync += [
344 z.e.eq(z.e - 1), # DECREASE exponent
345 z.m.eq(z.m << 1), # shift mantissa UP
346 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
347 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
348 of.round_bit.eq(0), # reset round bit
349 of.m0.eq(of.guard),
350 ]
351 with m.Else():
352 m.next = next_state
353
354 def normalise_2(self, m, z, of, next_state):
355 """ second stage normalisation
356
357 NOTE: just like "align", this one keeps going round every clock
358 until the result's exponent is within acceptable "range"
359 NOTE: the weirdness of reassigning guard and round is due to
360 the extra mantissa bits coming from tot[0..2]
361 """
362 with m.If(z.e < z.N126):
363 m.d.sync +=[
364 z.e.eq(z.e + 1), # INCREASE exponent
365 z.m.eq(z.m >> 1), # shift mantissa DOWN
366 of.guard.eq(z.m[0]),
367 of.m0.eq(z.m[1]),
368 of.round_bit.eq(of.guard),
369 of.sticky.eq(of.sticky | of.round_bit)
370 ]
371 with m.Else():
372 m.next = next_state
373
374 def roundz(self, m, z, of, next_state):
375 """ performs rounding on the output. TODO: different kinds of rounding
376 """
377 m.next = next_state
378 with m.If(of.roundz):
379 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
380 with m.If(z.m == z.m1s): # all 1s
381 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
382
383 def corrections(self, m, z, next_state):
384 """ denormalisation and sign-bug corrections
385 """
386 m.next = next_state
387 # denormalised, correct exponent to zero
388 with m.If(z.is_denormalised):
389 m.d.sync += z.e.eq(z.N127)
390
391 def pack(self, m, z, next_state):
392 """ packs the result into the output (detects overflow->Inf)
393 """
394 m.next = next_state
395 # if overflow occurs, return inf
396 with m.If(z.is_overflowed):
397 m.d.sync += z.inf(z.s)
398 with m.Else():
399 m.d.sync += z.create(z.s, z.e, z.m)
400
401 def put_z(self, m, z, out_z, next_state):
402 """ put_z: stores the result in the output. raises stb and waits
403 for ack to be set to 1 before moving to the next state.
404 resets stb back to zero when that occurs, as acknowledgement.
405 """
406 m.d.sync += [
407 out_z.v.eq(z.v)
408 ]
409 with m.If(out_z.stb & out_z.ack):
410 m.d.sync += out_z.stb.eq(0)
411 m.next = next_state
412 with m.Else():
413 m.d.sync += out_z.stb.eq(1)
414
415