add docstring Mul8_16_32_64 only for testing
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 5ffd608fefc697c1a7e9f95a3be92d26ddfa4c8c..215d18c6a1aacce049dd74c6331437c2e74e5853 100644 (file)
@@ -8,6 +8,8 @@ from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -50,15 +52,17 @@ class PartitionPoints(dict):
                     raise ValueError("point must be a non-negative integer")
                 self[point] = Value.wrap(enabled)
 
                     raise ValueError("point must be a non-negative integer")
                 self[point] = Value.wrap(enabled)
 
-    def like(self, name=None, src_loc_at=0):
+    def like(self, name=None, src_loc_at=0, mul=1):
         """Create a new ``PartitionPoints`` with ``Signal``s for all values.
 
         :param name: the base name for the new ``Signal``s.
         """Create a new ``PartitionPoints`` with ``Signal``s for all values.
 
         :param name: the base name for the new ``Signal``s.
+        :param mul: a multiplication factor on the indices
         """
         if name is None:
             name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
         retval = PartitionPoints()
         for point, enabled in self.items():
         """
         if name is None:
             name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
         retval = PartitionPoints()
         for point, enabled in self.items():
+            point *= mul
             retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
         return retval
 
             retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
         return retval
 
@@ -69,17 +73,20 @@ class PartitionPoints(dict):
         for point, enabled in self.items():
             yield enabled.eq(rhs[point])
 
         for point, enabled in self.items():
             yield enabled.eq(rhs[point])
 
-    def as_mask(self, width):
+    def as_mask(self, width, mul=1):
         """Create a bit-mask from `self`.
 
         Each bit in the returned mask is clear only if the partition point at
         the same bit-index is enabled.
 
         :param width: the bit width of the resulting mask
         """Create a bit-mask from `self`.
 
         Each bit in the returned mask is clear only if the partition point at
         the same bit-index is enabled.
 
         :param width: the bit width of the resulting mask
+        :param mul: a "multiplier" which in-place expands the partition points
+                    typically set to "2" when used for multipliers
         """
         bits = []
         for i in range(width):
         """
         bits = []
         for i in range(width):
-            if i in self:
+            i /= mul
+            if i.is_integer() and int(i) in self:
                 bits.append(~self[i])
             else:
                 bits.append(True)
                 bits.append(~self[i])
             else:
                 bits.append(True)
@@ -103,6 +110,12 @@ class PartitionPoints(dict):
                 return False
         return True
 
                 return False
         return True
 
+    def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
+        if index == -1 or index == 7:
+            return C(True, 1)
+        assert index >= 0 and index < 8
+        return self[(index * 8 + 8)*mfactor]
+
 
 class FullAdder(Elaboratable):
     """Full Adder.
 
 class FullAdder(Elaboratable):
     """Full Adder.
@@ -124,11 +137,11 @@ class FullAdder(Elaboratable):
 
         :param width: the bit width of the input and output
         """
 
         :param width: the bit width of the input and output
         """
-        self.in0 = Signal(width)
-        self.in1 = Signal(width)
-        self.in2 = Signal(width)
-        self.sum = Signal(width)
-        self.carry = Signal(width)
+        self.in0 = Signal(width, reset_less=True)
+        self.in1 = Signal(width, reset_less=True)
+        self.in2 = Signal(width, reset_less=True)
+        self.sum = Signal(width, reset_less=True)
+        self.carry = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         """Elaborate this module."""
 
     def elaborate(self, platform):
         """Elaborate this module."""
@@ -219,16 +232,19 @@ class PartitionedAdder(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, width, partition_points):
+    def __init__(self, width, partition_points, partition_step=1):
         """Create a ``PartitionedAdder``.
 
         :param width: the bit width of the input and output
         :param partition_points: the input partition points
         """Create a ``PartitionedAdder``.
 
         :param width: the bit width of the input and output
         :param partition_points: the input partition points
+        :param partition_step: a multiplier (typically double) step
+                               which in-place "expands" the partition points
         """
         self.width = width
         """
         self.width = width
-        self.a = Signal(width)
-        self.b = Signal(width)
-        self.output = Signal(width)
+        self.pmul = partition_step
+        self.a = Signal(width, reset_less=True)
+        self.b = Signal(width, reset_less=True)
+        self.output = Signal(width, reset_less=True)
         self.partition_points = PartitionPoints(partition_points)
         if not self.partition_points.fits_in_width(width):
             raise ValueError("partition_points doesn't fit in width")
         self.partition_points = PartitionPoints(partition_points)
         if not self.partition_points.fits_in_width(width):
             raise ValueError("partition_points doesn't fit in width")
