rename inputs_ to terms_
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 95fa45a6e5093fa06a34c001485dbd0570bb0032..89e4c4d322782ff1e16b2e88646405d6dcce0bd3 100644 (file)
@@ -8,6 +8,8 @@ from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -308,7 +310,7 @@ class AddReduceData:
     def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
     def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.terms = [Signal(output_width, name=f"inputs_{i}",
+        self.terms = [Signal(output_width, name=f"terms_{i}",
                               reset_less=True)
                         for i in range(n_inputs)]
         self.part_pts = part_pts.like()
                               reset_less=True)
                         for i in range(n_inputs)]
         self.part_pts = part_pts.like()
@@ -342,21 +344,31 @@ class FinalReduceData:
         return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
         return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
     """ Final stage of add reduce
     """
 
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points):
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.output_width = pspec.width * 2
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
+        self.n_parts = pspec.n_parts
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
+        super().__init__(pspec, "finaladd")
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return FinalReduceData(self.partition_points,
+                                 self.output_width, self.n_parts)
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -373,7 +385,7 @@ class FinalAdd(Elaboratable):
             # base case for adding 2 inputs
             assert self.n_inputs == 2
             adder = PartitionedAdder(output_width,
             # base case for adding 2 inputs
             assert self.n_inputs == 2
             adder = PartitionedAdder(output_width,
-                                     self.i.part_pts, 2)
+                                     self.i.part_pts, self.partition_step)
             m.submodules.final_adder = adder
             m.d.comb += adder.a.eq(self.i.terms[0])
             m.d.comb += adder.b.eq(self.i.terms[1])
             m.submodules.final_adder = adder
             m.d.comb += adder.a.eq(self.i.terms[0])
             m.d.comb += adder.b.eq(self.i.terms[1])
@@ -386,7 +398,7 @@ class FinalAdd(Elaboratable):
         return m
 
 
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -398,25 +410,35 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points):
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
+        self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
             raise ValueError("partition_points doesn't fit in output_width")
 
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
-        n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
-        self.o = AddReduceData(partition_points, n_terms, output_width, n_parts)
+        self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+
+        super().__init__(pspec, "addreduce_%d" % lidx)
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return AddReduceData(self.partition_points, self.n_terms,
+                             self.output_width, self.n_parts)
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
@@ -500,7 +522,8 @@ class AddReduceSingle(Elaboratable):
         part_mask = Signal(self.output_width, reset_less=True)
 
         # get partition points as a mask
         part_mask = Signal(self.output_width, reset_less=True)
 
         # get partition points as a mask
-        mask = self.i.part_pts.as_mask(self.output_width, mul=2)
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
@@ -516,7 +539,7 @@ class AddReduceSingle(Elaboratable):
 
 
 class AddReduceInternal:
 
 
 class AddReduceInternal:
-    """Recursively Add list of numbers together.
+    """Iteratively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
         supported, except for by ``Signal.eq``.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
         supported, except for by ``Signal.eq``.
@@ -527,18 +550,18 @@ class AddReduceInternal:
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, partition_points,
-                       part_ops):
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
-        self.inputs = inputs
-        self.part_ops = part_ops
-        self.output_width = output_width
-        self.partition_points = partition_points
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
+        self.partition_step = partition_step
 
         self.create_levels()
 
 
         self.create_levels()
 
@@ -547,24 +570,22 @@ class AddReduceInternal:
 
         mods = []
         partition_points = self.partition_points
 
         mods = []
         partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
+        ilen = self.n_inputs
         while True:
         while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
+            groups = AddReduceSingle.full_adder_groups(ilen)
             if len(groups) == 0:
                 break
             if len(groups) == 0:
                 break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         partition_points)
+            lidx = len(mods)
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
+                                         partition_points,
+                                         self.partition_step)
             mods.append(next_level)
             partition_points = next_level.i.part_pts
             mods.append(next_level)
             partition_points = next_level.i.part_pts
-            inputs = next_level.o.terms
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
+            ilen = len(next_level.o.terms)
 
 
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              partition_points)
+        lidx = len(mods)
+        next_level = FinalAdd(self.pspec, lidx, ilen,
+                              partition_points, self.partition_step)
         mods.append(next_level)
 
         self.levels = mods
         mods.append(next_level)
 
         self.levels = mods
