rename inputs_ to terms_
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 7b8782e220797f43db32f2fdae442c90bf464980..89e4c4d322782ff1e16b2e88646405d6dcce0bd3 100644 (file)
@@ -8,6 +8,8 @@ from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -71,17 +73,20 @@ class PartitionPoints(dict):
         for point, enabled in self.items():
             yield enabled.eq(rhs[point])
 
         for point, enabled in self.items():
             yield enabled.eq(rhs[point])
 
-    def as_mask(self, width):
+    def as_mask(self, width, mul=1):
         """Create a bit-mask from `self`.
 
         Each bit in the returned mask is clear only if the partition point at
         the same bit-index is enabled.
 
         :param width: the bit width of the resulting mask
         """Create a bit-mask from `self`.
 
         Each bit in the returned mask is clear only if the partition point at
         the same bit-index is enabled.
 
         :param width: the bit width of the resulting mask
+        :param mul: a "multiplier" which in-place expands the partition points
+                    typically set to "2" when used for multipliers
         """
         bits = []
         for i in range(width):
         """
         bits = []
         for i in range(width):
-            if i in self:
+            i /= mul
+            if i.is_integer() and int(i) in self:
                 bits.append(~self[i])
             else:
                 bits.append(True)
                 bits.append(~self[i])
             else:
                 bits.append(True)
@@ -227,13 +232,16 @@ class PartitionedAdder(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, width, partition_points):
+    def __init__(self, width, partition_points, partition_step=1):
         """Create a ``PartitionedAdder``.
 
         :param width: the bit width of the input and output
         :param partition_points: the input partition points
         """Create a ``PartitionedAdder``.
 
         :param width: the bit width of the input and output
         :param partition_points: the input partition points
+        :param partition_step: a multiplier (typically double) step
+                               which in-place "expands" the partition points
         """
         self.width = width
         """
         self.width = width
+        self.pmul = partition_step
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
         self.output = Signal(width, reset_less=True)
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
         self.output = Signal(width, reset_less=True)
@@ -267,11 +275,12 @@ class PartitionedAdder(Elaboratable):
         # carry has been carried *over* the break point.
 
         for i in range(self.width):
         # carry has been carried *over* the break point.
 
         for i in range(self.width):
-            if i in self.partition_points:
+            pi = i/self.pmul # double the range of the partition point test
+            if pi.is_integer() and pi in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(expanded_a[expanded_index])
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(expanded_a[expanded_index])
-                al.append(~self.partition_points[i]) # add extra bit in a
+                al.append(~self.partition_points[pi]) # add extra bit in a
                 eb.append(expanded_b[expanded_index])
                 bl.append(C(0)) # yes, add a zero
                 expanded_index += 1 # skip the extra point.  NOT in the output
                 eb.append(expanded_b[expanded_index])
                 bl.append(C(0)) # yes, add a zero
                 expanded_index += 1 # skip the extra point.  NOT in the output
@@ -298,60 +307,68 @@ FULL_ADDER_INPUT_COUNT = 3
 
 class AddReduceData:
 
 
 class AddReduceData:
 
