update comments
[ieee754fpu.git] / src / add / fpbase.py
index 242f41d814a2b11d0bfb177c13dd95ee02bf88d2..f49085921d914d600dea950d3707ab924eb0272b 100644 (file)
@@ -2,11 +2,15 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
 from math import log
 from operator import or_
 from functools import reduce
 
 from math import log
 from operator import or_
 from functools import reduce
 
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
 class MultiShiftR:
 
     def __init__(self, width):
@@ -56,7 +60,7 @@ class MultiShift:
         return res
 
 
         return res
 
 
-class FPNumBase:
+class FPNumBase: #(Elaboratable):
     """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
     """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
@@ -84,6 +88,8 @@ class FPNumBase:
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
@@ -96,6 +102,9 @@ class FPNumBase:
         self.is_overflowed = Signal(reset_less=True)
         self.is_denormalised = Signal(reset_less=True)
         self.exp_128 = Signal(reset_less=True)
         self.is_overflowed = Signal(reset_less=True)
         self.is_denormalised = Signal(reset_less=True)
         self.exp_128 = Signal(reset_less=True)
+        self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
+        self.exp_lt_n126 = Signal(reset_less=True)
+        self.exp_gt_n126 = Signal(reset_less=True)
         self.exp_gt127 = Signal(reset_less=True)
         self.exp_n127 = Signal(reset_less=True)
         self.exp_n126 = Signal(reset_less=True)
         self.exp_gt127 = Signal(reset_less=True)
         self.exp_n127 = Signal(reset_less=True)
         self.exp_n126 = Signal(reset_less=True)
@@ -110,6 +119,9 @@ class FPNumBase:
         m.d.comb += self.is_overflowed.eq(self._is_overflowed())
         m.d.comb += self.is_denormalised.eq(self._is_denormalised())
         m.d.comb += self.exp_128.eq(self.e == self.P128)
         m.d.comb += self.is_overflowed.eq(self._is_overflowed())
         m.d.comb += self.is_denormalised.eq(self._is_denormalised())
         m.d.comb += self.exp_128.eq(self.e == self.P128)
+        m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
+        m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
+        m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
         m.d.comb += self.exp_gt127.eq(self.e > self.P127)
         m.d.comb += self.exp_n127.eq(self.e == self.N127)
         m.d.comb += self.exp_n126.eq(self.e == self.N126)
         m.d.comb += self.exp_gt127.eq(self.e > self.P127)
         m.d.comb += self.exp_n127.eq(self.e == self.N127)
         m.d.comb += self.exp_n126.eq(self.e == self.N126)
@@ -133,7 +145,12 @@ class FPNumBase:
     def _is_denormalised(self):
         return (self.exp_n126) & (self.m_msbzero)
 
     def _is_denormalised(self):
         return (self.exp_n126) & (self.m_msbzero)
 
-    def copy(self, inp):
+    def __iter__(self):
+        yield self.s
+        yield self.e
+        yield self.m
+
+    def eq(self, inp):
         return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
 
 
         return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
 
 
@@ -177,8 +194,70 @@ class FPNumOut(FPNumBase):
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
+
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
+
+class MultiShiftRMerge(Elaboratable):
+    """ shifts down (right) and merges lower bits into m[0].
+        m[0] is the "sticky" bit, basically
+    """
+    def __init__(self, width, s_max=None):
+        if s_max is None:
+            s_max = int(log(width) / log(2))
+        self.smax = s_max
+        self.m = Signal(width, reset_less=True)
+        self.inp = Signal(width, reset_less=True)
+        self.diff = Signal(s_max, reset_less=True)
+        self.width = width
+
+    def elaborate(self, platform):
+        m = Module()
+
+        rs = Signal(self.width, reset_less=True)
+        m_mask = Signal(self.width, reset_less=True)
+        smask = Signal(self.width, reset_less=True)
+        stickybit = Signal(reset_less=True)
+        maxslen = Signal(self.smax, reset_less=True)
+        maxsleni = Signal(self.smax, reset_less=True)
+
+        sm = MultiShift(self.width-1)
+        m0s = Const(0, self.width-1)
+        mw = Const(self.width-1, len(self.diff))
+        m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+                     maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+                    ]
+
+        m.d.comb += [
+                # shift mantissa by maxslen, mask by inverse
+                rs.eq(sm.rshift(self.inp[1:], maxslen)),
+                m_mask.eq(sm.rshift(~m0s, maxsleni)),
+                smask.eq(self.inp[1:] & m_mask),
+                # sticky bit combines all mask (and mantissa low bit)
+                stickybit.eq(smask.bool() | self.inp[0]),
+                # mantissa result contains m[0] already.
+                self.m.eq(Cat(stickybit, rs))
+           ]
+        return m
 
 
-class FPNumShift(FPNumBase):
+
+class FPNumShift(FPNumBase, Elaboratable):
     """ Floating-point Number Class for shifting
     """
     def __init__(self, mainm, op, inv, width, m_extra=True):
     """ Floating-point Number Class for shifting
     """
     def __init__(self, mainm, op, inv, width, m_extra=True):
