update comments
[ieee754fpu.git] / src / add / fpbase.py
index 8c10d7821c96dd0f1592d05df7b5e415e6f796a8..f49085921d914d600dea950d3707ab924eb0272b 100644 (file)
@@ -2,19 +2,23 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux, Module
+from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
 from math import log
 from operator import or_
 from functools import reduce
 
 from math import log
 from operator import or_
 from functools import reduce
 
+from singlepipe import PrevControl, NextControl
+from pipeline import ObjectProxy
+
+
 class MultiShiftR:
 
     def __init__(self, width):
         self.width = width
         self.smax = int(log(width) / log(2))
 class MultiShiftR:
 
     def __init__(self, width):
         self.width = width
         self.smax = int(log(width) / log(2))
-        self.i = Signal(width)
-        self.s = Signal(self.smax)
-        self.o = Signal(width)
+        self.i = Signal(width, reset_less=True)
+        self.s = Signal(self.smax, reset_less=True)
+        self.o = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -56,22 +60,13 @@ class MultiShift:
         return res
 
 
         return res
 
 
-class FPNum:
-    """ Floating-point Number Class, variable-width TODO (currently 32-bit)
-
-        Contains signals for an incoming copy of the value, decoded into
-        sign / exponent / mantissa.
-        Also contains encoding functions, creation and recognition of
-        zero, NaN and inf (all signed)
-
-        Four extra bits are included in the mantissa: the top bit
-        (m[-1]) is effectively a carry-overflow.  The other three are
-        guard (m[2]), round (m[1]), and sticky (m[0])
+class FPNumBase: #(Elaboratable):
+    """ Floating-point Base Number Class
     """
     def __init__(self, width, m_extra=True):
         self.width = width
     """
     def __init__(self, width, m_extra=True):
         self.width = width
-        m_width = {32: 24, 64: 53}[width]
-        e_width = {32: 10, 64: 13}[width]
+        m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
+        e_width = {16: 7,  32: 10, 64: 13}[width] # 2 extra bits (overflow)
         e_max = 1<<(e_width-3)
         self.rmw = m_width # real mantissa width (not including extras)
         self.e_max = e_max
         e_max = 1<<(e_width-3)
         self.rmw = m_width # real mantissa width (not including extras)
         self.e_max = e_max
@@ -87,31 +82,97 @@ class FPNum:
         self.e_start = self.rmw - 1
         self.e_end = self.rmw + self.e_width - 3 # for decoding
 
         self.e_start = self.rmw - 1
         self.e_end = self.rmw + self.e_width - 3 # for decoding
 
-        self.v = Signal(width)      # Latched copy of value
-        self.m = Signal(m_width)    # Mantissa
-        self.e = Signal((e_width, True)) # Exponent: 10 bits, signed
-        self.s = Signal()           # Sign bit
+        self.v = Signal(width, reset_less=True)      # Latched copy of value
+        self.m = Signal(m_width, reset_less=True)    # Mantissa
+        self.e = Signal((e_width, True), reset_less=True) # Exponent: IEEE754exp+2 bits, signed
+        self.s = Signal(reset_less=True)           # Sign bit
 
         self.mzero = Const(0, (m_width, False))
 
         self.mzero = Const(0, (m_width, False))
+        m_msb = 1<<(self.m_width-2)
+        self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
         self.N127 = Const(-(e_max-1), (e_width, True))
         self.N126 = Const(-(e_max-2), (e_width, True))
 
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
         self.P127 = Const(e_max-1, (e_width, True))
         self.N127 = Const(-(e_max-1), (e_width, True))
         self.N126 = Const(-(e_max-2), (e_width, True))
 