-    def __init__(self, ppoints, n_inputs, output_width, n_parts):
+    def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.inputs = [Signal(output_width, name=f"inputs_{i}",
+        self.terms = [Signal(output_width, name=f"terms_{i}",
                               reset_less=True)
                         for i in range(n_inputs)]
                               reset_less=True)
                         for i in range(n_inputs)]
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
 
 
-    def eq_from(self, reg_partition_points, inputs, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
-               [self.inputs[i].eq(inputs[i])
-                                     for i in range(len(self.inputs))] + \
+    def eq_from(self, part_pts, inputs, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.terms[i].eq(inputs[i])
+                                     for i in range(len(self.terms))] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.inputs, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
 
 
 class FinalReduceData:
 
 
 
 class FinalReduceData:
 
-    def __init__(self, ppoints, output_width, n_parts):
+    def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
 
 
-    def eq_from(self, reg_partition_points, output, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
+    def eq_from(self, part_pts, output, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.output, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
 
 
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
     """ Final stage of add reduce
     """
 
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.output_width = pspec.width * 2
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
+        super().__init__(pspec, "finaladd")
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return FinalReduceData(self.partition_points,
+                                 self.output_width, self.n_parts)
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -363,24 +380,25 @@ class FinalAdd(Elaboratable):
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
-            m.d.comb += output.eq(self.i.inputs[0])
+            m.d.comb += output.eq(self.i.terms[0])
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
-            adder = PartitionedAdder(output_width, self.i.reg_partition_points)
+            adder = PartitionedAdder(output_width,
+                                     self.i.part_pts, self.partition_step)
             m.submodules.final_adder = adder
             m.submodules.final_adder = adder
-            m.d.comb += adder.a.eq(self.i.inputs[0])
-            m.d.comb += adder.b.eq(self.i.inputs[1])
+            m.d.comb += adder.a.eq(self.i.terms[0])
+            m.d.comb += adder.b.eq(self.i.terms[1])
             m.d.comb += output.eq(adder.output)
 
         # create output
             m.d.comb += output.eq(adder.output)
 
         # create output
-        m.d.comb += self.o.eq_from(self.i.reg_partition_points, output,
+        m.d.comb += self.o.eq_from(self.i.part_pts, output,
                                    self.i.part_ops)
 
         return m
 
 
                                    self.i.part_ops)
 
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -392,35 +410,35 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
         :param partition_points: the input partition points.
         """
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
+        self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
-        max_level = AddReduceSingle.get_max_level(n_inputs)
-        for level in self.register_levels:
-            if level > max_level:
-                raise ValueError(
-                    "not enough adder levels for specified register levels")
-
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
-        n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
-        self.o = AddReduceData(partition_points, n_terms, output_width, n_parts)
+        self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+
+        super().__init__(pspec, "addreduce_%d" % lidx)
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return AddReduceData(self.partition_points, self.n_terms,
+                             self.output_width, self.n_parts)
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
@@ -473,13 +491,13 @@ class AddReduceSingle(Elaboratable):
             terms.append(adder_i.mcarry)
         # handle the remaining inputs.
         if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
             terms.append(adder_i.mcarry)
         # handle the remaining inputs.
         if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
-            terms.append(self.i.inputs[-1])
+            terms.append(self.i.terms[-1])
         elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
             # Just pass the terms to the next layer, since we wouldn't gain
             # anything by using a half adder since there would still be 2 terms
             # and just passing the terms to the next layer saves gates.
         elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
             # Just pass the terms to the next layer, since we wouldn't gain
             # anything by using a half adder since there would still be 2 terms
             # and just passing the terms to the next layer saves gates.
-            terms.append(self.i.inputs[-2])
-            terms.append(self.i.inputs[-1])
+            terms.append(self.i.terms[-2])
+            terms.append(self.i.terms[-1])
         else:
             assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
 
         else:
             assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
 
@@ -493,32 +511,87 @@ class AddReduceSingle(Elaboratable):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
-            m.d.comb += self.o.inputs[i].eq(value)
+            m.d.comb += self.o.terms[i].eq(value)
 
         # copy reg part points and part ops to output
 
         # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(self.i.reg_partition_points)
+        m.d.comb += self.o.part_pts.eq(self.i.part_pts)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
 
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
 
-        mask = self.i.reg_partition_points.as_mask(self.output_width)
+        # get partition points as a mask
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
-            m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
-            m.d.comb += adder_i.in1.eq(self.i.inputs[iidx + 1])
-            m.d.comb += adder_i.in2.eq(self.i.inputs[iidx + 2])
+            m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
+            m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
 
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
 
-class AddReduce(Elaboratable):
+class AddReduceInternal:
+    """Iteratively Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param partition_points: the input partition points.
+        """
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
+        self.partition_step = partition_step
+
+        self.create_levels()
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        partition_points = self.partition_points
+        ilen = self.n_inputs
+        while True:
+            groups = AddReduceSingle.full_adder_groups(ilen)
+            if len(groups) == 0:
+                break
+            lidx = len(mods)
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
+                                         partition_points,
+                                         self.partition_step)
+            mods.append(next_level)
+            partition_points = next_level.i.part_pts
+            ilen = len(next_level.o.terms)
+
+        lidx = len(mods)
+        next_level = FinalAdd(self.pspec, lidx, ilen,
+                              partition_points, self.partition_step)
+        mods.append(next_level)
+
+        self.levels = mods
+
+
+class AddReduce(AddReduceInternal, Elaboratable):
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -530,8 +603,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, inputs, output_width, register_levels, part_pts,
+                       part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -540,15 +613,16 @@ class AddReduce(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.inputs = inputs
-        self.part_ops = part_ops
+        self._inputs = inputs
+        self._part_pts = part_pts
+        self._part_ops = part_ops
         n_parts = len(part_ops)
         n_parts = len(part_ops)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+        self.i = AddReduceData(part_pts, len(inputs),
+                             output_width, n_parts)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
+        self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
         self.register_levels = register_levels
-        self.partition_points = partition_points
-
-        self.create_levels()
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -561,53 +635,19 @@ class AddReduce(Elaboratable):
             if level > 0:
                 yield level - 1
 
             if level > 0:
                 yield level - 1
 
-    def create_levels(self):
-        """creates reduction levels"""
-
-        mods = []
-        next_levels = self.register_levels
-        partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
-        while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
-            if len(groups) == 0:
-                break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         next_levels, partition_points)
-            mods.append(next_level)
-            next_levels = list(AddReduce.next_register_levels(next_levels))
-            partition_points = next_level.i.reg_partition_points
-            inputs = next_level.o.inputs
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
-
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              next_levels, partition_points)
-        mods.append(next_level)
-
-        self.levels = mods
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
-        partition_points = self.partition_points
-        inputs = self.inputs
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        n_inputs = len(inputs)
-        output_width = self.output_width
-        i = AddReduceData(partition_points, n_inputs, output_width, n_parts)
-        m.d.comb += i.eq_from(partition_points, inputs, part_ops)
+        i = self.i
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
-            if 0 in mcur.register_levels:
+            if idx in self.register_levels:
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
@@ -680,8 +720,8 @@ class ProductTerm(Elaboratable):
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
-        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
@@ -695,8 +735,8 @@ class ProductTerm(Elaboratable):
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
-        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
@@ -777,10 +817,10 @@ class LSBNegTerm(Elaboratable):
 
 class Parts(Elaboratable):
 
 
 class Parts(Elaboratable):
 
-    def __init__(self, pbwid, epps, n_parts):
+    def __init__(self, pbwid, part_pts, n_parts):
         self.pbwid = pbwid
         # inputs
         self.pbwid = pbwid
         # inputs
-        self.epps = PartitionPoints.like(epps, name="epps") # expanded points
+        self.part_pts = PartitionPoints.like(part_pts)
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
@@ -788,13 +828,13 @@ class Parts(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
-        epps, parts = self.epps, self.parts
+        part_pts, parts = self.part_pts, self.parts
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
+            m.d.comb += pb.eq(part_pts.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -830,10 +870,10 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
-    def __init__(self, epps, width, n_parts, n_levels, pbwid):
+    def __init__(self, part_pts, width, n_parts, pbwid):
 
         self.pbwid = pbwid
 
         self.pbwid = pbwid
-        self.epps = epps
+        self.part_pts = part_pts
 
         # inputs
         self.a = Signal(64, reset_less=True)
 
         # inputs
         self.a = Signal(64, reset_less=True)
@@ -857,9 +897,9 @@ class Part(Elaboratable):
         m = Module()
 
         pbs, parts = self.pbs, self.parts
         m = Module()
 
         pbs, parts = self.pbs, self.parts
-        epps = self.epps
-        m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
-        m.d.comb += p.epps.eq(epps)
+        part_pts = self.part_pts
+        m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
+        m.d.comb += p.part_pts.eq(part_pts)
         parts = p.parts
 
         byte_count = 8 // len(parts)
         parts = p.parts
 
         byte_count = 8 // len(parts)
@@ -876,7 +916,7 @@ class Part(Elaboratable):
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
-            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
@@ -886,7 +926,7 @@ class Part(Elaboratable):
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
-            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
@@ -924,39 +964,46 @@ class IntermediateOut(Elaboratable):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
-                    self.intermed.part(i * w*2, w),
-                    self.intermed.part(i * w*2 + w, w)))
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, output_width, n_parts, partition_points):
-        self.expanded_part_points = partition_points
-        self.i = IntermediateData(partition_points, output_width, n_parts)
-        self.out_wid = output_width//2
-        # output
-        self.out = Signal(self.out_wid, reset_less=True)
-        self.intermediate_output = Signal(output_width, reset_less=True)
+    def __init__(self, pspec, part_pts):
+
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
+
+        super().__init__(pspec, "finalout")
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        eps = self.expanded_part_points
-        m.submodules.p_8 = p_8 = Parts(8, eps, 8)
-        m.submodules.p_16 = p_16 = Parts(8, eps, 4)
-        m.submodules.p_32 = p_32 = Parts(8, eps, 2)
-        m.submodules.p_64 = p_64 = Parts(8, eps, 1)
+        part_pts = self.part_pts
+        m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
+        m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
+        m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
+        m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
 
 
-        out_part_pts = self.i.reg_partition_points
+        out_part_pts = self.i.part_pts
 
         # temporaries
         d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
 
         # temporaries
         d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
@@ -968,10 +1015,10 @@ class FinalOut(Elaboratable):
         i32 = Signal(self.out_wid, reset_less=True)
         i64 = Signal(self.out_wid, reset_less=True)
 
         i32 = Signal(self.out_wid, reset_less=True)
         i64 = Signal(self.out_wid, reset_less=True)
 
-        m.d.comb += p_8.epps.eq(out_part_pts)
-        m.d.comb += p_16.epps.eq(out_part_pts)
-        m.d.comb += p_32.epps.eq(out_part_pts)
-        m.d.comb += p_64.epps.eq(out_part_pts)
+        m.d.comb += p_8.part_pts.eq(out_part_pts)
+        m.d.comb += p_16.part_pts.eq(out_part_pts)
+        m.d.comb += p_32.part_pts.eq(out_part_pts)
+        m.d.comb += p_64.part_pts.eq(out_part_pts)
 
         for i in range(len(p_8.parts)):
             m.d.comb += d8[i].eq(p_8.parts[i])
 
         for i in range(len(p_8.parts)):
             m.d.comb += d8[i].eq(p_8.parts[i])
@@ -994,11 +1041,16 @@ class FinalOut(Elaboratable):
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
-                    Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)),
-                    Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8))))
+                    Mux(d8[i], i8.bit_select(i * 8, 8),
+                               i16.bit_select(i * 8, 8)),
+                    Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+                                      i64.bit_select(i * 8, 8))))
             ol.append(op)
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
-        m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
         return m
 
 
@@ -1047,18 +1099,18 @@ class Signs(Elaboratable):
 
 class IntermediateData:
 
 
 class IntermediateData:
 
-    def __init__(self, ppoints, output_width, n_parts):
+    def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
         self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
                           for i in range(4)]
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(output_width)
 
         self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
                           for i in range(4)]
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(output_width)
 
-    def eq_from(self, reg_partition_points, outputs, intermediate_output,
+    def eq_from(self, part_pts, outputs, intermediate_output,
                       part_ops):
                       part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
+        return [self.part_pts.eq(part_pts)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
                [self.outputs[i].eq(outputs[i])
                                      for i in range(4)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
                [self.outputs[i].eq(outputs[i])
                                      for i in range(4)] + \
@@ -1066,44 +1118,73 @@ class IntermediateData:
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.outputs,
+        return self.eq_from(rhs.part_pts, rhs.outputs,
                             rhs.intermediate_output, rhs.part_ops)
 
 
                             rhs.intermediate_output, rhs.part_ops)
 
 
-class AllTerms(Elaboratable):
+class InputData:
+
+    def __init__(self):
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self.part_pts = PartitionPoints()
+        for i in range(8, 64, 8):
+            self.part_pts[i] = Signal(name=f"part_pts_{i}")
+        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+
+    def eq_from(self, part_pts, a, b, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.a.eq(a), self.b.eq(b)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
+
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
+
+
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128) # needed for unit tests
+        self.output = Signal(64)
+
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
+
+
+class AllTerms(PipeModBase):
     """Set of terms to be added together
     """
 
     """Set of terms to be added together
     """
 
-    def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        """Create an ``AddReduce``.
-
-        :param inputs: input ``Signal``s to be summed.
-        :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
-        :param partition_points: the input partition points.
+    def __init__(self, pspec, n_inputs):
+        """Create an ``AllTerms``.
         """
         """
-        self.epps = partition_points.like()
-        self.register_levels = register_levels
-        self.pbwid = pbwid
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.o = AddReduceData(self.epps, n_inputs,
-                               output_width, n_parts)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
+        super().__init__(pspec, "allterms")
 
 
-        self.a = Signal(64)
-        self.b = Signal(64)
+    def ispec(self):
+        return InputData()
 
 
-        self.pbs = Signal(pbwid, reset_less=True)
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        pbs = self.pbs
-        eps = self.epps
+        eps = self.i.part_pts
+
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(eps.part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
 
         # local variables
         signs = []
 
         # local variables
         signs = []
@@ -1111,17 +1192,16 @@ class AllTerms(Elaboratable):
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
-            m.d.comb += s.part_ops.eq(self.part_ops[i])
+            m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
 
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
-            m.d.comb += mod.a.eq(self.a)
-            m.d.comb += mod.b.eq(self.b)
+            m.d.comb += mod.a.eq(self.i.a)
+            m.d.comb += mod.b.eq(self.i.b)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
@@ -1137,8 +1217,8 @@ class AllTerms(Elaboratable):
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
-            m.d.comb += t.a.eq(self.a)
-            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.a.eq(self.i.a)
+            m.d.comb += t.b.eq(self.i.b)
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
@@ -1160,29 +1240,38 @@ class AllTerms(Elaboratable):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
-            m.d.comb += self.o.inputs[i].eq(value)
+            m.d.comb += self.o.terms[i].eq(value)
 
         # copy reg part points and part ops to output
 
         # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(eps)
-        m.d.comb += [self.o.part_ops[i].eq(self.part_ops[i])
-                                     for i in range(len(self.part_ops))]
+        m.d.comb += self.o.part_pts.eq(eps)
+        m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+                                     for i in range(len(self.i.part_ops))]
 
         return m
 
 
 
         return m
 
 
-class Intermediates(Elaboratable):
+class Intermediates(PipeModBase):
     """ Intermediate output modules
     """
 
     """ Intermediate output modules
     """
 
-    def __init__(self, output_width, n_parts, partition_points):
-        self.i = FinalReduceData(partition_points, output_width, n_parts)
-        self.o = IntermediateData(partition_points, output_width, n_parts)
+    def __init__(self, pspec, part_pts):
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+
+        super().__init__(pspec, "intermediates")
+
+    def ispec(self):
+        return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
         out_part_ops = self.i.part_ops
 
     def elaborate(self, platform):
         m = Module()
 
         out_part_ops = self.i.part_ops
-        out_part_pts = self.i.reg_partition_points
+        out_part_pts = self.i.part_pts
 
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
 
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
@@ -1214,7 +1303,7 @@ class Intermediates(Elaboratable):
 
         for i in range(8):
             m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
 
         for i in range(8):
             m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
-        m.d.comb += self.o.reg_partition_points.eq(out_part_pts)
+        m.d.comb += self.o.part_pts.eq(out_part_pts)
         m.d.comb += self.o.intermediate_output.eq(self.i.output)
 
         return m
         m.d.comb += self.o.intermediate_output.eq(self.i.output)
 
         return m
@@ -1253,76 +1342,65 @@ class Mul8_16_32_64(Elaboratable):
             flip-flops are to be inserted.
         """
 
             flip-flops are to be inserted.
         """
 
+        self.id_wid = 0 # num_bits(num_rows)
+        self.op_wid = 0
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+        self.pspec.n_parts = 8
+
         # parameter(s)
         self.register_levels = list(register_levels)
 
         # parameter(s)
         self.register_levels = list(register_levels)
 
-        # inputs
-        self.part_pts = PartitionPoints()
-        for i in range(8, 64, 8):
-            self.part_pts[i] = Signal(name=f"part_pts_{i}")
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-        self.a = Signal(64)
-        self.b = Signal(64)
+        self.i = self.ispec()
+        self.o = self.ospec()
 
 
-        # intermediates (needed for unit tests)
-        self.intermediate_output = Signal(128)
+        # inputs
+        self.part_pts = self.i.part_pts
+        self.part_ops = self.i.part_ops
+        self.a = self.i.a
+        self.b = self.i.b
 
         # output
 
         # output
-        self.output = Signal(64)
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        # collect part-bytes
-        pbs = Signal(8, reset_less=True)
-        tl = []
-        for i in range(8):
-            pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(self.part_pts.part_byte(i))
-            tl.append(pb)
-        m.d.comb += pbs.eq(Cat(*tl))
-
-        # create (doubled) PartitionPoints (output is double input width)
-        expanded_part_pts = eps = PartitionPoints()
-        for i, v in self.part_pts.items():
-            ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
-            expanded_part_pts[i * 2] = ep
-            m.d.comb += ep.eq(v)
+        part_pts = self.part_pts
 
         n_inputs = 64 + 4
 
         n_inputs = 64 + 4
-        n_parts = 8 #len(self.part_pts)
-        t = AllTerms(8, n_inputs, 128, n_parts, self.register_levels,
-                       eps)
-        m.submodules.allterms = t
-        m.d.comb += t.a.eq(self.a)
-        m.d.comb += t.b.eq(self.b)
-        m.d.comb += t.pbs.eq(pbs)
-        m.d.comb += t.epps.eq(eps)
-        for i in range(8):
-            m.d.comb += t.part_ops[i].eq(self.part_ops[i])
+        t = AllTerms(self.pspec, n_inputs)
+        t.setup(m, self.i)
 
 
-        terms = t.o.inputs
+        terms = t.o.terms
 
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               t.o.reg_partition_points,
-                               t.o.part_ops)
+        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
 
 
-        out_part_ops = add_reduce.o.part_ops
-        out_part_pts = add_reduce.o.reg_partition_points
-
-        m.submodules.add_reduce = add_reduce
+        i = t.o
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            mcur.setup(m, i)
+            o = mcur.ospec()
+            if idx in self.register_levels:
+                m.d.sync += o.eq(mcur.process(i))
+            else:
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
 
 
-        interm = Intermediates(128, 8, expanded_part_pts)
-        m.submodules.intermediates = interm
-        m.d.comb += interm.i.eq(add_reduce.o)
+        interm = Intermediates(self.pspec, part_pts)
+        interm.setup(m, i)
+        o = interm.process(interm.i)
 
         # final output
 
         # final output
-        m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts)
-        m.d.comb += finalout.i.eq(interm.o)
-        m.d.comb += self.output.eq(finalout.out)
-        m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)
+        finalout = FinalOut(self.pspec, part_pts)
+        finalout.setup(m, o)
+        m.d.comb += self.o.eq(finalout.process(o))
 
         return m
 
 
         return m