add a MultiShift class for generating single-cycle bit-shifters
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux
6 from math import log
7
8 class MultiShift:
9 """ Generates variable-length single-cycle shifter from a series
10 of conditional tests on each bit of the left/right shift operand.
11 Each bit tested produces output shifted by that number of bits,
12 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
13 shifts by 2 bits, each partial result cascading to the next Mux.
14
15 Could be adapted to do arithmetic shift by taking copies of the
16 MSB instead of zeros.
17 """
18
19 def __init__(self, width):
20 self.width = width
21 self.smax = int(log(width) / log(2))
22
23 def lshift(self, op, s):
24 res = op
25 for i in range(self.smax):
26 zeros = [0] * (1<<i)
27 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
28 return res
29
30 def rshift(self, op, s):
31 res = op
32 for i in range(self.smax):
33 zeros = [0] * (1<<i)
34 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
35 return res
36
37
38 class FPNum:
39 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
40
41 Contains signals for an incoming copy of the value, decoded into
42 sign / exponent / mantissa.
43 Also contains encoding functions, creation and recognition of
44 zero, NaN and inf (all signed)
45
46 Four extra bits are included in the mantissa: the top bit
47 (m[-1]) is effectively a carry-overflow. The other three are
48 guard (m[2]), round (m[1]), and sticky (m[0])
49 """
50 def __init__(self, width, m_extra=True):
51 self.width = width
52 m_width = {32: 24, 64: 53}[width]
53 e_width = {32: 10, 64: 13}[width]
54 e_max = 1<<(e_width-3)
55 self.rmw = m_width # real mantissa width (not including extras)
56 if m_extra:
57 # mantissa extra bits (top,guard,round)
58 self.m_extra = 3
59 m_width += self.m_extra
60 else:
61 self.m_extra = 0
62 print (m_width, e_width, e_max, self.rmw, self.m_extra)
63 self.m_width = m_width
64 self.e_width = e_width
65 self.e_start = self.rmw - 1
66 self.e_end = self.rmw + self.e_width - 3 # for decoding
67
68 self.v = Signal(width) # Latched copy of value
69 self.m = Signal(m_width) # Mantissa
70 self.e = Signal((e_width, True)) # Exponent: 10 bits, signed
71 self.s = Signal() # Sign bit
72
73 self.mzero = Const(0, (m_width, False))
74 self.m1s = Const(-1, (m_width, False))
75 self.P128 = Const(e_max, (e_width, True))
76 self.P127 = Const(e_max-1, (e_width, True))
77 self.N127 = Const(-(e_max-1), (e_width, True))
78 self.N126 = Const(-(e_max-2), (e_width, True))
79
80 def decode(self, v):
81 """ decodes a latched value into sign / exponent / mantissa
82
83 bias is subtracted here, from the exponent. exponent
84 is extended to 10 bits so that subtract 127 is done on
85 a 10-bit number
86 """
87 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
88 print (self.e_end)
89 return [self.m.eq(Cat(*args)), # mantissa
90 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
91 self.s.eq(v[-1]), # sign
92 ]
93
94 def create(self, s, e, m):
95 """ creates a value from sign / exponent / mantissa
96
97 bias is added here, to the exponent
98 """
99 return [
100 self.v[-1].eq(s), # sign
101 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
102 self.v[0:self.e_start].eq(m) # mantissa
103 ]
104
105 def shift_down(self):
106 """ shifts a mantissa down by one. exponent is increased to compensate
107
108 accuracy is lost as a result in the mantissa however there are 3
109 guard bits (the latter of which is the "sticky" bit)
110 """
111 return [self.e.eq(self.e + 1),
112 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
113 ]
114
115 def nan(self, s):
116 return self.create(s, self.P128, 1<<(self.e_start-1))
117
118 def inf(self, s):
119 return self.create(s, self.P128, 0)
120
121 def zero(self, s):
122 return self.create(s, self.N127, 0)
123
124 def is_nan(self):
125 return (self.e == self.P128) & (self.m != 0)
126
127 def is_inf(self):
128 return (self.e == self.P128) & (self.m == 0)
129
130 def is_zero(self):
131 return (self.e == self.N127) & (self.m == self.mzero)
132
133 def is_overflowed(self):
134 return (self.e > self.P127)
135
136 def is_denormalised(self):
137 return (self.e == self.N126) & (self.m[self.e_start] == 0)
138
139
140 class FPOp:
141 def __init__(self, width):
142 self.width = width
143
144 self.v = Signal(width)
145 self.stb = Signal()
146 self.ack = Signal()
147
148 def ports(self):
149 return [self.v, self.stb, self.ack]
150
151
152 class Overflow:
153 def __init__(self):
154 self.guard = Signal() # tot[2]
155 self.round_bit = Signal() # tot[1]
156 self.sticky = Signal() # tot[0]
157
158
159 class FPBase:
160 """ IEEE754 Floating Point Base Class
161
162 contains common functions for FP manipulation, such as
163 extracting and packing operands, normalisation, denormalisation,
164 rounding etc.
165 """
166
167 def get_op(self, m, op, v, next_state):
168 """ this function moves to the next state and copies the operand
169 when both stb and ack are 1.
170 acknowledgement is sent by setting ack to ZERO.
171 """
172 with m.If((op.ack) & (op.stb)):
173 m.next = next_state
174 m.d.sync += [
175 v.decode(op.v),
176 op.ack.eq(0)
177 ]
178 with m.Else():
179 m.d.sync += op.ack.eq(1)
180
181 def denormalise(self, m, a):
182 """ denormalises a number. this is probably the wrong name for
183 this function. for normalised numbers (exponent != minimum)
184 one *extra* bit (the implicit 1) is added *back in*.
185 for denormalised numbers, the mantissa is left alone
186 and the exponent increased by 1.
187
188 both cases *effectively multiply the number stored by 2*,
189 which has to be taken into account when extracting the result.
190 """
191 with m.If(a.e == a.N127):
192 m.d.sync += a.e.eq(a.N126) # limit a exponent
193 with m.Else():
194 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
195
196 def op_normalise(self, m, op, next_state):
197 """ operand normalisation
198 NOTE: just like "align", this one keeps going round every clock
199 until the result's exponent is within acceptable "range"
200 """
201 with m.If((op.m[-1] == 0)): # check last bit of mantissa
202 m.d.sync +=[
203 op.e.eq(op.e - 1), # DECREASE exponent
204 op.m.eq(op.m << 1), # shift mantissa UP
205 ]
206 with m.Else():
207 m.next = next_state
208
209 def normalise_1(self, m, z, of, next_state):
210 """ first stage normalisation
211
212 NOTE: just like "align", this one keeps going round every clock
213 until the result's exponent is within acceptable "range"
214 NOTE: the weirdness of reassigning guard and round is due to
215 the extra mantissa bits coming from tot[0..2]
216 """
217 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
218 m.d.sync +=[
219 z.e.eq(z.e - 1), # DECREASE exponent
220 z.m.eq(z.m << 1), # shift mantissa UP
221 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
222 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
223 of.round_bit.eq(0), # reset round bit
224 ]
225 with m.Else():
226 m.next = next_state
227
228 def normalise_2(self, m, z, of, next_state):
229 """ second stage normalisation
230
231 NOTE: just like "align", this one keeps going round every clock
232 until the result's exponent is within acceptable "range"
233 NOTE: the weirdness of reassigning guard and round is due to
234 the extra mantissa bits coming from tot[0..2]
235 """
236 with m.If(z.e < z.N126):
237 m.d.sync +=[
238 z.e.eq(z.e + 1), # INCREASE exponent
239 z.m.eq(z.m >> 1), # shift mantissa DOWN
240 of.guard.eq(z.m[0]),
241 of.round_bit.eq(of.guard),
242 of.sticky.eq(of.sticky | of.round_bit)
243 ]
244 with m.Else():
245 m.next = next_state
246
247 def roundz(self, m, z, of, next_state):
248 """ performs rounding on the output. TODO: different kinds of rounding
249 """
250 m.next = next_state
251 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
252 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
253 with m.If(z.m == z.m1s): # all 1s
254 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
255
256 def corrections(self, m, z, next_state):
257 """ denormalisation and sign-bug corrections
258 """
259 m.next = next_state
260 # denormalised, correct exponent to zero
261 with m.If(z.is_denormalised()):
262 m.d.sync += z.m.eq(z.N127)
263 # FIX SIGN BUG: -a + a = +0.
264 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
265 m.d.sync += z.s.eq(0)
266
267 def pack(self, m, z, next_state):
268 """ packs the result into the output (detects overflow->Inf)
269 """
270 m.next = next_state
271 # if overflow occurs, return inf
272 with m.If(z.is_overflowed()):
273 m.d.sync += z.inf(0)
274 with m.Else():
275 m.d.sync += z.create(z.s, z.e, z.m)
276
277 def put_z(self, m, z, out_z, next_state):
278 """ put_z: stores the result in the output. raises stb and waits
279 for ack to be set to 1 before moving to the next state.
280 resets stb back to zero when that occurs, as acknowledgement.
281 """
282 m.d.sync += [
283 out_z.stb.eq(1),
284 out_z.v.eq(z.v)
285 ]
286 with m.If(out_z.stb & out_z.ack):
287 m.d.sync += out_z.stb.eq(0)
288 m.next = next_state
289
290