add beginning unit tests for 64-bit add
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const
6
7
8 class FPNum:
9 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
10
11 Contains signals for an incoming copy of the value, decoded into
12 sign / exponent / mantissa.
13 Also contains encoding functions, creation and recognition of
14 zero, NaN and inf (all signed)
15
16 Four extra bits are included in the mantissa: the top bit
17 (m[-1]) is effectively a carry-overflow. The other three are
18 guard (m[2]), round (m[1]), and sticky (m[0])
19 """
20 def __init__(self, width, m_extra=True):
21 self.width = width
22 m_width = {32: 24, 64: 53}[width]
23 e_width = {32: 10, 64: 13}[width]
24 e_max = 1<<(e_width-3)
25 self.rmw = m_width # real mantissa width (not including extras)
26 if m_extra:
27 # mantissa extra bits (top,guard,round)
28 self.m_extra = 3
29 m_width += self.m_extra
30 else:
31 self.m_extra = 0
32 print (m_width, e_width, e_max, self.rmw, self.m_extra)
33 self.m_width = m_width
34 self.e_width = e_width
35 self.e_start = self.rmw - 1
36 self.e_end = self.rmw + self.e_width - 3 # for decoding
37
38 self.v = Signal(width) # Latched copy of value
39 self.m = Signal(m_width) # Mantissa
40 self.e = Signal((e_width, True)) # Exponent: 10 bits, signed
41 self.s = Signal() # Sign bit
42
43 self.mzero = Const(0, (m_width, False))
44 self.m1s = Const(-1, (m_width, False))
45 self.P128 = Const(e_max, (e_width, True))
46 self.P127 = Const(e_max-1, (e_width, True))
47 self.N127 = Const(-(e_max-1), (e_width, True))
48 self.N126 = Const(-(e_max-2), (e_width, True))
49
50 def decode(self, v):
51 """ decodes a latched value into sign / exponent / mantissa
52
53 bias is subtracted here, from the exponent. exponent
54 is extended to 10 bits so that subtract 127 is done on
55 a 10-bit number
56 """
57 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
58 print (self.e_end)
59 return [self.m.eq(Cat(*args)), # mantissa
60 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
61 self.s.eq(v[-1]), # sign
62 ]
63
64 def create(self, s, e, m):
65 """ creates a value from sign / exponent / mantissa
66
67 bias is added here, to the exponent
68 """
69 return [
70 self.v[-1].eq(s), # sign
71 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
72 self.v[0:self.e_start].eq(m) # mantissa
73 ]
74
75 def shift_down(self):
76 """ shifts a mantissa down by one. exponent is increased to compensate
77
78 accuracy is lost as a result in the mantissa however there are 3
79 guard bits (the latter of which is the "sticky" bit)
80 """
81 return [self.e.eq(self.e + 1),
82 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
83 ]
84
85 def nan(self, s):
86 return self.create(s, self.P128, 1<<(self.e_start-1))
87
88 def inf(self, s):
89 return self.create(s, self.P128, 0)
90
91 def zero(self, s):
92 return self.create(s, self.N127, 0)
93
94 def is_nan(self):
95 return (self.e == self.P128) & (self.m != 0)
96
97 def is_inf(self):
98 return (self.e == self.P128) & (self.m == 0)
99
100 def is_zero(self):
101 return (self.e == self.N127) & (self.m == self.mzero)
102
103 def is_overflowed(self):
104 return (self.e > self.P127)
105
106 def is_denormalised(self):
107 return (self.e == self.N126) & (self.m[self.e_start] == 0)
108
109
110 class FPOp:
111 def __init__(self, width):
112 self.width = width
113
114 self.v = Signal(width)
115 self.stb = Signal()
116 self.ack = Signal()
117
118 def ports(self):
119 return [self.v, self.stb, self.ack]
120
121
122 class Overflow:
123 def __init__(self):
124 self.guard = Signal() # tot[2]
125 self.round_bit = Signal() # tot[1]
126 self.sticky = Signal() # tot[0]
127
128
129 class FPBase:
130 """ IEEE754 Floating Point Base Class
131
132 contains common functions for FP manipulation, such as
133 extracting and packing operands, normalisation, denormalisation,
134 rounding etc.
135 """
136
137 def get_op(self, m, op, v, next_state):
138 """ this function moves to the next state and copies the operand
139 when both stb and ack are 1.
140 acknowledgement is sent by setting ack to ZERO.
141 """
142 with m.If((op.ack) & (op.stb)):
143 m.next = next_state
144 m.d.sync += [
145 v.decode(op.v),
146 op.ack.eq(0)
147 ]
148 with m.Else():
149 m.d.sync += op.ack.eq(1)
150
151 def denormalise(self, m, a):
152 """ denormalises a number
153 """
154 with m.If(a.e == a.N127):
155 m.d.sync += a.e.eq(a.N126) # limit a exponent
156 with m.Else():
157 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
158
159 def op_normalise(self, m, op, next_state):
160 """ operand normalisation
161 NOTE: just like "align", this one keeps going round every clock
162 until the result's exponent is within acceptable "range"
163 """
164 with m.If((op.m[-1] == 0)): # check last bit of mantissa
165 m.d.sync +=[
166 op.e.eq(op.e - 1), # DECREASE exponent
167 op.m.eq(op.m << 1), # shift mantissa UP
168 ]
169 with m.Else():
170 m.next = next_state
171
172 def normalise_1(self, m, z, of, next_state):
173 """ first stage normalisation
174
175 NOTE: just like "align", this one keeps going round every clock
176 until the result's exponent is within acceptable "range"
177 NOTE: the weirdness of reassigning guard and round is due to
178 the extra mantissa bits coming from tot[0..2]
179 """
180 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
181 m.d.sync +=[
182 z.e.eq(z.e - 1), # DECREASE exponent
183 z.m.eq(z.m << 1), # shift mantissa UP
184 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
185 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
186 of.round_bit.eq(0), # reset round bit
187 ]
188 with m.Else():
189 m.next = next_state
190
191 def normalise_2(self, m, z, of, next_state):
192 """ second stage normalisation
193
194 NOTE: just like "align", this one keeps going round every clock
195 until the result's exponent is within acceptable "range"
196 NOTE: the weirdness of reassigning guard and round is due to
197 the extra mantissa bits coming from tot[0..2]
198 """
199 with m.If(z.e < z.N126):
200 m.d.sync +=[
201 z.e.eq(z.e + 1), # INCREASE exponent
202 z.m.eq(z.m >> 1), # shift mantissa DOWN
203 of.guard.eq(z.m[0]),
204 of.round_bit.eq(of.guard),
205 of.sticky.eq(of.sticky | of.round_bit)
206 ]
207 with m.Else():
208 m.next = next_state
209
210 def roundz(self, m, z, of, next_state):
211 """ performs rounding on the output. TODO: different kinds of rounding
212 """
213 m.next = next_state
214 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
215 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
216 with m.If(z.m == z.m1s): # all 1s
217 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
218
219 def corrections(self, m, z, next_state):
220 """ denormalisation and sign-bug corrections
221 """
222 m.next = next_state
223 # denormalised, correct exponent to zero
224 with m.If(z.is_denormalised()):
225 m.d.sync += z.m.eq(z.N127)
226 # FIX SIGN BUG: -a + a = +0.
227 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
228 m.d.sync += z.s.eq(0)
229
230 def pack(self, m, z, next_state):
231 """ packs the result into the output (detects overflow->Inf)
232 """
233 m.next = next_state
234 # if overflow occurs, return inf
235 with m.If(z.is_overflowed()):
236 m.d.sync += z.inf(0)
237 with m.Else():
238 m.d.sync += z.create(z.s, z.e, z.m)
239
240 def put_z(self, m, z, out_z, next_state):
241 """ put_z: stores the result in the output. raises stb and waits
242 for ack to be set to 1 before moving to the next state.
243 resets stb back to zero when that occurs, as acknowledgement.
244 """
245 m.d.sync += [
246 out_z.stb.eq(1),
247 out_z.v.eq(z.v)
248 ]
249 with m.If(out_z.stb & out_z.ack):
250 m.d.sync += out_z.stb.eq(0)
251 m.next = next_state
252
253