update comments
[ieee754fpu.git] / src / add / fpbase.py
index 30c3a5c47e288c8224f3dc0acd8da7fd8a765fd1..f49085921d914d600dea950d3707ab924eb0272b 100644 (file)
@@ -2,11 +2,15 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
 from math import log
 from operator import or_
 from functools import reduce
 
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
@@ -56,7 +60,7 @@ class MultiShift:
         return res
 
 
-class FPNumBase:
+class FPNumBase: #(Elaboratable):
     """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
@@ -84,6 +88,8 @@ class FPNumBase:
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
@@ -139,6 +145,11 @@ class FPNumBase:
     def _is_denormalised(self):
         return (self.exp_n126) & (self.m_msbzero)
 
+    def __iter__(self):
+        yield self.s
+        yield self.e
+        yield self.m
+
     def eq(self, inp):
         return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
 
@@ -183,8 +194,27 @@ class FPNumOut(FPNumBase):
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
+
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
 
-class MultiShiftRMerge:
+class MultiShiftRMerge(Elaboratable):
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
@@ -227,7 +257,7 @@ class MultiShiftRMerge:
         return m
 
 
-class FPNumShift(FPNumBase):
+class FPNumShift(FPNumBase, Elaboratable):
     """ Floating-point Number Class for shifting
     """
     def __init__(self, mainm, op, inv, width, m_extra=True):
@@ -298,7 +328,8 @@ class FPNumShift(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class FPNumIn(FPNumBase):
+
+class FPNumDecode(FPNumBase):
     """ Floating-point Number Class
 
         Contains signals for an incoming copy of the value, decoded into
@@ -312,15 +343,12 @@ class FPNumIn(FPNumBase):
     """
     def __init__(self, op, width, m_extra=True):
         FPNumBase.__init__(self, width, m_extra)
-        self.latch_in = Signal()
         self.op = op
 
     def elaborate(self, platform):
         m = FPNumBase.elaborate(self, platform)
 
-        #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
-        #with m.If(self.latch_in):
-        #    m.d.sync += self.decode(self.v)
+        m.d.comb += self.decode(self.v)
 
         return m
 
@@ -338,6 +366,53 @@ class FPNumIn(FPNumBase):
                 self.s.eq(v[-1]),                 # sign
                 ]
 
+class FPNumIn(FPNumBase):
+    """ Floating-point Number Class
+
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.op = op
+
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
+
     def shift_down(self, inp):
         """ shifts a mantissa down by one. exponent is increased to compensate
 
@@ -389,7 +464,7 @@ class FPNumIn(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class Trigger:
+class Trigger(Elaboratable):
     def __init__(self):
 
         self.stb = Signal(reset=0)
@@ -410,12 +485,14 @@ class Trigger:
         return [self.stb, self.ack]
 
 
-class FPOp(Trigger):
+class FPOpIn(PrevControl):
     def __init__(self, width):
-        Trigger.__init__(self)
+        PrevControl.__init__(self)
         self.width = width
 
-        self.v   = Signal(width)
+    @property
+    def v(self):
+        return self.data_i
 
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
@@ -435,17 +512,36 @@ class FPOp(Trigger):
                 in_op.ack.eq(self.ack), # send ACK
                ]
 
-    def eq(self, inp):
-        return [self.v.eq(inp.v),
-                self.stb.eq(inp.stb),
-                self.ack.eq(inp.ack)
+
+class FPOpOut(NextControl):
+    def __init__(self, width):
+        NextControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_o
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
                ]
 
-    def ports(self):
-        return [self.v, self.stb, self.ack]
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
 
 
-class Overflow:
+class Overflow: #(Elaboratable):
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
         self.round_bit = Signal(reset_less=True) # tot[1]
@@ -454,6 +550,12 @@ class Overflow:
 
         self.roundz = Signal(reset_less=True)
 
+    def __iter__(self):
+        yield self.guard
+        yield self.round_bit
+        yield self.sticky
+        yield self.m0
+
     def eq(self, inp):
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
@@ -480,15 +582,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
-        with m.If((op.ack) & (op.stb)):
+        res = v.decode2(m)
+        ack = Signal()
+        with m.If((op.ready_o) & (op.valid_i_test)):
             m.next = next_state
-            m.d.sync += [
-                # op is latched in from FPNumIn class on same ack/stb
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
@@ -592,11 +694,11 @@ class FPBase:
         m.d.sync += [
           out_z.v.eq(z.v)
         ]
-        with m.If(out_z.stb & out_z.ack):
-            m.d.sync += out_z.stb.eq(0)
+        with m.If(out_z.valid_o & out_z.ready_i_test):
+            m.d.sync += out_z.valid_o.eq(0)
             m.next = next_state
         with m.Else():
-            m.d.sync += out_z.stb.eq(1)
+            m.d.sync += out_z.valid_o.eq(1)
 
 
 class FPState(FPBase):