update comments
[ieee754fpu.git] / src / add / fpbase.py
index e6222c2fa5bb3e2d94fe3f99312825ca6bcac024..f49085921d914d600dea950d3707ab924eb0272b 100644 (file)
@@ -2,19 +2,23 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
 from math import log
 from operator import or_
 from functools import reduce
 
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
         self.width = width
         self.smax = int(log(width) / log(2))
-        self.i = Signal(width)
-        self.s = Signal(self.smax)
-        self.o = Signal(width)
+        self.i = Signal(width, reset_less=True)
+        self.s = Signal(self.smax, reset_less=True)
+        self.o = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
@@ -56,17 +60,8 @@ class MultiShift:
         return res
 
 
-class FPNum:
-    """ Floating-point Number Class, variable-width TODO (currently 32-bit)
-
-        Contains signals for an incoming copy of the value, decoded into
-        sign / exponent / mantissa.
-        Also contains encoding functions, creation and recognition of
-        zero, NaN and inf (all signed)
-
-        Four extra bits are included in the mantissa: the top bit
-        (m[-1]) is effectively a carry-overflow.  The other three are
-        guard (m[2]), round (m[1]), and sticky (m[0])
+class FPNumBase: #(Elaboratable):
+    """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
         self.width = width
@@ -87,31 +82,97 @@ class FPNum:
         self.e_start = self.rmw - 1
         self.e_end = self.rmw + self.e_width - 3 # for decoding
 
-        self.v = Signal(width)      # Latched copy of value
-        self.m = Signal(m_width)    # Mantissa
-        self.e = Signal((e_width, True)) # Exponent: IEEE754exp+2 bits, signed
-        self.s = Signal()           # Sign bit
+        self.v = Signal(width, reset_less=True)      # Latched copy of value
+        self.m = Signal(m_width, reset_less=True)    # Mantissa
+        self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
+        self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
         self.N127 = Const(-(e_max-1), (e_width, True))
         self.N126 = Const(-(e_max-2), (e_width, True))
 
-    def decode(self, v):
-        """ decodes a latched value into sign / exponent / mantissa
+        self.is_nan = Signal(reset_less=True)
+        self.is_zero = Signal(reset_less=True)
+        self.is_inf = Signal(reset_less=True)
+        self.is_overflowed = Signal(reset_less=True)
+        self.is_denormalised = Signal(reset_less=True)
+        self.exp_128 = Signal(reset_less=True)
+        self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
+        self.exp_lt_n126 = Signal(reset_less=True)
+        self.exp_gt_n126 = Signal(reset_less=True)
+        self.exp_gt127 = Signal(reset_less=True)
+        self.exp_n127 = Signal(reset_less=True)
+        self.exp_n126 = Signal(reset_less=True)
+        self.m_zero = Signal(reset_less=True)
+        self.m_msbzero = Signal(reset_less=True)
 
-            bias is subtracted here, from the exponent.  exponent
-            is extended to 10 bits so that subtract 127 is done on
-            a 10-bit number
-        """
-        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
-        #print ("decode", self.e_end)
-        return [self.m.eq(Cat(*args)), # mantissa
-                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
-                self.s.eq(v[-1]),                 # sign
-                ]
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.is_nan.eq(self._is_nan())
+        m.d.comb += self.is_zero.eq(self._is_zero())
+        m.d.comb += self.is_inf.eq(self._is_inf())
+        m.d.comb += self.is_overflowed.eq(self._is_overflowed())
+        m.d.comb += self.is_denormalised.eq(self._is_denormalised())
+        m.d.comb += self.exp_128.eq(self.e == self.P128)
+        m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
+        m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
+        m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
+        m.d.comb += self.exp_gt127.eq(self.e > self.P127)
+        m.d.comb += self.exp_n127.eq(self.e == self.N127)
+        m.d.comb += self.exp_n126.eq(self.e == self.N126)
+        m.d.comb += self.m_zero.eq(self.m == self.mzero)
+        m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
+
+        return m
+
+    def _is_nan(self):
+        return (self.exp_128) & (~self.m_zero)
+
+    def _is_inf(self):
+        return (self.exp_128) & (self.m_zero)
+
+    def _is_zero(self):
+        return (self.exp_n127) & (self.m_zero)
+
+    def _is_overflowed(self):
+        return self.exp_gt127
+
+    def _is_denormalised(self):
+        return (self.exp_n126) & (self.m_msbzero)
+
+    def __iter__(self):
+        yield self.s
+        yield self.e
+        yield self.m
+
+    def eq(self, inp):
+        return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
+
+
+class FPNumOut(FPNumBase):
+    """ Floating-point Number Class
+
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
+
+        return m
 
     def create(self, s, e, m):
         """ creates a value from sign / exponent / mantissa
