add name to Overflow class, also recreate OverflowMod
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47
48 def rshift(self, op, s):
49 res = op >> s
50 return res[:len(op)]
51
52
53 class FPNumBaseRecord:
54 """ Floating-point Base Number Class
55 """
56
57 def __init__(self, width, m_extra=True, e_extra=False):
58 self.width = width
59 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
60 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
61 e_max = 1 << (e_width-3)
62 self.rmw = m_width - 1 # real mantissa width (not including extras)
63 self.e_max = e_max
64 if m_extra:
65 # mantissa extra bits (top,guard,round)
66 self.m_extra = 3
67 m_width += self.m_extra
68 else:
69 self.m_extra = 0
70 if e_extra:
71 self.e_extra = 6 # enough to cover FP64 when converting to FP16
72 e_width += self.e_extra
73 else:
74 self.e_extra = 0
75 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self.m_width = m_width
77 self.e_width = e_width
78 self.e_start = self.rmw
79 self.e_end = self.rmw + self.e_width - 2 # for decoding
80
81 self.v = Signal(width, reset_less=True) # Latched copy of value
82 self.m = Signal(m_width, reset_less=True) # Mantissa
83 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
84 self.s = Signal(reset_less=True) # Sign bit
85
86 self.fp = self
87 self.drop_in(self)
88
89 def drop_in(self, fp):
90 fp.s = self.s
91 fp.e = self.e
92 fp.m = self.m
93 fp.v = self.v
94 fp.rmw = self.rmw
95 fp.width = self.width
96 fp.e_width = self.e_width
97 fp.e_max = self.e_max
98 fp.m_width = self.m_width
99 fp.e_start = self.e_start
100 fp.e_end = self.e_end
101 fp.m_extra = self.m_extra
102
103 m_width = self.m_width
104 e_max = self.e_max
105 e_width = self.e_width
106
107 self.mzero = Const(0, (m_width, False))
108 m_msb = 1 << (self.m_width-2)
109 self.msb1 = Const(m_msb, (m_width, False))
110 self.m1s = Const(-1, (m_width, False))
111 self.P128 = Const(e_max, (e_width, True))
112 self.P127 = Const(e_max-1, (e_width, True))
113 self.N127 = Const(-(e_max-1), (e_width, True))
114 self.N126 = Const(-(e_max-2), (e_width, True))
115
116 def create(self, s, e, m):
117 """ creates a value from sign / exponent / mantissa
118
119 bias is added here, to the exponent.
120
121 NOTE: order is important, because e_start/e_end can be
122 a bit too long (overwriting s).
123 """
124 return [
125 self.v[0:self.e_start].eq(m), # mantissa
126 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
127 self.v[-1].eq(s), # sign
128 ]
129
130 def _nan(self, s):
131 return (s, self.fp.P128, 1 << (self.e_start-1))
132
133 def _inf(self, s):
134 return (s, self.fp.P128, 0)
135
136 def _zero(self, s):
137 return (s, self.fp.N127, 0)
138
139 def nan(self, s):
140 return self.create(*self._nan(s))
141
142 def inf(self, s):
143 return self.create(*self._inf(s))
144
145 def zero(self, s):
146 return self.create(*self._zero(s))
147
148 def create2(self, s, e, m):
149 """ creates a value from sign / exponent / mantissa
150
151 bias is added here, to the exponent
152 """
153 e = e + self.P127 # exp (add on bias)
154 return Cat(m[0:self.e_start],
155 e[0:self.e_end-self.e_start],
156 s)
157
158 def nan2(self, s):
159 return self.create2(s, self.P128, self.msb1)
160
161 def inf2(self, s):
162 return self.create2(s, self.P128, self.mzero)
163
164 def zero2(self, s):
165 return self.create2(s, self.N127, self.mzero)
166
167 def __iter__(self):
168 yield self.s
169 yield self.e
170 yield self.m
171
172 def eq(self, inp):
173 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
174
175
176 class FPNumBase(FPNumBaseRecord, Elaboratable):
177 """ Floating-point Base Number Class
178 """
179
180 def __init__(self, fp):
181 fp.drop_in(self)
182 self.fp = fp
183 e_width = fp.e_width
184
185 self.is_nan = Signal(reset_less=True)
186 self.is_zero = Signal(reset_less=True)
187 self.is_inf = Signal(reset_less=True)
188 self.is_overflowed = Signal(reset_less=True)
189 self.is_denormalised = Signal(reset_less=True)
190 self.exp_128 = Signal(reset_less=True)
191 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
192 self.exp_lt_n126 = Signal(reset_less=True)
193 self.exp_gt_n126 = Signal(reset_less=True)
194 self.exp_gt127 = Signal(reset_less=True)
195 self.exp_n127 = Signal(reset_less=True)
196 self.exp_n126 = Signal(reset_less=True)
197 self.m_zero = Signal(reset_less=True)
198 self.m_msbzero = Signal(reset_less=True)
199
200 def elaborate(self, platform):
201 m = Module()
202 m.d.comb += self.is_nan.eq(self._is_nan())
203 m.d.comb += self.is_zero.eq(self._is_zero())
204 m.d.comb += self.is_inf.eq(self._is_inf())
205 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
206 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
207 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
208 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
209 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
210 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
211 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
212 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
213 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
214 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
215 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
216
217 return m
218
219 def _is_nan(self):
220 return (self.exp_128) & (~self.m_zero)
221
222 def _is_inf(self):
223 return (self.exp_128) & (self.m_zero)
224
225 def _is_zero(self):
226 return (self.exp_n127) & (self.m_zero)
227
228 def _is_overflowed(self):
229 return self.exp_gt127
230
231 def _is_denormalised(self):
232 return (self.exp_n126) & (self.m_msbzero)
233
234
235 class FPNumOut(FPNumBase):
236 """ Floating-point Number Class
237
238 Contains signals for an incoming copy of the value, decoded into
239 sign / exponent / mantissa.
240 Also contains encoding functions, creation and recognition of
241 zero, NaN and inf (all signed)
242
243 Four extra bits are included in the mantissa: the top bit
244 (m[-1]) is effectively a carry-overflow. The other three are
245 guard (m[2]), round (m[1]), and sticky (m[0])
246 """
247
248 def __init__(self, fp):
249 FPNumBase.__init__(self, fp)
250
251 def elaborate(self, platform):
252 m = FPNumBase.elaborate(self, platform)
253
254 return m
255
256
257 class MultiShiftRMerge(Elaboratable):
258 """ shifts down (right) and merges lower bits into m[0].
259 m[0] is the "sticky" bit, basically
260 """
261
262 def __init__(self, width, s_max=None):
263 if s_max is None:
264 s_max = int(log(width) / log(2))
265 self.smax = s_max
266 self.m = Signal(width, reset_less=True)
267 self.inp = Signal(width, reset_less=True)
268 self.diff = Signal(s_max, reset_less=True)
269 self.width = width
270
271 def elaborate(self, platform):
272 m = Module()
273
274 rs = Signal(self.width, reset_less=True)
275 m_mask = Signal(self.width, reset_less=True)
276 smask = Signal(self.width, reset_less=True)
277 stickybit = Signal(reset_less=True)
278 maxslen = Signal(self.smax, reset_less=True)
279 maxsleni = Signal(self.smax, reset_less=True)
280
281 sm = MultiShift(self.width-1)
282 m0s = Const(0, self.width-1)
283 mw = Const(self.width-1, len(self.diff))
284 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
285 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
286 ]
287
288 m.d.comb += [
289 # shift mantissa by maxslen, mask by inverse
290 rs.eq(sm.rshift(self.inp[1:], maxslen)),
291 m_mask.eq(sm.rshift(~m0s, maxsleni)),
292 smask.eq(self.inp[1:] & m_mask),
293 # sticky bit combines all mask (and mantissa low bit)
294 stickybit.eq(smask.bool() | self.inp[0]),
295 # mantissa result contains m[0] already.
296 self.m.eq(Cat(stickybit, rs))
297 ]
298 return m
299
300
301 class FPNumShift(FPNumBase, Elaboratable):
302 """ Floating-point Number Class for shifting
303 """
304
305 def __init__(self, mainm, op, inv, width, m_extra=True):
306 FPNumBase.__init__(self, width, m_extra)
307 self.latch_in = Signal()
308 self.mainm = mainm
309 self.inv = inv
310 self.op = op
311
312 def elaborate(self, platform):
313 m = FPNumBase.elaborate(self, platform)
314
315 m.d.comb += self.s.eq(op.s)
316 m.d.comb += self.e.eq(op.e)
317 m.d.comb += self.m.eq(op.m)
318
319 with self.mainm.State("align"):
320 with m.If(self.e < self.inv.e):
321 m.d.sync += self.shift_down()
322
323 return m
324
325 def shift_down(self, inp):
326 """ shifts a mantissa down by one. exponent is increased to compensate
327
328 accuracy is lost as a result in the mantissa however there are 3
329 guard bits (the latter of which is the "sticky" bit)
330 """
331 return [self.e.eq(inp.e + 1),
332 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
333 ]
334
335 def shift_down_multi(self, diff):
336 """ shifts a mantissa down. exponent is increased to compensate
337
338 accuracy is lost as a result in the mantissa however there are 3
339 guard bits (the latter of which is the "sticky" bit)
340
341 this code works by variable-shifting the mantissa by up to
342 its maximum bit-length: no point doing more (it'll still be
343 zero).
344
345 the sticky bit is computed by shifting a batch of 1s by
346 the same amount, which will introduce zeros. it's then
347 inverted and used as a mask to get the LSBs of the mantissa.
348 those are then |'d into the sticky bit.
349 """
350 sm = MultiShift(self.width)
351 mw = Const(self.m_width-1, len(diff))
352 maxslen = Mux(diff > mw, mw, diff)
353 rs = sm.rshift(self.m[1:], maxslen)
354 maxsleni = mw - maxslen
355 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
356
357 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
358 return [self.e.eq(self.e + diff),
359 self.m.eq(Cat(stickybits, rs))
360 ]
361
362 def shift_up_multi(self, diff):
363 """ shifts a mantissa up. exponent is decreased to compensate
364 """
365 sm = MultiShift(self.width)
366 mw = Const(self.m_width, len(diff))
367 maxslen = Mux(diff > mw, mw, diff)
368
369 return [self.e.eq(self.e - diff),
370 self.m.eq(sm.lshift(self.m, maxslen))
371 ]
372
373
374 class FPNumDecode(FPNumBase):
375 """ Floating-point Number Class
376
377 Contains signals for an incoming copy of the value, decoded into
378 sign / exponent / mantissa.
379 Also contains encoding functions, creation and recognition of
380 zero, NaN and inf (all signed)
381
382 Four extra bits are included in the mantissa: the top bit
383 (m[-1]) is effectively a carry-overflow. The other three are
384 guard (m[2]), round (m[1]), and sticky (m[0])
385 """
386
387 def __init__(self, op, fp):
388 FPNumBase.__init__(self, fp)
389 self.op = op
390
391 def elaborate(self, platform):
392 m = FPNumBase.elaborate(self, platform)
393
394 m.d.comb += self.decode(self.v)
395
396 return m
397
398 def decode(self, v):
399 """ decodes a latched value into sign / exponent / mantissa
400
401 bias is subtracted here, from the exponent. exponent
402 is extended to 10 bits so that subtract 127 is done on
403 a 10-bit number
404 """
405 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
406 #print ("decode", self.e_end)
407 return [self.m.eq(Cat(*args)), # mantissa
408 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
409 self.s.eq(v[-1]), # sign
410 ]
411
412
413 class FPNumIn(FPNumBase):
414 """ Floating-point Number Class
415
416 Contains signals for an incoming copy of the value, decoded into
417 sign / exponent / mantissa.
418 Also contains encoding functions, creation and recognition of
419 zero, NaN and inf (all signed)
420
421 Four extra bits are included in the mantissa: the top bit
422 (m[-1]) is effectively a carry-overflow. The other three are
423 guard (m[2]), round (m[1]), and sticky (m[0])
424 """
425
426 def __init__(self, op, fp):
427 FPNumBase.__init__(self, fp)
428 self.latch_in = Signal()
429 self.op = op
430
431 def decode2(self, m):
432 """ decodes a latched value into sign / exponent / mantissa
433
434 bias is subtracted here, from the exponent. exponent
435 is extended to 10 bits so that subtract 127 is done on
436 a 10-bit number
437 """
438 v = self.v
439 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
440 #print ("decode", self.e_end)
441 res = ObjectProxy(m, pipemode=False)
442 res.m = Cat(*args) # mantissa
443 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
444 res.s = v[-1] # sign
445 return res
446
447 def decode(self, v):
448 """ decodes a latched value into sign / exponent / mantissa
449
450 bias is subtracted here, from the exponent. exponent
451 is extended to 10 bits so that subtract 127 is done on
452 a 10-bit number
453 """
454 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
455 #print ("decode", self.e_end)
456 return [self.m.eq(Cat(*args)), # mantissa
457 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
458 self.s.eq(v[-1]), # sign
459 ]
460
461 def shift_down(self, inp):
462 """ shifts a mantissa down by one. exponent is increased to compensate
463
464 accuracy is lost as a result in the mantissa however there are 3
465 guard bits (the latter of which is the "sticky" bit)
466 """
467 return [self.e.eq(inp.e + 1),
468 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
469 ]
470
471 def shift_down_multi(self, diff, inp=None):
472 """ shifts a mantissa down. exponent is increased to compensate
473
474 accuracy is lost as a result in the mantissa however there are 3
475 guard bits (the latter of which is the "sticky" bit)
476
477 this code works by variable-shifting the mantissa by up to
478 its maximum bit-length: no point doing more (it'll still be
479 zero).
480
481 the sticky bit is computed by shifting a batch of 1s by
482 the same amount, which will introduce zeros. it's then
483 inverted and used as a mask to get the LSBs of the mantissa.
484 those are then |'d into the sticky bit.
485 """
486 if inp is None:
487 inp = self
488 sm = MultiShift(self.width)
489 mw = Const(self.m_width-1, len(diff))
490 maxslen = Mux(diff > mw, mw, diff)
491 rs = sm.rshift(inp.m[1:], maxslen)
492 maxsleni = mw - maxslen
493 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
494
495 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
496 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
497 return [self.e.eq(inp.e + diff),
498 self.m.eq(Cat(stickybit, rs))
499 ]
500
501 def shift_up_multi(self, diff):
502 """ shifts a mantissa up. exponent is decreased to compensate
503 """
504 sm = MultiShift(self.width)
505 mw = Const(self.m_width, len(diff))
506 maxslen = Mux(diff > mw, mw, diff)
507
508 return [self.e.eq(self.e - diff),
509 self.m.eq(sm.lshift(self.m, maxslen))
510 ]
511
512
513 class Trigger(Elaboratable):
514 def __init__(self):
515
516 self.stb = Signal(reset=0)
517 self.ack = Signal()
518 self.trigger = Signal(reset_less=True)
519
520 def elaborate(self, platform):
521 m = Module()
522 m.d.comb += self.trigger.eq(self.stb & self.ack)
523 return m
524
525 def eq(self, inp):
526 return [self.stb.eq(inp.stb),
527 self.ack.eq(inp.ack)
528 ]
529
530 def ports(self):
531 return [self.stb, self.ack]
532
533
534 class FPOpIn(PrevControl):
535 def __init__(self, width):
536 PrevControl.__init__(self)
537 self.width = width
538
539 @property
540 def v(self):
541 return self.data_i
542
543 def chain_inv(self, in_op, extra=None):
544 stb = in_op.stb
545 if extra is not None:
546 stb = stb & extra
547 return [self.v.eq(in_op.v), # receive value
548 self.stb.eq(stb), # receive STB
549 in_op.ack.eq(~self.ack), # send ACK
550 ]
551
552 def chain_from(self, in_op, extra=None):
553 stb = in_op.stb
554 if extra is not None:
555 stb = stb & extra
556 return [self.v.eq(in_op.v), # receive value
557 self.stb.eq(stb), # receive STB
558 in_op.ack.eq(self.ack), # send ACK
559 ]
560
561
562 class FPOpOut(NextControl):
563 def __init__(self, width):
564 NextControl.__init__(self)
565 self.width = width
566
567 @property
568 def v(self):
569 return self.data_o
570
571 def chain_inv(self, in_op, extra=None):
572 stb = in_op.stb
573 if extra is not None:
574 stb = stb & extra
575 return [self.v.eq(in_op.v), # receive value
576 self.stb.eq(stb), # receive STB
577 in_op.ack.eq(~self.ack), # send ACK
578 ]
579
580 def chain_from(self, in_op, extra=None):
581 stb = in_op.stb
582 if extra is not None:
583 stb = stb & extra
584 return [self.v.eq(in_op.v), # receive value
585 self.stb.eq(stb), # receive STB
586 in_op.ack.eq(self.ack), # send ACK
587 ]
588
589
590 class Overflow: # (Elaboratable):
591 def __init__(self, name=None):
592 if name is None:
593 name = ""
594 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
595 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
596 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
597 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
598
599 #self.roundz = Signal(reset_less=True)
600
601 def __iter__(self):
602 yield self.guard
603 yield self.round_bit
604 yield self.sticky
605 yield self.m0
606
607 def eq(self, inp):
608 return [self.guard.eq(inp.guard),
609 self.round_bit.eq(inp.round_bit),
610 self.sticky.eq(inp.sticky),
611 self.m0.eq(inp.m0)]
612
613 @property
614 def roundz(self):
615 return self.guard & (self.round_bit | self.sticky | self.m0)
616
617
618 class OverflowMod(Elaboratable, Overflow):
619 def __init__(self, name=None):
620 Overflow.__init__(self, name)
621 if name is None:
622 name = ""
623 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
624
625 def __iter__(self):
626 yield from Overflow.__iter__(self)
627 yield self.roundz_out
628
629 def eq(self, inp):
630 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
631
632 def elaborate(self, platform):
633 m = Module()
634 m.d.comb += self.roundz_out.eq(self.roundz)
635 return m
636
637
638 class FPBase:
639 """ IEEE754 Floating Point Base Class
640
641 contains common functions for FP manipulation, such as
642 extracting and packing operands, normalisation, denormalisation,
643 rounding etc.
644 """
645
646 def get_op(self, m, op, v, next_state):
647 """ this function moves to the next state and copies the operand
648 when both stb and ack are 1.
649 acknowledgement is sent by setting ack to ZERO.
650 """
651 res = v.decode2(m)
652 ack = Signal()
653 with m.If((op.ready_o) & (op.valid_i_test)):
654 m.next = next_state
655 # op is latched in from FPNumIn class on same ack/stb
656 m.d.comb += ack.eq(0)
657 with m.Else():
658 m.d.comb += ack.eq(1)
659 return [res, ack]
660
661 def denormalise(self, m, a):
662 """ denormalises a number. this is probably the wrong name for
663 this function. for normalised numbers (exponent != minimum)
664 one *extra* bit (the implicit 1) is added *back in*.
665 for denormalised numbers, the mantissa is left alone
666 and the exponent increased by 1.
667
668 both cases *effectively multiply the number stored by 2*,
669 which has to be taken into account when extracting the result.
670 """
671 with m.If(a.exp_n127):
672 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
673 with m.Else():
674 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
675
676 def op_normalise(self, m, op, next_state):
677 """ operand normalisation
678 NOTE: just like "align", this one keeps going round every clock
679 until the result's exponent is within acceptable "range"
680 """
681 with m.If((op.m[-1] == 0)): # check last bit of mantissa
682 m.d.sync += [
683 op.e.eq(op.e - 1), # DECREASE exponent
684 op.m.eq(op.m << 1), # shift mantissa UP
685 ]
686 with m.Else():
687 m.next = next_state
688
689 def normalise_1(self, m, z, of, next_state):
690 """ first stage normalisation
691
692 NOTE: just like "align", this one keeps going round every clock
693 until the result's exponent is within acceptable "range"
694 NOTE: the weirdness of reassigning guard and round is due to
695 the extra mantissa bits coming from tot[0..2]
696 """
697 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
698 m.d.sync += [
699 z.e.eq(z.e - 1), # DECREASE exponent
700 z.m.eq(z.m << 1), # shift mantissa UP
701 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
702 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
703 of.round_bit.eq(0), # reset round bit
704 of.m0.eq(of.guard),
705 ]
706 with m.Else():
707 m.next = next_state
708
709 def normalise_2(self, m, z, of, next_state):
710 """ second stage normalisation
711
712 NOTE: just like "align", this one keeps going round every clock
713 until the result's exponent is within acceptable "range"
714 NOTE: the weirdness of reassigning guard and round is due to
715 the extra mantissa bits coming from tot[0..2]
716 """
717 with m.If(z.e < z.fp.N126):
718 m.d.sync += [
719 z.e.eq(z.e + 1), # INCREASE exponent
720 z.m.eq(z.m >> 1), # shift mantissa DOWN
721 of.guard.eq(z.m[0]),
722 of.m0.eq(z.m[1]),
723 of.round_bit.eq(of.guard),
724 of.sticky.eq(of.sticky | of.round_bit)
725 ]
726 with m.Else():
727 m.next = next_state
728
729 def roundz(self, m, z, roundz):
730 """ performs rounding on the output. TODO: different kinds of rounding
731 """
732 with m.If(roundz):
733 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
734 with m.If(z.m == z.fp.m1s): # all 1s
735 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
736
737 def corrections(self, m, z, next_state):
738 """ denormalisation and sign-bug corrections
739 """
740 m.next = next_state
741 # denormalised, correct exponent to zero
742 with m.If(z.is_denormalised):
743 m.d.sync += z.e.eq(z.fp.N127)
744
745 def pack(self, m, z, next_state):
746 """ packs the result into the output (detects overflow->Inf)
747 """
748 m.next = next_state
749 # if overflow occurs, return inf
750 with m.If(z.is_overflowed):
751 m.d.sync += z.inf(z.s)
752 with m.Else():
753 m.d.sync += z.create(z.s, z.e, z.m)
754
755 def put_z(self, m, z, out_z, next_state):
756 """ put_z: stores the result in the output. raises stb and waits
757 for ack to be set to 1 before moving to the next state.
758 resets stb back to zero when that occurs, as acknowledgement.
759 """
760 m.d.sync += [
761 out_z.v.eq(z.v)
762 ]
763 with m.If(out_z.valid_o & out_z.ready_i_test):
764 m.d.sync += out_z.valid_o.eq(0)
765 m.next = next_state
766 with m.Else():
767 m.d.sync += out_z.valid_o.eq(1)
768
769
770 class FPState(FPBase):
771 def __init__(self, state_from):
772 self.state_from = state_from
773
774 def set_inputs(self, inputs):
775 self.inputs = inputs
776 for k, v in inputs.items():
777 setattr(self, k, v)
778
779 def set_outputs(self, outputs):
780 self.outputs = outputs
781 for k, v in outputs.items():
782 setattr(self, k, v)
783
784
785 class FPID:
786 def __init__(self, id_wid):
787 self.id_wid = id_wid
788 if self.id_wid:
789 self.in_mid = Signal(id_wid, reset_less=True)
790 self.out_mid = Signal(id_wid, reset_less=True)
791 else:
792 self.in_mid = None
793 self.out_mid = None
794
795 def idsync(self, m):
796 if self.id_wid is not None:
797 m.d.sync += self.out_mid.eq(self.in_mid)