convert to more general base classes, start support for FP64
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const
6
7
8 class FPNum:
9 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
10
11 Contains signals for an incoming copy of the value, decoded into
12 sign / exponent / mantissa.
13 Also contains encoding functions, creation and recognition of
14 zero, NaN and inf (all signed)
15
16 Four extra bits are included in the mantissa: the top bit
17 (m[-1]) is effectively a carry-overflow. The other three are
18 guard (m[2]), round (m[1]), and sticky (m[0])
19 """
20 def __init__(self, width, m_extra=True):
21 self.width = width
22 m_width = {32: 24, 64: 53}[width]
23 e_width = {32: 10, 64: 13}[width]
24 e_max = 1<<(e_width-3)
25 if m_extra:
26 # mantissa extra bits (top,guard,round)
27 m_width += 3
28 print (m_width, e_width, e_max)
29 self.m_width = m_width
30 self.e_width = e_width
31 self.v = Signal(width) # Latched copy of value
32 self.m = Signal(m_width) # Mantissa
33 self.e = Signal((e_width, True)) # Exponent: 10 bits, signed
34 self.s = Signal() # Sign bit
35
36 self.mzero = Const(0, (m_width, False))
37 self.m1s = Const(-1, (m_width, False))
38 self.P128 = Const(e_max, (e_width, True))
39 self.P127 = Const(e_max-1, (e_width, True))
40 self.N127 = Const(-(e_max-1), (e_width, True))
41 self.N126 = Const(-(e_max-2), (e_width, True))
42
43 def decode(self, v):
44 """ decodes a latched value into sign / exponent / mantissa
45
46 bias is subtracted here, from the exponent. exponent
47 is extended to 10 bits so that subtract 127 is done on
48 a 10-bit number
49 """
50 args = [0] * (self.m_width-24) + [v[0:23]] # pad with extra zeros
51 return [self.m.eq(Cat(*args)), # mantissa
52 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
53 self.s.eq(v[31]), # sign
54 ]
55
56 def create(self, s, e, m):
57 """ creates a value from sign / exponent / mantissa
58
59 bias is added here, to the exponent
60 """
61 return [
62 self.v[31].eq(s), # sign
63 self.v[23:31].eq(e + self.P127), # exp (add on bias)
64 self.v[0:23].eq(m) # mantissa
65 ]
66
67 def shift_down(self):
68 """ shifts a mantissa down by one. exponent is increased to compensate
69
70 accuracy is lost as a result in the mantissa however there are 3
71 guard bits (the latter of which is the "sticky" bit)
72 """
73 return [self.e.eq(self.e + 1),
74 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
75 ]
76
77 def nan(self, s):
78 return self.create(s, self.P128, 1<<22)
79
80 def inf(self, s):
81 return self.create(s, self.P128, 0)
82
83 def zero(self, s):
84 return self.create(s, self.N127, 0)
85
86 def is_nan(self):
87 return (self.e == self.P128) & (self.m != 0)
88
89 def is_inf(self):
90 return (self.e == self.P128) & (self.m == 0)
91
92 def is_zero(self):
93 return (self.e == self.N127) & (self.m == self.mzero)
94
95 def is_overflowed(self):
96 return (self.e > self.P127)
97
98 def is_denormalised(self):
99 return (self.e == self.N126) & (self.m[23] == 0)
100
101
102 class FPOp:
103 def __init__(self, width):
104 self.width = width
105
106 self.v = Signal(width)
107 self.stb = Signal()
108 self.ack = Signal()
109
110 def ports(self):
111 return [self.v, self.stb, self.ack]
112
113
114 class Overflow:
115 def __init__(self):
116 self.guard = Signal() # tot[2]
117 self.round_bit = Signal() # tot[1]
118 self.sticky = Signal() # tot[0]
119
120
121 class FPBase:
122 """ IEEE754 Floating Point Base Class
123
124 contains common functions for FP manipulation, such as
125 extracting and packing operands, normalisation, denormalisation,
126 rounding etc.
127 """
128
129 def get_op(self, m, op, v, next_state):
130 """ this function moves to the next state and copies the operand
131 when both stb and ack are 1.
132 acknowledgement is sent by setting ack to ZERO.
133 """
134 with m.If((op.ack) & (op.stb)):
135 m.next = next_state
136 m.d.sync += [
137 v.decode(op.v),
138 op.ack.eq(0)
139 ]
140 with m.Else():
141 m.d.sync += op.ack.eq(1)
142
143 def denormalise(self, m, a):
144 """ denormalises a number
145 """
146 with m.If(a.e == a.N127):
147 m.d.sync += a.e.eq(-126) # limit a exponent
148 with m.Else():
149 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
150
151 def op_normalise(self, m, op, next_state):
152 """ operand normalisation
153 NOTE: just like "align", this one keeps going round every clock
154 until the result's exponent is within acceptable "range"
155 """
156 with m.If((op.m[-1] == 0)): # check last bit of mantissa
157 m.d.sync +=[
158 op.e.eq(op.e - 1), # DECREASE exponent
159 op.m.eq(op.m << 1), # shift mantissa UP
160 ]
161 with m.Else():
162 m.next = next_state
163
164 def normalise_1(self, m, z, of, next_state):
165 """ first stage normalisation
166
167 NOTE: just like "align", this one keeps going round every clock
168 until the result's exponent is within acceptable "range"
169 NOTE: the weirdness of reassigning guard and round is due to
170 the extra mantissa bits coming from tot[0..2]
171 """
172 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
173 m.d.sync +=[
174 z.e.eq(z.e - 1), # DECREASE exponent
175 z.m.eq(z.m << 1), # shift mantissa UP
176 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
177 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
178 of.round_bit.eq(0), # reset round bit
179 ]
180 with m.Else():
181 m.next = next_state
182
183 def normalise_2(self, m, z, of, next_state):
184 """ second stage normalisation
185
186 NOTE: just like "align", this one keeps going round every clock
187 until the result's exponent is within acceptable "range"
188 NOTE: the weirdness of reassigning guard and round is due to
189 the extra mantissa bits coming from tot[0..2]
190 """
191 with m.If(z.e < z.N126):
192 m.d.sync +=[
193 z.e.eq(z.e + 1), # INCREASE exponent
194 z.m.eq(z.m >> 1), # shift mantissa DOWN
195 of.guard.eq(z.m[0]),
196 of.round_bit.eq(of.guard),
197 of.sticky.eq(of.sticky | of.round_bit)
198 ]
199 with m.Else():
200 m.next = next_state
201
202 def roundz(self, m, z, of, next_state):
203 """ performs rounding on the output. TODO: different kinds of rounding
204 """
205 m.next = next_state
206 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
207 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
208 with m.If(z.m == z.m1s): # all 1s
209 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
210
211 def corrections(self, m, z, next_state):
212 """ denormalisation and sign-bug corrections
213 """
214 m.next = next_state
215 # denormalised, correct exponent to zero
216 with m.If(z.is_denormalised()):
217 m.d.sync += z.m.eq(-127)
218 # FIX SIGN BUG: -a + a = +0.
219 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
220 m.d.sync += z.s.eq(0)
221
222 def pack(self, m, z, next_state):
223 """ packs the result into the output (detects overflow->Inf)
224 """
225 m.next = next_state
226 # if overflow occurs, return inf
227 with m.If(z.is_overflowed()):
228 m.d.sync += z.inf(0)
229 with m.Else():
230 m.d.sync += z.create(z.s, z.e, z.m)
231
232 def put_z(self, m, z, out_z, next_state):
233 """ put_z: stores the result in the output. raises stb and waits
234 for ack to be set to 1 before moving to the next state.
235 resets stb back to zero when that occurs, as acknowledgement.
236 """
237 m.d.sync += [
238 out_z.stb.eq(1),
239 out_z.v.eq(z.v)
240 ]
241 with m.If(out_z.stb & out_z.ack):
242 m.d.sync += out_z.stb.eq(0)
243 m.next = next_state
244
245