-    def decode(self, v):
-        """ decodes a latched value into sign / exponent / mantissa
+        self.is_nan = Signal(reset_less=True)
+        self.is_zero = Signal(reset_less=True)
+        self.is_inf = Signal(reset_less=True)
+        self.is_overflowed = Signal(reset_less=True)
+        self.is_denormalised = Signal(reset_less=True)
+        self.exp_128 = Signal(reset_less=True)
+        self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
+        self.exp_lt_n126 = Signal(reset_less=True)
+        self.exp_gt_n126 = Signal(reset_less=True)
+        self.exp_gt127 = Signal(reset_less=True)
+        self.exp_n127 = Signal(reset_less=True)
+        self.exp_n126 = Signal(reset_less=True)
+        self.m_zero = Signal(reset_less=True)
+        self.m_msbzero = Signal(reset_less=True)
 
 
-            bias is subtracted here, from the exponent.  exponent
-            is extended to 10 bits so that subtract 127 is done on
-            a 10-bit number
-        """
-        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
-        #print ("decode", self.e_end)
-        return [self.m.eq(Cat(*args)), # mantissa
-                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
-                self.s.eq(v[-1]),                 # sign
-                ]
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.is_nan.eq(self._is_nan())
+        m.d.comb += self.is_zero.eq(self._is_zero())
+        m.d.comb += self.is_inf.eq(self._is_inf())
+        m.d.comb += self.is_overflowed.eq(self._is_overflowed())
+        m.d.comb += self.is_denormalised.eq(self._is_denormalised())
+        m.d.comb += self.exp_128.eq(self.e == self.P128)
+        m.d.comb += self.exp_sub_n126.eq(self.e - self.N126)
+        m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
+        m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
+        m.d.comb += self.exp_gt127.eq(self.e > self.P127)
+        m.d.comb += self.exp_n127.eq(self.e == self.N127)
+        m.d.comb += self.exp_n126.eq(self.e == self.N126)
+        m.d.comb += self.m_zero.eq(self.m == self.mzero)
+        m.d.comb += self.m_msbzero.eq(self.m[self.e_start] == 0)
+
+        return m
+
+    def _is_nan(self):
+        return (self.exp_128) & (~self.m_zero)
+
+    def _is_inf(self):
+        return (self.exp_128) & (self.m_zero)
+
+    def _is_zero(self):
+        return (self.exp_n127) & (self.m_zero)
+
+    def _is_overflowed(self):
+        return self.exp_gt127
+
+    def _is_denormalised(self):
+        return (self.exp_n126) & (self.m_msbzero)
+
+    def __iter__(self):
+        yield self.s
+        yield self.e
+        yield self.m
+
+    def eq(self, inp):
+        return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
+
+
+class FPNumOut(FPNumBase):
+    """ Floating-point Number Class
+
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
+
+        return m
 
     def create(self, s, e, m):
         """ creates a value from sign / exponent / mantissa
 
     def create(self, s, e, m):
         """ creates a value from sign / exponent / mantissa
@@ -124,14 +185,109 @@ class FPNum:
           self.v[0:self.e_start].eq(m)         # mantissa
         ]
 
           self.v[0:self.e_start].eq(m)         # mantissa
         ]
 
-    def shift_down(self):
+    def nan(self, s):
+        return self.create(s, self.P128, 1<<(self.e_start-1))
+
+    def inf(self, s):
+        return self.create(s, self.P128, 0)
+
+    def zero(self, s):
+        return self.create(s, self.N127, 0)
+
+    def create2(self, s, e, m):
+        """ creates a value from sign / exponent / mantissa
+
+            bias is added here, to the exponent
+        """
+        e = e + self.P127 # exp (add on bias)
+        return Cat(m[0:self.e_start],
+                   e[0:self.e_end-self.e_start],
+                   s)
+
+    def nan2(self, s):
+        return self.create2(s, self.P128, self.msb1)
+
+    def inf2(self, s):
+        return self.create2(s, self.P128, self.mzero)
+
+    def zero2(self, s):
+        return self.create2(s, self.N127, self.mzero)
+
+
+class MultiShiftRMerge(Elaboratable):
+    """ shifts down (right) and merges lower bits into m[0].
+        m[0] is the "sticky" bit, basically
+    """
+    def __init__(self, width, s_max=None):
+        if s_max is None:
+            s_max = int(log(width) / log(2))
+        self.smax = s_max
+        self.m = Signal(width, reset_less=True)
+        self.inp = Signal(width, reset_less=True)
+        self.diff = Signal(s_max, reset_less=True)
+        self.width = width
+
+    def elaborate(self, platform):
+        m = Module()
+
+        rs = Signal(self.width, reset_less=True)
+        m_mask = Signal(self.width, reset_less=True)
+        smask = Signal(self.width, reset_less=True)
+        stickybit = Signal(reset_less=True)
+        maxslen = Signal(self.smax, reset_less=True)
+        maxsleni = Signal(self.smax, reset_less=True)
+
+        sm = MultiShift(self.width-1)
+        m0s = Const(0, self.width-1)
+        mw = Const(self.width-1, len(self.diff))
+        m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+                     maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+                    ]
+
+        m.d.comb += [
+                # shift mantissa by maxslen, mask by inverse
+                rs.eq(sm.rshift(self.inp[1:], maxslen)),
+                m_mask.eq(sm.rshift(~m0s, maxsleni)),
+                smask.eq(self.inp[1:] & m_mask),
+                # sticky bit combines all mask (and mantissa low bit)
+                stickybit.eq(smask.bool() | self.inp[0]),
+                # mantissa result contains m[0] already.
+                self.m.eq(Cat(stickybit, rs))
+           ]
+        return m
+
+
+class FPNumShift(FPNumBase, Elaboratable):
+    """ Floating-point Number Class for shifting
+    """
+    def __init__(self, mainm, op, inv, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.mainm = mainm
+        self.inv = inv
+        self.op = op
+
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
+
+        m.d.comb += self.s.eq(op.s)
+        m.d.comb += self.e.eq(op.e)
+        m.d.comb += self.m.eq(op.m)
+
+        with self.mainm.State("align"):
+            with m.If(self.e < self.inv.e):
+                m.d.sync += self.shift_down()
+
+        return m
+
+    def shift_down(self, inp):
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
         """ shifts a mantissa down by one. exponent is increased to compensate
 
             accuracy is lost as a result in the mantissa however there are 3
             guard bits (the latter of which is the "sticky" bit)
         """
-        return [self.e.eq(self.e + 1),
-                self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
                ]
 
     def shift_down_multi(self, diff):
                ]
 
     def shift_down_multi(self, diff):
@@ -161,48 +317,256 @@ class FPNum:
                 self.m.eq(Cat(stickybits, rs))
                ]
 
                 self.m.eq(Cat(stickybits, rs))
                ]
 
-    def nan(self, s):
-        return self.create(s, self.P128, 1<<(self.e_start-1))
+    def shift_up_multi(self, diff):
+        """ shifts a mantissa up. exponent is decreased to compensate
+        """
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
 
 
-    def inf(self, s):
-        return self.create(s, self.P128, 0)
+        return [self.e.eq(self.e - diff),
+                self.m.eq(sm.lshift(self.m, maxslen))
+               ]
 
 
-    def zero(self, s):
-        return self.create(s, self.N127, 0)
 
 
-    def is_nan(self):
-        return (self.e == self.P128) & (self.m != 0)
+class FPNumDecode(FPNumBase):
+    """ Floating-point Number Class
 
 
-    def is_inf(self):
-        return (self.e == self.P128) & (self.m == 0)
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
 
 
-    def is_zero(self):
-        return (self.e == self.N127) & (self.m == self.mzero)
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.op = op
 
 
-    def is_overflowed(self):
-        return (self.e > self.P127)
+    def elaborate(self, platform):
+        m = FPNumBase.elaborate(self, platform)
 
 
-    def is_denormalised(self):
-        return (self.e == self.N126) & (self.m[self.e_start] == 0)
+        m.d.comb += self.decode(self.v)
 
 
+        return m
 
 
-class FPOp:
-    def __init__(self, width):
-        self.width = width
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
+
+class FPNumIn(FPNumBase):
+    """ Floating-point Number Class
+
+        Contains signals for an incoming copy of the value, decoded into
+        sign / exponent / mantissa.
+        Also contains encoding functions, creation and recognition of
+        zero, NaN and inf (all signed)
+
+        Four extra bits are included in the mantissa: the top bit
+        (m[-1]) is effectively a carry-overflow.  The other three are
+        guard (m[2]), round (m[1]), and sticky (m[0])
+    """
+    def __init__(self, op, width, m_extra=True):
+        FPNumBase.__init__(self, width, m_extra)
+        self.latch_in = Signal()
+        self.op = op
+
+    def decode2(self, m):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        v = self.v
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        res = ObjectProxy(m, pipemode=False)
+        res.m = Cat(*args)                             # mantissa
+        res.e = v[self.e_start:self.e_end] - self.P127 # exp
+        res.s = v[-1]                                  # sign
+        return res
+
+    def decode(self, v):
+        """ decodes a latched value into sign / exponent / mantissa
+
+            bias is subtracted here, from the exponent.  exponent
+            is extended to 10 bits so that subtract 127 is done on
+            a 10-bit number
+        """
+        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        #print ("decode", self.e_end)
+        return [self.m.eq(Cat(*args)), # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+                self.s.eq(v[-1]),                 # sign
+                ]
+
+    def shift_down(self, inp):
+        """ shifts a mantissa down by one. exponent is increased to compensate
+
+            accuracy is lost as a result in the mantissa however there are 3
+            guard bits (the latter of which is the "sticky" bit)
+        """
+        return [self.e.eq(inp.e + 1),
+                self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
+               ]
+
+    def shift_down_multi(self, diff, inp=None):
+        """ shifts a mantissa down. exponent is increased to compensate
+
+            accuracy is lost as a result in the mantissa however there are 3
+            guard bits (the latter of which is the "sticky" bit)
+
+            this code works by variable-shifting the mantissa by up to
+            its maximum bit-length: no point doing more (it'll still be
+            zero).
+
+            the sticky bit is computed by shifting a batch of 1s by
+            the same amount, which will introduce zeros.  it's then
+            inverted and used as a mask to get the LSBs of the mantissa.
+            those are then |'d into the sticky bit.
+        """
+        if inp is None:
+            inp = self
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width-1, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
+        rs = sm.rshift(inp.m[1:], maxslen)
+        maxsleni = mw - maxslen
+        m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+
+        #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
+        stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
+        return [self.e.eq(inp.e + diff),
+                self.m.eq(Cat(stickybit, rs))
+               ]
+
+    def shift_up_multi(self, diff):
+        """ shifts a mantissa up. exponent is decreased to compensate
+        """
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
+
+        return [self.e.eq(self.e - diff),
+                self.m.eq(sm.lshift(self.m, maxslen))
+               ]
+
+class Trigger(Elaboratable):
+    def __init__(self):
 
 
-        self.v   = Signal(width)
-        self.stb = Signal()
+        self.stb = Signal(reset=0)
         self.ack = Signal()
         self.ack = Signal()
