update comments
[ieee754fpu.git] / src / add / fpbase.py
index dc2c9020dd5fc3644da3a9f7f36e2532ea896619..f49085921d914d600dea950d3707ab924eb0272b 100644 (file)
@@ -2,11 +2,15 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
 from math import log
 from operator import or_
 from functools import reduce
 
 from math import log
 from operator import or_
 from functools import reduce
 
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
 class MultiShiftR:
 
     def __init__(self, width):
@@ -56,7 +60,7 @@ class MultiShift:
         return res
 
 
         return res
 
 
-class FPNumBase:
+class FPNumBase: #(Elaboratable):
     """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
     """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
@@ -84,6 +88,8 @@ class FPNumBase:
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
@@ -139,6 +145,11 @@ class FPNumBase:
     def _is_denormalised(self):
         return (self.exp_n126) & (self.m_msbzero)
 
     def _is_denormalised(self):
         return (self.exp_n126) & (self.m_msbzero)
 
+    def __iter__(self):
+        yield self.s
+        yield self.e
+        yield self.m
+
     def eq(self, inp):
         return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
 
     def eq(self, inp):
         return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
 
@@ -183,8 +194,27 @@ class FPNumOut(FPNumBase):
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
     def zero(self, s):
         return self.create(s, self.N127, 0)
 
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
 
 
-class MultiShiftRMerge:
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
+
+class MultiShiftRMerge(Elaboratable):
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
@@ -227,7 +257,7 @@ class MultiShiftRMerge:
         return m
 
 
         return m
 
 
-class FPNumShift(FPNumBase):
+class FPNumShift(FPNumBase, Elaboratable):
     """ Floating-point Number Class for shifting
     """
     def __init__(self, mainm, op, inv, width, m_extra=True):
     """ Floating-point Number Class for shifting
     """
     def __init__(self, mainm, op, inv, width, m_extra=True):
@@ -353,6 +383,22 @@ class FPNumIn(FPNumBase):
         self.latch_in = Signal()
         self.op = op
 
         self.latch_in = Signal()
         self.op = op
 
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
     def decode(self, v):
         """ decodes a latched value into sign / exponent / mantissa
 
     def decode(self, v):
         """ decodes a latched value into sign / exponent / mantissa
 
@@ -418,7 +464,7 @@ class FPNumIn(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class Trigger:
+class Trigger(Elaboratable):
     def __init__(self):
 
         self.stb = Signal(reset=0)
     def __init__(self):
 
         self.stb = Signal(reset=0)
@@ -439,12 +485,14 @@ class Trigger:
         return [self.stb, self.ack]
 
 
         return [self.stb, self.ack]
 
 
-class FPOp(Trigger):
+class FPOpIn(PrevControl):
     def __init__(self, width):
     def __init__(self, width):
-        Trigger.__init__(self)
+        PrevControl.__init__(self)
         self.width = width
 
         self.width = width
 
-        self.v   = Signal(width)
+    @property
+    def v(self):
+        return self.data_i
 
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
 
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
@@ -464,17 +512,36 @@ class FPOp(Trigger):
                 in_op.ack.eq(self.ack), # send ACK
                ]
 
                 in_op.ack.eq(self.ack), # send ACK
                ]
 
-    def eq(self, inp):
-        return [self.v.eq(inp.v),
-                self.stb.eq(inp.stb),
-                self.ack.eq(inp.ack)
+
+class FPOpOut(NextControl):
+    def __init__(self, width):
+        NextControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_o
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
                ]
 
                ]
 
-    def ports(self):
-        return [self.v, self.stb, self.ack]
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
 
 
 
 
-class Overflow:
+class Overflow: #(Elaboratable):
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
         self.round_bit = Signal(reset_less=True) # tot[1]
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
         self.round_bit = Signal(reset_less=True) # tot[1]
@@ -483,6 +550,12 @@ class Overflow:
 
         self.roundz = Signal(reset_less=True)
 
 
         self.roundz = Signal(reset_less=True)
 
+    def __iter__(self):
+        yield self.guard
+        yield self.round_bit
+        yield self.sticky
+        yield self.m0
+
     def eq(self, inp):
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
     def eq(self, inp):
         return [self.guard.eq(inp.guard),
                 self.round_bit.eq(inp.round_bit),
@@ -509,15 +582,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
-        with m.If((op.ack) & (op.stb)):
+        res = v.decode2(m)
+        ack = Signal()
+        with m.If((op.ready_o) & (op.valid_i_test)):
             m.next = next_state
             m.next = next_state
-            m.d.sync += [
-                # op is latched in from FPNumIn class on same ack/stb
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
@@ -621,11 +694,11 @@ class FPBase:
         m.d.sync += [
           out_z.v.eq(z.v)
         ]
         m.d.sync += [
           out_z.v.eq(z.v)
         ]
-        with m.If(out_z.stb & out_z.ack):
-            m.d.sync += out_z.stb.eq(0)
+        with m.If(out_z.valid_o & out_z.ready_i_test):
+            m.d.sync += out_z.valid_o.eq(0)
             m.next = next_state
         with m.Else():
             m.next = next_state
         with m.Else():
-            m.d.sync += out_z.stb.eq(1)
+            m.d.sync += out_z.valid_o.eq(1)
 
 
 class FPState(FPBase):
 
 
 class FPState(FPBase):