@@ -201,14 +280,14 @@ class FPNumShift(FPNumBase):
 
         return m
 
 
         return m
 
-    def shift_down(self):
+    def shift_down(self, inp):
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
-        return [self.e.eq(self.e + 1),
-                self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
                ]
 
     def shift_down_multi(self, diff):
                ]
 
     def shift_down_multi(self, diff):
@@ -249,7 +328,8 @@ class FPNumShift(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class FPNumIn(FPNumBase):
+
+class FPNumDecode(FPNumBase):
     """ Floating-point Number Class
 
         Contains signals for an incoming copy of the value, decoded into
     """ Floating-point Number Class
 
         Contains signals for an incoming copy of the value, decoded into
@@ -263,15 +343,12 @@ class FPNumIn(FPNumBase):
     """
     def __init__(self, op, width, m_extra=True):
         FPNumBase.__init__(self, width, m_extra)
     """
     def __init__(self, op, width, m_extra=True):
         FPNumBase.__init__(self, width, m_extra)
-        self.latch_in = Signal()
         self.op = op
 
     def elaborate(self, platform):
         m = FPNumBase.elaborate(self, platform)
 
         self.op = op
 
     def elaborate(self, platform):
         m = FPNumBase.elaborate(self, platform)
 
-        #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
-        #with m.If(self.latch_in):
-        #    m.d.sync += self.decode(self.v)
+        m.d.comb += self.decode(self.v)
 
         return m
 
 
         return m
 
@@ -289,17 +366,64 @@ class FPNumIn(FPNumBase):
                 self.s.eq(v[-1]),                 # sign
                 ]
 
                 self.s.eq(v[-1]),                 # sign
                 ]
 
