X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffpbase.py;h=db95eb13e2a4ef7ae7199c6a29e05865674fc62a;hb=892d640f8224e6a52907c6899ab6ab671f5f53af;hp=2124849ecac154f1c509bc329b71dfa9bfbbded3;hpb=978906052a938cb8c6f0056d1f0395a18e6acaf8;p=ieee754fpu.git diff --git a/src/add/fpbase.py b/src/add/fpbase.py index 2124849e..db95eb13 100644 --- a/src/add/fpbase.py +++ b/src/add/fpbase.py @@ -96,6 +96,9 @@ class FPNumBase: self.is_overflowed = Signal(reset_less=True) self.is_denormalised = Signal(reset_less=True) self.exp_128 = Signal(reset_less=True) + self.exp_sub_n126 = Signal((e_width, True), reset_less=True) + self.exp_lt_n126 = Signal(reset_less=True) + self.exp_gt_n126 = Signal(reset_less=True) self.exp_gt127 = Signal(reset_less=True) self.exp_n127 = Signal(reset_less=True) self.exp_n126 = Signal(reset_less=True) @@ -110,6 +113,9 @@ class FPNumBase: m.d.comb += self.is_overflowed.eq(self._is_overflowed()) m.d.comb += self.is_denormalised.eq(self._is_denormalised()) m.d.comb += self.exp_128.eq(self.e == self.P128) + m.d.comb += self.exp_sub_n126.eq(self.e - self.N126) + m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0) + m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0) m.d.comb += self.exp_gt127.eq(self.e > self.P127) m.d.comb += self.exp_n127.eq(self.e == self.N127) m.d.comb += self.exp_n126.eq(self.e == self.N126) @@ -133,6 +139,9 @@ class FPNumBase: def _is_denormalised(self): return (self.exp_n126) & (self.m_msbzero) + def copy(self, inp): + return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)] + class FPNumOut(FPNumBase): """ Floating-point Number Class @@ -175,6 +184,49 @@ class FPNumOut(FPNumBase): return self.create(s, self.N127, 0) +class MultiShiftRMerge: + """ shifts down (right) and merges lower bits into m[0]. + m[0] is the "sticky" bit, basically + """ + def __init__(self, width, s_max=None): + if s_max is None: + s_max = int(log(width) / log(2)) + self.smax = s_max + self.m = Signal(width, reset_less=True) + self.inp = Signal(width, reset_less=True) + self.diff = Signal(s_max, reset_less=True) + self.width = width + + def elaborate(self, platform): + m = Module() + + rs = Signal(self.width, reset_less=True) + m_mask = Signal(self.width, reset_less=True) + smask = Signal(self.width, reset_less=True) + stickybit = Signal(reset_less=True) + maxslen = Signal(self.smax, reset_less=True) + maxsleni = Signal(self.smax, reset_less=True) + + sm = MultiShift(self.width-1) + m0s = Const(0, self.width-1) + mw = Const(self.width-1, len(self.diff)) + m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)), + maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)), + ] + + m.d.comb += [ + # shift mantissa by maxslen, mask by inverse + rs.eq(sm.rshift(self.inp[1:], maxslen)), + m_mask.eq(sm.rshift(~m0s, maxsleni)), + smask.eq(self.inp[1:] & m_mask), + # sticky bit combines all mask (and mantissa low bit) + stickybit.eq(smask.bool() | self.inp[0]), + # mantissa result contains m[0] already. + self.m.eq(Cat(stickybit, rs)) + ] + return m + + class FPNumShift(FPNumBase): """ Floating-point Number Class for shifting """ @@ -198,14 +250,14 @@ class FPNumShift(FPNumBase): return m - def shift_down(self): + def shift_down(self, inp): """ shifts a mantissa down by one. exponent is increased to compensate accuracy is lost as a result in the mantissa however there are 3 guard bits (the latter of which is the "sticky" bit) """ - return [self.e.eq(self.e + 1), - self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0)) + return [self.e.eq(inp.e + 1), + self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0)) ] def shift_down_multi(self, diff): @@ -266,9 +318,9 @@ class FPNumIn(FPNumBase): def elaborate(self, platform): m = FPNumBase.elaborate(self, platform) - m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb) - with m.If(self.latch_in): - m.d.sync += self.decode(self.v) + #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb) + #with m.If(self.latch_in): + # m.d.sync += self.decode(self.v) return m @@ -286,17 +338,17 @@ class FPNumIn(FPNumBase): self.s.eq(v[-1]), # sign ] - def shift_down(self): + def shift_down(self, inp): """ shifts a mantissa down by one. exponent is increased to compensate accuracy is lost as a result in the mantissa however there are 3 guard bits (the latter of which is the "sticky" bit) """ - return [self.e.eq(self.e + 1), - self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0)) + return [self.e.eq(inp.e + 1), + self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0)) ] - def shift_down_multi(self, diff): + def shift_down_multi(self, diff, inp=None): """ shifts a mantissa down. exponent is increased to compensate accuracy is lost as a result in the mantissa however there are 3 @@ -311,16 +363,19 @@ class FPNumIn(FPNumBase): inverted and used as a mask to get the LSBs of the mantissa. those are then |'d into the sticky bit. """ + if inp is None: + inp = self sm = MultiShift(self.width) mw = Const(self.m_width-1, len(diff)) maxslen = Mux(diff > mw, mw, diff) - rs = sm.rshift(self.m[1:], maxslen) + rs = sm.rshift(inp.m[1:], maxslen) maxsleni = mw - maxslen m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert - stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0] - return [self.e.eq(self.e + diff), - self.m.eq(Cat(stickybits, rs)) + #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0] + stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0] + return [self.e.eq(inp.e + diff), + self.m.eq(Cat(stickybit, rs)) ] def shift_up_multi(self, diff): @@ -334,13 +389,57 @@ class FPNumIn(FPNumBase): self.m.eq(sm.lshift(self.m, maxslen)) ] -class FPOp: +class Trigger: + def __init__(self): + + self.stb = Signal(reset=0) + self.ack = Signal() + self.trigger = Signal(reset_less=True) + + def elaborate(self, platform): + m = Module() + m.d.comb += self.trigger.eq(self.stb & self.ack) + return m + + def copy(self, inp): + return [self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] + + def ports(self): + return [self.stb, self.ack] + + +class FPOp(Trigger): def __init__(self, width): + Trigger.__init__(self) self.width = width self.v = Signal(width) - self.stb = Signal() - self.ack = Signal() + + def chain_inv(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(~self.ack), # send ACK + ] + + def chain_from(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(self.ack), # send ACK + ] + + def copy(self, inp): + return [self.v.eq(inp.v), + self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] def ports(self): return [self.v, self.stb, self.ack] @@ -355,6 +454,12 @@ class Overflow: self.roundz = Signal(reset_less=True) + def copy(self, inp): + return [self.guard.eq(inp.guard), + self.round_bit.eq(inp.round_bit), + self.sticky.eq(inp.sticky), + self.m0.eq(inp.m0)] + def elaborate(self, platform): m = Module() m.d.comb += self.roundz.eq(self.guard & \ @@ -379,6 +484,7 @@ class FPBase: m.next = next_state m.d.sync += [ # op is latched in from FPNumIn class on same ack/stb + v.decode(op.v), op.ack.eq(0) ] with m.Else(): @@ -394,7 +500,7 @@ class FPBase: both cases *effectively multiply the number stored by 2*, which has to be taken into account when extracting the result. """ - with m.If(a.e == a.N127): + with m.If(a.exp_n127): m.d.sync += a.e.eq(a.N126) # limit a exponent with m.Else(): m.d.sync += a.m[-1].eq(1) # set top mantissa bit @@ -452,11 +558,10 @@ class FPBase: with m.Else(): m.next = next_state - def roundz(self, m, z, of, next_state): + def roundz(self, m, z, roundz): """ performs rounding on the output. TODO: different kinds of rounding """ - m.next = next_state - with m.If(of.roundz): + with m.If(roundz): m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up with m.If(z.m == z.m1s): # all 1s m.d.sync += z.e.eq(z.e + 1) # exponent rounds up