@@ -582,8 +603,8 @@ class AddReduce(AddReduceInternal, Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, inputs, output_width, register_levels, part_pts,
+                       part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -592,10 +613,15 @@ class AddReduce(AddReduceInternal, Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        AddReduceInternal.__init__(self, inputs, output_width,
-                                   partition_points, part_ops)
+        self._inputs = inputs
+        self._part_pts = part_pts
+        self._part_ops = part_ops
         n_parts = len(part_ops)
         n_parts = len(part_ops)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
+        self.i = AddReduceData(part_pts, len(inputs),
+                             output_width, n_parts)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
+        self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
 
     @staticmethod
         self.register_levels = register_levels
 
     @staticmethod
@@ -609,48 +635,16 @@ class AddReduce(AddReduceInternal, Elaboratable):
             if level > 0:
                 yield level - 1
 
             if level > 0:
                 yield level - 1
 
-    def create_levels(self):
-        """creates reduction levels"""
-
-        mods = []
-        partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
-        while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
-            if len(groups) == 0:
-                break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         partition_points)
-            mods.append(next_level)
-            partition_points = next_level.i.part_pts
-            inputs = next_level.o.terms
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
-
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              partition_points)
-        mods.append(next_level)
-
-        self.levels = mods
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
-        partition_points = self.partition_points
-        inputs = self.inputs
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        n_inputs = len(inputs)
-        output_width = self.output_width
-        i = AddReduceData(partition_points, n_inputs, output_width, n_parts)
-        m.d.comb += i.eq_from(partition_points, inputs, part_ops)
+        i = self.i
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
             if idx in self.register_levels:
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
             if idx in self.register_levels:
@@ -726,8 +720,8 @@ class ProductTerm(Elaboratable):
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
-        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
@@ -741,8 +735,8 @@ class ProductTerm(Elaboratable):
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
-        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
@@ -876,7 +870,7 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
-    def __init__(self, part_pts, width, n_parts, n_levels, pbwid):
+    def __init__(self, part_pts, width, n_parts, pbwid):
 
         self.pbwid = pbwid
         self.part_pts = part_pts
 
         self.pbwid = pbwid
         self.part_pts = part_pts
@@ -922,7 +916,7 @@ class Part(Elaboratable):
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
-            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
@@ -932,7 +926,7 @@ class Part(Elaboratable):
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
-            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
@@ -970,28 +964,35 @@ class IntermediateOut(Elaboratable):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
-                    self.intermed.part(i * w*2, w),
-                    self.intermed.part(i * w*2 + w, w)))
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, output_width, n_parts, part_pts):
+    def __init__(self, pspec, part_pts):
+
         self.part_pts = part_pts
         self.part_pts = part_pts
-        self.i = IntermediateData(part_pts, output_width, n_parts)
-        self.out_wid = output_width//2
-        # output
-        self.out = Signal(self.out_wid, reset_less=True)
-        self.intermediate_output = Signal(output_width, reset_less=True)
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
+
+        super().__init__(pspec, "finalout")
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -1040,11 +1041,16 @@ class FinalOut(Elaboratable):
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
-                    Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)),
-                    Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8))))
+                    Mux(d8[i], i8.bit_select(i * 8, 8),
+                               i16.bit_select(i * 8, 8)),
+                    Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+                                      i64.bit_select(i * 8, 8))))
             ol.append(op)
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
-        m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
         return m
 
 