@@ -124,14 +185,109 @@ class FPNum:
           self.v[0:self.e_start].eq(m)         # mantissa
         ]
 
-    def shift_down(self):
+    def nan(self, s):
+        return self.create(s, self.P128, 1<<(self.e_start-1))
+
+    def inf(self, s):
+        return self.create(s, self.P128, 0)
+
+    def zero(self, s):
+        return self.create(s, self.N127, 0)
+
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
+
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
+
+class MultiShiftRMerge(Elaboratable):
+    """ shifts down (right) and merges lower bits into m[0].
+        m[0] is the "sticky" bit, basically
+    """
+    def __init__(self, width, s_max=None):
+        if s_max is None:
+            s_max = int(log(width) / log(2))
+        self.smax = s_max
+        self.m = Signal(width, reset_less=True)
+        self.inp = Signal(width, reset_less=True)
+        self.diff = Signal(s_max, reset_less=True)
+        self.width = width
+
+    def elaborate(self, platform):
+        m = Module()
+
+        rs = Signal(self.width, reset_less=True)
+        m_mask = Signal(self.width, reset_less=True)
+        smask = Signal(self.width, reset_less=True)
+        stickybit = Signal(reset_less=True)
+        maxslen = Signal(self.smax, reset_less=True)
+        maxsleni = Signal(self.smax, reset_less=True)
+
+        sm = MultiShift(self.width-1)
+        m0s = Const(0, self.width-1)
+        mw = Const(self.width-1, len(self.diff))
+        m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+                     maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+                    ]
+
+        m.d.comb += [
+                # shift mantissa by maxslen, mask by inverse
+                rs.eq(sm.rshift(self.inp[1:], maxslen)),
+                m_mask.eq(sm.rshift(~m0s, maxsleni)),
+                smask.eq(self.inp[1:] & m_mask),
+                # sticky bit combines all mask (and mantissa low bit)
+                stickybit.eq(smask.bool() | self.inp[0]),
+                # mantissa result contains m[0] already.
+                self.m.eq(Cat(stickybit, rs))
+           ]
+        return m
+
+
+class FPNumShift(FPNumBase, Elaboratable):
+    """ Floating-point Number Class for shifting
+    """
+    def __init__(self, mainm, op, inv, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.mainm = mainm
+        self.inv = inv
+        self.op = op
+
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
+
+        m.d.comb += self.s.eq(op.s)
+        m.d.comb += self.e.eq(op.e)
+        m.d.comb += self.m.eq(op.m)
+
+        with self.mainm.State("align"):
+            with m.If(self.e < self.inv.e):
+                m.d.sync += self.shift_down()
+
+        return m
+
+    def shift_down(self, inp):
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
-        return [self.e.eq(self.e + 1),
-                self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
                ]
 
     def shift_down_multi(self, diff):
@@ -172,48 +328,245 @@ class FPNum:
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-    def nan(self, s):
-        return self.create(s, self.P128, 1<<(self.e_start-1))
 
-    def inf(self, s):
-        return self.create(s, self.P128, 0)
+class FPNumDecode(FPNumBase):
+    """ Floating-point Number Class
 
-    def zero(self, s):
-        return self.create(s, self.N127, 0)
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
 
-    def is_nan(self):
-        return (self.e == self.P128) & (self.m != 0)
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.op = op
 
-    def is_inf(self):
-        return (self.e == self.P128) & (self.m == 0)
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
 
-    def is_zero(self):
-        return (self.e == self.N127) & (self.m == self.mzero)
+        m.d.comb += self.decode(self.v)
 
-    def is_overflowed(self):
-        return (self.e > self.P127)
+        return m
 
-    def is_denormalised(self):
-        return (self.e == self.N126) & (self.m[self.e_start] == 0)
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
 
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
 
-class FPOp:
-    def __init__(self, width):
-        self.width = width
+class FPNumIn(FPNumBase):
+    """ Floating-point Number Class
 
-        self.v   = Signal(width)
-        self.stb = Signal()
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.op = op
+
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
+
+    def shift_down(self, inp):
+        """ shifts a mantissa down by one. exponent is increased to compensate
+
+            accuracy is lost as a result in the mantissa however there are 3
+            guard bits (the latter of which is the "sticky" bit)
+        """
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
+               ]
+
+    def shift_down_multi(self, diff, inp=None):
+        """ shifts a mantissa down. exponent is increased to compensate
+
+            accuracy is lost as a result in the mantissa however there are 3
+            guard bits (the latter of which is the "sticky" bit)
+
+            this code works by variable-shifting the mantissa by up to
+            its maximum bit-length: no point doing more (it'll still be
+            zero).
+
+            the sticky bit is computed by shifting a batch of 1s by
+            the same amount, which will introduce zeros.  it's then
+            inverted and used as a mask to get the LSBs of the mantissa.
+            those are then |'d into the sticky bit.
+        """
+        if inp is None:
+            inp = self
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width-1, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
+        rs = sm.rshift(inp.m[1:], maxslen)
+        maxsleni = mw - maxslen
+        m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+
+        #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
+        stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
+        return [self.e.eq(inp.e + diff),
+                self.m.eq(Cat(stickybit, rs))
+               ]
+
+    def shift_up_multi(self, diff):
+        """ shifts a mantissa up. exponent is decreased to compensate
+        """
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
+
+        return [self.e.eq(self.e - diff),
+                self.m.eq(sm.lshift(self.m, maxslen))
+               ]
+
+class Trigger(Elaboratable):
+    def __init__(self):
+
+        self.stb = Signal(reset=0)
         self.ack = Signal()