+        self.trigger = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.trigger.eq(self.stb & self.ack)
+        return m
+
+    def eq(self, inp):
+        return [self.stb.eq(inp.stb),
+                self.ack.eq(inp.ack)
+               ]
 
     def ports(self):
 
     def ports(self):
-        return [self.v, self.stb, self.ack]
+        return [self.stb, self.ack]
+
+
+class FPOpIn(PrevControl):
+    def __init__(self, width):
+        PrevControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_i
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
+               ]
+
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
 
 
 
 
-class Overflow:
+class FPOpOut(NextControl):
+    def __init__(self, width):
+        NextControl.__init__(self)
+        self.width = width
+
+    @property
+    def v(self):
+        return self.data_o
+
+    def chain_inv(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(~self.ack), # send ACK
+               ]
+
+    def chain_from(self, in_op, extra=None):
+        stb = in_op.stb
+        if extra is not None:
+            stb = stb & extra
+        return [self.v.eq(in_op.v),          # receive value
+                self.stb.eq(stb),      # receive STB
+                in_op.ack.eq(self.ack), # send ACK
+               ]
+
+
+class Overflow: #(Elaboratable):
     def __init__(self):
     def __init__(self):
-        self.guard = Signal()     # tot[2]
-        self.round_bit = Signal() # tot[1]
-        self.sticky = Signal()    # tot[0]
+        self.guard = Signal(reset_less=True)     # tot[2]
+        self.round_bit = Signal(reset_less=True) # tot[1]
+        self.sticky = Signal(reset_less=True)    # tot[0]
+        self.m0 = Signal(reset_less=True)        # mantissa zero bit
+
+        self.roundz = Signal(reset_less=True)
+
+    def __iter__(self):
+        yield self.guard
+        yield self.round_bit
+        yield self.sticky
+        yield self.m0
+
+    def eq(self, inp):
+        return [self.guard.eq(inp.guard),
+                self.round_bit.eq(inp.round_bit),
+                self.sticky.eq(inp.sticky),
+                self.m0.eq(inp.m0)]
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.roundz.eq(self.guard & \
+                                   (self.round_bit | self.sticky | self.m0))
+        return m
 
 
 class FPBase:
 
 
 class FPBase:
