quick debug session on FP div stub pipeline
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBaseRecord:
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self.m_width = m_width
81 self.e_width = e_width
82 self.e_start = self.rmw - 1
83 self.e_end = self.rmw + self.e_width - 3 # for decoding
84
85 self.v = Signal(width, reset_less=True) # Latched copy of value
86 self.m = Signal(m_width, reset_less=True) # Mantissa
87 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
88 self.s = Signal(reset_less=True) # Sign bit
89
90 self.fp = self
91 self.drop_in(self)
92
93 def drop_in(self, fp):
94 fp.s = self.s
95 fp.e = self.e
96 fp.m = self.m
97 fp.v = self.v
98 fp.rmw = self.rmw
99 fp.width = self.width
100 fp.e_width = self.e_width
101 fp.e_max = self.e_max
102 fp.m_width = self.m_width
103 fp.e_start = self.e_start
104 fp.e_end = self.e_end
105 fp.m_extra = self.m_extra
106
107 m_width = self.m_width
108 e_max = self.e_max
109 e_width = self.e_width
110
111 self.mzero = Const(0, (m_width, False))
112 m_msb = 1<<(self.m_width-2)
113 self.msb1 = Const(m_msb, (m_width, False))
114 self.m1s = Const(-1, (m_width, False))
115 self.P128 = Const(e_max, (e_width, True))
116 self.P127 = Const(e_max-1, (e_width, True))
117 self.N127 = Const(-(e_max-1), (e_width, True))
118 self.N126 = Const(-(e_max-2), (e_width, True))
119
120 def create(self, s, e, m):
121 """ creates a value from sign / exponent / mantissa
122
123 bias is added here, to the exponent
124 """
125 return [
126 self.v[-1].eq(s), # sign
127 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add on bias)
128 self.v[0:self.e_start].eq(m) # mantissa
129 ]
130
131 def nan(self, s):
132 return self.create(s, self.fp.P128, 1<<(self.e_start-1))
133
134 def inf(self, s):
135 return self.create(s, self.fp.P128, 0)
136
137 def zero(self, s):
138 return self.create(s, self.fp.N127, 0)
139
140 def create2(self, s, e, m):
141 """ creates a value from sign / exponent / mantissa
142
143 bias is added here, to the exponent
144 """
145 e = e + self.P127 # exp (add on bias)
146 return Cat(m[0:self.e_start],
147 e[0:self.e_end-self.e_start],
148 s)
149
150 def nan2(self, s):
151 return self.create2(s, self.P128, self.msb1)
152
153 def inf2(self, s):
154 return self.create2(s, self.P128, self.mzero)
155
156 def zero2(self, s):
157 return self.create2(s, self.N127, self.mzero)
158
159 def __iter__(self):
160 yield self.s
161 yield self.e
162 yield self.m
163
164 def eq(self, inp):
165 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
166
167
168 class FPNumBase(FPNumBaseRecord, Elaboratable):
169 """ Floating-point Base Number Class
170 """
171 def __init__(self, fp):
172 fp.drop_in(self)
173 self.fp = fp
174 e_width = fp.e_width
175
176 self.is_nan = Signal(reset_less=True)
177 self.is_zero = Signal(reset_less=True)
178 self.is_inf = Signal(reset_less=True)
179 self.is_overflowed = Signal(reset_less=True)
180 self.is_denormalised = Signal(reset_less=True)
181 self.exp_128 = Signal(reset_less=True)
182 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
183 self.exp_lt_n126 = Signal(reset_less=True)
184 self.exp_gt_n126 = Signal(reset_less=True)
185 self.exp_gt127 = Signal(reset_less=True)
186 self.exp_n127 = Signal(reset_less=True)
187 self.exp_n126 = Signal(reset_less=True)
188 self.m_zero = Signal(reset_less=True)
189 self.m_msbzero = Signal(reset_less=True)
190
191 def elaborate(self, platform):
192 m = Module()
193 m.d.comb += self.is_nan.eq(self._is_nan())
194 m.d.comb += self.is_zero.eq(self._is_zero())
195 m.d.comb += self.is_inf.eq(self._is_inf())
196 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
197 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
198 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
199 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
200 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
201 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
202 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
203 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
204 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
205 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
206 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
207
208 return m
209
210 def _is_nan(self):
211 return (self.exp_128) & (~self.m_zero)
212
213 def _is_inf(self):
214 return (self.exp_128) & (self.m_zero)
215
216 def _is_zero(self):
217 return (self.exp_n127) & (self.m_zero)
218
219 def _is_overflowed(self):
220 return self.exp_gt127
221
222 def _is_denormalised(self):
223 return (self.exp_n126) & (self.m_msbzero)
224
225
226 class FPNumOut(FPNumBase):
227 """ Floating-point Number Class
228
229 Contains signals for an incoming copy of the value, decoded into
230 sign / exponent / mantissa.
231 Also contains encoding functions, creation and recognition of
232 zero, NaN and inf (all signed)
233
234 Four extra bits are included in the mantissa: the top bit
235 (m[-1]) is effectively a carry-overflow. The other three are
236 guard (m[2]), round (m[1]), and sticky (m[0])
237 """
238 def __init__(self, fp):
239 FPNumBase.__init__(self, fp)
240
241 def elaborate(self, platform):
242 m = FPNumBase.elaborate(self, platform)
243
244 return m
245
246
247 class MultiShiftRMerge(Elaboratable):
248 """ shifts down (right) and merges lower bits into m[0].
249 m[0] is the "sticky" bit, basically
250 """
251 def __init__(self, width, s_max=None):
252 if s_max is None:
253 s_max = int(log(width) / log(2))
254 self.smax = s_max
255 self.m = Signal(width, reset_less=True)
256 self.inp = Signal(width, reset_less=True)
257 self.diff = Signal(s_max, reset_less=True)
258 self.width = width
259
260 def elaborate(self, platform):
261 m = Module()
262
263 rs = Signal(self.width, reset_less=True)
264 m_mask = Signal(self.width, reset_less=True)
265 smask = Signal(self.width, reset_less=True)
266 stickybit = Signal(reset_less=True)
267 maxslen = Signal(self.smax, reset_less=True)
268 maxsleni = Signal(self.smax, reset_less=True)
269
270 sm = MultiShift(self.width-1)
271 m0s = Const(0, self.width-1)
272 mw = Const(self.width-1, len(self.diff))
273 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
274 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
275 ]
276
277 m.d.comb += [
278 # shift mantissa by maxslen, mask by inverse
279 rs.eq(sm.rshift(self.inp[1:], maxslen)),
280 m_mask.eq(sm.rshift(~m0s, maxsleni)),
281 smask.eq(self.inp[1:] & m_mask),
282 # sticky bit combines all mask (and mantissa low bit)
283 stickybit.eq(smask.bool() | self.inp[0]),
284 # mantissa result contains m[0] already.
285 self.m.eq(Cat(stickybit, rs))
286 ]
287 return m
288
289
290 class FPNumShift(FPNumBase, Elaboratable):
291 """ Floating-point Number Class for shifting
292 """
293 def __init__(self, mainm, op, inv, width, m_extra=True):
294 FPNumBase.__init__(self, width, m_extra)
295 self.latch_in = Signal()
296 self.mainm = mainm
297 self.inv = inv
298 self.op = op
299
300 def elaborate(self, platform):
301 m = FPNumBase.elaborate(self, platform)
302
303 m.d.comb += self.s.eq(op.s)
304 m.d.comb += self.e.eq(op.e)
305 m.d.comb += self.m.eq(op.m)
306
307 with self.mainm.State("align"):
308 with m.If(self.e < self.inv.e):
309 m.d.sync += self.shift_down()
310
311 return m
312
313 def shift_down(self, inp):
314 """ shifts a mantissa down by one. exponent is increased to compensate
315
316 accuracy is lost as a result in the mantissa however there are 3
317 guard bits (the latter of which is the "sticky" bit)
318 """
319 return [self.e.eq(inp.e + 1),
320 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
321 ]
322
323 def shift_down_multi(self, diff):
324 """ shifts a mantissa down. exponent is increased to compensate
325
326 accuracy is lost as a result in the mantissa however there are 3
327 guard bits (the latter of which is the "sticky" bit)
328
329 this code works by variable-shifting the mantissa by up to
330 its maximum bit-length: no point doing more (it'll still be
331 zero).
332
333 the sticky bit is computed by shifting a batch of 1s by
334 the same amount, which will introduce zeros. it's then
335 inverted and used as a mask to get the LSBs of the mantissa.
336 those are then |'d into the sticky bit.
337 """
338 sm = MultiShift(self.width)
339 mw = Const(self.m_width-1, len(diff))
340 maxslen = Mux(diff > mw, mw, diff)
341 rs = sm.rshift(self.m[1:], maxslen)
342 maxsleni = mw - maxslen
343 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
344
345 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
346 return [self.e.eq(self.e + diff),
347 self.m.eq(Cat(stickybits, rs))
348 ]
349
350 def shift_up_multi(self, diff):
351 """ shifts a mantissa up. exponent is decreased to compensate
352 """
353 sm = MultiShift(self.width)
354 mw = Const(self.m_width, len(diff))
355 maxslen = Mux(diff > mw, mw, diff)
356
357 return [self.e.eq(self.e - diff),
358 self.m.eq(sm.lshift(self.m, maxslen))
359 ]
360
361
362 class FPNumDecode(FPNumBase):
363 """ Floating-point Number Class
364
365 Contains signals for an incoming copy of the value, decoded into
366 sign / exponent / mantissa.
367 Also contains encoding functions, creation and recognition of
368 zero, NaN and inf (all signed)
369
370 Four extra bits are included in the mantissa: the top bit
371 (m[-1]) is effectively a carry-overflow. The other three are
372 guard (m[2]), round (m[1]), and sticky (m[0])
373 """
374 def __init__(self, op, fp):
375 FPNumBase.__init__(self, fp)
376 self.op = op
377
378 def elaborate(self, platform):
379 m = FPNumBase.elaborate(self, platform)
380
381 m.d.comb += self.decode(self.v)
382
383 return m
384
385 def decode(self, v):
386 """ decodes a latched value into sign / exponent / mantissa
387
388 bias is subtracted here, from the exponent. exponent
389 is extended to 10 bits so that subtract 127 is done on
390 a 10-bit number
391 """
392 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
393 #print ("decode", self.e_end)
394 return [self.m.eq(Cat(*args)), # mantissa
395 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
396 self.s.eq(v[-1]), # sign
397 ]
398
399 class FPNumIn(FPNumBase):
400 """ Floating-point Number Class
401
402 Contains signals for an incoming copy of the value, decoded into
403 sign / exponent / mantissa.
404 Also contains encoding functions, creation and recognition of
405 zero, NaN and inf (all signed)
406
407 Four extra bits are included in the mantissa: the top bit
408 (m[-1]) is effectively a carry-overflow. The other three are
409 guard (m[2]), round (m[1]), and sticky (m[0])
410 """
411 def __init__(self, op, fp):
412 FPNumBase.__init__(self, fp)
413 self.latch_in = Signal()
414 self.op = op
415
416 def decode2(self, m):
417 """ decodes a latched value into sign / exponent / mantissa
418
419 bias is subtracted here, from the exponent. exponent
420 is extended to 10 bits so that subtract 127 is done on
421 a 10-bit number
422 """
423 v = self.v
424 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
425 #print ("decode", self.e_end)
426 res = ObjectProxy(m, pipemode=False)
427 res.m = Cat(*args) # mantissa
428 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
429 res.s = v[-1] # sign
430 return res
431
432 def decode(self, v):
433 """ decodes a latched value into sign / exponent / mantissa
434
435 bias is subtracted here, from the exponent. exponent
436 is extended to 10 bits so that subtract 127 is done on
437 a 10-bit number
438 """
439 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
440 #print ("decode", self.e_end)
441 return [self.m.eq(Cat(*args)), # mantissa
442 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
443 self.s.eq(v[-1]), # sign
444 ]
445
446 def shift_down(self, inp):
447 """ shifts a mantissa down by one. exponent is increased to compensate
448
449 accuracy is lost as a result in the mantissa however there are 3
450 guard bits (the latter of which is the "sticky" bit)
451 """
452 return [self.e.eq(inp.e + 1),
453 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
454 ]
455
456 def shift_down_multi(self, diff, inp=None):
457 """ shifts a mantissa down. exponent is increased to compensate
458
459 accuracy is lost as a result in the mantissa however there are 3
460 guard bits (the latter of which is the "sticky" bit)
461
462 this code works by variable-shifting the mantissa by up to
463 its maximum bit-length: no point doing more (it'll still be
464 zero).
465
466 the sticky bit is computed by shifting a batch of 1s by
467 the same amount, which will introduce zeros. it's then
468 inverted and used as a mask to get the LSBs of the mantissa.
469 those are then |'d into the sticky bit.
470 """
471 if inp is None:
472 inp = self
473 sm = MultiShift(self.width)
474 mw = Const(self.m_width-1, len(diff))
475 maxslen = Mux(diff > mw, mw, diff)
476 rs = sm.rshift(inp.m[1:], maxslen)
477 maxsleni = mw - maxslen
478 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
479
480 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
481 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
482 return [self.e.eq(inp.e + diff),
483 self.m.eq(Cat(stickybit, rs))
484 ]
485
486 def shift_up_multi(self, diff):
487 """ shifts a mantissa up. exponent is decreased to compensate
488 """
489 sm = MultiShift(self.width)
490 mw = Const(self.m_width, len(diff))
491 maxslen = Mux(diff > mw, mw, diff)
492
493 return [self.e.eq(self.e - diff),
494 self.m.eq(sm.lshift(self.m, maxslen))
495 ]
496
497 class Trigger(Elaboratable):
498 def __init__(self):
499
500 self.stb = Signal(reset=0)
501 self.ack = Signal()
502 self.trigger = Signal(reset_less=True)
503
504 def elaborate(self, platform):
505 m = Module()
506 m.d.comb += self.trigger.eq(self.stb & self.ack)
507 return m
508
509 def eq(self, inp):
510 return [self.stb.eq(inp.stb),
511 self.ack.eq(inp.ack)
512 ]
513
514 def ports(self):
515 return [self.stb, self.ack]
516
517
518 class FPOpIn(PrevControl):
519 def __init__(self, width):
520 PrevControl.__init__(self)
521 self.width = width
522
523 @property
524 def v(self):
525 return self.data_i
526
527 def chain_inv(self, in_op, extra=None):
528 stb = in_op.stb
529 if extra is not None:
530 stb = stb & extra
531 return [self.v.eq(in_op.v), # receive value
532 self.stb.eq(stb), # receive STB
533 in_op.ack.eq(~self.ack), # send ACK
534 ]
535
536 def chain_from(self, in_op, extra=None):
537 stb = in_op.stb
538 if extra is not None:
539 stb = stb & extra
540 return [self.v.eq(in_op.v), # receive value
541 self.stb.eq(stb), # receive STB
542 in_op.ack.eq(self.ack), # send ACK
543 ]
544
545
546 class FPOpOut(NextControl):
547 def __init__(self, width):
548 NextControl.__init__(self)
549 self.width = width
550
551 @property
552 def v(self):
553 return self.data_o
554
555 def chain_inv(self, in_op, extra=None):
556 stb = in_op.stb
557 if extra is not None:
558 stb = stb & extra
559 return [self.v.eq(in_op.v), # receive value
560 self.stb.eq(stb), # receive STB
561 in_op.ack.eq(~self.ack), # send ACK
562 ]
563
564 def chain_from(self, in_op, extra=None):
565 stb = in_op.stb
566 if extra is not None:
567 stb = stb & extra
568 return [self.v.eq(in_op.v), # receive value
569 self.stb.eq(stb), # receive STB
570 in_op.ack.eq(self.ack), # send ACK
571 ]
572
573
574 class Overflow: #(Elaboratable):
575 def __init__(self):
576 self.guard = Signal(reset_less=True) # tot[2]
577 self.round_bit = Signal(reset_less=True) # tot[1]
578 self.sticky = Signal(reset_less=True) # tot[0]
579 self.m0 = Signal(reset_less=True) # mantissa zero bit
580
581 #self.roundz = Signal(reset_less=True)
582
583 def __iter__(self):
584 yield self.guard
585 yield self.round_bit
586 yield self.sticky
587 yield self.m0
588
589 def eq(self, inp):
590 return [self.guard.eq(inp.guard),
591 self.round_bit.eq(inp.round_bit),
592 self.sticky.eq(inp.sticky),
593 self.m0.eq(inp.m0)]
594
595 @property
596 def roundz(self):
597 return self.guard & (self.round_bit | self.sticky | self.m0)
598
599
600 class FPBase:
601 """ IEEE754 Floating Point Base Class
602
603 contains common functions for FP manipulation, such as
604 extracting and packing operands, normalisation, denormalisation,
605 rounding etc.
606 """
607
608 def get_op(self, m, op, v, next_state):
609 """ this function moves to the next state and copies the operand
610 when both stb and ack are 1.
611 acknowledgement is sent by setting ack to ZERO.
612 """
613 res = v.decode2(m)
614 ack = Signal()
615 with m.If((op.ready_o) & (op.valid_i_test)):
616 m.next = next_state
617 # op is latched in from FPNumIn class on same ack/stb
618 m.d.comb += ack.eq(0)
619 with m.Else():
620 m.d.comb += ack.eq(1)
621 return [res, ack]
622
623 def denormalise(self, m, a):
624 """ denormalises a number. this is probably the wrong name for
625 this function. for normalised numbers (exponent != minimum)
626 one *extra* bit (the implicit 1) is added *back in*.
627 for denormalised numbers, the mantissa is left alone
628 and the exponent increased by 1.
629
630 both cases *effectively multiply the number stored by 2*,
631 which has to be taken into account when extracting the result.
632 """
633 with m.If(a.exp_n127):
634 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
635 with m.Else():
636 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
637
638 def op_normalise(self, m, op, next_state):
639 """ operand normalisation
640 NOTE: just like "align", this one keeps going round every clock
641 until the result's exponent is within acceptable "range"
642 """
643 with m.If((op.m[-1] == 0)): # check last bit of mantissa
644 m.d.sync +=[
645 op.e.eq(op.e - 1), # DECREASE exponent
646 op.m.eq(op.m << 1), # shift mantissa UP
647 ]
648 with m.Else():
649 m.next = next_state
650
651 def normalise_1(self, m, z, of, next_state):
652 """ first stage normalisation
653
654 NOTE: just like "align", this one keeps going round every clock
655 until the result's exponent is within acceptable "range"
656 NOTE: the weirdness of reassigning guard and round is due to
657 the extra mantissa bits coming from tot[0..2]
658 """
659 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
660 m.d.sync += [
661 z.e.eq(z.e - 1), # DECREASE exponent
662 z.m.eq(z.m << 1), # shift mantissa UP
663 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
664 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
665 of.round_bit.eq(0), # reset round bit
666 of.m0.eq(of.guard),
667 ]
668 with m.Else():
669 m.next = next_state
670
671 def normalise_2(self, m, z, of, next_state):
672 """ second stage normalisation
673
674 NOTE: just like "align", this one keeps going round every clock
675 until the result's exponent is within acceptable "range"
676 NOTE: the weirdness of reassigning guard and round is due to
677 the extra mantissa bits coming from tot[0..2]
678 """
679 with m.If(z.e < z.fp.N126):
680 m.d.sync +=[
681 z.e.eq(z.e + 1), # INCREASE exponent
682 z.m.eq(z.m >> 1), # shift mantissa DOWN
683 of.guard.eq(z.m[0]),
684 of.m0.eq(z.m[1]),
685 of.round_bit.eq(of.guard),
686 of.sticky.eq(of.sticky | of.round_bit)
687 ]
688 with m.Else():
689 m.next = next_state
690
691 def roundz(self, m, z, roundz):
692 """ performs rounding on the output. TODO: different kinds of rounding
693 """
694 with m.If(roundz):
695 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
696 with m.If(z.m == z.fp.m1s): # all 1s
697 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
698
699 def corrections(self, m, z, next_state):
700 """ denormalisation and sign-bug corrections
701 """
702 m.next = next_state
703 # denormalised, correct exponent to zero
704 with m.If(z.is_denormalised):
705 m.d.sync += z.e.eq(z.fp.N127)
706
707 def pack(self, m, z, next_state):
708 """ packs the result into the output (detects overflow->Inf)
709 """
710 m.next = next_state
711 # if overflow occurs, return inf
712 with m.If(z.is_overflowed):
713 m.d.sync += z.inf(z.s)
714 with m.Else():
715 m.d.sync += z.create(z.s, z.e, z.m)
716
717 def put_z(self, m, z, out_z, next_state):
718 """ put_z: stores the result in the output. raises stb and waits
719 for ack to be set to 1 before moving to the next state.
720 resets stb back to zero when that occurs, as acknowledgement.
721 """
722 m.d.sync += [
723 out_z.v.eq(z.v)
724 ]
725 with m.If(out_z.valid_o & out_z.ready_i_test):
726 m.d.sync += out_z.valid_o.eq(0)
727 m.next = next_state
728 with m.Else():
729 m.d.sync += out_z.valid_o.eq(1)
730
731
732 class FPState(FPBase):
733 def __init__(self, state_from):
734 self.state_from = state_from
735
736 def set_inputs(self, inputs):
737 self.inputs = inputs
738 for k,v in inputs.items():
739 setattr(self, k, v)
740
741 def set_outputs(self, outputs):
742 self.outputs = outputs
743 for k,v in outputs.items():
744 setattr(self, k, v)
745
746
747 class FPID:
748 def __init__(self, id_wid):
749 self.id_wid = id_wid
750 if self.id_wid:
751 self.in_mid = Signal(id_wid, reset_less=True)
752 self.out_mid = Signal(id_wid, reset_less=True)
753 else:
754 self.in_mid = None
755 self.out_mid = None
756
757 def idsync(self, m):
758 if self.id_wid is not None:
759 m.d.sync += self.out_mid.eq(self.in_mid)
760
761