+        self.trigger = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.trigger.eq(self.stb & self.ack)
+        return m
+
+    def eq(self, inp):
+        return [self.stb.eq(inp.stb),
+                self.ack.eq(inp.ack)
+               ]
 
     def ports(self):
-        return [self.v, self.stb, self.ack]
+        return [self.stb, self.ack]
+
+
+class FPOpIn(PrevControl):
+    def __init__(self, width):
+        PrevControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_i
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
+               ]
 
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
+
+
+class FPOpOut(NextControl):
+    def __init__(self, width):
+        NextControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_o
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
+               ]
 
-class Overflow:
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
+
+
+class Overflow: #(Elaboratable):
     def __init__(self):
-        self.guard = Signal()     # tot[2]
-        self.round_bit = Signal() # tot[1]
-        self.sticky = Signal()    # tot[0]
+        self.guard = Signal(reset_less=True)     # tot[2]
+        self.round_bit = Signal(reset_less=True) # tot[1]
+        self.sticky = Signal(reset_less=True)    # tot[0]
+        self.m0 = Signal(reset_less=True)        # mantissa zero bit
+
+        self.roundz = Signal(reset_less=True)
+
+    def __iter__(self):
+        yield self.guard
+        yield self.round_bit
+        yield self.sticky
+        yield self.m0
+
+    def eq(self, inp):
+        return [self.guard.eq(inp.guard),
+                self.round_bit.eq(inp.round_bit),
+                self.sticky.eq(inp.sticky),
+                self.m0.eq(inp.m0)]
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.roundz.eq(self.guard & \
+                                   (self.round_bit | self.sticky | self.m0))
+        return m
 
 
 class FPBase:
