convert to use PrevControl and NextControl instead of Trigger class
[ieee754fpu.git] / src / add / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from singlepipe import PrevControl, NextControl
11 from pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBase:
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self.m_width = m_width
81 self.e_width = e_width
82 self.e_start = self.rmw - 1
83 self.e_end = self.rmw + self.e_width - 3 # for decoding
84
85 self.v = Signal(width, reset_less=True) # Latched copy of value
86 self.m = Signal(m_width, reset_less=True) # Mantissa
87 self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
88 self.s = Signal(reset_less=True) # Sign bit
89
90 self.mzero = Const(0, (m_width, False))
91 m_msb = 1<<(self.m_width-2)
92 self.msb1 = Const(m_msb, (m_width, False))
93 self.m1s = Const(-1, (m_width, False))
94 self.P128 = Const(e_max, (e_width, True))
95 self.P127 = Const(e_max-1, (e_width, True))
96 self.N127 = Const(-(e_max-1), (e_width, True))
97 self.N126 = Const(-(e_max-2), (e_width, True))
98
99 self.is_nan = Signal(reset_less=True)
100 self.is_zero = Signal(reset_less=True)
101 self.is_inf = Signal(reset_less=True)
102 self.is_overflowed = Signal(reset_less=True)
103 self.is_denormalised = Signal(reset_less=True)
104 self.exp_128 = Signal(reset_less=True)
105 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
106 self.exp_lt_n126 = Signal(reset_less=True)
107 self.exp_gt_n126 = Signal(reset_less=True)
108 self.exp_gt127 = Signal(reset_less=True)
109 self.exp_n127 = Signal(reset_less=True)
110 self.exp_n126 = Signal(reset_less=True)
111 self.m_zero = Signal(reset_less=True)
112 self.m_msbzero = Signal(reset_less=True)
113
114 def elaborate(self, platform):
115 m = Module()
116 m.d.comb += self.is_nan.eq(self._is_nan())
117 m.d.comb += self.is_zero.eq(self._is_zero())
118 m.d.comb += self.is_inf.eq(self._is_inf())
119 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
120 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
121 m.d.comb += self.exp_128.eq(self.e == self.P128)
122 m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
123 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
124 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
125 m.d.comb += self.exp_gt127.eq(self.e > self.P127)
126 m.d.comb += self.exp_n127.eq(self.e == self.N127)
127 m.d.comb += self.exp_n126.eq(self.e == self.N126)
128 m.d.comb += self.m_zero.eq(self.m == self.mzero)
129 m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
130
131 return m
132
133 def _is_nan(self):
134 return (self.exp_128) & (~self.m_zero)
135
136 def _is_inf(self):
137 return (self.exp_128) & (self.m_zero)
138
139 def _is_zero(self):
140 return (self.exp_n127) & (self.m_zero)
141
142 def _is_overflowed(self):
143 return self.exp_gt127
144
145 def _is_denormalised(self):
146 return (self.exp_n126) & (self.m_msbzero)
147
148 def eq(self, inp):
149 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
150
151
152 class FPNumOut(FPNumBase):
153 """ Floating-point Number Class
154
155 Contains signals for an incoming copy of the value, decoded into
156 sign / exponent / mantissa.
157 Also contains encoding functions, creation and recognition of
158 zero, NaN and inf (all signed)
159
160 Four extra bits are included in the mantissa: the top bit
161 (m[-1]) is effectively a carry-overflow. The other three are
162 guard (m[2]), round (m[1]), and sticky (m[0])
163 """
164 def __init__(self, width, m_extra=True):
165 FPNumBase.__init__(self, width, m_extra)
166
167 def elaborate(self, platform):
168 m = FPNumBase.elaborate(self, platform)
169
170 return m
171
172 def create(self, s, e, m):
173 """ creates a value from sign / exponent / mantissa
174
175 bias is added here, to the exponent
176 """
177 return [
178 self.v[-1].eq(s), # sign
179 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
180 self.v[0:self.e_start].eq(m) # mantissa
181 ]
182
183 def nan(self, s):
184 return self.create(s, self.P128, 1<<(self.e_start-1))
185
186 def inf(self, s):
187 return self.create(s, self.P128, 0)
188
189 def zero(self, s):
190 return self.create(s, self.N127, 0)
191
192 def create2(self, s, e, m):
193 """ creates a value from sign / exponent / mantissa
194
195 bias is added here, to the exponent
196 """
197 e = e + self.P127 # exp (add on bias)
198 return Cat(m[0:self.e_start],
199 e[0:self.e_end-self.e_start],
200 s)
201
202 def nan2(self, s):
203 return self.create2(s, self.P128, self.msb1)
204
205 def inf2(self, s):
206 return self.create2(s, self.P128, self.mzero)
207
208 def zero2(self, s):
209 return self.create2(s, self.N127, self.mzero)
210
211
212 class MultiShiftRMerge:
213 """ shifts down (right) and merges lower bits into m[0].
214 m[0] is the "sticky" bit, basically
215 """
216 def __init__(self, width, s_max=None):
217 if s_max is None:
218 s_max = int(log(width) / log(2))
219 self.smax = s_max
220 self.m = Signal(width, reset_less=True)
221 self.inp = Signal(width, reset_less=True)
222 self.diff = Signal(s_max, reset_less=True)
223 self.width = width
224
225 def elaborate(self, platform):
226 m = Module()
227
228 rs = Signal(self.width, reset_less=True)
229 m_mask = Signal(self.width, reset_less=True)
230 smask = Signal(self.width, reset_less=True)
231 stickybit = Signal(reset_less=True)
232 maxslen = Signal(self.smax, reset_less=True)
233 maxsleni = Signal(self.smax, reset_less=True)
234
235 sm = MultiShift(self.width-1)
236 m0s = Const(0, self.width-1)
237 mw = Const(self.width-1, len(self.diff))
238 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
239 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
240 ]
241
242 m.d.comb += [
243 # shift mantissa by maxslen, mask by inverse
244 rs.eq(sm.rshift(self.inp[1:], maxslen)),
245 m_mask.eq(sm.rshift(~m0s, maxsleni)),
246 smask.eq(self.inp[1:] & m_mask),
247 # sticky bit combines all mask (and mantissa low bit)
248 stickybit.eq(smask.bool() | self.inp[0]),
249 # mantissa result contains m[0] already.
250 self.m.eq(Cat(stickybit, rs))
251 ]
252 return m
253
254
255 class FPNumShift(FPNumBase):
256 """ Floating-point Number Class for shifting
257 """
258 def __init__(self, mainm, op, inv, width, m_extra=True):
259 FPNumBase.__init__(self, width, m_extra)
260 self.latch_in = Signal()
261 self.mainm = mainm
262 self.inv = inv
263 self.op = op
264
265 def elaborate(self, platform):
266 m = FPNumBase.elaborate(self, platform)
267
268 m.d.comb += self.s.eq(op.s)
269 m.d.comb += self.e.eq(op.e)
270 m.d.comb += self.m.eq(op.m)
271
272 with self.mainm.State("align"):
273 with m.If(self.e < self.inv.e):
274 m.d.sync += self.shift_down()
275
276 return m
277
278 def shift_down(self, inp):
279 """ shifts a mantissa down by one. exponent is increased to compensate
280
281 accuracy is lost as a result in the mantissa however there are 3
282 guard bits (the latter of which is the "sticky" bit)
283 """
284 return [self.e.eq(inp.e + 1),
285 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
286 ]
287
288 def shift_down_multi(self, diff):
289 """ shifts a mantissa down. exponent is increased to compensate
290
291 accuracy is lost as a result in the mantissa however there are 3
292 guard bits (the latter of which is the "sticky" bit)
293
294 this code works by variable-shifting the mantissa by up to
295 its maximum bit-length: no point doing more (it'll still be
296 zero).
297
298 the sticky bit is computed by shifting a batch of 1s by
299 the same amount, which will introduce zeros. it's then
300 inverted and used as a mask to get the LSBs of the mantissa.
301 those are then |'d into the sticky bit.
302 """
303 sm = MultiShift(self.width)
304 mw = Const(self.m_width-1, len(diff))
305 maxslen = Mux(diff > mw, mw, diff)
306 rs = sm.rshift(self.m[1:], maxslen)
307 maxsleni = mw - maxslen
308 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
309
310 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
311 return [self.e.eq(self.e + diff),
312 self.m.eq(Cat(stickybits, rs))
313 ]
314
315 def shift_up_multi(self, diff):
316 """ shifts a mantissa up. exponent is decreased to compensate
317 """
318 sm = MultiShift(self.width)
319 mw = Const(self.m_width, len(diff))
320 maxslen = Mux(diff > mw, mw, diff)
321
322 return [self.e.eq(self.e - diff),
323 self.m.eq(sm.lshift(self.m, maxslen))
324 ]
325
326
327 class FPNumDecode(FPNumBase):
328 """ Floating-point Number Class
329
330 Contains signals for an incoming copy of the value, decoded into
331 sign / exponent / mantissa.
332 Also contains encoding functions, creation and recognition of
333 zero, NaN and inf (all signed)
334
335 Four extra bits are included in the mantissa: the top bit
336 (m[-1]) is effectively a carry-overflow. The other three are
337 guard (m[2]), round (m[1]), and sticky (m[0])
338 """
339 def __init__(self, op, width, m_extra=True):
340 FPNumBase.__init__(self, width, m_extra)
341 self.op = op
342
343 def elaborate(self, platform):
344 m = FPNumBase.elaborate(self, platform)
345
346 m.d.comb += self.decode(self.v)
347
348 return m
349
350 def decode(self, v):
351 """ decodes a latched value into sign / exponent / mantissa
352
353 bias is subtracted here, from the exponent. exponent
354 is extended to 10 bits so that subtract 127 is done on
355 a 10-bit number
356 """
357 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
358 #print ("decode", self.e_end)
359 return [self.m.eq(Cat(*args)), # mantissa
360 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
361 self.s.eq(v[-1]), # sign
362 ]
363
364 class FPNumIn(FPNumBase):
365 """ Floating-point Number Class
366
367 Contains signals for an incoming copy of the value, decoded into
368 sign / exponent / mantissa.
369 Also contains encoding functions, creation and recognition of
370 zero, NaN and inf (all signed)
371
372 Four extra bits are included in the mantissa: the top bit
373 (m[-1]) is effectively a carry-overflow. The other three are
374 guard (m[2]), round (m[1]), and sticky (m[0])
375 """
376 def __init__(self, op, width, m_extra=True):
377 FPNumBase.__init__(self, width, m_extra)
378 self.latch_in = Signal()
379 self.op = op
380
381 def decode2(self, m):
382 """ decodes a latched value into sign / exponent / mantissa
383
384 bias is subtracted here, from the exponent. exponent
385 is extended to 10 bits so that subtract 127 is done on
386 a 10-bit number
387 """
388 v = self.v
389 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
390 #print ("decode", self.e_end)
391 res = ObjectProxy(m, pipemode=False)
392 res.m = Cat(*args) # mantissa
393 res.e = v[self.e_start:self.e_end] - self.P127 # exp
394 res.s = v[-1] # sign
395 return res
396
397 def decode(self, v):
398 """ decodes a latched value into sign / exponent / mantissa
399
400 bias is subtracted here, from the exponent. exponent
401 is extended to 10 bits so that subtract 127 is done on
402 a 10-bit number
403 """
404 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
405 #print ("decode", self.e_end)
406 return [self.m.eq(Cat(*args)), # mantissa
407 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
408 self.s.eq(v[-1]), # sign
409 ]
410
411 def shift_down(self, inp):
412 """ shifts a mantissa down by one. exponent is increased to compensate
413
414 accuracy is lost as a result in the mantissa however there are 3
415 guard bits (the latter of which is the "sticky" bit)
416 """
417 return [self.e.eq(inp.e + 1),
418 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
419 ]
420
421 def shift_down_multi(self, diff, inp=None):
422 """ shifts a mantissa down. exponent is increased to compensate
423
424 accuracy is lost as a result in the mantissa however there are 3
425 guard bits (the latter of which is the "sticky" bit)
426
427 this code works by variable-shifting the mantissa by up to
428 its maximum bit-length: no point doing more (it'll still be
429 zero).
430
431 the sticky bit is computed by shifting a batch of 1s by
432 the same amount, which will introduce zeros. it's then
433 inverted and used as a mask to get the LSBs of the mantissa.
434 those are then |'d into the sticky bit.
435 """
436 if inp is None:
437 inp = self
438 sm = MultiShift(self.width)
439 mw = Const(self.m_width-1, len(diff))
440 maxslen = Mux(diff > mw, mw, diff)
441 rs = sm.rshift(inp.m[1:], maxslen)
442 maxsleni = mw - maxslen
443 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
444
445 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
446 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
447 return [self.e.eq(inp.e + diff),
448 self.m.eq(Cat(stickybit, rs))
449 ]
450
451 def shift_up_multi(self, diff):
452 """ shifts a mantissa up. exponent is decreased to compensate
453 """
454 sm = MultiShift(self.width)
455 mw = Const(self.m_width, len(diff))
456 maxslen = Mux(diff > mw, mw, diff)
457
458 return [self.e.eq(self.e - diff),
459 self.m.eq(sm.lshift(self.m, maxslen))
460 ]
461
462 class Trigger:
463 def __init__(self):
464
465 self.stb = Signal(reset=0)
466 self.ack = Signal()
467 self.trigger = Signal(reset_less=True)
468
469 def elaborate(self, platform):
470 m = Module()
471 m.d.comb += self.trigger.eq(self.stb & self.ack)
472 return m
473
474 def eq(self, inp):
475 return [self.stb.eq(inp.stb),
476 self.ack.eq(inp.ack)
477 ]
478
479 def ports(self):
480 return [self.stb, self.ack]
481
482
483 class FPOpIn(PrevControl):
484 def __init__(self, width):
485 PrevControl.__init__(self)
486 self.width = width
487 self.v = Signal(width)
488 self.i_data = self.v
489
490 def chain_inv(self, in_op, extra=None):
491 stb = in_op.stb
492 if extra is not None:
493 stb = stb & extra
494 return [self.v.eq(in_op.v), # receive value
495 self.stb.eq(stb), # receive STB
496 in_op.ack.eq(~self.ack), # send ACK
497 ]
498
499 def chain_from(self, in_op, extra=None):
500 stb = in_op.stb
501 if extra is not None:
502 stb = stb & extra
503 return [self.v.eq(in_op.v), # receive value
504 self.stb.eq(stb), # receive STB
505 in_op.ack.eq(self.ack), # send ACK
506 ]
507
508
509 class FPOpOut(NextControl):
510 def __init__(self, width):
511 NextControl.__init__(self)
512 self.width = width
513 self.v = Signal(width)
514 self.o_data = self.v
515
516 def chain_inv(self, in_op, extra=None):
517 stb = in_op.stb
518 if extra is not None:
519 stb = stb & extra
520 return [self.v.eq(in_op.v), # receive value
521 self.stb.eq(stb), # receive STB
522 in_op.ack.eq(~self.ack), # send ACK
523 ]
524
525 def chain_from(self, in_op, extra=None):
526 stb = in_op.stb
527 if extra is not None:
528 stb = stb & extra
529 return [self.v.eq(in_op.v), # receive value
530 self.stb.eq(stb), # receive STB
531 in_op.ack.eq(self.ack), # send ACK
532 ]
533
534
535 class Overflow:
536 def __init__(self):
537 self.guard = Signal(reset_less=True) # tot[2]
538 self.round_bit = Signal(reset_less=True) # tot[1]
539 self.sticky = Signal(reset_less=True) # tot[0]
540 self.m0 = Signal(reset_less=True) # mantissa zero bit
541
542 self.roundz = Signal(reset_less=True)
543
544 def eq(self, inp):
545 return [self.guard.eq(inp.guard),
546 self.round_bit.eq(inp.round_bit),
547 self.sticky.eq(inp.sticky),
548 self.m0.eq(inp.m0)]
549
550 def elaborate(self, platform):
551 m = Module()
552 m.d.comb += self.roundz.eq(self.guard & \
553 (self.round_bit | self.sticky | self.m0))
554 return m
555
556
557 class FPBase:
558 """ IEEE754 Floating Point Base Class
559
560 contains common functions for FP manipulation, such as
561 extracting and packing operands, normalisation, denormalisation,
562 rounding etc.
563 """
564
565 def get_op(self, m, op, v, next_state):
566 """ this function moves to the next state and copies the operand
567 when both stb and ack are 1.
568 acknowledgement is sent by setting ack to ZERO.
569 """
570 res = v.decode2(m)
571 ack = Signal()
572 with m.If((op.o_ready) & (op.i_valid_test)):
573 m.next = next_state
574 # op is latched in from FPNumIn class on same ack/stb
575 m.d.comb += ack.eq(0)
576 with m.Else():
577 m.d.comb += ack.eq(1)
578 return [res, ack]
579
580 def denormalise(self, m, a):
581 """ denormalises a number. this is probably the wrong name for
582 this function. for normalised numbers (exponent != minimum)
583 one *extra* bit (the implicit 1) is added *back in*.
584 for denormalised numbers, the mantissa is left alone
585 and the exponent increased by 1.
586
587 both cases *effectively multiply the number stored by 2*,
588 which has to be taken into account when extracting the result.
589 """
590 with m.If(a.exp_n127):
591 m.d.sync += a.e.eq(a.N126) # limit a exponent
592 with m.Else():
593 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
594
595 def op_normalise(self, m, op, next_state):
596 """ operand normalisation
597 NOTE: just like "align", this one keeps going round every clock
598 until the result's exponent is within acceptable "range"
599 """
600 with m.If((op.m[-1] == 0)): # check last bit of mantissa
601 m.d.sync +=[
602 op.e.eq(op.e - 1), # DECREASE exponent
603 op.m.eq(op.m << 1), # shift mantissa UP
604 ]
605 with m.Else():
606 m.next = next_state
607
608 def normalise_1(self, m, z, of, next_state):
609 """ first stage normalisation
610
611 NOTE: just like "align", this one keeps going round every clock
612 until the result's exponent is within acceptable "range"
613 NOTE: the weirdness of reassigning guard and round is due to
614 the extra mantissa bits coming from tot[0..2]
615 """
616 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
617 m.d.sync += [
618 z.e.eq(z.e - 1), # DECREASE exponent
619 z.m.eq(z.m << 1), # shift mantissa UP
620 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
621 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
622 of.round_bit.eq(0), # reset round bit
623 of.m0.eq(of.guard),
624 ]
625 with m.Else():
626 m.next = next_state
627
628 def normalise_2(self, m, z, of, next_state):
629 """ second stage normalisation
630
631 NOTE: just like "align", this one keeps going round every clock
632 until the result's exponent is within acceptable "range"
633 NOTE: the weirdness of reassigning guard and round is due to
634 the extra mantissa bits coming from tot[0..2]
635 """
636 with m.If(z.e < z.N126):
637 m.d.sync +=[
638 z.e.eq(z.e + 1), # INCREASE exponent
639 z.m.eq(z.m >> 1), # shift mantissa DOWN
640 of.guard.eq(z.m[0]),
641 of.m0.eq(z.m[1]),
642 of.round_bit.eq(of.guard),
643 of.sticky.eq(of.sticky | of.round_bit)
644 ]
645 with m.Else():
646 m.next = next_state
647
648 def roundz(self, m, z, roundz):
649 """ performs rounding on the output. TODO: different kinds of rounding
650 """
651 with m.If(roundz):
652 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
653 with m.If(z.m == z.m1s): # all 1s
654 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
655
656 def corrections(self, m, z, next_state):
657 """ denormalisation and sign-bug corrections
658 """
659 m.next = next_state
660 # denormalised, correct exponent to zero
661 with m.If(z.is_denormalised):
662 m.d.sync += z.e.eq(z.N127)
663
664 def pack(self, m, z, next_state):
665 """ packs the result into the output (detects overflow->Inf)
666 """
667 m.next = next_state
668 # if overflow occurs, return inf
669 with m.If(z.is_overflowed):
670 m.d.sync += z.inf(z.s)
671 with m.Else():
672 m.d.sync += z.create(z.s, z.e, z.m)
673
674 def put_z(self, m, z, out_z, next_state):
675 """ put_z: stores the result in the output. raises stb and waits
676 for ack to be set to 1 before moving to the next state.
677 resets stb back to zero when that occurs, as acknowledgement.
678 """
679 m.d.sync += [
680 out_z.v.eq(z.v)
681 ]
682 with m.If(out_z.o_valid & out_z.i_ready_test):
683 m.d.sync += out_z.o_valid.eq(0)
684 m.next = next_state
685 with m.Else():
686 m.d.sync += out_z.o_valid.eq(1)
687
688
689 class FPState(FPBase):
690 def __init__(self, state_from):
691 self.state_from = state_from
692
693 def set_inputs(self, inputs):
694 self.inputs = inputs
695 for k,v in inputs.items():
696 setattr(self, k, v)
697
698 def set_outputs(self, outputs):
699 self.outputs = outputs
700 for k,v in outputs.items():
701 setattr(self, k, v)
702
703
704 class FPID:
705 def __init__(self, id_wid):
706 self.id_wid = id_wid
707 if self.id_wid:
708 self.in_mid = Signal(id_wid, reset_less=True)
709 self.out_mid = Signal(id_wid, reset_less=True)
710 else:
711 self.in_mid = None
712 self.out_mid = None
713
714 def idsync(self, m):
715 if self.id_wid is not None:
716 m.d.sync += self.out_mid.eq(self.in_mid)
717
718