X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffpbase.py;h=f49085921d914d600dea950d3707ab924eb0272b;hb=6bff1a997f3846872cf489c24b5c01426c4dc97c;hp=d8073d5355ae1f787c9b07e73eae0cb02f49d276;hpb=7cef607cae22586ffa4b376fc167fc668f59be14;p=ieee754fpu.git diff --git a/src/add/fpbase.py b/src/add/fpbase.py index d8073d53..f4908592 100644 --- a/src/add/fpbase.py +++ b/src/add/fpbase.py @@ -2,11 +2,15 @@ # Copyright (C) Jonathan P Dawson 2013 # 2013-12-12 -from nmigen import Signal, Cat, Const, Mux, Module +from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable from math import log from operator import or_ from functools import reduce +from singlepipe import PrevControl, NextControl +from pipeline import ObjectProxy + + class MultiShiftR: def __init__(self, width): @@ -56,7 +60,7 @@ class MultiShift: return res -class FPNumBase: +class FPNumBase: #(Elaboratable): """ Floating-point Base Number Class """ def __init__(self, width, m_extra=True): @@ -84,6 +88,8 @@ class FPNumBase: self.s = Signal(reset_less=True) # Sign bit self.mzero = Const(0, (m_width, False)) + m_msb = 1<<(self.m_width-2) + self.msb1 = Const(m_msb, (m_width, False)) self.m1s = Const(-1, (m_width, False)) self.P128 = Const(e_max, (e_width, True)) self.P127 = Const(e_max-1, (e_width, True)) @@ -96,6 +102,7 @@ class FPNumBase: self.is_overflowed = Signal(reset_less=True) self.is_denormalised = Signal(reset_less=True) self.exp_128 = Signal(reset_less=True) + self.exp_sub_n126 = Signal((e_width, True), reset_less=True) self.exp_lt_n126 = Signal(reset_less=True) self.exp_gt_n126 = Signal(reset_less=True) self.exp_gt127 = Signal(reset_less=True) @@ -112,8 +119,9 @@ class FPNumBase: m.d.comb += self.is_overflowed.eq(self._is_overflowed()) m.d.comb += self.is_denormalised.eq(self._is_denormalised()) m.d.comb += self.exp_128.eq(self.e == self.P128) - m.d.comb += self.exp_gt_n126.eq(self.e > self.N126) - m.d.comb += self.exp_lt_n126.eq(self.e < self.N126) + m.d.comb += self.exp_sub_n126.eq(self.e - self.N126) + m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0) + m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0) m.d.comb += self.exp_gt127.eq(self.e > self.P127) m.d.comb += self.exp_n127.eq(self.e == self.N127) m.d.comb += self.exp_n126.eq(self.e == self.N126) @@ -137,7 +145,12 @@ class FPNumBase: def _is_denormalised(self): return (self.exp_n126) & (self.m_msbzero) - def copy(self, inp): + def __iter__(self): + yield self.s + yield self.e + yield self.m + + def eq(self, inp): return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)] @@ -181,57 +194,70 @@ class FPNumOut(FPNumBase): def zero(self, s): return self.create(s, self.N127, 0) + def create2(self, s, e, m): + """ creates a value from sign / exponent / mantissa + + bias is added here, to the exponent + """ + e = e + self.P127 # exp (add on bias) + return Cat(m[0:self.e_start], + e[0:self.e_end-self.e_start], + s) -class FPNumShiftMultiRight(FPNumBase): - """ shifts a mantissa down. exponent is increased to compensate + def nan2(self, s): + return self.create2(s, self.P128, self.msb1) - accuracy is lost as a result in the mantissa however there are 3 - guard bits (the latter of which is the "sticky" bit) + def inf2(self, s): + return self.create2(s, self.P128, self.mzero) - this code works by variable-shifting the mantissa by up to - its maximum bit-length: no point doing more (it'll still be - zero). + def zero2(self, s): + return self.create2(s, self.N127, self.mzero) - the sticky bit is computed by shifting a batch of 1s by - the same amount, which will introduce zeros. it's then - inverted and used as a mask to get the LSBs of the mantissa. - those are then |'d into the sticky bit. + +class MultiShiftRMerge(Elaboratable): + """ shifts down (right) and merges lower bits into m[0]. + m[0] is the "sticky" bit, basically """ - def __init__(self, inp, diff, width): + def __init__(self, width, s_max=None): + if s_max is None: + s_max = int(log(width) / log(2)) + self.smax = s_max self.m = Signal(width, reset_less=True) - self.inp = inp - self.diff = diff + self.inp = Signal(width, reset_less=True) + self.diff = Signal(s_max, reset_less=True) self.width = width def elaborate(self, platform): m = Module() - #m.submodules.inp = self.inp rs = Signal(self.width, reset_less=True) m_mask = Signal(self.width, reset_less=True) smask = Signal(self.width, reset_less=True) stickybit = Signal(reset_less=True) + maxslen = Signal(self.smax, reset_less=True) + maxsleni = Signal(self.smax, reset_less=True) sm = MultiShift(self.width-1) + m0s = Const(0, self.width-1) mw = Const(self.width-1, len(self.diff)) - maxslen = Mux(self.diff > mw, mw, self.diff) - maxsleni = mw - maxslen + m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)), + maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)), + ] + m.d.comb += [ # shift mantissa by maxslen, mask by inverse - rs.eq(sm.rshift(self.inp.m[1:], maxslen)), - m_mask.eq(sm.rshift(self.inp.m1s[1:], maxsleni)), - smask.eq(self.inp.m[1:] & m_mask), + rs.eq(sm.rshift(self.inp[1:], maxslen)), + m_mask.eq(sm.rshift(~m0s, maxsleni)), + smask.eq(self.inp[1:] & m_mask), # sticky bit combines all mask (and mantissa low bit) - stickybit.eq(smask.bool() | self.inp.m[0]), - #self.s.eq(self.inp.s), - #self.e.eq(self.inp.e + diff), + stickybit.eq(smask.bool() | self.inp[0]), # mantissa result contains m[0] already. self.m.eq(Cat(stickybit, rs)) ] return m -class FPNumShift(FPNumBase): +class FPNumShift(FPNumBase, Elaboratable): """ Floating-point Number Class for shifting """ def __init__(self, mainm, op, inv, width, m_extra=True): @@ -302,7 +328,8 @@ class FPNumShift(FPNumBase): self.m.eq(sm.lshift(self.m, maxslen)) ] -class FPNumIn(FPNumBase): + +class FPNumDecode(FPNumBase): """ Floating-point Number Class Contains signals for an incoming copy of the value, decoded into @@ -316,15 +343,12 @@ class FPNumIn(FPNumBase): """ def __init__(self, op, width, m_extra=True): FPNumBase.__init__(self, width, m_extra) - self.latch_in = Signal() self.op = op def elaborate(self, platform): m = FPNumBase.elaborate(self, platform) - #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb) - #with m.If(self.latch_in): - # m.d.sync += self.decode(self.v) + m.d.comb += self.decode(self.v) return m @@ -342,6 +366,53 @@ class FPNumIn(FPNumBase): self.s.eq(v[-1]), # sign ] +class FPNumIn(FPNumBase): + """ Floating-point Number Class + + Contains signals for an incoming copy of the value, decoded into + sign / exponent / mantissa. + Also contains encoding functions, creation and recognition of + zero, NaN and inf (all signed) + + Four extra bits are included in the mantissa: the top bit + (m[-1]) is effectively a carry-overflow. The other three are + guard (m[2]), round (m[1]), and sticky (m[0]) + """ + def __init__(self, op, width, m_extra=True): + FPNumBase.__init__(self, width, m_extra) + self.latch_in = Signal() + self.op = op + + def decode2(self, m): + """ decodes a latched value into sign / exponent / mantissa + + bias is subtracted here, from the exponent. exponent + is extended to 10 bits so that subtract 127 is done on + a 10-bit number + """ + v = self.v + args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros + #print ("decode", self.e_end) + res = ObjectProxy(m, pipemode=False) + res.m = Cat(*args) # mantissa + res.e = v[self.e_start:self.e_end] - self.P127 # exp + res.s = v[-1] # sign + return res + + def decode(self, v): + """ decodes a latched value into sign / exponent / mantissa + + bias is subtracted here, from the exponent. exponent + is extended to 10 bits so that subtract 127 is done on + a 10-bit number + """ + args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros + #print ("decode", self.e_end) + return [self.m.eq(Cat(*args)), # mantissa + self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp + self.s.eq(v[-1]), # sign + ] + def shift_down(self, inp): """ shifts a mantissa down by one. exponent is increased to compensate @@ -376,10 +447,10 @@ class FPNumIn(FPNumBase): maxsleni = mw - maxslen m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert - #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0] - stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0] + #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0] + stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0] return [self.e.eq(inp.e + diff), - self.m.eq(Cat(stickybits, rs)) + self.m.eq(Cat(stickybit, rs)) ] def shift_up_multi(self, diff): @@ -393,20 +464,36 @@ class FPNumIn(FPNumBase): self.m.eq(sm.lshift(self.m, maxslen)) ] -class FPOp: - def __init__(self, width): - self.width = width +class Trigger(Elaboratable): + def __init__(self): - self.v = Signal(width) self.stb = Signal(reset=0) self.ack = Signal() self.trigger = Signal(reset_less=True) def elaborate(self, platform): m = Module() - m.d.sync += self.trigger.eq(self.stb & self.ack) + m.d.comb += self.trigger.eq(self.stb & self.ack) return m + def eq(self, inp): + return [self.stb.eq(inp.stb), + self.ack.eq(inp.ack) + ] + + def ports(self): + return [self.stb, self.ack] + + +class FPOpIn(PrevControl): + def __init__(self, width): + PrevControl.__init__(self) + self.width = width + + @property + def v(self): + return self.data_i + def chain_inv(self, in_op, extra=None): stb = in_op.stb if extra is not None: @@ -425,17 +512,36 @@ class FPOp: in_op.ack.eq(self.ack), # send ACK ] - def copy(self, inp): - return [self.v.eq(inp.v), - self.stb.eq(inp.stb), - self.ack.eq(inp.ack) + +class FPOpOut(NextControl): + def __init__(self, width): + NextControl.__init__(self) + self.width = width + + @property + def v(self): + return self.data_o + + def chain_inv(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(~self.ack), # send ACK ] - def ports(self): - return [self.v, self.stb, self.ack] + def chain_from(self, in_op, extra=None): + stb = in_op.stb + if extra is not None: + stb = stb & extra + return [self.v.eq(in_op.v), # receive value + self.stb.eq(stb), # receive STB + in_op.ack.eq(self.ack), # send ACK + ] -class Overflow: +class Overflow: #(Elaboratable): def __init__(self): self.guard = Signal(reset_less=True) # tot[2] self.round_bit = Signal(reset_less=True) # tot[1] @@ -444,7 +550,13 @@ class Overflow: self.roundz = Signal(reset_less=True) - def copy(self, inp): + def __iter__(self): + yield self.guard + yield self.round_bit + yield self.sticky + yield self.m0 + + def eq(self, inp): return [self.guard.eq(inp.guard), self.round_bit.eq(inp.round_bit), self.sticky.eq(inp.sticky), @@ -470,15 +582,15 @@ class FPBase: when both stb and ack are 1. acknowledgement is sent by setting ack to ZERO. """ - with m.If((op.ack) & (op.stb)): + res = v.decode2(m) + ack = Signal() + with m.If((op.ready_o) & (op.valid_i_test)): m.next = next_state - m.d.sync += [ - # op is latched in from FPNumIn class on same ack/stb - v.decode(op.v), - op.ack.eq(0) - ] + # op is latched in from FPNumIn class on same ack/stb + m.d.comb += ack.eq(0) with m.Else(): - m.d.sync += op.ack.eq(1) + m.d.comb += ack.eq(1) + return [res, ack] def denormalise(self, m, a): """ denormalises a number. this is probably the wrong name for @@ -548,14 +660,13 @@ class FPBase: with m.Else(): m.next = next_state - def roundz(self, m, z, out_z, roundz): + def roundz(self, m, z, roundz): """ performs rounding on the output. TODO: different kinds of rounding """ - m.d.comb += out_z.copy(z) # copies input to output first with m.If(roundz): - m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up + m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up with m.If(z.m == z.m1s): # all 1s - m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up + m.d.sync += z.e.eq(z.e + 1) # exponent rounds up def corrections(self, m, z, next_state): """ denormalisation and sign-bug corrections @@ -583,10 +694,40 @@ class FPBase: m.d.sync += [ out_z.v.eq(z.v) ] - with m.If(out_z.stb & out_z.ack): - m.d.sync += out_z.stb.eq(0) + with m.If(out_z.valid_o & out_z.ready_i_test): + m.d.sync += out_z.valid_o.eq(0) m.next = next_state with m.Else(): - m.d.sync += out_z.stb.eq(1) + m.d.sync += out_z.valid_o.eq(1) + + +class FPState(FPBase): + def __init__(self, state_from): + self.state_from = state_from + + def set_inputs(self, inputs): + self.inputs = inputs + for k,v in inputs.items(): + setattr(self, k, v) + + def set_outputs(self, outputs): + self.outputs = outputs + for k,v in outputs.items(): + setattr(self, k, v) + + +class FPID: + def __init__(self, id_wid): + self.id_wid = id_wid + if self.id_wid: + self.in_mid = Signal(id_wid, reset_less=True) + self.out_mid = Signal(id_wid, reset_less=True) + else: + self.in_mid = None + self.out_mid = None + + def idsync(self, m): + if self.id_wid is not None: + m.d.sync += self.out_mid.eq(self.in_mid)