switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 """IEEE754 Floating Point Library
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
5
6 """
7
8
9 from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
10 Value, Shape, signed, unsigned)
11 from nmigen.utils import bits_for
12 from operator import or_
13 from functools import reduce
14
15 from nmutil.singlepipe import PrevControl, NextControl
16 from nmutil.pipeline import ObjectProxy
17 import unittest
18 import math
19 import enum
20
21 try:
22 from nmigen.hdl.smtlib2 import RoundingModeEnum
23 _HAVE_SMTLIB2 = True
24 except ImportError:
25 _HAVE_SMTLIB2 = False
26
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
28 _raise_err = object()
29
30
31 class FPRoundingMode(enum.Enum):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
34
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
38 RNE = 0b00
39 """Round to Nearest Even
40
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
45 """
46
47 ROUND_NEAREST_TIES_TO_EVEN = RNE
48 DEFAULT = RNE
49
50 RTZ = 0b01
51 """Round towards Zero
52
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
56 """
57
58 ROUND_TOWARDS_ZERO = RTZ
59
60 RTP = 0b10
61 """Round towards +Infinity
62
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
66 +Infinity.
67 """
68
69 ROUND_TOWARDS_POSITIVE = RTP
70
71 RTN = 0b11
72 """Round towards -Infinity
73
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
77 -Infinity.
78 """
79
80 ROUND_TOWARDS_NEGATIVE = RTN
81
82 RNA = 0b100
83 """Round to Nearest Away
84
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
89 """
90
91 ROUND_NEAREST_TIES_TO_AWAY = RNA
92
93 RTOP = 0b101
94 """Round to Odd, unsigned zeros are Positive
95
96 Not in smtlib2.
97
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
101
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
104
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
107
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
110 (e.g. f80).
111 """
112
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE = RTOP
114
115 RTON = 0b110
116 """Round to Odd, unsigned zeros are Negative
117
118 Not in smtlib2.
119
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
123
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
126
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
129
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
132 (e.g. f80).
133 """
134
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE = RTON
136
137 @staticmethod
138 def make_array(f):
139 l = [None] * len(FPRoundingMode)
140 for rm in FPRoundingMode:
141 l[rm.value] = f(rm)
142 return Array(l)
143
144 def overflow_rounds_to_inf(self, sign):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
147 """
148 not_sign = ~sign if isinstance(sign, Value) else not sign
149 if self is FPRoundingMode.RNE:
150 return True
151 elif self is FPRoundingMode.RTZ:
152 return False
153 elif self is FPRoundingMode.RTP:
154 return not_sign
155 elif self is FPRoundingMode.RTN:
156 return sign
157 elif self is FPRoundingMode.RNA:
158 return True
159 elif self is FPRoundingMode.RTOP:
160 return False
161 else:
162 assert self is FPRoundingMode.RTON
163 return False
164
165 def underflow_rounds_to_zero(self, sign):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
168 """
169 not_sign = ~sign if isinstance(sign, Value) else not sign
170 if self is FPRoundingMode.RNE:
171 return True
172 elif self is FPRoundingMode.RTZ:
173 return True
174 elif self is FPRoundingMode.RTP:
175 return sign
176 elif self is FPRoundingMode.RTN:
177 return not_sign
178 elif self is FPRoundingMode.RNA:
179 return True
180 elif self is FPRoundingMode.RTOP:
181 return False
182 else:
183 assert self is FPRoundingMode.RTON
184 return False
185
186 def zero_sign(self):
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
189 """
190 if self is FPRoundingMode.RNE:
191 return False
192 elif self is FPRoundingMode.RTZ:
193 return False
194 elif self is FPRoundingMode.RTP:
195 return False
196 elif self is FPRoundingMode.RTN:
197 return True
198 elif self is FPRoundingMode.RNA:
199 return False
200 elif self is FPRoundingMode.RTOP:
201 return False
202 else:
203 assert self is FPRoundingMode.RTON
204 return True
205
206 if _HAVE_SMTLIB2:
207 def to_smtlib2(self, default=_raise_err):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
211 """
212 if self is FPRoundingMode.RNE:
213 return RoundingModeEnum.RNE
214 elif self is FPRoundingMode.RTZ:
215 return RoundingModeEnum.RTZ
216 elif self is FPRoundingMode.RTP:
217 return RoundingModeEnum.RTP
218 elif self is FPRoundingMode.RTN:
219 return RoundingModeEnum.RTN
220 elif self is FPRoundingMode.RNA:
221 return RoundingModeEnum.RNA
222 else:
223 assert self in (FPRoundingMode.RTOP, FPRoundingMode.RTON)
224 if default is _raise_err:
225 raise ValueError(
226 "no corresponding smtlib2 rounding mode", self)
227 return default
228
229
230
231
232 class FPFormat:
233 """ Class describing binary floating-point formats based on IEEE 754.
234
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
237 field.
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
242 """
243
244 def __init__(self,
245 e_width,
246 m_width,
247 has_int_bit=False,
248 has_sign=True):
249 """ Create ``FPFormat`` instance. """
250 self.e_width = e_width
251 self.m_width = m_width
252 self.has_int_bit = has_int_bit
253 self.has_sign = has_sign
254
255 def __eq__(self, other):
256 """ Check for equality. """
257 if not isinstance(other, FPFormat):
258 return NotImplemented
259 return (self.e_width == other.e_width
260 and self.m_width == other.m_width
261 and self.has_int_bit == other.has_int_bit
262 and self.has_sign == other.has_sign)
263
264 @staticmethod
265 def standard(width):
266 """ Get standard IEEE 754-2008 format.
267
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
270 """
271 if width == 16:
272 return FPFormat(5, 10)
273 if width == 32:
274 return FPFormat(8, 23)
275 if width == 64:
276 return FPFormat(11, 52)
277 if width == 128:
278 return FPFormat(15, 112)
279 if width > 128 and width % 32 == 0:
280 if width > 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width = round(4 * math.log2(width)) - 13
283 return FPFormat(e_width, width - 1 - e_width)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
286
287 def __repr__(self):
288 """ Get repr. """
289 try:
290 if self == self.standard(self.width):
291 return f"FPFormat.standard({self.width})"
292 except ValueError:
293 pass
294 retval = f"FPFormat({self.e_width}, {self.m_width}"
295 if self.has_int_bit is not False:
296 retval += f", {self.has_int_bit}"
297 if self.has_sign is not True:
298 retval += f", {self.has_sign}"
299 return retval + ")"
300
301 def get_sign_field(self, x):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
304 """
305 return x >> (self.e_width + self.m_width)
306
307 def get_exponent_field(self, x):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
309 """
310 x = ((x >> self.m_width) & self.exponent_inf_nan)
311 return x
312
313 def get_exponent(self, x):
314 """ returns the exponent of its input number, x
315 """
316 x = self.get_exponent_field(x)
317 if isinstance(x, Value) and not x.shape().signed:
318 # convert x to signed without changing its value,
319 # since exponents can be negative
320 x |= Const(0, signed(1))
321 return x - self.exponent_bias
322
323 def get_exponent_value(self, x):
324 """ returns the exponent of its input number, x, adjusted for the
325 mathematically correct subnormal exponent.
326 """
327 x = self.get_exponent_field(x)
328 if isinstance(x, Value) and not x.shape().signed:
329 # convert x to signed without changing its value,
330 # since exponents can be negative
331 x |= Const(0, signed(1))
332 return x + (x == self.exponent_denormal_zero) - self.exponent_bias
333
334 def get_mantissa_field(self, x):
335 """ returns the mantissa of its input number, x
336 """
337 return x & self.mantissa_mask
338
339 def get_mantissa_value(self, x):
340 """ returns the mantissa of its input number, x, but with the
341 implicit bit, if any, made explicit.
342 """
343 if self.has_int_bit:
344 return self.get_mantissa_field(x)
345 exponent_field = self.get_exponent_field(x)
346 mantissa_field = self.get_mantissa_field(x)
347 implicit_bit = exponent_field != self.exponent_denormal_zero
348 return (implicit_bit << self.fraction_width) | mantissa_field
349
350 def is_zero(self, x):
351 """ returns true if x is +/- zero
352 """
353 return (self.get_exponent(x) == self.e_sub) & \
354 (self.get_mantissa_field(x) == 0)
355
356 def is_subnormal(self, x):
357 """ returns true if x is subnormal (exp at minimum)
358 """
359 return (self.get_exponent(x) == self.e_sub) & \
360 (self.get_mantissa_field(x) != 0)
361
362 def is_inf(self, x):
363 """ returns true if x is infinite
364 """
365 return (self.get_exponent(x) == self.e_max) & \
366 (self.get_mantissa_field(x) == 0)
367
368 def is_nan(self, x):
369 """ returns true if x is a nan (quiet or signalling)
370 """
371 return (self.get_exponent(x) == self.e_max) & \
372 (self.get_mantissa_field(x) != 0)
373
374 def is_quiet_nan(self, x):
375 """ returns true if x is a quiet nan
376 """
377 highbit = 1 << (self.m_width - 1)
378 return (self.get_exponent(x) == self.e_max) & \
379 (self.get_mantissa_field(x) != 0) & \
380 (self.get_mantissa_field(x) & highbit != 0)
381
382 def to_quiet_nan(self, x):
383 """ converts `x` to a quiet NaN """
384 highbit = 1 << (self.m_width - 1)
385 return x | highbit | self.exponent_mask
386
387 def quiet_nan(self, sign=0):
388 """ return the default quiet NaN with sign `sign` """
389 return self.to_quiet_nan(self.zero(sign))
390
391 def zero(self, sign=0):
392 """ return zero with sign `sign` """
393 return (sign != 0) << (self.e_width + self.m_width)
394
395 def inf(self, sign=0):
396 """ return infinity with sign `sign` """
397 return self.zero(sign) | self.exponent_mask
398
399 def is_nan_signaling(self, x):
400 """ returns true if x is a signalling nan
401 """
402 highbit = 1 << (self.m_width - 1)
403 return (self.get_exponent(x) == self.e_max) & \
404 (self.get_mantissa_field(x) != 0) & \
405 (self.get_mantissa_field(x) & highbit) == 0
406
407 @property
408 def width(self):
409 """ Get the total number of bits in the FP format. """
410 return self.has_sign + self.e_width + self.m_width
411
412 @property
413 def mantissa_mask(self):
414 """ Get a mantissa mask based on the mantissa width """
415 return (1 << self.m_width) - 1
416
417 @property
418 def exponent_mask(self):
419 """ Get an exponent mask """
420 return self.exponent_inf_nan << self.m_width
421
422 @property
423 def exponent_inf_nan(self):
424 """ Get the value of the exponent field designating infinity/NaN. """
425 return (1 << self.e_width) - 1
426
427 @property
428 def e_max(self):
429 """ get the maximum exponent (minus bias)
430 """
431 return self.exponent_inf_nan - self.exponent_bias
432
433 @property
434 def e_sub(self):
435 return self.exponent_denormal_zero - self.exponent_bias
436 @property
437 def exponent_denormal_zero(self):
438 """ Get the value of the exponent field designating denormal/zero. """
439 return 0
440
441 @property
442 def exponent_min_normal(self):
443 """ Get the minimum value of the exponent field for normal numbers. """
444 return 1
445
446 @property
447 def exponent_max_normal(self):
448 """ Get the maximum value of the exponent field for normal numbers. """
449 return self.exponent_inf_nan - 1
450
451 @property
452 def exponent_bias(self):
453 """ Get the exponent bias. """
454 return (1 << (self.e_width - 1)) - 1
455
456 @property
457 def fraction_width(self):
458 """ Get the number of mantissa bits that are fraction bits. """
459 return self.m_width - self.has_int_bit
460
461 @staticmethod
462 def from_pspec(pspec):
463 width = getattr(pspec, "width", None)
464 assert width is None or isinstance(width, int)
465 fpformat = getattr(pspec, "fpformat", None)
466 if fpformat is None:
467 assert width is not None, \
468 "neither pspec.width nor pspec.fpformat were set"
469 fpformat = FPFormat.standard(width)
470 else:
471 assert isinstance(fpformat, FPFormat)
472 assert width == fpformat.width
473 return fpformat
474
475
476 class TestFPFormat(unittest.TestCase):
477 """ very quick test for FPFormat
478 """
479
480 def test_fpformat_fp64(self):
481 f64 = FPFormat.standard(64)
482 from sfpy import Float64
483 x = Float64(1.0).bits
484 print (hex(x))
485
486 self.assertEqual(f64.get_exponent(x), 0)
487 x = Float64(2.0).bits
488 print (hex(x))
489 self.assertEqual(f64.get_exponent(x), 1)
490
491 x = Float64(1.5).bits
492 m = f64.get_mantissa_field(x)
493 print (hex(x), hex(m))
494 self.assertEqual(m, 0x8000000000000)
495
496 s = f64.get_sign_field(x)
497 print (hex(x), hex(s))
498 self.assertEqual(s, 0)
499
500 x = Float64(-1.5).bits
501 s = f64.get_sign_field(x)
502 print (hex(x), hex(s))
503 self.assertEqual(s, 1)
504
505 def test_fpformat_fp32(self):
506 f32 = FPFormat.standard(32)
507 from sfpy import Float32
508 x = Float32(1.0).bits
509 print (hex(x))
510
511 self.assertEqual(f32.get_exponent(x), 0)
512 x = Float32(2.0).bits
513 print (hex(x))
514 self.assertEqual(f32.get_exponent(x), 1)
515
516 x = Float32(1.5).bits
517 m = f32.get_mantissa_field(x)
518 print (hex(x), hex(m))
519 self.assertEqual(m, 0x400000)
520
521 # NaN test
522 x = Float32(-1.0).sqrt()
523 x = x.bits
524 i = f32.is_nan(x)
525 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
526 f32.get_mantissa_field(x), i)
527 self.assertEqual(i, True)
528
529 # Inf test
530 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
531 x = x.bits
532 i = f32.is_inf(x)
533 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
534 f32.get_mantissa_field(x), i)
535 self.assertEqual(i, True)
536
537 # subnormal
538 x = Float32(1e-41)
539 x = x.bits
540 i = f32.is_subnormal(x)
541 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
542 f32.get_mantissa_field(x), i)
543 self.assertEqual(i, True)
544
545 x = Float32(0.0)
546 x = x.bits
547 i = f32.is_subnormal(x)
548 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
549 f32.get_mantissa_field(x), i)
550 self.assertEqual(i, False)
551
552 # zero
553 i = f32.is_zero(x)
554 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
555 f32.get_mantissa_field(x), i)
556 self.assertEqual(i, True)
557
558
559 class MultiShiftR(Elaboratable):
560
561 def __init__(self, width):
562 self.width = width
563 self.smax = bits_for(width - 1)
564 self.i = Signal(width, reset_less=True)
565 self.s = Signal(self.smax, reset_less=True)
566 self.o = Signal(width, reset_less=True)
567
568 def elaborate(self, platform):
569 m = Module()
570 m.d.comb += self.o.eq(self.i >> self.s)
571 return m
572
573
574 class MultiShift:
575 """ Generates variable-length single-cycle shifter from a series
576 of conditional tests on each bit of the left/right shift operand.
577 Each bit tested produces output shifted by that number of bits,
578 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
579 shifts by 2 bits, each partial result cascading to the next Mux.
580
581 Could be adapted to do arithmetic shift by taking copies of the
582 MSB instead of zeros.
583 """
584
585 def __init__(self, width):
586 self.width = width
587 self.smax = bits_for(width - 1)
588
589 def lshift(self, op, s):
590 res = op << s
591 return res[:len(op)]
592
593 def rshift(self, op, s):
594 res = op >> s
595 return res[:len(op)]
596
597
598 class FPNumBaseRecord:
599 """ Floating-point Base Number Class.
600
601 This class is designed to be passed around in other data structures
602 (between pipelines and between stages). Its "friend" is FPNumBase,
603 which is a *module*. The reason for the discernment is because
604 nmigen modules that are not added to submodules results in the
605 irritating "Elaboration" warning. Despite not *needing* FPNumBase
606 in many cases to be added as a submodule (because it is just data)
607 this was not possible to solve without splitting out the data from
608 the module.
609 """
610
611 def __init__(self, width=None, m_extra=True, e_extra=False, name=None,
612 fpformat=None):
613 if name is None:
614 name = ""
615 # assert false, "missing name"
616 else:
617 name += "_"
618 if fpformat is None:
619 assert isinstance(width, int)
620 fpformat = FPFormat.standard(width)
621 else:
622 assert isinstance(fpformat, FPFormat)
623 if width is None:
624 width = fpformat.width
625 assert isinstance(width, int)
626 assert width == fpformat.width
627 self.width = width
628 self.fpformat = fpformat
629 assert not fpformat.has_int_bit
630 assert fpformat.has_sign
631 m_width = fpformat.m_width + 1 # 1 extra bit (overflow)
632 e_width = fpformat.e_width + 2 # 2 extra bits (overflow)
633 e_max = 1 << (e_width-3)
634 self.rmw = m_width - 1 # real mantissa width (not including extras)
635 self.e_max = e_max
636 if m_extra:
637 # mantissa extra bits (top,guard,round)
638 self.m_extra = 3
639 m_width += self.m_extra
640 else:
641 self.m_extra = 0
642 if e_extra:
643 self.e_extra = 6 # enough to cover FP64 when converting to FP16
644 e_width += self.e_extra
645 else:
646 self.e_extra = 0
647 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
648 self.m_width = m_width
649 self.e_width = e_width
650 self.e_start = self.rmw
651 self.e_end = self.rmw + self.e_width - 2 # for decoding
652
653 self.v = Signal(width, reset_less=True,
654 name=name+"v") # Latched copy of value
655 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
656 self.e = Signal(signed(e_width),
657 reset_less=True, name=name+"e") # exp+2 bits, signed
658 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
659
660 self.fp = self
661 self.drop_in(self)
662
663 def drop_in(self, fp):
664 fp.s = self.s
665 fp.e = self.e
666 fp.m = self.m
667 fp.v = self.v
668 fp.rmw = self.rmw
669 fp.width = self.width
670 fp.e_width = self.e_width
671 fp.e_max = self.e_max
672 fp.m_width = self.m_width
673 fp.e_start = self.e_start
674 fp.e_end = self.e_end
675 fp.m_extra = self.m_extra
676
677 m_width = self.m_width
678 e_max = self.e_max
679 e_width = self.e_width
680
681 self.mzero = Const(0, unsigned(m_width))
682 m_msb = 1 << (self.m_width-2)
683 self.msb1 = Const(m_msb, unsigned(m_width))
684 self.m1s = Const(-1, unsigned(m_width))
685 self.P128 = Const(e_max, signed(e_width))
686 self.P127 = Const(e_max-1, signed(e_width))
687 self.N127 = Const(-(e_max-1), signed(e_width))
688 self.N126 = Const(-(e_max-2), signed(e_width))
689
690 def create(self, s, e, m):
691 """ creates a value from sign / exponent / mantissa
692
693 bias is added here, to the exponent.
694
695 NOTE: order is important, because e_start/e_end can be
696 a bit too long (overwriting s).
697 """
698 return [
699 self.v[0:self.e_start].eq(m), # mantissa
700 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
701 self.v[-1].eq(s), # sign
702 ]
703
704 def _nan(self, s):
705 return (s, self.fp.P128, 1 << (self.e_start-1))
706
707 def _inf(self, s):
708 return (s, self.fp.P128, 0)
709
710 def _zero(self, s):
711 return (s, self.fp.N127, 0)
712
713 def nan(self, s):
714 return self.create(*self._nan(s))
715
716 def quieted_nan(self, other):
717 assert isinstance(other, FPNumBaseRecord)
718 assert self.width == other.width
719 return self.create(other.s, self.fp.P128,
720 other.v[0:self.e_start] | (1 << (self.e_start - 1)))
721
722 def inf(self, s):
723 return self.create(*self._inf(s))
724
725 def max_normal(self, s):
726 return self.create(s, self.fp.P127, ~0)
727
728 def min_denormal(self, s):
729 return self.create(s, self.fp.N127, 1)
730
731 def zero(self, s):
732 return self.create(*self._zero(s))
733
734 def create2(self, s, e, m):
735 """ creates a value from sign / exponent / mantissa
736
737 bias is added here, to the exponent
738 """
739 e = e + self.P127 # exp (add on bias)
740 return Cat(m[0:self.e_start],
741 e[0:self.e_end-self.e_start],
742 s)
743
744 def nan2(self, s):
745 return self.create2(s, self.P128, self.msb1)
746
747 def inf2(self, s):
748 return self.create2(s, self.P128, self.mzero)
749
750 def zero2(self, s):
751 return self.create2(s, self.N127, self.mzero)
752
753 def __iter__(self):
754 yield self.s
755 yield self.e
756 yield self.m
757
758 def eq(self, inp):
759 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
760
761
762 class FPNumBase(FPNumBaseRecord, Elaboratable):
763 """ Floating-point Base Number Class
764 """
765
766 def __init__(self, fp):
767 fp.drop_in(self)
768 self.fp = fp
769 e_width = fp.e_width
770
771 self.is_nan = Signal(reset_less=True)
772 self.is_zero = Signal(reset_less=True)
773 self.is_inf = Signal(reset_less=True)
774 self.is_overflowed = Signal(reset_less=True)
775 self.is_denormalised = Signal(reset_less=True)
776 self.exp_128 = Signal(reset_less=True)
777 self.exp_sub_n126 = Signal(signed(e_width), reset_less=True)
778 self.exp_lt_n126 = Signal(reset_less=True)
779 self.exp_zero = Signal(reset_less=True)
780 self.exp_gt_n126 = Signal(reset_less=True)
781 self.exp_gt127 = Signal(reset_less=True)
782 self.exp_n127 = Signal(reset_less=True)
783 self.exp_n126 = Signal(reset_less=True)
784 self.m_zero = Signal(reset_less=True)
785 self.m_msbzero = Signal(reset_less=True)
786
787 def elaborate(self, platform):
788 m = Module()
789 m.d.comb += self.is_nan.eq(self._is_nan())
790 m.d.comb += self.is_zero.eq(self._is_zero())
791 m.d.comb += self.is_inf.eq(self._is_inf())
792 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
793 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
794 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
795 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
796 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
797 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
798 m.d.comb += self.exp_zero.eq(self.e == 0)
799 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
800 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
801 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
802 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
803 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
804
805 return m
806
807 def _is_nan(self):
808 return (self.exp_128) & (~self.m_zero)
809
810 def _is_inf(self):
811 return (self.exp_128) & (self.m_zero)
812
813 def _is_zero(self):
814 return (self.exp_n127) & (self.m_zero)
815
816 def _is_overflowed(self):
817 return self.exp_gt127
818
819 def _is_denormalised(self):
820 # XXX NOT to be used for "official" quiet NaN tests!
821 # particularly when the MSB has been extended
822 return (self.exp_n126) & (self.m_msbzero)
823
824
825 class FPNumOut(FPNumBase):
826 """ Floating-point Number Class
827
828 Contains signals for an incoming copy of the value, decoded into
829 sign / exponent / mantissa.
830 Also contains encoding functions, creation and recognition of
831 zero, NaN and inf (all signed)
832
833 Four extra bits are included in the mantissa: the top bit
834 (m[-1]) is effectively a carry-overflow. The other three are
835 guard (m[2]), round (m[1]), and sticky (m[0])
836 """
837
838 def __init__(self, fp):
839 FPNumBase.__init__(self, fp)
840
841 def elaborate(self, platform):
842 m = FPNumBase.elaborate(self, platform)
843
844 return m
845
846
847 class MultiShiftRMerge(Elaboratable):
848 """ shifts down (right) and merges lower bits into m[0].
849 m[0] is the "sticky" bit, basically
850 """
851
852 def __init__(self, width, s_max=None):
853 if s_max is None:
854 s_max = bits_for(width - 1)
855 self.smax = Shape.cast(s_max)
856 self.m = Signal(width, reset_less=True)
857 self.inp = Signal(width, reset_less=True)
858 self.diff = Signal(s_max, reset_less=True)
859 self.width = width
860
861 def elaborate(self, platform):
862 m = Module()
863
864 rs = Signal(self.width, reset_less=True)
865 m_mask = Signal(self.width, reset_less=True)
866 smask = Signal(self.width, reset_less=True)
867 stickybit = Signal(reset_less=True)
868 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
869 maxslen = Signal(self.smax.width, reset_less=True)
870 maxsleni = Signal(self.smax.width, reset_less=True)
871
872 sm = MultiShift(self.width-1)
873 m0s = Const(0, self.width-1)
874 mw = Const(self.width-1, len(self.diff))
875 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
876 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
877 ]
878
879 m.d.comb += [
880 # shift mantissa by maxslen, mask by inverse
881 rs.eq(sm.rshift(self.inp[1:], maxslen)),
882 m_mask.eq(sm.rshift(~m0s, maxsleni)),
883 smask.eq(self.inp[1:] & m_mask),
884 # sticky bit combines all mask (and mantissa low bit)
885 stickybit.eq(smask.bool() | self.inp[0]),
886 # mantissa result contains m[0] already.
887 self.m.eq(Cat(stickybit, rs))
888 ]
889 return m
890
891
892 class FPNumShift(FPNumBase, Elaboratable):
893 """ Floating-point Number Class for shifting
894 """
895
896 def __init__(self, mainm, op, inv, width, m_extra=True):
897 FPNumBase.__init__(self, width, m_extra)
898 self.latch_in = Signal()
899 self.mainm = mainm
900 self.inv = inv
901 self.op = op
902
903 def elaborate(self, platform):
904 m = FPNumBase.elaborate(self, platform)
905
906 m.d.comb += self.s.eq(op.s)
907 m.d.comb += self.e.eq(op.e)
908 m.d.comb += self.m.eq(op.m)
909
910 with self.mainm.State("align"):
911 with m.If(self.e < self.inv.e):
912 m.d.sync += self.shift_down()
913
914 return m
915
916 def shift_down(self, inp):
917 """ shifts a mantissa down by one. exponent is increased to compensate
918
919 accuracy is lost as a result in the mantissa however there are 3
920 guard bits (the latter of which is the "sticky" bit)
921 """
922 return [self.e.eq(inp.e + 1),
923 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
924 ]
925
926 def shift_down_multi(self, diff):
927 """ shifts a mantissa down. exponent is increased to compensate
928
929 accuracy is lost as a result in the mantissa however there are 3
930 guard bits (the latter of which is the "sticky" bit)
931
932 this code works by variable-shifting the mantissa by up to
933 its maximum bit-length: no point doing more (it'll still be
934 zero).
935
936 the sticky bit is computed by shifting a batch of 1s by
937 the same amount, which will introduce zeros. it's then
938 inverted and used as a mask to get the LSBs of the mantissa.
939 those are then |'d into the sticky bit.
940 """
941 sm = MultiShift(self.width)
942 mw = Const(self.m_width-1, len(diff))
943 maxslen = Mux(diff > mw, mw, diff)
944 rs = sm.rshift(self.m[1:], maxslen)
945 maxsleni = mw - maxslen
946 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
947
948 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
949 return [self.e.eq(self.e + diff),
950 self.m.eq(Cat(stickybits, rs))
951 ]
952
953 def shift_up_multi(self, diff):
954 """ shifts a mantissa up. exponent is decreased to compensate
955 """
956 sm = MultiShift(self.width)
957 mw = Const(self.m_width, len(diff))
958 maxslen = Mux(diff > mw, mw, diff)
959
960 return [self.e.eq(self.e - diff),
961 self.m.eq(sm.lshift(self.m, maxslen))
962 ]
963
964
965 class FPNumDecode(FPNumBase):
966 """ Floating-point Number Class
967
968 Contains signals for an incoming copy of the value, decoded into
969 sign / exponent / mantissa.
970 Also contains encoding functions, creation and recognition of
971 zero, NaN and inf (all signed)
972
973 Four extra bits are included in the mantissa: the top bit
974 (m[-1]) is effectively a carry-overflow. The other three are
975 guard (m[2]), round (m[1]), and sticky (m[0])
976 """
977
978 def __init__(self, op, fp):
979 FPNumBase.__init__(self, fp)
980 self.op = op
981
982 def elaborate(self, platform):
983 m = FPNumBase.elaborate(self, platform)
984
985 m.d.comb += self.decode(self.v)
986
987 return m
988
989 def decode(self, v):
990 """ decodes a latched value into sign / exponent / mantissa
991
992 bias is subtracted here, from the exponent. exponent
993 is extended to 10 bits so that subtract 127 is done on
994 a 10-bit number
995 """
996 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
997 #print ("decode", self.e_end)
998 return [self.m.eq(Cat(*args)), # mantissa
999 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
1000 self.s.eq(v[-1]), # sign
1001 ]
1002
1003
1004 class FPNumIn(FPNumBase):
1005 """ Floating-point Number Class
1006
1007 Contains signals for an incoming copy of the value, decoded into
1008 sign / exponent / mantissa.
1009 Also contains encoding functions, creation and recognition of
1010 zero, NaN and inf (all signed)
1011
1012 Four extra bits are included in the mantissa: the top bit
1013 (m[-1]) is effectively a carry-overflow. The other three are
1014 guard (m[2]), round (m[1]), and sticky (m[0])
1015 """
1016
1017 def __init__(self, op, fp):
1018 FPNumBase.__init__(self, fp)
1019 self.latch_in = Signal()
1020 self.op = op
1021
1022 def decode2(self, m):
1023 """ decodes a latched value into sign / exponent / mantissa
1024
1025 bias is subtracted here, from the exponent. exponent
1026 is extended to 10 bits so that subtract 127 is done on
1027 a 10-bit number
1028 """
1029 v = self.v
1030 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
1031 #print ("decode", self.e_end)
1032 res = ObjectProxy(m, pipemode=False)
1033 res.m = Cat(*args) # mantissa
1034 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
1035 res.s = v[-1] # sign
1036 return res
1037
1038 def decode(self, v):
1039 """ decodes a latched value into sign / exponent / mantissa
1040
1041 bias is subtracted here, from the exponent. exponent
1042 is extended to 10 bits so that subtract 127 is done on
1043 a 10-bit number
1044 """
1045 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
1046 #print ("decode", self.e_end)
1047 return [self.m.eq(Cat(*args)), # mantissa
1048 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
1049 self.s.eq(v[-1]), # sign
1050 ]
1051
1052 def shift_down(self, inp):
1053 """ shifts a mantissa down by one. exponent is increased to compensate
1054
1055 accuracy is lost as a result in the mantissa however there are 3
1056 guard bits (the latter of which is the "sticky" bit)
1057 """
1058 return [self.e.eq(inp.e + 1),
1059 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
1060 ]
1061
1062 def shift_down_multi(self, diff, inp=None):
1063 """ shifts a mantissa down. exponent is increased to compensate
1064
1065 accuracy is lost as a result in the mantissa however there are 3
1066 guard bits (the latter of which is the "sticky" bit)
1067
1068 this code works by variable-shifting the mantissa by up to
1069 its maximum bit-length: no point doing more (it'll still be
1070 zero).
1071
1072 the sticky bit is computed by shifting a batch of 1s by
1073 the same amount, which will introduce zeros. it's then
1074 inverted and used as a mask to get the LSBs of the mantissa.
1075 those are then |'d into the sticky bit.
1076 """
1077 if inp is None:
1078 inp = self
1079 sm = MultiShift(self.width)
1080 mw = Const(self.m_width-1, len(diff))
1081 maxslen = Mux(diff > mw, mw, diff)
1082 rs = sm.rshift(inp.m[1:], maxslen)
1083 maxsleni = mw - maxslen
1084 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
1085
1086 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1087 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
1088 return [self.e.eq(inp.e + diff),
1089 self.m.eq(Cat(stickybit, rs))
1090 ]
1091
1092 def shift_up_multi(self, diff):
1093 """ shifts a mantissa up. exponent is decreased to compensate
1094 """
1095 sm = MultiShift(self.width)
1096 mw = Const(self.m_width, len(diff))
1097 maxslen = Mux(diff > mw, mw, diff)
1098
1099 return [self.e.eq(self.e - diff),
1100 self.m.eq(sm.lshift(self.m, maxslen))
1101 ]
1102
1103
1104 class Trigger(Elaboratable):
1105 def __init__(self):
1106
1107 self.stb = Signal(reset=0)
1108 self.ack = Signal()
1109 self.trigger = Signal(reset_less=True)
1110
1111 def elaborate(self, platform):
1112 m = Module()
1113 m.d.comb += self.trigger.eq(self.stb & self.ack)
1114 return m
1115
1116 def eq(self, inp):
1117 return [self.stb.eq(inp.stb),
1118 self.ack.eq(inp.ack)
1119 ]
1120
1121 def ports(self):
1122 return [self.stb, self.ack]
1123
1124
1125 class FPOpIn(PrevControl):
1126 def __init__(self, width):
1127 PrevControl.__init__(self)
1128 self.width = width
1129
1130 @property
1131 def v(self):
1132 return self.data_i
1133
1134 def chain_inv(self, in_op, extra=None):
1135 stb = in_op.stb
1136 if extra is not None:
1137 stb = stb & extra
1138 return [self.v.eq(in_op.v), # receive value
1139 self.stb.eq(stb), # receive STB
1140 in_op.ack.eq(~self.ack), # send ACK
1141 ]
1142
1143 def chain_from(self, in_op, extra=None):
1144 stb = in_op.stb
1145 if extra is not None:
1146 stb = stb & extra
1147 return [self.v.eq(in_op.v), # receive value
1148 self.stb.eq(stb), # receive STB
1149 in_op.ack.eq(self.ack), # send ACK
1150 ]
1151
1152
1153 class FPOpOut(NextControl):
1154 def __init__(self, width):
1155 NextControl.__init__(self)
1156 self.width = width
1157
1158 @property
1159 def v(self):
1160 return self.data_o
1161
1162 def chain_inv(self, in_op, extra=None):
1163 stb = in_op.stb
1164 if extra is not None:
1165 stb = stb & extra
1166 return [self.v.eq(in_op.v), # receive value
1167 self.stb.eq(stb), # receive STB
1168 in_op.ack.eq(~self.ack), # send ACK
1169 ]
1170
1171 def chain_from(self, in_op, extra=None):
1172 stb = in_op.stb
1173 if extra is not None:
1174 stb = stb & extra
1175 return [self.v.eq(in_op.v), # receive value
1176 self.stb.eq(stb), # receive STB
1177 in_op.ack.eq(self.ack), # send ACK
1178 ]
1179
1180
1181 class Overflow:
1182 # TODO: change FFLAGS to be FPSCR's status flags
1183 FFLAGS_NV = Const(1<<4, 5) # invalid operation
1184 FFLAGS_DZ = Const(1<<3, 5) # divide by zero
1185 FFLAGS_OF = Const(1<<2, 5) # overflow
1186 FFLAGS_UF = Const(1<<1, 5) # underflow
1187 FFLAGS_NX = Const(1<<0, 5) # inexact
1188 def __init__(self, name=None):
1189 if name is None:
1190 name = ""
1191 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
1192 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
1193 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
1194 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
1195 self.fpflags = Signal(5, reset_less=True, name=name+"fflags")
1196
1197 self.sign = Signal(reset_less=True, name=name+"sign")
1198 """sign bit -- 1 means negative, 0 means positive"""
1199
1200 self.rm = Signal(FPRoundingMode, name=name+"rm",
1201 reset=FPRoundingMode.DEFAULT)
1202 """rounding mode"""
1203
1204 #self.roundz = Signal(reset_less=True)
1205
1206 def __iter__(self):
1207 yield self.guard
1208 yield self.round_bit
1209 yield self.sticky
1210 yield self.m0
1211 yield self.fpflags
1212 yield self.sign
1213 yield self.rm
1214
1215 def eq(self, inp):
1216 return [self.guard.eq(inp.guard),
1217 self.round_bit.eq(inp.round_bit),
1218 self.sticky.eq(inp.sticky),
1219 self.m0.eq(inp.m0),
1220 self.fpflags.eq(inp.fpflags),
1221 self.sign.eq(inp.sign),
1222 self.rm.eq(inp.rm)]
1223
1224 @property
1225 def roundz_rne(self):
1226 """true if the mantissa should be rounded up for `rm == RNE`
1227
1228 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1229 """
1230 return self.guard & (self.round_bit | self.sticky | self.m0)
1231
1232 @property
1233 def roundz_rna(self):
1234 """true if the mantissa should be rounded up for `rm == RNA`
1235
1236 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1237 """
1238 return self.guard
1239
1240 @property
1241 def roundz_rtn(self):
1242 """true if the mantissa should be rounded up for `rm == RTN`
1243
1244 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1245 """
1246 return self.sign & (self.guard | self.round_bit | self.sticky)
1247
1248 @property
1249 def roundz_rto(self):
1250 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1251
1252 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1253 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1254 """
1255 return ~self.m0 & (self.guard | self.round_bit | self.sticky)
1256
1257 @property
1258 def roundz_rtp(self):
1259 """true if the mantissa should be rounded up for `rm == RTP`
1260
1261 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1262 """
1263 return ~self.sign & (self.guard | self.round_bit | self.sticky)
1264
1265 @property
1266 def roundz_rtz(self):
1267 """true if the mantissa should be rounded up for `rm == RTZ`
1268
1269 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1270 """
1271 return False
1272
1273 @property
1274 def roundz(self):
1275 """true if the mantissa should be rounded up for the current rounding
1276 mode `self.rm`
1277 """
1278 d = {
1279 FPRoundingMode.RNA: self.roundz_rna,
1280 FPRoundingMode.RNE: self.roundz_rne,
1281 FPRoundingMode.RTN: self.roundz_rtn,
1282 FPRoundingMode.RTOP: self.roundz_rto,
1283 FPRoundingMode.RTON: self.roundz_rto,
1284 FPRoundingMode.RTP: self.roundz_rtp,
1285 FPRoundingMode.RTZ: self.roundz_rtz,
1286 }
1287 return FPRoundingMode.make_array(lambda rm: d[rm])[self.rm]
1288
1289
1290 class OverflowMod(Elaboratable, Overflow):
1291 def __init__(self, name=None):
1292 Overflow.__init__(self, name)
1293 if name is None:
1294 name = ""
1295 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
1296
1297 def __iter__(self):
1298 yield from Overflow.__iter__(self)
1299 yield self.roundz_out
1300
1301 def eq(self, inp):
1302 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
1303
1304 def elaborate(self, platform):
1305 m = Module()
1306 m.d.comb += self.roundz_out.eq(self.roundz) # roundz is a property
1307 return m
1308
1309
1310 class FPBase:
1311 """ IEEE754 Floating Point Base Class
1312
1313 contains common functions for FP manipulation, such as
1314 extracting and packing operands, normalisation, denormalisation,
1315 rounding etc.
1316 """
1317
1318 def get_op(self, m, op, v, next_state):
1319 """ this function moves to the next state and copies the operand
1320 when both stb and ack are 1.
1321 acknowledgement is sent by setting ack to ZERO.
1322 """
1323 res = v.decode2(m)
1324 ack = Signal()
1325 with m.If((op.ready_o) & (op.valid_i_test)):
1326 m.next = next_state
1327 # op is latched in from FPNumIn class on same ack/stb
1328 m.d.comb += ack.eq(0)
1329 with m.Else():
1330 m.d.comb += ack.eq(1)
1331 return [res, ack]
1332
1333 def denormalise(self, m, a):
1334 """ denormalises a number. this is probably the wrong name for
1335 this function. for normalised numbers (exponent != minimum)
1336 one *extra* bit (the implicit 1) is added *back in*.
1337 for denormalised numbers, the mantissa is left alone
1338 and the exponent increased by 1.
1339
1340 both cases *effectively multiply the number stored by 2*,
1341 which has to be taken into account when extracting the result.
1342 """
1343 with m.If(a.exp_n127):
1344 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
1345 with m.Else():
1346 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
1347
1348 def op_normalise(self, m, op, next_state):
1349 """ operand normalisation
1350 NOTE: just like "align", this one keeps going round every clock
1351 until the result's exponent is within acceptable "range"
1352 """
1353 with m.If((op.m[-1] == 0)): # check last bit of mantissa
1354 m.d.sync += [
1355 op.e.eq(op.e - 1), # DECREASE exponent
1356 op.m.eq(op.m << 1), # shift mantissa UP
1357 ]
1358 with m.Else():
1359 m.next = next_state
1360
1361 def normalise_1(self, m, z, of, next_state):
1362 """ first stage normalisation
1363
1364 NOTE: just like "align", this one keeps going round every clock
1365 until the result's exponent is within acceptable "range"
1366 NOTE: the weirdness of reassigning guard and round is due to
1367 the extra mantissa bits coming from tot[0..2]
1368 """
1369 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
1370 m.d.sync += [
1371 z.e.eq(z.e - 1), # DECREASE exponent
1372 z.m.eq(z.m << 1), # shift mantissa UP
1373 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
1374 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
1375 of.round_bit.eq(0), # reset round bit
1376 of.m0.eq(of.guard),
1377 ]
1378 with m.Else():
1379 m.next = next_state
1380
1381 def normalise_2(self, m, z, of, next_state):
1382 """ second stage normalisation
1383
1384 NOTE: just like "align", this one keeps going round every clock
1385 until the result's exponent is within acceptable "range"
1386 NOTE: the weirdness of reassigning guard and round is due to
1387 the extra mantissa bits coming from tot[0..2]
1388 """
1389 with m.If(z.e < z.fp.N126):
1390 m.d.sync += [
1391 z.e.eq(z.e + 1), # INCREASE exponent
1392 z.m.eq(z.m >> 1), # shift mantissa DOWN
1393 of.guard.eq(z.m[0]),
1394 of.m0.eq(z.m[1]),
1395 of.round_bit.eq(of.guard),
1396 of.sticky.eq(of.sticky | of.round_bit)
1397 ]
1398 with m.Else():
1399 m.next = next_state
1400
1401 def roundz(self, m, z, roundz):
1402 """ performs rounding on the output. TODO: different kinds of rounding
1403 """
1404 with m.If(roundz):
1405 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1406 with m.If(z.m == z.fp.m1s): # all 1s
1407 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1408
1409 def corrections(self, m, z, next_state):
1410 """ denormalisation and sign-bug corrections
1411 """
1412 m.next = next_state
1413 # denormalised, correct exponent to zero
1414 with m.If(z.is_denormalised):
1415 m.d.sync += z.e.eq(z.fp.N127)
1416
1417 def pack(self, m, z, next_state):
1418 """ packs the result into the output (detects overflow->Inf)
1419 """
1420 m.next = next_state
1421 # if overflow occurs, return inf
1422 with m.If(z.is_overflowed):
1423 m.d.sync += z.inf(z.s)
1424 with m.Else():
1425 m.d.sync += z.create(z.s, z.e, z.m)
1426
1427 def put_z(self, m, z, out_z, next_state):
1428 """ put_z: stores the result in the output. raises stb and waits
1429 for ack to be set to 1 before moving to the next state.
1430 resets stb back to zero when that occurs, as acknowledgement.
1431 """
1432 m.d.sync += [
1433 out_z.v.eq(z.v)
1434 ]
1435 with m.If(out_z.valid_o & out_z.ready_i_test):
1436 m.d.sync += out_z.valid_o.eq(0)
1437 m.next = next_state
1438 with m.Else():
1439 m.d.sync += out_z.valid_o.eq(1)
1440
1441
1442 class FPState(FPBase):
1443 def __init__(self, state_from):
1444 self.state_from = state_from
1445
1446 def set_inputs(self, inputs):
1447 self.inputs = inputs
1448 for k, v in inputs.items():
1449 setattr(self, k, v)
1450
1451 def set_outputs(self, outputs):
1452 self.outputs = outputs
1453 for k, v in outputs.items():
1454 setattr(self, k, v)
1455
1456
1457 class FPID:
1458 def __init__(self, id_wid):
1459 self.id_wid = id_wid
1460 if self.id_wid:
1461 self.in_mid = Signal(id_wid, reset_less=True)
1462 self.out_mid = Signal(id_wid, reset_less=True)
1463 else:
1464 self.in_mid = None
1465 self.out_mid = None
1466
1467 def idsync(self, m):
1468 if self.id_wid is not None:
1469 m.d.sync += self.out_mid.eq(self.in_mid)
1470
1471
1472 if __name__ == '__main__':
1473 unittest.main()