@@ -238,17 +254,14 @@ class PartitionedAdder(Elaboratable):
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
-        # XXX these have to remain here due to some horrible nmigen
-        # simulation bugs involving sync.  it is *not* necessary to
-        # have them here, they should (under normal circumstances)
-        # be moved into elaborate, as they are entirely local
-        self._expanded_a = Signal(expanded_width) # includes extra part-points
-        self._expanded_b = Signal(expanded_width) # likewise.
-        self._expanded_o = Signal(expanded_width) # likewise.
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
+        expanded_a = Signal(self._expanded_width, reset_less=True)
+        expanded_b = Signal(self._expanded_width, reset_less=True)
+        expanded_o = Signal(self._expanded_width, reset_less=True)
+
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
         al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
         al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
@@ -262,17 +275,18 @@ class PartitionedAdder(Elaboratable):
         # carry has been carried *over* the break point.
 
         for i in range(self.width):
         # carry has been carried *over* the break point.
 
         for i in range(self.width):
-            if i in self.partition_points:
+            pi = i/self.pmul # double the range of the partition point test
+            if pi.is_integer() and pi in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
-                ea.append(self._expanded_a[expanded_index])
-                al.append(~self.partition_points[i]) # add extra bit in a
-                eb.append(self._expanded_b[expanded_index])
+                ea.append(expanded_a[expanded_index])
+                al.append(~self.partition_points[pi]) # add extra bit in a
+                eb.append(expanded_b[expanded_index])
                 bl.append(C(0)) # yes, add a zero
                 expanded_index += 1 # skip the extra point.  NOT in the output
                 bl.append(C(0)) # yes, add a zero
                 expanded_index += 1 # skip the extra point.  NOT in the output
-            ea.append(self._expanded_a[expanded_index])
-            eb.append(self._expanded_b[expanded_index])
-            eo.append(self._expanded_o[expanded_index])
+            ea.append(expanded_a[expanded_index])
+            eb.append(expanded_b[expanded_index])
+            eo.append(expanded_o[expanded_index])
             al.append(self.a[i])
             bl.append(self.b[i])
             ol.append(self.output[i])
             al.append(self.a[i])
             bl.append(self.b[i])
             ol.append(self.output[i])
@@ -285,15 +299,106 @@ class PartitionedAdder(Elaboratable):
 
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
 
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
-        m.d.comb += self._expanded_o.eq(
-            self._expanded_a + self._expanded_b)
+        m.d.comb += expanded_o.eq(expanded_a + expanded_b)
         return m
 
 
 FULL_ADDER_INPUT_COUNT = 3
 
         return m
 
 
 FULL_ADDER_INPUT_COUNT = 3
 
