use ospec to make clone of out_z
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
index 121f61b22ab54654bd7147bf8972a524699a362b..e9929d5ca5969f6c4ca955d2bb591b1b1c57ac43 100644 (file)
@@ -2,9 +2,10 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Module, Signal, Cat, Mux, Array
+from nmigen import Module, Signal, Cat, Mux, Array, Const
 from nmigen.lib.coding import PriorityEncoder
 from nmigen.cli import main, verilog
 from nmigen.lib.coding import PriorityEncoder
 from nmigen.cli import main, verilog
+from math import log
 
 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
 from fpbase import MultiShiftRMerge, Trigger
 
 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
 from fpbase import MultiShiftRMerge, Trigger
@@ -26,6 +27,114 @@ class FPState(FPBase):
             setattr(self, k, v)
 
 
             setattr(self, k, v)
 
 
+class FPGetSyncOpsMod:
+    def __init__(self, width, num_ops=2):
+        self.width = width
+        self.num_ops = num_ops
+        inops = []
+        outops = []
+        for i in range(num_ops):
+            inops.append(Signal(width, reset_less=True))
+            outops.append(Signal(width, reset_less=True))
+        self.in_op = inops
+        self.out_op = outops
+        self.stb = Signal(num_ops)
+        self.ack = Signal()
+        self.ready = Signal(reset_less=True)
+        self.out_decode = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.ready.eq(self.stb == Const(-1, (self.num_ops, False)))
+        m.d.comb += self.out_decode.eq(self.ack & self.ready)
+        with m.If(self.out_decode):
+            for i in range(self.num_ops):
+                m.d.comb += [
+                        self.out_op[i].eq(self.in_op[i]),
+                ]
+        return m
+
+    def ports(self):
+        return self.in_op + self.out_op + [self.stb, self.ack]
+
+
+class FPOps(Trigger):
+    def __init__(self, width, num_ops):
+        Trigger.__init__(self)
+        self.width = width
+        self.num_ops = num_ops
+
+        res = []
+        for i in range(num_ops):
+            res.append(Signal(width))
+        self.v  = Array(res)
+
+    def ports(self):
+        res = []
+        for i in range(self.num_ops):
+            res.append(self.v[i])
+        res.append(self.ack)
+        res.append(self.stb)
+        return res
+
+
+class InputGroup:
+    def __init__(self, width, num_ops=2, num_rows=4):
+        self.width = width
+        self.num_ops = num_ops
+        self.num_rows = num_rows
+        self.mmax = int(log(self.num_rows) / log(2))
+        self.rs = []
+        self.mid = Signal(self.mmax, reset_less=True) # multiplex id
+        for i in range(num_rows):
+            self.rs.append(FPGetSyncOpsMod(width, num_ops))
+        self.rs = Array(self.rs)
+
+        self.out_op = FPOps(width, num_ops)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        pe = PriorityEncoder(self.num_rows)
+        m.submodules.selector = pe
+        m.submodules.out_op = self.out_op
+        m.submodules += self.rs
+
+        # connect priority encoder
+        in_ready = []
+        for i in range(self.num_rows):
+            in_ready.append(self.rs[i].ready)
+        m.d.comb += pe.i.eq(Cat(*in_ready))
+
+        active = Signal(reset_less=True)
+        out_en = Signal(reset_less=True)
+        m.d.comb += active.eq(~pe.n) # encoder active
+        m.d.comb += out_en.eq(active & self.out_op.trigger)
+
+        # encoder active: ack relevant input, record MID, pass output
+        with m.If(out_en):
+            rs = self.rs[pe.o]
+            m.d.sync += self.mid.eq(pe.o)
+            m.d.sync += rs.ack.eq(0)
+            m.d.sync += self.out_op.stb.eq(0)
+            for j in range(self.num_ops):
+                m.d.sync += self.out_op.v[j].eq(rs.out_op[j])
+        with m.Else():
+            m.d.sync += self.out_op.stb.eq(1)
+            # acks all default to zero
+            for i in range(self.num_rows):
+                m.d.sync += self.rs[i].ack.eq(1)
+
+        return m
+
+    def ports(self):
+        res = []
+        for i in range(self.num_rows):
+            inop = self.rs[i]
+            res += inop.in_op + [inop.stb]
+        return self.out_op.ports() + res + [self.mid]
+
+
 class FPGetOpMod:
     def __init__(self, width):
         self.in_op = FPOp(width)
 class FPGetOpMod:
     def __init__(self, width):
         self.in_op = FPOp(width)
@@ -60,7 +169,7 @@ class FPGetOp(FPState):
         """ links module to inputs and outputs
         """
         setattr(m.submodules, self.state_from, self.mod)
         """ links module to inputs and outputs
         """
         setattr(m.submodules, self.state_from, self.mod)
