add enough to "extra" exponent to cover FP64 to FP16 fcvt
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBaseRecord:
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True, e_extra=False):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width - 1 # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 if e_extra:
80 self.e_extra = 6 # enough to cover FP64 when converting to FP16
81 e_width += self.e_extra
82 else:
83 self.e_extra = 0
84 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
85 self.m_width = m_width
86 self.e_width = e_width
87 self.e_start = self.rmw
88 self.e_end = self.rmw + self.e_width - 2 # for decoding
89
90 self.v = Signal(width, reset_less=True) # Latched copy of value
91 self.m = Signal(m_width, reset_less=True) # Mantissa
92 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
93 self.s = Signal(reset_less=True) # Sign bit
94
95 self.fp = self
96 self.drop_in(self)
97
98 def drop_in(self, fp):
99 fp.s = self.s
100 fp.e = self.e
101 fp.m = self.m
102 fp.v = self.v
103 fp.rmw = self.rmw
104 fp.width = self.width
105 fp.e_width = self.e_width
106 fp.e_max = self.e_max
107 fp.m_width = self.m_width
108 fp.e_start = self.e_start
109 fp.e_end = self.e_end
110 fp.m_extra = self.m_extra
111
112 m_width = self.m_width
113 e_max = self.e_max
114 e_width = self.e_width
115
116 self.mzero = Const(0, (m_width, False))
117 m_msb = 1<<(self.m_width-2)
118 self.msb1 = Const(m_msb, (m_width, False))
119 self.m1s = Const(-1, (m_width, False))
120 self.P128 = Const(e_max, (e_width, True))
121 self.P127 = Const(e_max-1, (e_width, True))
122 self.N127 = Const(-(e_max-1), (e_width, True))
123 self.N126 = Const(-(e_max-2), (e_width, True))
124
125 def create(self, s, e, m):
126 """ creates a value from sign / exponent / mantissa
127
128 bias is added here, to the exponent.
129
130 NOTE: order is important, because e_start/e_end can be
131 a bit too long (overwriting s).
132 """
133 return [
134 self.v[0:self.e_start].eq(m), # mantissa
135 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add on bias)
136 self.v[-1].eq(s), # sign
137 ]
138
139 def _nan(self, s):
140 return (s, self.fp.P128, 1<<(self.e_start-1))
141
142 def _inf(self, s):
143 return (s, self.fp.P128, 0)
144
145 def _zero(self, s):
146 return (s, self.fp.N127, 0)
147
148 def nan(self, s):
149 return self.create(*self._nan(s))
150
151 def inf(self, s):
152 return self.create(*self._inf(s))
153
154 def zero(self, s):
155 return self.create(*self._zero(s))
156
157 def create2(self, s, e, m):
158 """ creates a value from sign / exponent / mantissa
159
160 bias is added here, to the exponent
161 """
162 e = e + self.P127 # exp (add on bias)
163 return Cat(m[0:self.e_start],
164 e[0:self.e_end-self.e_start],
165 s)
166
167 def nan2(self, s):
168 return self.create2(s, self.P128, self.msb1)
169
170 def inf2(self, s):
171 return self.create2(s, self.P128, self.mzero)
172
173 def zero2(self, s):
174 return self.create2(s, self.N127, self.mzero)
175
176 def __iter__(self):
177 yield self.s
178 yield self.e
179 yield self.m
180
181 def eq(self, inp):
182 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
183
184
185 class FPNumBase(FPNumBaseRecord, Elaboratable):
186 """ Floating-point Base Number Class
187 """
188 def __init__(self, fp):
189 fp.drop_in(self)
190 self.fp = fp
191 e_width = fp.e_width
192
193 self.is_nan = Signal(reset_less=True)
194 self.is_zero = Signal(reset_less=True)
195 self.is_inf = Signal(reset_less=True)
196 self.is_overflowed = Signal(reset_less=True)
197 self.is_denormalised = Signal(reset_less=True)
198 self.exp_128 = Signal(reset_less=True)
199 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
200 self.exp_lt_n126 = Signal(reset_less=True)
201 self.exp_gt_n126 = Signal(reset_less=True)
202 self.exp_gt127 = Signal(reset_less=True)
203 self.exp_n127 = Signal(reset_less=True)
204 self.exp_n126 = Signal(reset_less=True)
205 self.m_zero = Signal(reset_less=True)
206 self.m_msbzero = Signal(reset_less=True)
207
208 def elaborate(self, platform):
209 m = Module()
210 m.d.comb += self.is_nan.eq(self._is_nan())
211 m.d.comb += self.is_zero.eq(self._is_zero())
212 m.d.comb += self.is_inf.eq(self._is_inf())
213 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
214 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
215 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
216 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
217 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
218 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
219 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
220 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
221 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
222 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
223 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
224
225 return m
226
227 def _is_nan(self):
228 return (self.exp_128) & (~self.m_zero)
229
230 def _is_inf(self):
231 return (self.exp_128) & (self.m_zero)
232
233 def _is_zero(self):
234 return (self.exp_n127) & (self.m_zero)
235
236 def _is_overflowed(self):
237 return self.exp_gt127
238
239 def _is_denormalised(self):
240 return (self.exp_n126) & (self.m_msbzero)
241
242
243 class FPNumOut(FPNumBase):
244 """ Floating-point Number Class
245
246 Contains signals for an incoming copy of the value, decoded into
247 sign / exponent / mantissa.
248 Also contains encoding functions, creation and recognition of
249 zero, NaN and inf (all signed)
250
251 Four extra bits are included in the mantissa: the top bit
252 (m[-1]) is effectively a carry-overflow. The other three are
253 guard (m[2]), round (m[1]), and sticky (m[0])
254 """
255 def __init__(self, fp):
256 FPNumBase.__init__(self, fp)
257
258 def elaborate(self, platform):
259 m = FPNumBase.elaborate(self, platform)
260
261 return m
262
263
264 class MultiShiftRMerge(Elaboratable):
265 """ shifts down (right) and merges lower bits into m[0].
266 m[0] is the "sticky" bit, basically
267 """
268 def __init__(self, width, s_max=None):
269 if s_max is None:
270 s_max = int(log(width) / log(2))
271 self.smax = s_max
272 self.m = Signal(width, reset_less=True)
273 self.inp = Signal(width, reset_less=True)
274 self.diff = Signal(s_max, reset_less=True)
275 self.width = width
276
277 def elaborate(self, platform):
278 m = Module()
279
280 rs = Signal(self.width, reset_less=True)
281 m_mask = Signal(self.width, reset_less=True)
282 smask = Signal(self.width, reset_less=True)
283 stickybit = Signal(reset_less=True)
284 maxslen = Signal(self.smax, reset_less=True)
285 maxsleni = Signal(self.smax, reset_less=True)
286
287 sm = MultiShift(self.width-1)
288 m0s = Const(0, self.width-1)
289 mw = Const(self.width-1, len(self.diff))
290 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
291 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
292 ]
293
294 m.d.comb += [
295 # shift mantissa by maxslen, mask by inverse
296 rs.eq(sm.rshift(self.inp[1:], maxslen)),
297 m_mask.eq(sm.rshift(~m0s, maxsleni)),
298 smask.eq(self.inp[1:] & m_mask),
299 # sticky bit combines all mask (and mantissa low bit)
300 stickybit.eq(smask.bool() | self.inp[0]),
301 # mantissa result contains m[0] already.
302 self.m.eq(Cat(stickybit, rs))
303 ]
304 return m
305
306
307 class FPNumShift(FPNumBase, Elaboratable):
308 """ Floating-point Number Class for shifting
309 """
310 def __init__(self, mainm, op, inv, width, m_extra=True):
311 FPNumBase.__init__(self, width, m_extra)
312 self.latch_in = Signal()
313 self.mainm = mainm
314 self.inv = inv
315 self.op = op
316
317 def elaborate(self, platform):
318 m = FPNumBase.elaborate(self, platform)
319
320 m.d.comb += self.s.eq(op.s)
321 m.d.comb += self.e.eq(op.e)
322 m.d.comb += self.m.eq(op.m)
323
324 with self.mainm.State("align"):
325 with m.If(self.e < self.inv.e):
326 m.d.sync += self.shift_down()
327
328 return m
329
330 def shift_down(self, inp):
331 """ shifts a mantissa down by one. exponent is increased to compensate
332
333 accuracy is lost as a result in the mantissa however there are 3
334 guard bits (the latter of which is the "sticky" bit)
335 """
336 return [self.e.eq(inp.e + 1),
337 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
338 ]
339
340 def shift_down_multi(self, diff):
341 """ shifts a mantissa down. exponent is increased to compensate
342
343 accuracy is lost as a result in the mantissa however there are 3
344 guard bits (the latter of which is the "sticky" bit)
345
346 this code works by variable-shifting the mantissa by up to
347 its maximum bit-length: no point doing more (it'll still be
348 zero).
349
350 the sticky bit is computed by shifting a batch of 1s by
351 the same amount, which will introduce zeros. it's then
352 inverted and used as a mask to get the LSBs of the mantissa.
353 those are then |'d into the sticky bit.
354 """
355 sm = MultiShift(self.width)
356 mw = Const(self.m_width-1, len(diff))
357 maxslen = Mux(diff > mw, mw, diff)
358 rs = sm.rshift(self.m[1:], maxslen)
359 maxsleni = mw - maxslen
360 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
361
362 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
363 return [self.e.eq(self.e + diff),
364 self.m.eq(Cat(stickybits, rs))
365 ]
366
367 def shift_up_multi(self, diff):
368 """ shifts a mantissa up. exponent is decreased to compensate
369 """
370 sm = MultiShift(self.width)
371 mw = Const(self.m_width, len(diff))
372 maxslen = Mux(diff > mw, mw, diff)
373
374 return [self.e.eq(self.e - diff),
375 self.m.eq(sm.lshift(self.m, maxslen))
376 ]
377
378
379 class FPNumDecode(FPNumBase):
380 """ Floating-point Number Class
381
382 Contains signals for an incoming copy of the value, decoded into
383 sign / exponent / mantissa.
384 Also contains encoding functions, creation and recognition of
385 zero, NaN and inf (all signed)
386
387 Four extra bits are included in the mantissa: the top bit
388 (m[-1]) is effectively a carry-overflow. The other three are
389 guard (m[2]), round (m[1]), and sticky (m[0])
390 """
391 def __init__(self, op, fp):
392 FPNumBase.__init__(self, fp)
393 self.op = op
394
395 def elaborate(self, platform):
396 m = FPNumBase.elaborate(self, platform)
397
398 m.d.comb += self.decode(self.v)
399
400 return m
401
402 def decode(self, v):
403 """ decodes a latched value into sign / exponent / mantissa
404
405 bias is subtracted here, from the exponent. exponent
406 is extended to 10 bits so that subtract 127 is done on
407 a 10-bit number
408 """
409 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
410 #print ("decode", self.e_end)
411 return [self.m.eq(Cat(*args)), # mantissa
412 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
413 self.s.eq(v[-1]), # sign
414 ]
415
416 class FPNumIn(FPNumBase):
417 """ Floating-point Number Class
418
419 Contains signals for an incoming copy of the value, decoded into
420 sign / exponent / mantissa.
421 Also contains encoding functions, creation and recognition of
422 zero, NaN and inf (all signed)
423
424 Four extra bits are included in the mantissa: the top bit
425 (m[-1]) is effectively a carry-overflow. The other three are
426 guard (m[2]), round (m[1]), and sticky (m[0])
427 """
428 def __init__(self, op, fp):
429 FPNumBase.__init__(self, fp)
430 self.latch_in = Signal()
431 self.op = op
432
433 def decode2(self, m):
434 """ decodes a latched value into sign / exponent / mantissa
435
436 bias is subtracted here, from the exponent. exponent
437 is extended to 10 bits so that subtract 127 is done on
438 a 10-bit number
439 """
440 v = self.v
441 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
442 #print ("decode", self.e_end)
443 res = ObjectProxy(m, pipemode=False)
444 res.m = Cat(*args) # mantissa
445 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
446 res.s = v[-1] # sign
447 return res
448
449 def decode(self, v):
450 """ decodes a latched value into sign / exponent / mantissa
451
452 bias is subtracted here, from the exponent. exponent
453 is extended to 10 bits so that subtract 127 is done on
454 a 10-bit number
455 """
456 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
457 #print ("decode", self.e_end)
458 return [self.m.eq(Cat(*args)), # mantissa
459 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
460 self.s.eq(v[-1]), # sign
461 ]
462
463 def shift_down(self, inp):
464 """ shifts a mantissa down by one. exponent is increased to compensate
465
466 accuracy is lost as a result in the mantissa however there are 3
467 guard bits (the latter of which is the "sticky" bit)
468 """
469 return [self.e.eq(inp.e + 1),
470 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
471 ]
472
473 def shift_down_multi(self, diff, inp=None):
474 """ shifts a mantissa down. exponent is increased to compensate
475
476 accuracy is lost as a result in the mantissa however there are 3
477 guard bits (the latter of which is the "sticky" bit)
478
479 this code works by variable-shifting the mantissa by up to
480 its maximum bit-length: no point doing more (it'll still be
481 zero).
482
483 the sticky bit is computed by shifting a batch of 1s by
484 the same amount, which will introduce zeros. it's then
485 inverted and used as a mask to get the LSBs of the mantissa.
486 those are then |'d into the sticky bit.
487 """
488 if inp is None:
489 inp = self
490 sm = MultiShift(self.width)
491 mw = Const(self.m_width-1, len(diff))
492 maxslen = Mux(diff > mw, mw, diff)
493 rs = sm.rshift(inp.m[1:], maxslen)
494 maxsleni = mw - maxslen
495 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
496
497 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
498 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
499 return [self.e.eq(inp.e + diff),
500 self.m.eq(Cat(stickybit, rs))
501 ]
502
503 def shift_up_multi(self, diff):
504 """ shifts a mantissa up. exponent is decreased to compensate
505 """
506 sm = MultiShift(self.width)
507 mw = Const(self.m_width, len(diff))
508 maxslen = Mux(diff > mw, mw, diff)
509
510 return [self.e.eq(self.e - diff),
511 self.m.eq(sm.lshift(self.m, maxslen))
512 ]
513
514 class Trigger(Elaboratable):
515 def __init__(self):
516
517 self.stb = Signal(reset=0)
518 self.ack = Signal()
519 self.trigger = Signal(reset_less=True)
520
521 def elaborate(self, platform):
522 m = Module()
523 m.d.comb += self.trigger.eq(self.stb & self.ack)
524 return m
525
526 def eq(self, inp):
527 return [self.stb.eq(inp.stb),
528 self.ack.eq(inp.ack)
529 ]
530
531 def ports(self):
532 return [self.stb, self.ack]
533
534
535 class FPOpIn(PrevControl):
536 def __init__(self, width):
537 PrevControl.__init__(self)
538 self.width = width
539
540 @property
541 def v(self):
542 return self.data_i
543
544 def chain_inv(self, in_op, extra=None):
545 stb = in_op.stb
546 if extra is not None:
547 stb = stb & extra
548 return [self.v.eq(in_op.v), # receive value
549 self.stb.eq(stb), # receive STB
550 in_op.ack.eq(~self.ack), # send ACK
551 ]
552
553 def chain_from(self, in_op, extra=None):
554 stb = in_op.stb
555 if extra is not None:
556 stb = stb & extra
557 return [self.v.eq(in_op.v), # receive value
558 self.stb.eq(stb), # receive STB
559 in_op.ack.eq(self.ack), # send ACK
560 ]
561
562
563 class FPOpOut(NextControl):
564 def __init__(self, width):
565 NextControl.__init__(self)
566 self.width = width
567
568 @property
569 def v(self):
570 return self.data_o
571
572 def chain_inv(self, in_op, extra=None):
573 stb = in_op.stb
574 if extra is not None:
575 stb = stb & extra
576 return [self.v.eq(in_op.v), # receive value
577 self.stb.eq(stb), # receive STB
578 in_op.ack.eq(~self.ack), # send ACK
579 ]
580
581 def chain_from(self, in_op, extra=None):
582 stb = in_op.stb
583 if extra is not None:
584 stb = stb & extra
585 return [self.v.eq(in_op.v), # receive value
586 self.stb.eq(stb), # receive STB
587 in_op.ack.eq(self.ack), # send ACK
588 ]
589
590
591 class Overflow: #(Elaboratable):
592 def __init__(self):
593 self.guard = Signal(reset_less=True) # tot[2]
594 self.round_bit = Signal(reset_less=True) # tot[1]
595 self.sticky = Signal(reset_less=True) # tot[0]
596 self.m0 = Signal(reset_less=True) # mantissa zero bit
597
598 #self.roundz = Signal(reset_less=True)
599
600 def __iter__(self):
601 yield self.guard
602 yield self.round_bit
603 yield self.sticky
604 yield self.m0
605
606 def eq(self, inp):
607 return [self.guard.eq(inp.guard),
608 self.round_bit.eq(inp.round_bit),
609 self.sticky.eq(inp.sticky),
610 self.m0.eq(inp.m0)]
611
612 @property
613 def roundz(self):
614 return self.guard & (self.round_bit | self.sticky | self.m0)
615
616
617 class FPBase:
618 """ IEEE754 Floating Point Base Class
619
620 contains common functions for FP manipulation, such as
621 extracting and packing operands, normalisation, denormalisation,
622 rounding etc.
623 """
624
625 def get_op(self, m, op, v, next_state):
626 """ this function moves to the next state and copies the operand
627 when both stb and ack are 1.
628 acknowledgement is sent by setting ack to ZERO.
629 """
630 res = v.decode2(m)
631 ack = Signal()
632 with m.If((op.ready_o) & (op.valid_i_test)):
633 m.next = next_state
634 # op is latched in from FPNumIn class on same ack/stb
635 m.d.comb += ack.eq(0)
636 with m.Else():
637 m.d.comb += ack.eq(1)
638 return [res, ack]
639
640 def denormalise(self, m, a):
641 """ denormalises a number. this is probably the wrong name for
642 this function. for normalised numbers (exponent != minimum)
643 one *extra* bit (the implicit 1) is added *back in*.
644 for denormalised numbers, the mantissa is left alone
645 and the exponent increased by 1.
646
647 both cases *effectively multiply the number stored by 2*,
648 which has to be taken into account when extracting the result.
649 """
650 with m.If(a.exp_n127):
651 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
652 with m.Else():
653 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
654
655 def op_normalise(self, m, op, next_state):
656 """ operand normalisation
657 NOTE: just like "align", this one keeps going round every clock
658 until the result's exponent is within acceptable "range"
659 """
660 with m.If((op.m[-1] == 0)): # check last bit of mantissa
661 m.d.sync +=[
662 op.e.eq(op.e - 1), # DECREASE exponent
663 op.m.eq(op.m << 1), # shift mantissa UP
664 ]
665 with m.Else():
666 m.next = next_state
667
668 def normalise_1(self, m, z, of, next_state):
669 """ first stage normalisation
670
671 NOTE: just like "align", this one keeps going round every clock
672 until the result's exponent is within acceptable "range"
673 NOTE: the weirdness of reassigning guard and round is due to
674 the extra mantissa bits coming from tot[0..2]
675 """
676 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
677 m.d.sync += [
678 z.e.eq(z.e - 1), # DECREASE exponent
679 z.m.eq(z.m << 1), # shift mantissa UP
680 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
681 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
682 of.round_bit.eq(0), # reset round bit
683 of.m0.eq(of.guard),
684 ]
685 with m.Else():
686 m.next = next_state
687
688 def normalise_2(self, m, z, of, next_state):
689 """ second stage normalisation
690
691 NOTE: just like "align", this one keeps going round every clock
692 until the result's exponent is within acceptable "range"
693 NOTE: the weirdness of reassigning guard and round is due to
694 the extra mantissa bits coming from tot[0..2]
695 """
696 with m.If(z.e < z.fp.N126):
697 m.d.sync +=[
698 z.e.eq(z.e + 1), # INCREASE exponent
699 z.m.eq(z.m >> 1), # shift mantissa DOWN
700 of.guard.eq(z.m[0]),
701 of.m0.eq(z.m[1]),
702 of.round_bit.eq(of.guard),
703 of.sticky.eq(of.sticky | of.round_bit)
704 ]
705 with m.Else():
706 m.next = next_state
707
708 def roundz(self, m, z, roundz):
709 """ performs rounding on the output. TODO: different kinds of rounding
710 """
711 with m.If(roundz):
712 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
713 with m.If(z.m == z.fp.m1s): # all 1s
714 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
715
716 def corrections(self, m, z, next_state):
717 """ denormalisation and sign-bug corrections
718 """
719 m.next = next_state
720 # denormalised, correct exponent to zero
721 with m.If(z.is_denormalised):
722 m.d.sync += z.e.eq(z.fp.N127)
723
724 def pack(self, m, z, next_state):
725 """ packs the result into the output (detects overflow->Inf)
726 """
727 m.next = next_state
728 # if overflow occurs, return inf
729 with m.If(z.is_overflowed):
730 m.d.sync += z.inf(z.s)
731 with m.Else():
732 m.d.sync += z.create(z.s, z.e, z.m)
733
734 def put_z(self, m, z, out_z, next_state):
735 """ put_z: stores the result in the output. raises stb and waits
736 for ack to be set to 1 before moving to the next state.
737 resets stb back to zero when that occurs, as acknowledgement.
738 """
739 m.d.sync += [
740 out_z.v.eq(z.v)
741 ]
742 with m.If(out_z.valid_o & out_z.ready_i_test):
743 m.d.sync += out_z.valid_o.eq(0)
744 m.next = next_state
745 with m.Else():
746 m.d.sync += out_z.valid_o.eq(1)
747
748
749 class FPState(FPBase):
750 def __init__(self, state_from):
751 self.state_from = state_from
752
753 def set_inputs(self, inputs):
754 self.inputs = inputs
755 for k,v in inputs.items():
756 setattr(self, k, v)
757
758 def set_outputs(self, outputs):
759 self.outputs = outputs
760 for k,v in outputs.items():
761 setattr(self, k, v)
762
763
764 class FPID:
765 def __init__(self, id_wid):
766 self.id_wid = id_wid
767 if self.id_wid:
768 self.in_mid = Signal(id_wid, reset_less=True)
769 self.out_mid = Signal(id_wid, reset_less=True)
770 else:
771 self.in_mid = None
772 self.out_mid = None
773
774 def idsync(self, m):
775 if self.id_wid is not None:
776 m.d.sync += self.out_mid.eq(self.in_mid)
777
778