pass in flatten/processing function into _connect_in/out
[ieee754fpu.git] / src / add / fpbase.py
index dc2c9020dd5fc3644da3a9f7f36e2532ea896619..b31ab21cb0b1d004e7b628e38dce79c764aa0659 100644 (file)
@@ -7,6 +7,9 @@ from math import log
 from operator import or_
 from functools import reduce
 
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
@@ -84,6 +87,8 @@ class FPNumBase:
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
@@ -183,6 +188,25 @@ class FPNumOut(FPNumBase):
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
+
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
 
 class MultiShiftRMerge:
     """ shifts down (right) and merges lower bits into m[0].
@@ -353,6 +377,22 @@ class FPNumIn(FPNumBase):
         self.latch_in = Signal()
         self.op = op
 
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
     def decode(self, v):
         """ decodes a latched value into sign / exponent / mantissa
 
@@ -509,15 +549,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
+        res = v.decode2(m)
+        ack = Signal()
         with m.If((op.ack) & (op.stb)):
             m.next = next_state
-            m.d.sync += [
-                # op is latched in from FPNumIn class on same ack/stb
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for