run tests in parallel
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 """IEEE754 Floating Point Library
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jake Lifshay
5
6 """
7
8
9 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
10 from math import log
11 from operator import or_
12 from functools import reduce
13
14 from nmutil.singlepipe import PrevControl, NextControl
15 from nmutil.pipeline import ObjectProxy
16 import unittest
17 import math
18
19
20 class FPFormat:
21 """ Class describing binary floating-point formats based on IEEE 754.
22
23 :attribute e_width: the number of bits in the exponent field.
24 :attribute m_width: the number of bits stored in the mantissa
25 field.
26 :attribute has_int_bit: if the FP format has an explicit integer bit (like
27 the x87 80-bit format). The bit is considered part of the mantissa.
28 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
29 Image/Buffer formats are FP numbers without a sign bit.)
30 """
31
32 def __init__(self,
33 e_width,
34 m_width,
35 has_int_bit=False,
36 has_sign=True):
37 """ Create ``FPFormat`` instance. """
38 self.e_width = e_width
39 self.m_width = m_width
40 self.has_int_bit = has_int_bit
41 self.has_sign = has_sign
42
43 def __eq__(self, other):
44 """ Check for equality. """
45 if not isinstance(other, FPFormat):
46 return NotImplemented
47 return (self.e_width == other.e_width
48 and self.m_width == other.m_width
49 and self.has_int_bit == other.has_int_bit
50 and self.has_sign == other.has_sign)
51
52 @staticmethod
53 def standard(width):
54 """ Get standard IEEE 754-2008 format.
55
56 :param width: bit-width of requested format.
57 :returns: the requested ``FPFormat`` instance.
58 """
59 if width == 16:
60 return FPFormat(5, 10)
61 if width == 32:
62 return FPFormat(8, 23)
63 if width == 64:
64 return FPFormat(11, 52)
65 if width == 128:
66 return FPFormat(15, 112)
67 if width > 128 and width % 32 == 0:
68 if width > 1000000: # arbitrary upper limit
69 raise ValueError("width too big")
70 e_width = round(4 * math.log2(width)) - 13
71 return FPFormat(e_width, width - 1 - e_width)
72 raise ValueError("width must be the bit-width of a valid IEEE"
73 " 754-2008 binary format")
74
75 def __repr__(self):
76 """ Get repr. """
77 try:
78 if self == self.standard(self.width):
79 return f"FPFormat.standard({self.width})"
80 except ValueError:
81 pass
82 retval = f"FPFormat({self.e_width}, {self.m_width}"
83 if self.has_int_bit is not False:
84 retval += f", {self.has_int_bit}"
85 if self.has_sign is not True:
86 retval += f", {self.has_sign}"
87 return retval + ")"
88
89 def get_sign_field(self, x):
90 """ returns the sign bit of its input number, x
91 (assumes FPFormat is set to signed - has_sign=True)
92 """
93 return x >> (self.e_width + self.m_width)
94
95 def get_exponent_field(self, x):
96 """ returns the raw exponent of its input number, x (no bias subtracted)
97 """
98 x = ((x >> self.m_width) & self.exponent_inf_nan)
99 return x
100
101 def get_exponent(self, x):
102 """ returns the exponent of its input number, x
103 """
104 return self.get_exponent_field(x) - self.exponent_bias
105
106 def get_mantissa_field(self, x):
107 """ returns the mantissa of its input number, x
108 """
109 return x & self.mantissa_mask
110
111 def is_zero(self, x):
112 """ returns true if x is +/- zero
113 """
114 return (self.get_exponent(x) == self.e_sub and
115 self.get_mantissa_field(x) == 0)
116
117 def is_subnormal(self, x):
118 """ returns true if x is subnormal (exp at minimum)
119 """
120 return (self.get_exponent(x) == self.e_sub and
121 self.get_mantissa_field(x) != 0)
122
123 def is_inf(self, x):
124 """ returns true if x is infinite
125 """
126 return (self.get_exponent(x) == self.e_max and
127 self.get_mantissa_field(x) == 0)
128
129 def is_nan(self, x):
130 """ returns true if x is a nan (quiet or signalling)
131 """
132 return (self.get_exponent(x) == self.e_max and
133 self.get_mantissa_field(x) != 0)
134
135 def is_quiet_nan(self, x):
136 """ returns true if x is a quiet nan
137 """
138 highbit = 1<<(self.m_width-1)
139 return (self.get_exponent(x) == self.e_max and
140 self.get_mantissa_field(x) != 0 and
141 self.get_mantissa_field(x) & highbit != 0)
142
143 def is_nan_signaling(self, x):
144 """ returns true if x is a signalling nan
145 """
146 highbit = 1<<(self.m_width-1)
147 return ((self.get_exponent(x) == self.e_max) and
148 (self.get_mantissa_field(x) != 0) and
149 (self.get_mantissa_field(x) & highbit) == 0)
150
151 @property
152 def width(self):
153 """ Get the total number of bits in the FP format. """
154 return self.has_sign + self.e_width + self.m_width
155
156 @property
157 def mantissa_mask(self):
158 """ Get a mantissa mask based on the mantissa width """
159 return (1 << self.m_width) - 1
160
161 @property
162 def exponent_inf_nan(self):
163 """ Get the value of the exponent field designating infinity/NaN. """
164 return (1 << self.e_width) - 1
165
166 @property
167 def e_max(self):
168 """ get the maximum exponent (minus bias)
169 """
170 return self.exponent_inf_nan - self.exponent_bias
171
172 @property
173 def e_sub(self):
174 return self.exponent_denormal_zero - self.exponent_bias
175 @property
176 def exponent_denormal_zero(self):
177 """ Get the value of the exponent field designating denormal/zero. """
178 return 0
179
180 @property
181 def exponent_min_normal(self):
182 """ Get the minimum value of the exponent field for normal numbers. """
183 return 1
184
185 @property
186 def exponent_max_normal(self):
187 """ Get the maximum value of the exponent field for normal numbers. """
188 return self.exponent_inf_nan - 1
189
190 @property
191 def exponent_bias(self):
192 """ Get the exponent bias. """
193 return (1 << (self.e_width - 1)) - 1
194
195 @property
196 def fraction_width(self):
197 """ Get the number of mantissa bits that are fraction bits. """
198 return self.m_width - self.has_int_bit
199
200
201 class TestFPFormat(unittest.TestCase):
202 """ very quick test for FPFormat
203 """
204
205 def test_fpformat_fp64(self):
206 f64 = FPFormat.standard(64)
207 from sfpy import Float64
208 x = Float64(1.0).bits
209 print (hex(x))
210
211 self.assertEqual(f64.get_exponent(x), 0)
212 x = Float64(2.0).bits
213 print (hex(x))
214 self.assertEqual(f64.get_exponent(x), 1)
215
216 x = Float64(1.5).bits
217 m = f64.get_mantissa_field(x)
218 print (hex(x), hex(m))
219 self.assertEqual(m, 0x8000000000000)
220
221 s = f64.get_sign_field(x)
222 print (hex(x), hex(s))
223 self.assertEqual(s, 0)
224
225 x = Float64(-1.5).bits
226 s = f64.get_sign_field(x)
227 print (hex(x), hex(s))
228 self.assertEqual(s, 1)
229
230 def test_fpformat_fp32(self):
231 f32 = FPFormat.standard(32)
232 from sfpy import Float32
233 x = Float32(1.0).bits
234 print (hex(x))
235
236 self.assertEqual(f32.get_exponent(x), 0)
237 x = Float32(2.0).bits
238 print (hex(x))
239 self.assertEqual(f32.get_exponent(x), 1)
240
241 x = Float32(1.5).bits
242 m = f32.get_mantissa_field(x)
243 print (hex(x), hex(m))
244 self.assertEqual(m, 0x400000)
245
246 # NaN test
247 x = Float32(-1.0).sqrt()
248 x = x.bits
249 i = f32.is_nan(x)
250 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
251 f32.get_mantissa_field(x), i)
252 self.assertEqual(i, True)
253
254 # Inf test
255 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
256 x = x.bits
257 i = f32.is_inf(x)
258 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
259 f32.get_mantissa_field(x), i)
260 self.assertEqual(i, True)
261
262 # subnormal
263 x = Float32(1e-41)
264 x = x.bits
265 i = f32.is_subnormal(x)
266 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
267 f32.get_mantissa_field(x), i)
268 self.assertEqual(i, True)
269
270 x = Float32(0.0)
271 x = x.bits
272 i = f32.is_subnormal(x)
273 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
274 f32.get_mantissa_field(x), i)
275 self.assertEqual(i, False)
276
277 # zero
278 i = f32.is_zero(x)
279 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
280 f32.get_mantissa_field(x), i)
281 self.assertEqual(i, True)
282
283
284 class MultiShiftR:
285
286 def __init__(self, width):
287 self.width = width
288 self.smax = int(log(width) / log(2))
289 self.i = Signal(width, reset_less=True)
290 self.s = Signal(self.smax, reset_less=True)
291 self.o = Signal(width, reset_less=True)
292
293 def elaborate(self, platform):
294 m = Module()
295 m.d.comb += self.o.eq(self.i >> self.s)
296 return m
297
298
299 class MultiShift:
300 """ Generates variable-length single-cycle shifter from a series
301 of conditional tests on each bit of the left/right shift operand.
302 Each bit tested produces output shifted by that number of bits,
303 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
304 shifts by 2 bits, each partial result cascading to the next Mux.
305
306 Could be adapted to do arithmetic shift by taking copies of the
307 MSB instead of zeros.
308 """
309
310 def __init__(self, width):
311 self.width = width
312 self.smax = int(log(width) / log(2))
313
314 def lshift(self, op, s):
315 res = op << s
316 return res[:len(op)]
317
318 def rshift(self, op, s):
319 res = op >> s
320 return res[:len(op)]
321
322
323 class FPNumBaseRecord:
324 """ Floating-point Base Number Class.
325
326 This class is designed to be passed around in other data structures
327 (between pipelines and between stages). Its "friend" is FPNumBase,
328 which is a *module*. The reason for the discernment is because
329 nmigen modules that are not added to submodules results in the
330 irritating "Elaboration" warning. Despite not *needing* FPNumBase
331 in many cases to be added as a submodule (because it is just data)
332 this was not possible to solve without splitting out the data from
333 the module.
334 """
335
336 def __init__(self, width, m_extra=True, e_extra=False, name=None):
337 if name is None:
338 name = ""
339 # assert false, "missing name"
340 else:
341 name += "_"
342 self.width = width
343 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
344 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
345 e_max = 1 << (e_width-3)
346 self.rmw = m_width - 1 # real mantissa width (not including extras)
347 self.e_max = e_max
348 if m_extra:
349 # mantissa extra bits (top,guard,round)
350 self.m_extra = 3
351 m_width += self.m_extra
352 else:
353 self.m_extra = 0
354 if e_extra:
355 self.e_extra = 6 # enough to cover FP64 when converting to FP16
356 e_width += self.e_extra
357 else:
358 self.e_extra = 0
359 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
360 self.m_width = m_width
361 self.e_width = e_width
362 self.e_start = self.rmw
363 self.e_end = self.rmw + self.e_width - 2 # for decoding
364
365 self.v = Signal(width, reset_less=True,
366 name=name+"v") # Latched copy of value
367 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
368 self.e = Signal((e_width, True),
369 reset_less=True, name=name+"e") # exp+2 bits, signed
370 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
371
372 self.fp = self
373 self.drop_in(self)
374
375 def drop_in(self, fp):
376 fp.s = self.s
377 fp.e = self.e
378 fp.m = self.m
379 fp.v = self.v
380 fp.rmw = self.rmw
381 fp.width = self.width
382 fp.e_width = self.e_width
383 fp.e_max = self.e_max
384 fp.m_width = self.m_width
385 fp.e_start = self.e_start
386 fp.e_end = self.e_end
387 fp.m_extra = self.m_extra
388
389 m_width = self.m_width
390 e_max = self.e_max
391 e_width = self.e_width
392
393 self.mzero = Const(0, (m_width, False))
394 m_msb = 1 << (self.m_width-2)
395 self.msb1 = Const(m_msb, (m_width, False))
396 self.m1s = Const(-1, (m_width, False))
397 self.P128 = Const(e_max, (e_width, True))
398 self.P127 = Const(e_max-1, (e_width, True))
399 self.N127 = Const(-(e_max-1), (e_width, True))
400 self.N126 = Const(-(e_max-2), (e_width, True))
401
402 def create(self, s, e, m):
403 """ creates a value from sign / exponent / mantissa
404
405 bias is added here, to the exponent.
406
407 NOTE: order is important, because e_start/e_end can be
408 a bit too long (overwriting s).
409 """
410 return [
411 self.v[0:self.e_start].eq(m), # mantissa
412 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
413 self.v[-1].eq(s), # sign
414 ]
415
416 def _nan(self, s):
417 return (s, self.fp.P128, 1 << (self.e_start-1))
418
419 def _inf(self, s):
420 return (s, self.fp.P128, 0)
421
422 def _zero(self, s):
423 return (s, self.fp.N127, 0)
424
425 def nan(self, s):
426 return self.create(*self._nan(s))
427
428 def inf(self, s):
429 return self.create(*self._inf(s))
430
431 def zero(self, s):
432 return self.create(*self._zero(s))
433
434 def create2(self, s, e, m):
435 """ creates a value from sign / exponent / mantissa
436
437 bias is added here, to the exponent
438 """
439 e = e + self.P127 # exp (add on bias)
440 return Cat(m[0:self.e_start],
441 e[0:self.e_end-self.e_start],
442 s)
443
444 def nan2(self, s):
445 return self.create2(s, self.P128, self.msb1)
446
447 def inf2(self, s):
448 return self.create2(s, self.P128, self.mzero)
449
450 def zero2(self, s):
451 return self.create2(s, self.N127, self.mzero)
452
453 def __iter__(self):
454 yield self.s
455 yield self.e
456 yield self.m
457
458 def eq(self, inp):
459 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
460
461
462 class FPNumBase(FPNumBaseRecord, Elaboratable):
463 """ Floating-point Base Number Class
464 """
465
466 def __init__(self, fp):
467 fp.drop_in(self)
468 self.fp = fp
469 e_width = fp.e_width
470
471 self.is_nan = Signal(reset_less=True)
472 self.is_zero = Signal(reset_less=True)
473 self.is_inf = Signal(reset_less=True)
474 self.is_overflowed = Signal(reset_less=True)
475 self.is_denormalised = Signal(reset_less=True)
476 self.exp_128 = Signal(reset_less=True)
477 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
478 self.exp_lt_n126 = Signal(reset_less=True)
479 self.exp_zero = Signal(reset_less=True)
480 self.exp_gt_n126 = Signal(reset_less=True)
481 self.exp_gt127 = Signal(reset_less=True)
482 self.exp_n127 = Signal(reset_less=True)
483 self.exp_n126 = Signal(reset_less=True)
484 self.m_zero = Signal(reset_less=True)
485 self.m_msbzero = Signal(reset_less=True)
486
487 def elaborate(self, platform):
488 m = Module()
489 m.d.comb += self.is_nan.eq(self._is_nan())
490 m.d.comb += self.is_zero.eq(self._is_zero())
491 m.d.comb += self.is_inf.eq(self._is_inf())
492 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
493 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
494 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
495 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
496 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
497 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
498 m.d.comb += self.exp_zero.eq(self.e == 0)
499 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
500 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
501 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
502 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
503 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
504
505 return m
506
507 def _is_nan(self):
508 return (self.exp_128) & (~self.m_zero)
509
510 def _is_inf(self):
511 return (self.exp_128) & (self.m_zero)
512
513 def _is_zero(self):
514 return (self.exp_n127) & (self.m_zero)
515
516 def _is_overflowed(self):
517 return self.exp_gt127
518
519 def _is_denormalised(self):
520 # XXX NOT to be used for "official" quiet NaN tests!
521 # particularly when the MSB has been extended
522 return (self.exp_n126) & (self.m_msbzero)
523
524
525 class FPNumOut(FPNumBase):
526 """ Floating-point Number Class
527
528 Contains signals for an incoming copy of the value, decoded into
529 sign / exponent / mantissa.
530 Also contains encoding functions, creation and recognition of
531 zero, NaN and inf (all signed)
532
533 Four extra bits are included in the mantissa: the top bit
534 (m[-1]) is effectively a carry-overflow. The other three are
535 guard (m[2]), round (m[1]), and sticky (m[0])
536 """
537
538 def __init__(self, fp):
539 FPNumBase.__init__(self, fp)
540
541 def elaborate(self, platform):
542 m = FPNumBase.elaborate(self, platform)
543
544 return m
545
546
547 class MultiShiftRMerge(Elaboratable):
548 """ shifts down (right) and merges lower bits into m[0].
549 m[0] is the "sticky" bit, basically
550 """
551
552 def __init__(self, width, s_max=None):
553 if s_max is None:
554 s_max = int(log(width) / log(2))
555 self.smax = s_max
556 self.m = Signal(width, reset_less=True)
557 self.inp = Signal(width, reset_less=True)
558 self.diff = Signal(s_max, reset_less=True)
559 self.width = width
560
561 def elaborate(self, platform):
562 m = Module()
563
564 rs = Signal(self.width, reset_less=True)
565 m_mask = Signal(self.width, reset_less=True)
566 smask = Signal(self.width, reset_less=True)
567 stickybit = Signal(reset_less=True)
568 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
569 maxslen = Signal(self.smax[0], reset_less=True)
570 maxsleni = Signal(self.smax[0], reset_less=True)
571
572 sm = MultiShift(self.width-1)
573 m0s = Const(0, self.width-1)
574 mw = Const(self.width-1, len(self.diff))
575 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
576 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
577 ]
578
579 m.d.comb += [
580 # shift mantissa by maxslen, mask by inverse
581 rs.eq(sm.rshift(self.inp[1:], maxslen)),
582 m_mask.eq(sm.rshift(~m0s, maxsleni)),
583 smask.eq(self.inp[1:] & m_mask),
584 # sticky bit combines all mask (and mantissa low bit)
585 stickybit.eq(smask.bool() | self.inp[0]),
586 # mantissa result contains m[0] already.
587 self.m.eq(Cat(stickybit, rs))
588 ]
589 return m
590
591
592 class FPNumShift(FPNumBase, Elaboratable):
593 """ Floating-point Number Class for shifting
594 """
595
596 def __init__(self, mainm, op, inv, width, m_extra=True):
597 FPNumBase.__init__(self, width, m_extra)
598 self.latch_in = Signal()
599 self.mainm = mainm
600 self.inv = inv
601 self.op = op
602
603 def elaborate(self, platform):
604 m = FPNumBase.elaborate(self, platform)
605
606 m.d.comb += self.s.eq(op.s)
607 m.d.comb += self.e.eq(op.e)
608 m.d.comb += self.m.eq(op.m)
609
610 with self.mainm.State("align"):
611 with m.If(self.e < self.inv.e):
612 m.d.sync += self.shift_down()
613
614 return m
615
616 def shift_down(self, inp):
617 """ shifts a mantissa down by one. exponent is increased to compensate
618
619 accuracy is lost as a result in the mantissa however there are 3
620 guard bits (the latter of which is the "sticky" bit)
621 """
622 return [self.e.eq(inp.e + 1),
623 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
624 ]
625
626 def shift_down_multi(self, diff):
627 """ shifts a mantissa down. exponent is increased to compensate
628
629 accuracy is lost as a result in the mantissa however there are 3
630 guard bits (the latter of which is the "sticky" bit)
631
632 this code works by variable-shifting the mantissa by up to
633 its maximum bit-length: no point doing more (it'll still be
634 zero).
635
636 the sticky bit is computed by shifting a batch of 1s by
637 the same amount, which will introduce zeros. it's then
638 inverted and used as a mask to get the LSBs of the mantissa.
639 those are then |'d into the sticky bit.
640 """
641 sm = MultiShift(self.width)
642 mw = Const(self.m_width-1, len(diff))
643 maxslen = Mux(diff > mw, mw, diff)
644 rs = sm.rshift(self.m[1:], maxslen)
645 maxsleni = mw - maxslen
646 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
647
648 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
649 return [self.e.eq(self.e + diff),
650 self.m.eq(Cat(stickybits, rs))
651 ]
652
653 def shift_up_multi(self, diff):
654 """ shifts a mantissa up. exponent is decreased to compensate
655 """
656 sm = MultiShift(self.width)
657 mw = Const(self.m_width, len(diff))
658 maxslen = Mux(diff > mw, mw, diff)
659
660 return [self.e.eq(self.e - diff),
661 self.m.eq(sm.lshift(self.m, maxslen))
662 ]
663
664
665 class FPNumDecode(FPNumBase):
666 """ Floating-point Number Class
667
668 Contains signals for an incoming copy of the value, decoded into
669 sign / exponent / mantissa.
670 Also contains encoding functions, creation and recognition of
671 zero, NaN and inf (all signed)
672
673 Four extra bits are included in the mantissa: the top bit
674 (m[-1]) is effectively a carry-overflow. The other three are
675 guard (m[2]), round (m[1]), and sticky (m[0])
676 """
677
678 def __init__(self, op, fp):
679 FPNumBase.__init__(self, fp)
680 self.op = op
681
682 def elaborate(self, platform):
683 m = FPNumBase.elaborate(self, platform)
684
685 m.d.comb += self.decode(self.v)
686
687 return m
688
689 def decode(self, v):
690 """ decodes a latched value into sign / exponent / mantissa
691
692 bias is subtracted here, from the exponent. exponent
693 is extended to 10 bits so that subtract 127 is done on
694 a 10-bit number
695 """
696 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
697 #print ("decode", self.e_end)
698 return [self.m.eq(Cat(*args)), # mantissa
699 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
700 self.s.eq(v[-1]), # sign
701 ]
702
703
704 class FPNumIn(FPNumBase):
705 """ Floating-point Number Class
706
707 Contains signals for an incoming copy of the value, decoded into
708 sign / exponent / mantissa.
709 Also contains encoding functions, creation and recognition of
710 zero, NaN and inf (all signed)
711
712 Four extra bits are included in the mantissa: the top bit
713 (m[-1]) is effectively a carry-overflow. The other three are
714 guard (m[2]), round (m[1]), and sticky (m[0])
715 """
716
717 def __init__(self, op, fp):
718 FPNumBase.__init__(self, fp)
719 self.latch_in = Signal()
720 self.op = op
721
722 def decode2(self, m):
723 """ decodes a latched value into sign / exponent / mantissa
724
725 bias is subtracted here, from the exponent. exponent
726 is extended to 10 bits so that subtract 127 is done on
727 a 10-bit number
728 """
729 v = self.v
730 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
731 #print ("decode", self.e_end)
732 res = ObjectProxy(m, pipemode=False)
733 res.m = Cat(*args) # mantissa
734 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
735 res.s = v[-1] # sign
736 return res
737
738 def decode(self, v):
739 """ decodes a latched value into sign / exponent / mantissa
740
741 bias is subtracted here, from the exponent. exponent
742 is extended to 10 bits so that subtract 127 is done on
743 a 10-bit number
744 """
745 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
746 #print ("decode", self.e_end)
747 return [self.m.eq(Cat(*args)), # mantissa
748 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
749 self.s.eq(v[-1]), # sign
750 ]
751
752 def shift_down(self, inp):
753 """ shifts a mantissa down by one. exponent is increased to compensate
754
755 accuracy is lost as a result in the mantissa however there are 3
756 guard bits (the latter of which is the "sticky" bit)
757 """
758 return [self.e.eq(inp.e + 1),
759 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
760 ]
761
762 def shift_down_multi(self, diff, inp=None):
763 """ shifts a mantissa down. exponent is increased to compensate
764
765 accuracy is lost as a result in the mantissa however there are 3
766 guard bits (the latter of which is the "sticky" bit)
767
768 this code works by variable-shifting the mantissa by up to
769 its maximum bit-length: no point doing more (it'll still be
770 zero).
771
772 the sticky bit is computed by shifting a batch of 1s by
773 the same amount, which will introduce zeros. it's then
774 inverted and used as a mask to get the LSBs of the mantissa.
775 those are then |'d into the sticky bit.
776 """
777 if inp is None:
778 inp = self
779 sm = MultiShift(self.width)
780 mw = Const(self.m_width-1, len(diff))
781 maxslen = Mux(diff > mw, mw, diff)
782 rs = sm.rshift(inp.m[1:], maxslen)
783 maxsleni = mw - maxslen
784 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
785
786 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
787 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
788 return [self.e.eq(inp.e + diff),
789 self.m.eq(Cat(stickybit, rs))
790 ]
791
792 def shift_up_multi(self, diff):
793 """ shifts a mantissa up. exponent is decreased to compensate
794 """
795 sm = MultiShift(self.width)
796 mw = Const(self.m_width, len(diff))
797 maxslen = Mux(diff > mw, mw, diff)
798
799 return [self.e.eq(self.e - diff),
800 self.m.eq(sm.lshift(self.m, maxslen))
801 ]
802
803
804 class Trigger(Elaboratable):
805 def __init__(self):
806
807 self.stb = Signal(reset=0)
808 self.ack = Signal()
809 self.trigger = Signal(reset_less=True)
810
811 def elaborate(self, platform):
812 m = Module()
813 m.d.comb += self.trigger.eq(self.stb & self.ack)
814 return m
815
816 def eq(self, inp):
817 return [self.stb.eq(inp.stb),
818 self.ack.eq(inp.ack)
819 ]
820
821 def ports(self):
822 return [self.stb, self.ack]
823
824
825 class FPOpIn(PrevControl):
826 def __init__(self, width):
827 PrevControl.__init__(self)
828 self.width = width
829
830 @property
831 def v(self):
832 return self.data_i
833
834 def chain_inv(self, in_op, extra=None):
835 stb = in_op.stb
836 if extra is not None:
837 stb = stb & extra
838 return [self.v.eq(in_op.v), # receive value
839 self.stb.eq(stb), # receive STB
840 in_op.ack.eq(~self.ack), # send ACK
841 ]
842
843 def chain_from(self, in_op, extra=None):
844 stb = in_op.stb
845 if extra is not None:
846 stb = stb & extra
847 return [self.v.eq(in_op.v), # receive value
848 self.stb.eq(stb), # receive STB
849 in_op.ack.eq(self.ack), # send ACK
850 ]
851
852
853 class FPOpOut(NextControl):
854 def __init__(self, width):
855 NextControl.__init__(self)
856 self.width = width
857
858 @property
859 def v(self):
860 return self.data_o
861
862 def chain_inv(self, in_op, extra=None):
863 stb = in_op.stb
864 if extra is not None:
865 stb = stb & extra
866 return [self.v.eq(in_op.v), # receive value
867 self.stb.eq(stb), # receive STB
868 in_op.ack.eq(~self.ack), # send ACK
869 ]
870
871 def chain_from(self, in_op, extra=None):
872 stb = in_op.stb
873 if extra is not None:
874 stb = stb & extra
875 return [self.v.eq(in_op.v), # receive value
876 self.stb.eq(stb), # receive STB
877 in_op.ack.eq(self.ack), # send ACK
878 ]
879
880
881 class Overflow:
882 FFLAGS_NV = Const(1<<4, 5) # invalid operation
883 FFLAGS_DZ = Const(1<<3, 5) # divide by zero
884 FFLAGS_OF = Const(1<<2, 5) # overflow
885 FFLAGS_UF = Const(1<<1, 5) # underflow
886 FFLAGS_NX = Const(1<<0, 5) # inexact
887 def __init__(self, name=None):
888 if name is None:
889 name = ""
890 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
891 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
892 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
893 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
894 self.fpflags = Signal(5, reset_less=True, name=name+"fflags")
895
896 #self.roundz = Signal(reset_less=True)
897
898 def __iter__(self):
899 yield self.guard
900 yield self.round_bit
901 yield self.sticky
902 yield self.m0
903 yield self.fpflags
904
905 def eq(self, inp):
906 return [self.guard.eq(inp.guard),
907 self.round_bit.eq(inp.round_bit),
908 self.sticky.eq(inp.sticky),
909 self.m0.eq(inp.m0),
910 self.fpflags.eq(inp.fpflags)]
911
912 @property
913 def roundz(self):
914 return self.guard & (self.round_bit | self.sticky | self.m0)
915
916
917 class OverflowMod(Elaboratable, Overflow):
918 def __init__(self, name=None):
919 Overflow.__init__(self, name)
920 if name is None:
921 name = ""
922 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
923
924 def __iter__(self):
925 yield from Overflow.__iter__(self)
926 yield self.roundz_out
927
928 def eq(self, inp):
929 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
930
931 def elaborate(self, platform):
932 m = Module()
933 m.d.comb += self.roundz_out.eq(self.roundz) # roundz is a property
934 return m
935
936
937 class FPBase:
938 """ IEEE754 Floating Point Base Class
939
940 contains common functions for FP manipulation, such as
941 extracting and packing operands, normalisation, denormalisation,
942 rounding etc.
943 """
944
945 def get_op(self, m, op, v, next_state):
946 """ this function moves to the next state and copies the operand
947 when both stb and ack are 1.
948 acknowledgement is sent by setting ack to ZERO.
949 """
950 res = v.decode2(m)
951 ack = Signal()
952 with m.If((op.ready_o) & (op.valid_i_test)):
953 m.next = next_state
954 # op is latched in from FPNumIn class on same ack/stb
955 m.d.comb += ack.eq(0)
956 with m.Else():
957 m.d.comb += ack.eq(1)
958 return [res, ack]
959
960 def denormalise(self, m, a):
961 """ denormalises a number. this is probably the wrong name for
962 this function. for normalised numbers (exponent != minimum)
963 one *extra* bit (the implicit 1) is added *back in*.
964 for denormalised numbers, the mantissa is left alone
965 and the exponent increased by 1.
966
967 both cases *effectively multiply the number stored by 2*,
968 which has to be taken into account when extracting the result.
969 """
970 with m.If(a.exp_n127):
971 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
972 with m.Else():
973 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
974
975 def op_normalise(self, m, op, next_state):
976 """ operand normalisation
977 NOTE: just like "align", this one keeps going round every clock
978 until the result's exponent is within acceptable "range"
979 """
980 with m.If((op.m[-1] == 0)): # check last bit of mantissa
981 m.d.sync += [
982 op.e.eq(op.e - 1), # DECREASE exponent
983 op.m.eq(op.m << 1), # shift mantissa UP
984 ]
985 with m.Else():
986 m.next = next_state
987
988 def normalise_1(self, m, z, of, next_state):
989 """ first stage normalisation
990
991 NOTE: just like "align", this one keeps going round every clock
992 until the result's exponent is within acceptable "range"
993 NOTE: the weirdness of reassigning guard and round is due to
994 the extra mantissa bits coming from tot[0..2]
995 """
996 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
997 m.d.sync += [
998 z.e.eq(z.e - 1), # DECREASE exponent
999 z.m.eq(z.m << 1), # shift mantissa UP
1000 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
1001 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
1002 of.round_bit.eq(0), # reset round bit
1003 of.m0.eq(of.guard),
1004 ]
1005 with m.Else():
1006 m.next = next_state
1007
1008 def normalise_2(self, m, z, of, next_state):
1009 """ second stage normalisation
1010
1011 NOTE: just like "align", this one keeps going round every clock
1012 until the result's exponent is within acceptable "range"
1013 NOTE: the weirdness of reassigning guard and round is due to
1014 the extra mantissa bits coming from tot[0..2]
1015 """
1016 with m.If(z.e < z.fp.N126):
1017 m.d.sync += [
1018 z.e.eq(z.e + 1), # INCREASE exponent
1019 z.m.eq(z.m >> 1), # shift mantissa DOWN
1020 of.guard.eq(z.m[0]),
1021 of.m0.eq(z.m[1]),
1022 of.round_bit.eq(of.guard),
1023 of.sticky.eq(of.sticky | of.round_bit)
1024 ]
1025 with m.Else():
1026 m.next = next_state
1027
1028 def roundz(self, m, z, roundz):
1029 """ performs rounding on the output. TODO: different kinds of rounding
1030 """
1031 with m.If(roundz):
1032 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1033 with m.If(z.m == z.fp.m1s): # all 1s
1034 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1035
1036 def corrections(self, m, z, next_state):
1037 """ denormalisation and sign-bug corrections
1038 """
1039 m.next = next_state
1040 # denormalised, correct exponent to zero
1041 with m.If(z.is_denormalised):
1042 m.d.sync += z.e.eq(z.fp.N127)
1043
1044 def pack(self, m, z, next_state):
1045 """ packs the result into the output (detects overflow->Inf)
1046 """
1047 m.next = next_state
1048 # if overflow occurs, return inf
1049 with m.If(z.is_overflowed):
1050 m.d.sync += z.inf(z.s)
1051 with m.Else():
1052 m.d.sync += z.create(z.s, z.e, z.m)
1053
1054 def put_z(self, m, z, out_z, next_state):
1055 """ put_z: stores the result in the output. raises stb and waits
1056 for ack to be set to 1 before moving to the next state.
1057 resets stb back to zero when that occurs, as acknowledgement.
1058 """
1059 m.d.sync += [
1060 out_z.v.eq(z.v)
1061 ]
1062 with m.If(out_z.valid_o & out_z.ready_i_test):
1063 m.d.sync += out_z.valid_o.eq(0)
1064 m.next = next_state
1065 with m.Else():
1066 m.d.sync += out_z.valid_o.eq(1)
1067
1068
1069 class FPState(FPBase):
1070 def __init__(self, state_from):
1071 self.state_from = state_from
1072
1073 def set_inputs(self, inputs):
1074 self.inputs = inputs
1075 for k, v in inputs.items():
1076 setattr(self, k, v)
1077
1078 def set_outputs(self, outputs):
1079 self.outputs = outputs
1080 for k, v in outputs.items():
1081 setattr(self, k, v)
1082
1083
1084 class FPID:
1085 def __init__(self, id_wid):
1086 self.id_wid = id_wid
1087 if self.id_wid:
1088 self.in_mid = Signal(id_wid, reset_less=True)
1089 self.out_mid = Signal(id_wid, reset_less=True)
1090 else:
1091 self.in_mid = None
1092 self.out_mid = None
1093
1094 def idsync(self, m):
1095 if self.id_wid is not None:
1096 m.d.sync += self.out_mid.eq(self.in_mid)
1097
1098
1099 if __name__ == '__main__':
1100 unittest.main()