ab6cc705dc5ae7e0113cd31736efca9bd0420676
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 """IEEE754 Floating Point Library
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
5
6 """
7
8
9 from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
10 Value, Shape)
11 from nmigen.utils import bits_for
12 from operator import or_
13 from functools import reduce
14
15 from nmutil.singlepipe import PrevControl, NextControl
16 from nmutil.pipeline import ObjectProxy
17 import unittest
18 import math
19 import enum
20
21 try:
22 from nmigen.hdl.smtlib2 import RoundingModeEnum
23 _HAVE_SMTLIB2 = True
24 except ImportError:
25 _HAVE_SMTLIB2 = False
26
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
28 _raise_err = object()
29
30
31 class FPRoundingMode(enum.Enum):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
34
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
38 RNE = 0b00
39 """Round to Nearest Even
40
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
45 """
46
47 ROUND_NEAREST_TIES_TO_EVEN = RNE
48 DEFAULT = RNE
49
50 RTZ = 0b01
51 """Round towards Zero
52
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
56 """
57
58 ROUND_TOWARDS_ZERO = RTZ
59
60 RTP = 0b10
61 """Round towards +Infinity
62
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
66 +Infinity.
67 """
68
69 ROUND_TOWARDS_POSITIVE = RTP
70
71 RTN = 0b11
72 """Round towards -Infinity
73
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
77 -Infinity.
78 """
79
80 ROUND_TOWARDS_NEGATIVE = RTN
81
82 RNA = 0b100
83 """Round to Nearest Away
84
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
89 """
90
91 ROUND_NEAREST_TIES_TO_AWAY = RNA
92
93 RTOP = 0b101
94 """Round to Odd, unsigned zeros are Positive
95
96 Not in smtlib2.
97
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
101
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
104
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
107
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
110 (e.g. f80).
111 """
112
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE = RTOP
114
115 RTON = 0b110
116 """Round to Odd, unsigned zeros are Negative
117
118 Not in smtlib2.
119
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
123
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
126
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
129
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
132 (e.g. f80).
133 """
134
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE = RTON
136
137 @staticmethod
138 def make_array(f):
139 l = [None] * len(FPRoundingMode)
140 for rm in FPRoundingMode:
141 l[rm.value] = f(rm)
142 return Array(l)
143
144 def overflow_rounds_to_inf(self, sign):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
147 """
148 not_sign = ~sign if isinstance(sign, Value) else not sign
149 if self is FPRoundingMode.RNE:
150 return True
151 elif self is FPRoundingMode.RTZ:
152 return False
153 elif self is FPRoundingMode.RTP:
154 return not_sign
155 elif self is FPRoundingMode.RTN:
156 return sign
157 elif self is FPRoundingMode.RNA:
158 return True
159 elif self is FPRoundingMode.RTOP:
160 return False
161 else:
162 assert self is FPRoundingMode.RTON
163 return False
164
165 def underflow_rounds_to_zero(self, sign):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
168 """
169 not_sign = ~sign if isinstance(sign, Value) else not sign
170 if self is FPRoundingMode.RNE:
171 return True
172 elif self is FPRoundingMode.RTZ:
173 return True
174 elif self is FPRoundingMode.RTP:
175 return sign
176 elif self is FPRoundingMode.RTN:
177 return not_sign
178 elif self is FPRoundingMode.RNA:
179 return True
180 elif self is FPRoundingMode.RTOP:
181 return False
182 else:
183 assert self is FPRoundingMode.RTON
184 return False
185
186 def zero_sign(self):
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
189 """
190 if self is FPRoundingMode.RNE:
191 return False
192 elif self is FPRoundingMode.RTZ:
193 return False
194 elif self is FPRoundingMode.RTP:
195 return False
196 elif self is FPRoundingMode.RTN:
197 return True
198 elif self is FPRoundingMode.RNA:
199 return False
200 elif self is FPRoundingMode.RTOP:
201 return False
202 else:
203 assert self is FPRoundingMode.RTON
204 return True
205
206 if _HAVE_SMTLIB2:
207 def to_smtlib2(self, default=_raise_err):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
211 """
212 if self is FPRoundingMode.RNE:
213 return RoundingModeEnum.RNE
214 elif self is FPRoundingMode.RTZ:
215 return RoundingModeEnum.RTZ
216 elif self is FPRoundingMode.RTP:
217 return RoundingModeEnum.RTP
218 elif self is FPRoundingMode.RTN:
219 return RoundingModeEnum.RTN
220 elif self is FPRoundingMode.RNA:
221 return RoundingModeEnum.RNA
222 else:
223 assert self in (FPRoundingMode.RTOP, FPRoundingMode.RTON)
224 if default is _raise_err:
225 raise ValueError(
226 "no corresponding smtlib2 rounding mode", self)
227 return default
228
229
230
231
232 class FPFormat:
233 """ Class describing binary floating-point formats based on IEEE 754.
234
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
237 field.
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
242 """
243
244 def __init__(self,
245 e_width,
246 m_width,
247 has_int_bit=False,
248 has_sign=True):
249 """ Create ``FPFormat`` instance. """
250 self.e_width = e_width
251 self.m_width = m_width
252 self.has_int_bit = has_int_bit
253 self.has_sign = has_sign
254
255 def __eq__(self, other):
256 """ Check for equality. """
257 if not isinstance(other, FPFormat):
258 return NotImplemented
259 return (self.e_width == other.e_width
260 and self.m_width == other.m_width
261 and self.has_int_bit == other.has_int_bit
262 and self.has_sign == other.has_sign)
263
264 @staticmethod
265 def standard(width):
266 """ Get standard IEEE 754-2008 format.
267
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
270 """
271 if width == 16:
272 return FPFormat(5, 10)
273 if width == 32:
274 return FPFormat(8, 23)
275 if width == 64:
276 return FPFormat(11, 52)
277 if width == 128:
278 return FPFormat(15, 112)
279 if width > 128 and width % 32 == 0:
280 if width > 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width = round(4 * math.log2(width)) - 13
283 return FPFormat(e_width, width - 1 - e_width)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
286
287 def __repr__(self):
288 """ Get repr. """
289 try:
290 if self == self.standard(self.width):
291 return f"FPFormat.standard({self.width})"
292 except ValueError:
293 pass
294 retval = f"FPFormat({self.e_width}, {self.m_width}"
295 if self.has_int_bit is not False:
296 retval += f", {self.has_int_bit}"
297 if self.has_sign is not True:
298 retval += f", {self.has_sign}"
299 return retval + ")"
300
301 def get_sign_field(self, x):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
304 """
305 return x >> (self.e_width + self.m_width)
306
307 def get_exponent_field(self, x):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
309 """
310 x = ((x >> self.m_width) & self.exponent_inf_nan)
311 return x
312
313 def get_exponent(self, x):
314 """ returns the exponent of its input number, x
315 """
316 return self.get_exponent_field(x) - self.exponent_bias
317
318 def get_mantissa_field(self, x):
319 """ returns the mantissa of its input number, x
320 """
321 return x & self.mantissa_mask
322
323 def get_mantissa_value(self, x):
324 """ returns the mantissa of its input number, x, but with the
325 implicit bit, if any, made explicit.
326 """
327 if self.has_int_bit:
328 return self.get_mantissa_field(x)
329 exponent_field = self.get_exponent_field(x)
330 mantissa_field = self.get_mantissa_field(x)
331 implicit_bit = exponent_field == self.exponent_denormal_zero
332 return (implicit_bit << self.fraction_width) | mantissa_field
333
334 def is_zero(self, x):
335 """ returns true if x is +/- zero
336 """
337 return (self.get_exponent(x) == self.e_sub) & \
338 (self.get_mantissa_field(x) == 0)
339
340 def is_subnormal(self, x):
341 """ returns true if x is subnormal (exp at minimum)
342 """
343 return (self.get_exponent(x) == self.e_sub) & \
344 (self.get_mantissa_field(x) != 0)
345
346 def is_inf(self, x):
347 """ returns true if x is infinite
348 """
349 return (self.get_exponent(x) == self.e_max) & \
350 (self.get_mantissa_field(x) == 0)
351
352 def is_nan(self, x):
353 """ returns true if x is a nan (quiet or signalling)
354 """
355 return (self.get_exponent(x) == self.e_max) & \
356 (self.get_mantissa_field(x) != 0)
357
358 def is_quiet_nan(self, x):
359 """ returns true if x is a quiet nan
360 """
361 highbit = 1 << (self.m_width - 1)
362 return (self.get_exponent(x) == self.e_max) & \
363 (self.get_mantissa_field(x) != 0) & \
364 (self.get_mantissa_field(x) & highbit != 0)
365
366 def to_quiet_nan(self, x):
367 """ converts `x` to a quiet NaN """
368 highbit = 1 << (self.m_width - 1)
369 return x | highbit | self.exponent_mask
370
371 def quiet_nan(self, sign=0):
372 """ return the default quiet NaN with sign `sign` """
373 return self.to_quiet_nan(self.zero(sign))
374
375 def zero(self, sign=0):
376 """ return zero with sign `sign` """
377 return (sign != 0) << (self.e_width + self.m_width)
378
379 def inf(self, sign=0):
380 """ return infinity with sign `sign` """
381 return self.zero(sign) | self.exponent_mask
382
383 def is_nan_signaling(self, x):
384 """ returns true if x is a signalling nan
385 """
386 highbit = 1 << (self.m_width - 1)
387 return (self.get_exponent(x) == self.e_max) & \
388 (self.get_mantissa_field(x) != 0) & \
389 (self.get_mantissa_field(x) & highbit) == 0
390
391 @property
392 def width(self):
393 """ Get the total number of bits in the FP format. """
394 return self.has_sign + self.e_width + self.m_width
395
396 @property
397 def mantissa_mask(self):
398 """ Get a mantissa mask based on the mantissa width """
399 return (1 << self.m_width) - 1
400
401 @property
402 def exponent_mask(self):
403 """ Get an exponent mask """
404 return self.exponent_inf_nan << self.m_width
405
406 @property
407 def exponent_inf_nan(self):
408 """ Get the value of the exponent field designating infinity/NaN. """
409 return (1 << self.e_width) - 1
410
411 @property
412 def e_max(self):
413 """ get the maximum exponent (minus bias)
414 """
415 return self.exponent_inf_nan - self.exponent_bias
416
417 @property
418 def e_sub(self):
419 return self.exponent_denormal_zero - self.exponent_bias
420 @property
421 def exponent_denormal_zero(self):
422 """ Get the value of the exponent field designating denormal/zero. """
423 return 0
424
425 @property
426 def exponent_min_normal(self):
427 """ Get the minimum value of the exponent field for normal numbers. """
428 return 1
429
430 @property
431 def exponent_max_normal(self):
432 """ Get the maximum value of the exponent field for normal numbers. """
433 return self.exponent_inf_nan - 1
434
435 @property
436 def exponent_bias(self):
437 """ Get the exponent bias. """
438 return (1 << (self.e_width - 1)) - 1
439
440 @property
441 def fraction_width(self):
442 """ Get the number of mantissa bits that are fraction bits. """
443 return self.m_width - self.has_int_bit
444
445
446 class TestFPFormat(unittest.TestCase):
447 """ very quick test for FPFormat
448 """
449
450 def test_fpformat_fp64(self):
451 f64 = FPFormat.standard(64)
452 from sfpy import Float64
453 x = Float64(1.0).bits
454 print (hex(x))
455
456 self.assertEqual(f64.get_exponent(x), 0)
457 x = Float64(2.0).bits
458 print (hex(x))
459 self.assertEqual(f64.get_exponent(x), 1)
460
461 x = Float64(1.5).bits
462 m = f64.get_mantissa_field(x)
463 print (hex(x), hex(m))
464 self.assertEqual(m, 0x8000000000000)
465
466 s = f64.get_sign_field(x)
467 print (hex(x), hex(s))
468 self.assertEqual(s, 0)
469
470 x = Float64(-1.5).bits
471 s = f64.get_sign_field(x)
472 print (hex(x), hex(s))
473 self.assertEqual(s, 1)
474
475 def test_fpformat_fp32(self):
476 f32 = FPFormat.standard(32)
477 from sfpy import Float32
478 x = Float32(1.0).bits
479 print (hex(x))
480
481 self.assertEqual(f32.get_exponent(x), 0)
482 x = Float32(2.0).bits
483 print (hex(x))
484 self.assertEqual(f32.get_exponent(x), 1)
485
486 x = Float32(1.5).bits
487 m = f32.get_mantissa_field(x)
488 print (hex(x), hex(m))
489 self.assertEqual(m, 0x400000)
490
491 # NaN test
492 x = Float32(-1.0).sqrt()
493 x = x.bits
494 i = f32.is_nan(x)
495 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
496 f32.get_mantissa_field(x), i)
497 self.assertEqual(i, True)
498
499 # Inf test
500 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
501 x = x.bits
502 i = f32.is_inf(x)
503 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
504 f32.get_mantissa_field(x), i)
505 self.assertEqual(i, True)
506
507 # subnormal
508 x = Float32(1e-41)
509 x = x.bits
510 i = f32.is_subnormal(x)
511 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
512 f32.get_mantissa_field(x), i)
513 self.assertEqual(i, True)
514
515 x = Float32(0.0)
516 x = x.bits
517 i = f32.is_subnormal(x)
518 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
519 f32.get_mantissa_field(x), i)
520 self.assertEqual(i, False)
521
522 # zero
523 i = f32.is_zero(x)
524 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
525 f32.get_mantissa_field(x), i)
526 self.assertEqual(i, True)
527
528
529 class MultiShiftR(Elaboratable):
530
531 def __init__(self, width):
532 self.width = width
533 self.smax = bits_for(width - 1)
534 self.i = Signal(width, reset_less=True)
535 self.s = Signal(self.smax, reset_less=True)
536 self.o = Signal(width, reset_less=True)
537
538 def elaborate(self, platform):
539 m = Module()
540 m.d.comb += self.o.eq(self.i >> self.s)
541 return m
542
543
544 class MultiShift:
545 """ Generates variable-length single-cycle shifter from a series
546 of conditional tests on each bit of the left/right shift operand.
547 Each bit tested produces output shifted by that number of bits,
548 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
549 shifts by 2 bits, each partial result cascading to the next Mux.
550
551 Could be adapted to do arithmetic shift by taking copies of the
552 MSB instead of zeros.
553 """
554
555 def __init__(self, width):
556 self.width = width
557 self.smax = bits_for(width - 1)
558
559 def lshift(self, op, s):
560 res = op << s
561 return res[:len(op)]
562
563 def rshift(self, op, s):
564 res = op >> s
565 return res[:len(op)]
566
567
568 class FPNumBaseRecord:
569 """ Floating-point Base Number Class.
570
571 This class is designed to be passed around in other data structures
572 (between pipelines and between stages). Its "friend" is FPNumBase,
573 which is a *module*. The reason for the discernment is because
574 nmigen modules that are not added to submodules results in the
575 irritating "Elaboration" warning. Despite not *needing* FPNumBase
576 in many cases to be added as a submodule (because it is just data)
577 this was not possible to solve without splitting out the data from
578 the module.
579 """
580
581 def __init__(self, width, m_extra=True, e_extra=False, name=None):
582 if name is None:
583 name = ""
584 # assert false, "missing name"
585 else:
586 name += "_"
587 self.width = width
588 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
589 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
590 e_max = 1 << (e_width-3)
591 self.rmw = m_width - 1 # real mantissa width (not including extras)
592 self.e_max = e_max
593 if m_extra:
594 # mantissa extra bits (top,guard,round)
595 self.m_extra = 3
596 m_width += self.m_extra
597 else:
598 self.m_extra = 0
599 if e_extra:
600 self.e_extra = 6 # enough to cover FP64 when converting to FP16
601 e_width += self.e_extra
602 else:
603 self.e_extra = 0
604 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
605 self.m_width = m_width
606 self.e_width = e_width
607 self.e_start = self.rmw
608 self.e_end = self.rmw + self.e_width - 2 # for decoding
609
610 self.v = Signal(width, reset_less=True,
611 name=name+"v") # Latched copy of value
612 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
613 self.e = Signal((e_width, True),
614 reset_less=True, name=name+"e") # exp+2 bits, signed
615 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
616
617 self.fp = self
618 self.drop_in(self)
619
620 def drop_in(self, fp):
621 fp.s = self.s
622 fp.e = self.e
623 fp.m = self.m
624 fp.v = self.v
625 fp.rmw = self.rmw
626 fp.width = self.width
627 fp.e_width = self.e_width
628 fp.e_max = self.e_max
629 fp.m_width = self.m_width
630 fp.e_start = self.e_start
631 fp.e_end = self.e_end
632 fp.m_extra = self.m_extra
633
634 m_width = self.m_width
635 e_max = self.e_max
636 e_width = self.e_width
637
638 self.mzero = Const(0, (m_width, False))
639 m_msb = 1 << (self.m_width-2)
640 self.msb1 = Const(m_msb, (m_width, False))
641 self.m1s = Const(-1, (m_width, False))
642 self.P128 = Const(e_max, (e_width, True))
643 self.P127 = Const(e_max-1, (e_width, True))
644 self.N127 = Const(-(e_max-1), (e_width, True))
645 self.N126 = Const(-(e_max-2), (e_width, True))
646
647 def create(self, s, e, m):
648 """ creates a value from sign / exponent / mantissa
649
650 bias is added here, to the exponent.
651
652 NOTE: order is important, because e_start/e_end can be
653 a bit too long (overwriting s).
654 """
655 return [
656 self.v[0:self.e_start].eq(m), # mantissa
657 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
658 self.v[-1].eq(s), # sign
659 ]
660
661 def _nan(self, s):
662 return (s, self.fp.P128, 1 << (self.e_start-1))
663
664 def _inf(self, s):
665 return (s, self.fp.P128, 0)
666
667 def _zero(self, s):
668 return (s, self.fp.N127, 0)
669
670 def nan(self, s):
671 return self.create(*self._nan(s))
672
673 def quieted_nan(self, other):
674 assert isinstance(other, FPNumBaseRecord)
675 assert self.width == other.width
676 return self.create(other.s, self.fp.P128,
677 other.v[0:self.e_start] | (1 << (self.e_start - 1)))
678
679 def inf(self, s):
680 return self.create(*self._inf(s))
681
682 def max_normal(self, s):
683 return self.create(s, self.fp.P127, ~0)
684
685 def min_denormal(self, s):
686 return self.create(s, self.fp.N127, 1)
687
688 def zero(self, s):
689 return self.create(*self._zero(s))
690
691 def create2(self, s, e, m):
692 """ creates a value from sign / exponent / mantissa
693
694 bias is added here, to the exponent
695 """
696 e = e + self.P127 # exp (add on bias)
697 return Cat(m[0:self.e_start],
698 e[0:self.e_end-self.e_start],
699 s)
700
701 def nan2(self, s):
702 return self.create2(s, self.P128, self.msb1)
703
704 def inf2(self, s):
705 return self.create2(s, self.P128, self.mzero)
706
707 def zero2(self, s):
708 return self.create2(s, self.N127, self.mzero)
709
710 def __iter__(self):
711 yield self.s
712 yield self.e
713 yield self.m
714
715 def eq(self, inp):
716 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
717
718
719 class FPNumBase(FPNumBaseRecord, Elaboratable):
720 """ Floating-point Base Number Class
721 """
722
723 def __init__(self, fp):
724 fp.drop_in(self)
725 self.fp = fp
726 e_width = fp.e_width
727
728 self.is_nan = Signal(reset_less=True)
729 self.is_zero = Signal(reset_less=True)
730 self.is_inf = Signal(reset_less=True)
731 self.is_overflowed = Signal(reset_less=True)
732 self.is_denormalised = Signal(reset_less=True)
733 self.exp_128 = Signal(reset_less=True)
734 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
735 self.exp_lt_n126 = Signal(reset_less=True)
736 self.exp_zero = Signal(reset_less=True)
737 self.exp_gt_n126 = Signal(reset_less=True)
738 self.exp_gt127 = Signal(reset_less=True)
739 self.exp_n127 = Signal(reset_less=True)
740 self.exp_n126 = Signal(reset_less=True)
741 self.m_zero = Signal(reset_less=True)
742 self.m_msbzero = Signal(reset_less=True)
743
744 def elaborate(self, platform):
745 m = Module()
746 m.d.comb += self.is_nan.eq(self._is_nan())
747 m.d.comb += self.is_zero.eq(self._is_zero())
748 m.d.comb += self.is_inf.eq(self._is_inf())
749 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
750 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
751 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
752 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
753 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
754 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
755 m.d.comb += self.exp_zero.eq(self.e == 0)
756 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
757 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
758 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
759 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
760 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
761
762 return m
763
764 def _is_nan(self):
765 return (self.exp_128) & (~self.m_zero)
766
767 def _is_inf(self):
768 return (self.exp_128) & (self.m_zero)
769
770 def _is_zero(self):
771 return (self.exp_n127) & (self.m_zero)
772
773 def _is_overflowed(self):
774 return self.exp_gt127
775
776 def _is_denormalised(self):
777 # XXX NOT to be used for "official" quiet NaN tests!
778 # particularly when the MSB has been extended
779 return (self.exp_n126) & (self.m_msbzero)
780
781
782 class FPNumOut(FPNumBase):
783 """ Floating-point Number Class
784
785 Contains signals for an incoming copy of the value, decoded into
786 sign / exponent / mantissa.
787 Also contains encoding functions, creation and recognition of
788 zero, NaN and inf (all signed)
789
790 Four extra bits are included in the mantissa: the top bit
791 (m[-1]) is effectively a carry-overflow. The other three are
792 guard (m[2]), round (m[1]), and sticky (m[0])
793 """
794
795 def __init__(self, fp):
796 FPNumBase.__init__(self, fp)
797
798 def elaborate(self, platform):
799 m = FPNumBase.elaborate(self, platform)
800
801 return m
802
803
804 class MultiShiftRMerge(Elaboratable):
805 """ shifts down (right) and merges lower bits into m[0].
806 m[0] is the "sticky" bit, basically
807 """
808
809 def __init__(self, width, s_max=None):
810 if s_max is None:
811 s_max = bits_for(width - 1)
812 self.smax = Shape.cast(s_max)
813 self.m = Signal(width, reset_less=True)
814 self.inp = Signal(width, reset_less=True)
815 self.diff = Signal(s_max, reset_less=True)
816 self.width = width
817
818 def elaborate(self, platform):
819 m = Module()
820
821 rs = Signal(self.width, reset_less=True)
822 m_mask = Signal(self.width, reset_less=True)
823 smask = Signal(self.width, reset_less=True)
824 stickybit = Signal(reset_less=True)
825 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
826 maxslen = Signal(self.smax.width, reset_less=True)
827 maxsleni = Signal(self.smax.width, reset_less=True)
828
829 sm = MultiShift(self.width-1)
830 m0s = Const(0, self.width-1)
831 mw = Const(self.width-1, len(self.diff))
832 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
833 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
834 ]
835
836 m.d.comb += [
837 # shift mantissa by maxslen, mask by inverse
838 rs.eq(sm.rshift(self.inp[1:], maxslen)),
839 m_mask.eq(sm.rshift(~m0s, maxsleni)),
840 smask.eq(self.inp[1:] & m_mask),
841 # sticky bit combines all mask (and mantissa low bit)
842 stickybit.eq(smask.bool() | self.inp[0]),
843 # mantissa result contains m[0] already.
844 self.m.eq(Cat(stickybit, rs))
845 ]
846 return m
847
848
849 class FPNumShift(FPNumBase, Elaboratable):
850 """ Floating-point Number Class for shifting
851 """
852
853 def __init__(self, mainm, op, inv, width, m_extra=True):
854 FPNumBase.__init__(self, width, m_extra)
855 self.latch_in = Signal()
856 self.mainm = mainm
857 self.inv = inv
858 self.op = op
859
860 def elaborate(self, platform):
861 m = FPNumBase.elaborate(self, platform)
862
863 m.d.comb += self.s.eq(op.s)
864 m.d.comb += self.e.eq(op.e)
865 m.d.comb += self.m.eq(op.m)
866
867 with self.mainm.State("align"):
868 with m.If(self.e < self.inv.e):
869 m.d.sync += self.shift_down()
870
871 return m
872
873 def shift_down(self, inp):
874 """ shifts a mantissa down by one. exponent is increased to compensate
875
876 accuracy is lost as a result in the mantissa however there are 3
877 guard bits (the latter of which is the "sticky" bit)
878 """
879 return [self.e.eq(inp.e + 1),
880 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
881 ]
882
883 def shift_down_multi(self, diff):
884 """ shifts a mantissa down. exponent is increased to compensate
885
886 accuracy is lost as a result in the mantissa however there are 3
887 guard bits (the latter of which is the "sticky" bit)
888
889 this code works by variable-shifting the mantissa by up to
890 its maximum bit-length: no point doing more (it'll still be
891 zero).
892
893 the sticky bit is computed by shifting a batch of 1s by
894 the same amount, which will introduce zeros. it's then
895 inverted and used as a mask to get the LSBs of the mantissa.
896 those are then |'d into the sticky bit.
897 """
898 sm = MultiShift(self.width)
899 mw = Const(self.m_width-1, len(diff))
900 maxslen = Mux(diff > mw, mw, diff)
901 rs = sm.rshift(self.m[1:], maxslen)
902 maxsleni = mw - maxslen
903 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
904
905 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
906 return [self.e.eq(self.e + diff),
907 self.m.eq(Cat(stickybits, rs))
908 ]
909
910 def shift_up_multi(self, diff):
911 """ shifts a mantissa up. exponent is decreased to compensate
912 """
913 sm = MultiShift(self.width)
914 mw = Const(self.m_width, len(diff))
915 maxslen = Mux(diff > mw, mw, diff)
916
917 return [self.e.eq(self.e - diff),
918 self.m.eq(sm.lshift(self.m, maxslen))
919 ]
920
921
922 class FPNumDecode(FPNumBase):
923 """ Floating-point Number Class
924
925 Contains signals for an incoming copy of the value, decoded into
926 sign / exponent / mantissa.
927 Also contains encoding functions, creation and recognition of
928 zero, NaN and inf (all signed)
929
930 Four extra bits are included in the mantissa: the top bit
931 (m[-1]) is effectively a carry-overflow. The other three are
932 guard (m[2]), round (m[1]), and sticky (m[0])
933 """
934
935 def __init__(self, op, fp):
936 FPNumBase.__init__(self, fp)
937 self.op = op
938
939 def elaborate(self, platform):
940 m = FPNumBase.elaborate(self, platform)
941
942 m.d.comb += self.decode(self.v)
943
944 return m
945
946 def decode(self, v):
947 """ decodes a latched value into sign / exponent / mantissa
948
949 bias is subtracted here, from the exponent. exponent
950 is extended to 10 bits so that subtract 127 is done on
951 a 10-bit number
952 """
953 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
954 #print ("decode", self.e_end)
955 return [self.m.eq(Cat(*args)), # mantissa
956 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
957 self.s.eq(v[-1]), # sign
958 ]
959
960
961 class FPNumIn(FPNumBase):
962 """ Floating-point Number Class
963
964 Contains signals for an incoming copy of the value, decoded into
965 sign / exponent / mantissa.
966 Also contains encoding functions, creation and recognition of
967 zero, NaN and inf (all signed)
968
969 Four extra bits are included in the mantissa: the top bit
970 (m[-1]) is effectively a carry-overflow. The other three are
971 guard (m[2]), round (m[1]), and sticky (m[0])
972 """
973
974 def __init__(self, op, fp):
975 FPNumBase.__init__(self, fp)
976 self.latch_in = Signal()
977 self.op = op
978
979 def decode2(self, m):
980 """ decodes a latched value into sign / exponent / mantissa
981
982 bias is subtracted here, from the exponent. exponent
983 is extended to 10 bits so that subtract 127 is done on
984 a 10-bit number
985 """
986 v = self.v
987 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
988 #print ("decode", self.e_end)
989 res = ObjectProxy(m, pipemode=False)
990 res.m = Cat(*args) # mantissa
991 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
992 res.s = v[-1] # sign
993 return res
994
995 def decode(self, v):
996 """ decodes a latched value into sign / exponent / mantissa
997
998 bias is subtracted here, from the exponent. exponent
999 is extended to 10 bits so that subtract 127 is done on
1000 a 10-bit number
1001 """
1002 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
1003 #print ("decode", self.e_end)
1004 return [self.m.eq(Cat(*args)), # mantissa
1005 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
1006 self.s.eq(v[-1]), # sign
1007 ]
1008
1009 def shift_down(self, inp):
1010 """ shifts a mantissa down by one. exponent is increased to compensate
1011
1012 accuracy is lost as a result in the mantissa however there are 3
1013 guard bits (the latter of which is the "sticky" bit)
1014 """
1015 return [self.e.eq(inp.e + 1),
1016 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
1017 ]
1018
1019 def shift_down_multi(self, diff, inp=None):
1020 """ shifts a mantissa down. exponent is increased to compensate
1021
1022 accuracy is lost as a result in the mantissa however there are 3
1023 guard bits (the latter of which is the "sticky" bit)
1024
1025 this code works by variable-shifting the mantissa by up to
1026 its maximum bit-length: no point doing more (it'll still be
1027 zero).
1028
1029 the sticky bit is computed by shifting a batch of 1s by
1030 the same amount, which will introduce zeros. it's then
1031 inverted and used as a mask to get the LSBs of the mantissa.
1032 those are then |'d into the sticky bit.
1033 """
1034 if inp is None:
1035 inp = self
1036 sm = MultiShift(self.width)
1037 mw = Const(self.m_width-1, len(diff))
1038 maxslen = Mux(diff > mw, mw, diff)
1039 rs = sm.rshift(inp.m[1:], maxslen)
1040 maxsleni = mw - maxslen
1041 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
1042
1043 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1044 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
1045 return [self.e.eq(inp.e + diff),
1046 self.m.eq(Cat(stickybit, rs))
1047 ]
1048
1049 def shift_up_multi(self, diff):
1050 """ shifts a mantissa up. exponent is decreased to compensate
1051 """
1052 sm = MultiShift(self.width)
1053 mw = Const(self.m_width, len(diff))
1054 maxslen = Mux(diff > mw, mw, diff)
1055
1056 return [self.e.eq(self.e - diff),
1057 self.m.eq(sm.lshift(self.m, maxslen))
1058 ]
1059
1060
1061 class Trigger(Elaboratable):
1062 def __init__(self):
1063
1064 self.stb = Signal(reset=0)
1065 self.ack = Signal()
1066 self.trigger = Signal(reset_less=True)
1067
1068 def elaborate(self, platform):
1069 m = Module()
1070 m.d.comb += self.trigger.eq(self.stb & self.ack)
1071 return m
1072
1073 def eq(self, inp):
1074 return [self.stb.eq(inp.stb),
1075 self.ack.eq(inp.ack)
1076 ]
1077
1078 def ports(self):
1079 return [self.stb, self.ack]
1080
1081
1082 class FPOpIn(PrevControl):
1083 def __init__(self, width):
1084 PrevControl.__init__(self)
1085 self.width = width
1086
1087 @property
1088 def v(self):
1089 return self.data_i
1090
1091 def chain_inv(self, in_op, extra=None):
1092 stb = in_op.stb
1093 if extra is not None:
1094 stb = stb & extra
1095 return [self.v.eq(in_op.v), # receive value
1096 self.stb.eq(stb), # receive STB
1097 in_op.ack.eq(~self.ack), # send ACK
1098 ]
1099
1100 def chain_from(self, in_op, extra=None):
1101 stb = in_op.stb
1102 if extra is not None:
1103 stb = stb & extra
1104 return [self.v.eq(in_op.v), # receive value
1105 self.stb.eq(stb), # receive STB
1106 in_op.ack.eq(self.ack), # send ACK
1107 ]
1108
1109
1110 class FPOpOut(NextControl):
1111 def __init__(self, width):
1112 NextControl.__init__(self)
1113 self.width = width
1114
1115 @property
1116 def v(self):
1117 return self.data_o
1118
1119 def chain_inv(self, in_op, extra=None):
1120 stb = in_op.stb
1121 if extra is not None:
1122 stb = stb & extra
1123 return [self.v.eq(in_op.v), # receive value
1124 self.stb.eq(stb), # receive STB
1125 in_op.ack.eq(~self.ack), # send ACK
1126 ]
1127
1128 def chain_from(self, in_op, extra=None):
1129 stb = in_op.stb
1130 if extra is not None:
1131 stb = stb & extra
1132 return [self.v.eq(in_op.v), # receive value
1133 self.stb.eq(stb), # receive STB
1134 in_op.ack.eq(self.ack), # send ACK
1135 ]
1136
1137
1138 class Overflow:
1139 # TODO: change FFLAGS to be FPSCR's status flags
1140 FFLAGS_NV = Const(1<<4, 5) # invalid operation
1141 FFLAGS_DZ = Const(1<<3, 5) # divide by zero
1142 FFLAGS_OF = Const(1<<2, 5) # overflow
1143 FFLAGS_UF = Const(1<<1, 5) # underflow
1144 FFLAGS_NX = Const(1<<0, 5) # inexact
1145 def __init__(self, name=None):
1146 if name is None:
1147 name = ""
1148 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
1149 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
1150 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
1151 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
1152 self.fpflags = Signal(5, reset_less=True, name=name+"fflags")
1153
1154 self.sign = Signal(reset_less=True, name=name+"sign")
1155 """sign bit -- 1 means negative, 0 means positive"""
1156
1157 self.rm = Signal(FPRoundingMode, name=name+"rm",
1158 reset=FPRoundingMode.DEFAULT)
1159 """rounding mode"""
1160
1161 #self.roundz = Signal(reset_less=True)
1162
1163 def __iter__(self):
1164 yield self.guard
1165 yield self.round_bit
1166 yield self.sticky
1167 yield self.m0
1168 yield self.fpflags
1169 yield self.sign
1170 yield self.rm
1171
1172 def eq(self, inp):
1173 return [self.guard.eq(inp.guard),
1174 self.round_bit.eq(inp.round_bit),
1175 self.sticky.eq(inp.sticky),
1176 self.m0.eq(inp.m0),
1177 self.fpflags.eq(inp.fpflags),
1178 self.sign.eq(inp.sign),
1179 self.rm.eq(inp.rm)]
1180
1181 @property
1182 def roundz_rne(self):
1183 """true if the mantissa should be rounded up for `rm == RNE`
1184
1185 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1186 """
1187 return self.guard & (self.round_bit | self.sticky | self.m0)
1188
1189 @property
1190 def roundz_rna(self):
1191 """true if the mantissa should be rounded up for `rm == RNA`
1192
1193 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1194 """
1195 return self.guard
1196
1197 @property
1198 def roundz_rtn(self):
1199 """true if the mantissa should be rounded up for `rm == RTN`
1200
1201 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1202 """
1203 return self.sign & (self.guard | self.round_bit | self.sticky)
1204
1205 @property
1206 def roundz_rto(self):
1207 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1208
1209 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1210 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1211 """
1212 return ~self.m0 & (self.guard | self.round_bit | self.sticky)
1213
1214 @property
1215 def roundz_rtp(self):
1216 """true if the mantissa should be rounded up for `rm == RTP`
1217
1218 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1219 """
1220 return ~self.sign & (self.guard | self.round_bit | self.sticky)
1221
1222 @property
1223 def roundz_rtz(self):
1224 """true if the mantissa should be rounded up for `rm == RTZ`
1225
1226 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1227 """
1228 return False
1229
1230 @property
1231 def roundz(self):
1232 """true if the mantissa should be rounded up for the current rounding
1233 mode `self.rm`
1234 """
1235 d = {
1236 FPRoundingMode.RNA: self.roundz_rna,
1237 FPRoundingMode.RNE: self.roundz_rne,
1238 FPRoundingMode.RTN: self.roundz_rtn,
1239 FPRoundingMode.RTOP: self.roundz_rto,
1240 FPRoundingMode.RTON: self.roundz_rto,
1241 FPRoundingMode.RTP: self.roundz_rtp,
1242 FPRoundingMode.RTZ: self.roundz_rtz,
1243 }
1244 return FPRoundingMode.make_array(lambda rm: d[rm])[self.rm]
1245
1246
1247 class OverflowMod(Elaboratable, Overflow):
1248 def __init__(self, name=None):
1249 Overflow.__init__(self, name)
1250 if name is None:
1251 name = ""
1252 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
1253
1254 def __iter__(self):
1255 yield from Overflow.__iter__(self)
1256 yield self.roundz_out
1257
1258 def eq(self, inp):
1259 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
1260
1261 def elaborate(self, platform):
1262 m = Module()
1263 m.d.comb += self.roundz_out.eq(self.roundz) # roundz is a property
1264 return m
1265
1266
1267 class FPBase:
1268 """ IEEE754 Floating Point Base Class
1269
1270 contains common functions for FP manipulation, such as
1271 extracting and packing operands, normalisation, denormalisation,
1272 rounding etc.
1273 """
1274
1275 def get_op(self, m, op, v, next_state):
1276 """ this function moves to the next state and copies the operand
1277 when both stb and ack are 1.
1278 acknowledgement is sent by setting ack to ZERO.
1279 """
1280 res = v.decode2(m)
1281 ack = Signal()
1282 with m.If((op.ready_o) & (op.valid_i_test)):
1283 m.next = next_state
1284 # op is latched in from FPNumIn class on same ack/stb
1285 m.d.comb += ack.eq(0)
1286 with m.Else():
1287 m.d.comb += ack.eq(1)
1288 return [res, ack]
1289
1290 def denormalise(self, m, a):
1291 """ denormalises a number. this is probably the wrong name for
1292 this function. for normalised numbers (exponent != minimum)
1293 one *extra* bit (the implicit 1) is added *back in*.
1294 for denormalised numbers, the mantissa is left alone
1295 and the exponent increased by 1.
1296
1297 both cases *effectively multiply the number stored by 2*,
1298 which has to be taken into account when extracting the result.
1299 """
1300 with m.If(a.exp_n127):
1301 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
1302 with m.Else():
1303 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
1304
1305 def op_normalise(self, m, op, next_state):
1306 """ operand normalisation
1307 NOTE: just like "align", this one keeps going round every clock
1308 until the result's exponent is within acceptable "range"
1309 """
1310 with m.If((op.m[-1] == 0)): # check last bit of mantissa
1311 m.d.sync += [
1312 op.e.eq(op.e - 1), # DECREASE exponent
1313 op.m.eq(op.m << 1), # shift mantissa UP
1314 ]
1315 with m.Else():
1316 m.next = next_state
1317
1318 def normalise_1(self, m, z, of, next_state):
1319 """ first stage normalisation
1320
1321 NOTE: just like "align", this one keeps going round every clock
1322 until the result's exponent is within acceptable "range"
1323 NOTE: the weirdness of reassigning guard and round is due to
1324 the extra mantissa bits coming from tot[0..2]
1325 """
1326 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
1327 m.d.sync += [
1328 z.e.eq(z.e - 1), # DECREASE exponent
1329 z.m.eq(z.m << 1), # shift mantissa UP
1330 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
1331 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
1332 of.round_bit.eq(0), # reset round bit
1333 of.m0.eq(of.guard),
1334 ]
1335 with m.Else():
1336 m.next = next_state
1337
1338 def normalise_2(self, m, z, of, next_state):
1339 """ second stage normalisation
1340
1341 NOTE: just like "align", this one keeps going round every clock
1342 until the result's exponent is within acceptable "range"
1343 NOTE: the weirdness of reassigning guard and round is due to
1344 the extra mantissa bits coming from tot[0..2]
1345 """
1346 with m.If(z.e < z.fp.N126):
1347 m.d.sync += [
1348 z.e.eq(z.e + 1), # INCREASE exponent
1349 z.m.eq(z.m >> 1), # shift mantissa DOWN
1350 of.guard.eq(z.m[0]),
1351 of.m0.eq(z.m[1]),
1352 of.round_bit.eq(of.guard),
1353 of.sticky.eq(of.sticky | of.round_bit)
1354 ]
1355 with m.Else():
1356 m.next = next_state
1357
1358 def roundz(self, m, z, roundz):
1359 """ performs rounding on the output. TODO: different kinds of rounding
1360 """
1361 with m.If(roundz):
1362 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1363 with m.If(z.m == z.fp.m1s): # all 1s
1364 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1365
1366 def corrections(self, m, z, next_state):
1367 """ denormalisation and sign-bug corrections
1368 """
1369 m.next = next_state
1370 # denormalised, correct exponent to zero
1371 with m.If(z.is_denormalised):
1372 m.d.sync += z.e.eq(z.fp.N127)
1373
1374 def pack(self, m, z, next_state):
1375 """ packs the result into the output (detects overflow->Inf)
1376 """
1377 m.next = next_state
1378 # if overflow occurs, return inf
1379 with m.If(z.is_overflowed):
1380 m.d.sync += z.inf(z.s)
1381 with m.Else():
1382 m.d.sync += z.create(z.s, z.e, z.m)
1383
1384 def put_z(self, m, z, out_z, next_state):
1385 """ put_z: stores the result in the output. raises stb and waits
1386 for ack to be set to 1 before moving to the next state.
1387 resets stb back to zero when that occurs, as acknowledgement.
1388 """
1389 m.d.sync += [
1390 out_z.v.eq(z.v)
1391 ]
1392 with m.If(out_z.valid_o & out_z.ready_i_test):
1393 m.d.sync += out_z.valid_o.eq(0)
1394 m.next = next_state
1395 with m.Else():
1396 m.d.sync += out_z.valid_o.eq(1)
1397
1398
1399 class FPState(FPBase):
1400 def __init__(self, state_from):
1401 self.state_from = state_from
1402
1403 def set_inputs(self, inputs):
1404 self.inputs = inputs
1405 for k, v in inputs.items():
1406 setattr(self, k, v)
1407
1408 def set_outputs(self, outputs):
1409 self.outputs = outputs
1410 for k, v in outputs.items():
1411 setattr(self, k, v)
1412
1413
1414 class FPID:
1415 def __init__(self, id_wid):
1416 self.id_wid = id_wid
1417 if self.id_wid:
1418 self.in_mid = Signal(id_wid, reset_less=True)
1419 self.out_mid = Signal(id_wid, reset_less=True)
1420 else:
1421 self.in_mid = None
1422 self.out_mid = None
1423
1424 def idsync(self, m):
1425 if self.id_wid is not None:
1426 m.d.sync += self.out_mid.eq(self.in_mid)
1427
1428
1429 if __name__ == '__main__':
1430 unittest.main()