fix up FPNumBase by creating a Record class (not derived from Elaboratable)
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBaseRecord:
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self.m_width = m_width
81 self.e_width = e_width
82 self.e_start = self.rmw - 1
83 self.e_end = self.rmw + self.e_width - 3 # for decoding
84
85 self.v = Signal(width, reset_less=True) # Latched copy of value
86 self.m = Signal(m_width, reset_less=True) # Mantissa
87 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
88 self.s = Signal(reset_less=True) # Sign bit
89
90 self.mzero = Const(0, (m_width, False))
91 m_msb = 1<<(self.m_width-2)
92 self.msb1 = Const(m_msb, (m_width, False))
93 self.m1s = Const(-1, (m_width, False))
94 self.P128 = Const(e_max, (e_width, True))
95 self.P127 = Const(e_max-1, (e_width, True))
96 self.N127 = Const(-(e_max-1), (e_width, True))
97 self.N126 = Const(-(e_max-2), (e_width, True))
98
99 def drop_in(self, fp):
100 fp.s = self.s
101 fp.e = self.e
102 fp.m = self.m
103 fp.v = self.v
104 fp.width = self.width
105 fp.e_width = self.e_width
106 fp.m_width = self.m_width
107 fp.e_start = self.e_start
108 fp.e_end = self.e_end
109 fp.m_extra = self.m_extra
110
111 def create(self, s, e, m):
112 """ creates a value from sign / exponent / mantissa
113
114 bias is added here, to the exponent
115 """
116 return [
117 self.v[-1].eq(s), # sign
118 self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias)
119 self.v[0:self.e_start].eq(m) # mantissa
120 ]
121
122 def nan(self, s):
123 return self.create(s, self.P128, 1<<(self.e_start-1))
124
125 def inf(self, s):
126 return self.create(s, self.P128, 0)
127
128 def zero(self, s):
129 return self.create(s, self.N127, 0)
130
131 def create2(self, s, e, m):
132 """ creates a value from sign / exponent / mantissa
133
134 bias is added here, to the exponent
135 """
136 e = e + self.P127 # exp (add on bias)
137 return Cat(m[0:self.e_start],
138 e[0:self.e_end-self.e_start],
139 s)
140
141 def nan2(self, s):
142 return self.create2(s, self.P128, self.msb1)
143
144 def inf2(self, s):
145 return self.create2(s, self.P128, self.mzero)
146
147 def zero2(self, s):
148 return self.create2(s, self.N127, self.mzero)
149
150 def __iter__(self):
151 yield self.s
152 yield self.e
153 yield self.m
154
155 def eq(self, inp):
156 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
157
158
159 class FPNumBase(FPNumBaseRecord, Elaboratable):
160 """ Floating-point Base Number Class
161 """
162 def __init__(self, fp):
163 fp.drop_in(self)
164 self.fp = fp
165 e_width = fp.e_width
166
167 self.is_nan = Signal(reset_less=True)
168 self.is_zero = Signal(reset_less=True)
169 self.is_inf = Signal(reset_less=True)
170 self.is_overflowed = Signal(reset_less=True)
171 self.is_denormalised = Signal(reset_less=True)
172 self.exp_128 = Signal(reset_less=True)
173 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
174 self.exp_lt_n126 = Signal(reset_less=True)
175 self.exp_gt_n126 = Signal(reset_less=True)
176 self.exp_gt127 = Signal(reset_less=True)
177 self.exp_n127 = Signal(reset_less=True)
178 self.exp_n126 = Signal(reset_less=True)
179 self.m_zero = Signal(reset_less=True)
180 self.m_msbzero = Signal(reset_less=True)
181
182 def elaborate(self, platform):
183 m = Module()
184 m.d.comb += self.is_nan.eq(self._is_nan())
185 m.d.comb += self.is_zero.eq(self._is_zero())
186 m.d.comb += self.is_inf.eq(self._is_inf())
187 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
188 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
189 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
190 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
191 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
192 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
193 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
194 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
195 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
196 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
197 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
198
199 return m
200
201 def _is_nan(self):
202 return (self.exp_128) & (~self.m_zero)
203
204 def _is_inf(self):
205 return (self.exp_128) & (self.m_zero)
206
207 def _is_zero(self):
208 return (self.exp_n127) & (self.m_zero)
209
210 def _is_overflowed(self):
211 return self.exp_gt127
212
213 def _is_denormalised(self):
214 return (self.exp_n126) & (self.m_msbzero)
215
216
217 class FPNumOut(FPNumBase):
218 """ Floating-point Number Class
219
220 Contains signals for an incoming copy of the value, decoded into
221 sign / exponent / mantissa.
222 Also contains encoding functions, creation and recognition of
223 zero, NaN and inf (all signed)
224
225 Four extra bits are included in the mantissa: the top bit
226 (m[-1]) is effectively a carry-overflow. The other three are
227 guard (m[2]), round (m[1]), and sticky (m[0])
228 """
229 def __init__(self, fp):
230 FPNumBase.__init__(self, fp)
231
232 def elaborate(self, platform):
233 m = FPNumBase.elaborate(self, platform)
234
235 return m
236
237
238 class MultiShiftRMerge(Elaboratable):
239 """ shifts down (right) and merges lower bits into m[0].
240 m[0] is the "sticky" bit, basically
241 """
242 def __init__(self, width, s_max=None):
243 if s_max is None:
244 s_max = int(log(width) / log(2))
245 self.smax = s_max
246 self.m = Signal(width, reset_less=True)
247 self.inp = Signal(width, reset_less=True)
248 self.diff = Signal(s_max, reset_less=True)
249 self.width = width
250
251 def elaborate(self, platform):
252 m = Module()
253
254 rs = Signal(self.width, reset_less=True)
255 m_mask = Signal(self.width, reset_less=True)
256 smask = Signal(self.width, reset_less=True)
257 stickybit = Signal(reset_less=True)
258 maxslen = Signal(self.smax, reset_less=True)
259 maxsleni = Signal(self.smax, reset_less=True)
260
261 sm = MultiShift(self.width-1)
262 m0s = Const(0, self.width-1)
263 mw = Const(self.width-1, len(self.diff))
264 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
265 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
266 ]
267
268 m.d.comb += [
269 # shift mantissa by maxslen, mask by inverse
270 rs.eq(sm.rshift(self.inp[1:], maxslen)),
271 m_mask.eq(sm.rshift(~m0s, maxsleni)),
272 smask.eq(self.inp[1:] & m_mask),
273 # sticky bit combines all mask (and mantissa low bit)
274 stickybit.eq(smask.bool() | self.inp[0]),
275 # mantissa result contains m[0] already.
276 self.m.eq(Cat(stickybit, rs))
277 ]
278 return m
279
280
281 class FPNumShift(FPNumBase, Elaboratable):
282 """ Floating-point Number Class for shifting
283 """
284 def __init__(self, mainm, op, inv, width, m_extra=True):
285 FPNumBase.__init__(self, width, m_extra)
286 self.latch_in = Signal()
287 self.mainm = mainm
288 self.inv = inv
289 self.op = op
290
291 def elaborate(self, platform):
292 m = FPNumBase.elaborate(self, platform)
293
294 m.d.comb += self.s.eq(op.s)
295 m.d.comb += self.e.eq(op.e)
296 m.d.comb += self.m.eq(op.m)
297
298 with self.mainm.State("align"):
299 with m.If(self.e < self.inv.e):
300 m.d.sync += self.shift_down()
301
302 return m
303
304 def shift_down(self, inp):
305 """ shifts a mantissa down by one. exponent is increased to compensate
306
307 accuracy is lost as a result in the mantissa however there are 3
308 guard bits (the latter of which is the "sticky" bit)
309 """
310 return [self.e.eq(inp.e + 1),
311 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
312 ]
313
314 def shift_down_multi(self, diff):
315 """ shifts a mantissa down. exponent is increased to compensate
316
317 accuracy is lost as a result in the mantissa however there are 3
318 guard bits (the latter of which is the "sticky" bit)
319
320 this code works by variable-shifting the mantissa by up to
321 its maximum bit-length: no point doing more (it'll still be
322 zero).
323
324 the sticky bit is computed by shifting a batch of 1s by
325 the same amount, which will introduce zeros. it's then
326 inverted and used as a mask to get the LSBs of the mantissa.
327 those are then |'d into the sticky bit.
328 """
329 sm = MultiShift(self.width)
330 mw = Const(self.m_width-1, len(diff))
331 maxslen = Mux(diff > mw, mw, diff)
332 rs = sm.rshift(self.m[1:], maxslen)
333 maxsleni = mw - maxslen
334 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
335
336 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
337 return [self.e.eq(self.e + diff),
338 self.m.eq(Cat(stickybits, rs))
339 ]
340
341 def shift_up_multi(self, diff):
342 """ shifts a mantissa up. exponent is decreased to compensate
343 """
344 sm = MultiShift(self.width)
345 mw = Const(self.m_width, len(diff))
346 maxslen = Mux(diff > mw, mw, diff)
347
348 return [self.e.eq(self.e - diff),
349 self.m.eq(sm.lshift(self.m, maxslen))
350 ]
351
352
353 class FPNumDecode(FPNumBase):
354 """ Floating-point Number Class
355
356 Contains signals for an incoming copy of the value, decoded into
357 sign / exponent / mantissa.
358 Also contains encoding functions, creation and recognition of
359 zero, NaN and inf (all signed)
360
361 Four extra bits are included in the mantissa: the top bit
362 (m[-1]) is effectively a carry-overflow. The other three are
363 guard (m[2]), round (m[1]), and sticky (m[0])
364 """
365 def __init__(self, op, fp):
366 FPNumBase.__init__(self, fp)
367 self.op = op
368
369 def elaborate(self, platform):
370 m = FPNumBase.elaborate(self, platform)
371
372 m.d.comb += self.decode(self.v)
373
374 return m
375
376 def decode(self, v):
377 """ decodes a latched value into sign / exponent / mantissa
378
379 bias is subtracted here, from the exponent. exponent
380 is extended to 10 bits so that subtract 127 is done on
381 a 10-bit number
382 """
383 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
384 #print ("decode", self.e_end)
385 return [self.m.eq(Cat(*args)), # mantissa
386 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
387 self.s.eq(v[-1]), # sign
388 ]
389
390 class FPNumIn(FPNumBase):
391 """ Floating-point Number Class
392
393 Contains signals for an incoming copy of the value, decoded into
394 sign / exponent / mantissa.
395 Also contains encoding functions, creation and recognition of
396 zero, NaN and inf (all signed)
397
398 Four extra bits are included in the mantissa: the top bit
399 (m[-1]) is effectively a carry-overflow. The other three are
400 guard (m[2]), round (m[1]), and sticky (m[0])
401 """
402 def __init__(self, op, fp):
403 FPNumBase.__init__(self, fp)
404 self.latch_in = Signal()
405 self.op = op
406
407 def decode2(self, m):
408 """ decodes a latched value into sign / exponent / mantissa
409
410 bias is subtracted here, from the exponent. exponent
411 is extended to 10 bits so that subtract 127 is done on
412 a 10-bit number
413 """
414 v = self.v
415 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
416 #print ("decode", self.e_end)
417 res = ObjectProxy(m, pipemode=False)
418 res.m = Cat(*args) # mantissa
419 res.e = v[self.e_start:self.e_end] - self.P127 # exp
420 res.s = v[-1] # sign
421 return res
422
423 def decode(self, v):
424 """ decodes a latched value into sign / exponent / mantissa
425
426 bias is subtracted here, from the exponent. exponent
427 is extended to 10 bits so that subtract 127 is done on
428 a 10-bit number
429 """
430 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
431 #print ("decode", self.e_end)
432 return [self.m.eq(Cat(*args)), # mantissa
433 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
434 self.s.eq(v[-1]), # sign
435 ]
436
437 def shift_down(self, inp):
438 """ shifts a mantissa down by one. exponent is increased to compensate
439
440 accuracy is lost as a result in the mantissa however there are 3
441 guard bits (the latter of which is the "sticky" bit)
442 """
443 return [self.e.eq(inp.e + 1),
444 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
445 ]
446
447 def shift_down_multi(self, diff, inp=None):
448 """ shifts a mantissa down. exponent is increased to compensate
449
450 accuracy is lost as a result in the mantissa however there are 3
451 guard bits (the latter of which is the "sticky" bit)
452
453 this code works by variable-shifting the mantissa by up to
454 its maximum bit-length: no point doing more (it'll still be
455 zero).
456
457 the sticky bit is computed by shifting a batch of 1s by
458 the same amount, which will introduce zeros. it's then
459 inverted and used as a mask to get the LSBs of the mantissa.
460 those are then |'d into the sticky bit.
461 """
462 if inp is None:
463 inp = self
464 sm = MultiShift(self.width)
465 mw = Const(self.m_width-1, len(diff))
466 maxslen = Mux(diff > mw, mw, diff)
467 rs = sm.rshift(inp.m[1:], maxslen)
468 maxsleni = mw - maxslen
469 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
470
471 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
472 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
473 return [self.e.eq(inp.e + diff),
474 self.m.eq(Cat(stickybit, rs))
475 ]
476
477 def shift_up_multi(self, diff):
478 """ shifts a mantissa up. exponent is decreased to compensate
479 """
480 sm = MultiShift(self.width)
481 mw = Const(self.m_width, len(diff))
482 maxslen = Mux(diff > mw, mw, diff)
483
484 return [self.e.eq(self.e - diff),
485 self.m.eq(sm.lshift(self.m, maxslen))
486 ]
487
488 class Trigger(Elaboratable):
489 def __init__(self):
490
491 self.stb = Signal(reset=0)
492 self.ack = Signal()
493 self.trigger = Signal(reset_less=True)
494
495 def elaborate(self, platform):
496 m = Module()
497 m.d.comb += self.trigger.eq(self.stb & self.ack)
498 return m
499
500 def eq(self, inp):
501 return [self.stb.eq(inp.stb),
502 self.ack.eq(inp.ack)
503 ]
504
505 def ports(self):
506 return [self.stb, self.ack]
507
508
509 class FPOpIn(PrevControl):
510 def __init__(self, width):
511 PrevControl.__init__(self)
512 self.width = width
513
514 @property
515 def v(self):
516 return self.data_i
517
518 def chain_inv(self, in_op, extra=None):
519 stb = in_op.stb
520 if extra is not None:
521 stb = stb & extra
522 return [self.v.eq(in_op.v), # receive value
523 self.stb.eq(stb), # receive STB
524 in_op.ack.eq(~self.ack), # send ACK
525 ]
526
527 def chain_from(self, in_op, extra=None):
528 stb = in_op.stb
529 if extra is not None:
530 stb = stb & extra
531 return [self.v.eq(in_op.v), # receive value
532 self.stb.eq(stb), # receive STB
533 in_op.ack.eq(self.ack), # send ACK
534 ]
535
536
537 class FPOpOut(NextControl):
538 def __init__(self, width):
539 NextControl.__init__(self)
540 self.width = width
541
542 @property
543 def v(self):
544 return self.data_o
545
546 def chain_inv(self, in_op, extra=None):
547 stb = in_op.stb
548 if extra is not None:
549 stb = stb & extra
550 return [self.v.eq(in_op.v), # receive value
551 self.stb.eq(stb), # receive STB
552 in_op.ack.eq(~self.ack), # send ACK
553 ]
554
555 def chain_from(self, in_op, extra=None):
556 stb = in_op.stb
557 if extra is not None:
558 stb = stb & extra
559 return [self.v.eq(in_op.v), # receive value
560 self.stb.eq(stb), # receive STB
561 in_op.ack.eq(self.ack), # send ACK
562 ]
563
564
565 class Overflow: #(Elaboratable):
566 def __init__(self):
567 self.guard = Signal(reset_less=True) # tot[2]
568 self.round_bit = Signal(reset_less=True) # tot[1]
569 self.sticky = Signal(reset_less=True) # tot[0]
570 self.m0 = Signal(reset_less=True) # mantissa zero bit
571
572 self.roundz = Signal(reset_less=True)
573
574 def __iter__(self):
575 yield self.guard
576 yield self.round_bit
577 yield self.sticky
578 yield self.m0
579
580 def eq(self, inp):
581 return [self.guard.eq(inp.guard),
582 self.round_bit.eq(inp.round_bit),
583 self.sticky.eq(inp.sticky),
584 self.m0.eq(inp.m0)]
585
586 def elaborate(self, platform):
587 m = Module()
588 m.d.comb += self.roundz.eq(self.guard & \
589 (self.round_bit | self.sticky | self.m0))
590 return m
591
592
593 class FPBase:
594 """ IEEE754 Floating Point Base Class
595
596 contains common functions for FP manipulation, such as
597 extracting and packing operands, normalisation, denormalisation,
598 rounding etc.
599 """
600
601 def get_op(self, m, op, v, next_state):
602 """ this function moves to the next state and copies the operand
603 when both stb and ack are 1.
604 acknowledgement is sent by setting ack to ZERO.
605 """
606 res = v.decode2(m)
607 ack = Signal()
608 with m.If((op.ready_o) & (op.valid_i_test)):
609 m.next = next_state
610 # op is latched in from FPNumIn class on same ack/stb
611 m.d.comb += ack.eq(0)
612 with m.Else():
613 m.d.comb += ack.eq(1)
614 return [res, ack]
615
616 def denormalise(self, m, a):
617 """ denormalises a number. this is probably the wrong name for
618 this function. for normalised numbers (exponent != minimum)
619 one *extra* bit (the implicit 1) is added *back in*.
620 for denormalised numbers, the mantissa is left alone
621 and the exponent increased by 1.
622
623 both cases *effectively multiply the number stored by 2*,
624 which has to be taken into account when extracting the result.
625 """
626 with m.If(a.exp_n127):
627 m.d.sync += a.e.eq(a.N126) # limit a exponent
628 with m.Else():
629 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
630
631 def op_normalise(self, m, op, next_state):
632 """ operand normalisation
633 NOTE: just like "align", this one keeps going round every clock
634 until the result's exponent is within acceptable "range"
635 """
636 with m.If((op.m[-1] == 0)): # check last bit of mantissa
637 m.d.sync +=[
638 op.e.eq(op.e - 1), # DECREASE exponent
639 op.m.eq(op.m << 1), # shift mantissa UP
640 ]
641 with m.Else():
642 m.next = next_state
643
644 def normalise_1(self, m, z, of, next_state):
645 """ first stage normalisation
646
647 NOTE: just like "align", this one keeps going round every clock
648 until the result's exponent is within acceptable "range"
649 NOTE: the weirdness of reassigning guard and round is due to
650 the extra mantissa bits coming from tot[0..2]
651 """
652 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
653 m.d.sync += [
654 z.e.eq(z.e - 1), # DECREASE exponent
655 z.m.eq(z.m << 1), # shift mantissa UP
656 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
657 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
658 of.round_bit.eq(0), # reset round bit
659 of.m0.eq(of.guard),
660 ]
661 with m.Else():
662 m.next = next_state
663
664 def normalise_2(self, m, z, of, next_state):
665 """ second stage normalisation
666
667 NOTE: just like "align", this one keeps going round every clock
668 until the result's exponent is within acceptable "range"
669 NOTE: the weirdness of reassigning guard and round is due to
670 the extra mantissa bits coming from tot[0..2]
671 """
672 with m.If(z.e < z.N126):
673 m.d.sync +=[
674 z.e.eq(z.e + 1), # INCREASE exponent
675 z.m.eq(z.m >> 1), # shift mantissa DOWN
676 of.guard.eq(z.m[0]),
677 of.m0.eq(z.m[1]),
678 of.round_bit.eq(of.guard),
679 of.sticky.eq(of.sticky | of.round_bit)
680 ]
681 with m.Else():
682 m.next = next_state
683
684 def roundz(self, m, z, roundz):
685 """ performs rounding on the output. TODO: different kinds of rounding
686 """
687 with m.If(roundz):
688 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
689 with m.If(z.m == z.m1s): # all 1s
690 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
691
692 def corrections(self, m, z, next_state):
693 """ denormalisation and sign-bug corrections
694 """
695 m.next = next_state
696 # denormalised, correct exponent to zero
697 with m.If(z.is_denormalised):
698 m.d.sync += z.e.eq(z.N127)
699
700 def pack(self, m, z, next_state):
701 """ packs the result into the output (detects overflow->Inf)
702 """
703 m.next = next_state
704 # if overflow occurs, return inf
705 with m.If(z.is_overflowed):
706 m.d.sync += z.inf(z.s)
707 with m.Else():
708 m.d.sync += z.create(z.s, z.e, z.m)
709
710 def put_z(self, m, z, out_z, next_state):
711 """ put_z: stores the result in the output. raises stb and waits
712 for ack to be set to 1 before moving to the next state.
713 resets stb back to zero when that occurs, as acknowledgement.
714 """
715 m.d.sync += [
716 out_z.v.eq(z.v)
717 ]
718 with m.If(out_z.valid_o & out_z.ready_i_test):
719 m.d.sync += out_z.valid_o.eq(0)
720 m.next = next_state
721 with m.Else():
722 m.d.sync += out_z.valid_o.eq(1)
723
724
725 class FPState(FPBase):
726 def __init__(self, state_from):
727 self.state_from = state_from
728
729 def set_inputs(self, inputs):
730 self.inputs = inputs
731 for k,v in inputs.items():
732 setattr(self, k, v)
733
734 def set_outputs(self, outputs):
735 self.outputs = outputs
736 for k,v in outputs.items():
737 setattr(self, k, v)
738
739
740 class FPID:
741 def __init__(self, id_wid):
742 self.id_wid = id_wid
743 if self.id_wid:
744 self.in_mid = Signal(id_wid, reset_less=True)
745 self.out_mid = Signal(id_wid, reset_less=True)
746 else:
747 self.in_mid = None
748 self.out_mid = None
749
750 def idsync(self, m):
751 if self.id_wid is not None:
752 m.d.sync += self.out_mid.eq(self.in_mid)
753
754