+class AddReduceData:
+
+    def __init__(self, part_pts, n_inputs, output_width, n_parts):
+        self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+                          for i in range(n_parts)]
+        self.terms = [Signal(output_width, name=f"terms_{i}",
+                              reset_less=True)
+                        for i in range(n_inputs)]
+        self.part_pts = part_pts.like()
+
+    def eq_from(self, part_pts, inputs, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.terms[i].eq(inputs[i])
+                                     for i in range(len(self.terms))] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
+
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
+
+
+class FinalReduceData:
+
+    def __init__(self, part_pts, output_width, n_parts):
+        self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+                          for i in range(n_parts)]
+        self.output = Signal(output_width, reset_less=True)
+        self.part_pts = part_pts.like()
+
+    def eq_from(self, part_pts, output, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.output.eq(output)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
+
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
+
+
+class FinalAdd(PipeModBase):
+    """ Final stage of add reduce
+    """
+
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.output_width = pspec.width * 2
+        self.n_inputs = n_inputs
+        self.n_parts = pspec.n_parts
+        self.partition_points = PartitionPoints(partition_points)
+        if not self.partition_points.fits_in_width(self.output_width):
+            raise ValueError("partition_points doesn't fit in output_width")
+
+        super().__init__(pspec, "finaladd")
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return FinalReduceData(self.partition_points,
+                                 self.output_width, self.n_parts)
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        m = Module()
+
+        output_width = self.output_width
+        output = Signal(output_width, reset_less=True)
+        if self.n_inputs == 0:
+            # use 0 as the default output value
+            m.d.comb += output.eq(0)
+        elif self.n_inputs == 1:
+            # handle single input
+            m.d.comb += output.eq(self.i.terms[0])
+        else:
+            # base case for adding 2 inputs
+            assert self.n_inputs == 2
+            adder = PartitionedAdder(output_width,
+                                     self.i.part_pts, self.partition_step)
+            m.submodules.final_adder = adder
+            m.d.comb += adder.a.eq(self.i.terms[0])
+            m.d.comb += adder.b.eq(self.i.terms[1])
+            m.d.comb += output.eq(adder.output)
+
+        # create output
+        m.d.comb += self.o.eq_from(self.i.part_pts, output,
+                                   self.i.part_ops)
 
 
-class AddReduceSingle(Elaboratable):
+        return m
+
+
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -305,44 +410,46 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
         :param partition_points: the input partition points.
         """
         :param partition_points: the input partition points.
         """
-        self.part_ops = part_ops
-        self.out_part_ops = [Signal(2, name=f"part_ops_{i}")
-                          for i in range(len(part_ops))]
-        self.inputs = list(inputs)
-        self._resized_inputs = [
-            Signal(output_width, name=f"resized_inputs[{i}]")
-            for i in range(len(self.inputs))]
-        self.register_levels = list(register_levels)
-        self.output = Signal(output_width)
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.n_inputs = n_inputs
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
             raise ValueError("partition_points doesn't fit in output_width")
-        self._reg_partition_points = self.partition_points.like()
-
-        max_level = AddReduceSingle.get_max_level(len(self.inputs))
-        for level in self.register_levels:
-            if level > max_level:
-                raise ValueError(
-                    "not enough adder levels for specified register levels")
-
-        # this is annoying.  we have to create the modules (and terms)
-        # because we need to know what they are (in order to set up the
-        # interconnects back in AddReduce), but cannot do the m.d.comb +=
-        # etc because this is not in elaboratable.
-        self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
-        self._intermediate_terms = []
-        if len(self.groups) != 0:
-            self.create_next_terms()
+
+        self.groups = AddReduceSingle.full_adder_groups(n_inputs)
+        self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+
+        super().__init__(pspec, "addreduce_%d" % lidx)
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return AddReduceData(self.partition_points, self.n_terms,
+                             self.output_width, self.n_parts)
+
+    @staticmethod
+    def calc_n_inputs(n_inputs, groups):
+        retval = len(groups)*2
+        if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            retval += 1
+        elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            retval += 2
+        else:
+            assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
+        return retval
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -367,105 +474,124 @@ class AddReduceSingle(Elaboratable):
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
+    def create_next_terms(self):
+        """ create next intermediate terms, for linking up in elaborate, below
+        """
+        terms = []
+        adders = []
+
+        # create full adders for this recursive level.
+        # this shrinks N terms to 2 * (N // 3) plus the remainder
+        for i in self.groups:
+            adder_i = MaskedFullAdder(self.output_width)
+            adders.append((i, adder_i))
+            # add both the sum and the masked-carry to the next level.
+            # 3 inputs have now been reduced to 2...
+            terms.append(adder_i.sum)
+            terms.append(adder_i.mcarry)
+        # handle the remaining inputs.
+        if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            terms.append(self.i.terms[-1])
+        elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            # Just pass the terms to the next layer, since we wouldn't gain
+            # anything by using a half adder since there would still be 2 terms
+            # and just passing the terms to the next layer saves gates.
+            terms.append(self.i.terms[-2])
+            terms.append(self.i.terms[-1])
+        else:
+            assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
+
+        return terms, adders
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
-        # resize inputs to correct bit-width and optionally add in
-        # pipeline registers
-        resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
-                                     for i in range(len(self.inputs))]
-        copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
-                                     for i in range(len(self.part_ops))]
-        if 0 in self.register_levels:
-            m.d.sync += copy_part_ops
-            m.d.sync += resized_input_assignments
-            m.d.sync += self._reg_partition_points.eq(self.partition_points)
-        else:
-            m.d.comb += copy_part_ops
-            m.d.comb += resized_input_assignments
-            m.d.comb += self._reg_partition_points.eq(self.partition_points)
-
-        for (value, term) in self._intermediate_terms:
-            m.d.comb += term.eq(value)
-
-        # if there are no full adders to create, then we handle the base cases
-        # and return, otherwise we go on to the recursive case
-        if len(self.groups) == 0:
-            if len(self.inputs) == 0:
-                # use 0 as the default output value
-                m.d.comb += self.output.eq(0)
-            elif len(self.inputs) == 1:
-                # handle single input
-                m.d.comb += self.output.eq(self._resized_inputs[0])
-            else:
-                # base case for adding 2 inputs
-                assert len(self.inputs) == 2
-                adder = PartitionedAdder(len(self.output),
-                                         self._reg_partition_points)
-                m.submodules.final_adder = adder
-                m.d.comb += adder.a.eq(self._resized_inputs[0])
-                m.d.comb += adder.b.eq(self._resized_inputs[1])
-                m.d.comb += self.output.eq(adder.output)
-            return m
-
-        mask = self._reg_partition_points.as_mask(len(self.output))
-        m.d.comb += self.part_mask.eq(mask)
+        terms, adders = self.create_next_terms()
+
+        # copy the intermediate terms to the output
+        for i, value in enumerate(terms):
+            m.d.comb += self.o.terms[i].eq(value)
+
+        # copy reg part points and part ops to output
+        m.d.comb += self.o.part_pts.eq(self.i.part_pts)
+        m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+                                     for i in range(len(self.i.part_ops))]
+
+        # set up the partition mask (for the adders)
+        part_mask = Signal(self.output_width, reset_less=True)
+
+        # get partition points as a mask
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
+        m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
 
         # add and link the intermediate term modules