-        m.d.comb += self.mod.in_op.copy(in_op)
+        m.d.comb += self.mod.in_op.eq(in_op)
         #m.d.comb += self.out_op.eq(self.mod.out_op)
         m.d.comb += self.out_decode.eq(self.mod.out_decode)
 
         #m.d.comb += self.out_op.eq(self.mod.out_op)
         m.d.comb += self.out_decode.eq(self.mod.out_decode)
 
@@ -130,12 +239,21 @@ class FPGet2Op(FPState):
                 self.mod.ack.eq(0),
                 #self.out_op1.v.eq(self.mod.out_op1.v),
                 #self.out_op2.v.eq(self.mod.out_op2.v),
                 self.mod.ack.eq(0),
                 #self.out_op1.v.eq(self.mod.out_op1.v),
                 #self.out_op2.v.eq(self.mod.out_op2.v),
-                self.out_op1.copy(self.mod.out_op1),
-                self.out_op2.copy(self.mod.out_op2)
+                self.out_op1.eq(self.mod.out_op1),
+                self.out_op2.eq(self.mod.out_op2)
             ]
         with m.Else():
             m.d.sync += self.mod.ack.eq(1)
 
             ]
         with m.Else():
             m.d.sync += self.mod.ack.eq(1)
 
+class FPNumBase2Ops:
+
+    def __init__(self, width):
+        self.a = FPNumBase(width)
+        self.b = FPNumBase(width)
+
+    def eq(self, i):
+        return [self.a.eq(i.a), self.a.eq(i.b)]
+
 
 class FPAddSpecialCasesMod:
     """ special cases: NaNs, infs, zeros, denormalised
 
 class FPAddSpecialCasesMod:
     """ special cases: NaNs, infs, zeros, denormalised
@@ -144,34 +262,40 @@ class FPAddSpecialCasesMod:
     """
 
     def __init__(self, width):
     """
 
     def __init__(self, width):
-        self.in_a = FPNumBase(width)
-        self.in_b = FPNumBase(width)
-        self.out_z = FPNumOut(width, False)
+        self.width = width
+        self.i = self.ispec()
+        self.out_z = self.ospec()
         self.out_do_z = Signal(reset_less=True)
 
         self.out_do_z = Signal(reset_less=True)
 
+    def ispec(self):
+        return FPNumBase2Ops(self.width)
+
+    def ospec(self):
+        return FPNumOut(self.width, False)
+
     def setup(self, m, in_a, in_b, out_do_z):
         """ links module to inputs and outputs
         """
         m.submodules.specialcases = self
     def setup(self, m, in_a, in_b, out_do_z):
         """ links module to inputs and outputs
         """
         m.submodules.specialcases = self
-        m.d.comb += self.in_a.copy(in_a)
-        m.d.comb += self.in_b.copy(in_b)
+        m.d.comb += self.i.a.eq(in_a)
+        m.d.comb += self.i.b.eq(in_b)
         m.d.comb += out_do_z.eq(self.out_do_z)
 
     def elaborate(self, platform):
         m = Module()
 
         m.d.comb += out_do_z.eq(self.out_do_z)
 
     def elaborate(self, platform):
         m = Module()
 
-        m.submodules.sc_in_a = self.in_a
-        m.submodules.sc_in_b = self.in_b
+        m.submodules.sc_in_a = self.i.a
+        m.submodules.sc_in_b = self.i.b
         m.submodules.sc_out_z = self.out_z
 
         s_nomatch = Signal()
         m.submodules.sc_out_z = self.out_z
 
         s_nomatch = Signal()
-        m.d.comb += s_nomatch.eq(self.in_a.s != self.in_b.s)
+        m.d.comb += s_nomatch.eq(self.i.a.s != self.i.b.s)
 
         m_match = Signal()
 
         m_match = Signal()
-        m.d.comb += m_match.eq(self.in_a.m == self.in_b.m)
+        m.d.comb += m_match.eq(self.i.a.m == self.i.b.m)
 
         # if a is NaN or b is NaN return NaN
 
         # if a is NaN or b is NaN return NaN
-        with m.If(self.in_a.is_nan | self.in_b.is_nan):
+        with m.If(self.i.a.is_nan | self.i.b.is_nan):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_z.nan(0)
 
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_z.nan(0)
 
@@ -199,39 +323,39 @@ class FPAddSpecialCasesMod:
         #    m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
 
         # if a is inf return inf (or NaN)
         #    m.d.comb += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
 
         # if a is inf return inf (or NaN)
-        with m.Elif(self.in_a.is_inf):
+        with m.Elif(self.i.a.is_inf):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_do_z.eq(1)
-            m.d.comb += self.out_z.inf(self.in_a.s)
+            m.d.comb += self.out_z.inf(self.i.a.s)
             # if a is inf and signs don't match return NaN
             # if a is inf and signs don't match return NaN
