1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jake Lifshay
9 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
11 from operator
import or_
12 from functools
import reduce
14 from nmutil
.singlepipe
import PrevControl
, NextControl
15 from nmutil
.pipeline
import ObjectProxy
21 """ Class describing binary floating-point formats based on IEEE 754.
23 :attribute e_width: the number of bits in the exponent field.
24 :attribute m_width: the number of bits stored in the mantissa
26 :attribute has_int_bit: if the FP format has an explicit integer bit (like
27 the x87 80-bit format). The bit is considered part of the mantissa.
28 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
29 Image/Buffer formats are FP numbers without a sign bit.)
37 """ Create ``FPFormat`` instance. """
38 self
.e_width
= e_width
39 self
.m_width
= m_width
40 self
.has_int_bit
= has_int_bit
41 self
.has_sign
= has_sign
43 def __eq__(self
, other
):
44 """ Check for equality. """
45 if not isinstance(other
, FPFormat
):
47 return (self
.e_width
== other
.e_width
48 and self
.m_width
== other
.m_width
49 and self
.has_int_bit
== other
.has_int_bit
50 and self
.has_sign
== other
.has_sign
)
54 """ Get standard IEEE 754-2008 format.
56 :param width: bit-width of requested format.
57 :returns: the requested ``FPFormat`` instance.
60 return FPFormat(5, 10)
62 return FPFormat(8, 23)
64 return FPFormat(11, 52)
66 return FPFormat(15, 112)
67 if width
> 128 and width
% 32 == 0:
68 if width
> 1000000: # arbitrary upper limit
69 raise ValueError("width too big")
70 e_width
= round(4 * math
.log2(width
)) - 13
71 return FPFormat(e_width
, width
- 1 - e_width
)
72 raise ValueError("width must be the bit-width of a valid IEEE"
73 " 754-2008 binary format")
78 if self
== self
.standard(self
.width
):
79 return f
"FPFormat.standard({self.width})"
82 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
83 if self
.has_int_bit
is not False:
84 retval
+= f
", {self.has_int_bit}"
85 if self
.has_sign
is not True:
86 retval
+= f
", {self.has_sign}"
89 def get_sign_field(self
, x
):
90 """ returns the sign bit of its input number, x
91 (assumes FPFormat is set to signed - has_sign=True)
93 return x
>> (self
.e_width
+ self
.m_width
)
95 def get_exponent_field(self
, x
):
96 """ returns the raw exponent of its input number, x (no bias subtracted)
98 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
101 def get_exponent(self
, x
):
102 """ returns the exponent of its input number, x
104 return self
.get_exponent_field(x
) - self
.exponent_bias
106 def get_mantissa_field(self
, x
):
107 """ returns the mantissa of its input number, x
109 return x
& self
.mantissa_mask
111 def is_zero(self
, x
):
112 """ returns true if x is +/- zero
114 return (self
.get_exponent(x
) == self
.e_sub
and
115 self
.get_mantissa_field(x
) == 0)
117 def is_subnormal(self
, x
):
118 """ returns true if x is subnormal (exp at minimum)
120 return (self
.get_exponent(x
) == self
.e_sub
and
121 self
.get_mantissa_field(x
) != 0)
124 """ returns true if x is infinite
126 return (self
.get_exponent(x
) == self
.e_max
and
127 self
.get_mantissa_field(x
) == 0)
130 """ returns true if x is a nan (quiet or signalling)
132 return (self
.get_exponent(x
) == self
.e_max
and
133 self
.get_mantissa_field(x
) != 0)
135 def is_quiet_nan(self
, x
):
136 """ returns true if x is a quiet nan
138 highbit
= 1<<(self
.m_width
-1)
139 return (self
.get_exponent(x
) == self
.e_max
and
140 self
.get_mantissa_field(x
) != 0 and
141 self
.get_mantissa_field(x
) & highbit
!= 0)
143 def is_nan_signaling(self
, x
):
144 """ returns true if x is a signalling nan
146 highbit
= 1<<(self
.m_width
-1)
147 return ((self
.get_exponent(x
) == self
.e_max
) and
148 (self
.get_mantissa_field(x
) != 0) and
149 (self
.get_mantissa_field(x
) & highbit
) == 0)
153 """ Get the total number of bits in the FP format. """
154 return self
.has_sign
+ self
.e_width
+ self
.m_width
157 def mantissa_mask(self
):
158 """ Get a mantissa mask based on the mantissa width """
159 return (1 << self
.m_width
) - 1
162 def exponent_inf_nan(self
):
163 """ Get the value of the exponent field designating infinity/NaN. """
164 return (1 << self
.e_width
) - 1
168 """ get the maximum exponent (minus bias)
170 return self
.exponent_inf_nan
- self
.exponent_bias
174 return self
.exponent_denormal_zero
- self
.exponent_bias
176 def exponent_denormal_zero(self
):
177 """ Get the value of the exponent field designating denormal/zero. """
181 def exponent_min_normal(self
):
182 """ Get the minimum value of the exponent field for normal numbers. """
186 def exponent_max_normal(self
):
187 """ Get the maximum value of the exponent field for normal numbers. """
188 return self
.exponent_inf_nan
- 1
191 def exponent_bias(self
):
192 """ Get the exponent bias. """
193 return (1 << (self
.e_width
- 1)) - 1
196 def fraction_width(self
):
197 """ Get the number of mantissa bits that are fraction bits. """
198 return self
.m_width
- self
.has_int_bit
201 class TestFPFormat(unittest
.TestCase
):
202 """ very quick test for FPFormat
205 def test_fpformat_fp64(self
):
206 f64
= FPFormat
.standard(64)
207 from sfpy
import Float64
208 x
= Float64(1.0).bits
211 self
.assertEqual(f64
.get_exponent(x
), 0)
212 x
= Float64(2.0).bits
214 self
.assertEqual(f64
.get_exponent(x
), 1)
216 x
= Float64(1.5).bits
217 m
= f64
.get_mantissa_field(x
)
218 print (hex(x
), hex(m
))
219 self
.assertEqual(m
, 0x8000000000000)
221 s
= f64
.get_sign_field(x
)
222 print (hex(x
), hex(s
))
223 self
.assertEqual(s
, 0)
225 x
= Float64(-1.5).bits
226 s
= f64
.get_sign_field(x
)
227 print (hex(x
), hex(s
))
228 self
.assertEqual(s
, 1)
230 def test_fpformat_fp32(self
):
231 f32
= FPFormat
.standard(32)
232 from sfpy
import Float32
233 x
= Float32(1.0).bits
236 self
.assertEqual(f32
.get_exponent(x
), 0)
237 x
= Float32(2.0).bits
239 self
.assertEqual(f32
.get_exponent(x
), 1)
241 x
= Float32(1.5).bits
242 m
= f32
.get_mantissa_field(x
)
243 print (hex(x
), hex(m
))
244 self
.assertEqual(m
, 0x400000)
247 x
= Float32(-1.0).sqrt()
250 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
251 f32
.get_mantissa_field(x
), i
)
252 self
.assertEqual(i
, True)
255 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
258 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
259 f32
.get_mantissa_field(x
), i
)
260 self
.assertEqual(i
, True)
265 i
= f32
.is_subnormal(x
)
266 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
267 f32
.get_mantissa_field(x
), i
)
268 self
.assertEqual(i
, True)
272 i
= f32
.is_subnormal(x
)
273 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
274 f32
.get_mantissa_field(x
), i
)
275 self
.assertEqual(i
, False)
279 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
280 f32
.get_mantissa_field(x
), i
)
281 self
.assertEqual(i
, True)
286 def __init__(self
, width
):
288 self
.smax
= int(log(width
) / log(2))
289 self
.i
= Signal(width
, reset_less
=True)
290 self
.s
= Signal(self
.smax
, reset_less
=True)
291 self
.o
= Signal(width
, reset_less
=True)
293 def elaborate(self
, platform
):
295 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
300 """ Generates variable-length single-cycle shifter from a series
301 of conditional tests on each bit of the left/right shift operand.
302 Each bit tested produces output shifted by that number of bits,
303 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
304 shifts by 2 bits, each partial result cascading to the next Mux.
306 Could be adapted to do arithmetic shift by taking copies of the
307 MSB instead of zeros.
310 def __init__(self
, width
):
312 self
.smax
= int(log(width
) / log(2))
314 def lshift(self
, op
, s
):
318 def rshift(self
, op
, s
):
323 class FPNumBaseRecord
:
324 """ Floating-point Base Number Class.
326 This class is designed to be passed around in other data structures
327 (between pipelines and between stages). Its "friend" is FPNumBase,
328 which is a *module*. The reason for the discernment is because
329 nmigen modules that are not added to submodules results in the
330 irritating "Elaboration" warning. Despite not *needing* FPNumBase
331 in many cases to be added as a submodule (because it is just data)
332 this was not possible to solve without splitting out the data from
336 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
339 # assert false, "missing name"
343 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
344 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
345 e_max
= 1 << (e_width
-3)
346 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
349 # mantissa extra bits (top,guard,round)
351 m_width
+= self
.m_extra
355 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
356 e_width
+= self
.e_extra
359 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
360 self
.m_width
= m_width
361 self
.e_width
= e_width
362 self
.e_start
= self
.rmw
363 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
365 self
.v
= Signal(width
, reset_less
=True,
366 name
=name
+"v") # Latched copy of value
367 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
368 self
.e
= Signal((e_width
, True),
369 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
370 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
375 def drop_in(self
, fp
):
381 fp
.width
= self
.width
382 fp
.e_width
= self
.e_width
383 fp
.e_max
= self
.e_max
384 fp
.m_width
= self
.m_width
385 fp
.e_start
= self
.e_start
386 fp
.e_end
= self
.e_end
387 fp
.m_extra
= self
.m_extra
389 m_width
= self
.m_width
391 e_width
= self
.e_width
393 self
.mzero
= Const(0, (m_width
, False))
394 m_msb
= 1 << (self
.m_width
-2)
395 self
.msb1
= Const(m_msb
, (m_width
, False))
396 self
.m1s
= Const(-1, (m_width
, False))
397 self
.P128
= Const(e_max
, (e_width
, True))
398 self
.P127
= Const(e_max
-1, (e_width
, True))
399 self
.N127
= Const(-(e_max
-1), (e_width
, True))
400 self
.N126
= Const(-(e_max
-2), (e_width
, True))
402 def create(self
, s
, e
, m
):
403 """ creates a value from sign / exponent / mantissa
405 bias is added here, to the exponent.
407 NOTE: order is important, because e_start/e_end can be
408 a bit too long (overwriting s).
411 self
.v
[0:self
.e_start
].eq(m
), # mantissa
412 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
413 self
.v
[-1].eq(s
), # sign
417 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
420 return (s
, self
.fp
.P128
, 0)
423 return (s
, self
.fp
.N127
, 0)
426 return self
.create(*self
._nan
(s
))
429 return self
.create(*self
._inf
(s
))
432 return self
.create(*self
._zero
(s
))
434 def create2(self
, s
, e
, m
):
435 """ creates a value from sign / exponent / mantissa
437 bias is added here, to the exponent
439 e
= e
+ self
.P127
# exp (add on bias)
440 return Cat(m
[0:self
.e_start
],
441 e
[0:self
.e_end
-self
.e_start
],
445 return self
.create2(s
, self
.P128
, self
.msb1
)
448 return self
.create2(s
, self
.P128
, self
.mzero
)
451 return self
.create2(s
, self
.N127
, self
.mzero
)
459 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
462 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
463 """ Floating-point Base Number Class
466 def __init__(self
, fp
):
471 self
.is_nan
= Signal(reset_less
=True)
472 self
.is_zero
= Signal(reset_less
=True)
473 self
.is_inf
= Signal(reset_less
=True)
474 self
.is_overflowed
= Signal(reset_less
=True)
475 self
.is_denormalised
= Signal(reset_less
=True)
476 self
.exp_128
= Signal(reset_less
=True)
477 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
478 self
.exp_lt_n126
= Signal(reset_less
=True)
479 self
.exp_zero
= Signal(reset_less
=True)
480 self
.exp_gt_n126
= Signal(reset_less
=True)
481 self
.exp_gt127
= Signal(reset_less
=True)
482 self
.exp_n127
= Signal(reset_less
=True)
483 self
.exp_n126
= Signal(reset_less
=True)
484 self
.m_zero
= Signal(reset_less
=True)
485 self
.m_msbzero
= Signal(reset_less
=True)
487 def elaborate(self
, platform
):
489 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
490 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
491 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
492 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
493 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
494 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
495 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
496 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
497 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
498 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
499 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
500 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
501 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
502 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
503 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
508 return (self
.exp_128
) & (~self
.m_zero
)
511 return (self
.exp_128
) & (self
.m_zero
)
514 return (self
.exp_n127
) & (self
.m_zero
)
516 def _is_overflowed(self
):
517 return self
.exp_gt127
519 def _is_denormalised(self
):
520 # XXX NOT to be used for "official" quiet NaN tests!
521 # particularly when the MSB has been extended
522 return (self
.exp_n126
) & (self
.m_msbzero
)
525 class FPNumOut(FPNumBase
):
526 """ Floating-point Number Class
528 Contains signals for an incoming copy of the value, decoded into
529 sign / exponent / mantissa.
530 Also contains encoding functions, creation and recognition of
531 zero, NaN and inf (all signed)
533 Four extra bits are included in the mantissa: the top bit
534 (m[-1]) is effectively a carry-overflow. The other three are
535 guard (m[2]), round (m[1]), and sticky (m[0])
538 def __init__(self
, fp
):
539 FPNumBase
.__init
__(self
, fp
)
541 def elaborate(self
, platform
):
542 m
= FPNumBase
.elaborate(self
, platform
)
547 class MultiShiftRMerge(Elaboratable
):
548 """ shifts down (right) and merges lower bits into m[0].
549 m[0] is the "sticky" bit, basically
552 def __init__(self
, width
, s_max
=None):
554 s_max
= int(log(width
) / log(2))
556 self
.m
= Signal(width
, reset_less
=True)
557 self
.inp
= Signal(width
, reset_less
=True)
558 self
.diff
= Signal(s_max
, reset_less
=True)
561 def elaborate(self
, platform
):
564 rs
= Signal(self
.width
, reset_less
=True)
565 m_mask
= Signal(self
.width
, reset_less
=True)
566 smask
= Signal(self
.width
, reset_less
=True)
567 stickybit
= Signal(reset_less
=True)
568 maxslen
= Signal(self
.smax
, reset_less
=True)
569 maxsleni
= Signal(self
.smax
, reset_less
=True)
571 sm
= MultiShift(self
.width
-1)
572 m0s
= Const(0, self
.width
-1)
573 mw
= Const(self
.width
-1, len(self
.diff
))
574 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
575 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
579 # shift mantissa by maxslen, mask by inverse
580 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
581 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
582 smask
.eq(self
.inp
[1:] & m_mask
),
583 # sticky bit combines all mask (and mantissa low bit)
584 stickybit
.eq(smask
.bool() | self
.inp
[0]),
585 # mantissa result contains m[0] already.
586 self
.m
.eq(Cat(stickybit
, rs
))
591 class FPNumShift(FPNumBase
, Elaboratable
):
592 """ Floating-point Number Class for shifting
595 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
596 FPNumBase
.__init
__(self
, width
, m_extra
)
597 self
.latch_in
= Signal()
602 def elaborate(self
, platform
):
603 m
= FPNumBase
.elaborate(self
, platform
)
605 m
.d
.comb
+= self
.s
.eq(op
.s
)
606 m
.d
.comb
+= self
.e
.eq(op
.e
)
607 m
.d
.comb
+= self
.m
.eq(op
.m
)
609 with self
.mainm
.State("align"):
610 with m
.If(self
.e
< self
.inv
.e
):
611 m
.d
.sync
+= self
.shift_down()
615 def shift_down(self
, inp
):
616 """ shifts a mantissa down by one. exponent is increased to compensate
618 accuracy is lost as a result in the mantissa however there are 3
619 guard bits (the latter of which is the "sticky" bit)
621 return [self
.e
.eq(inp
.e
+ 1),
622 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
625 def shift_down_multi(self
, diff
):
626 """ shifts a mantissa down. exponent is increased to compensate
628 accuracy is lost as a result in the mantissa however there are 3
629 guard bits (the latter of which is the "sticky" bit)
631 this code works by variable-shifting the mantissa by up to
632 its maximum bit-length: no point doing more (it'll still be
635 the sticky bit is computed by shifting a batch of 1s by
636 the same amount, which will introduce zeros. it's then
637 inverted and used as a mask to get the LSBs of the mantissa.
638 those are then |'d into the sticky bit.
640 sm
= MultiShift(self
.width
)
641 mw
= Const(self
.m_width
-1, len(diff
))
642 maxslen
= Mux(diff
> mw
, mw
, diff
)
643 rs
= sm
.rshift(self
.m
[1:], maxslen
)
644 maxsleni
= mw
- maxslen
645 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
647 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
648 return [self
.e
.eq(self
.e
+ diff
),
649 self
.m
.eq(Cat(stickybits
, rs
))
652 def shift_up_multi(self
, diff
):
653 """ shifts a mantissa up. exponent is decreased to compensate
655 sm
= MultiShift(self
.width
)
656 mw
= Const(self
.m_width
, len(diff
))
657 maxslen
= Mux(diff
> mw
, mw
, diff
)
659 return [self
.e
.eq(self
.e
- diff
),
660 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
664 class FPNumDecode(FPNumBase
):
665 """ Floating-point Number Class
667 Contains signals for an incoming copy of the value, decoded into
668 sign / exponent / mantissa.
669 Also contains encoding functions, creation and recognition of
670 zero, NaN and inf (all signed)
672 Four extra bits are included in the mantissa: the top bit
673 (m[-1]) is effectively a carry-overflow. The other three are
674 guard (m[2]), round (m[1]), and sticky (m[0])
677 def __init__(self
, op
, fp
):
678 FPNumBase
.__init
__(self
, fp
)
681 def elaborate(self
, platform
):
682 m
= FPNumBase
.elaborate(self
, platform
)
684 m
.d
.comb
+= self
.decode(self
.v
)
689 """ decodes a latched value into sign / exponent / mantissa
691 bias is subtracted here, from the exponent. exponent
692 is extended to 10 bits so that subtract 127 is done on
695 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
696 #print ("decode", self.e_end)
697 return [self
.m
.eq(Cat(*args
)), # mantissa
698 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
699 self
.s
.eq(v
[-1]), # sign
703 class FPNumIn(FPNumBase
):
704 """ Floating-point Number Class
706 Contains signals for an incoming copy of the value, decoded into
707 sign / exponent / mantissa.
708 Also contains encoding functions, creation and recognition of
709 zero, NaN and inf (all signed)
711 Four extra bits are included in the mantissa: the top bit
712 (m[-1]) is effectively a carry-overflow. The other three are
713 guard (m[2]), round (m[1]), and sticky (m[0])
716 def __init__(self
, op
, fp
):
717 FPNumBase
.__init
__(self
, fp
)
718 self
.latch_in
= Signal()
721 def decode2(self
, m
):
722 """ decodes a latched value into sign / exponent / mantissa
724 bias is subtracted here, from the exponent. exponent
725 is extended to 10 bits so that subtract 127 is done on
729 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
730 #print ("decode", self.e_end)
731 res
= ObjectProxy(m
, pipemode
=False)
732 res
.m
= Cat(*args
) # mantissa
733 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
738 """ decodes a latched value into sign / exponent / mantissa
740 bias is subtracted here, from the exponent. exponent
741 is extended to 10 bits so that subtract 127 is done on
744 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
745 #print ("decode", self.e_end)
746 return [self
.m
.eq(Cat(*args
)), # mantissa
747 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
748 self
.s
.eq(v
[-1]), # sign
751 def shift_down(self
, inp
):
752 """ shifts a mantissa down by one. exponent is increased to compensate
754 accuracy is lost as a result in the mantissa however there are 3
755 guard bits (the latter of which is the "sticky" bit)
757 return [self
.e
.eq(inp
.e
+ 1),
758 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
761 def shift_down_multi(self
, diff
, inp
=None):
762 """ shifts a mantissa down. exponent is increased to compensate
764 accuracy is lost as a result in the mantissa however there are 3
765 guard bits (the latter of which is the "sticky" bit)
767 this code works by variable-shifting the mantissa by up to
768 its maximum bit-length: no point doing more (it'll still be
771 the sticky bit is computed by shifting a batch of 1s by
772 the same amount, which will introduce zeros. it's then
773 inverted and used as a mask to get the LSBs of the mantissa.
774 those are then |'d into the sticky bit.
778 sm
= MultiShift(self
.width
)
779 mw
= Const(self
.m_width
-1, len(diff
))
780 maxslen
= Mux(diff
> mw
, mw
, diff
)
781 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
782 maxsleni
= mw
- maxslen
783 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
785 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
786 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
787 return [self
.e
.eq(inp
.e
+ diff
),
788 self
.m
.eq(Cat(stickybit
, rs
))
791 def shift_up_multi(self
, diff
):
792 """ shifts a mantissa up. exponent is decreased to compensate
794 sm
= MultiShift(self
.width
)
795 mw
= Const(self
.m_width
, len(diff
))
796 maxslen
= Mux(diff
> mw
, mw
, diff
)
798 return [self
.e
.eq(self
.e
- diff
),
799 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
803 class Trigger(Elaboratable
):
806 self
.stb
= Signal(reset
=0)
808 self
.trigger
= Signal(reset_less
=True)
810 def elaborate(self
, platform
):
812 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
816 return [self
.stb
.eq(inp
.stb
),
821 return [self
.stb
, self
.ack
]
824 class FPOpIn(PrevControl
):
825 def __init__(self
, width
):
826 PrevControl
.__init
__(self
)
833 def chain_inv(self
, in_op
, extra
=None):
835 if extra
is not None:
837 return [self
.v
.eq(in_op
.v
), # receive value
838 self
.stb
.eq(stb
), # receive STB
839 in_op
.ack
.eq(~self
.ack
), # send ACK
842 def chain_from(self
, in_op
, extra
=None):
844 if extra
is not None:
846 return [self
.v
.eq(in_op
.v
), # receive value
847 self
.stb
.eq(stb
), # receive STB
848 in_op
.ack
.eq(self
.ack
), # send ACK
852 class FPOpOut(NextControl
):
853 def __init__(self
, width
):
854 NextControl
.__init
__(self
)
861 def chain_inv(self
, in_op
, extra
=None):
863 if extra
is not None:
865 return [self
.v
.eq(in_op
.v
), # receive value
866 self
.stb
.eq(stb
), # receive STB
867 in_op
.ack
.eq(~self
.ack
), # send ACK
870 def chain_from(self
, in_op
, extra
=None):
872 if extra
is not None:
874 return [self
.v
.eq(in_op
.v
), # receive value
875 self
.stb
.eq(stb
), # receive STB
876 in_op
.ack
.eq(self
.ack
), # send ACK
881 def __init__(self
, name
=None):
884 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
885 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
886 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
887 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
889 #self.roundz = Signal(reset_less=True)
898 return [self
.guard
.eq(inp
.guard
),
899 self
.round_bit
.eq(inp
.round_bit
),
900 self
.sticky
.eq(inp
.sticky
),
905 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
908 class OverflowMod(Elaboratable
, Overflow
):
909 def __init__(self
, name
=None):
910 Overflow
.__init
__(self
, name
)
913 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
916 yield from Overflow
.__iter
__(self
)
917 yield self
.roundz_out
920 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
922 def elaborate(self
, platform
):
924 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
)
929 """ IEEE754 Floating Point Base Class
931 contains common functions for FP manipulation, such as
932 extracting and packing operands, normalisation, denormalisation,
936 def get_op(self
, m
, op
, v
, next_state
):
937 """ this function moves to the next state and copies the operand
938 when both stb and ack are 1.
939 acknowledgement is sent by setting ack to ZERO.
943 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
945 # op is latched in from FPNumIn class on same ack/stb
946 m
.d
.comb
+= ack
.eq(0)
948 m
.d
.comb
+= ack
.eq(1)
951 def denormalise(self
, m
, a
):
952 """ denormalises a number. this is probably the wrong name for
953 this function. for normalised numbers (exponent != minimum)
954 one *extra* bit (the implicit 1) is added *back in*.
955 for denormalised numbers, the mantissa is left alone
956 and the exponent increased by 1.
958 both cases *effectively multiply the number stored by 2*,
959 which has to be taken into account when extracting the result.
961 with m
.If(a
.exp_n127
):
962 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
964 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
966 def op_normalise(self
, m
, op
, next_state
):
967 """ operand normalisation
968 NOTE: just like "align", this one keeps going round every clock
969 until the result's exponent is within acceptable "range"
971 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
973 op
.e
.eq(op
.e
- 1), # DECREASE exponent
974 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
979 def normalise_1(self
, m
, z
, of
, next_state
):
980 """ first stage normalisation
982 NOTE: just like "align", this one keeps going round every clock
983 until the result's exponent is within acceptable "range"
984 NOTE: the weirdness of reassigning guard and round is due to
985 the extra mantissa bits coming from tot[0..2]
987 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
989 z
.e
.eq(z
.e
- 1), # DECREASE exponent
990 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
991 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
992 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
993 of
.round_bit
.eq(0), # reset round bit
999 def normalise_2(self
, m
, z
, of
, next_state
):
1000 """ second stage normalisation
1002 NOTE: just like "align", this one keeps going round every clock
1003 until the result's exponent is within acceptable "range"
1004 NOTE: the weirdness of reassigning guard and round is due to
1005 the extra mantissa bits coming from tot[0..2]
1007 with m
.If(z
.e
< z
.fp
.N126
):
1009 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1010 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1011 of
.guard
.eq(z
.m
[0]),
1013 of
.round_bit
.eq(of
.guard
),
1014 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1019 def roundz(self
, m
, z
, roundz
):
1020 """ performs rounding on the output. TODO: different kinds of rounding
1023 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1024 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1025 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1027 def corrections(self
, m
, z
, next_state
):
1028 """ denormalisation and sign-bug corrections
1031 # denormalised, correct exponent to zero
1032 with m
.If(z
.is_denormalised
):
1033 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1035 def pack(self
, m
, z
, next_state
):
1036 """ packs the result into the output (detects overflow->Inf)
1039 # if overflow occurs, return inf
1040 with m
.If(z
.is_overflowed
):
1041 m
.d
.sync
+= z
.inf(z
.s
)
1043 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1045 def put_z(self
, m
, z
, out_z
, next_state
):
1046 """ put_z: stores the result in the output. raises stb and waits
1047 for ack to be set to 1 before moving to the next state.
1048 resets stb back to zero when that occurs, as acknowledgement.
1053 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1054 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1057 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1060 class FPState(FPBase
):
1061 def __init__(self
, state_from
):
1062 self
.state_from
= state_from
1064 def set_inputs(self
, inputs
):
1065 self
.inputs
= inputs
1066 for k
, v
in inputs
.items():
1069 def set_outputs(self
, outputs
):
1070 self
.outputs
= outputs
1071 for k
, v
in outputs
.items():
1076 def __init__(self
, id_wid
):
1077 self
.id_wid
= id_wid
1079 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1080 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1085 def idsync(self
, m
):
1086 if self
.id_wid
is not None:
1087 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1090 if __name__
== '__main__':