-    def shift_down(self):
+class FPNumIn(FPNumBase):
+    """ Floating-point Number Class
+
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.op = op
+
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
+
+    def shift_down(self, inp):
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
-        return [self.e.eq(self.e + 1),
-                self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
                ]
 
                ]
 
-    def shift_down_multi(self, diff):
+    def shift_down_multi(self, diff, inp=None):
         """ shifts a mantissa down. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
         """ shifts a mantissa down. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
@@ -314,16 +438,19 @@ class FPNumIn(FPNumBase):
             inverted and used as a mask to get the LSBs of the mantissa.
             those are then |'d into the sticky bit.
         """
             inverted and used as a mask to get the LSBs of the mantissa.
             those are then |'d into the sticky bit.
         """
+        if inp is None:
+            inp = self
         sm = MultiShift(self.width)
         mw = Const(self.m_width-1, len(diff))
         maxslen = Mux(diff > mw, mw, diff)
         sm = MultiShift(self.width)
         mw = Const(self.m_width-1, len(diff))
         maxslen = Mux(diff > mw, mw, diff)
-        rs = sm.rshift(self.m[1:], maxslen)
+        rs = sm.rshift(inp.m[1:], maxslen)
         maxsleni = mw - maxslen
         m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
 
         maxsleni = mw - maxslen
         m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
 
-        stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
-        return [self.e.eq(self.e + diff),
-                self.m.eq(Cat(stickybits, rs))
+        #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
+        stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
+        return [self.e.eq(inp.e + diff),
+                self.m.eq(Cat(stickybit, rs))
                ]
 
     def shift_up_multi(self, diff):
                ]
 
     def shift_up_multi(self, diff):
@@ -337,13 +464,35 @@ class FPNumIn(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class FPOp:
-    def __init__(self, width):
-        self.width = width
+class Trigger(Elaboratable):
+    def __init__(self):
 
 
-        self.v   = Signal(width)
         self.stb = Signal(reset=0)
         self.ack = Signal()
         self.stb = Signal(reset=0)
         self.ack = Signal()
+        self.trigger = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.trigger.eq(self.stb & self.ack)
+        return m
+
+    def eq(self, inp):
+        return [self.stb.eq(inp.stb),
+                self.ack.eq(inp.ack)
+               ]
+
+    def ports(self):
+        return [self.stb, self.ack]
+
+
+class FPOpIn(PrevControl):
+    def __init__(self, width):
+        PrevControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_i
 
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
 
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
@@ -363,11 +512,36 @@ class FPOp:
                 in_op.ack.eq(self.ack), # send ACK
                ]
 
                 in_op.ack.eq(self.ack), # send ACK
                ]
 
-    def ports(self):
-        return [self.v, self.stb, self.ack]
 
 
+class FPOpOut(NextControl):
+    def __init__(self, width):
+        NextControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_o
 
 
-class Overflow:
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
+               ]
+
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
+
+
+class Overflow: #(Elaboratable):
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
         self.round_bit = Signal(reset_less=True) # tot[1]
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
         self.round_bit = Signal(reset_less=True) # tot[1]
@@ -376,7 +550,13 @@ class Overflow:
 
         self.roundz = Signal(reset_less=True)
 
 
         self.roundz = Signal(reset_less=True)
 
-    def copy(self, inp):
+    def __iter__(self):
+        yield self.guard
+        yield self.round_bit
+        yield self.sticky
+        yield self.m0
+
+    def eq(self, inp):
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
                 self.sticky.eq(inp.sticky),
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
                 self.sticky.eq(inp.sticky),
@@ -402,15 +582,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
-        with m.If((op.ack) & (op.stb)):
+        res = v.decode2(m)
+        ack = Signal()
+        with m.If((op.ready_o) & (op.valid_i_test)):
             m.next = next_state
             m.next = next_state
-            m.d.sync += [
-                # op is latched in from FPNumIn class on same ack/stb
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
@@ -422,7 +602,7 @@ class FPBase:
             both cases *effectively multiply the number stored by 2*,
             which has to be taken into account when extracting the result.
         """
             both cases *effectively multiply the number stored by 2*,
             which has to be taken into account when extracting the result.
         """
-        with m.If(a.e == a.N127):
+        with m.If(a.exp_n127):
             m.d.sync += a.e.eq(a.N126) # limit a exponent
         with m.Else():
             m.d.sync += a.m[-1].eq(1) # set top mantissa bit
             m.d.sync += a.e.eq(a.N126) # limit a exponent
         with m.Else():
             m.d.sync += a.m[-1].eq(1) # set top mantissa bit
@@ -480,14 +660,13 @@ class FPBase:
         with m.Else():
             m.next = next_state
 
         with m.Else():
             m.next = next_state
 
-    def roundz(self, m, z, out_z, roundz):
+    def roundz(self, m, z, roundz):
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
-        m.d.comb += out_z.copy(z) # copies input to output first
         with m.If(roundz):
         with m.If(roundz):
-            m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
+            m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
             with m.If(z.m == z.m1s): # all 1s
             with m.If(z.m == z.m1s): # all 1s
-                m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
+                m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
 
     def corrections(self, m, z, next_state):
         """ denormalisation and sign-bug corrections
 
     def corrections(self, m, z, next_state):
         """ denormalisation and sign-bug corrections
@@ -515,10 +694,40 @@ class FPBase:
         m.d.sync += [
           out_z.v.eq(z.v)
         ]
         m.d.sync += [
           out_z.v.eq(z.v)
         ]
-        with m.If(out_z.stb & out_z.ack):
-            m.d.sync += out_z.stb.eq(0)
+        with m.If(out_z.valid_o & out_z.ready_i_test):
+            m.d.sync += out_z.valid_o.eq(0)
             m.next = next_state
         with m.Else():
             m.next = next_state
         with m.Else():
-            m.d.sync += out_z.stb.eq(1)
+            m.d.sync += out_z.valid_o.eq(1)
+
+
+class FPState(FPBase):
+    def __init__(self, state_from):
+        self.state_from = state_from
+
+    def set_inputs(self, inputs):
+        self.inputs = inputs
+        for k,v in inputs.items():
+            setattr(self, k, v)
+
+    def set_outputs(self, outputs):
+        self.outputs = outputs
+        for k,v in outputs.items():
+            setattr(self, k, v)
+
+
+class FPID:
+    def __init__(self, id_wid):
+        self.id_wid = id_wid
+        if self.id_wid:
+            self.in_mid = Signal(id_wid, reset_less=True)
+            self.out_mid = Signal(id_wid, reset_less=True)
+        else:
+            self.in_mid = None
+            self.out_mid = None
+
+    def idsync(self, m):
+        if self.id_wid is not None:
+            m.d.sync += self.out_mid.eq(self.in_mid)