-            with m.If(self.in_b.exp_128 & s_nomatch):
+            with m.If(self.i.b.exp_128 & s_nomatch):
                 m.d.comb += self.out_z.nan(0)
 
         # if b is inf return inf
                 m.d.comb += self.out_z.nan(0)
 
         # if b is inf return inf
-        with m.Elif(self.in_b.is_inf):
+        with m.Elif(self.i.b.is_inf):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_do_z.eq(1)
-            m.d.comb += self.out_z.inf(self.in_b.s)
+            m.d.comb += self.out_z.inf(self.i.b.s)
 
         # if a is zero and b zero return signed-a/b
 
         # if a is zero and b zero return signed-a/b
-        with m.Elif(self.in_a.is_zero & self.in_b.is_zero):
+        with m.Elif(self.i.a.is_zero & self.i.b.is_zero):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_do_z.eq(1)
-            m.d.comb += self.out_z.create(self.in_a.s & self.in_b.s,
-                                          self.in_b.e,
-                                          self.in_b.m[3:-1])
+            m.d.comb += self.out_z.create(self.i.a.s & self.i.b.s,
+                                          self.i.b.e,
+                                          self.i.b.m[3:-1])
 
         # if a is zero return b
 
         # if a is zero return b
-        with m.Elif(self.in_a.is_zero):
+        with m.Elif(self.i.a.is_zero):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_do_z.eq(1)
-            m.d.comb += self.out_z.create(self.in_b.s, self.in_b.e,
-                                      self.in_b.m[3:-1])
+            m.d.comb += self.out_z.create(self.i.b.s, self.i.b.e,
+                                      self.i.b.m[3:-1])
 
         # if b is zero return a
 
         # if b is zero return a
-        with m.Elif(self.in_b.is_zero):
+        with m.Elif(self.i.b.is_zero):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_do_z.eq(1)
-            m.d.comb += self.out_z.create(self.in_a.s, self.in_a.e,
-                                      self.in_a.m[3:-1])
+            m.d.comb += self.out_z.create(self.i.a.s, self.i.a.e,
+                                      self.i.a.m[3:-1])
 
         # if a equal to -b return zero (+ve zero)
 
         # if a equal to -b return zero (+ve zero)
-        with m.Elif(s_nomatch & m_match & (self.in_a.e == self.in_b.e)):
+        with m.Elif(s_nomatch & m_match & (self.i.a.e == self.i.b.e)):
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_z.zero(0)
 
             m.d.comb += self.out_do_z.eq(1)
             m.d.comb += self.out_z.zero(0)
 
@@ -267,7 +391,7 @@ class FPAddSpecialCases(FPState, FPID):
         FPState.__init__(self, "special_cases")
         FPID.__init__(self, id_wid)
         self.mod = FPAddSpecialCasesMod(width)
         FPState.__init__(self, "special_cases")
         FPID.__init__(self, id_wid)
         self.mod = FPAddSpecialCasesMod(width)
-        self.out_z = FPNumOut(width, False)
+        self.out_z = self.mod.ospec()
         self.out_do_z = Signal(reset_less=True)
 
     def setup(self, m, in_a, in_b, in_mid):
         self.out_do_z = Signal(reset_less=True)
 
     def setup(self, m, in_a, in_b, in_mid):
@@ -318,8 +442,8 @@ class FPAddSpecialCasesDeNorm(FPState, FPID):
             m.next = "put_z"
         with m.Else():
             m.next = "align"
             m.next = "put_z"
         with m.Else():
             m.next = "align"
-            m.d.sync += self.out_a.copy(self.dmod.out_a)
-            m.d.sync += self.out_b.copy(self.dmod.out_b)
+            m.d.sync += self.out_a.eq(self.dmod.out_a)
+            m.d.sync += self.out_b.eq(self.dmod.out_b)
 
 
 class FPAddDeNormMod(FPState):
 
 
 class FPAddDeNormMod(FPState):
@@ -334,8 +458,8 @@ class FPAddDeNormMod(FPState):
         """ links module to inputs and outputs
         """
         m.submodules.denormalise = self
         """ links module to inputs and outputs
         """
         m.submodules.denormalise = self
-        m.d.comb += self.in_a.copy(in_a)
-        m.d.comb += self.in_b.copy(in_b)
+        m.d.comb += self.in_a.eq(in_a)
+        m.d.comb += self.in_b.eq(in_b)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -344,13 +468,13 @@ class FPAddDeNormMod(FPState):
         m.submodules.denorm_out_a = self.out_a
         m.submodules.denorm_out_b = self.out_b
         # hmmm, don't like repeating identical code
         m.submodules.denorm_out_a = self.out_a
         m.submodules.denorm_out_b = self.out_b
         # hmmm, don't like repeating identical code