@@ -218,14 +582,15 @@ class FPBase:
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
             when both stb and ack are 1.
             acknowledgement is sent by setting ack to ZERO.
         """
-        with m.If((op.ack) & (op.stb)):
+        res = v.decode2(m)
+        ack = Signal()
+        with m.If((op.ready_o) & (op.valid_i_test)):
             m.next = next_state
             m.next = next_state
-            m.d.sync += [
-                v.decode(op.v),
-                op.ack.eq(0)
-            ]
+            # op is latched in from FPNumIn class on same ack/stb
+            m.d.comb += ack.eq(0)
         with m.Else():
         with m.Else():
-            m.d.sync += op.ack.eq(1)
+            m.d.comb += ack.eq(1)
+        return [res, ack]
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
 
     def denormalise(self, m, a):
         """ denormalises a number.  this is probably the wrong name for
@@ -237,7 +602,7 @@ class FPBase:
             both cases *effectively multiply the number stored by 2*,
             which has to be taken into account when extracting the result.
         """
             both cases *effectively multiply the number stored by 2*,
             which has to be taken into account when extracting the result.
         """
-        with m.If(a.e == a.N127):
+        with m.If(a.exp_n127):
             m.d.sync += a.e.eq(a.N126) # limit a exponent
         with m.Else():
             m.d.sync += a.m[-1].eq(1) # set top mantissa bit
             m.d.sync += a.e.eq(a.N126) # limit a exponent
         with m.Else():
             m.d.sync += a.m[-1].eq(1) # set top mantissa bit
