copy context/roundz, a and b manually in fpmul align
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12 import math
13
14
15 class FPFormat:
16 """ Class describing binary floating-point formats based on IEEE 754.
17
18 :attribute exponent_width: the number of bits in the exponent field.
19 :attribute mantissa_width: the number of bits stored in the mantissa
20 field.
21 :attribute has_int_bit: if the FP format has an explicit integer bit (like
22 the x87 80-bit format). The bit is considered part of the mantissa.
23 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
24 Image/Buffer formats are FP numbers without a sign bit.)
25 """
26
27 def __init__(self,
28 exponent_width,
29 mantissa_width,
30 has_int_bit=False,
31 has_sign=True):
32 """ Create ``FPFormat`` instance. """
33 self.exponent_width = exponent_width
34 self.mantissa_width = mantissa_width
35 self.has_int_bit = has_int_bit
36 self.has_sign = has_sign
37
38 def __eq__(self, other):
39 """ Check for equality. """
40 if not isinstance(other, FPFormat):
41 return NotImplemented
42 return (self.exponent_width == other.exponent_width
43 and self.mantissa_width == other.mantissa_width
44 and self.has_int_bit == other.has_int_bit
45 and self.has_sign == other.has_sign)
46
47 @staticmethod
48 def standard(width):
49 """ Get standard IEEE 754-2008 format.
50
51 :param width: bit-width of requested format.
52 :returns: the requested ``FPFormat`` instance.
53 """
54 if not instanceof(width, int):
55 raise TypeError()
56 if width == 16:
57 return FPFormat(5, 10)
58 if width == 32:
59 return FPFormat(8, 23)
60 if width == 64:
61 return FPFormat(11, 52)
62 if width == 128:
63 return FPFormat(15, 112)
64 if width > 128 and width % 32 == 0:
65 if width > 1000000: # arbitrary upper limit
66 raise ValueError("width too big")
67 exponent_width = round(4 * math.log2(width)) - 13
68 return FPFormat(exponent_width, width - 1 - exponent_width)
69 raise ValueError("width must be the bit-width of a valid IEEE"
70 " 754-2008 binary format")
71
72 def __repr__(self):
73 """ Get repr. """
74 try:
75 if self == self.standard(self.width):
76 return f"FPFormat.standard({self.width})"
77 except ValueError:
78 pass
79 retval = f"FPFormat({self.exponent_width}, {self.mantissa_width}"
80 if self.has_int_bit is not False:
81 retval += f", {self.has_int_bit}"
82 if self.has_sign is not True:
83 retval += f", {self.has_sign}"
84 return retval + ")"
85
86 @property
87 def width(self):
88 """ Get the total number of bits in the FP format. """
89 return self.has_sign + self.exponent_width + self.mantissa_width
90
91 @property
92 def exponent_inf_nan(self):
93 """ Get the value of the exponent field designating infinity/NaN. """
94 return (1 << self.exponent_width) - 1
95
96 @property
97 def exponent_denormal_zero(self):
98 """ Get the value of the exponent field designating denormal/zero. """
99 return 0
100
101 @property
102 def exponent_min_normal(self):
103 """ Get the minimum value of the exponent field for normal numbers. """
104 return 1
105
106 @property
107 def exponent_max_normal(self):
108 """ Get the maximum value of the exponent field for normal numbers. """
109 return self.exponent_inf_nan - 1
110
111 @property
112 def exponent_bias(self):
113 """ Get the exponent bias. """
114 return (1 << (self.exponent_width - 1)) - 1
115
116 @property
117 def fraction_width(self):
118 """ Get the number of mantissa bits that are fraction bits. """
119 return self.mantissa_width - self.has_int_bit
120
121
122 class MultiShiftR:
123
124 def __init__(self, width):
125 self.width = width
126 self.smax = int(log(width) / log(2))
127 self.i = Signal(width, reset_less=True)
128 self.s = Signal(self.smax, reset_less=True)
129 self.o = Signal(width, reset_less=True)
130
131 def elaborate(self, platform):
132 m = Module()
133 m.d.comb += self.o.eq(self.i >> self.s)
134 return m
135
136
137 class MultiShift:
138 """ Generates variable-length single-cycle shifter from a series
139 of conditional tests on each bit of the left/right shift operand.
140 Each bit tested produces output shifted by that number of bits,
141 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
142 shifts by 2 bits, each partial result cascading to the next Mux.
143
144 Could be adapted to do arithmetic shift by taking copies of the
145 MSB instead of zeros.
146 """
147
148 def __init__(self, width):
149 self.width = width
150 self.smax = int(log(width) / log(2))
151
152 def lshift(self, op, s):
153 res = op << s
154 return res[:len(op)]
155
156 def rshift(self, op, s):
157 res = op >> s
158 return res[:len(op)]
159
160
161 class FPNumBaseRecord:
162 """ Floating-point Base Number Class
163 """
164
165 def __init__(self, width, m_extra=True, e_extra=False):
166 self.width = width
167 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
168 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
169 e_max = 1 << (e_width-3)
170 self.rmw = m_width - 1 # real mantissa width (not including extras)
171 self.e_max = e_max
172 if m_extra:
173 # mantissa extra bits (top,guard,round)
174 self.m_extra = 3
175 m_width += self.m_extra
176 else:
177 self.m_extra = 0
178 if e_extra:
179 self.e_extra = 6 # enough to cover FP64 when converting to FP16
180 e_width += self.e_extra
181 else:
182 self.e_extra = 0
183 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
184 self.m_width = m_width
185 self.e_width = e_width
186 self.e_start = self.rmw
187 self.e_end = self.rmw + self.e_width - 2 # for decoding
188
189 self.v = Signal(width, reset_less=True) # Latched copy of value
190 self.m = Signal(m_width, reset_less=True) # Mantissa
191 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
192 self.s = Signal(reset_less=True) # Sign bit
193
194 self.fp = self
195 self.drop_in(self)
196
197 def drop_in(self, fp):
198 fp.s = self.s
199 fp.e = self.e
200 fp.m = self.m
201 fp.v = self.v
202 fp.rmw = self.rmw
203 fp.width = self.width
204 fp.e_width = self.e_width
205 fp.e_max = self.e_max
206 fp.m_width = self.m_width
207 fp.e_start = self.e_start
208 fp.e_end = self.e_end
209 fp.m_extra = self.m_extra
210
211 m_width = self.m_width
212 e_max = self.e_max
213 e_width = self.e_width
214
215 self.mzero = Const(0, (m_width, False))
216 m_msb = 1 << (self.m_width-2)
217 self.msb1 = Const(m_msb, (m_width, False))
218 self.m1s = Const(-1, (m_width, False))
219 self.P128 = Const(e_max, (e_width, True))
220 self.P127 = Const(e_max-1, (e_width, True))
221 self.N127 = Const(-(e_max-1), (e_width, True))
222 self.N126 = Const(-(e_max-2), (e_width, True))
223
224 def create(self, s, e, m):
225 """ creates a value from sign / exponent / mantissa
226
227 bias is added here, to the exponent.
228
229 NOTE: order is important, because e_start/e_end can be
230 a bit too long (overwriting s).
231 """
232 return [
233 self.v[0:self.e_start].eq(m), # mantissa
234 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
235 self.v[-1].eq(s), # sign
236 ]
237
238 def _nan(self, s):
239 return (s, self.fp.P128, 1 << (self.e_start-1))
240
241 def _inf(self, s):
242 return (s, self.fp.P128, 0)
243
244 def _zero(self, s):
245 return (s, self.fp.N127, 0)
246
247 def nan(self, s):
248 return self.create(*self._nan(s))
249
250 def inf(self, s):
251 return self.create(*self._inf(s))
252
253 def zero(self, s):
254 return self.create(*self._zero(s))
255
256 def create2(self, s, e, m):
257 """ creates a value from sign / exponent / mantissa
258
259 bias is added here, to the exponent
260 """
261 e = e + self.P127 # exp (add on bias)
262 return Cat(m[0:self.e_start],
263 e[0:self.e_end-self.e_start],
264 s)
265
266 def nan2(self, s):
267 return self.create2(s, self.P128, self.msb1)
268
269 def inf2(self, s):
270 return self.create2(s, self.P128, self.mzero)
271
272 def zero2(self, s):
273 return self.create2(s, self.N127, self.mzero)
274
275 def __iter__(self):
276 yield self.s
277 yield self.e
278 yield self.m
279
280 def eq(self, inp):
281 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
282
283
284 class FPNumBase(FPNumBaseRecord, Elaboratable):
285 """ Floating-point Base Number Class
286 """
287
288 def __init__(self, fp):
289 fp.drop_in(self)
290 self.fp = fp
291 e_width = fp.e_width
292
293 self.is_nan = Signal(reset_less=True)
294 self.is_zero = Signal(reset_less=True)
295 self.is_inf = Signal(reset_less=True)
296 self.is_overflowed = Signal(reset_less=True)
297 self.is_denormalised = Signal(reset_less=True)
298 self.exp_128 = Signal(reset_less=True)
299 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
300 self.exp_lt_n126 = Signal(reset_less=True)
301 self.exp_gt_n126 = Signal(reset_less=True)
302 self.exp_gt127 = Signal(reset_less=True)
303 self.exp_n127 = Signal(reset_less=True)
304 self.exp_n126 = Signal(reset_less=True)
305 self.m_zero = Signal(reset_less=True)
306 self.m_msbzero = Signal(reset_less=True)
307
308 def elaborate(self, platform):
309 m = Module()
310 m.d.comb += self.is_nan.eq(self._is_nan())
311 m.d.comb += self.is_zero.eq(self._is_zero())
312 m.d.comb += self.is_inf.eq(self._is_inf())
313 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
314 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
315 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
316 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
317 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
318 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
319 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
320 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
321 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
322 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
323 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
324
325 return m
326
327 def _is_nan(self):
328 return (self.exp_128) & (~self.m_zero)
329
330 def _is_inf(self):
331 return (self.exp_128) & (self.m_zero)
332
333 def _is_zero(self):
334 return (self.exp_n127) & (self.m_zero)
335
336 def _is_overflowed(self):
337 return self.exp_gt127
338
339 def _is_denormalised(self):
340 return (self.exp_n126) & (self.m_msbzero)
341
342
343 class FPNumOut(FPNumBase):
344 """ Floating-point Number Class
345
346 Contains signals for an incoming copy of the value, decoded into
347 sign / exponent / mantissa.
348 Also contains encoding functions, creation and recognition of
349 zero, NaN and inf (all signed)
350
351 Four extra bits are included in the mantissa: the top bit
352 (m[-1]) is effectively a carry-overflow. The other three are
353 guard (m[2]), round (m[1]), and sticky (m[0])
354 """
355
356 def __init__(self, fp):
357 FPNumBase.__init__(self, fp)
358
359 def elaborate(self, platform):
360 m = FPNumBase.elaborate(self, platform)
361
362 return m
363
364
365 class MultiShiftRMerge(Elaboratable):
366 """ shifts down (right) and merges lower bits into m[0].
367 m[0] is the "sticky" bit, basically
368 """
369
370 def __init__(self, width, s_max=None):
371 if s_max is None:
372 s_max = int(log(width) / log(2))
373 self.smax = s_max
374 self.m = Signal(width, reset_less=True)
375 self.inp = Signal(width, reset_less=True)
376 self.diff = Signal(s_max, reset_less=True)
377 self.width = width
378
379 def elaborate(self, platform):
380 m = Module()
381
382 rs = Signal(self.width, reset_less=True)
383 m_mask = Signal(self.width, reset_less=True)
384 smask = Signal(self.width, reset_less=True)
385 stickybit = Signal(reset_less=True)
386 maxslen = Signal(self.smax, reset_less=True)
387 maxsleni = Signal(self.smax, reset_less=True)
388
389 sm = MultiShift(self.width-1)
390 m0s = Const(0, self.width-1)
391 mw = Const(self.width-1, len(self.diff))
392 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
393 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
394 ]
395
396 m.d.comb += [
397 # shift mantissa by maxslen, mask by inverse
398 rs.eq(sm.rshift(self.inp[1:], maxslen)),
399 m_mask.eq(sm.rshift(~m0s, maxsleni)),
400 smask.eq(self.inp[1:] & m_mask),
401 # sticky bit combines all mask (and mantissa low bit)
402 stickybit.eq(smask.bool() | self.inp[0]),
403 # mantissa result contains m[0] already.
404 self.m.eq(Cat(stickybit, rs))
405 ]
406 return m
407
408
409 class FPNumShift(FPNumBase, Elaboratable):
410 """ Floating-point Number Class for shifting
411 """
412
413 def __init__(self, mainm, op, inv, width, m_extra=True):
414 FPNumBase.__init__(self, width, m_extra)
415 self.latch_in = Signal()
416 self.mainm = mainm
417 self.inv = inv
418 self.op = op
419
420 def elaborate(self, platform):
421 m = FPNumBase.elaborate(self, platform)
422
423 m.d.comb += self.s.eq(op.s)
424 m.d.comb += self.e.eq(op.e)
425 m.d.comb += self.m.eq(op.m)
426
427 with self.mainm.State("align"):
428 with m.If(self.e < self.inv.e):
429 m.d.sync += self.shift_down()
430
431 return m
432
433 def shift_down(self, inp):
434 """ shifts a mantissa down by one. exponent is increased to compensate
435
436 accuracy is lost as a result in the mantissa however there are 3
437 guard bits (the latter of which is the "sticky" bit)
438 """
439 return [self.e.eq(inp.e + 1),
440 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
441 ]
442
443 def shift_down_multi(self, diff):
444 """ shifts a mantissa down. exponent is increased to compensate
445
446 accuracy is lost as a result in the mantissa however there are 3
447 guard bits (the latter of which is the "sticky" bit)
448
449 this code works by variable-shifting the mantissa by up to
450 its maximum bit-length: no point doing more (it'll still be
451 zero).
452
453 the sticky bit is computed by shifting a batch of 1s by
454 the same amount, which will introduce zeros. it's then
455 inverted and used as a mask to get the LSBs of the mantissa.
456 those are then |'d into the sticky bit.
457 """
458 sm = MultiShift(self.width)
459 mw = Const(self.m_width-1, len(diff))
460 maxslen = Mux(diff > mw, mw, diff)
461 rs = sm.rshift(self.m[1:], maxslen)
462 maxsleni = mw - maxslen
463 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
464
465 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
466 return [self.e.eq(self.e + diff),
467 self.m.eq(Cat(stickybits, rs))
468 ]
469
470 def shift_up_multi(self, diff):
471 """ shifts a mantissa up. exponent is decreased to compensate
472 """
473 sm = MultiShift(self.width)
474 mw = Const(self.m_width, len(diff))
475 maxslen = Mux(diff > mw, mw, diff)
476
477 return [self.e.eq(self.e - diff),
478 self.m.eq(sm.lshift(self.m, maxslen))
479 ]
480
481
482 class FPNumDecode(FPNumBase):
483 """ Floating-point Number Class
484
485 Contains signals for an incoming copy of the value, decoded into
486 sign / exponent / mantissa.
487 Also contains encoding functions, creation and recognition of
488 zero, NaN and inf (all signed)
489
490 Four extra bits are included in the mantissa: the top bit
491 (m[-1]) is effectively a carry-overflow. The other three are
492 guard (m[2]), round (m[1]), and sticky (m[0])
493 """
494
495 def __init__(self, op, fp):
496 FPNumBase.__init__(self, fp)
497 self.op = op
498
499 def elaborate(self, platform):
500 m = FPNumBase.elaborate(self, platform)
501
502 m.d.comb += self.decode(self.v)
503
504 return m
505
506 def decode(self, v):
507 """ decodes a latched value into sign / exponent / mantissa
508
509 bias is subtracted here, from the exponent. exponent
510 is extended to 10 bits so that subtract 127 is done on
511 a 10-bit number
512 """
513 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
514 #print ("decode", self.e_end)
515 return [self.m.eq(Cat(*args)), # mantissa
516 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
517 self.s.eq(v[-1]), # sign
518 ]
519
520
521 class FPNumIn(FPNumBase):
522 """ Floating-point Number Class
523
524 Contains signals for an incoming copy of the value, decoded into
525 sign / exponent / mantissa.
526 Also contains encoding functions, creation and recognition of
527 zero, NaN and inf (all signed)
528
529 Four extra bits are included in the mantissa: the top bit
530 (m[-1]) is effectively a carry-overflow. The other three are
531 guard (m[2]), round (m[1]), and sticky (m[0])
532 """
533
534 def __init__(self, op, fp):
535 FPNumBase.__init__(self, fp)
536 self.latch_in = Signal()
537 self.op = op
538
539 def decode2(self, m):
540 """ decodes a latched value into sign / exponent / mantissa
541
542 bias is subtracted here, from the exponent. exponent
543 is extended to 10 bits so that subtract 127 is done on
544 a 10-bit number
545 """
546 v = self.v
547 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
548 #print ("decode", self.e_end)
549 res = ObjectProxy(m, pipemode=False)
550 res.m = Cat(*args) # mantissa
551 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
552 res.s = v[-1] # sign
553 return res
554
555 def decode(self, v):
556 """ decodes a latched value into sign / exponent / mantissa
557
558 bias is subtracted here, from the exponent. exponent
559 is extended to 10 bits so that subtract 127 is done on
560 a 10-bit number
561 """
562 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
563 #print ("decode", self.e_end)
564 return [self.m.eq(Cat(*args)), # mantissa
565 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
566 self.s.eq(v[-1]), # sign
567 ]
568
569 def shift_down(self, inp):
570 """ shifts a mantissa down by one. exponent is increased to compensate
571
572 accuracy is lost as a result in the mantissa however there are 3
573 guard bits (the latter of which is the "sticky" bit)
574 """
575 return [self.e.eq(inp.e + 1),
576 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
577 ]
578
579 def shift_down_multi(self, diff, inp=None):
580 """ shifts a mantissa down. exponent is increased to compensate
581
582 accuracy is lost as a result in the mantissa however there are 3
583 guard bits (the latter of which is the "sticky" bit)
584
585 this code works by variable-shifting the mantissa by up to
586 its maximum bit-length: no point doing more (it'll still be
587 zero).
588
589 the sticky bit is computed by shifting a batch of 1s by
590 the same amount, which will introduce zeros. it's then
591 inverted and used as a mask to get the LSBs of the mantissa.
592 those are then |'d into the sticky bit.
593 """
594 if inp is None:
595 inp = self
596 sm = MultiShift(self.width)
597 mw = Const(self.m_width-1, len(diff))
598 maxslen = Mux(diff > mw, mw, diff)
599 rs = sm.rshift(inp.m[1:], maxslen)
600 maxsleni = mw - maxslen
601 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
602
603 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
604 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
605 return [self.e.eq(inp.e + diff),
606 self.m.eq(Cat(stickybit, rs))
607 ]
608
609 def shift_up_multi(self, diff):
610 """ shifts a mantissa up. exponent is decreased to compensate
611 """
612 sm = MultiShift(self.width)
613 mw = Const(self.m_width, len(diff))
614 maxslen = Mux(diff > mw, mw, diff)
615
616 return [self.e.eq(self.e - diff),
617 self.m.eq(sm.lshift(self.m, maxslen))
618 ]
619
620
621 class Trigger(Elaboratable):
622 def __init__(self):
623
624 self.stb = Signal(reset=0)
625 self.ack = Signal()
626 self.trigger = Signal(reset_less=True)
627
628 def elaborate(self, platform):
629 m = Module()
630 m.d.comb += self.trigger.eq(self.stb & self.ack)
631 return m
632
633 def eq(self, inp):
634 return [self.stb.eq(inp.stb),
635 self.ack.eq(inp.ack)
636 ]
637
638 def ports(self):
639 return [self.stb, self.ack]
640
641
642 class FPOpIn(PrevControl):
643 def __init__(self, width):
644 PrevControl.__init__(self)
645 self.width = width
646
647 @property
648 def v(self):
649 return self.data_i
650
651 def chain_inv(self, in_op, extra=None):
652 stb = in_op.stb
653 if extra is not None:
654 stb = stb & extra
655 return [self.v.eq(in_op.v), # receive value
656 self.stb.eq(stb), # receive STB
657 in_op.ack.eq(~self.ack), # send ACK
658 ]
659
660 def chain_from(self, in_op, extra=None):
661 stb = in_op.stb
662 if extra is not None:
663 stb = stb & extra
664 return [self.v.eq(in_op.v), # receive value
665 self.stb.eq(stb), # receive STB
666 in_op.ack.eq(self.ack), # send ACK
667 ]
668
669
670 class FPOpOut(NextControl):
671 def __init__(self, width):
672 NextControl.__init__(self)
673 self.width = width
674
675 @property
676 def v(self):
677 return self.data_o
678
679 def chain_inv(self, in_op, extra=None):
680 stb = in_op.stb
681 if extra is not None:
682 stb = stb & extra
683 return [self.v.eq(in_op.v), # receive value
684 self.stb.eq(stb), # receive STB
685 in_op.ack.eq(~self.ack), # send ACK
686 ]
687
688 def chain_from(self, in_op, extra=None):
689 stb = in_op.stb
690 if extra is not None:
691 stb = stb & extra
692 return [self.v.eq(in_op.v), # receive value
693 self.stb.eq(stb), # receive STB
694 in_op.ack.eq(self.ack), # send ACK
695 ]
696
697
698 class Overflow: # (Elaboratable):
699 def __init__(self, name=None):
700 if name is None:
701 name = ""
702 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
703 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
704 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
705 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
706
707 #self.roundz = Signal(reset_less=True)
708
709 def __iter__(self):
710 yield self.guard
711 yield self.round_bit
712 yield self.sticky
713 yield self.m0
714
715 def eq(self, inp):
716 return [self.guard.eq(inp.guard),
717 self.round_bit.eq(inp.round_bit),
718 self.sticky.eq(inp.sticky),
719 self.m0.eq(inp.m0)]
720
721 @property
722 def roundz(self):
723 return self.guard & (self.round_bit | self.sticky | self.m0)
724
725
726 class OverflowMod(Elaboratable, Overflow):
727 def __init__(self, name=None):
728 Overflow.__init__(self, name)
729 if name is None:
730 name = ""
731 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
732
733 def __iter__(self):
734 yield from Overflow.__iter__(self)
735 yield self.roundz_out
736
737 def eq(self, inp):
738 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
739
740 def elaborate(self, platform):
741 m = Module()
742 m.d.comb += self.roundz_out.eq(self.roundz)
743 return m
744
745
746 class FPBase:
747 """ IEEE754 Floating Point Base Class
748
749 contains common functions for FP manipulation, such as
750 extracting and packing operands, normalisation, denormalisation,
751 rounding etc.
752 """
753
754 def get_op(self, m, op, v, next_state):
755 """ this function moves to the next state and copies the operand
756 when both stb and ack are 1.
757 acknowledgement is sent by setting ack to ZERO.
758 """
759 res = v.decode2(m)
760 ack = Signal()
761 with m.If((op.ready_o) & (op.valid_i_test)):
762 m.next = next_state
763 # op is latched in from FPNumIn class on same ack/stb
764 m.d.comb += ack.eq(0)
765 with m.Else():
766 m.d.comb += ack.eq(1)
767 return [res, ack]
768
769 def denormalise(self, m, a):
770 """ denormalises a number. this is probably the wrong name for
771 this function. for normalised numbers (exponent != minimum)
772 one *extra* bit (the implicit 1) is added *back in*.
773 for denormalised numbers, the mantissa is left alone
774 and the exponent increased by 1.
775
776 both cases *effectively multiply the number stored by 2*,
777 which has to be taken into account when extracting the result.
778 """
779 with m.If(a.exp_n127):
780 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
781 with m.Else():
782 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
783
784 def op_normalise(self, m, op, next_state):
785 """ operand normalisation
786 NOTE: just like "align", this one keeps going round every clock
787 until the result's exponent is within acceptable "range"
788 """
789 with m.If((op.m[-1] == 0)): # check last bit of mantissa
790 m.d.sync += [
791 op.e.eq(op.e - 1), # DECREASE exponent
792 op.m.eq(op.m << 1), # shift mantissa UP
793 ]
794 with m.Else():
795 m.next = next_state
796
797 def normalise_1(self, m, z, of, next_state):
798 """ first stage normalisation
799
800 NOTE: just like "align", this one keeps going round every clock
801 until the result's exponent is within acceptable "range"
802 NOTE: the weirdness of reassigning guard and round is due to
803 the extra mantissa bits coming from tot[0..2]
804 """
805 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
806 m.d.sync += [
807 z.e.eq(z.e - 1), # DECREASE exponent
808 z.m.eq(z.m << 1), # shift mantissa UP
809 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
810 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
811 of.round_bit.eq(0), # reset round bit
812 of.m0.eq(of.guard),
813 ]
814 with m.Else():
815 m.next = next_state
816
817 def normalise_2(self, m, z, of, next_state):
818 """ second stage normalisation
819
820 NOTE: just like "align", this one keeps going round every clock
821 until the result's exponent is within acceptable "range"
822 NOTE: the weirdness of reassigning guard and round is due to
823 the extra mantissa bits coming from tot[0..2]
824 """
825 with m.If(z.e < z.fp.N126):
826 m.d.sync += [
827 z.e.eq(z.e + 1), # INCREASE exponent
828 z.m.eq(z.m >> 1), # shift mantissa DOWN
829 of.guard.eq(z.m[0]),
830 of.m0.eq(z.m[1]),
831 of.round_bit.eq(of.guard),
832 of.sticky.eq(of.sticky | of.round_bit)
833 ]
834 with m.Else():
835 m.next = next_state
836
837 def roundz(self, m, z, roundz):
838 """ performs rounding on the output. TODO: different kinds of rounding
839 """
840 with m.If(roundz):
841 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
842 with m.If(z.m == z.fp.m1s): # all 1s
843 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
844
845 def corrections(self, m, z, next_state):
846 """ denormalisation and sign-bug corrections
847 """
848 m.next = next_state
849 # denormalised, correct exponent to zero
850 with m.If(z.is_denormalised):
851 m.d.sync += z.e.eq(z.fp.N127)
852
853 def pack(self, m, z, next_state):
854 """ packs the result into the output (detects overflow->Inf)
855 """
856 m.next = next_state
857 # if overflow occurs, return inf
858 with m.If(z.is_overflowed):
859 m.d.sync += z.inf(z.s)
860 with m.Else():
861 m.d.sync += z.create(z.s, z.e, z.m)
862
863 def put_z(self, m, z, out_z, next_state):
864 """ put_z: stores the result in the output. raises stb and waits
865 for ack to be set to 1 before moving to the next state.
866 resets stb back to zero when that occurs, as acknowledgement.
867 """
868 m.d.sync += [
869 out_z.v.eq(z.v)
870 ]
871 with m.If(out_z.valid_o & out_z.ready_i_test):
872 m.d.sync += out_z.valid_o.eq(0)
873 m.next = next_state
874 with m.Else():
875 m.d.sync += out_z.valid_o.eq(1)
876
877
878 class FPState(FPBase):
879 def __init__(self, state_from):
880 self.state_from = state_from
881
882 def set_inputs(self, inputs):
883 self.inputs = inputs
884 for k, v in inputs.items():
885 setattr(self, k, v)
886
887 def set_outputs(self, outputs):
888 self.outputs = outputs
889 for k, v in outputs.items():
890 setattr(self, k, v)
891
892
893 class FPID:
894 def __init__(self, id_wid):
895 self.id_wid = id_wid
896 if self.id_wid:
897 self.in_mid = Signal(id_wid, reset_less=True)
898 self.out_mid = Signal(id_wid, reset_less=True)
899 else:
900 self.in_mid = None
901 self.out_mid = None
902
903 def idsync(self, m):
904 if self.id_wid is not None:
905 m.d.sync += self.out_mid.eq(self.in_mid)