-        m.d.comb += self.out_a.copy(self.in_a)
+        m.d.comb += self.out_a.eq(self.in_a)
         with m.If(self.in_a.exp_n127):
             m.d.comb += self.out_a.e.eq(self.in_a.N126) # limit a exponent
         with m.Else():
             m.d.comb += self.out_a.m[-1].eq(1) # set top mantissa bit
 
         with m.If(self.in_a.exp_n127):
             m.d.comb += self.out_a.e.eq(self.in_a.N126) # limit a exponent
         with m.Else():
             m.d.comb += self.out_a.m[-1].eq(1) # set top mantissa bit
 
-        m.d.comb += self.out_b.copy(self.in_b)
+        m.d.comb += self.out_b.eq(self.in_b)
         with m.If(self.in_b.exp_n127):
             m.d.comb += self.out_b.e.eq(self.in_b.N126) # limit a exponent
         with m.Else():
         with m.If(self.in_b.exp_n127):
             m.d.comb += self.out_b.e.eq(self.in_b.N126) # limit a exponent
         with m.Else():
@@ -379,8 +503,8 @@ class FPAddDeNorm(FPState, FPID):
         self.idsync(m)
         # Denormalised Number checks
         m.next = "align"
         self.idsync(m)
         # Denormalised Number checks
         m.next = "align"
-        m.d.sync += self.out_a.copy(self.mod.out_a)
-        m.d.sync += self.out_b.copy(self.mod.out_b)
+        m.d.sync += self.out_a.eq(self.mod.out_a)
+        m.d.sync += self.out_b.eq(self.mod.out_b)
 
 
 class FPAddAlignMultiMod(FPState):
 
 
 class FPAddAlignMultiMod(FPState):
@@ -408,8 +532,8 @@ class FPAddAlignMultiMod(FPState):
 
         # exponent of a greater than b: shift b down
         m.d.comb += self.exp_eq.eq(0)
 
         # exponent of a greater than b: shift b down
         m.d.comb += self.exp_eq.eq(0)
-        m.d.comb += self.out_a.copy(self.in_a)
-        m.d.comb += self.out_b.copy(self.in_b)
+        m.d.comb += self.out_a.eq(self.in_a)
+        m.d.comb += self.out_b.eq(self.in_b)
         agtb = Signal(reset_less=True)
         altb = Signal(reset_less=True)
         m.d.comb += agtb.eq(self.in_a.e > self.in_b.e)
         agtb = Signal(reset_less=True)
         altb = Signal(reset_less=True)
         m.d.comb += agtb.eq(self.in_a.e > self.in_b.e)
@@ -439,18 +563,18 @@ class FPAddAlignMulti(FPState, FPID):
         """ links module to inputs and outputs
         """
         m.submodules.align = self.mod
         """ links module to inputs and outputs
         """
         m.submodules.align = self.mod
-        m.d.comb += self.mod.in_a.copy(in_a)
-        m.d.comb += self.mod.in_b.copy(in_b)
-        #m.d.comb += self.out_a.copy(self.mod.out_a)
-        #m.d.comb += self.out_b.copy(self.mod.out_b)
+        m.d.comb += self.mod.in_a.eq(in_a)
+        m.d.comb += self.mod.in_b.eq(in_b)
+        #m.d.comb += self.out_a.eq(self.mod.out_a)
+        #m.d.comb += self.out_b.eq(self.mod.out_b)
         m.d.comb += self.exp_eq.eq(self.mod.exp_eq)
         if self.in_mid is not None:
             m.d.comb += self.in_mid.eq(in_mid)
 
     def action(self, m):
         self.idsync(m)
         m.d.comb += self.exp_eq.eq(self.mod.exp_eq)
         if self.in_mid is not None:
             m.d.comb += self.in_mid.eq(in_mid)
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_a.copy(self.mod.out_a)
-        m.d.sync += self.out_b.copy(self.mod.out_b)
+        m.d.sync += self.out_a.eq(self.mod.out_a)
+        m.d.sync += self.out_b.eq(self.mod.out_b)
         with m.If(self.exp_eq):
             m.next = "add_0"
 
         with m.If(self.exp_eq):
             m.next = "add_0"
 
@@ -468,8 +592,8 @@ class FPAddAlignSingleMod:
         """ links module to inputs and outputs
         """
         m.submodules.align = self
         """ links module to inputs and outputs
         """
         m.submodules.align = self
-        m.d.comb += self.in_a.copy(in_a)
-        m.d.comb += self.in_b.copy(in_b)
+        m.d.comb += self.in_a.eq(in_a)
+        m.d.comb += self.in_b.eq(in_b)
 
     def elaborate(self, platform):
         """ Aligns A against B or B against A, depending on which has the
 
     def elaborate(self, platform):
         """ Aligns A against B or B against A, depending on which has the