@@ -229,14 +582,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
-        with m.If((op.ack) & (op.stb)):
+        res = v.decode2(m)
+        ack = Signal()
+        with m.If((op.ready_o) & (op.valid_i_test)):
             m.next = next_state
-            m.d.sync += [
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
@@ -248,7 +602,7 @@ class FPBase:
             both cases *effectively multiply the number stored by 2*,
             which has to be taken into account when extracting the result.
         """
-        with m.If(a.e == a.N127):
+        with m.If(a.exp_n127):
             m.d.sync += a.e.eq(a.N126) # limit a exponent
         with m.Else():
             m.d.sync += a.m[-1].eq(1) # set top mantissa bit
@@ -281,6 +635,7 @@ class FPBase:
                 z.m[0].eq(of.guard),       # steal guard bit (was tot[2])
                 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
                 of.round_bit.eq(0),        # reset round bit
+                of.m0.eq(of.guard),
             ]
         with m.Else():
             m.next = next_state
@@ -298,17 +653,17 @@ class FPBase:
                 z.e.eq(z.e + 1),  # INCREASE exponent
                 z.m.eq(z.m >> 1), # shift mantissa DOWN
                 of.guard.eq(z.m[0]),
+                of.m0.eq(z.m[1]),
                 of.round_bit.eq(of.guard),
                 of.sticky.eq(of.sticky | of.round_bit)
             ]
         with m.Else():
             m.next = next_state
 
-    def roundz(self, m, z, of, next_state):
+    def roundz(self, m, z, roundz):
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
-        m.next = next_state
-        with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
+        with m.If(roundz):
             m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
             with m.If(z.m == z.m1s): # all 1s
                 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
@@ -318,7 +673,7 @@ class FPBase:
         """
         m.next = next_state
         # denormalised, correct exponent to zero
-        with m.If(z.is_denormalised()):
+        with m.If(z.is_denormalised):
             m.d.sync += z.e.eq(z.N127)
 
     def pack(self, m, z, next_state):
@@ -326,7 +681,7 @@ class FPBase:
         """
         m.next = next_state
         # if overflow occurs, return inf
-        with m.If(z.is_overflowed()):
+        with m.If(z.is_overflowed):
             m.d.sync += z.inf(z.s)
         with m.Else():
             m.d.sync += z.create(z.s, z.e, z.m)
@@ -337,11 +692,42 @@ class FPBase:
             resets stb back to zero when that occurs, as acknowledgement.
         """
         m.d.sync += [
-          out_z.stb.eq(1),
           out_z.v.eq(z.v)
         ]
-        with m.If(out_z.stb & out_z.ack):
-            m.d.sync += out_z.stb.eq(0)
+        with m.If(out_z.valid_o & out_z.ready_i_test):
+            m.d.sync += out_z.valid_o.eq(0)
             m.next = next_state
+        with m.Else():
+            m.d.sync += out_z.valid_o.eq(1)
+
+
+class FPState(FPBase):
+    def __init__(self, state_from):
+        self.state_from = state_from
+
+    def set_inputs(self, inputs):
+        self.inputs = inputs
+        for k,v in inputs.items():
+            setattr(self, k, v)
+
+    def set_outputs(self, outputs):
+        self.outputs = outputs
+        for k,v in outputs.items():
+            setattr(self, k, v)
+
+
+class FPID:
+    def __init__(self, id_wid):
+        self.id_wid = id_wid
+        if self.id_wid:
+            self.in_mid = Signal(id_wid, reset_less=True)
+            self.out_mid = Signal(id_wid, reset_less=True)
+        else:
+            self.in_mid = None
+            self.out_mid = None
+
+    def idsync(self, m):
+        if self.id_wid is not None:
+            m.d.sync += self.out_mid.eq(self.in_mid)