@@ -1126,7 +1132,7 @@ class InputData:
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
-    def eq_from(self, part_pts, inputs, part_ops):
+    def eq_from(self, part_pts, a, b, part_ops):
         return [self.part_pts.eq(part_pts)] + \
                [self.a.eq(a), self.b.eq(b)] + \
                [self.part_ops[i].eq(part_ops[i])
         return [self.part_pts.eq(part_pts)] + \
                [self.a.eq(a), self.b.eq(b)] + \
                [self.part_ops[i].eq(part_ops[i])
@@ -1136,26 +1142,35 @@ class InputData:
         return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
 
 
         return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
 
 
-class AllTerms(Elaboratable):
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128) # needed for unit tests
+        self.output = Signal(64)
+
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
+
+
+class AllTerms(PipeModBase):
     """Set of terms to be added together
     """
 
     """Set of terms to be added together
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels):
-        """Create an ``AddReduce``.
-
-        :param inputs: input ``Signal``s to be summed.
-        :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
-        :param partition_points: the input partition points.
+    def __init__(self, pspec, n_inputs):
+        """Create an ``AllTerms``.
         """
         """
-        self.i = InputData()
-        self.register_levels = register_levels
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.o = AddReduceData(self.i.part_pts, n_inputs,
-                               output_width, n_parts)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
+        super().__init__(pspec, "allterms")
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -1179,11 +1194,10 @@ class AllTerms(Elaboratable):
             setattr(m.submodules, "signs%d" % i, s)
             m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
             setattr(m.submodules, "signs%d" % i, s)
             m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
             m.d.comb += mod.a.eq(self.i.a)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
             m.d.comb += mod.a.eq(self.i.a)
@@ -1236,13 +1250,22 @@ class AllTerms(Elaboratable):
         return m
 
 
         return m
 
 
-class Intermediates(Elaboratable):
+class Intermediates(PipeModBase):
     """ Intermediate output modules
     """
 
     """ Intermediate output modules
     """
 
-    def __init__(self, output_width, n_parts, partition_points):
-        self.i = FinalReduceData(partition_points, output_width, n_parts)
-        self.o = IntermediateData(partition_points, output_width, n_parts)
+    def __init__(self, pspec, part_pts):
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+
+        super().__init__(pspec, "intermediates")
+
+    def ispec(self):
+        return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -1319,21 +1342,32 @@ class Mul8_16_32_64(Elaboratable):
             flip-flops are to be inserted.
         """
 
             flip-flops are to be inserted.
         """
 
+        self.id_wid = 0 # num_bits(num_rows)
+        self.op_wid = 0
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+        self.pspec.n_parts = 8
+
         # parameter(s)
         self.register_levels = list(register_levels)
 
         # parameter(s)
         self.register_levels = list(register_levels)
 
+        self.i = self.ispec()
+        self.o = self.ospec()
+
         # inputs
         # inputs
-        self.i = InputData()
         self.part_pts = self.i.part_pts
         self.part_ops = self.i.part_ops
         self.a = self.i.a
         self.b = self.i.b
 
         self.part_pts = self.i.part_pts
         self.part_ops = self.i.part_ops
         self.a = self.i.a
         self.b = self.i.b
 
-        # intermediates (needed for unit tests)
-        self.intermediate_output = Signal(128)
-
         # output
         # output
-        self.output = Signal(64)
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
@@ -1341,37 +1375,32 @@ class Mul8_16_32_64(Elaboratable):
         part_pts = self.part_pts
 
         n_inputs = 64 + 4
         part_pts = self.part_pts
 
         n_inputs = 64 + 4
-        n_parts = 8 #len(self.part_pts)
-        t = AllTerms(n_inputs, 128, n_parts, self.register_levels)
-        m.submodules.allterms = t
-        m.d.comb += t.i.a.eq(self.a)
-        m.d.comb += t.i.b.eq(self.b)
-        m.d.comb += t.i.part_pts.eq(part_pts)
-        for i in range(8):
-            m.d.comb += t.i.part_ops[i].eq(self.part_ops[i])
+        t = AllTerms(self.pspec, n_inputs)
+        t.setup(m, self.i)
 
         terms = t.o.terms
 
 
         terms = t.o.terms
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               t.o.part_pts,
-                               t.o.part_ops)
-
-        out_part_ops = add_reduce.o.part_ops
-        out_part_pts = add_reduce.o.part_pts
+        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
 
 
-        m.submodules.add_reduce = add_reduce
+        i = t.o
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            mcur.setup(m, i)
+            o = mcur.ospec()
+            if idx in self.register_levels:
+                m.d.sync += o.eq(mcur.process(i))
+            else:
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
 
 
-        interm = Intermediates(128, 8, part_pts)
-        m.submodules.intermediates = interm
-        m.d.comb += interm.i.eq(add_reduce.o)
+        interm = Intermediates(self.pspec, part_pts)
+        interm.setup(m, i)
+        o = interm.process(interm.i)
 
         # final output
 
         # final output
-        m.submodules.finalout = finalout = FinalOut(128, 8, part_pts)
-        m.d.comb += finalout.i.eq(interm.o)
-        m.d.comb += self.output.eq(finalout.out)
-        m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)
+        finalout = FinalOut(self.pspec, part_pts)
+        finalout.setup(m, o)
+        m.d.comb += self.o.eq(finalout.process(o))
 
         return m
 
 
         return m