@@ -515,22 +639,22 @@ class FPAddAlignSingleMod:
         m.d.comb += egz.eq(self.in_a.e > self.in_b.e)
 
         # default: A-exp == B-exp, A and B untouched (fall through)
         m.d.comb += egz.eq(self.in_a.e > self.in_b.e)
 
         # default: A-exp == B-exp, A and B untouched (fall through)
-        m.d.comb += self.out_a.copy(self.in_a)
-        m.d.comb += self.out_b.copy(self.in_b)
+        m.d.comb += self.out_a.eq(self.in_a)
+        m.d.comb += self.out_b.eq(self.in_b)
         # only one shifter (muxed)
         #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
         # exponent of a greater than b: shift b down
         with m.If(egz):
         # only one shifter (muxed)
         #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
         # exponent of a greater than b: shift b down
         with m.If(egz):
-            m.d.comb += [t_inp.copy(self.in_b),
+            m.d.comb += [t_inp.eq(self.in_b),
                          tdiff.eq(ediff),
                          tdiff.eq(ediff),
-                         self.out_b.copy(t_out),
+                         self.out_b.eq(t_out),
                          self.out_b.s.eq(self.in_b.s), # whoops forgot sign
                         ]
         # exponent of b greater than a: shift a down
         with m.Elif(elz):
                          self.out_b.s.eq(self.in_b.s), # whoops forgot sign
                         ]
         # exponent of b greater than a: shift a down
         with m.Elif(elz):
-            m.d.comb += [t_inp.copy(self.in_a),
+            m.d.comb += [t_inp.eq(self.in_a),
                          tdiff.eq(ediffr),
                          tdiff.eq(ediffr),
-                         self.out_a.copy(t_out),
+                         self.out_a.eq(t_out),
                          self.out_a.s.eq(self.in_a.s), # whoops forgot sign
                         ]
         return m
                          self.out_a.s.eq(self.in_a.s), # whoops forgot sign
                         ]
         return m
@@ -555,8 +679,8 @@ class FPAddAlignSingle(FPState, FPID):
     def action(self, m):
         self.idsync(m)
         # NOTE: could be done as comb
     def action(self, m):
         self.idsync(m)
         # NOTE: could be done as comb
-        m.d.sync += self.out_a.copy(self.mod.out_a)
-        m.d.sync += self.out_b.copy(self.mod.out_b)
+        m.d.sync += self.out_a.eq(self.mod.out_a)
+        m.d.sync += self.out_b.eq(self.mod.out_b)
         m.next = "add_0"
 
 
         m.next = "add_0"
 
 
@@ -582,11 +706,11 @@ class FPAddAlignSingleAdd(FPState, FPID):
         """ links module to inputs and outputs
         """
         self.mod.setup(m, in_a, in_b)
         """ links module to inputs and outputs
         """
         self.mod.setup(m, in_a, in_b)
-        m.d.comb += self.out_a.copy(self.mod.out_a)
-        m.d.comb += self.out_b.copy(self.mod.out_b)
+        m.d.comb += self.out_a.eq(self.mod.out_a)
+        m.d.comb += self.out_b.eq(self.mod.out_b)
 
         self.a0mod.setup(m, self.out_a, self.out_b)
 
         self.a0mod.setup(m, self.out_a, self.out_b)
-        m.d.comb += self.a0_out_z.copy(self.a0mod.out_z)
+        m.d.comb += self.a0_out_z.eq(self.a0mod.out_z)
         m.d.comb += self.out_tot.eq(self.a0mod.out_tot)
 
         self.a1mod.setup(m, self.out_tot, self.a0_out_z)
         m.d.comb += self.out_tot.eq(self.a0mod.out_tot)
 
         self.a1mod.setup(m, self.out_tot, self.a0_out_z)
