rename inputs_ to terms_
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index d0d84e4f203a875ae5cfde7eb8df478e2d79c96a..89e4c4d322782ff1e16b2e88646405d6dcce0bd3 100644 (file)
@@ -8,6 +8,8 @@ from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -305,60 +307,68 @@ FULL_ADDER_INPUT_COUNT = 3
 
 class AddReduceData:
 
 
 class AddReduceData:
 
-    def __init__(self, ppoints, n_inputs, output_width, n_parts):
+    def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.inputs = [Signal(output_width, name=f"inputs_{i}",
+        self.terms = [Signal(output_width, name=f"terms_{i}",
                               reset_less=True)
                         for i in range(n_inputs)]
                               reset_less=True)
                         for i in range(n_inputs)]
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
 
 
-    def eq_from(self, reg_partition_points, inputs, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
-               [self.inputs[i].eq(inputs[i])
-                                     for i in range(len(self.inputs))] + \
+    def eq_from(self, part_pts, inputs, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.terms[i].eq(inputs[i])
+                                     for i in range(len(self.terms))] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.inputs, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
 
 
 class FinalReduceData:
 
 
 
 class FinalReduceData:
 
-    def __init__(self, ppoints, output_width, n_parts):
+    def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
 
 
-    def eq_from(self, reg_partition_points, output, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
+    def eq_from(self, part_pts, output, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.output, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
 
 
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
     """ Final stage of add reduce
     """
 
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.output_width = pspec.width * 2
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
+        super().__init__(pspec, "finaladd")
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return FinalReduceData(self.partition_points,
+                                 self.output_width, self.n_parts)
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -370,24 +380,25 @@ class FinalAdd(Elaboratable):
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
-            m.d.comb += output.eq(self.i.inputs[0])
+            m.d.comb += output.eq(self.i.terms[0])
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
-            adder = PartitionedAdder(output_width, self.i.reg_partition_points)
+            adder = PartitionedAdder(output_width,
+                                     self.i.part_pts, self.partition_step)
             m.submodules.final_adder = adder
             m.submodules.final_adder = adder
-            m.d.comb += adder.a.eq(self.i.inputs[0])
-            m.d.comb += adder.b.eq(self.i.inputs[1])
+            m.d.comb += adder.a.eq(self.i.terms[0])
+            m.d.comb += adder.b.eq(self.i.terms[1])
             m.d.comb += output.eq(adder.output)
 
         # create output
             m.d.comb += output.eq(adder.output)
 
         # create output
-        m.d.comb += self.o.eq_from(self.i.reg_partition_points, output,
+        m.d.comb += self.o.eq_from(self.i.part_pts, output,
                                    self.i.part_ops)
 
         return m
 
 
                                    self.i.part_ops)
 
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -399,35 +410,35 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                       partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
         :param partition_points: the input partition points.
         """
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
+        self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
-        max_level = AddReduceSingle.get_max_level(n_inputs)
-        for level in self.register_levels:
-            if level > max_level:
-                raise ValueError(
-                    "not enough adder levels for specified register levels")
-
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
-        n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
-        self.o = AddReduceData(partition_points, n_terms, output_width, n_parts)
+        self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+
+        super().__init__(pspec, "addreduce_%d" % lidx)
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return AddReduceData(self.partition_points, self.n_terms,
+                             self.output_width, self.n_parts)
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
 
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
@@ -480,13 +491,13 @@ class AddReduceSingle(Elaboratable):
             terms.append(adder_i.mcarry)
         # handle the remaining inputs.
         if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
             terms.append(adder_i.mcarry)
         # handle the remaining inputs.
         if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
-            terms.append(self.i.inputs[-1])
+            terms.append(self.i.terms[-1])
         elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
             # Just pass the terms to the next layer, since we wouldn't gain
             # anything by using a half adder since there would still be 2 terms
             # and just passing the terms to the next layer saves gates.
         elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
             # Just pass the terms to the next layer, since we wouldn't gain
             # anything by using a half adder since there would still be 2 terms
             # and just passing the terms to the next layer saves gates.
-            terms.append(self.i.inputs[-2])
-            terms.append(self.i.inputs[-1])
+            terms.append(self.i.terms[-2])
+            terms.append(self.i.terms[-1])
         else:
             assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
 
         else:
             assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
 
@@ -500,10 +511,10 @@ class AddReduceSingle(Elaboratable):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
-            m.d.comb += self.o.inputs[i].eq(value)
+            m.d.comb += self.o.terms[i].eq(value)
 
         # copy reg part points and part ops to output
 
         # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(self.i.reg_partition_points)
+        m.d.comb += self.o.part_pts.eq(self.i.part_pts)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
@@ -511,22 +522,76 @@ class AddReduceSingle(Elaboratable):
         part_mask = Signal(self.output_width, reset_less=True)
 
         # get partition points as a mask
         part_mask = Signal(self.output_width, reset_less=True)
 
         # get partition points as a mask
-        mask = self.i.reg_partition_points.as_mask(self.output_width, mul=2)
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
-            m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
-            m.d.comb += adder_i.in1.eq(self.i.inputs[iidx + 1])
-            m.d.comb += adder_i.in2.eq(self.i.inputs[iidx + 2])
+            m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
+            m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
 
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
 
-class AddReduce(Elaboratable):
+class AddReduceInternal:
+    """Iteratively Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param partition_points: the input partition points.
+        """
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
+        self.partition_step = partition_step
+
+        self.create_levels()
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        partition_points = self.partition_points
+        ilen = self.n_inputs
+        while True:
+            groups = AddReduceSingle.full_adder_groups(ilen)
+            if len(groups) == 0:
+                break
+            lidx = len(mods)
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
+                                         partition_points,
+                                         self.partition_step)
+            mods.append(next_level)
+            partition_points = next_level.i.part_pts
+            ilen = len(next_level.o.terms)
+
+        lidx = len(mods)
+        next_level = FinalAdd(self.pspec, lidx, ilen,
+                              partition_points, self.partition_step)
+        mods.append(next_level)
+
+        self.levels = mods
+
+
+class AddReduce(AddReduceInternal, Elaboratable):
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -538,8 +603,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, inputs, output_width, register_levels, part_pts,
+                       part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -548,15 +613,16 @@ class AddReduce(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.inputs = inputs
-        self.part_ops = part_ops
+        self._inputs = inputs
+        self._part_pts = part_pts
+        self._part_ops = part_ops
         n_parts = len(part_ops)
         n_parts = len(part_ops)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+        self.i = AddReduceData(part_pts, len(inputs),
+                             output_width, n_parts)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
+        self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
         self.register_levels = register_levels
-        self.partition_points = partition_points
-
-        self.create_levels()
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -569,53 +635,19 @@ class AddReduce(Elaboratable):
             if level > 0:
                 yield level - 1
 
             if level > 0:
                 yield level - 1
 
-    def create_levels(self):
-        """creates reduction levels"""
-
-        mods = []
-        next_levels = self.register_levels
-        partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
-        while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
-            if len(groups) == 0:
-                break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         next_levels, partition_points)
-            mods.append(next_level)
-            next_levels = list(AddReduce.next_register_levels(next_levels))
-            partition_points = next_level.i.reg_partition_points
-            inputs = next_level.o.inputs
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
-
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              next_levels, partition_points)
-        mods.append(next_level)
-
-        self.levels = mods
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
-        partition_points = self.partition_points
-        inputs = self.inputs
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        n_inputs = len(inputs)
-        output_width = self.output_width
-        i = AddReduceData(partition_points, n_inputs, output_width, n_parts)
-        m.d.comb += i.eq_from(partition_points, inputs, part_ops)
+        i = self.i
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
-            if 0 in mcur.register_levels:
+            if idx in self.register_levels:
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
@@ -688,8 +720,8 @@ class ProductTerm(Elaboratable):
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
-        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
@@ -703,8 +735,8 @@ class ProductTerm(Elaboratable):
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
-        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
@@ -785,10 +817,10 @@ class LSBNegTerm(Elaboratable):
 
 class Parts(Elaboratable):
 
 
 class Parts(Elaboratable):
 
-    def __init__(self, pbwid, epps, n_parts):
+    def __init__(self, pbwid, part_pts, n_parts):
         self.pbwid = pbwid
         # inputs
         self.pbwid = pbwid
         # inputs
-        self.epps = PartitionPoints.like(epps, name="epps") # expanded points
+        self.part_pts = PartitionPoints.like(part_pts)
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
@@ -796,13 +828,13 @@ class Parts(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
-        epps, parts = self.epps, self.parts
+        part_pts, parts = self.part_pts, self.parts
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(epps.part_byte(i))
+            m.d.comb += pb.eq(part_pts.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -838,10 +870,10 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
-    def __init__(self, epps, width, n_parts, n_levels, pbwid):
+    def __init__(self, part_pts, width, n_parts, pbwid):
 
         self.pbwid = pbwid
 
         self.pbwid = pbwid
-        self.epps = epps
+        self.part_pts = part_pts
 
         # inputs
         self.a = Signal(64, reset_less=True)
 
         # inputs
         self.a = Signal(64, reset_less=True)
@@ -865,9 +897,9 @@ class Part(Elaboratable):
         m = Module()
 
         pbs, parts = self.pbs, self.parts
         m = Module()
 
         pbs, parts = self.pbs, self.parts
-        epps = self.epps
-        m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
-        m.d.comb += p.epps.eq(epps)
+        part_pts = self.part_pts
+        m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
+        m.d.comb += p.part_pts.eq(part_pts)
         parts = p.parts
 
         byte_count = 8 // len(parts)
         parts = p.parts
 
         byte_count = 8 // len(parts)
@@ -884,7 +916,7 @@ class Part(Elaboratable):
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
-            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
             m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
             nat.append(pa.nt)
@@ -894,7 +926,7 @@ class Part(Elaboratable):
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
-            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
             m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
             nbt.append(pb.nt)
@@ -932,39 +964,46 @@ class IntermediateOut(Elaboratable):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
-                    self.intermed.part(i * w*2, w),
-                    self.intermed.part(i * w*2 + w, w)))
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, output_width, n_parts, partition_points):
-        self.expanded_part_points = partition_points
-        self.i = IntermediateData(partition_points, output_width, n_parts)
-        self.out_wid = output_width//2
-        # output
-        self.out = Signal(self.out_wid, reset_less=True)
-        self.intermediate_output = Signal(output_width, reset_less=True)
+    def __init__(self, pspec, part_pts):
+
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
+
+        super().__init__(pspec, "finalout")
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        eps = self.expanded_part_points
-        m.submodules.p_8 = p_8 = Parts(8, eps, 8)
-        m.submodules.p_16 = p_16 = Parts(8, eps, 4)
-        m.submodules.p_32 = p_32 = Parts(8, eps, 2)
-        m.submodules.p_64 = p_64 = Parts(8, eps, 1)
+        part_pts = self.part_pts
+        m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
+        m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
+        m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
+        m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
 
 
-        out_part_pts = self.i.reg_partition_points
+        out_part_pts = self.i.part_pts
 
         # temporaries
         d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
 
         # temporaries
         d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
@@ -976,10 +1015,10 @@ class FinalOut(Elaboratable):
         i32 = Signal(self.out_wid, reset_less=True)
         i64 = Signal(self.out_wid, reset_less=True)
 
         i32 = Signal(self.out_wid, reset_less=True)
         i64 = Signal(self.out_wid, reset_less=True)
 
-        m.d.comb += p_8.epps.eq(out_part_pts)
-        m.d.comb += p_16.epps.eq(out_part_pts)
-        m.d.comb += p_32.epps.eq(out_part_pts)
-        m.d.comb += p_64.epps.eq(out_part_pts)
+        m.d.comb += p_8.part_pts.eq(out_part_pts)
+        m.d.comb += p_16.part_pts.eq(out_part_pts)
+        m.d.comb += p_32.part_pts.eq(out_part_pts)
+        m.d.comb += p_64.part_pts.eq(out_part_pts)
 
         for i in range(len(p_8.parts)):
             m.d.comb += d8[i].eq(p_8.parts[i])
 
         for i in range(len(p_8.parts)):
             m.d.comb += d8[i].eq(p_8.parts[i])
@@ -1002,11 +1041,16 @@ class FinalOut(Elaboratable):
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
-                    Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)),
-                    Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8))))
+                    Mux(d8[i], i8.bit_select(i * 8, 8),
+                               i16.bit_select(i * 8, 8)),
+                    Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+                                      i64.bit_select(i * 8, 8))))
             ol.append(op)
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
-        m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
         return m
 
 
@@ -1055,18 +1099,18 @@ class Signs(Elaboratable):
 
 class IntermediateData:
 
 
 class IntermediateData:
 
-    def __init__(self, ppoints, output_width, n_parts):
+    def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
         self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
                           for i in range(4)]
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(output_width)
 
         self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
                           for i in range(4)]
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(output_width)
 
-    def eq_from(self, reg_partition_points, outputs, intermediate_output,
+    def eq_from(self, part_pts, outputs, intermediate_output,
                       part_ops):
                       part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
+        return [self.part_pts.eq(part_pts)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
                [self.outputs[i].eq(outputs[i])
                                      for i in range(4)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
                [self.outputs[i].eq(outputs[i])
                                      for i in range(4)] + \
@@ -1074,54 +1118,64 @@ class IntermediateData:
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.outputs,
+        return self.eq_from(rhs.part_pts, rhs.outputs,
                             rhs.intermediate_output, rhs.part_ops)
 
 
                             rhs.intermediate_output, rhs.part_ops)
 
 
-class AllTermsData:
+class InputData:
 
 
-    def __init__(self, partition_points):
+    def __init__(self):
         self.a = Signal(64)
         self.b = Signal(64)
         self.a = Signal(64)
         self.b = Signal(64)
-        self.epps = partition_points.like()
+        self.part_pts = PartitionPoints()
+        for i in range(8, 64, 8):
+            self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
-    def eq_from(self, epps, inputs, part_ops):
-        return [self.epps.eq(epps)] + \
+    def eq_from(self, part_pts, a, b, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
                [self.a.eq(a), self.b.eq(b)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
                [self.a.eq(a), self.b.eq(b)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
-        return self.eq_from(rhs.epps, rhs.a, rhs.b, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
+
+
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128) # needed for unit tests
+        self.output = Signal(64)
+
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
 
 
 
 
-class AllTerms(Elaboratable):
+class AllTerms(PipeModBase):
     """Set of terms to be added together
     """
 
     """Set of terms to be added together
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        """Create an ``AddReduce``.
-
-        :param inputs: input ``Signal``s to be summed.
-        :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
-        :param partition_points: the input partition points.
+    def __init__(self, pspec, n_inputs):
+        """Create an ``AllTerms``.
         """
         """
-        self.i = AllTermsData(partition_points)
-        self.register_levels = register_levels
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.o = AddReduceData(self.i.epps, n_inputs,
-                               output_width, n_parts)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
+        super().__init__(pspec, "allterms")
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        eps = self.i.epps
+        eps = self.i.part_pts
 
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
 
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
@@ -1140,11 +1194,10 @@ class AllTerms(Elaboratable):
             setattr(m.submodules, "signs%d" % i, s)
             m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
             setattr(m.submodules, "signs%d" % i, s)
             m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
             m.d.comb += mod.a.eq(self.i.a)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
             m.d.comb += mod.a.eq(self.i.a)
@@ -1187,29 +1240,38 @@ class AllTerms(Elaboratable):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
 
         # copy the intermediate terms to the output
         for i, value in enumerate(terms):
-            m.d.comb += self.o.inputs[i].eq(value)
+            m.d.comb += self.o.terms[i].eq(value)
 
         # copy reg part points and part ops to output
 
         # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(eps)
+        m.d.comb += self.o.part_pts.eq(eps)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
         return m
 
 
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
                                      for i in range(len(self.i.part_ops))]
 
         return m
 
 
-class Intermediates(Elaboratable):
+class Intermediates(PipeModBase):
     """ Intermediate output modules
     """
 
     """ Intermediate output modules
     """
 
-    def __init__(self, output_width, n_parts, partition_points):
-        self.i = FinalReduceData(partition_points, output_width, n_parts)
-        self.o = IntermediateData(partition_points, output_width, n_parts)
+    def __init__(self, pspec, part_pts):
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+
+        super().__init__(pspec, "intermediates")
+
+    def ispec(self):
+        return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
 
         out_part_ops = self.i.part_ops
 
     def elaborate(self, platform):
         m = Module()
 
         out_part_ops = self.i.part_ops
-        out_part_pts = self.i.reg_partition_points
+        out_part_pts = self.i.part_pts
 
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
 
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
@@ -1241,7 +1303,7 @@ class Intermediates(Elaboratable):
 
         for i in range(8):
             m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
 
         for i in range(8):
             m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
-        m.d.comb += self.o.reg_partition_points.eq(out_part_pts)
+        m.d.comb += self.o.part_pts.eq(out_part_pts)
         m.d.comb += self.o.intermediate_output.eq(self.i.output)
 
         return m
         m.d.comb += self.o.intermediate_output.eq(self.i.output)
 
         return m
@@ -1280,66 +1342,65 @@ class Mul8_16_32_64(Elaboratable):
             flip-flops are to be inserted.
         """
 
             flip-flops are to be inserted.
         """
 
+        self.id_wid = 0 # num_bits(num_rows)
+        self.op_wid = 0
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+        self.pspec.n_parts = 8
+
         # parameter(s)
         self.register_levels = list(register_levels)
 
         # parameter(s)
         self.register_levels = list(register_levels)
 
-        # inputs
-        self.part_pts = PartitionPoints()
-        for i in range(8, 64, 8):
-            self.part_pts[i] = Signal(name=f"part_pts_{i}")
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-        self.a = Signal(64)
-        self.b = Signal(64)
+        self.i = self.ispec()
+        self.o = self.ospec()
 
 
-        # intermediates (needed for unit tests)
-        self.intermediate_output = Signal(128)
+        # inputs
+        self.part_pts = self.i.part_pts
+        self.part_ops = self.i.part_ops
+        self.a = self.i.a
+        self.b = self.i.b
 
         # output
 
         # output
-        self.output = Signal(64)
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        # create (doubled) PartitionPoints (output is double input width)
-        expanded_part_pts = eps = PartitionPoints()
-        for i, v in self.part_pts.items():
-            ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
-            expanded_part_pts[i] = ep
-            m.d.comb += ep.eq(v)
+        part_pts = self.part_pts
 
         n_inputs = 64 + 4
 
         n_inputs = 64 + 4
-        n_parts = 8 #len(self.part_pts)
-        t = AllTerms(n_inputs, 128, n_parts, self.register_levels,
-                       eps)
-        m.submodules.allterms = t
-        m.d.comb += t.i.a.eq(self.a)
-        m.d.comb += t.i.b.eq(self.b)
-        m.d.comb += t.i.epps.eq(eps)
-        for i in range(8):
-            m.d.comb += t.i.part_ops[i].eq(self.part_ops[i])
-
-        terms = t.o.inputs
+        t = AllTerms(self.pspec, n_inputs)
+        t.setup(m, self.i)
 
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               t.o.reg_partition_points,
-                               t.o.part_ops)
+        terms = t.o.terms
 
 
-        out_part_ops = add_reduce.o.part_ops
-        out_part_pts = add_reduce.o.reg_partition_points
+        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
 
 
-        m.submodules.add_reduce = add_reduce
+        i = t.o
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            mcur.setup(m, i)
+            o = mcur.ospec()
+            if idx in self.register_levels:
+                m.d.sync += o.eq(mcur.process(i))
+            else:
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
 
 
-        interm = Intermediates(128, 8, expanded_part_pts)
-        m.submodules.intermediates = interm
-        m.d.comb += interm.i.eq(add_reduce.o)
+        interm = Intermediates(self.pspec, part_pts)
+        interm.setup(m, i)
+        o = interm.process(interm.i)
 
         # final output
 
         # final output
-        m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts)
-        m.d.comb += finalout.i.eq(interm.o)
-        m.d.comb += self.output.eq(finalout.out)
-        m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)
+        finalout = FinalOut(self.pspec, part_pts)
+        finalout.setup(m, o)
+        m.d.comb += self.o.eq(finalout.process(o))
 
         return m
 
 
         return m