@@ -264,12 +629,13 @@ class FPBase:
                   the extra mantissa bits coming from tot[0..2]
         """
         with m.If((z.m[-1] == 0) & (z.e > z.N126)):
                   the extra mantissa bits coming from tot[0..2]
         """
         with m.If((z.m[-1] == 0) & (z.e > z.N126)):
-            m.d.sync +=[
+            m.d.sync += [
                 z.e.eq(z.e - 1),  # DECREASE exponent
                 z.m.eq(z.m << 1), # shift mantissa UP
                 z.m[0].eq(of.guard),       # steal guard bit (was tot[2])
                 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
                 of.round_bit.eq(0),        # reset round bit
                 z.e.eq(z.e - 1),  # DECREASE exponent
                 z.m.eq(z.m << 1), # shift mantissa UP
                 z.m[0].eq(of.guard),       # steal guard bit (was tot[2])
                 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
                 of.round_bit.eq(0),        # reset round bit
+                of.m0.eq(of.guard),
             ]
         with m.Else():
             m.next = next_state
             ]
         with m.Else():
             m.next = next_state
@@ -287,17 +653,17 @@ class FPBase:
                 z.e.eq(z.e + 1),  # INCREASE exponent
                 z.m.eq(z.m >> 1), # shift mantissa DOWN
                 of.guard.eq(z.m[0]),
                 z.e.eq(z.e + 1),  # INCREASE exponent
                 z.m.eq(z.m >> 1), # shift mantissa DOWN
                 of.guard.eq(z.m[0]),