-        for i, (iidx, adder_i) in enumerate(self.adders):
+        for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
             setattr(m.submodules, f"adder_{i}", adder_i)
 
-            m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
-            m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
-            m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
-            m.d.comb += adder_i.mask.eq(self.part_mask)
+            m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
+            m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
+            m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
 
         return m
 
-    def create_next_terms(self):
 
 
-        # go on to prepare recursive case
-        intermediate_terms = []
-        _intermediate_terms = []
+class AddReduceInternal:
+    """Iteratively Add list of numbers together.
 
 
-        def add_intermediate_term(value):
-            intermediate_term = Signal(
-                len(self.output),
-                name=f"intermediate_terms[{len(intermediate_terms)}]")
-            _intermediate_terms.append((value, intermediate_term))
-            intermediate_terms.append(intermediate_term)
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
 
 
-        # store mask in intermediary (simplifies graph)
-        self.part_mask = Signal(len(self.output), reset_less=True)
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
+        """Create an ``AddReduce``.
 
 
-        # create full adders for this recursive level.
-        # this shrinks N terms to 2 * (N // 3) plus the remainder
-        self.adders = []
-        for i in self.groups:
-            adder_i = MaskedFullAdder(len(self.output))
-            self.adders.append((i, adder_i))
-            # add both the sum and the masked-carry to the next level.
-            # 3 inputs have now been reduced to 2...
-            add_intermediate_term(adder_i.sum)
-            add_intermediate_term(adder_i.mcarry)
-        # handle the remaining inputs.
-        if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
-            add_intermediate_term(self._resized_inputs[-1])
-        elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
-            # Just pass the terms to the next layer, since we wouldn't gain
-            # anything by using a half adder since there would still be 2 terms
-            # and just passing the terms to the next layer saves gates.
-            add_intermediate_term(self._resized_inputs[-2])
-            add_intermediate_term(self._resized_inputs[-1])
-        else:
-            assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param partition_points: the input partition points.
+        """
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
+        self.partition_step = partition_step
+
+        self.create_levels()
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        partition_points = self.partition_points
+        ilen = self.n_inputs
+        while True:
+            groups = AddReduceSingle.full_adder_groups(ilen)
+            if len(groups) == 0:
+                break
+            lidx = len(mods)
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
+                                         partition_points,
+                                         self.partition_step)
+            mods.append(next_level)
+            partition_points = next_level.i.part_pts
+            ilen = len(next_level.o.terms)
+
+        lidx = len(mods)
+        next_level = FinalAdd(self.pspec, lidx, ilen,
+                              partition_points, self.partition_step)
+        mods.append(next_level)
 
 
-        self.intermediate_terms = intermediate_terms
-        self._intermediate_terms = _intermediate_terms
+        self.levels = mods
 
 
 
 
-class AddReduce(Elaboratable):
+class AddReduce(AddReduceInternal, Elaboratable):
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -477,8 +603,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, inputs, output_width, register_levels, part_pts,
+                       part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -487,16 +613,16 @@ class AddReduce(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.inputs = inputs
-        self.part_ops = part_ops
-        self.out_part_ops = [Signal(2, name=f"part_ops_{i}")
-                          for i in range(len(part_ops))]
-        self.output = Signal(output_width)
-        self.output_width = output_width
+        self._inputs = inputs
+        self._part_pts = part_pts
+        self._part_ops = part_ops
+        n_parts = len(part_ops)
+        self.i = AddReduceData(part_pts, len(inputs),
+                             output_width, n_parts)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
+        self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
         self.register_levels = register_levels
-        self.partition_points = partition_points
-
-        self.create_levels()
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -509,39 +635,26 @@ class AddReduce(Elaboratable):
             if level > 0:
                 yield level - 1
 
             if level > 0:
                 yield level - 1
 
-    def create_levels(self):
-        """creates reduction levels"""
-
-        mods = []
-        next_levels = self.register_levels
-        partition_points = self.partition_points
-        inputs = self.inputs
-        part_ops = self.part_ops
-        while True:
-            next_level = AddReduceSingle(inputs, self.output_width, next_levels,
-                                         partition_points, part_ops)
-            mods.append(next_level)
-            if len(next_level.groups) == 0:
-                break
-            next_levels = list(AddReduce.next_register_levels(next_levels))
-            partition_points = next_level._reg_partition_points
-            inputs = next_level.intermediate_terms
-            part_ops = next_level.part_ops
-
-        self.levels = mods
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
+        i = self.i
+        for idx in range(len(self.levels)):
+            mcur = self.levels[idx]
+            if idx in self.register_levels:
+                m.d.sync += mcur.i.eq(i)
+            else:
+                m.d.comb += mcur.i.eq(i)
+            i = mcur.o # for next loop
+
         # output comes from last module
         # output comes from last module
-        m.d.comb += self.output.eq(next_level.output)
-        copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
-                                     for i in range(len(self.part_ops))]
-        m.d.comb += copy_part_ops
+        m.d.comb += self.o.eq(i)
 
         return m
 
 
         return m
 
@@ -607,8 +720,8 @@ class ProductTerm(Elaboratable):
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
-        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
@@ -622,8 +735,8 @@ class ProductTerm(Elaboratable):
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
-        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
@@ -702,6 +815,46 @@ class LSBNegTerm(Elaboratable):
         return m
 
 
         return m
 
 
+class Parts(Elaboratable):
+
+    def __init__(self, pbwid, part_pts, n_parts):
+        self.pbwid = pbwid
+        # inputs
+        self.part_pts = PartitionPoints.like(part_pts)
+        # outputs
+        self.parts = [Signal(name=f"part_{i}", reset_less=True)
+                      for i in range(n_parts)]
+
+    def elaborate(self, platform):
+        m = Module()
+
+        part_pts, parts = self.part_pts, self.parts
+        # collect part-bytes (double factor because the input is extended)
+        pbs = Signal(self.pbwid, reset_less=True)
+        tl = []
+        for i in range(self.pbwid):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(part_pts.part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
+
+        # negated-temporary copy of partition bits
+        npbs = Signal.like(pbs, reset_less=True)
+        m.d.comb += npbs.eq(~pbs)
+        byte_count = 8 // len(parts)
+        for i in range(len(parts)):
+            pbl = []
+            pbl.append(npbs[i * byte_count - 1])
+            for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                pbl.append(pbs[j])
+            pbl.append(npbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
+            m.d.comb += value.eq(Cat(*pbl))
+            m.d.comb += parts[i].eq(~(value).bool())
+
+        return m
+
+
 class Part(Elaboratable):
     """ a key class which, depending on the partitioning, will determine
         what action to take when parts of the output are signed or unsigned.
 class Part(Elaboratable):
     """ a key class which, depending on the partitioning, will determine
         what action to take when parts of the output are signed or unsigned.
@@ -717,55 +870,43 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
-    def __init__(self, width, n_parts, n_levels, pbwid):
+    def __init__(self, part_pts, width, n_parts, pbwid):
+
+        self.pbwid = pbwid
+        self.part_pts = part_pts
 
         # inputs
 
         # inputs
-        self.a = Signal(64)
-        self.b = Signal(64)
-        self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
-        self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+        self.a = Signal(64, reset_less=True)
+        self.b = Signal(64, reset_less=True)
+        self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
+                            for i in range(8)]
+        self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
+                            for i in range(8)]
         self.pbs = Signal(pbwid, reset_less=True)
 
         # outputs
         self.pbs = Signal(pbwid, reset_less=True)
 
         # outputs
-        self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
-        self.delayed_parts = [
-            [Signal(name=f"delayed_part_{delay}_{i}")
-             for i in range(n_parts)]
-                for delay in range(n_levels)]
-        # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
-        self.dplast = [Signal(name=f"dplast_{i}")
-                         for i in range(n_parts)]
-
-        self.not_a_term = Signal(width)
-        self.neg_lsb_a_term = Signal(width)
-        self.not_b_term = Signal(width)
-        self.neg_lsb_b_term = Signal(width)
+        self.parts = [Signal(name=f"part_{i}", reset_less=True)
+                            for i in range(n_parts)]
+
+        self.not_a_term = Signal(width, reset_less=True)
+        self.neg_lsb_a_term = Signal(width, reset_less=True)
+        self.not_b_term = Signal(width, reset_less=True)
+        self.neg_lsb_b_term = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
-        # negated-temporary copy of partition bits
-        npbs = Signal.like(pbs, reset_less=True)
-        m.d.comb += npbs.eq(~pbs)
+        pbs, parts = self.pbs, self.parts
+        part_pts = self.part_pts
+        m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
+        m.d.comb += p.part_pts.eq(part_pts)
+        parts = p.parts
+
         byte_count = 8 // len(parts)
         byte_count = 8 // len(parts)
-        for i in range(len(parts)):
-            pbl = []
-            pbl.append(npbs[i * byte_count - 1])
-            for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                pbl.append(pbs[j])
-            pbl.append(npbs[(i + 1) * byte_count - 1])
-            value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
-            m.d.comb += value.eq(Cat(*pbl))
-            m.d.comb += parts[i].eq(~(value).bool())
-            m.d.comb += delayed_parts[0][i].eq(parts[i])
-            m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
-                         for j in range(len(delayed_parts)-1)]
-            m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
 
 
-        not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
-                self.not_a_term, self.neg_lsb_a_term, \
-                self.not_b_term, self.neg_lsb_b_term
+        not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
+                self.not_a_term, self.neg_lsb_a_term,
+                self.not_b_term, self.neg_lsb_b_term)
 
         byte_width = 8 // len(parts) # byte width
         bit_wid = 8 * byte_width     # bit width
 
         byte_width = 8 // len(parts) # byte width
         bit_wid = 8 * byte_width     # bit width
@@ -775,7 +916,7 @@ class Part(Elaboratable):
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
-            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
@@ -785,7 +926,7 @@ class Part(Elaboratable):
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
-            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
@@ -808,7 +949,7 @@ class IntermediateOut(Elaboratable):
     def __init__(self, width, out_wid, n_parts):
         self.width = width
         self.n_parts = n_parts
     def __init__(self, width, out_wid, n_parts):
         self.width = width
         self.n_parts = n_parts
-        self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
+        self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
                                      for i in range(8)]
         self.intermed = Signal(out_wid, reset_less=True)
         self.output = Signal(out_wid//2, reset_less=True)
                                      for i in range(8)]
         self.intermed = Signal(out_wid, reset_less=True)
         self.output = Signal(out_wid//2, reset_less=True)
@@ -822,38 +963,74 @@ class IntermediateOut(Elaboratable):
         for i in range(self.n_parts):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
         for i in range(self.n_parts):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
-                Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
-                    self.intermed.part(i * w*2, w),
-                    self.intermed.part(i * w*2 + w, w)))
+                Mux(self.part_ops[sel * i] == OP_MUL_LOW,
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, out_wid):
-        # inputs
-        self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
-        self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
-        self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+    def __init__(self, pspec, part_pts):
 
 
-        self.i8 = Signal(out_wid, reset_less=True)
-        self.i16 = Signal(out_wid, reset_less=True)
-        self.i32 = Signal(out_wid, reset_less=True)
-        self.i64 = Signal(out_wid, reset_less=True)
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
 
 
-        # output
-        self.out = Signal(out_wid, reset_less=True)
+        super().__init__(pspec, "finalout")
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
+
+        part_pts = self.part_pts
+        m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
+        m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
+        m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
+        m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
+
+        out_part_pts = self.i.part_pts
+
+        # temporaries
+        d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+        d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+        d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+        i8 = Signal(self.out_wid, reset_less=True)
+        i16 = Signal(self.out_wid, reset_less=True)
+        i32 = Signal(self.out_wid, reset_less=True)
+        i64 = Signal(self.out_wid, reset_less=True)
+
+        m.d.comb += p_8.part_pts.eq(out_part_pts)
+        m.d.comb += p_16.part_pts.eq(out_part_pts)
+        m.d.comb += p_32.part_pts.eq(out_part_pts)
+        m.d.comb += p_64.part_pts.eq(out_part_pts)
+
+        for i in range(len(p_8.parts)):
+            m.d.comb += d8[i].eq(p_8.parts[i])
+        for i in range(len(p_16.parts)):
+            m.d.comb += d16[i].eq(p_16.parts[i])
+        for i in range(len(p_32.parts)):
+            m.d.comb += d32[i].eq(p_32.parts[i])
+        m.d.comb += i8.eq(self.i.outputs[0])
+        m.d.comb += i16.eq(self.i.outputs[1])
+        m.d.comb += i32.eq(self.i.outputs[2])
+        m.d.comb += i64.eq(self.i.outputs[3])
+
         ol = []
         for i in range(8):
             # select one of the outputs: d8 selects i8, d16 selects i16
         ol = []
         for i in range(8):
             # select one of the outputs: d8 selects i8, d16 selects i16
@@ -863,13 +1040,17 @@ class FinalOut(Elaboratable):
             # if neither d8 nor d16 are set, d32 selects either i32 or i64.
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
             # if neither d8 nor d16 are set, d32 selects either i32 or i64.
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
-                Mux(self.d8[i] | self.d16[i // 2],
-                    Mux(self.d8[i], self.i8.part(i * 8, 8),
-                                     self.i16.part(i * 8, 8)),
-                    Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
-                                          self.i64.part(i * 8, 8))))
+                Mux(d8[i] | d16[i // 2],
+                    Mux(d8[i], i8.bit_select(i * 8, 8),
+                               i16.bit_select(i * 8, 8)),
+                    Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+                                      i64.bit_select(i * 8, 8))))
             ol.append(op)
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
         return m
 
 
@@ -916,71 +1097,92 @@ class Signs(Elaboratable):
         return m
 
 
         return m
 
 
-class Mul8_16_32_64(Elaboratable):
-    """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
-
-    Supports partitioning into any combination of 8, 16, 32, and 64-bit
-    partitions on naturally-aligned boundaries. Supports the operation being
-    set for each partition independently.
+class IntermediateData:
 
 
-    :attribute part_pts: the input partition points. Has a partition point at
-        multiples of 8 in 0 < i < 64. Each partition point's associated
-        ``Value`` is a ``Signal``. Modification not supported, except for by
-        ``Signal.eq``.
-    :attribute part_ops: the operation for each byte. The operation for a
-        particular partition is selected by assigning the selected operation
-        code to each byte in the partition. The allowed operation codes are:
+    def __init__(self, part_pts, output_width, n_parts):
+        self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+                          for i in range(n_parts)]
+        self.part_pts = part_pts.like()
+        self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
+                          for i in range(4)]
+        # intermediates (needed for unit tests)
+        self.intermediate_output = Signal(output_width)
+
+    def eq_from(self, part_pts, outputs, intermediate_output,
+                      part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.intermediate_output.eq(intermediate_output)] + \
+               [self.outputs[i].eq(outputs[i])
+                                     for i in range(4)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
 
 
-        :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
-            RISC-V's `mul` instruction.
-        :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
-            ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
-            instruction.
-        :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
-            where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
-            `mulhsu` instruction.
-        :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
-            ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
-            instruction.
-    """
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.outputs,
+                            rhs.intermediate_output, rhs.part_ops)
 
 
-    def __init__(self, register_levels=()):
-        """ register_levels: specifies the points in the cascade at which
-            flip-flops are to be inserted.
-        """
 
 
-        # parameter(s)
-        self.register_levels = list(register_levels)
+class InputData:
 
 
-        # inputs
+    def __init__(self):
+        self.a = Signal(64)
+        self.b = Signal(64)
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-        self.a = Signal(64)
-        self.b = Signal(64)
 
 
-        # intermediates (needed for unit tests)
-        self._intermediate_output = Signal(128)
+    def eq_from(self, part_pts, a, b, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.a.eq(a), self.b.eq(b)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
 
 
-        # output
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
+
+
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128) # needed for unit tests
         self.output = Signal(64)
 
         self.output = Signal(64)
 
-    def _part_byte(self, index):
-        if index == -1 or index == 7:
-            return C(True, 1)
-        assert index >= 0 and index < 8
-        return self.part_pts[index * 8 + 8]
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
+
+
+class AllTerms(PipeModBase):
+    """Set of terms to be added together
+    """
+
+    def __init__(self, pspec, n_inputs):
+        """Create an ``AllTerms``.
+        """
+        self.n_inputs = n_inputs
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
+        super().__init__(pspec, "allterms")
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
+        eps = self.i.part_pts
+
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
         tl = []
         for i in range(8):
             pb = Signal(name="pb%d" % i, reset_less=True)
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
         tl = []
         for i in range(8):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(self._part_byte(i))
+            m.d.comb += pb.eq(eps.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -990,26 +1192,16 @@ class Mul8_16_32_64(Elaboratable):
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
-            m.d.comb += s.part_ops.eq(self.part_ops[i])
-
-        delayed_part_ops = [
-            [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        for i in range(len(self.part_ops)):
-            m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
-            m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
-                         for j in range(len(self.register_levels))]
-
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
+            m.d.comb += s.part_ops.eq(self.i.part_ops[i])
+
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
-            m.d.comb += mod.a.eq(self.a)
-            m.d.comb += mod.b.eq(self.b)
+            m.d.comb += mod.a.eq(self.i.a)
+            m.d.comb += mod.b.eq(self.i.b)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
@@ -1025,8 +1217,8 @@ class Mul8_16_32_64(Elaboratable):
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
-            m.d.comb += t.a.eq(self.a)
-            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.a.eq(self.i.a)
+            m.d.comb += t.b.eq(self.i.b)
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
@@ -1046,60 +1238,171 @@ class Mul8_16_32_64(Elaboratable):
                 m.d.comb += mod.orin[i].eq(l[i])
             terms.append(mod.orout)
 
                 m.d.comb += mod.orin[i].eq(l[i])
             terms.append(mod.orout)
 
-        expanded_part_pts = PartitionPoints()
-        for i, v in self.part_pts.items():
-            signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
-            expanded_part_pts[i * 2] = signal
-            m.d.comb += signal.eq(v)
+        # copy the intermediate terms to the output
+        for i, value in enumerate(terms):
+            m.d.comb += self.o.terms[i].eq(value)
+
+        # copy reg part points and part ops to output
+        m.d.comb += self.o.part_pts.eq(eps)
+        m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+                                     for i in range(len(self.i.part_ops))]
+
+        return m
+
+
+class Intermediates(PipeModBase):
+    """ Intermediate output modules
+    """
+
+    def __init__(self, pspec, part_pts):
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+
+        super().__init__(pspec, "intermediates")
 
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               expanded_part_pts,
-                               self.part_ops)
+    def ispec(self):
+        return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
 
 
-        #out_part_ops = add_reduce.levels[-1].out_part_ops
-        out_part_ops = delayed_part_ops[-1]
+    def ospec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        out_part_ops = self.i.part_ops
+        out_part_pts = self.i.part_pts
 
 
-        m.submodules.add_reduce = add_reduce
-        m.d.comb += self._intermediate_output.eq(add_reduce.output)
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
-        m.d.comb += io64.intermed.eq(self._intermediate_output)
+        m.d.comb += io64.intermed.eq(self.i.output)
         for i in range(8):
         for i in range(8):
-            m.d.comb += io64.delayed_part_ops[i].eq(out_part_ops[i])
+            m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[3].eq(io64.output)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
-        m.d.comb += io32.intermed.eq(self._intermediate_output)
+        m.d.comb += io32.intermed.eq(self.i.output)
         for i in range(8):
         for i in range(8):
-            m.d.comb += io32.delayed_part_ops[i].eq(out_part_ops[i])
+            m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[2].eq(io32.output)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
-        m.d.comb += io16.intermed.eq(self._intermediate_output)
+        m.d.comb += io16.intermed.eq(self.i.output)
         for i in range(8):
         for i in range(8):
-            m.d.comb += io16.delayed_part_ops[i].eq(out_part_ops[i])
+            m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[1].eq(io16.output)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
-        m.d.comb += io8.intermed.eq(self._intermediate_output)
+        m.d.comb += io8.intermed.eq(self.i.output)
+        for i in range(8):
+            m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[0].eq(io8.output)
+
         for i in range(8):
         for i in range(8):
-            m.d.comb += io8.delayed_part_ops[i].eq(out_part_ops[i])
+            m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.part_pts.eq(out_part_pts)
+        m.d.comb += self.o.intermediate_output.eq(self.i.output)
+
+        return m
+
+
+class Mul8_16_32_64(Elaboratable):
+    """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
+
+    XXX NOTE: this class is intended for unit test purposes ONLY.
+
+    Supports partitioning into any combination of 8, 16, 32, and 64-bit
+    partitions on naturally-aligned boundaries. Supports the operation being
+    set for each partition independently.
+
+    :attribute part_pts: the input partition points. Has a partition point at
+        multiples of 8 in 0 < i < 64. Each partition point's associated
+        ``Value`` is a ``Signal``. Modification not supported, except for by
+        ``Signal.eq``.
+    :attribute part_ops: the operation for each byte. The operation for a
+        particular partition is selected by assigning the selected operation
+        code to each byte in the partition. The allowed operation codes are:
+
+        :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
+            RISC-V's `mul` instruction.
+        :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
+            instruction.
+        :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
+            where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
+            `mulhsu` instruction.
+        :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
+            instruction.
+    """
+
+    def __init__(self, register_levels=()):
+        """ register_levels: specifies the points in the cascade at which
+            flip-flops are to be inserted.
+        """
+
+        self.id_wid = 0 # num_bits(num_rows)
+        self.op_wid = 0
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+        self.pspec.n_parts = 8
+
+        # parameter(s)
+        self.register_levels = list(register_levels)
+
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+        # inputs
+        self.part_pts = self.i.part_pts
+        self.part_ops = self.i.part_ops
+        self.a = self.i.a
+        self.b = self.i.b
+
+        # output
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
+
+    def elaborate(self, platform):
+        m = Module()
+
+        part_pts = self.part_pts
+
+        n_inputs = 64 + 4
+        t = AllTerms(self.pspec, n_inputs)
+        t.setup(m, self.i)
+
+        terms = t.o.terms
+
+        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
+
+        i = t.o
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            mcur.setup(m, i)
+            o = mcur.ospec()
+            if idx in self.register_levels:
+                m.d.sync += o.eq(mcur.process(i))
+            else:
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
+
+        interm = Intermediates(self.pspec, part_pts)
+        interm.setup(m, i)
+        o = interm.process(interm.i)
 
         # final output
 
         # final output
-        m.submodules.finalout = finalout = FinalOut(64)
-        for i in range(len(part_8.delayed_parts[-1])):
-            m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
-        for i in range(len(part_16.delayed_parts[-1])):
-            m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
-        for i in range(len(part_32.delayed_parts[-1])):
-            m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
-        m.d.comb += finalout.i8.eq(io8.output)
-        m.d.comb += finalout.i16.eq(io16.output)
-        m.d.comb += finalout.i32.eq(io32.output)
-        m.d.comb += finalout.i64.eq(io64.output)
-        m.d.comb += self.output.eq(finalout.out)
+        finalout = FinalOut(self.pspec, part_pts)
+        finalout.setup(m, o)
+        m.d.comb += self.o.eq(finalout.process(o))
 
         return m
 
 
         return m
 
@@ -1108,7 +1411,7 @@ if __name__ == "__main__":
     m = Mul8_16_32_64()
     main(m, ports=[m.a,
                    m.b,
     m = Mul8_16_32_64()
     main(m, ports=[m.a,
                    m.b,
-                   m._intermediate_output,
+                   m.intermediate_output,
                    m.output,
                    *m.part_ops,
                    *m.part_pts.values()])
                    m.output,
                    *m.part_ops,
                    *m.part_pts.values()])