1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jake Lifshay
9 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
11 from operator
import or_
12 from functools
import reduce
14 from nmutil
.singlepipe
import PrevControl
, NextControl
15 from nmutil
.pipeline
import ObjectProxy
21 """ Class describing binary floating-point formats based on IEEE 754.
23 :attribute e_width: the number of bits in the exponent field.
24 :attribute m_width: the number of bits stored in the mantissa
26 :attribute has_int_bit: if the FP format has an explicit integer bit (like
27 the x87 80-bit format). The bit is considered part of the mantissa.
28 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
29 Image/Buffer formats are FP numbers without a sign bit.)
37 """ Create ``FPFormat`` instance. """
38 self
.e_width
= e_width
39 self
.m_width
= m_width
40 self
.has_int_bit
= has_int_bit
41 self
.has_sign
= has_sign
43 def __eq__(self
, other
):
44 """ Check for equality. """
45 if not isinstance(other
, FPFormat
):
47 return (self
.e_width
== other
.e_width
48 and self
.m_width
== other
.m_width
49 and self
.has_int_bit
== other
.has_int_bit
50 and self
.has_sign
== other
.has_sign
)
54 """ Get standard IEEE 754-2008 format.
56 :param width: bit-width of requested format.
57 :returns: the requested ``FPFormat`` instance.
60 return FPFormat(5, 10)
62 return FPFormat(8, 23)
64 return FPFormat(11, 52)
66 return FPFormat(15, 112)
67 if width
> 128 and width
% 32 == 0:
68 if width
> 1000000: # arbitrary upper limit
69 raise ValueError("width too big")
70 e_width
= round(4 * math
.log2(width
)) - 13
71 return FPFormat(e_width
, width
- 1 - e_width
)
72 raise ValueError("width must be the bit-width of a valid IEEE"
73 " 754-2008 binary format")
78 if self
== self
.standard(self
.width
):
79 return f
"FPFormat.standard({self.width})"
82 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
83 if self
.has_int_bit
is not False:
84 retval
+= f
", {self.has_int_bit}"
85 if self
.has_sign
is not True:
86 retval
+= f
", {self.has_sign}"
89 def get_sign_field(self
, x
):
90 """ returns the sign bit of its input number, x
91 (assumes FPFormat is set to signed - has_sign=True)
93 return x
>> (self
.e_width
+ self
.m_width
)
95 def get_exponent_field(self
, x
):
96 """ returns the raw exponent of its input number, x (no bias subtracted)
98 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
101 def get_exponent(self
, x
):
102 """ returns the exponent of its input number, x
104 return self
.get_exponent_field(x
) - self
.exponent_bias
106 def get_mantissa_field(self
, x
):
107 """ returns the mantissa of its input number, x
109 return x
& self
.mantissa_mask
111 def is_zero(self
, x
):
112 """ returns true if x is +/- zero
114 return (self
.get_exponent(x
) == self
.e_sub
and
115 self
.get_mantissa_field(x
) == 0)
117 def is_subnormal(self
, x
):
118 """ returns true if x is subnormal (exp at minimum)
120 return (self
.get_exponent(x
) == self
.e_sub
and
121 self
.get_mantissa_field(x
) != 0)
124 """ returns true if x is infinite
126 return (self
.get_exponent(x
) == self
.e_max
and
127 self
.get_mantissa_field(x
) == 0)
130 """ returns true if x is a nan (quiet or signalling)
132 return (self
.get_exponent(x
) == self
.e_max
and
133 self
.get_mantissa_field(x
) != 0)
135 def is_quiet_nan(self
, x
):
136 """ returns true if x is a quiet nan
138 highbit
= 1<<(self
.m_width
-1)
139 return (self
.get_exponent(x
) == self
.e_max
and
140 self
.get_mantissa_field(x
) != 0 and
141 self
.get_mantissa_field(x
) & highbit
!= 0)
143 def is_nan_signaling(self
, x
):
144 """ returns true if x is a signalling nan
146 highbit
= 1<<(self
.m_width
-1)
147 return ((self
.get_exponent(x
) == self
.e_max
) and
148 (self
.get_mantissa_field(x
) != 0) and
149 (self
.get_mantissa_field(x
) & highbit
) == 0)
153 """ Get the total number of bits in the FP format. """
154 return self
.has_sign
+ self
.e_width
+ self
.m_width
157 def mantissa_mask(self
):
158 """ Get a mantissa mask based on the mantissa width """
159 return (1 << self
.m_width
) - 1
162 def exponent_inf_nan(self
):
163 """ Get the value of the exponent field designating infinity/NaN. """
164 return (1 << self
.e_width
) - 1
168 """ get the maximum exponent (minus bias)
170 return self
.exponent_inf_nan
- self
.exponent_bias
174 return self
.exponent_denormal_zero
- self
.exponent_bias
176 def exponent_denormal_zero(self
):
177 """ Get the value of the exponent field designating denormal/zero. """
181 def exponent_min_normal(self
):
182 """ Get the minimum value of the exponent field for normal numbers. """
186 def exponent_max_normal(self
):
187 """ Get the maximum value of the exponent field for normal numbers. """
188 return self
.exponent_inf_nan
- 1
191 def exponent_bias(self
):
192 """ Get the exponent bias. """
193 return (1 << (self
.e_width
- 1)) - 1
196 def fraction_width(self
):
197 """ Get the number of mantissa bits that are fraction bits. """
198 return self
.m_width
- self
.has_int_bit
201 class TestFPFormat(unittest
.TestCase
):
202 """ very quick test for FPFormat
205 def test_fpformat_fp64(self
):
206 f64
= FPFormat
.standard(64)
207 from sfpy
import Float64
208 x
= Float64(1.0).bits
211 self
.assertEqual(f64
.get_exponent(x
), 0)
212 x
= Float64(2.0).bits
214 self
.assertEqual(f64
.get_exponent(x
), 1)
216 x
= Float64(1.5).bits
217 m
= f64
.get_mantissa_field(x
)
218 print (hex(x
), hex(m
))
219 self
.assertEqual(m
, 0x8000000000000)
221 s
= f64
.get_sign_field(x
)
222 print (hex(x
), hex(s
))
223 self
.assertEqual(s
, 0)
225 x
= Float64(-1.5).bits
226 s
= f64
.get_sign_field(x
)
227 print (hex(x
), hex(s
))
228 self
.assertEqual(s
, 1)
230 def test_fpformat_fp32(self
):
231 f32
= FPFormat
.standard(32)
232 from sfpy
import Float32
233 x
= Float32(1.0).bits
236 self
.assertEqual(f32
.get_exponent(x
), 0)
237 x
= Float32(2.0).bits
239 self
.assertEqual(f32
.get_exponent(x
), 1)
241 x
= Float32(1.5).bits
242 m
= f32
.get_mantissa_field(x
)
243 print (hex(x
), hex(m
))
244 self
.assertEqual(m
, 0x400000)
247 x
= Float32(-1.0).sqrt()
250 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
251 f32
.get_mantissa_field(x
), i
)
252 self
.assertEqual(i
, True)
255 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
258 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
259 f32
.get_mantissa_field(x
), i
)
260 self
.assertEqual(i
, True)
265 i
= f32
.is_subnormal(x
)
266 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
267 f32
.get_mantissa_field(x
), i
)
268 self
.assertEqual(i
, True)
272 i
= f32
.is_subnormal(x
)
273 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
274 f32
.get_mantissa_field(x
), i
)
275 self
.assertEqual(i
, False)
279 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
280 f32
.get_mantissa_field(x
), i
)
281 self
.assertEqual(i
, True)
286 def __init__(self
, width
):
288 self
.smax
= int(log(width
) / log(2))
289 self
.i
= Signal(width
, reset_less
=True)
290 self
.s
= Signal(self
.smax
, reset_less
=True)
291 self
.o
= Signal(width
, reset_less
=True)
293 def elaborate(self
, platform
):
295 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
300 """ Generates variable-length single-cycle shifter from a series
301 of conditional tests on each bit of the left/right shift operand.
302 Each bit tested produces output shifted by that number of bits,
303 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
304 shifts by 2 bits, each partial result cascading to the next Mux.
306 Could be adapted to do arithmetic shift by taking copies of the
307 MSB instead of zeros.
310 def __init__(self
, width
):
312 self
.smax
= int(log(width
) / log(2))
314 def lshift(self
, op
, s
):
318 def rshift(self
, op
, s
):
323 class FPNumBaseRecord
:
324 """ Floating-point Base Number Class.
326 This class is designed to be passed around in other data structures
327 (between pipelines and between stages). Its "friend" is FPNumBase,
328 which is a *module*. The reason for the discernment is because
329 nmigen modules that are not added to submodules results in the
330 irritating "Elaboration" warning. Despite not *needing* FPNumBase
331 in many cases to be added as a submodule (because it is just data)
332 this was not possible to solve without splitting out the data from
336 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
339 # assert false, "missing name"
343 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
344 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
345 e_max
= 1 << (e_width
-3)
346 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
349 # mantissa extra bits (top,guard,round)
351 m_width
+= self
.m_extra
355 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
356 e_width
+= self
.e_extra
359 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
360 self
.m_width
= m_width
361 self
.e_width
= e_width
362 self
.e_start
= self
.rmw
363 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
365 self
.v
= Signal(width
, reset_less
=True,
366 name
=name
+"v") # Latched copy of value
367 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
368 self
.e
= Signal((e_width
, True),
369 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
370 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
375 def drop_in(self
, fp
):
381 fp
.width
= self
.width
382 fp
.e_width
= self
.e_width
383 fp
.e_max
= self
.e_max
384 fp
.m_width
= self
.m_width
385 fp
.e_start
= self
.e_start
386 fp
.e_end
= self
.e_end
387 fp
.m_extra
= self
.m_extra
389 m_width
= self
.m_width
391 e_width
= self
.e_width
393 self
.mzero
= Const(0, (m_width
, False))
394 m_msb
= 1 << (self
.m_width
-2)
395 self
.msb1
= Const(m_msb
, (m_width
, False))
396 self
.m1s
= Const(-1, (m_width
, False))
397 self
.P128
= Const(e_max
, (e_width
, True))
398 self
.P127
= Const(e_max
-1, (e_width
, True))
399 self
.N127
= Const(-(e_max
-1), (e_width
, True))
400 self
.N126
= Const(-(e_max
-2), (e_width
, True))
402 def create(self
, s
, e
, m
):
403 """ creates a value from sign / exponent / mantissa
405 bias is added here, to the exponent.
407 NOTE: order is important, because e_start/e_end can be
408 a bit too long (overwriting s).
411 self
.v
[0:self
.e_start
].eq(m
), # mantissa
412 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
413 self
.v
[-1].eq(s
), # sign
417 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
420 return (s
, self
.fp
.P128
, 0)
423 return (s
, self
.fp
.N127
, 0)
426 return self
.create(*self
._nan
(s
))
429 return self
.create(*self
._inf
(s
))
432 return self
.create(*self
._zero
(s
))
434 def create2(self
, s
, e
, m
):
435 """ creates a value from sign / exponent / mantissa
437 bias is added here, to the exponent
439 e
= e
+ self
.P127
# exp (add on bias)
440 return Cat(m
[0:self
.e_start
],
441 e
[0:self
.e_end
-self
.e_start
],
445 return self
.create2(s
, self
.P128
, self
.msb1
)
448 return self
.create2(s
, self
.P128
, self
.mzero
)
451 return self
.create2(s
, self
.N127
, self
.mzero
)
459 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
462 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
463 """ Floating-point Base Number Class
466 def __init__(self
, fp
):
471 self
.is_nan
= Signal(reset_less
=True)
472 self
.is_zero
= Signal(reset_less
=True)
473 self
.is_inf
= Signal(reset_less
=True)
474 self
.is_overflowed
= Signal(reset_less
=True)
475 self
.is_denormalised
= Signal(reset_less
=True)
476 self
.exp_128
= Signal(reset_less
=True)
477 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
478 self
.exp_lt_n126
= Signal(reset_less
=True)
479 self
.exp_zero
= Signal(reset_less
=True)
480 self
.exp_gt_n126
= Signal(reset_less
=True)
481 self
.exp_gt127
= Signal(reset_less
=True)
482 self
.exp_n127
= Signal(reset_less
=True)
483 self
.exp_n126
= Signal(reset_less
=True)
484 self
.m_zero
= Signal(reset_less
=True)
485 self
.m_msbzero
= Signal(reset_less
=True)
487 def elaborate(self
, platform
):
489 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
490 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
491 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
492 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
493 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
494 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
495 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
496 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
497 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
498 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
499 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
500 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
501 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
502 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
503 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
508 return (self
.exp_128
) & (~self
.m_zero
)
511 return (self
.exp_128
) & (self
.m_zero
)
514 return (self
.exp_n127
) & (self
.m_zero
)
516 def _is_overflowed(self
):
517 return self
.exp_gt127
519 def _is_denormalised(self
):
520 # XXX NOT to be used for "official" quiet NaN tests!
521 # particularly when the MSB has been extended
522 return (self
.exp_n126
) & (self
.m_msbzero
)
525 class FPNumOut(FPNumBase
):
526 """ Floating-point Number Class
528 Contains signals for an incoming copy of the value, decoded into
529 sign / exponent / mantissa.
530 Also contains encoding functions, creation and recognition of
531 zero, NaN and inf (all signed)
533 Four extra bits are included in the mantissa: the top bit
534 (m[-1]) is effectively a carry-overflow. The other three are
535 guard (m[2]), round (m[1]), and sticky (m[0])
538 def __init__(self
, fp
):
539 FPNumBase
.__init
__(self
, fp
)
541 def elaborate(self
, platform
):
542 m
= FPNumBase
.elaborate(self
, platform
)
547 class MultiShiftRMerge(Elaboratable
):
548 """ shifts down (right) and merges lower bits into m[0].
549 m[0] is the "sticky" bit, basically
552 def __init__(self
, width
, s_max
=None):
554 s_max
= int(log(width
) / log(2))
556 self
.m
= Signal(width
, reset_less
=True)
557 self
.inp
= Signal(width
, reset_less
=True)
558 self
.diff
= Signal(s_max
, reset_less
=True)
561 def elaborate(self
, platform
):
564 rs
= Signal(self
.width
, reset_less
=True)
565 m_mask
= Signal(self
.width
, reset_less
=True)
566 smask
= Signal(self
.width
, reset_less
=True)
567 stickybit
= Signal(reset_less
=True)
568 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
569 maxslen
= Signal(self
.smax
[0], reset_less
=True)
570 maxsleni
= Signal(self
.smax
[0], reset_less
=True)
572 sm
= MultiShift(self
.width
-1)
573 m0s
= Const(0, self
.width
-1)
574 mw
= Const(self
.width
-1, len(self
.diff
))
575 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
576 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
580 # shift mantissa by maxslen, mask by inverse
581 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
582 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
583 smask
.eq(self
.inp
[1:] & m_mask
),
584 # sticky bit combines all mask (and mantissa low bit)
585 stickybit
.eq(smask
.bool() | self
.inp
[0]),
586 # mantissa result contains m[0] already.
587 self
.m
.eq(Cat(stickybit
, rs
))
592 class FPNumShift(FPNumBase
, Elaboratable
):
593 """ Floating-point Number Class for shifting
596 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
597 FPNumBase
.__init
__(self
, width
, m_extra
)
598 self
.latch_in
= Signal()
603 def elaborate(self
, platform
):
604 m
= FPNumBase
.elaborate(self
, platform
)
606 m
.d
.comb
+= self
.s
.eq(op
.s
)
607 m
.d
.comb
+= self
.e
.eq(op
.e
)
608 m
.d
.comb
+= self
.m
.eq(op
.m
)
610 with self
.mainm
.State("align"):
611 with m
.If(self
.e
< self
.inv
.e
):
612 m
.d
.sync
+= self
.shift_down()
616 def shift_down(self
, inp
):
617 """ shifts a mantissa down by one. exponent is increased to compensate
619 accuracy is lost as a result in the mantissa however there are 3
620 guard bits (the latter of which is the "sticky" bit)
622 return [self
.e
.eq(inp
.e
+ 1),
623 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
626 def shift_down_multi(self
, diff
):
627 """ shifts a mantissa down. exponent is increased to compensate
629 accuracy is lost as a result in the mantissa however there are 3
630 guard bits (the latter of which is the "sticky" bit)
632 this code works by variable-shifting the mantissa by up to
633 its maximum bit-length: no point doing more (it'll still be
636 the sticky bit is computed by shifting a batch of 1s by
637 the same amount, which will introduce zeros. it's then
638 inverted and used as a mask to get the LSBs of the mantissa.
639 those are then |'d into the sticky bit.
641 sm
= MultiShift(self
.width
)
642 mw
= Const(self
.m_width
-1, len(diff
))
643 maxslen
= Mux(diff
> mw
, mw
, diff
)
644 rs
= sm
.rshift(self
.m
[1:], maxslen
)
645 maxsleni
= mw
- maxslen
646 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
648 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
649 return [self
.e
.eq(self
.e
+ diff
),
650 self
.m
.eq(Cat(stickybits
, rs
))
653 def shift_up_multi(self
, diff
):
654 """ shifts a mantissa up. exponent is decreased to compensate
656 sm
= MultiShift(self
.width
)
657 mw
= Const(self
.m_width
, len(diff
))
658 maxslen
= Mux(diff
> mw
, mw
, diff
)
660 return [self
.e
.eq(self
.e
- diff
),
661 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
665 class FPNumDecode(FPNumBase
):
666 """ Floating-point Number Class
668 Contains signals for an incoming copy of the value, decoded into
669 sign / exponent / mantissa.
670 Also contains encoding functions, creation and recognition of
671 zero, NaN and inf (all signed)
673 Four extra bits are included in the mantissa: the top bit
674 (m[-1]) is effectively a carry-overflow. The other three are
675 guard (m[2]), round (m[1]), and sticky (m[0])
678 def __init__(self
, op
, fp
):
679 FPNumBase
.__init
__(self
, fp
)
682 def elaborate(self
, platform
):
683 m
= FPNumBase
.elaborate(self
, platform
)
685 m
.d
.comb
+= self
.decode(self
.v
)
690 """ decodes a latched value into sign / exponent / mantissa
692 bias is subtracted here, from the exponent. exponent
693 is extended to 10 bits so that subtract 127 is done on
696 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
697 #print ("decode", self.e_end)
698 return [self
.m
.eq(Cat(*args
)), # mantissa
699 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
700 self
.s
.eq(v
[-1]), # sign
704 class FPNumIn(FPNumBase
):
705 """ Floating-point Number Class
707 Contains signals for an incoming copy of the value, decoded into
708 sign / exponent / mantissa.
709 Also contains encoding functions, creation and recognition of
710 zero, NaN and inf (all signed)
712 Four extra bits are included in the mantissa: the top bit
713 (m[-1]) is effectively a carry-overflow. The other three are
714 guard (m[2]), round (m[1]), and sticky (m[0])
717 def __init__(self
, op
, fp
):
718 FPNumBase
.__init
__(self
, fp
)
719 self
.latch_in
= Signal()
722 def decode2(self
, m
):
723 """ decodes a latched value into sign / exponent / mantissa
725 bias is subtracted here, from the exponent. exponent
726 is extended to 10 bits so that subtract 127 is done on
730 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
731 #print ("decode", self.e_end)
732 res
= ObjectProxy(m
, pipemode
=False)
733 res
.m
= Cat(*args
) # mantissa
734 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
739 """ decodes a latched value into sign / exponent / mantissa
741 bias is subtracted here, from the exponent. exponent
742 is extended to 10 bits so that subtract 127 is done on
745 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
746 #print ("decode", self.e_end)
747 return [self
.m
.eq(Cat(*args
)), # mantissa
748 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
749 self
.s
.eq(v
[-1]), # sign
752 def shift_down(self
, inp
):
753 """ shifts a mantissa down by one. exponent is increased to compensate
755 accuracy is lost as a result in the mantissa however there are 3
756 guard bits (the latter of which is the "sticky" bit)
758 return [self
.e
.eq(inp
.e
+ 1),
759 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
762 def shift_down_multi(self
, diff
, inp
=None):
763 """ shifts a mantissa down. exponent is increased to compensate
765 accuracy is lost as a result in the mantissa however there are 3
766 guard bits (the latter of which is the "sticky" bit)
768 this code works by variable-shifting the mantissa by up to
769 its maximum bit-length: no point doing more (it'll still be
772 the sticky bit is computed by shifting a batch of 1s by
773 the same amount, which will introduce zeros. it's then
774 inverted and used as a mask to get the LSBs of the mantissa.
775 those are then |'d into the sticky bit.
779 sm
= MultiShift(self
.width
)
780 mw
= Const(self
.m_width
-1, len(diff
))
781 maxslen
= Mux(diff
> mw
, mw
, diff
)
782 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
783 maxsleni
= mw
- maxslen
784 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
786 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
787 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
788 return [self
.e
.eq(inp
.e
+ diff
),
789 self
.m
.eq(Cat(stickybit
, rs
))
792 def shift_up_multi(self
, diff
):
793 """ shifts a mantissa up. exponent is decreased to compensate
795 sm
= MultiShift(self
.width
)
796 mw
= Const(self
.m_width
, len(diff
))
797 maxslen
= Mux(diff
> mw
, mw
, diff
)
799 return [self
.e
.eq(self
.e
- diff
),
800 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
804 class Trigger(Elaboratable
):
807 self
.stb
= Signal(reset
=0)
809 self
.trigger
= Signal(reset_less
=True)
811 def elaborate(self
, platform
):
813 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
817 return [self
.stb
.eq(inp
.stb
),
822 return [self
.stb
, self
.ack
]
825 class FPOpIn(PrevControl
):
826 def __init__(self
, width
):
827 PrevControl
.__init
__(self
)
834 def chain_inv(self
, in_op
, extra
=None):
836 if extra
is not None:
838 return [self
.v
.eq(in_op
.v
), # receive value
839 self
.stb
.eq(stb
), # receive STB
840 in_op
.ack
.eq(~self
.ack
), # send ACK
843 def chain_from(self
, in_op
, extra
=None):
845 if extra
is not None:
847 return [self
.v
.eq(in_op
.v
), # receive value
848 self
.stb
.eq(stb
), # receive STB
849 in_op
.ack
.eq(self
.ack
), # send ACK
853 class FPOpOut(NextControl
):
854 def __init__(self
, width
):
855 NextControl
.__init
__(self
)
862 def chain_inv(self
, in_op
, extra
=None):
864 if extra
is not None:
866 return [self
.v
.eq(in_op
.v
), # receive value
867 self
.stb
.eq(stb
), # receive STB
868 in_op
.ack
.eq(~self
.ack
), # send ACK
871 def chain_from(self
, in_op
, extra
=None):
873 if extra
is not None:
875 return [self
.v
.eq(in_op
.v
), # receive value
876 self
.stb
.eq(stb
), # receive STB
877 in_op
.ack
.eq(self
.ack
), # send ACK
882 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
883 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
884 FFLAGS_OF
= Const(1<<2, 5) # overflow
885 FFLAGS_UF
= Const(1<<1, 5) # underflow
886 FFLAGS_NX
= Const(1<<0, 5) # inexact
887 def __init__(self
, name
=None):
890 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
891 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
892 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
893 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
894 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
896 #self.roundz = Signal(reset_less=True)
906 return [self
.guard
.eq(inp
.guard
),
907 self
.round_bit
.eq(inp
.round_bit
),
908 self
.sticky
.eq(inp
.sticky
),
910 self
.fpflags
.eq(inp
.fpflags
)]
914 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
917 class OverflowMod(Elaboratable
, Overflow
):
918 def __init__(self
, name
=None):
919 Overflow
.__init
__(self
, name
)
922 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
925 yield from Overflow
.__iter
__(self
)
926 yield self
.roundz_out
929 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
931 def elaborate(self
, platform
):
933 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
938 """ IEEE754 Floating Point Base Class
940 contains common functions for FP manipulation, such as
941 extracting and packing operands, normalisation, denormalisation,
945 def get_op(self
, m
, op
, v
, next_state
):
946 """ this function moves to the next state and copies the operand
947 when both stb and ack are 1.
948 acknowledgement is sent by setting ack to ZERO.
952 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
954 # op is latched in from FPNumIn class on same ack/stb
955 m
.d
.comb
+= ack
.eq(0)
957 m
.d
.comb
+= ack
.eq(1)
960 def denormalise(self
, m
, a
):
961 """ denormalises a number. this is probably the wrong name for
962 this function. for normalised numbers (exponent != minimum)
963 one *extra* bit (the implicit 1) is added *back in*.
964 for denormalised numbers, the mantissa is left alone
965 and the exponent increased by 1.
967 both cases *effectively multiply the number stored by 2*,
968 which has to be taken into account when extracting the result.
970 with m
.If(a
.exp_n127
):
971 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
973 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
975 def op_normalise(self
, m
, op
, next_state
):
976 """ operand normalisation
977 NOTE: just like "align", this one keeps going round every clock
978 until the result's exponent is within acceptable "range"
980 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
982 op
.e
.eq(op
.e
- 1), # DECREASE exponent
983 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
988 def normalise_1(self
, m
, z
, of
, next_state
):
989 """ first stage normalisation
991 NOTE: just like "align", this one keeps going round every clock
992 until the result's exponent is within acceptable "range"
993 NOTE: the weirdness of reassigning guard and round is due to
994 the extra mantissa bits coming from tot[0..2]
996 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
998 z
.e
.eq(z
.e
- 1), # DECREASE exponent
999 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1000 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1001 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1002 of
.round_bit
.eq(0), # reset round bit
1008 def normalise_2(self
, m
, z
, of
, next_state
):
1009 """ second stage normalisation
1011 NOTE: just like "align", this one keeps going round every clock
1012 until the result's exponent is within acceptable "range"
1013 NOTE: the weirdness of reassigning guard and round is due to
1014 the extra mantissa bits coming from tot[0..2]
1016 with m
.If(z
.e
< z
.fp
.N126
):
1018 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1019 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1020 of
.guard
.eq(z
.m
[0]),
1022 of
.round_bit
.eq(of
.guard
),
1023 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1028 def roundz(self
, m
, z
, roundz
):
1029 """ performs rounding on the output. TODO: different kinds of rounding
1032 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1033 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1034 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1036 def corrections(self
, m
, z
, next_state
):
1037 """ denormalisation and sign-bug corrections
1040 # denormalised, correct exponent to zero
1041 with m
.If(z
.is_denormalised
):
1042 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1044 def pack(self
, m
, z
, next_state
):
1045 """ packs the result into the output (detects overflow->Inf)
1048 # if overflow occurs, return inf
1049 with m
.If(z
.is_overflowed
):
1050 m
.d
.sync
+= z
.inf(z
.s
)
1052 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1054 def put_z(self
, m
, z
, out_z
, next_state
):
1055 """ put_z: stores the result in the output. raises stb and waits
1056 for ack to be set to 1 before moving to the next state.
1057 resets stb back to zero when that occurs, as acknowledgement.
1062 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1063 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1066 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1069 class FPState(FPBase
):
1070 def __init__(self
, state_from
):
1071 self
.state_from
= state_from
1073 def set_inputs(self
, inputs
):
1074 self
.inputs
= inputs
1075 for k
, v
in inputs
.items():
1078 def set_outputs(self
, outputs
):
1079 self
.outputs
= outputs
1080 for k
, v
in outputs
.items():
1085 def __init__(self
, id_wid
):
1086 self
.id_wid
= id_wid
1088 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1089 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1094 def idsync(self
, m
):
1095 if self
.id_wid
is not None:
1096 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1099 if __name__
== '__main__':