X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffpbase.py;h=db95eb13e2a4ef7ae7199c6a29e05865674fc62a;hb=892d640f8224e6a52907c6899ab6ab671f5f53af;hp=bfd946219b65ec8a026eeab04f56c497d558775d;hpb=559a675c570f9373777426ee80cedc9fed13690f;p=ieee754fpu.git diff --git a/src/add/fpbase.py b/src/add/fpbase.py index bfd94621..db95eb13 100644 --- a/src/add/fpbase.py +++ b/src/add/fpbase.py @@ -96,6 +96,7 @@ class FPNumBase: self.is_overflowed = Signal(reset_less=True) self.is_denormalised = Signal(reset_less=True) self.exp_128 = Signal(reset_less=True) + self.exp_sub_n126 = Signal((e_width, True), reset_less=True) self.exp_lt_n126 = Signal(reset_less=True) self.exp_gt_n126 = Signal(reset_less=True) self.exp_gt127 = Signal(reset_less=True) @@ -112,8 +113,9 @@ class FPNumBase: m.d.comb += self.is_overflowed.eq(self._is_overflowed()) m.d.comb += self.is_denormalised.eq(self._is_denormalised()) m.d.comb += self.exp_128.eq(self.e == self.P128) - m.d.comb += self.exp_gt_n126.eq(self.e > self.N126) - m.d.comb += self.exp_lt_n126.eq(self.e < self.N126) + m.d.comb += self.exp_sub_n126.eq(self.e - self.N126) + m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0) + m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0) m.d.comb += self.exp_gt127.eq(self.e > self.P127) m.d.comb += self.exp_n127.eq(self.e == self.N127) m.d.comb += self.exp_n126.eq(self.e == self.N126) @@ -182,6 +184,49 @@ class FPNumOut(FPNumBase): return self.create(s, self.N127, 0) +class MultiShiftRMerge: + """ shifts down (right) and merges lower bits into m[0]. + m[0] is the "sticky" bit, basically + """ + def __init__(self, width, s_max=None): + if s_max is None: + s_max = int(log(width) / log(2)) + self.smax = s_max + self.m = Signal(width, reset_less=True) + self.inp = Signal(width, reset_less=True) + self.diff = Signal(s_max, reset_less=True) + self.width = width + + def elaborate(self, platform): + m = Module() + + rs = Signal(self.width, reset_less=True) + m_mask = Signal(self.width, reset_less=True) + smask = Signal(self.width, reset_less=True) + stickybit = Signal(reset_less=True) + maxslen = Signal(self.smax, reset_less=True) + maxsleni = Signal(self.smax, reset_less=True) + + sm = MultiShift(self.width-1) + m0s = Const(0, self.width-1) + mw = Const(self.width-1, len(self.diff)) + m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)), + maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)), + ] + + m.d.comb += [ + # shift mantissa by maxslen, mask by inverse + rs.eq(sm.rshift(self.inp[1:], maxslen)), + m_mask.eq(sm.rshift(~m0s, maxsleni)), + smask.eq(self.inp[1:] & m_mask), + # sticky bit combines all mask (and mantissa low bit) + stickybit.eq(smask.bool() | self.inp[0]), + # mantissa result contains m[0] already. + self.m.eq(Cat(stickybit, rs)) + ] + return m + + class FPNumShift(FPNumBase): """ Floating-point Number Class for shifting """ @@ -303,7 +348,7 @@ class FPNumIn(FPNumBase): self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0)) ] - def shift_down_multi(self, diff): + def shift_down_multi(self, diff, inp=None): """ shifts a mantissa down. exponent is increased to compensate accuracy is lost as a result in the mantissa however there are 3 @@ -318,16 +363,19 @@ class FPNumIn(FPNumBase): inverted and used as a mask to get the LSBs of the mantissa. those are then |'d into the sticky bit. """ + if inp is None: + inp = self sm = MultiShift(self.width) mw = Const(self.m_width-1, len(diff)) maxslen = Mux(diff > mw, mw, diff) - rs = sm.rshift(self.m[1:], maxslen) + rs = sm.rshift(inp.m[1:], maxslen) maxsleni = mw - maxslen m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert - stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0] - return [self.e.eq(self.e + diff), - self.m.eq(Cat(stickybits, rs)) + #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0] + stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0] + return [self.e.eq(inp.e + diff), + self.m.eq(Cat(stickybit, rs)) ] def shift_up_multi(self, diff): @@ -341,20 +389,34 @@ class FPNumIn(FPNumBase): self.m.eq(sm.lshift(self.m, maxslen)) ] -class FPOp: - def __init__(self, width): - self.width = width +class Trigger: + def __init__(self): - self.v = Signal(width) self.stb = Signal(reset=0) self.ack = Signal() self.trigger = Signal(reset_less=True) def elaborate(self, platform): m = Module() - m.d.sync += self.trigger.eq(self.stb & self.ack) + m.d.comb += self.trigger.eq(self.stb & self.ack) return m + def copy(self, inp): + return [self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] + + def ports(self): + return [self.stb, self.ack] + + +class FPOp(Trigger): + def __init__(self, width): + Trigger.__init__(self) + self.width = width + + self.v = Signal(width) + def chain_inv(self, in_op, extra=None): stb = in_op.stb if extra is not None: @@ -496,14 +558,13 @@ class FPBase: with m.Else(): m.next = next_state - def roundz(self, m, z, out_z, roundz): + def roundz(self, m, z, roundz): """ performs rounding on the output. TODO: different kinds of rounding """ - m.d.comb += out_z.copy(z) # copies input to output first with m.If(roundz): - m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up + m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up with m.If(z.m == z.m1s): # all 1s - m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up + m.d.sync += z.e.eq(z.e + 1) # exponent rounds up def corrections(self, m, z, next_state): """ denormalisation and sign-bug corrections