remove extra arg from old roundz function
[ieee754fpu.git] / src / add / fpbase.py
index d73183f3ec8b712fcab5b51954c2822326a9bf0d..db95eb13e2a4ef7ae7199c6a29e05865674fc62a 100644 (file)
@@ -96,6 +96,7 @@ class FPNumBase:
         self.is_overflowed = Signal(reset_less=True)
         self.is_denormalised = Signal(reset_less=True)
         self.exp_128 = Signal(reset_less=True)
+        self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
         self.exp_lt_n126 = Signal(reset_less=True)
         self.exp_gt_n126 = Signal(reset_less=True)
         self.exp_gt127 = Signal(reset_less=True)
@@ -112,8 +113,9 @@ class FPNumBase:
         m.d.comb += self.is_overflowed.eq(self._is_overflowed())
         m.d.comb += self.is_denormalised.eq(self._is_denormalised())
         m.d.comb += self.exp_128.eq(self.e == self.P128)
-        m.d.comb += self.exp_gt_n126.eq(self.e > self.N126)
-        m.d.comb += self.exp_lt_n126.eq(self.e < self.N126)
+        m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
+        m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
+        m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
         m.d.comb += self.exp_gt127.eq(self.e > self.P127)
         m.d.comb += self.exp_n127.eq(self.e == self.N127)
         m.d.comb += self.exp_n126.eq(self.e == self.N126)
@@ -182,6 +184,49 @@ class FPNumOut(FPNumBase):
         return self.create(s, self.N127, 0)
 
 
+class MultiShiftRMerge:
+    """ shifts down (right) and merges lower bits into m[0].
+        m[0] is the "sticky" bit, basically
+    """
+    def __init__(self, width, s_max=None):
+        if s_max is None:
+            s_max = int(log(width) / log(2))
+        self.smax = s_max
+        self.m = Signal(width, reset_less=True)
+        self.inp = Signal(width, reset_less=True)
+        self.diff = Signal(s_max, reset_less=True)
+        self.width = width
+
+    def elaborate(self, platform):
+        m = Module()
+
+        rs = Signal(self.width, reset_less=True)
+        m_mask = Signal(self.width, reset_less=True)
+        smask = Signal(self.width, reset_less=True)
+        stickybit = Signal(reset_less=True)
+        maxslen = Signal(self.smax, reset_less=True)
+        maxsleni = Signal(self.smax, reset_less=True)
+
+        sm = MultiShift(self.width-1)
+        m0s = Const(0, self.width-1)
+        mw = Const(self.width-1, len(self.diff))
+        m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+                     maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+                    ]
+
+        m.d.comb += [
+                # shift mantissa by maxslen, mask by inverse
+                rs.eq(sm.rshift(self.inp[1:], maxslen)),
+                m_mask.eq(sm.rshift(~m0s, maxsleni)),
+                smask.eq(self.inp[1:] & m_mask),
+                # sticky bit combines all mask (and mantissa low bit)
+                stickybit.eq(smask.bool() | self.inp[0]),
+                # mantissa result contains m[0] already.
+                self.m.eq(Cat(stickybit, rs))
+           ]
+        return m
+
+
 class FPNumShift(FPNumBase):
     """ Floating-point Number Class for shifting
     """
@@ -327,10 +372,10 @@ class FPNumIn(FPNumBase):
         maxsleni = mw - maxslen
         m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
 
-        #stickybits = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
-        stickybits = (inp.m[1:] & m_mask).bool() | inp.m[0]
+        #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
+        stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
         return [self.e.eq(inp.e + diff),
-                self.m.eq(Cat(stickybits, rs))
+                self.m.eq(Cat(stickybit, rs))
                ]
 
     def shift_up_multi(self, diff):
@@ -344,20 +389,34 @@ class FPNumIn(FPNumBase):
                 self.m.eq(sm.lshift(self.m, maxslen))
                ]
 
-class FPOp:
-    def __init__(self, width):
-        self.width = width
+class Trigger:
+    def __init__(self):
 
-        self.v   = Signal(width)
         self.stb = Signal(reset=0)
         self.ack = Signal()
         self.trigger = Signal(reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
-        m.d.sync += self.trigger.eq(self.stb & self.ack)
+        m.d.comb += self.trigger.eq(self.stb & self.ack)
         return m
 
+    def copy(self, inp):
+        return [self.stb.eq(inp.stb),
+                self.ack.eq(inp.ack)
+               ]
+
+    def ports(self):
+        return [self.stb, self.ack]
+
+
+class FPOp(Trigger):
+    def __init__(self, width):
+        Trigger.__init__(self)
+        self.width = width
+
+        self.v   = Signal(width)
+
     def chain_inv(self, in_op, extra=None):
         stb = in_op.stb
         if extra is not None:
@@ -499,14 +558,13 @@ class FPBase:
         with m.Else():
             m.next = next_state
 
-    def roundz(self, m, z, out_z, roundz):
+    def roundz(self, m, z, roundz):
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
-        m.d.comb += out_z.copy(z) # copies input to output first
         with m.If(roundz):
-            m.d.comb += out_z.m.eq(z.m + 1) # mantissa rounds up
+            m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
             with m.If(z.m == z.m1s): # all 1s
-                m.d.comb += out_z.e.eq(z.e + 1) # exponent rounds up
+                m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
 
     def corrections(self, m, z, next_state):
         """ denormalisation and sign-bug corrections