X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffpbase.py;h=db95eb13e2a4ef7ae7199c6a29e05865674fc62a;hb=892d640f8224e6a52907c6899ab6ab671f5f53af;hp=aba8c9676e33b1aac1527a78b298138f2ecd2e5b;hpb=f68605ababcf961defbcf83005dff4699f83d373;p=ieee754fpu.git diff --git a/src/add/fpbase.py b/src/add/fpbase.py index aba8c967..db95eb13 100644 --- a/src/add/fpbase.py +++ b/src/add/fpbase.py @@ -2,43 +2,86 @@ # Copyright (C) Jonathan P Dawson 2013 # 2013-12-12 -from nmigen import Signal, Cat, Const +from nmigen import Signal, Cat, Const, Mux, Module +from math import log +from operator import or_ +from functools import reduce +class MultiShiftR: -class FPNum: - """ Floating-point Number Class, variable-width TODO (currently 32-bit) - - Contains signals for an incoming copy of the value, decoded into - sign / exponent / mantissa. - Also contains encoding functions, creation and recognition of - zero, NaN and inf (all signed) + def __init__(self, width): + self.width = width + self.smax = int(log(width) / log(2)) + self.i = Signal(width, reset_less=True) + self.s = Signal(self.smax, reset_less=True) + self.o = Signal(width, reset_less=True) + + def elaborate(self, platform): + m = Module() + m.d.comb += self.o.eq(self.i >> self.s) + return m + + +class MultiShift: + """ Generates variable-length single-cycle shifter from a series + of conditional tests on each bit of the left/right shift operand. + Each bit tested produces output shifted by that number of bits, + in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set + shifts by 2 bits, each partial result cascading to the next Mux. + + Could be adapted to do arithmetic shift by taking copies of the + MSB instead of zeros. + """ - Four extra bits are included in the mantissa: the top bit - (m[-1]) is effectively a carry-overflow. The other three are - guard (m[2]), round (m[1]), and sticky (m[0]) + def __init__(self, width): + self.width = width + self.smax = int(log(width) / log(2)) + + def lshift(self, op, s): + res = op << s + return res[:len(op)] + res = op + for i in range(self.smax): + zeros = [0] * (1<> s + return res[:len(op)] + res = op + for i in range(self.smax): + zeros = [0] * (1< 0) + m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0) + m.d.comb += self.exp_gt127.eq(self.e > self.P127) + m.d.comb += self.exp_n127.eq(self.e == self.N127) + m.d.comb += self.exp_n126.eq(self.e == self.N126) + m.d.comb += self.m_zero.eq(self.m == self.mzero) + m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0) + + return m + + def _is_nan(self): + return (self.exp_128) & (~self.m_zero) + + def _is_inf(self): + return (self.exp_128) & (self.m_zero) + + def _is_zero(self): + return (self.exp_n127) & (self.m_zero) + + def _is_overflowed(self): + return self.exp_gt127 + + def _is_denormalised(self): + return (self.exp_n126) & (self.m_msbzero) + + def copy(self, inp): + return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)] + + +class FPNumOut(FPNumBase): + """ Floating-point Number Class + + Contains signals for an incoming copy of the value, decoded into + sign / exponent / mantissa. + Also contains encoding functions, creation and recognition of + zero, NaN and inf (all signed) + + Four extra bits are included in the mantissa: the top bit + (m[-1]) is effectively a carry-overflow. The other three are + guard (m[2]), round (m[1]), and sticky (m[0]) + """ + def __init__(self, width, m_extra=True): + FPNumBase.__init__(self, width, m_extra) + + def elaborate(self, platform): + m = FPNumBase.elaborate(self, platform) + + return m + + def create(self, s, e, m): + """ creates a value from sign / exponent / mantissa + + bias is added here, to the exponent + """ + return [ + self.v[-1].eq(s), # sign + self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias) + self.v[0:self.e_start].eq(m) # mantissa + ] + + def nan(self, s): + return self.create(s, self.P128, 1<<(self.e_start-1)) + + def inf(self, s): + return self.create(s, self.P128, 0) + + def zero(self, s): + return self.create(s, self.N127, 0) + + +class MultiShiftRMerge: + """ shifts down (right) and merges lower bits into m[0]. + m[0] is the "sticky" bit, basically + """ + def __init__(self, width, s_max=None): + if s_max is None: + s_max = int(log(width) / log(2)) + self.smax = s_max + self.m = Signal(width, reset_less=True) + self.inp = Signal(width, reset_less=True) + self.diff = Signal(s_max, reset_less=True) + self.width = width + + def elaborate(self, platform): + m = Module() + + rs = Signal(self.width, reset_less=True) + m_mask = Signal(self.width, reset_less=True) + smask = Signal(self.width, reset_less=True) + stickybit = Signal(reset_less=True) + maxslen = Signal(self.smax, reset_less=True) + maxsleni = Signal(self.smax, reset_less=True) + + sm = MultiShift(self.width-1) + m0s = Const(0, self.width-1) + mw = Const(self.width-1, len(self.diff)) + m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)), + maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)), + ] + + m.d.comb += [ + # shift mantissa by maxslen, mask by inverse + rs.eq(sm.rshift(self.inp[1:], maxslen)), + m_mask.eq(sm.rshift(~m0s, maxsleni)), + smask.eq(self.inp[1:] & m_mask), + # sticky bit combines all mask (and mantissa low bit) + stickybit.eq(smask.bool() | self.inp[0]), + # mantissa result contains m[0] already. + self.m.eq(Cat(stickybit, rs)) + ] + return m + + +class FPNumShift(FPNumBase): + """ Floating-point Number Class for shifting + """ + def __init__(self, mainm, op, inv, width, m_extra=True): + FPNumBase.__init__(self, width, m_extra) + self.latch_in = Signal() + self.mainm = mainm + self.inv = inv + self.op = op + + def elaborate(self, platform): + m = FPNumBase.elaborate(self, platform) + + m.d.comb += self.s.eq(op.s) + m.d.comb += self.e.eq(op.e) + m.d.comb += self.m.eq(op.m) + + with self.mainm.State("align"): + with m.If(self.e < self.inv.e): + m.d.sync += self.shift_down() + + return m + + def shift_down(self, inp): + """ shifts a mantissa down by one. exponent is increased to compensate + + accuracy is lost as a result in the mantissa however there are 3 + guard bits (the latter of which is the "sticky" bit) + """ + return [self.e.eq(inp.e + 1), + self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0)) + ] + + def shift_down_multi(self, diff): + """ shifts a mantissa down. exponent is increased to compensate + + accuracy is lost as a result in the mantissa however there are 3 + guard bits (the latter of which is the "sticky" bit) + + this code works by variable-shifting the mantissa by up to + its maximum bit-length: no point doing more (it'll still be + zero). + + the sticky bit is computed by shifting a batch of 1s by + the same amount, which will introduce zeros. it's then + inverted and used as a mask to get the LSBs of the mantissa. + those are then |'d into the sticky bit. + """ + sm = MultiShift(self.width) + mw = Const(self.m_width-1, len(diff)) + maxslen = Mux(diff > mw, mw, diff) + rs = sm.rshift(self.m[1:], maxslen) + maxsleni = mw - maxslen + m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert + + stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0] + return [self.e.eq(self.e + diff), + self.m.eq(Cat(stickybits, rs)) + ] + + def shift_up_multi(self, diff): + """ shifts a mantissa up. exponent is decreased to compensate + """ + sm = MultiShift(self.width) + mw = Const(self.m_width, len(diff)) + maxslen = Mux(diff > mw, mw, diff) + + return [self.e.eq(self.e - diff), + self.m.eq(sm.lshift(self.m, maxslen)) + ] + +class FPNumIn(FPNumBase): + """ Floating-point Number Class + + Contains signals for an incoming copy of the value, decoded into + sign / exponent / mantissa. + Also contains encoding functions, creation and recognition of + zero, NaN and inf (all signed) + + Four extra bits are included in the mantissa: the top bit + (m[-1]) is effectively a carry-overflow. The other three are + guard (m[2]), round (m[1]), and sticky (m[0]) + """ + def __init__(self, op, width, m_extra=True): + FPNumBase.__init__(self, width, m_extra) + self.latch_in = Signal() + self.op = op + + def elaborate(self, platform): + m = FPNumBase.elaborate(self, platform) + + #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb) + #with m.If(self.latch_in): + # m.d.sync += self.decode(self.v) + + return m + def decode(self, v): """ decodes a latched value into sign / exponent / mantissa @@ -55,65 +332,114 @@ class FPNum: a 10-bit number """ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros - print (self.e_end) + #print ("decode", self.e_end) return [self.m.eq(Cat(*args)), # mantissa self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp self.s.eq(v[-1]), # sign ] - def create(self, s, e, m): - """ creates a value from sign / exponent / mantissa + def shift_down(self, inp): + """ shifts a mantissa down by one. exponent is increased to compensate - bias is added here, to the exponent + accuracy is lost as a result in the mantissa however there are 3 + guard bits (the latter of which is the "sticky" bit) """ - return [ - self.v[-1].eq(s), # sign - self.v[self.e_start:self.e_end].eq(e + self.P127), # exp (add on bias) - self.v[0:self.e_start].eq(m) # mantissa - ] + return [self.e.eq(inp.e + 1), + self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0)) + ] - def shift_down(self): - """ shifts a mantissa down by one. exponent is increased to compensate + def shift_down_multi(self, diff, inp=None): + """ shifts a mantissa down. exponent is increased to compensate accuracy is lost as a result in the mantissa however there are 3 guard bits (the latter of which is the "sticky" bit) + + this code works by variable-shifting the mantissa by up to + its maximum bit-length: no point doing more (it'll still be + zero). + + the sticky bit is computed by shifting a batch of 1s by + the same amount, which will introduce zeros. it's then + inverted and used as a mask to get the LSBs of the mantissa. + those are then |'d into the sticky bit. """ - return [self.e.eq(self.e + 1), - self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0)) + if inp is None: + inp = self + sm = MultiShift(self.width) + mw = Const(self.m_width-1, len(diff)) + maxslen = Mux(diff > mw, mw, diff) + rs = sm.rshift(inp.m[1:], maxslen) + maxsleni = mw - maxslen + m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert + + #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0] + stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0] + return [self.e.eq(inp.e + diff), + self.m.eq(Cat(stickybit, rs)) ] - def nan(self, s): - return self.create(s, self.P128, 1<<(self.e_start-1)) - - def inf(self, s): - return self.create(s, self.P128, 0) + def shift_up_multi(self, diff): + """ shifts a mantissa up. exponent is decreased to compensate + """ + sm = MultiShift(self.width) + mw = Const(self.m_width, len(diff)) + maxslen = Mux(diff > mw, mw, diff) - def zero(self, s): - return self.create(s, self.N127, 0) + return [self.e.eq(self.e - diff), + self.m.eq(sm.lshift(self.m, maxslen)) + ] - def is_nan(self): - return (self.e == self.P128) & (self.m != 0) +class Trigger: + def __init__(self): - def is_inf(self): - return (self.e == self.P128) & (self.m == 0) + self.stb = Signal(reset=0) + self.ack = Signal() + self.trigger = Signal(reset_less=True) - def is_zero(self): - return (self.e == self.N127) & (self.m == self.mzero) + def elaborate(self, platform): + m = Module() + m.d.comb += self.trigger.eq(self.stb & self.ack) + return m - def is_overflowed(self): - return (self.e > self.P127) + def copy(self, inp): + return [self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] - def is_denormalised(self): - return (self.e == self.N126) & (self.m[self.e_start] == 0) + def ports(self): + return [self.stb, self.ack] -class FPOp: +class FPOp(Trigger): def __init__(self, width): + Trigger.__init__(self) self.width = width self.v = Signal(width) - self.stb = Signal() - self.ack = Signal() + + def chain_inv(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(~self.ack), # send ACK + ] + + def chain_from(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(self.ack), # send ACK + ] + + def copy(self, inp): + return [self.v.eq(inp.v), + self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] def ports(self): return [self.v, self.stb, self.ack] @@ -121,9 +447,24 @@ class FPOp: class Overflow: def __init__(self): - self.guard = Signal() # tot[2] - self.round_bit = Signal() # tot[1] - self.sticky = Signal() # tot[0] + self.guard = Signal(reset_less=True) # tot[2] + self.round_bit = Signal(reset_less=True) # tot[1] + self.sticky = Signal(reset_less=True) # tot[0] + self.m0 = Signal(reset_less=True) # mantissa zero bit + + self.roundz = Signal(reset_less=True) + + def copy(self, inp): + return [self.guard.eq(inp.guard), + self.round_bit.eq(inp.round_bit), + self.sticky.eq(inp.sticky), + self.m0.eq(inp.m0)] + + def elaborate(self, platform): + m = Module() + m.d.comb += self.roundz.eq(self.guard & \ + (self.round_bit | self.sticky | self.m0)) + return m class FPBase: @@ -142,6 +483,7 @@ class FPBase: with m.If((op.ack) & (op.stb)): m.next = next_state m.d.sync += [ + # op is latched in from FPNumIn class on same ack/stb v.decode(op.v), op.ack.eq(0) ] @@ -149,9 +491,16 @@ class FPBase: m.d.sync += op.ack.eq(1) def denormalise(self, m, a): - """ denormalises a number + """ denormalises a number. this is probably the wrong name for + this function. for normalised numbers (exponent != minimum) + one *extra* bit (the implicit 1) is added *back in*. + for denormalised numbers, the mantissa is left alone + and the exponent increased by 1. + + both cases *effectively multiply the number stored by 2*, + which has to be taken into account when extracting the result. """ - with m.If(a.e == a.N127): + with m.If(a.exp_n127): m.d.sync += a.e.eq(a.N126) # limit a exponent with m.Else(): m.d.sync += a.m[-1].eq(1) # set top mantissa bit @@ -178,12 +527,13 @@ class FPBase: the extra mantissa bits coming from tot[0..2] """ with m.If((z.m[-1] == 0) & (z.e > z.N126)): - m.d.sync +=[ + m.d.sync += [ z.e.eq(z.e - 1), # DECREASE exponent z.m.eq(z.m << 1), # shift mantissa UP z.m[0].eq(of.guard), # steal guard bit (was tot[2]) of.guard.eq(of.round_bit), # steal round_bit (was tot[1]) of.round_bit.eq(0), # reset round bit + of.m0.eq(of.guard), ] with m.Else(): m.next = next_state @@ -201,17 +551,17 @@ class FPBase: z.e.eq(z.e + 1), # INCREASE exponent z.m.eq(z.m >> 1), # shift mantissa DOWN of.guard.eq(z.m[0]), + of.m0.eq(z.m[1]), of.round_bit.eq(of.guard), of.sticky.eq(of.sticky | of.round_bit) ] with m.Else(): m.next = next_state - def roundz(self, m, z, of, next_state): + def roundz(self, m, z, roundz): """ performs rounding on the output. TODO: different kinds of rounding """ - m.next = next_state - with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])): + with m.If(roundz): m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up with m.If(z.m == z.m1s): # all 1s m.d.sync += z.e.eq(z.e + 1) # exponent rounds up @@ -221,19 +571,16 @@ class FPBase: """ m.next = next_state # denormalised, correct exponent to zero - with m.If(z.is_denormalised()): - m.d.sync += z.m.eq(z.N127) - # FIX SIGN BUG: -a + a = +0. - with m.If((z.e == z.N126) & (z.m[0:] == 0)): - m.d.sync += z.s.eq(0) + with m.If(z.is_denormalised): + m.d.sync += z.e.eq(z.N127) def pack(self, m, z, next_state): """ packs the result into the output (detects overflow->Inf) """ m.next = next_state # if overflow occurs, return inf - with m.If(z.is_overflowed()): - m.d.sync += z.inf(0) + with m.If(z.is_overflowed): + m.d.sync += z.inf(z.s) with m.Else(): m.d.sync += z.create(z.s, z.e, z.m) @@ -243,11 +590,12 @@ class FPBase: resets stb back to zero when that occurs, as acknowledgement. """ m.d.sync += [ - out_z.stb.eq(1), out_z.v.eq(z.v) ] with m.If(out_z.stb & out_z.ack): m.d.sync += out_z.stb.eq(0) m.next = next_state + with m.Else(): + m.d.sync += out_z.stb.eq(1)