+                of.m0.eq(z.m[1]),
                 of.round_bit.eq(of.guard),
                 of.sticky.eq(of.sticky | of.round_bit)
             ]
         with m.Else():
             m.next = next_state
 
                 of.round_bit.eq(of.guard),
                 of.sticky.eq(of.sticky | of.round_bit)
             ]
         with m.Else():
             m.next = next_state
 
-    def roundz(self, m, z, of, next_state):
+    def roundz(self, m, z, roundz):
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
-        m.next = next_state
-        with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
+        with m.If(roundz):
             m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
             with m.If(z.m == z.m1s): # all 1s
                 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
             m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
             with m.If(z.m == z.m1s): # all 1s
                 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
@@ -307,19 +673,16 @@ class FPBase:
         """
         m.next = next_state
         # denormalised, correct exponent to zero
         """
         m.next = next_state
         # denormalised, correct exponent to zero
-        with m.If(z.is_denormalised()):
-            m.d.sync += z.m.eq(z.N127)
-        # FIX SIGN BUG: -a + a = +0.
-        with m.If((z.e == z.N126) & (z.m[0:] == 0)):
-            m.d.sync += z.s.eq(0)
+        with m.If(z.is_denormalised):
+            m.d.sync += z.e.eq(z.N127)
 
     def pack(self, m, z, next_state):
         """ packs the result into the output (detects overflow->Inf)
         """
         m.next = next_state
         # if overflow occurs, return inf
 
     def pack(self, m, z, next_state):
         """ packs the result into the output (detects overflow->Inf)
         """
         m.next = next_state
         # if overflow occurs, return inf
-        with m.If(z.is_overflowed()):
-            m.d.sync += z.inf(0)
+        with m.If(z.is_overflowed):
+            m.d.sync += z.inf(z.s)
         with m.Else():
             m.d.sync += z.create(z.s, z.e, z.m)
 
         with m.Else():
             m.d.sync += z.create(z.s, z.e, z.m)
 
@@ -329,11 +692,42 @@ class FPBase:
             resets stb back to zero when that occurs, as acknowledgement.
         """
         m.d.sync += [
             resets stb back to zero when that occurs, as acknowledgement.
         """
         m.d.sync += [
-          out_z.stb.eq(1),
           out_z.v.eq(z.v)
         ]
           out_z.v.eq(z.v)
         ]
-        with m.If(out_z.stb & out_z.ack):
-            m.d.sync += out_z.stb.eq(0)
+        with m.If(out_z.valid_o & out_z.ready_i_test):
+            m.d.sync += out_z.valid_o.eq(0)
             m.next = next_state
             m.next = next_state
+        with m.Else():
+            m.d.sync += out_z.valid_o.eq(1)
+
+
+class FPState(FPBase):
+    def __init__(self, state_from):
+        self.state_from = state_from
+
+    def set_inputs(self, inputs):
+        self.inputs = inputs
+        for k,v in inputs.items():
+            setattr(self, k, v)
+
+    def set_outputs(self, outputs):
+        self.outputs = outputs
+        for k,v in outputs.items():
+            setattr(self, k, v)
+
+
+class FPID:
+    def __init__(self, id_wid):
+        self.id_wid = id_wid
+        if self.id_wid:
+            self.in_mid = Signal(id_wid, reset_less=True)
+            self.out_mid = Signal(id_wid, reset_less=True)
+        else:
+            self.in_mid = None
+            self.out_mid = None
+
+    def idsync(self, m):
+        if self.id_wid is not None:
+            m.d.sync += self.out_mid.eq(self.in_mid)