reorg setup functions in more add phases
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Mux
6 from nmigen.lib.coding import PriorityEncoder
7 from nmigen.cli import main, verilog
8
9 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
10 from fpbase import MultiShiftRMerge
11 #from fpbase import FPNumShiftMultiRight
12
13 class FPState(FPBase):
14 def __init__(self, state_from):
15 self.state_from = state_from
16
17 def set_inputs(self, inputs):
18 self.inputs = inputs
19 for k,v in inputs.items():
20 setattr(self, k, v)
21
22 def set_outputs(self, outputs):
23 self.outputs = outputs
24 for k,v in outputs.items():
25 setattr(self, k, v)
26
27
28 class FPGetOpMod:
29 def __init__(self, width):
30 self.in_op = FPOp(width)
31 self.out_op = FPNumIn(self.in_op, width)
32 self.out_decode = Signal(reset_less=True)
33
34 def elaborate(self, platform):
35 m = Module()
36 m.d.comb += self.out_decode.eq((self.in_op.ack) & (self.in_op.stb))
37 #m.submodules.get_op_in = self.in_op
38 m.submodules.get_op_out = self.out_op
39 with m.If(self.out_decode):
40 m.d.comb += [
41 self.out_op.decode(self.in_op.v),
42 ]
43 return m
44
45
46 class FPGetOp(FPState):
47 """ gets operand
48 """
49
50 def __init__(self, in_state, out_state, in_op, width):
51 FPState.__init__(self, in_state)
52 self.out_state = out_state
53 self.mod = FPGetOpMod(width)
54 self.in_op = in_op
55 self.out_op = FPNumIn(in_op, width)
56 self.out_decode = Signal(reset_less=True)
57
58 def setup(self, m, in_op):
59 """ links module to inputs and outputs
60 """
61 setattr(m.submodules, self.state_from, self.mod)
62 m.d.comb += self.mod.in_op.copy(in_op)
63 m.d.comb += self.out_op.v.eq(self.mod.out_op.v)
64 m.d.comb += self.out_decode.eq(self.mod.out_decode)
65
66 def action(self, m):
67 with m.If(self.out_decode):
68 m.next = self.out_state
69 m.d.sync += [
70 self.in_op.ack.eq(0),
71 self.out_op.copy(self.mod.out_op)
72 ]
73 with m.Else():
74 m.d.sync += self.in_op.ack.eq(1)
75
76
77 class FPGetOpB(FPState):
78 """ gets operand b
79 """
80
81 def __init__(self, in_b, width):
82 FPState.__init__(self, "get_b")
83 self.in_b = in_b
84 self.b = FPNumIn(self.in_b, width)
85
86 def action(self, m):
87 self.get_op(m, self.in_b, self.b, "special_cases")
88
89
90 class FPAddSpecialCasesMod:
91 """ special cases: NaNs, infs, zeros, denormalised
92 NOTE: some of these are unique to add. see "Special Operations"
93 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
94 """
95
96 def __init__(self, width):
97 self.in_a = FPNumBase(width)
98 self.in_b = FPNumBase(width)
99 self.out_z = FPNumOut(width, False)
100 self.out_do_z = Signal(reset_less=True)
101
102 def setup(self, m, in_a, in_b, out_z, out_do_z):
103 """ links module to inputs and outputs
104 """
105 m.d.comb += self.in_a.copy(in_a)
106 m.d.comb += self.in_b.copy(in_b)
107 #m.d.comb += out_z.v.eq(self.out_z.v)
108 m.d.comb += out_do_z.eq(self.out_do_z)
109
110 def elaborate(self, platform):
111 m = Module()
112
113 m.submodules.sc_in_a = self.in_a
114 m.submodules.sc_in_b = self.in_b
115 m.submodules.sc_out_z = self.out_z
116
117 s_nomatch = Signal()
118 m.d.comb += s_nomatch.eq(self.in_a.s != self.in_b.s)
119
120 m_match = Signal()
121 m.d.comb += m_match.eq(self.in_a.m == self.in_b.m)
122
123 # if a is NaN or b is NaN return NaN
124 with m.If(self.in_a.is_nan | self.in_b.is_nan):
125 m.d.comb += self.out_do_z.eq(1)
126 m.d.comb += self.out_z.nan(0)
127
128 # XXX WEIRDNESS for FP16 non-canonical NaN handling
129 # under review
130
131 ## if a is zero and b is NaN return -b
132 #with m.If(a.is_zero & (a.s==0) & b.is_nan):
133 # m.d.comb += self.out_do_z.eq(1)
134 # m.d.comb += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
135
136 ## if b is zero and a is NaN return -a
137 #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
138 # m.d.comb += self.out_do_z.eq(1)
139 # m.d.comb += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
140
141 ## if a is -zero and b is NaN return -b
142 #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
143 # m.d.comb += self.out_do_z.eq(1)
144 # m.d.comb += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
145
146 ## if b is -zero and a is NaN return -a
147 #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
148 # m.d.comb += self.out_do_z.eq(1)
149 # m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
150
151 # if a is inf return inf (or NaN)
152 with m.Elif(self.in_a.is_inf):
153 m.d.comb += self.out_do_z.eq(1)
154 m.d.comb += self.out_z.inf(self.in_a.s)
155 # if a is inf and signs don't match return NaN
156 with m.If(self.in_b.exp_128 & s_nomatch):
157 m.d.comb += self.out_z.nan(0)
158
159 # if b is inf return inf
160 with m.Elif(self.in_b.is_inf):
161 m.d.comb += self.out_do_z.eq(1)
162 m.d.comb += self.out_z.inf(self.in_b.s)
163
164 # if a is zero and b zero return signed-a/b
165 with m.Elif(self.in_a.is_zero & self.in_b.is_zero):
166 m.d.comb += self.out_do_z.eq(1)
167 m.d.comb += self.out_z.create(self.in_a.s & self.in_b.s,
168 self.in_b.e,
169 self.in_b.m[3:-1])
170
171 # if a is zero return b
172 with m.Elif(self.in_a.is_zero):
173 m.d.comb += self.out_do_z.eq(1)
174 m.d.comb += self.out_z.create(self.in_b.s, self.in_b.e,
175 self.in_b.m[3:-1])
176
177 # if b is zero return a
178 with m.Elif(self.in_b.is_zero):
179 m.d.comb += self.out_do_z.eq(1)
180 m.d.comb += self.out_z.create(self.in_a.s, self.in_a.e,
181 self.in_a.m[3:-1])
182
183 # if a equal to -b return zero (+ve zero)
184 with m.Elif(s_nomatch & m_match & (self.in_a.e == self.in_b.e)):
185 m.d.comb += self.out_do_z.eq(1)
186 m.d.comb += self.out_z.zero(0)
187
188 # Denormalised Number checks
189 with m.Else():
190 m.d.comb += self.out_do_z.eq(0)
191
192 return m
193
194
195 class FPAddSpecialCases(FPState):
196 """ special cases: NaNs, infs, zeros, denormalised
197 NOTE: some of these are unique to add. see "Special Operations"
198 https://steve.hollasch.net/cgindex/coding/ieeefloat.html
199 """
200
201 def __init__(self, width):
202 FPState.__init__(self, "special_cases")
203 self.mod = FPAddSpecialCasesMod(width)
204 self.out_z = FPNumOut(width, False)
205 self.out_do_z = Signal(reset_less=True)
206
207 def action(self, m):
208 with m.If(self.out_do_z):
209 m.d.sync += self.out_z.v.eq(self.mod.out_z.v) # only take the output
210 m.next = "put_z"
211 with m.Else():
212 m.next = "denormalise"
213
214
215 class FPAddDeNormMod(FPState):
216
217 def __init__(self, width):
218 self.in_a = FPNumBase(width)
219 self.in_b = FPNumBase(width)
220 self.out_a = FPNumBase(width)
221 self.out_b = FPNumBase(width)
222
223 def elaborate(self, platform):
224 m = Module()
225 m.submodules.denorm_in_a = self.in_a
226 m.submodules.denorm_in_b = self.in_b
227 m.submodules.denorm_out_a = self.out_a
228 m.submodules.denorm_out_b = self.out_b
229 # hmmm, don't like repeating identical code
230 m.d.comb += self.out_a.copy(self.in_a)
231 with m.If(self.in_a.exp_n127):
232 m.d.comb += self.out_a.e.eq(self.in_a.N126) # limit a exponent
233 with m.Else():
234 m.d.comb += self.out_a.m[-1].eq(1) # set top mantissa bit
235
236 m.d.comb += self.out_b.copy(self.in_b)
237 with m.If(self.in_b.exp_n127):
238 m.d.comb += self.out_b.e.eq(self.in_b.N126) # limit a exponent
239 with m.Else():
240 m.d.comb += self.out_b.m[-1].eq(1) # set top mantissa bit
241
242 return m
243
244
245 class FPAddDeNorm(FPState):
246
247 def __init__(self, width):
248 FPState.__init__(self, "denormalise")
249 self.mod = FPAddDeNormMod(width)
250 self.out_a = FPNumBase(width)
251 self.out_b = FPNumBase(width)
252
253 def setup(self, m, in_a, in_b):
254 """ links module to inputs and outputs
255 """
256 m.submodules.denormalise = self.mod
257 m.d.comb += self.mod.in_a.copy(in_a)
258 m.d.comb += self.mod.in_b.copy(in_b)
259
260 def action(self, m):
261 # Denormalised Number checks
262 m.next = "align"
263 m.d.sync += self.out_a.copy(self.mod.out_a)
264 m.d.sync += self.out_b.copy(self.mod.out_b)
265
266
267 class FPAddAlignMultiMod(FPState):
268
269 def __init__(self, width):
270 self.in_a = FPNumBase(width)
271 self.in_b = FPNumBase(width)
272 self.out_a = FPNumIn(None, width)
273 self.out_b = FPNumIn(None, width)
274 self.exp_eq = Signal(reset_less=True)
275
276 def elaborate(self, platform):
277 # This one however (single-cycle) will do the shift
278 # in one go.
279
280 m = Module()
281
282 m.submodules.align_in_a = self.in_a
283 m.submodules.align_in_b = self.in_b
284 m.submodules.align_out_a = self.out_a
285 m.submodules.align_out_b = self.out_b
286
287 # NOTE: this does *not* do single-cycle multi-shifting,
288 # it *STAYS* in the align state until exponents match
289
290 # exponent of a greater than b: shift b down
291 m.d.comb += self.exp_eq.eq(0)
292 m.d.comb += self.out_a.copy(self.in_a)
293 m.d.comb += self.out_b.copy(self.in_b)
294 agtb = Signal(reset_less=True)
295 altb = Signal(reset_less=True)
296 m.d.comb += agtb.eq(self.in_a.e > self.in_b.e)
297 m.d.comb += altb.eq(self.in_a.e < self.in_b.e)
298 with m.If(agtb):
299 m.d.comb += self.out_b.shift_down(self.in_b)
300 # exponent of b greater than a: shift a down
301 with m.Elif(altb):
302 m.d.comb += self.out_a.shift_down(self.in_a)
303 # exponents equal: move to next stage.
304 with m.Else():
305 m.d.comb += self.exp_eq.eq(1)
306 return m
307
308
309 class FPAddAlignMulti(FPState):
310
311 def __init__(self, width):
312 FPState.__init__(self, "align")
313 self.mod = FPAddAlignMultiMod(width)
314 self.out_a = FPNumIn(None, width)
315 self.out_b = FPNumIn(None, width)
316 self.exp_eq = Signal(reset_less=True)
317
318 def setup(self, m, in_a, in_b):
319 """ links module to inputs and outputs
320 """
321 m.submodules.align = self.mod
322 m.d.comb += self.mod.in_a.copy(in_a)
323 m.d.comb += self.mod.in_b.copy(in_b)
324 #m.d.comb += self.out_a.copy(self.mod.out_a)
325 #m.d.comb += self.out_b.copy(self.mod.out_b)
326 m.d.comb += self.exp_eq.eq(self.mod.exp_eq)
327
328 def action(self, m):
329 m.d.sync += self.out_a.copy(self.mod.out_a)
330 m.d.sync += self.out_b.copy(self.mod.out_b)
331 with m.If(self.exp_eq):
332 m.next = "add_0"
333
334
335 class FPAddAlignSingleMod:
336
337 def __init__(self, width):
338 self.width = width
339 self.in_a = FPNumBase(width)
340 self.in_b = FPNumBase(width)
341 self.out_a = FPNumIn(None, width)
342 self.out_b = FPNumIn(None, width)
343
344 def elaborate(self, platform):
345 """ Aligns A against B or B against A, depending on which has the
346 greater exponent. This is done in a *single* cycle using
347 variable-width bit-shift
348
349 the shifter used here is quite expensive in terms of gates.
350 Mux A or B in (and out) into temporaries, as only one of them
351 needs to be aligned against the other
352 """
353 m = Module()
354
355 m.submodules.align_in_a = self.in_a
356 m.submodules.align_in_b = self.in_b
357 m.submodules.align_out_a = self.out_a
358 m.submodules.align_out_b = self.out_b
359
360 # temporary (muxed) input and output to be shifted
361 t_inp = FPNumBase(self.width)
362 t_out = FPNumIn(None, self.width)
363 espec = (len(self.in_a.e), True)
364 msr = MultiShiftRMerge(self.in_a.m_width, espec)
365 m.submodules.align_t_in = t_inp
366 m.submodules.align_t_out = t_out
367 m.submodules.multishift_r = msr
368
369 ediff = Signal(espec, reset_less=True)
370 ediffr = Signal(espec, reset_less=True)
371 tdiff = Signal(espec, reset_less=True)
372 elz = Signal(reset_less=True)
373 egz = Signal(reset_less=True)
374
375 # connect multi-shifter to t_inp/out mantissa (and tdiff)
376 m.d.comb += msr.inp.eq(t_inp.m)
377 m.d.comb += msr.diff.eq(tdiff)
378 m.d.comb += t_out.m.eq(msr.m)
379 m.d.comb += t_out.e.eq(t_inp.e + tdiff)
380 m.d.comb += t_out.s.eq(t_inp.s)
381
382 m.d.comb += ediff.eq(self.in_a.e - self.in_b.e)
383 m.d.comb += ediffr.eq(self.in_b.e - self.in_a.e)
384 m.d.comb += elz.eq(self.in_a.e < self.in_b.e)
385 m.d.comb += egz.eq(self.in_a.e > self.in_b.e)
386
387 # default: A-exp == B-exp, A and B untouched (fall through)
388 m.d.comb += self.out_a.copy(self.in_a)
389 m.d.comb += self.out_b.copy(self.in_b)
390 # only one shifter (muxed)
391 #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
392 # exponent of a greater than b: shift b down
393 with m.If(egz):
394 m.d.comb += [t_inp.copy(self.in_b),
395 tdiff.eq(ediff),
396 self.out_b.copy(t_out),
397 self.out_b.s.eq(self.in_b.s), # whoops forgot sign
398 ]
399 # exponent of b greater than a: shift a down
400 with m.Elif(elz):
401 m.d.comb += [t_inp.copy(self.in_a),
402 tdiff.eq(ediffr),
403 self.out_a.copy(t_out),
404 self.out_a.s.eq(self.in_a.s), # whoops forgot sign
405 ]
406 return m
407
408
409 class FPAddAlignSingle(FPState):
410
411 def __init__(self, width):
412 FPState.__init__(self, "align")
413 self.mod = FPAddAlignSingleMod(width)
414 self.out_a = FPNumIn(None, width)
415 self.out_b = FPNumIn(None, width)
416
417 def setup(self, m, in_a, in_b):
418 """ links module to inputs and outputs
419 """
420 m.submodules.align = self.mod
421 m.d.comb += self.mod.in_a.copy(in_a)
422 m.d.comb += self.mod.in_b.copy(in_b)
423
424 def action(self, m):
425 # NOTE: could be done as comb
426 m.d.sync += self.out_a.copy(self.mod.out_a)
427 m.d.sync += self.out_b.copy(self.mod.out_b)
428 m.next = "add_0"
429
430
431 class FPAddStage0Mod:
432
433 def __init__(self, width):
434 self.in_a = FPNumBase(width)
435 self.in_b = FPNumBase(width)
436 self.in_z = FPNumBase(width, False)
437 self.out_z = FPNumBase(width, False)
438 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
439
440 def elaborate(self, platform):
441 m = Module()
442 m.submodules.add0_in_a = self.in_a
443 m.submodules.add0_in_b = self.in_b
444 m.submodules.add0_out_z = self.out_z
445
446 m.d.comb += self.out_z.e.eq(self.in_a.e)
447
448 # store intermediate tests (and zero-extended mantissas)
449 seq = Signal(reset_less=True)
450 mge = Signal(reset_less=True)
451 am0 = Signal(len(self.in_a.m)+1, reset_less=True)
452 bm0 = Signal(len(self.in_b.m)+1, reset_less=True)
453 m.d.comb += [seq.eq(self.in_a.s == self.in_b.s),
454 mge.eq(self.in_a.m >= self.in_b.m),
455 am0.eq(Cat(self.in_a.m, 0)),
456 bm0.eq(Cat(self.in_b.m, 0))
457 ]
458 # same-sign (both negative or both positive) add mantissas
459 with m.If(seq):
460 m.d.comb += [
461 self.out_tot.eq(am0 + bm0),
462 self.out_z.s.eq(self.in_a.s)
463 ]
464 # a mantissa greater than b, use a
465 with m.Elif(mge):
466 m.d.comb += [
467 self.out_tot.eq(am0 - bm0),
468 self.out_z.s.eq(self.in_a.s)
469 ]
470 # b mantissa greater than a, use b
471 with m.Else():
472 m.d.comb += [
473 self.out_tot.eq(bm0 - am0),
474 self.out_z.s.eq(self.in_b.s)
475 ]
476 return m
477
478
479 class FPAddStage0(FPState):
480 """ First stage of add. covers same-sign (add) and subtract
481 special-casing when mantissas are greater or equal, to
482 give greatest accuracy.
483 """
484
485 def __init__(self, width):
486 FPState.__init__(self, "add_0")
487 self.mod = FPAddStage0Mod(width)
488 self.out_z = FPNumBase(width, False)
489 self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
490
491 def setup(self, m, in_a, in_b):
492 """ links module to inputs and outputs
493 """
494 m.submodules.add0 = self.mod
495
496 m.d.comb += self.mod.in_a.copy(in_a)
497 m.d.comb += self.mod.in_b.copy(in_b)
498
499 def action(self, m):
500 m.next = "add_1"
501 # NOTE: these could be done as combinatorial (merge add0+add1)
502 m.d.sync += self.out_z.copy(self.mod.out_z)
503 m.d.sync += self.out_tot.eq(self.mod.out_tot)
504
505
506 class FPAddStage1Mod(FPState):
507 """ Second stage of add: preparation for normalisation.
508 detects when tot sum is too big (tot[27] is kinda a carry bit)
509 """
510
511 def __init__(self, width):
512 self.out_norm = Signal(reset_less=True)
513 self.in_z = FPNumBase(width, False)
514 self.in_tot = Signal(self.in_z.m_width + 4, reset_less=True)
515 self.out_z = FPNumBase(width, False)
516 self.out_of = Overflow()
517
518 def elaborate(self, platform):
519 m = Module()
520 #m.submodules.norm1_in_overflow = self.in_of
521 #m.submodules.norm1_out_overflow = self.out_of
522 #m.submodules.norm1_in_z = self.in_z
523 #m.submodules.norm1_out_z = self.out_z
524 m.d.comb += self.out_z.copy(self.in_z)
525 # tot[27] gets set when the sum overflows. shift result down
526 with m.If(self.in_tot[-1]):
527 m.d.comb += [
528 self.out_z.m.eq(self.in_tot[4:]),
529 self.out_of.m0.eq(self.in_tot[4]),
530 self.out_of.guard.eq(self.in_tot[3]),
531 self.out_of.round_bit.eq(self.in_tot[2]),
532 self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
533 self.out_z.e.eq(self.in_z.e + 1)
534 ]
535 # tot[27] zero case
536 with m.Else():
537 m.d.comb += [
538 self.out_z.m.eq(self.in_tot[3:]),
539 self.out_of.m0.eq(self.in_tot[3]),
540 self.out_of.guard.eq(self.in_tot[2]),
541 self.out_of.round_bit.eq(self.in_tot[1]),
542 self.out_of.sticky.eq(self.in_tot[0])
543 ]
544 return m
545
546
547 class FPAddStage1(FPState):
548
549 def __init__(self, width):
550 FPState.__init__(self, "add_1")
551 self.mod = FPAddStage1Mod(width)
552 self.out_z = FPNumBase(width, False)
553 self.out_of = Overflow()
554 self.norm_stb = Signal()
555
556 def setup(self, m, in_tot, in_z):
557 """ links module to inputs and outputs
558 """
559 m.submodules.add1 = self.mod
560
561 m.d.comb += self.mod.in_z.copy(in_z)
562 m.d.comb += self.mod.in_tot.eq(in_tot)
563
564 m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state
565
566 def action(self, m):
567 m.submodules.add1_out_overflow = self.out_of
568 m.d.sync += self.out_of.copy(self.mod.out_of)
569 m.d.sync += self.out_z.copy(self.mod.out_z)
570 m.d.sync += self.norm_stb.eq(1)
571 m.next = "normalise_1"
572
573
574 class FPNorm1ModSingle:
575
576 def __init__(self, width):
577 self.width = width
578 self.in_select = Signal(reset_less=True)
579 self.out_norm = Signal(reset_less=True)
580 self.in_z = FPNumBase(width, False)
581 self.in_of = Overflow()
582 self.temp_z = FPNumBase(width, False)
583 self.temp_of = Overflow()
584 self.out_z = FPNumBase(width, False)
585 self.out_of = Overflow()
586
587 def elaborate(self, platform):
588 m = Module()
589
590 mwid = self.out_z.m_width+2
591 pe = PriorityEncoder(mwid)
592 m.submodules.norm_pe = pe
593
594 m.submodules.norm1_out_z = self.out_z
595 m.submodules.norm1_out_overflow = self.out_of
596 m.submodules.norm1_temp_z = self.temp_z
597 m.submodules.norm1_temp_of = self.temp_of
598 m.submodules.norm1_in_z = self.in_z
599 m.submodules.norm1_in_overflow = self.in_of
600
601 in_z = FPNumBase(self.width, False)
602 in_of = Overflow()
603 m.submodules.norm1_insel_z = in_z
604 m.submodules.norm1_insel_overflow = in_of
605
606 espec = (len(in_z.e), True)
607 ediff_n126 = Signal(espec, reset_less=True)
608 msr = MultiShiftRMerge(mwid, espec)
609 m.submodules.multishift_r = msr
610
611 # select which of temp or in z/of to use
612 with m.If(self.in_select):
613 m.d.comb += in_z.copy(self.in_z)
614 m.d.comb += in_of.copy(self.in_of)
615 with m.Else():
616 m.d.comb += in_z.copy(self.temp_z)
617 m.d.comb += in_of.copy(self.temp_of)
618 # initialise out from in (overridden below)
619 m.d.comb += self.out_z.copy(in_z)
620 m.d.comb += self.out_of.copy(in_of)
621 # normalisation increase/decrease conditions
622 decrease = Signal(reset_less=True)
623 increase = Signal(reset_less=True)
624 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
625 m.d.comb += increase.eq(in_z.exp_lt_n126)
626 m.d.comb += self.out_norm.eq(0) # loop-end condition
627 # decrease exponent
628 with m.If(decrease):
629 # *sigh* not entirely obvious: count leading zeros (clz)
630 # with a PriorityEncoder: to find from the MSB
631 # we reverse the order of the bits.
632 temp_m = Signal(mwid, reset_less=True)
633 temp_s = Signal(mwid+1, reset_less=True)
634 clz = Signal((len(in_z.e), True), reset_less=True)
635 # make sure that the amount to decrease by does NOT
636 # go below the minimum non-INF/NaN exponent
637 limclz = Mux(in_z.exp_sub_n126 > pe.o, pe.o,
638 in_z.exp_sub_n126)
639 m.d.comb += [
640 # cat round and guard bits back into the mantissa
641 temp_m.eq(Cat(in_of.round_bit, in_of.guard, in_z.m)),
642 pe.i.eq(temp_m[::-1]), # inverted
643 clz.eq(limclz), # count zeros from MSB down
644 temp_s.eq(temp_m << clz), # shift mantissa UP
645 self.out_z.e.eq(in_z.e - clz), # DECREASE exponent
646 self.out_z.m.eq(temp_s[2:]), # exclude bits 0&1
647 self.out_of.m0.eq(temp_s[2]), # copy of mantissa[0]
648 # overflow in bits 0..1: got shifted too (leave sticky)
649 self.out_of.guard.eq(temp_s[1]), # guard
650 self.out_of.round_bit.eq(temp_s[0]), # round
651 ]
652 # increase exponent
653 with m.Elif(increase):
654 temp_m = Signal(mwid+1, reset_less=True)
655 m.d.comb += [
656 temp_m.eq(Cat(in_of.sticky, in_of.round_bit, in_of.guard,
657 in_z.m)),
658 ediff_n126.eq(in_z.N126 - in_z.e),
659 # connect multi-shifter to inp/out mantissa (and ediff)
660 msr.inp.eq(temp_m),
661 msr.diff.eq(ediff_n126),
662 self.out_z.m.eq(msr.m[3:]),
663 self.out_of.m0.eq(temp_s[3]), # copy of mantissa[0]
664 # overflow in bits 0..1: got shifted too (leave sticky)
665 self.out_of.guard.eq(temp_s[2]), # guard
666 self.out_of.round_bit.eq(temp_s[1]), # round
667 self.out_of.sticky.eq(temp_s[0]), # sticky
668 self.out_z.e.eq(in_z.e + ediff_n126),
669 ]
670
671 return m
672
673
674 class FPNorm1ModMulti:
675
676 def __init__(self, width, single_cycle=True):
677 self.width = width
678 self.in_select = Signal(reset_less=True)
679 self.out_norm = Signal(reset_less=True)
680 self.in_z = FPNumBase(width, False)
681 self.in_of = Overflow()
682 self.temp_z = FPNumBase(width, False)
683 self.temp_of = Overflow()
684 self.out_z = FPNumBase(width, False)
685 self.out_of = Overflow()
686
687 def elaborate(self, platform):
688 m = Module()
689
690 m.submodules.norm1_out_z = self.out_z
691 m.submodules.norm1_out_overflow = self.out_of
692 m.submodules.norm1_temp_z = self.temp_z
693 m.submodules.norm1_temp_of = self.temp_of
694 m.submodules.norm1_in_z = self.in_z
695 m.submodules.norm1_in_overflow = self.in_of
696
697 in_z = FPNumBase(self.width, False)
698 in_of = Overflow()
699 m.submodules.norm1_insel_z = in_z
700 m.submodules.norm1_insel_overflow = in_of
701
702 # select which of temp or in z/of to use
703 with m.If(self.in_select):
704 m.d.comb += in_z.copy(self.in_z)
705 m.d.comb += in_of.copy(self.in_of)
706 with m.Else():
707 m.d.comb += in_z.copy(self.temp_z)
708 m.d.comb += in_of.copy(self.temp_of)
709 # initialise out from in (overridden below)
710 m.d.comb += self.out_z.copy(in_z)
711 m.d.comb += self.out_of.copy(in_of)
712 # normalisation increase/decrease conditions
713 decrease = Signal(reset_less=True)
714 increase = Signal(reset_less=True)
715 m.d.comb += decrease.eq(in_z.m_msbzero & in_z.exp_gt_n126)
716 m.d.comb += increase.eq(in_z.exp_lt_n126)
717 m.d.comb += self.out_norm.eq(decrease | increase) # loop-end
718 # decrease exponent
719 with m.If(decrease):
720 m.d.comb += [
721 self.out_z.e.eq(in_z.e - 1), # DECREASE exponent
722 self.out_z.m.eq(in_z.m << 1), # shift mantissa UP
723 self.out_z.m[0].eq(in_of.guard), # steal guard (was tot[2])
724 self.out_of.guard.eq(in_of.round_bit), # round (was tot[1])
725 self.out_of.round_bit.eq(0), # reset round bit
726 self.out_of.m0.eq(in_of.guard),
727 ]
728 # increase exponent
729 with m.Elif(increase):
730 m.d.comb += [
731 self.out_z.e.eq(in_z.e + 1), # INCREASE exponent
732 self.out_z.m.eq(in_z.m >> 1), # shift mantissa DOWN
733 self.out_of.guard.eq(in_z.m[0]),
734 self.out_of.m0.eq(in_z.m[1]),
735 self.out_of.round_bit.eq(in_of.guard),
736 self.out_of.sticky.eq(in_of.sticky | in_of.round_bit)
737 ]
738
739 return m
740
741
742 class FPNorm1(FPState):
743
744 def __init__(self, width, single_cycle=True):
745 FPState.__init__(self, "normalise_1")
746 if single_cycle:
747 self.mod = FPNorm1ModSingle(width)
748 else:
749 self.mod = FPNorm1ModMulti(width)
750 self.stb = Signal(reset_less=True)
751 self.ack = Signal(reset=0, reset_less=True)
752 self.out_norm = Signal(reset_less=True)
753 self.in_accept = Signal(reset_less=True)
754 self.temp_z = FPNumBase(width)
755 self.temp_of = Overflow()
756 self.out_z = FPNumBase(width)
757 self.out_roundz = Signal(reset_less=True)
758
759 def setup(self, m, in_z, in_of, norm_stb):
760 """ links module to inputs and outputs
761 """
762 m.submodules.normalise_1 = self.mod
763
764 m.d.comb += self.mod.in_z.copy(in_z)
765 m.d.comb += self.mod.in_of.copy(in_of)
766
767 m.d.comb += self.mod.in_select.eq(self.in_accept)
768 m.d.comb += self.mod.temp_z.copy(self.temp_z)
769 m.d.comb += self.mod.temp_of.copy(self.temp_of)
770
771 m.d.comb += self.out_z.copy(self.mod.out_z)
772 m.d.comb += self.out_norm.eq(self.mod.out_norm)
773
774 m.d.comb += self.stb.eq(norm_stb)
775 m.d.sync += self.ack.eq(0) # sets to zero when not in normalise_1 state
776
777 def action(self, m):
778
779 m.d.comb += self.in_accept.eq((~self.ack) & (self.stb))
780 m.d.sync += self.temp_of.copy(self.mod.out_of)
781 m.d.sync += self.temp_z.copy(self.out_z)
782 with m.If(self.out_norm):
783 with m.If(self.in_accept):
784 m.d.sync += [
785 self.ack.eq(1),
786 ]
787 with m.Else():
788 m.d.sync += self.ack.eq(0)
789 with m.Else():
790 # normalisation not required (or done).
791 m.next = "round"
792 m.d.sync += self.ack.eq(1)
793 m.d.sync += self.out_roundz.eq(self.mod.out_of.roundz)
794
795
796 class FPRoundMod:
797
798 def __init__(self, width):
799 self.in_roundz = Signal(reset_less=True)
800 self.in_z = FPNumBase(width, False)
801 self.out_z = FPNumBase(width, False)
802
803 def elaborate(self, platform):
804 m = Module()
805 m.d.comb += self.out_z.copy(self.in_z)
806 with m.If(self.in_roundz):
807 m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
808 with m.If(self.in_z.m == self.in_z.m1s): # all 1s
809 m.d.comb += self.out_z.e.eq(self.in_z.e + 1) # exponent up
810 return m
811
812
813 class FPRound(FPState):
814
815 def __init__(self, width):
816 FPState.__init__(self, "round")
817 self.mod = FPRoundMod(width)
818 self.out_z = FPNumBase(width)
819
820 def setup(self, m, in_z, roundz):
821 """ links module to inputs and outputs
822 """
823 m.submodules.roundz = self.mod
824
825 m.d.comb += self.mod.in_z.copy(in_z)
826 m.d.comb += self.mod.in_roundz.eq(roundz)
827
828 def action(self, m):
829 m.d.sync += self.out_z.copy(self.mod.out_z)
830 m.next = "corrections"
831
832
833 class FPCorrectionsMod:
834
835 def __init__(self, width):
836 self.in_z = FPNumOut(width, False)
837 self.out_z = FPNumOut(width, False)
838
839 def elaborate(self, platform):
840 m = Module()
841 m.submodules.corr_in_z = self.in_z
842 m.submodules.corr_out_z = self.out_z
843 m.d.comb += self.out_z.copy(self.in_z)
844 with m.If(self.in_z.is_denormalised):
845 m.d.comb += self.out_z.e.eq(self.in_z.N127)
846
847 # with m.If(self.in_z.is_overflowed):
848 # m.d.comb += self.out_z.inf(self.in_z.s)
849 # with m.Else():
850 # m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
851 return m
852
853
854 class FPCorrections(FPState):
855
856 def __init__(self, width):
857 FPState.__init__(self, "corrections")
858 self.mod = FPCorrectionsMod(width)
859 self.out_z = FPNumBase(width)
860
861 def setup(self, m, in_z):
862 """ links module to inputs and outputs
863 """
864 m.submodules.corrections = self.mod
865 m.d.comb += self.mod.in_z.copy(in_z)
866
867 def action(self, m):
868 m.d.sync += self.out_z.copy(self.mod.out_z)
869 m.next = "pack"
870
871
872 class FPPackMod:
873
874 def __init__(self, width):
875 self.in_z = FPNumOut(width, False)
876 self.out_z = FPNumOut(width, False)
877
878 def elaborate(self, platform):
879 m = Module()
880 m.submodules.pack_in_z = self.in_z
881 with m.If(self.in_z.is_overflowed):
882 m.d.comb += self.out_z.inf(self.in_z.s)
883 with m.Else():
884 m.d.comb += self.out_z.create(self.in_z.s, self.in_z.e, self.in_z.m)
885 return m
886
887
888 class FPPack(FPState):
889
890 def __init__(self, width):
891 FPState.__init__(self, "pack")
892 self.mod = FPPackMod(width)
893 self.out_z = FPNumOut(width, False)
894
895 def setup(self, m, in_z):
896 """ links module to inputs and outputs
897 """
898 m.submodules.pack = self.mod
899 m.d.comb += self.mod.in_z.copy(in_z)
900
901 def action(self, m):
902 m.d.sync += self.out_z.v.eq(self.mod.out_z.v)
903 m.next = "pack_put_z"
904
905
906 class FPPutZ(FPState):
907
908 def __init__(self, state, in_z, out_z):
909 FPState.__init__(self, state)
910 self.in_z = in_z
911 self.out_z = out_z
912
913 def action(self, m):
914 m.d.sync += [
915 self.out_z.v.eq(self.in_z.v)
916 ]
917 with m.If(self.out_z.stb & self.out_z.ack):
918 m.d.sync += self.out_z.stb.eq(0)
919 m.next = "get_a"
920 with m.Else():
921 m.d.sync += self.out_z.stb.eq(1)
922
923
924 class FPADD:
925
926 def __init__(self, width, single_cycle=False):
927 self.width = width
928 self.single_cycle = single_cycle
929
930 self.in_a = FPOp(width)
931 self.in_b = FPOp(width)
932 self.out_z = FPOp(width)
933
934 self.states = []
935
936 def add_state(self, state):
937 self.states.append(state)
938 return state
939
940 def get_fragment(self, platform=None):
941 """ creates the HDL code-fragment for FPAdd
942 """
943 m = Module()
944 m.submodules.in_a = self.in_a
945 m.submodules.in_b = self.in_b
946 m.submodules.out_z = self.out_z
947
948 geta = self.add_state(FPGetOp("get_a", "get_b",
949 self.in_a, self.width))
950 geta.setup(m, self.in_a)
951 a = geta.out_op
952
953 getb = self.add_state(FPGetOp("get_b", "special_cases",
954 self.in_b, self.width))
955 getb.setup(m, self.in_b)
956 b = getb.out_op
957
958 sc = self.add_state(FPAddSpecialCases(self.width))
959 sc.mod.setup(m, a, b, sc.out_z, sc.out_do_z)
960 m.submodules.specialcases = sc.mod
961
962 dn = self.add_state(FPAddDeNorm(self.width))
963 dn.setup(m, a, b)
964
965 if self.single_cycle:
966 alm = self.add_state(FPAddAlignSingle(self.width))
967 alm.setup(m, dn.out_a, dn.out_b)
968 else:
969 alm = self.add_state(FPAddAlignMulti(self.width))
970 #alm.set_inputs({"a": a, "b": b})
971 alm.setup(m, dn.out_a, dn.out_b)
972
973 add0 = self.add_state(FPAddStage0(self.width))
974 add0.setup(m, alm.out_a, alm.out_b)
975
976 add1 = self.add_state(FPAddStage1(self.width))
977 add1.setup(m, add0.out_tot, add0.out_z)
978
979 n1 = self.add_state(FPNorm1(self.width))
980 n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb)
981
982 rn = self.add_state(FPRound(self.width))
983 rn.setup(m, n1.out_z, n1.out_roundz)
984
985 cor = self.add_state(FPCorrections(self.width))
986 cor.setup(m, rn.out_z)
987
988 pa = self.add_state(FPPack(self.width))
989 pa.setup(m, cor.out_z)
990
991 ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z))
992
993 pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z))
994
995 with m.FSM() as fsm:
996
997 for state in self.states:
998 with m.State(state.state_from):
999 state.action(m)
1000
1001 return m
1002
1003
1004 if __name__ == "__main__":
1005 alu = FPADD(width=32, single_cycle=True)
1006 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
1007
1008
1009 # works... but don't use, just do "python fname.py convert -t v"
1010 #print (verilog.convert(alu, ports=[
1011 # ports=alu.in_a.ports() + \
1012 # alu.in_b.ports() + \
1013 # alu.out_z.ports())