split out adder code (PartitionedAdder) into module, PartitionPoints too
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 """IEEE754 Floating Point Library
2
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019 Jake Lifshay
5
6 """
7
8
9 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
10 from math import log
11 from operator import or_
12 from functools import reduce
13
14 from nmutil.singlepipe import PrevControl, NextControl
15 from nmutil.pipeline import ObjectProxy
16 import unittest
17 import math
18
19
20 class FPFormat:
21 """ Class describing binary floating-point formats based on IEEE 754.
22
23 :attribute e_width: the number of bits in the exponent field.
24 :attribute m_width: the number of bits stored in the mantissa
25 field.
26 :attribute has_int_bit: if the FP format has an explicit integer bit (like
27 the x87 80-bit format). The bit is considered part of the mantissa.
28 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
29 Image/Buffer formats are FP numbers without a sign bit.)
30 """
31
32 def __init__(self,
33 e_width,
34 m_width,
35 has_int_bit=False,
36 has_sign=True):
37 """ Create ``FPFormat`` instance. """
38 self.e_width = e_width
39 self.m_width = m_width
40 self.has_int_bit = has_int_bit
41 self.has_sign = has_sign
42
43 def __eq__(self, other):
44 """ Check for equality. """
45 if not isinstance(other, FPFormat):
46 return NotImplemented
47 return (self.e_width == other.e_width
48 and self.m_width == other.m_width
49 and self.has_int_bit == other.has_int_bit
50 and self.has_sign == other.has_sign)
51
52 @staticmethod
53 def standard(width):
54 """ Get standard IEEE 754-2008 format.
55
56 :param width: bit-width of requested format.
57 :returns: the requested ``FPFormat`` instance.
58 """
59 if width == 16:
60 return FPFormat(5, 10)
61 if width == 32:
62 return FPFormat(8, 23)
63 if width == 64:
64 return FPFormat(11, 52)
65 if width == 128:
66 return FPFormat(15, 112)
67 if width > 128 and width % 32 == 0:
68 if width > 1000000: # arbitrary upper limit
69 raise ValueError("width too big")
70 e_width = round(4 * math.log2(width)) - 13
71 return FPFormat(e_width, width - 1 - e_width)
72 raise ValueError("width must be the bit-width of a valid IEEE"
73 " 754-2008 binary format")
74
75 def __repr__(self):
76 """ Get repr. """
77 try:
78 if self == self.standard(self.width):
79 return f"FPFormat.standard({self.width})"
80 except ValueError:
81 pass
82 retval = f"FPFormat({self.e_width}, {self.m_width}"
83 if self.has_int_bit is not False:
84 retval += f", {self.has_int_bit}"
85 if self.has_sign is not True:
86 retval += f", {self.has_sign}"
87 return retval + ")"
88
89 def get_sign_field(self, x):
90 """ returns the sign bit of its input number, x
91 (assumes FPFormat is set to signed - has_sign=True)
92 """
93 return x >> (self.e_width + self.m_width)
94
95 def get_exponent_field(self, x):
96 """ returns the raw exponent of its input number, x (no bias subtracted)
97 """
98 x = ((x >> self.m_width) & self.exponent_inf_nan)
99 return x
100
101 def get_exponent(self, x):
102 """ returns the exponent of its input number, x
103 """
104 return self.get_exponent_field(x) - self.exponent_bias
105
106 def get_mantissa_field(self, x):
107 """ returns the mantissa of its input number, x
108 """
109 return x & self.mantissa_mask
110
111 def is_zero(self, x):
112 """ returns true if x is +/- zero
113 """
114 return (self.get_exponent(x) == self.e_sub and
115 self.get_mantissa_field(x) == 0)
116
117 def is_subnormal(self, x):
118 """ returns true if x is subnormal (exp at minimum)
119 """
120 return (self.get_exponent(x) == self.e_sub and
121 self.get_mantissa_field(x) != 0)
122
123 def is_inf(self, x):
124 """ returns true if x is infinite
125 """
126 return (self.get_exponent(x) == self.e_max and
127 self.get_mantissa_field(x) == 0)
128
129 def is_nan(self, x):
130 """ returns true if x is a nan (quiet or signalling)
131 """
132 return (self.get_exponent(x) == self.e_max and
133 self.get_mantissa_field(x) != 0)
134
135 def is_quiet_nan(self, x):
136 """ returns true if x is a quiet nan
137 """
138 highbit = 1<<(self.m_width-1)
139 return (self.get_exponent(x) == self.e_max and
140 self.get_mantissa_field(x) != 0 and
141 self.get_mantissa_field(x) & highbit != 0)
142
143 def is_nan_signaling(self, x):
144 """ returns true if x is a signalling nan
145 """
146 highbit = 1<<(self.m_width-1)
147 return ((self.get_exponent(x) == self.e_max) and
148 (self.get_mantissa_field(x) != 0) and
149 (self.get_mantissa_field(x) & highbit) == 0)
150
151 @property
152 def width(self):
153 """ Get the total number of bits in the FP format. """
154 return self.has_sign + self.e_width + self.m_width
155
156 @property
157 def mantissa_mask(self):
158 """ Get a mantissa mask based on the mantissa width """
159 return (1 << self.m_width) - 1
160
161 @property
162 def exponent_inf_nan(self):
163 """ Get the value of the exponent field designating infinity/NaN. """
164 return (1 << self.e_width) - 1
165
166 @property
167 def e_max(self):
168 """ get the maximum exponent (minus bias)
169 """
170 return self.exponent_inf_nan - self.exponent_bias
171
172 @property
173 def e_sub(self):
174 return self.exponent_denormal_zero - self.exponent_bias
175 @property
176 def exponent_denormal_zero(self):
177 """ Get the value of the exponent field designating denormal/zero. """
178 return 0
179
180 @property
181 def exponent_min_normal(self):
182 """ Get the minimum value of the exponent field for normal numbers. """
183 return 1
184
185 @property
186 def exponent_max_normal(self):
187 """ Get the maximum value of the exponent field for normal numbers. """
188 return self.exponent_inf_nan - 1
189
190 @property
191 def exponent_bias(self):
192 """ Get the exponent bias. """
193 return (1 << (self.e_width - 1)) - 1
194
195 @property
196 def fraction_width(self):
197 """ Get the number of mantissa bits that are fraction bits. """
198 return self.m_width - self.has_int_bit
199
200
201 class TestFPFormat(unittest.TestCase):
202 """ very quick test for FPFormat
203 """
204
205 def test_fpformat_fp64(self):
206 f64 = FPFormat.standard(64)
207 from sfpy import Float64
208 x = Float64(1.0).bits
209 print (hex(x))
210
211 self.assertEqual(f64.get_exponent(x), 0)
212 x = Float64(2.0).bits
213 print (hex(x))
214 self.assertEqual(f64.get_exponent(x), 1)
215
216 x = Float64(1.5).bits
217 m = f64.get_mantissa_field(x)
218 print (hex(x), hex(m))
219 self.assertEqual(m, 0x8000000000000)
220
221 s = f64.get_sign_field(x)
222 print (hex(x), hex(s))
223 self.assertEqual(s, 0)
224
225 x = Float64(-1.5).bits
226 s = f64.get_sign_field(x)
227 print (hex(x), hex(s))
228 self.assertEqual(s, 1)
229
230 def test_fpformat_fp32(self):
231 f32 = FPFormat.standard(32)
232 from sfpy import Float32
233 x = Float32(1.0).bits
234 print (hex(x))
235
236 self.assertEqual(f32.get_exponent(x), 0)
237 x = Float32(2.0).bits
238 print (hex(x))
239 self.assertEqual(f32.get_exponent(x), 1)
240
241 x = Float32(1.5).bits
242 m = f32.get_mantissa_field(x)
243 print (hex(x), hex(m))
244 self.assertEqual(m, 0x400000)
245
246 # NaN test
247 x = Float32(-1.0).sqrt()
248 x = x.bits
249 i = f32.is_nan(x)
250 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
251 f32.get_mantissa_field(x), i)
252 self.assertEqual(i, True)
253
254 # Inf test
255 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
256 x = x.bits
257 i = f32.is_inf(x)
258 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
259 f32.get_mantissa_field(x), i)
260 self.assertEqual(i, True)
261
262 # subnormal
263 x = Float32(1e-41)
264 x = x.bits
265 i = f32.is_subnormal(x)
266 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
267 f32.get_mantissa_field(x), i)
268 self.assertEqual(i, True)
269
270 x = Float32(0.0)
271 x = x.bits
272 i = f32.is_subnormal(x)
273 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
274 f32.get_mantissa_field(x), i)
275 self.assertEqual(i, False)
276
277 # zero
278 i = f32.is_zero(x)
279 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
280 f32.get_mantissa_field(x), i)
281 self.assertEqual(i, True)
282
283
284 class MultiShiftR:
285
286 def __init__(self, width):
287 self.width = width
288 self.smax = int(log(width) / log(2))
289 self.i = Signal(width, reset_less=True)
290 self.s = Signal(self.smax, reset_less=True)
291 self.o = Signal(width, reset_less=True)
292
293 def elaborate(self, platform):
294 m = Module()
295 m.d.comb += self.o.eq(self.i >> self.s)
296 return m
297
298
299 class MultiShift:
300 """ Generates variable-length single-cycle shifter from a series
301 of conditional tests on each bit of the left/right shift operand.
302 Each bit tested produces output shifted by that number of bits,
303 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
304 shifts by 2 bits, each partial result cascading to the next Mux.
305
306 Could be adapted to do arithmetic shift by taking copies of the
307 MSB instead of zeros.
308 """
309
310 def __init__(self, width):
311 self.width = width
312 self.smax = int(log(width) / log(2))
313
314 def lshift(self, op, s):
315 res = op << s
316 return res[:len(op)]
317
318 def rshift(self, op, s):
319 res = op >> s
320 return res[:len(op)]
321
322
323 class FPNumBaseRecord:
324 """ Floating-point Base Number Class.
325
326 This class is designed to be passed around in other data structures
327 (between pipelines and between stages). Its "friend" is FPNumBase,
328 which is a *module*. The reason for the discernment is because
329 nmigen modules that are not added to submodules results in the
330 irritating "Elaboration" warning. Despite not *needing* FPNumBase
331 in many cases to be added as a submodule (because it is just data)
332 this was not possible to solve without splitting out the data from
333 the module.
334 """
335
336 def __init__(self, width, m_extra=True, e_extra=False, name=None):
337 if name is None:
338 name = ""
339 # assert false, "missing name"
340 else:
341 name += "_"
342 self.width = width
343 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
344 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
345 e_max = 1 << (e_width-3)
346 self.rmw = m_width - 1 # real mantissa width (not including extras)
347 self.e_max = e_max
348 if m_extra:
349 # mantissa extra bits (top,guard,round)
350 self.m_extra = 3
351 m_width += self.m_extra
352 else:
353 self.m_extra = 0
354 if e_extra:
355 self.e_extra = 6 # enough to cover FP64 when converting to FP16
356 e_width += self.e_extra
357 else:
358 self.e_extra = 0
359 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
360 self.m_width = m_width
361 self.e_width = e_width
362 self.e_start = self.rmw
363 self.e_end = self.rmw + self.e_width - 2 # for decoding
364
365 self.v = Signal(width, reset_less=True,
366 name=name+"v") # Latched copy of value
367 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
368 self.e = Signal((e_width, True),
369 reset_less=True, name=name+"e") # exp+2 bits, signed
370 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
371
372 self.fp = self
373 self.drop_in(self)
374
375 def drop_in(self, fp):
376 fp.s = self.s
377 fp.e = self.e
378 fp.m = self.m
379 fp.v = self.v
380 fp.rmw = self.rmw
381 fp.width = self.width
382 fp.e_width = self.e_width
383 fp.e_max = self.e_max
384 fp.m_width = self.m_width
385 fp.e_start = self.e_start
386 fp.e_end = self.e_end
387 fp.m_extra = self.m_extra
388
389 m_width = self.m_width
390 e_max = self.e_max
391 e_width = self.e_width
392
393 self.mzero = Const(0, (m_width, False))
394 m_msb = 1 << (self.m_width-2)
395 self.msb1 = Const(m_msb, (m_width, False))
396 self.m1s = Const(-1, (m_width, False))
397 self.P128 = Const(e_max, (e_width, True))
398 self.P127 = Const(e_max-1, (e_width, True))
399 self.N127 = Const(-(e_max-1), (e_width, True))
400 self.N126 = Const(-(e_max-2), (e_width, True))
401
402 def create(self, s, e, m):
403 """ creates a value from sign / exponent / mantissa
404
405 bias is added here, to the exponent.
406
407 NOTE: order is important, because e_start/e_end can be
408 a bit too long (overwriting s).
409 """
410 return [
411 self.v[0:self.e_start].eq(m), # mantissa
412 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
413 self.v[-1].eq(s), # sign
414 ]
415
416 def _nan(self, s):
417 return (s, self.fp.P128, 1 << (self.e_start-1))
418
419 def _inf(self, s):
420 return (s, self.fp.P128, 0)
421
422 def _zero(self, s):
423 return (s, self.fp.N127, 0)
424
425 def nan(self, s):
426 return self.create(*self._nan(s))
427
428 def inf(self, s):
429 return self.create(*self._inf(s))
430
431 def zero(self, s):
432 return self.create(*self._zero(s))
433
434 def create2(self, s, e, m):
435 """ creates a value from sign / exponent / mantissa
436
437 bias is added here, to the exponent
438 """
439 e = e + self.P127 # exp (add on bias)
440 return Cat(m[0:self.e_start],
441 e[0:self.e_end-self.e_start],
442 s)
443
444 def nan2(self, s):
445 return self.create2(s, self.P128, self.msb1)
446
447 def inf2(self, s):
448 return self.create2(s, self.P128, self.mzero)
449
450 def zero2(self, s):
451 return self.create2(s, self.N127, self.mzero)
452
453 def __iter__(self):
454 yield self.s
455 yield self.e
456 yield self.m
457
458 def eq(self, inp):
459 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
460
461
462 class FPNumBase(FPNumBaseRecord, Elaboratable):
463 """ Floating-point Base Number Class
464 """
465
466 def __init__(self, fp):
467 fp.drop_in(self)
468 self.fp = fp
469 e_width = fp.e_width
470
471 self.is_nan = Signal(reset_less=True)
472 self.is_zero = Signal(reset_less=True)
473 self.is_inf = Signal(reset_less=True)
474 self.is_overflowed = Signal(reset_less=True)
475 self.is_denormalised = Signal(reset_less=True)
476 self.exp_128 = Signal(reset_less=True)
477 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
478 self.exp_lt_n126 = Signal(reset_less=True)
479 self.exp_zero = Signal(reset_less=True)
480 self.exp_gt_n126 = Signal(reset_less=True)
481 self.exp_gt127 = Signal(reset_less=True)
482 self.exp_n127 = Signal(reset_less=True)
483 self.exp_n126 = Signal(reset_less=True)
484 self.m_zero = Signal(reset_less=True)
485 self.m_msbzero = Signal(reset_less=True)
486
487 def elaborate(self, platform):
488 m = Module()
489 m.d.comb += self.is_nan.eq(self._is_nan())
490 m.d.comb += self.is_zero.eq(self._is_zero())
491 m.d.comb += self.is_inf.eq(self._is_inf())
492 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
493 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
494 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
495 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
496 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
497 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
498 m.d.comb += self.exp_zero.eq(self.e == 0)
499 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
500 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
501 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
502 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
503 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
504
505 return m
506
507 def _is_nan(self):
508 return (self.exp_128) & (~self.m_zero)
509
510 def _is_inf(self):
511 return (self.exp_128) & (self.m_zero)
512
513 def _is_zero(self):
514 return (self.exp_n127) & (self.m_zero)
515
516 def _is_overflowed(self):
517 return self.exp_gt127
518
519 def _is_denormalised(self):
520 # XXX NOT to be used for "official" quiet NaN tests!
521 # particularly when the MSB has been extended
522 return (self.exp_n126) & (self.m_msbzero)
523
524
525 class FPNumOut(FPNumBase):
526 """ Floating-point Number Class
527
528 Contains signals for an incoming copy of the value, decoded into
529 sign / exponent / mantissa.
530 Also contains encoding functions, creation and recognition of
531 zero, NaN and inf (all signed)
532
533 Four extra bits are included in the mantissa: the top bit
534 (m[-1]) is effectively a carry-overflow. The other three are
535 guard (m[2]), round (m[1]), and sticky (m[0])
536 """
537
538 def __init__(self, fp):
539 FPNumBase.__init__(self, fp)
540
541 def elaborate(self, platform):
542 m = FPNumBase.elaborate(self, platform)
543
544 return m
545
546
547 class MultiShiftRMerge(Elaboratable):
548 """ shifts down (right) and merges lower bits into m[0].
549 m[0] is the "sticky" bit, basically
550 """
551
552 def __init__(self, width, s_max=None):
553 if s_max is None:
554 s_max = int(log(width) / log(2))
555 self.smax = s_max
556 self.m = Signal(width, reset_less=True)
557 self.inp = Signal(width, reset_less=True)
558 self.diff = Signal(s_max, reset_less=True)
559 self.width = width
560
561 def elaborate(self, platform):
562 m = Module()
563
564 rs = Signal(self.width, reset_less=True)
565 m_mask = Signal(self.width, reset_less=True)
566 smask = Signal(self.width, reset_less=True)
567 stickybit = Signal(reset_less=True)
568 maxslen = Signal(self.smax, reset_less=True)
569 maxsleni = Signal(self.smax, reset_less=True)
570
571 sm = MultiShift(self.width-1)
572 m0s = Const(0, self.width-1)
573 mw = Const(self.width-1, len(self.diff))
574 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
575 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
576 ]
577
578 m.d.comb += [
579 # shift mantissa by maxslen, mask by inverse
580 rs.eq(sm.rshift(self.inp[1:], maxslen)),
581 m_mask.eq(sm.rshift(~m0s, maxsleni)),
582 smask.eq(self.inp[1:] & m_mask),
583 # sticky bit combines all mask (and mantissa low bit)
584 stickybit.eq(smask.bool() | self.inp[0]),
585 # mantissa result contains m[0] already.
586 self.m.eq(Cat(stickybit, rs))
587 ]
588 return m
589
590
591 class FPNumShift(FPNumBase, Elaboratable):
592 """ Floating-point Number Class for shifting
593 """
594
595 def __init__(self, mainm, op, inv, width, m_extra=True):
596 FPNumBase.__init__(self, width, m_extra)
597 self.latch_in = Signal()
598 self.mainm = mainm
599 self.inv = inv
600 self.op = op
601
602 def elaborate(self, platform):
603 m = FPNumBase.elaborate(self, platform)
604
605 m.d.comb += self.s.eq(op.s)
606 m.d.comb += self.e.eq(op.e)
607 m.d.comb += self.m.eq(op.m)
608
609 with self.mainm.State("align"):
610 with m.If(self.e < self.inv.e):
611 m.d.sync += self.shift_down()
612
613 return m
614
615 def shift_down(self, inp):
616 """ shifts a mantissa down by one. exponent is increased to compensate
617
618 accuracy is lost as a result in the mantissa however there are 3
619 guard bits (the latter of which is the "sticky" bit)
620 """
621 return [self.e.eq(inp.e + 1),
622 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
623 ]
624
625 def shift_down_multi(self, diff):
626 """ shifts a mantissa down. exponent is increased to compensate
627
628 accuracy is lost as a result in the mantissa however there are 3
629 guard bits (the latter of which is the "sticky" bit)
630
631 this code works by variable-shifting the mantissa by up to
632 its maximum bit-length: no point doing more (it'll still be
633 zero).
634
635 the sticky bit is computed by shifting a batch of 1s by
636 the same amount, which will introduce zeros. it's then
637 inverted and used as a mask to get the LSBs of the mantissa.
638 those are then |'d into the sticky bit.
639 """
640 sm = MultiShift(self.width)
641 mw = Const(self.m_width-1, len(diff))
642 maxslen = Mux(diff > mw, mw, diff)
643 rs = sm.rshift(self.m[1:], maxslen)
644 maxsleni = mw - maxslen
645 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
646
647 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
648 return [self.e.eq(self.e + diff),
649 self.m.eq(Cat(stickybits, rs))
650 ]
651
652 def shift_up_multi(self, diff):
653 """ shifts a mantissa up. exponent is decreased to compensate
654 """
655 sm = MultiShift(self.width)
656 mw = Const(self.m_width, len(diff))
657 maxslen = Mux(diff > mw, mw, diff)
658
659 return [self.e.eq(self.e - diff),
660 self.m.eq(sm.lshift(self.m, maxslen))
661 ]
662
663
664 class FPNumDecode(FPNumBase):
665 """ Floating-point Number Class
666
667 Contains signals for an incoming copy of the value, decoded into
668 sign / exponent / mantissa.
669 Also contains encoding functions, creation and recognition of
670 zero, NaN and inf (all signed)
671
672 Four extra bits are included in the mantissa: the top bit
673 (m[-1]) is effectively a carry-overflow. The other three are
674 guard (m[2]), round (m[1]), and sticky (m[0])
675 """
676
677 def __init__(self, op, fp):
678 FPNumBase.__init__(self, fp)
679 self.op = op
680
681 def elaborate(self, platform):
682 m = FPNumBase.elaborate(self, platform)
683
684 m.d.comb += self.decode(self.v)
685
686 return m
687
688 def decode(self, v):
689 """ decodes a latched value into sign / exponent / mantissa
690
691 bias is subtracted here, from the exponent. exponent
692 is extended to 10 bits so that subtract 127 is done on
693 a 10-bit number
694 """
695 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
696 #print ("decode", self.e_end)
697 return [self.m.eq(Cat(*args)), # mantissa
698 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
699 self.s.eq(v[-1]), # sign
700 ]
701
702
703 class FPNumIn(FPNumBase):
704 """ Floating-point Number Class
705
706 Contains signals for an incoming copy of the value, decoded into
707 sign / exponent / mantissa.
708 Also contains encoding functions, creation and recognition of
709 zero, NaN and inf (all signed)
710
711 Four extra bits are included in the mantissa: the top bit
712 (m[-1]) is effectively a carry-overflow. The other three are
713 guard (m[2]), round (m[1]), and sticky (m[0])
714 """
715
716 def __init__(self, op, fp):
717 FPNumBase.__init__(self, fp)
718 self.latch_in = Signal()
719 self.op = op
720
721 def decode2(self, m):
722 """ decodes a latched value into sign / exponent / mantissa
723
724 bias is subtracted here, from the exponent. exponent
725 is extended to 10 bits so that subtract 127 is done on
726 a 10-bit number
727 """
728 v = self.v
729 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
730 #print ("decode", self.e_end)
731 res = ObjectProxy(m, pipemode=False)
732 res.m = Cat(*args) # mantissa
733 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
734 res.s = v[-1] # sign
735 return res
736
737 def decode(self, v):
738 """ decodes a latched value into sign / exponent / mantissa
739
740 bias is subtracted here, from the exponent. exponent
741 is extended to 10 bits so that subtract 127 is done on
742 a 10-bit number
743 """
744 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
745 #print ("decode", self.e_end)
746 return [self.m.eq(Cat(*args)), # mantissa
747 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
748 self.s.eq(v[-1]), # sign
749 ]
750
751 def shift_down(self, inp):
752 """ shifts a mantissa down by one. exponent is increased to compensate
753
754 accuracy is lost as a result in the mantissa however there are 3
755 guard bits (the latter of which is the "sticky" bit)
756 """
757 return [self.e.eq(inp.e + 1),
758 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
759 ]
760
761 def shift_down_multi(self, diff, inp=None):
762 """ shifts a mantissa down. exponent is increased to compensate
763
764 accuracy is lost as a result in the mantissa however there are 3
765 guard bits (the latter of which is the "sticky" bit)
766
767 this code works by variable-shifting the mantissa by up to
768 its maximum bit-length: no point doing more (it'll still be
769 zero).
770
771 the sticky bit is computed by shifting a batch of 1s by
772 the same amount, which will introduce zeros. it's then
773 inverted and used as a mask to get the LSBs of the mantissa.
774 those are then |'d into the sticky bit.
775 """
776 if inp is None:
777 inp = self
778 sm = MultiShift(self.width)
779 mw = Const(self.m_width-1, len(diff))
780 maxslen = Mux(diff > mw, mw, diff)
781 rs = sm.rshift(inp.m[1:], maxslen)
782 maxsleni = mw - maxslen
783 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
784
785 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
786 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
787 return [self.e.eq(inp.e + diff),
788 self.m.eq(Cat(stickybit, rs))
789 ]
790
791 def shift_up_multi(self, diff):
792 """ shifts a mantissa up. exponent is decreased to compensate
793 """
794 sm = MultiShift(self.width)
795 mw = Const(self.m_width, len(diff))
796 maxslen = Mux(diff > mw, mw, diff)
797
798 return [self.e.eq(self.e - diff),
799 self.m.eq(sm.lshift(self.m, maxslen))
800 ]
801
802
803 class Trigger(Elaboratable):
804 def __init__(self):
805
806 self.stb = Signal(reset=0)
807 self.ack = Signal()
808 self.trigger = Signal(reset_less=True)
809
810 def elaborate(self, platform):
811 m = Module()
812 m.d.comb += self.trigger.eq(self.stb & self.ack)
813 return m
814
815 def eq(self, inp):
816 return [self.stb.eq(inp.stb),
817 self.ack.eq(inp.ack)
818 ]
819
820 def ports(self):
821 return [self.stb, self.ack]
822
823
824 class FPOpIn(PrevControl):
825 def __init__(self, width):
826 PrevControl.__init__(self)
827 self.width = width
828
829 @property
830 def v(self):
831 return self.data_i
832
833 def chain_inv(self, in_op, extra=None):
834 stb = in_op.stb
835 if extra is not None:
836 stb = stb & extra
837 return [self.v.eq(in_op.v), # receive value
838 self.stb.eq(stb), # receive STB
839 in_op.ack.eq(~self.ack), # send ACK
840 ]
841
842 def chain_from(self, in_op, extra=None):
843 stb = in_op.stb
844 if extra is not None:
845 stb = stb & extra
846 return [self.v.eq(in_op.v), # receive value
847 self.stb.eq(stb), # receive STB
848 in_op.ack.eq(self.ack), # send ACK
849 ]
850
851
852 class FPOpOut(NextControl):
853 def __init__(self, width):
854 NextControl.__init__(self)
855 self.width = width
856
857 @property
858 def v(self):
859 return self.data_o
860
861 def chain_inv(self, in_op, extra=None):
862 stb = in_op.stb
863 if extra is not None:
864 stb = stb & extra
865 return [self.v.eq(in_op.v), # receive value
866 self.stb.eq(stb), # receive STB
867 in_op.ack.eq(~self.ack), # send ACK
868 ]
869
870 def chain_from(self, in_op, extra=None):
871 stb = in_op.stb
872 if extra is not None:
873 stb = stb & extra
874 return [self.v.eq(in_op.v), # receive value
875 self.stb.eq(stb), # receive STB
876 in_op.ack.eq(self.ack), # send ACK
877 ]
878
879
880 class Overflow:
881 def __init__(self, name=None):
882 if name is None:
883 name = ""
884 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
885 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
886 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
887 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
888
889 #self.roundz = Signal(reset_less=True)
890
891 def __iter__(self):
892 yield self.guard
893 yield self.round_bit
894 yield self.sticky
895 yield self.m0
896
897 def eq(self, inp):
898 return [self.guard.eq(inp.guard),
899 self.round_bit.eq(inp.round_bit),
900 self.sticky.eq(inp.sticky),
901 self.m0.eq(inp.m0)]
902
903 @property
904 def roundz(self):
905 return self.guard & (self.round_bit | self.sticky | self.m0)
906
907
908 class OverflowMod(Elaboratable, Overflow):
909 def __init__(self, name=None):
910 Overflow.__init__(self, name)
911 if name is None:
912 name = ""
913 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
914
915 def __iter__(self):
916 yield from Overflow.__iter__(self)
917 yield self.roundz_out
918
919 def eq(self, inp):
920 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
921
922 def elaborate(self, platform):
923 m = Module()
924 m.d.comb += self.roundz_out.eq(self.roundz)
925 return m
926
927
928 class FPBase:
929 """ IEEE754 Floating Point Base Class
930
931 contains common functions for FP manipulation, such as
932 extracting and packing operands, normalisation, denormalisation,
933 rounding etc.
934 """
935
936 def get_op(self, m, op, v, next_state):
937 """ this function moves to the next state and copies the operand
938 when both stb and ack are 1.
939 acknowledgement is sent by setting ack to ZERO.
940 """
941 res = v.decode2(m)
942 ack = Signal()
943 with m.If((op.ready_o) & (op.valid_i_test)):
944 m.next = next_state
945 # op is latched in from FPNumIn class on same ack/stb
946 m.d.comb += ack.eq(0)
947 with m.Else():
948 m.d.comb += ack.eq(1)
949 return [res, ack]
950
951 def denormalise(self, m, a):
952 """ denormalises a number. this is probably the wrong name for
953 this function. for normalised numbers (exponent != minimum)
954 one *extra* bit (the implicit 1) is added *back in*.
955 for denormalised numbers, the mantissa is left alone
956 and the exponent increased by 1.
957
958 both cases *effectively multiply the number stored by 2*,
959 which has to be taken into account when extracting the result.
960 """
961 with m.If(a.exp_n127):
962 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
963 with m.Else():
964 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
965
966 def op_normalise(self, m, op, next_state):
967 """ operand normalisation
968 NOTE: just like "align", this one keeps going round every clock
969 until the result's exponent is within acceptable "range"
970 """
971 with m.If((op.m[-1] == 0)): # check last bit of mantissa
972 m.d.sync += [
973 op.e.eq(op.e - 1), # DECREASE exponent
974 op.m.eq(op.m << 1), # shift mantissa UP
975 ]
976 with m.Else():
977 m.next = next_state
978
979 def normalise_1(self, m, z, of, next_state):
980 """ first stage normalisation
981
982 NOTE: just like "align", this one keeps going round every clock
983 until the result's exponent is within acceptable "range"
984 NOTE: the weirdness of reassigning guard and round is due to
985 the extra mantissa bits coming from tot[0..2]
986 """
987 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
988 m.d.sync += [
989 z.e.eq(z.e - 1), # DECREASE exponent
990 z.m.eq(z.m << 1), # shift mantissa UP
991 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
992 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
993 of.round_bit.eq(0), # reset round bit
994 of.m0.eq(of.guard),
995 ]
996 with m.Else():
997 m.next = next_state
998
999 def normalise_2(self, m, z, of, next_state):
1000 """ second stage normalisation
1001
1002 NOTE: just like "align", this one keeps going round every clock
1003 until the result's exponent is within acceptable "range"
1004 NOTE: the weirdness of reassigning guard and round is due to
1005 the extra mantissa bits coming from tot[0..2]
1006 """
1007 with m.If(z.e < z.fp.N126):
1008 m.d.sync += [
1009 z.e.eq(z.e + 1), # INCREASE exponent
1010 z.m.eq(z.m >> 1), # shift mantissa DOWN
1011 of.guard.eq(z.m[0]),
1012 of.m0.eq(z.m[1]),
1013 of.round_bit.eq(of.guard),
1014 of.sticky.eq(of.sticky | of.round_bit)
1015 ]
1016 with m.Else():
1017 m.next = next_state
1018
1019 def roundz(self, m, z, roundz):
1020 """ performs rounding on the output. TODO: different kinds of rounding
1021 """
1022 with m.If(roundz):
1023 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1024 with m.If(z.m == z.fp.m1s): # all 1s
1025 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1026
1027 def corrections(self, m, z, next_state):
1028 """ denormalisation and sign-bug corrections
1029 """
1030 m.next = next_state
1031 # denormalised, correct exponent to zero
1032 with m.If(z.is_denormalised):
1033 m.d.sync += z.e.eq(z.fp.N127)
1034
1035 def pack(self, m, z, next_state):
1036 """ packs the result into the output (detects overflow->Inf)
1037 """
1038 m.next = next_state
1039 # if overflow occurs, return inf
1040 with m.If(z.is_overflowed):
1041 m.d.sync += z.inf(z.s)
1042 with m.Else():
1043 m.d.sync += z.create(z.s, z.e, z.m)
1044
1045 def put_z(self, m, z, out_z, next_state):
1046 """ put_z: stores the result in the output. raises stb and waits
1047 for ack to be set to 1 before moving to the next state.
1048 resets stb back to zero when that occurs, as acknowledgement.
1049 """
1050 m.d.sync += [
1051 out_z.v.eq(z.v)
1052 ]
1053 with m.If(out_z.valid_o & out_z.ready_i_test):
1054 m.d.sync += out_z.valid_o.eq(0)
1055 m.next = next_state
1056 with m.Else():
1057 m.d.sync += out_z.valid_o.eq(1)
1058
1059
1060 class FPState(FPBase):
1061 def __init__(self, state_from):
1062 self.state_from = state_from
1063
1064 def set_inputs(self, inputs):
1065 self.inputs = inputs
1066 for k, v in inputs.items():
1067 setattr(self, k, v)
1068
1069 def set_outputs(self, outputs):
1070 self.outputs = outputs
1071 for k, v in outputs.items():
1072 setattr(self, k, v)
1073
1074
1075 class FPID:
1076 def __init__(self, id_wid):
1077 self.id_wid = id_wid
1078 if self.id_wid:
1079 self.in_mid = Signal(id_wid, reset_less=True)
1080 self.out_mid = Signal(id_wid, reset_less=True)
1081 else:
1082 self.in_mid = None
1083 self.out_mid = None
1084
1085 def idsync(self, m):
1086 if self.id_wid is not None:
1087 m.d.sync += self.out_mid.eq(self.in_mid)
1088
1089
1090 if __name__ == '__main__':
1091 unittest.main()