format fpbase.py
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1 << i)
50 res = Mux(s & (1 << i), Cat(zeros, res[0:-(1 << i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1 << i)
59 res = Mux(s & (1 << i), Cat(res[(1 << i):], zeros), res)
60 return res
61
62
63 class FPNumBaseRecord:
64 """ Floating-point Base Number Class
65 """
66
67 def __init__(self, width, m_extra=True, e_extra=False):
68 self.width = width
69 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
70 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
71 e_max = 1 << (e_width-3)
72 self.rmw = m_width - 1 # real mantissa width (not including extras)
73 self.e_max = e_max
74 if m_extra:
75 # mantissa extra bits (top,guard,round)
76 self.m_extra = 3
77 m_width += self.m_extra
78 else:
79 self.m_extra = 0
80 if e_extra:
81 self.e_extra = 6 # enough to cover FP64 when converting to FP16
82 e_width += self.e_extra
83 else:
84 self.e_extra = 0
85 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
86 self.m_width = m_width
87 self.e_width = e_width
88 self.e_start = self.rmw
89 self.e_end = self.rmw + self.e_width - 2 # for decoding
90
91 self.v = Signal(width, reset_less=True) # Latched copy of value
92 self.m = Signal(m_width, reset_less=True) # Mantissa
93 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
94 self.s = Signal(reset_less=True) # Sign bit
95
96 self.fp = self
97 self.drop_in(self)
98
99 def drop_in(self, fp):
100 fp.s = self.s
101 fp.e = self.e
102 fp.m = self.m
103 fp.v = self.v
104 fp.rmw = self.rmw
105 fp.width = self.width
106 fp.e_width = self.e_width
107 fp.e_max = self.e_max
108 fp.m_width = self.m_width
109 fp.e_start = self.e_start
110 fp.e_end = self.e_end
111 fp.m_extra = self.m_extra
112
113 m_width = self.m_width
114 e_max = self.e_max
115 e_width = self.e_width
116
117 self.mzero = Const(0, (m_width, False))
118 m_msb = 1 << (self.m_width-2)
119 self.msb1 = Const(m_msb, (m_width, False))
120 self.m1s = Const(-1, (m_width, False))
121 self.P128 = Const(e_max, (e_width, True))
122 self.P127 = Const(e_max-1, (e_width, True))
123 self.N127 = Const(-(e_max-1), (e_width, True))
124 self.N126 = Const(-(e_max-2), (e_width, True))
125
126 def create(self, s, e, m):
127 """ creates a value from sign / exponent / mantissa
128
129 bias is added here, to the exponent.
130
131 NOTE: order is important, because e_start/e_end can be
132 a bit too long (overwriting s).
133 """
134 return [
135 self.v[0:self.e_start].eq(m), # mantissa
136 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
137 self.v[-1].eq(s), # sign
138 ]
139
140 def _nan(self, s):
141 return (s, self.fp.P128, 1 << (self.e_start-1))
142
143 def _inf(self, s):
144 return (s, self.fp.P128, 0)
145
146 def _zero(self, s):
147 return (s, self.fp.N127, 0)
148
149 def nan(self, s):
150 return self.create(*self._nan(s))
151
152 def inf(self, s):
153 return self.create(*self._inf(s))
154
155 def zero(self, s):
156 return self.create(*self._zero(s))
157
158 def create2(self, s, e, m):
159 """ creates a value from sign / exponent / mantissa
160
161 bias is added here, to the exponent
162 """
163 e = e + self.P127 # exp (add on bias)
164 return Cat(m[0:self.e_start],
165 e[0:self.e_end-self.e_start],
166 s)
167
168 def nan2(self, s):
169 return self.create2(s, self.P128, self.msb1)
170
171 def inf2(self, s):
172 return self.create2(s, self.P128, self.mzero)
173
174 def zero2(self, s):
175 return self.create2(s, self.N127, self.mzero)
176
177 def __iter__(self):
178 yield self.s
179 yield self.e
180 yield self.m
181
182 def eq(self, inp):
183 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
184
185
186 class FPNumBase(FPNumBaseRecord, Elaboratable):
187 """ Floating-point Base Number Class
188 """
189
190 def __init__(self, fp):
191 fp.drop_in(self)
192 self.fp = fp
193 e_width = fp.e_width
194
195 self.is_nan = Signal(reset_less=True)
196 self.is_zero = Signal(reset_less=True)
197 self.is_inf = Signal(reset_less=True)
198 self.is_overflowed = Signal(reset_less=True)
199 self.is_denormalised = Signal(reset_less=True)
200 self.exp_128 = Signal(reset_less=True)
201 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
202 self.exp_lt_n126 = Signal(reset_less=True)
203 self.exp_gt_n126 = Signal(reset_less=True)
204 self.exp_gt127 = Signal(reset_less=True)
205 self.exp_n127 = Signal(reset_less=True)
206 self.exp_n126 = Signal(reset_less=True)
207 self.m_zero = Signal(reset_less=True)
208 self.m_msbzero = Signal(reset_less=True)
209
210 def elaborate(self, platform):
211 m = Module()
212 m.d.comb += self.is_nan.eq(self._is_nan())
213 m.d.comb += self.is_zero.eq(self._is_zero())
214 m.d.comb += self.is_inf.eq(self._is_inf())
215 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
216 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
217 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
218 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
219 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
220 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
221 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
222 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
223 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
224 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
225 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
226
227 return m
228
229 def _is_nan(self):
230 return (self.exp_128) & (~self.m_zero)
231
232 def _is_inf(self):
233 return (self.exp_128) & (self.m_zero)
234
235 def _is_zero(self):
236 return (self.exp_n127) & (self.m_zero)
237
238 def _is_overflowed(self):
239 return self.exp_gt127
240
241 def _is_denormalised(self):
242 return (self.exp_n126) & (self.m_msbzero)
243
244
245 class FPNumOut(FPNumBase):
246 """ Floating-point Number Class
247
248 Contains signals for an incoming copy of the value, decoded into
249 sign / exponent / mantissa.
250 Also contains encoding functions, creation and recognition of
251 zero, NaN and inf (all signed)
252
253 Four extra bits are included in the mantissa: the top bit
254 (m[-1]) is effectively a carry-overflow. The other three are
255 guard (m[2]), round (m[1]), and sticky (m[0])
256 """
257
258 def __init__(self, fp):
259 FPNumBase.__init__(self, fp)
260
261 def elaborate(self, platform):
262 m = FPNumBase.elaborate(self, platform)
263
264 return m
265
266
267 class MultiShiftRMerge(Elaboratable):
268 """ shifts down (right) and merges lower bits into m[0].
269 m[0] is the "sticky" bit, basically
270 """
271
272 def __init__(self, width, s_max=None):
273 if s_max is None:
274 s_max = int(log(width) / log(2))
275 self.smax = s_max
276 self.m = Signal(width, reset_less=True)
277 self.inp = Signal(width, reset_less=True)
278 self.diff = Signal(s_max, reset_less=True)
279 self.width = width
280
281 def elaborate(self, platform):
282 m = Module()
283
284 rs = Signal(self.width, reset_less=True)
285 m_mask = Signal(self.width, reset_less=True)
286 smask = Signal(self.width, reset_less=True)
287 stickybit = Signal(reset_less=True)
288 maxslen = Signal(self.smax, reset_less=True)
289 maxsleni = Signal(self.smax, reset_less=True)
290
291 sm = MultiShift(self.width-1)
292 m0s = Const(0, self.width-1)
293 mw = Const(self.width-1, len(self.diff))
294 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
295 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
296 ]
297
298 m.d.comb += [
299 # shift mantissa by maxslen, mask by inverse
300 rs.eq(sm.rshift(self.inp[1:], maxslen)),
301 m_mask.eq(sm.rshift(~m0s, maxsleni)),
302 smask.eq(self.inp[1:] & m_mask),
303 # sticky bit combines all mask (and mantissa low bit)
304 stickybit.eq(smask.bool() | self.inp[0]),
305 # mantissa result contains m[0] already.
306 self.m.eq(Cat(stickybit, rs))
307 ]
308 return m
309
310
311 class FPNumShift(FPNumBase, Elaboratable):
312 """ Floating-point Number Class for shifting
313 """
314
315 def __init__(self, mainm, op, inv, width, m_extra=True):
316 FPNumBase.__init__(self, width, m_extra)
317 self.latch_in = Signal()
318 self.mainm = mainm
319 self.inv = inv
320 self.op = op
321
322 def elaborate(self, platform):
323 m = FPNumBase.elaborate(self, platform)
324
325 m.d.comb += self.s.eq(op.s)
326 m.d.comb += self.e.eq(op.e)
327 m.d.comb += self.m.eq(op.m)
328
329 with self.mainm.State("align"):
330 with m.If(self.e < self.inv.e):
331 m.d.sync += self.shift_down()
332
333 return m
334
335 def shift_down(self, inp):
336 """ shifts a mantissa down by one. exponent is increased to compensate
337
338 accuracy is lost as a result in the mantissa however there are 3
339 guard bits (the latter of which is the "sticky" bit)
340 """
341 return [self.e.eq(inp.e + 1),
342 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
343 ]
344
345 def shift_down_multi(self, diff):
346 """ shifts a mantissa down. exponent is increased to compensate
347
348 accuracy is lost as a result in the mantissa however there are 3
349 guard bits (the latter of which is the "sticky" bit)
350
351 this code works by variable-shifting the mantissa by up to
352 its maximum bit-length: no point doing more (it'll still be
353 zero).
354
355 the sticky bit is computed by shifting a batch of 1s by
356 the same amount, which will introduce zeros. it's then
357 inverted and used as a mask to get the LSBs of the mantissa.
358 those are then |'d into the sticky bit.
359 """
360 sm = MultiShift(self.width)
361 mw = Const(self.m_width-1, len(diff))
362 maxslen = Mux(diff > mw, mw, diff)
363 rs = sm.rshift(self.m[1:], maxslen)
364 maxsleni = mw - maxslen
365 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
366
367 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
368 return [self.e.eq(self.e + diff),
369 self.m.eq(Cat(stickybits, rs))
370 ]
371
372 def shift_up_multi(self, diff):
373 """ shifts a mantissa up. exponent is decreased to compensate
374 """
375 sm = MultiShift(self.width)
376 mw = Const(self.m_width, len(diff))
377 maxslen = Mux(diff > mw, mw, diff)
378
379 return [self.e.eq(self.e - diff),
380 self.m.eq(sm.lshift(self.m, maxslen))
381 ]
382
383
384 class FPNumDecode(FPNumBase):
385 """ Floating-point Number Class
386
387 Contains signals for an incoming copy of the value, decoded into
388 sign / exponent / mantissa.
389 Also contains encoding functions, creation and recognition of
390 zero, NaN and inf (all signed)
391
392 Four extra bits are included in the mantissa: the top bit
393 (m[-1]) is effectively a carry-overflow. The other three are
394 guard (m[2]), round (m[1]), and sticky (m[0])
395 """
396
397 def __init__(self, op, fp):
398 FPNumBase.__init__(self, fp)
399 self.op = op
400
401 def elaborate(self, platform):
402 m = FPNumBase.elaborate(self, platform)
403
404 m.d.comb += self.decode(self.v)
405
406 return m
407
408 def decode(self, v):
409 """ decodes a latched value into sign / exponent / mantissa
410
411 bias is subtracted here, from the exponent. exponent
412 is extended to 10 bits so that subtract 127 is done on
413 a 10-bit number
414 """
415 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
416 #print ("decode", self.e_end)
417 return [self.m.eq(Cat(*args)), # mantissa
418 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
419 self.s.eq(v[-1]), # sign
420 ]
421
422
423 class FPNumIn(FPNumBase):
424 """ Floating-point Number Class
425
426 Contains signals for an incoming copy of the value, decoded into
427 sign / exponent / mantissa.
428 Also contains encoding functions, creation and recognition of
429 zero, NaN and inf (all signed)
430
431 Four extra bits are included in the mantissa: the top bit
432 (m[-1]) is effectively a carry-overflow. The other three are
433 guard (m[2]), round (m[1]), and sticky (m[0])
434 """
435
436 def __init__(self, op, fp):
437 FPNumBase.__init__(self, fp)
438 self.latch_in = Signal()
439 self.op = op
440
441 def decode2(self, m):
442 """ decodes a latched value into sign / exponent / mantissa
443
444 bias is subtracted here, from the exponent. exponent
445 is extended to 10 bits so that subtract 127 is done on
446 a 10-bit number
447 """
448 v = self.v
449 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
450 #print ("decode", self.e_end)
451 res = ObjectProxy(m, pipemode=False)
452 res.m = Cat(*args) # mantissa
453 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
454 res.s = v[-1] # sign
455 return res
456
457 def decode(self, v):
458 """ decodes a latched value into sign / exponent / mantissa
459
460 bias is subtracted here, from the exponent. exponent
461 is extended to 10 bits so that subtract 127 is done on
462 a 10-bit number
463 """
464 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
465 #print ("decode", self.e_end)
466 return [self.m.eq(Cat(*args)), # mantissa
467 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
468 self.s.eq(v[-1]), # sign
469 ]
470
471 def shift_down(self, inp):
472 """ shifts a mantissa down by one. exponent is increased to compensate
473
474 accuracy is lost as a result in the mantissa however there are 3
475 guard bits (the latter of which is the "sticky" bit)
476 """
477 return [self.e.eq(inp.e + 1),
478 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
479 ]
480
481 def shift_down_multi(self, diff, inp=None):
482 """ shifts a mantissa down. exponent is increased to compensate
483
484 accuracy is lost as a result in the mantissa however there are 3
485 guard bits (the latter of which is the "sticky" bit)
486
487 this code works by variable-shifting the mantissa by up to
488 its maximum bit-length: no point doing more (it'll still be
489 zero).
490
491 the sticky bit is computed by shifting a batch of 1s by
492 the same amount, which will introduce zeros. it's then
493 inverted and used as a mask to get the LSBs of the mantissa.
494 those are then |'d into the sticky bit.
495 """
496 if inp is None:
497 inp = self
498 sm = MultiShift(self.width)
499 mw = Const(self.m_width-1, len(diff))
500 maxslen = Mux(diff > mw, mw, diff)
501 rs = sm.rshift(inp.m[1:], maxslen)
502 maxsleni = mw - maxslen
503 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
504
505 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
506 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
507 return [self.e.eq(inp.e + diff),
508 self.m.eq(Cat(stickybit, rs))
509 ]
510
511 def shift_up_multi(self, diff):
512 """ shifts a mantissa up. exponent is decreased to compensate
513 """
514 sm = MultiShift(self.width)
515 mw = Const(self.m_width, len(diff))
516 maxslen = Mux(diff > mw, mw, diff)
517
518 return [self.e.eq(self.e - diff),
519 self.m.eq(sm.lshift(self.m, maxslen))
520 ]
521
522
523 class Trigger(Elaboratable):
524 def __init__(self):
525
526 self.stb = Signal(reset=0)
527 self.ack = Signal()
528 self.trigger = Signal(reset_less=True)
529
530 def elaborate(self, platform):
531 m = Module()
532 m.d.comb += self.trigger.eq(self.stb & self.ack)
533 return m
534
535 def eq(self, inp):
536 return [self.stb.eq(inp.stb),
537 self.ack.eq(inp.ack)
538 ]
539
540 def ports(self):
541 return [self.stb, self.ack]
542
543
544 class FPOpIn(PrevControl):
545 def __init__(self, width):
546 PrevControl.__init__(self)
547 self.width = width
548
549 @property
550 def v(self):
551 return self.data_i
552
553 def chain_inv(self, in_op, extra=None):
554 stb = in_op.stb
555 if extra is not None:
556 stb = stb & extra
557 return [self.v.eq(in_op.v), # receive value
558 self.stb.eq(stb), # receive STB
559 in_op.ack.eq(~self.ack), # send ACK
560 ]
561
562 def chain_from(self, in_op, extra=None):
563 stb = in_op.stb
564 if extra is not None:
565 stb = stb & extra
566 return [self.v.eq(in_op.v), # receive value
567 self.stb.eq(stb), # receive STB
568 in_op.ack.eq(self.ack), # send ACK
569 ]
570
571
572 class FPOpOut(NextControl):
573 def __init__(self, width):
574 NextControl.__init__(self)
575 self.width = width
576
577 @property
578 def v(self):
579 return self.data_o
580
581 def chain_inv(self, in_op, extra=None):
582 stb = in_op.stb
583 if extra is not None:
584 stb = stb & extra
585 return [self.v.eq(in_op.v), # receive value
586 self.stb.eq(stb), # receive STB
587 in_op.ack.eq(~self.ack), # send ACK
588 ]
589
590 def chain_from(self, in_op, extra=None):
591 stb = in_op.stb
592 if extra is not None:
593 stb = stb & extra
594 return [self.v.eq(in_op.v), # receive value
595 self.stb.eq(stb), # receive STB
596 in_op.ack.eq(self.ack), # send ACK
597 ]
598
599
600 class Overflow: # (Elaboratable):
601 def __init__(self):
602 self.guard = Signal(reset_less=True) # tot[2]
603 self.round_bit = Signal(reset_less=True) # tot[1]
604 self.sticky = Signal(reset_less=True) # tot[0]
605 self.m0 = Signal(reset_less=True) # mantissa zero bit
606
607 #self.roundz = Signal(reset_less=True)
608
609 def __iter__(self):
610 yield self.guard
611 yield self.round_bit
612 yield self.sticky
613 yield self.m0
614
615 def eq(self, inp):
616 return [self.guard.eq(inp.guard),
617 self.round_bit.eq(inp.round_bit),
618 self.sticky.eq(inp.sticky),
619 self.m0.eq(inp.m0)]
620
621 @property
622 def roundz(self):
623 return self.guard & (self.round_bit | self.sticky | self.m0)
624
625
626 class FPBase:
627 """ IEEE754 Floating Point Base Class
628
629 contains common functions for FP manipulation, such as
630 extracting and packing operands, normalisation, denormalisation,
631 rounding etc.
632 """
633
634 def get_op(self, m, op, v, next_state):
635 """ this function moves to the next state and copies the operand
636 when both stb and ack are 1.
637 acknowledgement is sent by setting ack to ZERO.
638 """
639 res = v.decode2(m)
640 ack = Signal()
641 with m.If((op.ready_o) & (op.valid_i_test)):
642 m.next = next_state
643 # op is latched in from FPNumIn class on same ack/stb
644 m.d.comb += ack.eq(0)
645 with m.Else():
646 m.d.comb += ack.eq(1)
647 return [res, ack]
648
649 def denormalise(self, m, a):
650 """ denormalises a number. this is probably the wrong name for
651 this function. for normalised numbers (exponent != minimum)
652 one *extra* bit (the implicit 1) is added *back in*.
653 for denormalised numbers, the mantissa is left alone
654 and the exponent increased by 1.
655
656 both cases *effectively multiply the number stored by 2*,
657 which has to be taken into account when extracting the result.
658 """
659 with m.If(a.exp_n127):
660 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
661 with m.Else():
662 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
663
664 def op_normalise(self, m, op, next_state):
665 """ operand normalisation
666 NOTE: just like "align", this one keeps going round every clock
667 until the result's exponent is within acceptable "range"
668 """
669 with m.If((op.m[-1] == 0)): # check last bit of mantissa
670 m.d.sync += [
671 op.e.eq(op.e - 1), # DECREASE exponent
672 op.m.eq(op.m << 1), # shift mantissa UP
673 ]
674 with m.Else():
675 m.next = next_state
676
677 def normalise_1(self, m, z, of, next_state):
678 """ first stage normalisation
679
680 NOTE: just like "align", this one keeps going round every clock
681 until the result's exponent is within acceptable "range"
682 NOTE: the weirdness of reassigning guard and round is due to
683 the extra mantissa bits coming from tot[0..2]
684 """
685 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
686 m.d.sync += [
687 z.e.eq(z.e - 1), # DECREASE exponent
688 z.m.eq(z.m << 1), # shift mantissa UP
689 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
690 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
691 of.round_bit.eq(0), # reset round bit
692 of.m0.eq(of.guard),
693 ]
694 with m.Else():
695 m.next = next_state
696
697 def normalise_2(self, m, z, of, next_state):
698 """ second stage normalisation
699
700 NOTE: just like "align", this one keeps going round every clock
701 until the result's exponent is within acceptable "range"
702 NOTE: the weirdness of reassigning guard and round is due to
703 the extra mantissa bits coming from tot[0..2]
704 """
705 with m.If(z.e < z.fp.N126):
706 m.d.sync += [
707 z.e.eq(z.e + 1), # INCREASE exponent
708 z.m.eq(z.m >> 1), # shift mantissa DOWN
709 of.guard.eq(z.m[0]),
710 of.m0.eq(z.m[1]),
711 of.round_bit.eq(of.guard),
712 of.sticky.eq(of.sticky | of.round_bit)
713 ]
714 with m.Else():
715 m.next = next_state
716
717 def roundz(self, m, z, roundz):
718 """ performs rounding on the output. TODO: different kinds of rounding
719 """
720 with m.If(roundz):
721 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
722 with m.If(z.m == z.fp.m1s): # all 1s
723 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
724
725 def corrections(self, m, z, next_state):
726 """ denormalisation and sign-bug corrections
727 """
728 m.next = next_state
729 # denormalised, correct exponent to zero
730 with m.If(z.is_denormalised):
731 m.d.sync += z.e.eq(z.fp.N127)
732
733 def pack(self, m, z, next_state):
734 """ packs the result into the output (detects overflow->Inf)
735 """
736 m.next = next_state
737 # if overflow occurs, return inf
738 with m.If(z.is_overflowed):
739 m.d.sync += z.inf(z.s)
740 with m.Else():
741 m.d.sync += z.create(z.s, z.e, z.m)
742
743 def put_z(self, m, z, out_z, next_state):
744 """ put_z: stores the result in the output. raises stb and waits
745 for ack to be set to 1 before moving to the next state.
746 resets stb back to zero when that occurs, as acknowledgement.
747 """
748 m.d.sync += [
749 out_z.v.eq(z.v)
750 ]
751 with m.If(out_z.valid_o & out_z.ready_i_test):
752 m.d.sync += out_z.valid_o.eq(0)
753 m.next = next_state
754 with m.Else():
755 m.d.sync += out_z.valid_o.eq(1)
756
757
758 class FPState(FPBase):
759 def __init__(self, state_from):
760 self.state_from = state_from
761
762 def set_inputs(self, inputs):
763 self.inputs = inputs
764 for k, v in inputs.items():
765 setattr(self, k, v)
766
767 def set_outputs(self, outputs):
768 self.outputs = outputs
769 for k, v in outputs.items():
770 setattr(self, k, v)
771
772
773 class FPID:
774 def __init__(self, id_wid):
775 self.id_wid = id_wid
776 if self.id_wid:
777 self.in_mid = Signal(id_wid, reset_less=True)
778 self.out_mid = Signal(id_wid, reset_less=True)
779 else:
780 self.in_mid = None
781 self.out_mid = None
782
783 def idsync(self, m):
784 if self.id_wid is not None:
785 m.d.sync += self.out_mid.eq(self.in_mid)