move add to ieee754 directory
[ieee754fpu.git] / src / ieee754 / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from singlepipe import PrevControl, NextControl
11 from pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBase: #(Elaboratable):
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self.m_width = m_width
81 self.e_width = e_width
82 self.e_start = self.rmw - 1
83 self.e_end = self.rmw + self.e_width - 3 # for decoding
84
85 self.v = Signal(width, reset_less=True) # Latched copy of value
86 self.m = Signal(m_width, reset_less=True) # Mantissa
87 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
88 self.s = Signal(reset_less=True) # Sign bit
89
90 self.mzero = Const(0, (m_width, False))
91 m_msb = 1<<(self.m_width-2)
92 self.msb1 = Const(m_msb, (m_width, False))
93 self.m1s = Const(-1, (m_width, False))
94 self.P128 = Const(e_max, (e_width, True))
95 self.P127 = Const(e_max-1, (e_width, True))
96 self.N127 = Const(-(e_max-1), (e_width, True))
97 self.N126 = Const(-(e_max-2), (e_width, True))
98
99 self.is_nan = Signal(reset_less=True)
100 self.is_zero = Signal(reset_less=True)
101 self.is_inf = Signal(reset_less=True)
102 self.is_overflowed = Signal(reset_less=True)
103 self.is_denormalised = Signal(reset_less=True)
104 self.exp_128 = Signal(reset_less=True)
105 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
106 self.exp_lt_n126 = Signal(reset_less=True)
107 self.exp_gt_n126 = Signal(reset_less=True)
108 self.exp_gt127 = Signal(reset_less=True)
109 self.exp_n127 = Signal(reset_less=True)
110 self.exp_n126 = Signal(reset_less=True)
111 self.m_zero = Signal(reset_less=True)
112 self.m_msbzero = Signal(reset_less=True)
113
114 def elaborate(self, platform):
115 m = Module()
116 m.d.comb += self.is_nan.eq(self._is_nan())
117 m.d.comb += self.is_zero.eq(self._is_zero())
118 m.d.comb += self.is_inf.eq(self._is_inf())
119 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
120 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
121 m.d.comb += self.exp_128.eq(self.e == self.P128)
122 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
123 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
124 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
125 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
126 m.d.comb += self.exp_n127.eq(self.e == self.N127)
127 m.d.comb += self.exp_n126.eq(self.e == self.N126)
128 m.d.comb += self.m_zero.eq(self.m == self.mzero)
129 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
130
131 return m
132
133 def _is_nan(self):
134 return (self.exp_128) & (~self.m_zero)
135
136 def _is_inf(self):
137 return (self.exp_128) & (self.m_zero)
138
139 def _is_zero(self):
140 return (self.exp_n127) & (self.m_zero)
141
142 def _is_overflowed(self):
143 return self.exp_gt127
144
145 def _is_denormalised(self):
146 return (self.exp_n126) & (self.m_msbzero)
147
148 def __iter__(self):
149 yield self.s
150 yield self.e
151 yield self.m
152
153 def eq(self, inp):
154 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
155
156
157 class FPNumOut(FPNumBase):
158 """ Floating-point Number Class
159
160 Contains signals for an incoming copy of the value, decoded into
161 sign / exponent / mantissa.
162 Also contains encoding functions, creation and recognition of
163 zero, NaN and inf (all signed)
164
165 Four extra bits are included in the mantissa: the top bit
166 (m[-1]) is effectively a carry-overflow. The other three are
167 guard (m[2]), round (m[1]), and sticky (m[0])
168 """
169 def __init__(self, width, m_extra=True):
170 FPNumBase.__init__(self, width, m_extra)
171
172 def elaborate(self, platform):
173 m = FPNumBase.elaborate(self, platform)
174
175 return m
176
177 def create(self, s, e, m):
178 """ creates a value from sign / exponent / mantissa
179
180 bias is added here, to the exponent
181 """
182 return [
183 self.v[-1].eq(s), # sign
184 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
185 self.v[0:self.e_start].eq(m) # mantissa
186 ]
187
188 def nan(self, s):
189 return self.create(s, self.P128, 1<<(self.e_start-1))
190
191 def inf(self, s):
192 return self.create(s, self.P128, 0)
193
194 def zero(self, s):
195 return self.create(s, self.N127, 0)
196
197 def create2(self, s, e, m):
198 """ creates a value from sign / exponent / mantissa
199
200 bias is added here, to the exponent
201 """
202 e = e + self.P127 # exp (add on bias)
203 return Cat(m[0:self.e_start],
204 e[0:self.e_end-self.e_start],
205 s)
206
207 def nan2(self, s):
208 return self.create2(s, self.P128, self.msb1)
209
210 def inf2(self, s):
211 return self.create2(s, self.P128, self.mzero)
212
213 def zero2(self, s):
214 return self.create2(s, self.N127, self.mzero)
215
216
217 class MultiShiftRMerge(Elaboratable):
218 """ shifts down (right) and merges lower bits into m[0].
219 m[0] is the "sticky" bit, basically
220 """
221 def __init__(self, width, s_max=None):
222 if s_max is None:
223 s_max = int(log(width) / log(2))
224 self.smax = s_max
225 self.m = Signal(width, reset_less=True)
226 self.inp = Signal(width, reset_less=True)
227 self.diff = Signal(s_max, reset_less=True)
228 self.width = width
229
230 def elaborate(self, platform):
231 m = Module()
232
233 rs = Signal(self.width, reset_less=True)
234 m_mask = Signal(self.width, reset_less=True)
235 smask = Signal(self.width, reset_less=True)
236 stickybit = Signal(reset_less=True)
237 maxslen = Signal(self.smax, reset_less=True)
238 maxsleni = Signal(self.smax, reset_less=True)
239
240 sm = MultiShift(self.width-1)
241 m0s = Const(0, self.width-1)
242 mw = Const(self.width-1, len(self.diff))
243 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
244 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
245 ]
246
247 m.d.comb += [
248 # shift mantissa by maxslen, mask by inverse
249 rs.eq(sm.rshift(self.inp[1:], maxslen)),
250 m_mask.eq(sm.rshift(~m0s, maxsleni)),
251 smask.eq(self.inp[1:] & m_mask),
252 # sticky bit combines all mask (and mantissa low bit)
253 stickybit.eq(smask.bool() | self.inp[0]),
254 # mantissa result contains m[0] already.
255 self.m.eq(Cat(stickybit, rs))
256 ]
257 return m
258
259
260 class FPNumShift(FPNumBase, Elaboratable):
261 """ Floating-point Number Class for shifting
262 """
263 def __init__(self, mainm, op, inv, width, m_extra=True):
264 FPNumBase.__init__(self, width, m_extra)
265 self.latch_in = Signal()
266 self.mainm = mainm
267 self.inv = inv
268 self.op = op
269
270 def elaborate(self, platform):
271 m = FPNumBase.elaborate(self, platform)
272
273 m.d.comb += self.s.eq(op.s)
274 m.d.comb += self.e.eq(op.e)
275 m.d.comb += self.m.eq(op.m)
276
277 with self.mainm.State("align"):
278 with m.If(self.e < self.inv.e):
279 m.d.sync += self.shift_down()
280
281 return m
282
283 def shift_down(self, inp):
284 """ shifts a mantissa down by one. exponent is increased to compensate
285
286 accuracy is lost as a result in the mantissa however there are 3
287 guard bits (the latter of which is the "sticky" bit)
288 """
289 return [self.e.eq(inp.e + 1),
290 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
291 ]
292
293 def shift_down_multi(self, diff):
294 """ shifts a mantissa down. exponent is increased to compensate
295
296 accuracy is lost as a result in the mantissa however there are 3
297 guard bits (the latter of which is the "sticky" bit)
298
299 this code works by variable-shifting the mantissa by up to
300 its maximum bit-length: no point doing more (it'll still be
301 zero).
302
303 the sticky bit is computed by shifting a batch of 1s by
304 the same amount, which will introduce zeros. it's then
305 inverted and used as a mask to get the LSBs of the mantissa.
306 those are then |'d into the sticky bit.
307 """
308 sm = MultiShift(self.width)
309 mw = Const(self.m_width-1, len(diff))
310 maxslen = Mux(diff > mw, mw, diff)
311 rs = sm.rshift(self.m[1:], maxslen)
312 maxsleni = mw - maxslen
313 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
314
315 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
316 return [self.e.eq(self.e + diff),
317 self.m.eq(Cat(stickybits, rs))
318 ]
319
320 def shift_up_multi(self, diff):
321 """ shifts a mantissa up. exponent is decreased to compensate
322 """
323 sm = MultiShift(self.width)
324 mw = Const(self.m_width, len(diff))
325 maxslen = Mux(diff > mw, mw, diff)
326
327 return [self.e.eq(self.e - diff),
328 self.m.eq(sm.lshift(self.m, maxslen))
329 ]
330
331
332 class FPNumDecode(FPNumBase):
333 """ Floating-point Number Class
334
335 Contains signals for an incoming copy of the value, decoded into
336 sign / exponent / mantissa.
337 Also contains encoding functions, creation and recognition of
338 zero, NaN and inf (all signed)
339
340 Four extra bits are included in the mantissa: the top bit
341 (m[-1]) is effectively a carry-overflow. The other three are
342 guard (m[2]), round (m[1]), and sticky (m[0])
343 """
344 def __init__(self, op, width, m_extra=True):
345 FPNumBase.__init__(self, width, m_extra)
346 self.op = op
347
348 def elaborate(self, platform):
349 m = FPNumBase.elaborate(self, platform)
350
351 m.d.comb += self.decode(self.v)
352
353 return m
354
355 def decode(self, v):
356 """ decodes a latched value into sign / exponent / mantissa
357
358 bias is subtracted here, from the exponent. exponent
359 is extended to 10 bits so that subtract 127 is done on
360 a 10-bit number
361 """
362 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
363 #print ("decode", self.e_end)
364 return [self.m.eq(Cat(*args)), # mantissa
365 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
366 self.s.eq(v[-1]), # sign
367 ]
368
369 class FPNumIn(FPNumBase):
370 """ Floating-point Number Class
371
372 Contains signals for an incoming copy of the value, decoded into
373 sign / exponent / mantissa.
374 Also contains encoding functions, creation and recognition of
375 zero, NaN and inf (all signed)
376
377 Four extra bits are included in the mantissa: the top bit
378 (m[-1]) is effectively a carry-overflow. The other three are
379 guard (m[2]), round (m[1]), and sticky (m[0])
380 """
381 def __init__(self, op, width, m_extra=True):
382 FPNumBase.__init__(self, width, m_extra)
383 self.latch_in = Signal()
384 self.op = op
385
386 def decode2(self, m):
387 """ decodes a latched value into sign / exponent / mantissa
388
389 bias is subtracted here, from the exponent. exponent
390 is extended to 10 bits so that subtract 127 is done on
391 a 10-bit number
392 """
393 v = self.v
394 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
395 #print ("decode", self.e_end)
396 res = ObjectProxy(m, pipemode=False)
397 res.m = Cat(*args) # mantissa
398 res.e = v[self.e_start:self.e_end] - self.P127 # exp
399 res.s = v[-1] # sign
400 return res
401
402 def decode(self, v):
403 """ decodes a latched value into sign / exponent / mantissa
404
405 bias is subtracted here, from the exponent. exponent
406 is extended to 10 bits so that subtract 127 is done on
407 a 10-bit number
408 """
409 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
410 #print ("decode", self.e_end)
411 return [self.m.eq(Cat(*args)), # mantissa
412 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
413 self.s.eq(v[-1]), # sign
414 ]
415
416 def shift_down(self, inp):
417 """ shifts a mantissa down by one. exponent is increased to compensate
418
419 accuracy is lost as a result in the mantissa however there are 3
420 guard bits (the latter of which is the "sticky" bit)
421 """
422 return [self.e.eq(inp.e + 1),
423 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
424 ]
425
426 def shift_down_multi(self, diff, inp=None):
427 """ shifts a mantissa down. exponent is increased to compensate
428
429 accuracy is lost as a result in the mantissa however there are 3
430 guard bits (the latter of which is the "sticky" bit)
431
432 this code works by variable-shifting the mantissa by up to
433 its maximum bit-length: no point doing more (it'll still be
434 zero).
435
436 the sticky bit is computed by shifting a batch of 1s by
437 the same amount, which will introduce zeros. it's then
438 inverted and used as a mask to get the LSBs of the mantissa.
439 those are then |'d into the sticky bit.
440 """
441 if inp is None:
442 inp = self
443 sm = MultiShift(self.width)
444 mw = Const(self.m_width-1, len(diff))
445 maxslen = Mux(diff > mw, mw, diff)
446 rs = sm.rshift(inp.m[1:], maxslen)
447 maxsleni = mw - maxslen
448 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
449
450 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
451 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
452 return [self.e.eq(inp.e + diff),
453 self.m.eq(Cat(stickybit, rs))
454 ]
455
456 def shift_up_multi(self, diff):
457 """ shifts a mantissa up. exponent is decreased to compensate
458 """
459 sm = MultiShift(self.width)
460 mw = Const(self.m_width, len(diff))
461 maxslen = Mux(diff > mw, mw, diff)
462
463 return [self.e.eq(self.e - diff),
464 self.m.eq(sm.lshift(self.m, maxslen))
465 ]
466
467 class Trigger(Elaboratable):
468 def __init__(self):
469
470 self.stb = Signal(reset=0)
471 self.ack = Signal()
472 self.trigger = Signal(reset_less=True)
473
474 def elaborate(self, platform):
475 m = Module()
476 m.d.comb += self.trigger.eq(self.stb & self.ack)
477 return m
478
479 def eq(self, inp):
480 return [self.stb.eq(inp.stb),
481 self.ack.eq(inp.ack)
482 ]
483
484 def ports(self):
485 return [self.stb, self.ack]
486
487
488 class FPOpIn(PrevControl):
489 def __init__(self, width):
490 PrevControl.__init__(self)
491 self.width = width
492
493 @property
494 def v(self):
495 return self.data_i
496
497 def chain_inv(self, in_op, extra=None):
498 stb = in_op.stb
499 if extra is not None:
500 stb = stb & extra
501 return [self.v.eq(in_op.v), # receive value
502 self.stb.eq(stb), # receive STB
503 in_op.ack.eq(~self.ack), # send ACK
504 ]
505
506 def chain_from(self, in_op, extra=None):
507 stb = in_op.stb
508 if extra is not None:
509 stb = stb & extra
510 return [self.v.eq(in_op.v), # receive value
511 self.stb.eq(stb), # receive STB
512 in_op.ack.eq(self.ack), # send ACK
513 ]
514
515
516 class FPOpOut(NextControl):
517 def __init__(self, width):
518 NextControl.__init__(self)
519 self.width = width
520
521 @property
522 def v(self):
523 return self.data_o
524
525 def chain_inv(self, in_op, extra=None):
526 stb = in_op.stb
527 if extra is not None:
528 stb = stb & extra
529 return [self.v.eq(in_op.v), # receive value
530 self.stb.eq(stb), # receive STB
531 in_op.ack.eq(~self.ack), # send ACK
532 ]
533
534 def chain_from(self, in_op, extra=None):
535 stb = in_op.stb
536 if extra is not None:
537 stb = stb & extra
538 return [self.v.eq(in_op.v), # receive value
539 self.stb.eq(stb), # receive STB
540 in_op.ack.eq(self.ack), # send ACK
541 ]
542
543
544 class Overflow: #(Elaboratable):
545 def __init__(self):
546 self.guard = Signal(reset_less=True) # tot[2]
547 self.round_bit = Signal(reset_less=True) # tot[1]
548 self.sticky = Signal(reset_less=True) # tot[0]
549 self.m0 = Signal(reset_less=True) # mantissa zero bit
550
551 self.roundz = Signal(reset_less=True)
552
553 def __iter__(self):
554 yield self.guard
555 yield self.round_bit
556 yield self.sticky
557 yield self.m0
558
559 def eq(self, inp):
560 return [self.guard.eq(inp.guard),
561 self.round_bit.eq(inp.round_bit),
562 self.sticky.eq(inp.sticky),
563 self.m0.eq(inp.m0)]
564
565 def elaborate(self, platform):
566 m = Module()
567 m.d.comb += self.roundz.eq(self.guard & \
568 (self.round_bit | self.sticky | self.m0))
569 return m
570
571
572 class FPBase:
573 """ IEEE754 Floating Point Base Class
574
575 contains common functions for FP manipulation, such as
576 extracting and packing operands, normalisation, denormalisation,
577 rounding etc.
578 """
579
580 def get_op(self, m, op, v, next_state):
581 """ this function moves to the next state and copies the operand
582 when both stb and ack are 1.
583 acknowledgement is sent by setting ack to ZERO.
584 """
585 res = v.decode2(m)
586 ack = Signal()
587 with m.If((op.ready_o) & (op.valid_i_test)):
588 m.next = next_state
589 # op is latched in from FPNumIn class on same ack/stb
590 m.d.comb += ack.eq(0)
591 with m.Else():
592 m.d.comb += ack.eq(1)
593 return [res, ack]
594
595 def denormalise(self, m, a):
596 """ denormalises a number. this is probably the wrong name for
597 this function. for normalised numbers (exponent != minimum)
598 one *extra* bit (the implicit 1) is added *back in*.
599 for denormalised numbers, the mantissa is left alone
600 and the exponent increased by 1.
601
602 both cases *effectively multiply the number stored by 2*,
603 which has to be taken into account when extracting the result.
604 """
605 with m.If(a.exp_n127):
606 m.d.sync += a.e.eq(a.N126) # limit a exponent
607 with m.Else():
608 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
609
610 def op_normalise(self, m, op, next_state):
611 """ operand normalisation
612 NOTE: just like "align", this one keeps going round every clock
613 until the result's exponent is within acceptable "range"
614 """
615 with m.If((op.m[-1] == 0)): # check last bit of mantissa
616 m.d.sync +=[
617 op.e.eq(op.e - 1), # DECREASE exponent
618 op.m.eq(op.m << 1), # shift mantissa UP
619 ]
620 with m.Else():
621 m.next = next_state
622
623 def normalise_1(self, m, z, of, next_state):
624 """ first stage normalisation
625
626 NOTE: just like "align", this one keeps going round every clock
627 until the result's exponent is within acceptable "range"
628 NOTE: the weirdness of reassigning guard and round is due to
629 the extra mantissa bits coming from tot[0..2]
630 """
631 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
632 m.d.sync += [
633 z.e.eq(z.e - 1), # DECREASE exponent
634 z.m.eq(z.m << 1), # shift mantissa UP
635 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
636 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
637 of.round_bit.eq(0), # reset round bit
638 of.m0.eq(of.guard),
639 ]
640 with m.Else():
641 m.next = next_state
642
643 def normalise_2(self, m, z, of, next_state):
644 """ second stage normalisation
645
646 NOTE: just like "align", this one keeps going round every clock
647 until the result's exponent is within acceptable "range"
648 NOTE: the weirdness of reassigning guard and round is due to
649 the extra mantissa bits coming from tot[0..2]
650 """
651 with m.If(z.e < z.N126):
652 m.d.sync +=[
653 z.e.eq(z.e + 1), # INCREASE exponent
654 z.m.eq(z.m >> 1), # shift mantissa DOWN
655 of.guard.eq(z.m[0]),
656 of.m0.eq(z.m[1]),
657 of.round_bit.eq(of.guard),
658 of.sticky.eq(of.sticky | of.round_bit)
659 ]
660 with m.Else():
661 m.next = next_state
662
663 def roundz(self, m, z, roundz):
664 """ performs rounding on the output. TODO: different kinds of rounding
665 """
666 with m.If(roundz):
667 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
668 with m.If(z.m == z.m1s): # all 1s
669 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
670
671 def corrections(self, m, z, next_state):
672 """ denormalisation and sign-bug corrections
673 """
674 m.next = next_state
675 # denormalised, correct exponent to zero
676 with m.If(z.is_denormalised):
677 m.d.sync += z.e.eq(z.N127)
678
679 def pack(self, m, z, next_state):
680 """ packs the result into the output (detects overflow->Inf)
681 """
682 m.next = next_state
683 # if overflow occurs, return inf
684 with m.If(z.is_overflowed):
685 m.d.sync += z.inf(z.s)
686 with m.Else():
687 m.d.sync += z.create(z.s, z.e, z.m)
688
689 def put_z(self, m, z, out_z, next_state):
690 """ put_z: stores the result in the output. raises stb and waits
691 for ack to be set to 1 before moving to the next state.
692 resets stb back to zero when that occurs, as acknowledgement.
693 """
694 m.d.sync += [
695 out_z.v.eq(z.v)
696 ]
697 with m.If(out_z.valid_o & out_z.ready_i_test):
698 m.d.sync += out_z.valid_o.eq(0)
699 m.next = next_state
700 with m.Else():
701 m.d.sync += out_z.valid_o.eq(1)
702
703
704 class FPState(FPBase):
705 def __init__(self, state_from):
706 self.state_from = state_from
707
708 def set_inputs(self, inputs):
709 self.inputs = inputs
710 for k,v in inputs.items():
711 setattr(self, k, v)
712
713 def set_outputs(self, outputs):
714 self.outputs = outputs
715 for k,v in outputs.items():
716 setattr(self, k, v)
717
718
719 class FPID:
720 def __init__(self, id_wid):
721 self.id_wid = id_wid
722 if self.id_wid:
723 self.in_mid = Signal(id_wid, reset_less=True)
724 self.out_mid = Signal(id_wid, reset_less=True)
725 else:
726 self.in_mid = None
727 self.out_mid = None
728
729 def idsync(self, m):
730 if self.id_wid is not None:
731 m.d.sync += self.out_mid.eq(self.in_mid)
732
733