@@ -596,8 +720,8 @@ class FPAddAlignSingleAdd(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_of.copy(self.a1mod.out_of)
-        m.d.sync += self.out_z.copy(self.a1mod.out_z)
+        m.d.sync += self.out_of.eq(self.a1mod.out_of)
+        m.d.sync += self.out_z.eq(self.a1mod.out_z)
         m.next = "normalise_1"
 
 
         m.next = "normalise_1"
 
 
@@ -614,8 +738,8 @@ class FPAddStage0Mod:
         """ links module to inputs and outputs
         """
         m.submodules.add0 = self
         """ links module to inputs and outputs
         """
         m.submodules.add0 = self
-        m.d.comb += self.in_a.copy(in_a)
-        m.d.comb += self.in_b.copy(in_b)
+        m.d.comb += self.in_a.eq(in_a)
+        m.d.comb += self.in_b.eq(in_b)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -679,7 +803,7 @@ class FPAddStage0(FPState, FPID):
     def action(self, m):
         self.idsync(m)
         # NOTE: these could be done as combinatorial (merge add0+add1)
     def action(self, m):
         self.idsync(m)
         # NOTE: these could be done as combinatorial (merge add0+add1)
-        m.d.sync += self.out_z.copy(self.mod.out_z)
+        m.d.sync += self.out_z.eq(self.mod.out_z)
         m.d.sync += self.out_tot.eq(self.mod.out_tot)
         m.next = "add_1"
 
         m.d.sync += self.out_tot.eq(self.mod.out_tot)
         m.next = "add_1"
 
@@ -702,7 +826,7 @@ class FPAddStage1Mod(FPState):
         m.submodules.add1 = self
         m.submodules.add1_out_overflow = self.out_of
 
         m.submodules.add1 = self
         m.submodules.add1_out_overflow = self.out_of
 
-        m.d.comb += self.in_z.copy(in_z)
+        m.d.comb += self.in_z.eq(in_z)
         m.d.comb += self.in_tot.eq(in_tot)
 
     def elaborate(self, platform):
         m.d.comb += self.in_tot.eq(in_tot)
 
     def elaborate(self, platform):
@@ -711,8 +835,8 @@ class FPAddStage1Mod(FPState):
         #m.submodules.norm1_out_overflow = self.out_of
         #m.submodules.norm1_in_z = self.in_z
         #m.submodules.norm1_out_z = self.out_z
         #m.submodules.norm1_out_overflow = self.out_of
         #m.submodules.norm1_in_z = self.in_z
         #m.submodules.norm1_out_z = self.out_z
-        m.d.comb += self.out_z.copy(self.in_z)
-        # tot[27] gets set when the sum overflows. shift result down
+        m.d.comb += self.out_z.eq(self.in_z)
+        # tot[-1] (MSB) gets set when the sum overflows. shift result down
         with m.If(self.in_tot[-1]):
             m.d.comb += [
                 self.out_z.m.eq(self.in_tot[4:]),
         with m.If(self.in_tot[-1]):
             m.d.comb += [
                 self.out_z.m.eq(self.in_tot[4:]),
@@ -722,7 +846,7 @@ class FPAddStage1Mod(FPState):
                 self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
                 self.out_z.e.eq(self.in_z.e + 1)
         ]
                 self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
                 self.out_z.e.eq(self.in_z.e + 1)
         ]
-        # tot[27] zero case
+        # tot[-1] (MSB) zero case
         with m.Else():
             m.d.comb += [
                 self.out_z.m.eq(self.in_tot[3:]),
         with m.Else():
             m.d.comb += [
                 self.out_z.m.eq(self.in_tot[3:]),
@@ -756,12 +880,75 @@ class FPAddStage1(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_of.copy(self.mod.out_of)
-        m.d.sync += self.out_z.copy(self.mod.out_z)
+        m.d.sync += self.out_of.eq(self.mod.out_of)
+        m.d.sync += self.out_z.eq(self.mod.out_z)
         m.d.sync += self.norm_stb.eq(1)
         m.next = "normalise_1"
 
 
         m.d.sync += self.norm_stb.eq(1)
         m.next = "normalise_1"
 
 
+class FPNormaliseModSingle:
+
+    def __init__(self, width):
+        self.width = width
+        self.in_z = FPNumBase(width, False)
+        self.out_z = FPNumBase(width, False)
+
+    def setup(self, m, in_z, out_z, modname):
+        """ links module to inputs and outputs
+        """
+        m.submodules.normalise = self
+        m.d.comb += self.in_z.eq(in_z)
+        m.d.comb += out_z.eq(self.out_z)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        mwid = self.out_z.m_width+2
+        pe = PriorityEncoder(mwid)
+        m.submodules.norm_pe = pe
+
+        m.submodules.norm1_out_z = self.out_z
+        m.submodules.norm1_in_z = self.in_z
+
+        in_z = FPNumBase(self.width, False)
+        in_of = Overflow()
+        m.submodules.norm1_insel_z = in_z
+        m.submodules.norm1_insel_overflow = in_of
+
+        espec = (len(in_z.e), True)
+        ediff_n126 = Signal(espec, reset_less=True)
+        msr = MultiShiftRMerge(mwid, espec)
+        m.submodules.multishift_r = msr
+
+        m.d.comb += in_z.eq(self.in_z)
+        m.d.comb += in_of.eq(self.in_of)
+        # initialise out from in (overridden below)
+        m.d.comb += self.out_z.eq(in_z)
+        m.d.comb += self.out_of.eq(in_of)
+        # normalisation increase/decrease conditions
+        decrease = Signal(reset_less=True)
+        m.d.comb += decrease.eq(in_z.m_msbzero)
+        # decrease exponent
+        with m.If(decrease):
+            # *sigh* not entirely obvious: count leading zeros (clz)
+            # with a PriorityEncoder: to find from the MSB
+            # we reverse the order of the bits.
+            temp_m = Signal(mwid, reset_less=True)
+            temp_s = Signal(mwid+1, reset_less=True)
+            clz = Signal((len(in_z.e), True), reset_less=True)
+            m.d.comb += [
+                # cat round and guard bits back into the mantissa
+                temp_m.eq(Cat(in_of.round_bit, in_of.guard, in_z.m)),
+                pe.i.eq(temp_m[::-1]),          # inverted
+                clz.eq(pe.o),                   # count zeros from MSB down
+                temp_s.eq(temp_m << clz),       # shift mantissa UP
+                self.out_z.e.eq(in_z.e - clz),  # DECREASE exponent
+                self.out_z.m.eq(temp_s[2:]),    # exclude bits 0&1
+            ]
+
+        return m
+
+
 class FPNorm1ModSingle:
 
     def __init__(self, width):
 class FPNorm1ModSingle:
 
     def __init__(self, width):
@@ -777,10 +964,10 @@ class FPNorm1ModSingle:
         """
         m.submodules.normalise_1 = self
 
         """
         m.submodules.normalise_1 = self
 
-        m.d.comb += self.in_z.copy(in_z)
-        m.d.comb += self.in_of.copy(in_of)
+        m.d.comb += self.in_z.eq(in_z)
+        m.d.comb += self.in_of.eq(in_of)
 
 
-        m.d.comb += out_z.copy(self.out_z)
+        m.d.comb += out_z.eq(self.out_z)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -804,11 +991,11 @@ class FPNorm1ModSingle:
         msr = MultiShiftRMerge(mwid, espec)
         m.submodules.multishift_r = msr
 
         msr = MultiShiftRMerge(mwid, espec)
         m.submodules.multishift_r = msr
 
-        m.d.comb += in_z.copy(self.in_z)
-        m.d.comb += in_of.copy(self.in_of)
+        m.d.comb += in_z.eq(self.in_z)
+        m.d.comb += in_of.eq(self.in_of)
         # initialise out from in (overridden below)
         # initialise out from in (overridden below)
-        m.d.comb += self.out_z.copy(in_z)
-        m.d.comb += self.out_of.copy(in_of)
+        m.d.comb += self.out_z.eq(in_z)
+        m.d.comb += self.out_of.eq(in_of)
         # normalisation increase/decrease conditions
         decrease = Signal(reset_less=True)
         increase = Signal(reset_less=True)
         # normalisation increase/decrease conditions
         decrease = Signal(reset_less=True)
         increase = Signal(reset_less=True)
@@ -891,14 +1078,14 @@ class FPNorm1ModMulti:
 
         # select which of temp or in z/of to use
         with m.If(self.in_select):
 
         # select which of temp or in z/of to use
         with m.If(self.in_select):
-            m.d.comb += in_z.copy(self.in_z)
-            m.d.comb += in_of.copy(self.in_of)
+            m.d.comb += in_z.eq(self.in_z)
+            m.d.comb += in_of.eq(self.in_of)
         with m.Else():
         with m.Else():
-            m.d.comb += in_z.copy(self.temp_z)
-            m.d.comb += in_of.copy(self.temp_of)
+            m.d.comb += in_z.eq(self.temp_z)
+            m.d.comb += in_of.eq(self.temp_of)
         # initialise out from in (overridden below)
         # initialise out from in (overridden below)
-        m.d.comb += self.out_z.copy(in_z)
-        m.d.comb += self.out_of.copy(in_of)
+        m.d.comb += self.out_z.eq(in_z)
+        m.d.comb += self.out_of.eq(in_of)
         # normalisation increase/decrease conditions
         decrease = Signal(reset_less=True)
         increase = Signal(reset_less=True)
         # normalisation increase/decrease conditions
         decrease = Signal(reset_less=True)
         increase = Signal(reset_less=True)
@@ -984,8 +1171,8 @@ class FPNorm1Multi(FPState, FPID):
     def action(self, m):
         self.idsync(m)
         m.d.comb += self.in_accept.eq((~self.ack) & (self.stb))
     def action(self, m):
         self.idsync(m)
         m.d.comb += self.in_accept.eq((~self.ack) & (self.stb))
-        m.d.sync += self.temp_of.copy(self.mod.out_of)
-        m.d.sync += self.temp_z.copy(self.out_z)
+        m.d.sync += self.temp_of.eq(self.mod.out_of)
+        m.d.sync += self.temp_z.eq(self.out_z)
         with m.If(self.out_norm):
             with m.If(self.in_accept):
                 m.d.sync += [
         with m.If(self.out_norm):
             with m.If(self.in_accept):
                 m.d.sync += [
@@ -1022,13 +1209,13 @@ class FPNormToPack(FPState, FPID):
         r_out_z = FPNumBase(self.width)
         rmod.setup(m, n_out_z, n_out_roundz)
         m.d.comb += n_out_roundz.eq(nmod.out_of.roundz)
         r_out_z = FPNumBase(self.width)
         rmod.setup(m, n_out_z, n_out_roundz)
         m.d.comb += n_out_roundz.eq(nmod.out_of.roundz)
-        m.d.comb += r_out_z.copy(rmod.out_z)
+        m.d.comb += r_out_z.eq(rmod.out_z)
 
         # Corrections (chained to rounding)
         cmod = FPCorrectionsMod(self.width)
         c_out_z = FPNumBase(self.width)
         cmod.setup(m, r_out_z)
 
         # Corrections (chained to rounding)
         cmod = FPCorrectionsMod(self.width)
         c_out_z = FPNumBase(self.width)
         cmod.setup(m, r_out_z)
-        m.d.comb += c_out_z.copy(cmod.out_z)
+        m.d.comb += c_out_z.eq(cmod.out_z)
 
         # Pack (chained to corrections)
         self.pmod = FPPackMod(self.width)
 
         # Pack (chained to corrections)
         self.pmod = FPPackMod(self.width)
@@ -1055,12 +1242,12 @@ class FPRoundMod:
     def setup(self, m, in_z, roundz):
         m.submodules.roundz = self
 
     def setup(self, m, in_z, roundz):
         m.submodules.roundz = self
 
-        m.d.comb += self.in_z.copy(in_z)
+        m.d.comb += self.in_z.eq(in_z)
         m.d.comb += self.in_roundz.eq(roundz)
 
     def elaborate(self, platform):
         m = Module()
         m.d.comb += self.in_roundz.eq(roundz)
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.out_z.copy(self.in_z)
+        m.d.comb += self.out_z.eq(self.in_z)
         with m.If(self.in_roundz):
             m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
             with m.If(self.in_z.m == self.in_z.m1s): # all 1s
         with m.If(self.in_roundz):
             m.d.comb += self.out_z.m.eq(self.in_z.m + 1) # mantissa rounds up
             with m.If(self.in_z.m == self.in_z.m1s): # all 1s
@@ -1086,7 +1273,7 @@ class FPRound(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_z.copy(self.mod.out_z)
+        m.d.sync += self.out_z.eq(self.mod.out_z)
         m.next = "corrections"
 
 
         m.next = "corrections"
 
 
@@ -1100,13 +1287,13 @@ class FPCorrectionsMod:
         """ links module to inputs and outputs
         """
         m.submodules.corrections = self
         """ links module to inputs and outputs
         """
         m.submodules.corrections = self
-        m.d.comb += self.in_z.copy(in_z)
+        m.d.comb += self.in_z.eq(in_z)
 
     def elaborate(self, platform):
         m = Module()
         m.submodules.corr_in_z = self.in_z
         m.submodules.corr_out_z = self.out_z
 
     def elaborate(self, platform):
         m = Module()
         m.submodules.corr_in_z = self.in_z
         m.submodules.corr_out_z = self.out_z
-        m.d.comb += self.out_z.copy(self.in_z)
+        m.d.comb += self.out_z.eq(self.in_z)
         with m.If(self.in_z.is_denormalised):
             m.d.comb += self.out_z.e.eq(self.in_z.N127)
         return m
         with m.If(self.in_z.is_denormalised):
             m.d.comb += self.out_z.e.eq(self.in_z.N127)
         return m
@@ -1129,7 +1316,7 @@ class FPCorrections(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_z.copy(self.mod.out_z)
+        m.d.sync += self.out_z.eq(self.mod.out_z)
         m.next = "pack"
 
 
         m.next = "pack"
 
 
@@ -1143,7 +1330,7 @@ class FPPackMod:
         """ links module to inputs and outputs
         """
         m.submodules.pack = self
         """ links module to inputs and outputs
         """
         m.submodules.pack = self
-        m.d.comb += self.in_z.copy(in_z)
+        m.d.comb += self.in_z.eq(in_z)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -1452,7 +1639,7 @@ class ResArray:
         self.in_mid = Signal(self.id_wid, reset_less=True)
 
     def setup(self, m, in_z, in_mid):
         self.in_mid = Signal(self.id_wid, reset_less=True)
 
     def setup(self, m, in_z, in_mid):
-        m.d.comb += [self.in_z.copy(in_z),
+        m.d.comb += [self.in_z.eq(in_z),
                      self.in_mid.eq(in_mid)]
 
     def get_fragment(self, platform=None):
                      self.in_mid.eq(in_mid)]
 
     def get_fragment(self, platform=None):