add new Parts class
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 5be8768c56aca65ff9c00a690f0c286ee1915579..64994ad36c9feff26d84cd6b52acd05193456fa9 100644 (file)
@@ -6,6 +6,8 @@ from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
 from nmigen.hdl.ast import Assign
 from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from nmigen.hdl.ast import Assign
 from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
+from functools import reduce
+from operator import or_
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -101,6 +103,12 @@ class PartitionPoints(dict):
                 return False
         return True
 
                 return False
         return True
 
+    def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
+        if index == -1 or index == 7:
+            return C(True, 1)
+        assert index >= 0 and index < 8
+        return self[(index * 8 + 8)*mfactor]
+
 
 class FullAdder(Elaboratable):
     """Full Adder.
 
 class FullAdder(Elaboratable):
     """Full Adder.
@@ -110,6 +118,11 @@ class FullAdder(Elaboratable):
     :attribute in2: the third input
     :attribute sum: the sum output
     :attribute carry: the carry output
     :attribute in2: the third input
     :attribute sum: the sum output
     :attribute carry: the carry output
+
+    Rather than do individual full adders (and have an array of them,
+    which would be very slow to simulate), this module can specify the
+    bit width of the inputs and outputs: in effect it performs multiple
+    Full 3-2 Add operations "in parallel".
     """
 
     def __init__(self, width):
     """
 
     def __init__(self, width):
@@ -133,9 +146,77 @@ class FullAdder(Elaboratable):
         return m
 
 
         return m
 
 
+class MaskedFullAdder(Elaboratable):
+    """Masked Full Adder.
+
+    :attribute mask: the carry partition mask
+    :attribute in0: the first input
+    :attribute in1: the second input
+    :attribute in2: the third input
+    :attribute sum: the sum output
+    :attribute mcarry: the masked carry output
+
+    FullAdders are always used with a "mask" on the output.  To keep
+    the graphviz "clean", this class performs the masking here rather
+    than inside a large for-loop.
+
+    See the following discussion as to why this is no longer derived
+    from FullAdder.  Each carry is shifted here *before* being ANDed
+    with the mask, so that an AOI cell may be used (which is more
+    gate-efficient)
+    https://en.wikipedia.org/wiki/AND-OR-Invert
+    https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
+    """
+
+    def __init__(self, width):
+        """Create a ``MaskedFullAdder``.
+
+        :param width: the bit width of the input and output
+        """
+        self.width = width
+        self.mask = Signal(width, reset_less=True)
+        self.mcarry = Signal(width, reset_less=True)
+        self.in0 = Signal(width, reset_less=True)
+        self.in1 = Signal(width, reset_less=True)
+        self.in2 = Signal(width, reset_less=True)
+        self.sum = Signal(width, reset_less=True)
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        m = Module()
+        s1 = Signal(self.width, reset_less=True)
+        s2 = Signal(self.width, reset_less=True)
+        s3 = Signal(self.width, reset_less=True)
+        c1 = Signal(self.width, reset_less=True)
+        c2 = Signal(self.width, reset_less=True)
+        c3 = Signal(self.width, reset_less=True)
+        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
+        m.d.comb += s1.eq(Cat(0, self.in0))
+        m.d.comb += s2.eq(Cat(0, self.in1))
+        m.d.comb += s3.eq(Cat(0, self.in2))
+        m.d.comb += c1.eq(s1 & s2 & self.mask)
+        m.d.comb += c2.eq(s2 & s3 & self.mask)
+        m.d.comb += c3.eq(s3 & s1 & self.mask)
+        m.d.comb += self.mcarry.eq(c1 | c2 | c3)
+        return m
+
+
 class PartitionedAdder(Elaboratable):
     """Partitioned Adder.
 
 class PartitionedAdder(Elaboratable):
     """Partitioned Adder.
 
+    Performs the final add.  The partition points are included in the
+    actual add (in one of the operands only), which causes a carry over
+    to the next bit.  Then the final output *removes* the extra bits from
+    the result.
+
+    partition: .... P... P... P... P... (32 bits)
+    a        : .... .... .... .... .... (32 bits)
+    b        : .... .... .... .... .... (32 bits)
+    exp-a    : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
+    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
+    exp-o    : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
+    o        : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
+
     :attribute width: the bit width of the input and output. Read-only.
     :attribute a: the first input to the adder
     :attribute b: the second input to the adder
     :attribute width: the bit width of the input and output. Read-only.
     :attribute a: the first input to the adder
     :attribute b: the second input to the adder
@@ -163,46 +244,54 @@ class PartitionedAdder(Elaboratable):
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
-        self._expanded_a = Signal(expanded_width)
-        self._expanded_b = Signal(expanded_width)
-        self._expanded_output = Signal(expanded_width)
+        # XXX these have to remain here due to some horrible nmigen
+        # simulation bugs involving sync.  it is *not* necessary to
+        # have them here, they should (under normal circumstances)
+        # be moved into elaborate, as they are entirely local
+        self._expanded_a = Signal(expanded_width) # includes extra part-points
+        self._expanded_b = Signal(expanded_width) # likewise.
+        self._expanded_o = Signal(expanded_width) # likewise.
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
-        al = []
-        bl = []
-        ol = []
-        ea = []
-        eb = []
-        eo = []
-        # partition points are "breaks" (extra zeros) in what would otherwise
-        # be a massive long add.
+        al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
+
+        # partition points are "breaks" (extra zeros or 1s) in what would
+        # otherwise be a massive long add.  when the "break" points are 0,
+        # whatever is in it (in the output) is discarded.  however when
+        # there is a "1", it causes a roll-over carry to the *next* bit.
+        # we still ignore the "break" bit in the [intermediate] output,
+        # however by that time we've got the effect that we wanted: the
+        # carry has been carried *over* the break point.
+
         for i in range(self.width):
             if i in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(self._expanded_a[expanded_index])
         for i in range(self.width):
             if i in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(self._expanded_a[expanded_index])
-                al.append(~self.partition_points[i])
+                al.append(~self.partition_points[i]) # add extra bit in a
                 eb.append(self._expanded_b[expanded_index])
                 eb.append(self._expanded_b[expanded_index])
-                bl.append(C(0))
-                expanded_index += 1
+                bl.append(C(0)) # yes, add a zero
+                expanded_index += 1 # skip the extra point.  NOT in the output
             ea.append(self._expanded_a[expanded_index])
             ea.append(self._expanded_a[expanded_index])
-            al.append(self.a[i])
             eb.append(self._expanded_b[expanded_index])
             eb.append(self._expanded_b[expanded_index])
+            eo.append(self._expanded_o[expanded_index])
+            al.append(self.a[i])
             bl.append(self.b[i])
             bl.append(self.b[i])
-            eo.append(self._expanded_output[expanded_index])
             ol.append(self.output[i])
             expanded_index += 1
             ol.append(self.output[i])
             expanded_index += 1
+
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*eo).eq(Cat(*ol))
+        m.d.comb += Cat(*ol).eq(Cat(*eo))
+
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
-        m.d.comb += self._expanded_output.eq(
+        m.d.comb += self._expanded_o.eq(
             self._expanded_a + self._expanded_b)
         return m
 
             self._expanded_a + self._expanded_b)
         return m
 
@@ -210,7 +299,7 @@ class PartitionedAdder(Elaboratable):
 FULL_ADDER_INPUT_COUNT = 3
 
 
 FULL_ADDER_INPUT_COUNT = 3
 
 
-class AddReduce(Elaboratable):
+class AddReduceSingle(Elaboratable):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -222,7 +311,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points):
+    def __init__(self, inputs, output_width, register_levels, partition_points,
+                       part_ops):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -231,6 +321,9 @@ class AddReduce(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
+        self.part_ops = part_ops
+        self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
+                          for i in range(len(part_ops))]
         self.inputs = list(inputs)
         self._resized_inputs = [
             Signal(output_width, name=f"resized_inputs[{i}]")
         self.inputs = list(inputs)
         self._resized_inputs = [
             Signal(output_width, name=f"resized_inputs[{i}]")
@@ -241,12 +334,22 @@ class AddReduce(Elaboratable):
         if not self.partition_points.fits_in_width(output_width):
             raise ValueError("partition_points doesn't fit in output_width")
         self._reg_partition_points = self.partition_points.like()
         if not self.partition_points.fits_in_width(output_width):
             raise ValueError("partition_points doesn't fit in output_width")
         self._reg_partition_points = self.partition_points.like()
-        max_level = AddReduce.get_max_level(len(self.inputs))
+
+        max_level = AddReduceSingle.get_max_level(len(self.inputs))
         for level in self.register_levels:
             if level > max_level:
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
         for level in self.register_levels:
             if level > max_level:
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
+        # this is annoying.  we have to create the modules (and terms)
+        # because we need to know what they are (in order to set up the
+        # interconnects back in AddReduce), but cannot do the m.d.comb +=
+        # etc because this is not in elaboratable.
+        self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
+        self._intermediate_terms = []
+        if len(self.groups) != 0:
+            self.create_next_terms()
+
     @staticmethod
     def get_max_level(input_count):
         """Get the maximum level.
     @staticmethod
     def get_max_level(input_count):
         """Get the maximum level.
@@ -256,19 +359,13 @@ class AddReduce(Elaboratable):
         """
         retval = 0
         while True:
         """
         retval = 0
         while True:
-            groups = AddReduce.full_adder_groups(input_count)
+            groups = AddReduceSingle.full_adder_groups(input_count)
             if len(groups) == 0:
                 return retval
             input_count %= FULL_ADDER_INPUT_COUNT
             input_count += 2 * len(groups)
             retval += 1
 
             if len(groups) == 0:
                 return retval
             input_count %= FULL_ADDER_INPUT_COUNT
             input_count += 2 * len(groups)
             retval += 1
 
-    def next_register_levels(self):
-        """``Iterable`` of ``register_levels`` for next recursive level."""
-        for level in self.register_levels:
-            if level > 0:
-                yield level - 1
-
     @staticmethod
     def full_adder_groups(input_count):
         """Get ``inputs`` indices for which a full adder should be built."""
     @staticmethod
     def full_adder_groups(input_count):
         """Get ``inputs`` indices for which a full adder should be built."""
@@ -284,17 +381,23 @@ class AddReduce(Elaboratable):
         # pipeline registers
         resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
                                      for i in range(len(self.inputs))]
         # pipeline registers
         resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
                                      for i in range(len(self.inputs))]
+        copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
+                                     for i in range(len(self.part_ops))]
         if 0 in self.register_levels:
         if 0 in self.register_levels:
+            m.d.sync += copy_part_ops
             m.d.sync += resized_input_assignments
             m.d.sync += self._reg_partition_points.eq(self.partition_points)
         else:
             m.d.sync += resized_input_assignments
             m.d.sync += self._reg_partition_points.eq(self.partition_points)
         else:
+            m.d.comb += copy_part_ops
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
-        groups = AddReduce.full_adder_groups(len(self.inputs))
+        for (value, term) in self._intermediate_terms:
+            m.d.comb += term.eq(value)
+
         # if there are no full adders to create, then we handle the base cases
         # and return, otherwise we go on to the recursive case
         # if there are no full adders to create, then we handle the base cases
         # and return, otherwise we go on to the recursive case
-        if len(groups) == 0:
+        if len(self.groups) == 0:
             if len(self.inputs) == 0:
                 # use 0 as the default output value
                 m.d.comb += self.output.eq(0)
             if len(self.inputs) == 0:
                 # use 0 as the default output value
                 m.d.comb += self.output.eq(0)
@@ -302,8 +405,7 @@ class AddReduce(Elaboratable):
                 # handle single input
                 m.d.comb += self.output.eq(self._resized_inputs[0])
             else:
                 # handle single input
                 m.d.comb += self.output.eq(self._resized_inputs[0])
             else:
-                # base case for adding 2 or more inputs, which get recursively
-                # reduced to 2 inputs
+                # base case for adding 2 inputs
                 assert len(self.inputs) == 2
                 adder = PartitionedAdder(len(self.output),
                                          self._reg_partition_points)
                 assert len(self.inputs) == 2
                 adder = PartitionedAdder(len(self.output),
                                          self._reg_partition_points)
@@ -312,33 +414,47 @@ class AddReduce(Elaboratable):
                 m.d.comb += adder.b.eq(self._resized_inputs[1])
                 m.d.comb += self.output.eq(adder.output)
             return m
                 m.d.comb += adder.b.eq(self._resized_inputs[1])
                 m.d.comb += self.output.eq(adder.output)
             return m
-        # go on to handle recursive case
+
+        mask = self._reg_partition_points.as_mask(len(self.output))
+        m.d.comb += self.part_mask.eq(mask)
+
+        # add and link the intermediate term modules
+        for i, (iidx, adder_i) in enumerate(self.adders):
+            setattr(m.submodules, f"adder_{i}", adder_i)
+
+            m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
+            m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
+            m.d.comb += adder_i.mask.eq(self.part_mask)
+
+        return m
+
+    def create_next_terms(self):
+
+        # go on to prepare recursive case
         intermediate_terms = []
         intermediate_terms = []
+        _intermediate_terms = []
 
         def add_intermediate_term(value):
             intermediate_term = Signal(
                 len(self.output),
                 name=f"intermediate_terms[{len(intermediate_terms)}]")
 
         def add_intermediate_term(value):
             intermediate_term = Signal(
                 len(self.output),
                 name=f"intermediate_terms[{len(intermediate_terms)}]")
+            _intermediate_terms.append((value, intermediate_term))
             intermediate_terms.append(intermediate_term)
             intermediate_terms.append(intermediate_term)
-            m.d.comb += intermediate_term.eq(value)
 
         # store mask in intermediary (simplifies graph)
 
         # store mask in intermediary (simplifies graph)
-        part_mask = Signal(len(self.output), reset_less=True)
-        mask = self._reg_partition_points.as_mask(len(self.output))
-        m.d.comb += part_mask.eq(mask)
+        self.part_mask = Signal(len(self.output), reset_less=True)
 
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
 
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
-        for i in groups:
-            adder_i = FullAdder(len(self.output))
-            setattr(m.submodules, f"adder_{i}", adder_i)
-            m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
-            m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
-            m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
+        self.adders = []
+        for i in self.groups:
+            adder_i = MaskedFullAdder(len(self.output))
+            self.adders.append((i, adder_i))
+            # add both the sum and the masked-carry to the next level.
+            # 3 inputs have now been reduced to 2...
             add_intermediate_term(adder_i.sum)
             add_intermediate_term(adder_i.sum)
-            shifted_carry = adder_i.carry << 1
-            # mask out carry bits to prevent carries between partitions
-            add_intermediate_term((adder_i.carry << 1) & part_mask)
+            add_intermediate_term(adder_i.mcarry)
         # handle the remaining inputs.
         if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
             add_intermediate_term(self._resized_inputs[-1])
         # handle the remaining inputs.
         if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
             add_intermediate_term(self._resized_inputs[-1])
@@ -350,13 +466,89 @@ class AddReduce(Elaboratable):
             add_intermediate_term(self._resized_inputs[-1])
         else:
             assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
             add_intermediate_term(self._resized_inputs[-1])
         else:
             assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
-        # recursive invocation of ``AddReduce``
-        next_level = AddReduce(intermediate_terms,
-                               len(self.output),
-                               self.next_register_levels(),
-                               self._reg_partition_points)
-        m.submodules.next_level = next_level
+
+        self.intermediate_terms = intermediate_terms
+        self._intermediate_terms = _intermediate_terms
+
+
+class AddReduce(Elaboratable):
+    """Recursively Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, inputs, output_width, register_levels, partition_points,
+                       part_ops):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param register_levels: List of nesting levels that should have
+            pipeline registers.
+        :param partition_points: the input partition points.
+        """
+        self.inputs = inputs
+        self.part_ops = part_ops
+        self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
+                          for i in range(len(part_ops))]
+        self.output = Signal(output_width)
+        self.output_width = output_width
+        self.register_levels = register_levels
+        self.partition_points = partition_points
+
+        self.create_levels()
+
+    @staticmethod
+    def get_max_level(input_count):
+        return AddReduceSingle.get_max_level(input_count)
+
+    @staticmethod
+    def next_register_levels(register_levels):
+        """``Iterable`` of ``register_levels`` for next recursive level."""
+        for level in register_levels:
+            if level > 0:
+                yield level - 1
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        next_levels = self.register_levels
+        partition_points = self.partition_points
+        inputs = self.inputs
+        part_ops = self.part_ops
+        while True:
+            next_level = AddReduceSingle(inputs, self.output_width, next_levels,
+                                         partition_points, part_ops)
+            mods.append(next_level)
+            if len(next_level.groups) == 0:
+                break
+            next_levels = list(AddReduce.next_register_levels(next_levels))
+            partition_points = next_level._reg_partition_points
+            inputs = next_level.intermediate_terms
+            part_ops = next_level.out_part_ops
+
+        self.levels = mods
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        m = Module()
+
+        for i, next_level in enumerate(self.levels):
+            setattr(m.submodules, "next_level%d" % i, next_level)
+
+        # output comes from last module
         m.d.comb += self.output.eq(next_level.output)
         m.d.comb += self.output.eq(next_level.output)
+        copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
+                                     for i in range(len(self.part_ops))]
+        m.d.comb += copy_part_ops
+
         return m
 
 
         return m
 
 
@@ -366,6 +558,409 @@ OP_MUL_SIGNED_UNSIGNED_HIGH = 2  # a is signed, b is unsigned
 OP_MUL_UNSIGNED_HIGH = 3
 
 
 OP_MUL_UNSIGNED_HIGH = 3
 
 
+def get_term(value, shift=0, enabled=None):
+    if enabled is not None:
+        value = Mux(enabled, value, 0)
+    if shift > 0:
+        value = Cat(Repl(C(0, 1), shift), value)
+    else:
+        assert shift == 0
+    return value
+
+
+class ProductTerm(Elaboratable):
+    """ this class creates a single product term (a[..]*b[..]).
+        it has a design flaw in that is the *output* that is selected,
+        where the multiplication(s) are combinatorially generated
+        all the time.
+    """
+
+    def __init__(self, width, twidth, pbwid, a_index, b_index):
+        self.a_index = a_index
+        self.b_index = b_index
+        shift = 8 * (self.a_index + self.b_index)
+        self.pwidth = width
+        self.twidth = twidth
+        self.width = width*2
+        self.shift = shift
+
+        self.ti = Signal(self.width, reset_less=True)
+        self.term = Signal(twidth, reset_less=True)
+        self.a = Signal(twidth//2, reset_less=True)
+        self.b = Signal(twidth//2, reset_less=True)
+        self.pb_en = Signal(pbwid, reset_less=True)
+
+        self.tl = tl = []
+        min_index = min(self.a_index, self.b_index)
+        max_index = max(self.a_index, self.b_index)
+        for i in range(min_index, max_index):
+            tl.append(self.pb_en[i])
+        name = "te_%d_%d" % (self.a_index, self.b_index)
+        if len(tl) > 0:
+            term_enabled = Signal(name=name, reset_less=True)
+        else:
+            term_enabled = None
+        self.enabled = term_enabled
+        self.term.name = "term_%d_%d" % (a_index, b_index) # rename
+
+    def elaborate(self, platform):
+
+        m = Module()
+        if self.enabled is not None:
+            m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
+
+        bsa = Signal(self.width, reset_less=True)
+        bsb = Signal(self.width, reset_less=True)
+        a_index, b_index = self.a_index, self.b_index
+        pwidth = self.pwidth
+        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += self.ti.eq(bsa * bsb)
+        m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
+        """
+        #TODO: sort out width issues, get inputs a/b switched on/off.
+        #data going into Muxes is 1/2 the required width
+
+        pwidth = self.pwidth
+        width = self.width
+        bsa = Signal(self.twidth//2, reset_less=True)
+        bsb = Signal(self.twidth//2, reset_less=True)
+        asel = Signal(width, reset_less=True)
+        bsel = Signal(width, reset_less=True)
+        a_index, b_index = self.a_index, self.b_index
+        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
+        m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
+        m.d.comb += self.ti.eq(bsa * bsb)
+        m.d.comb += self.term.eq(self.ti)
+        """
+
+        return m
+
+
+class ProductTerms(Elaboratable):
+    """ creates a bank of product terms.  also performs the actual bit-selection
+        this class is to be wrapped with a for-loop on the "a" operand.
+        it creates a second-level for-loop on the "b" operand.
+    """
+    def __init__(self, width, twidth, pbwid, a_index, blen):
+        self.a_index = a_index
+        self.blen = blen
+        self.pwidth = width
+        self.twidth = twidth
+        self.pbwid = pbwid
+        self.a = Signal(twidth//2, reset_less=True)
+        self.b = Signal(twidth//2, reset_less=True)
+        self.pb_en = Signal(pbwid, reset_less=True)
+        self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
+                            for i in range(blen)]
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        for b_index in range(self.blen):
+            t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
+                            self.a_index, b_index)
+            setattr(m.submodules, "term_%d" % b_index, t)
+
+            m.d.comb += t.a.eq(self.a)
+            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.pb_en.eq(self.pb_en)
+
+            m.d.comb += self.terms[b_index].eq(t.term)
+
+        return m
+
+
+class LSBNegTerm(Elaboratable):
+
+    def __init__(self, bit_width):
+        self.bit_width = bit_width
+        self.part = Signal(reset_less=True)
+        self.signed = Signal(reset_less=True)
+        self.op = Signal(bit_width, reset_less=True)
+        self.msb = Signal(reset_less=True)
+        self.nt = Signal(bit_width*2, reset_less=True)
+        self.nl = Signal(bit_width*2, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        bit_wid = self.bit_width
+        ext = Repl(0, bit_wid) # extend output to HI part
+
+        # determine sign of each incoming number *in this partition*
+        enabled = Signal(reset_less=True)
+        m.d.comb += enabled.eq(self.part & self.msb & self.signed)
+
+        # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
+        # negation operation is split into a bitwise not and a +1.
+        # likewise for 16, 32, and 64-bit values.
+
+        # width-extended 1s complement if a is signed, otherwise zero
+        comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
+
+        # add 1 if signed, otherwise add zero
+        comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
+
+        return m
+
+
+class Parts(Elaboratable):
+
+    def __init__(self, pbwid, epps, n_parts):
+        self.pbwid = pbwid
+        # inputs
+        self.epps = PartitionPoints.like(epps, name="epps") # expanded points
+        # outputs
+        self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
+
+    def elaborate(self, platform):
+        m = Module()
+
+        epps, parts = self.epps, self.parts
+        # collect part-bytes (double factor because the input is extended)
+        pbs = Signal(self.pbwid, reset_less=True)
+        tl = []
+        for i in range(self.pbwid):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
+
+        # negated-temporary copy of partition bits
+        npbs = Signal.like(pbs, reset_less=True)
+        m.d.comb += npbs.eq(~pbs)
+        byte_count = 8 // len(parts)
+        for i in range(len(parts)):
+            pbl = []
+            pbl.append(npbs[i * byte_count - 1])
+            for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                pbl.append(pbs[j])
+            pbl.append(npbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
+            m.d.comb += value.eq(Cat(*pbl))
+            m.d.comb += parts[i].eq(~(value).bool())
+
+        return m
+
+
+class Part(Elaboratable):
+    """ a key class which, depending on the partitioning, will determine
+        what action to take when parts of the output are signed or unsigned.
+
+        this requires 2 pieces of data *per operand, per partition*:
+        whether the MSB is HI/LO (per partition!), and whether a signed
+        or unsigned operation has been *requested*.
+
+        once that is determined, signed is basically carried out
+        by splitting 2's complement into 1's complement plus one.
+        1's complement is just a bit-inversion.
+
+        the extra terms - as separate terms - are then thrown at the
+        AddReduce alongside the multiplication part-results.
+    """
+    def __init__(self, width, n_parts, n_levels, pbwid):
+
+        # inputs
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
+        self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+        self.pbs = Signal(pbwid, reset_less=True)
+
+        # outputs
+        self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
+        self.delayed_parts = [
+            [Signal(name=f"delayed_part_{delay}_{i}")
+             for i in range(n_parts)]
+                for delay in range(n_levels)]
+        # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
+        self.dplast = [Signal(name=f"dplast_{i}")
+                         for i in range(n_parts)]
+
+        self.not_a_term = Signal(width)
+        self.neg_lsb_a_term = Signal(width)
+        self.not_b_term = Signal(width)
+        self.neg_lsb_b_term = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
+        # negated-temporary copy of partition bits
+        npbs = Signal.like(pbs, reset_less=True)
+        m.d.comb += npbs.eq(~pbs)
+        byte_count = 8 // len(parts)
+        for i in range(len(parts)):
+            pbl = []
+            pbl.append(npbs[i * byte_count - 1])
+            for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                pbl.append(pbs[j])
+            pbl.append(npbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
+            m.d.comb += value.eq(Cat(*pbl))
+            m.d.comb += parts[i].eq(~(value).bool())
+            m.d.comb += delayed_parts[0][i].eq(parts[i])
+            m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
+                         for j in range(len(delayed_parts)-1)]
+            m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
+
+        not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
+                self.not_a_term, self.neg_lsb_a_term, \
+                self.not_b_term, self.neg_lsb_b_term
+
+        byte_width = 8 // len(parts) # byte width
+        bit_wid = 8 * byte_width     # bit width
+        nat, nbt, nla, nlb = [], [], [], []
+        for i in range(len(parts)):
+            # work out bit-inverted and +1 term for a.
+            pa = LSBNegTerm(bit_wid)
+            setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
+            m.d.comb += pa.part.eq(parts[i])
+            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
+            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
+            nat.append(pa.nt)
+            nla.append(pa.nl)
+
+            # work out bit-inverted and +1 term for b
+            pb = LSBNegTerm(bit_wid)
+            setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
+            m.d.comb += pb.part.eq(parts[i])
+            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
+            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
+            nbt.append(pb.nt)
+            nlb.append(pb.nl)
+
+        # concatenate together and return all 4 results.
+        m.d.comb += [not_a_term.eq(Cat(*nat)),
+                     not_b_term.eq(Cat(*nbt)),
+                     neg_lsb_a_term.eq(Cat(*nla)),
+                     neg_lsb_b_term.eq(Cat(*nlb)),
+                    ]
+
+        return m
+
+
+class IntermediateOut(Elaboratable):
+    """ selects the HI/LO part of the multiplication, for a given bit-width
+        the output is also reconstructed in its SIMD (partition) lanes.
+    """
+    def __init__(self, width, out_wid, n_parts):
+        self.width = width
+        self.n_parts = n_parts
+        self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
+                                     for i in range(8)]
+        self.intermed = Signal(out_wid, reset_less=True)
+        self.output = Signal(out_wid//2, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        ol = []
+        w = self.width
+        sel = w // 8
+        for i in range(self.n_parts):
+            op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
+            m.d.comb += op.eq(
+                Mux(self.part_ops[sel * i] == OP_MUL_LOW,
+                    self.intermed.part(i * w*2, w),
+                    self.intermed.part(i * w*2 + w, w)))
+            ol.append(op)
+        m.d.comb += self.output.eq(Cat(*ol))
+
+        return m
+
+
+class FinalOut(Elaboratable):
+    """ selects the final output based on the partitioning.
+
+        each byte is selectable independently, i.e. it is possible
+        that some partitions requested 8-bit computation whilst others
+        requested 16 or 32 bit.
+    """
+    def __init__(self, out_wid):
+        # inputs
+        self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+        self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+        self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+        self.i8 = Signal(out_wid, reset_less=True)
+        self.i16 = Signal(out_wid, reset_less=True)
+        self.i32 = Signal(out_wid, reset_less=True)
+        self.i64 = Signal(out_wid, reset_less=True)
+
+        # output
+        self.out = Signal(out_wid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        ol = []
+        for i in range(8):
+            # select one of the outputs: d8 selects i8, d16 selects i16
+            # d32 selects i32, and the default is i64.
+            # d8 and d16 are ORed together in the first Mux
+            # then the 2nd selects either i8 or i16.
+            # if neither d8 nor d16 are set, d32 selects either i32 or i64.
+            op = Signal(8, reset_less=True, name="op_%d" % i)
+            m.d.comb += op.eq(
+                Mux(self.d8[i] | self.d16[i // 2],
+                    Mux(self.d8[i], self.i8.part(i * 8, 8),
+                                     self.i16.part(i * 8, 8)),
+                    Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
+                                          self.i64.part(i * 8, 8))))
+            ol.append(op)
+        m.d.comb += self.out.eq(Cat(*ol))
+        return m
+
+
+class OrMod(Elaboratable):
+    """ ORs four values together in a hierarchical tree
+    """
+    def __init__(self, wid):
+        self.wid = wid
+        self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
+                     for i in range(4)]
+        self.orout = Signal(wid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        or1 = Signal(self.wid, reset_less=True)
+        or2 = Signal(self.wid, reset_less=True)
+        m.d.comb += or1.eq(self.orin[0] | self.orin[1])
+        m.d.comb += or2.eq(self.orin[2] | self.orin[3])
+        m.d.comb += self.orout.eq(or1 | or2)
+
+        return m
+
+
+class Signs(Elaboratable):
+    """ determines whether a or b are signed numbers
+        based on the required operation type (OP_MUL_*)
+    """
+
+    def __init__(self):
+        self.part_ops = Signal(2, reset_less=True)
+        self.a_signed = Signal(reset_less=True)
+        self.b_signed = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
+        bsig = (self.part_ops == OP_MUL_LOW) \
+                    | (self.part_ops == OP_MUL_SIGNED_HIGH)
+        m.d.comb += self.a_signed.eq(asig)
+        m.d.comb += self.b_signed.eq(bsig)
+
+        return m
+
+
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
 
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
 
@@ -394,270 +989,148 @@ class Mul8_16_32_64(Elaboratable):
             instruction.
     """
 
             instruction.
     """
 
-    def __init__(self, register_levels= ()):
+    def __init__(self, register_levels=()):
+        """ register_levels: specifies the points in the cascade at which
+            flip-flops are to be inserted.
+        """
+
+        # parameter(s)
+        self.register_levels = list(register_levels)
+
+        # inputs
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.a = Signal(64)
         self.b = Signal(64)
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.a = Signal(64)
         self.b = Signal(64)
-        self.output = Signal(64)
-        self.register_levels = list(register_levels)
+
+        # intermediates (needed for unit tests)
         self._intermediate_output = Signal(128)
         self._intermediate_output = Signal(128)
-        self._delayed_part_ops = [
-            [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)]
-        self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)]
-        self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)]
-        self._part_64 = [Signal(name=f"_part_64")]
-        self._delayed_part_8 = [
-            [Signal(name=f"_delayed_part_8_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_16 = [
-            [Signal(name=f"_delayed_part_16_{delay}_{i}")
-             for i in range(4)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_32 = [
-            [Signal(name=f"_delayed_part_32_{delay}_{i}")
-             for i in range(2)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_64 = [
-            [Signal(name=f"_delayed_part_64_{delay}")]
-            for delay in range(1 + len(self.register_levels))]
-        self._output_64 = Signal(64)
-        self._output_32 = Signal(64)
-        self._output_16 = Signal(64)
-        self._output_8 = Signal(64)
-        self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
-        self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
-        self._not_a_term_8 = Signal(128)
-        self._neg_lsb_a_term_8 = Signal(128)
-        self._not_b_term_8 = Signal(128)
-        self._neg_lsb_b_term_8 = Signal(128)
-        self._not_a_term_16 = Signal(128)
-        self._neg_lsb_a_term_16 = Signal(128)
-        self._not_b_term_16 = Signal(128)
-        self._neg_lsb_b_term_16 = Signal(128)
-        self._not_a_term_32 = Signal(128)
-        self._neg_lsb_a_term_32 = Signal(128)
-        self._not_b_term_32 = Signal(128)
-        self._neg_lsb_b_term_32 = Signal(128)
-        self._not_a_term_64 = Signal(128)
-        self._neg_lsb_a_term_64 = Signal(128)
-        self._not_b_term_64 = Signal(128)
-        self._neg_lsb_b_term_64 = Signal(128)
-
-    def _part_byte(self, index):
-        if index == -1 or index == 7:
-            return C(True, 1)
-        assert index >= 0 and index < 8
-        return self.part_pts[index * 8 + 8]
+
+        # output
+        self.output = Signal(64)
 
     def elaborate(self, platform):
         m = Module()
 
 
     def elaborate(self, platform):
         m = Module()
 
-        for i in range(len(self.part_ops)):
-            m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i])
-            m.d.sync += [self._delayed_part_ops[j + 1][i]
-                         .eq(self._delayed_part_ops[j][i])
-                         for j in range(len(self.register_levels))]
-
-        def add_intermediate_value(value):
-            intermediate_value = Signal(len(value), reset_less=True)
-            m.d.comb += intermediate_value.eq(value)
-            return intermediate_value
-
-        for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
-                                     (self._part_32, self._delayed_part_32),
-                                     (self._part_16, self._delayed_part_16),
-                                     (self._part_8, self._delayed_part_8)]:
-            byte_count = 8 // len(parts)
-            for i in range(len(parts)):
-                pb = self._part_byte(i * byte_count - 1)
-                value = add_intermediate_value(pb)
-                for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                    pb = add_intermediate_value(~self._part_byte(j))
-                    value = add_intermediate_value(value & pb)
-                pb = self._part_byte((i + 1) * byte_count - 1)
-                value = add_intermediate_value(value & pb)
-                m.d.comb += parts[i].eq(value)
-                m.d.comb += delayed_parts[0][i].eq(parts[i])
-                m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
-                             for j in range(len(self.register_levels))]
-
-        products = [[
-                Signal(16, name=f"products_{i}_{j}")
-                for j in range(8)]
-            for i in range(8)]
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(self.part_pts.part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
 
 
-        for a_index in range(8):
-            for b_index in range(8):
-                a = self.a.part(a_index * 8, 8)
-                b = self.b.part(b_index * 8, 8)
-                m.d.comb += products[a_index][b_index].eq(a * b)
+        # create (doubled) PartitionPoints (output is double input width)
+        expanded_part_pts = PartitionPoints()
+        for i, v in self.part_pts.items():
+            ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
+            expanded_part_pts[i * 2] = ep
+            m.d.comb += ep.eq(v)
 
 
-        terms = []
+        # local variables
+        signs = []
+        for i in range(8):
+            s = Signs()
+            signs.append(s)
+            setattr(m.submodules, "signs%d" % i, s)
+            m.d.comb += s.part_ops.eq(self.part_ops[i])
+
+        n_levels = len(self.register_levels)+1
+        m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
+        m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
+        m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
+        m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
+        nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+        for mod in [part_8, part_16, part_32, part_64]:
+            m.d.comb += mod.a.eq(self.a)
+            m.d.comb += mod.b.eq(self.b)
+            for i in range(len(signs)):
+                m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
+                m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
+            m.d.comb += mod.pbs.eq(pbs)
+            nat_l.append(mod.not_a_term)
+            nbt_l.append(mod.not_b_term)
+            nla_l.append(mod.neg_lsb_a_term)
+            nlb_l.append(mod.neg_lsb_b_term)
 
 
-        def add_term(value, shift=0, enabled=None):
-            term = Signal(128)
-            terms.append(term)
-            if enabled is not None:
-                value = Mux(enabled, value, 0)
-            if shift > 0:
-                value = Cat(Repl(C(0, 1), shift), value)
-            else:
-                assert shift == 0
-            m.d.comb += term.eq(value)
+        terms = []
 
         for a_index in range(8):
 
         for a_index in range(8):
-            for b_index in range(8):
-                term_enabled: Value = C(True, 1)
-                min_index = min(a_index, b_index)
-                max_index = max(a_index, b_index)
-                for i in range(min_index, max_index):
-                    term_enabled &= ~self._part_byte(i)
-                add_term(products[a_index][b_index],
-                         8 * (a_index + b_index),
-                         term_enabled)
+            t = ProductTerms(8, 128, 8, a_index, 8)
+            setattr(m.submodules, "terms_%d" % a_index, t)
 
 
-        for i in range(8):
-            a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
-            b_signed = (self.part_ops[i] == OP_MUL_LOW) \
-                | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
-            m.d.comb += self._a_signed[i].eq(a_signed)
-            m.d.comb += self._b_signed[i].eq(b_signed)
+            m.d.comb += t.a.eq(self.a)
+            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.pb_en.eq(pbs)
 
 
-        # it's fine to bitwise-or these together since they are never enabled
-        # at the same time
-        add_term(self._not_a_term_8 | self._not_a_term_16
-                 | self._not_a_term_32 | self._not_a_term_64)
-        add_term(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16
-                 | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64)
-        add_term(self._not_b_term_8 | self._not_b_term_16
-                 | self._not_b_term_32 | self._not_b_term_64)
-        add_term(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16
-                 | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64)
-
-        for not_a_term, \
-            neg_lsb_a_term, \
-            not_b_term, \
-            neg_lsb_b_term, \
-            parts in [
-                (self._not_a_term_8,
-                 self._neg_lsb_a_term_8,
-                 self._not_b_term_8,
-                 self._neg_lsb_b_term_8,
-                 self._part_8),
-                (self._not_a_term_16,
-                 self._neg_lsb_a_term_16,
-                 self._not_b_term_16,
-                 self._neg_lsb_b_term_16,
-                 self._part_16),
-                (self._not_a_term_32,
-                 self._neg_lsb_a_term_32,
-                 self._not_b_term_32,
-                 self._neg_lsb_b_term_32,
-                 self._part_32),
-                (self._not_a_term_64,
-                 self._neg_lsb_a_term_64,
-                 self._not_b_term_64,
-                 self._neg_lsb_b_term_64,
-                 self._part_64),
-                ]:
-            byte_width = 8 // len(parts)
-            bit_width = 8 * byte_width
-            for i in range(len(parts)):
-                ae = parts[i] & self.a[(i + 1) * bit_width - 1] \
-                    & self._a_signed[i * byte_width]
-                be = parts[i] & self.b[(i + 1) * bit_width - 1] \
-                    & self._b_signed[i * byte_width]
-                a_enabled = Signal(name="a_enabled_%d" % i, reset_less=True)
-                b_enabled = Signal(name="b_enabled_%d" % i, reset_less=True)
-                m.d.comb += a_enabled.eq(ae)
-                m.d.comb += b_enabled.eq(be)
-
-                # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
-                # negation operation is split into a bitwise not and a +1.
-                # likewise for 16, 32, and 64-bit values.
-                m.d.comb += [
-                    not_a_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(a_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.a.part(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_a_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), a_enabled)),
-
-                    not_b_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(b_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.b.part(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_b_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), b_enabled))]
+            for term in t.terms:
+                terms.append(term)
 
 
-        expanded_part_pts = PartitionPoints()
-        for i, v in self.part_pts.items():
-            signal = Signal(name=f"expanded_part_pts_{i*2}")
-            expanded_part_pts[i * 2] = signal
-            m.d.comb += signal.eq(v)
+        # it's fine to bitwise-or data together since they are never enabled
+        # at the same time
+        m.submodules.nat_or = nat_or = OrMod(128)
+        m.submodules.nbt_or = nbt_or = OrMod(128)
+        m.submodules.nla_or = nla_or = OrMod(128)
+        m.submodules.nlb_or = nlb_or = OrMod(128)
+        for l, mod in [(nat_l, nat_or),
+                             (nbt_l, nbt_or),
+                             (nla_l, nla_or),
+                             (nlb_l, nlb_or)]:
+            for i in range(len(l)):
+                m.d.comb += mod.orin[i].eq(l[i])
+            terms.append(mod.orout)
 
         add_reduce = AddReduce(terms,
                                128,
                                self.register_levels,
 
         add_reduce = AddReduce(terms,
                                128,
                                self.register_levels,
-                               expanded_part_pts)
+                               expanded_part_pts,
+                               self.part_ops)
+
+        out_part_ops = add_reduce.levels[-1].out_part_ops
+
         m.submodules.add_reduce = add_reduce
         m.d.comb += self._intermediate_output.eq(add_reduce.output)
         m.submodules.add_reduce = add_reduce
         m.d.comb += self._intermediate_output.eq(add_reduce.output)
-        m.d.comb += self._output_64.eq(
-            Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW,
-                self._intermediate_output.part(0, 64),
-                self._intermediate_output.part(64, 64)))
+        # create _output_64
+        m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
+        m.d.comb += io64.intermed.eq(self._intermediate_output)
+        for i in range(8):
+            m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
 
         # create _output_32
 
         # create _output_32
-        ol = []
-        for i in range(2):
-            ol.append(
-                Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 64, 32),
-                    self._intermediate_output.part(i * 64 + 32, 32)))
-        m.d.comb += self._output_32.eq(Cat(*ol))
+        m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
+        m.d.comb += io32.intermed.eq(self._intermediate_output)
+        for i in range(8):
+            m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
 
         # create _output_16
 
         # create _output_16
-        ol = []
-        for i in range(4):
-            ol.append(
-                Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 32, 16),
-                    self._intermediate_output.part(i * 32 + 16, 16)))
-        m.d.comb += self._output_16.eq(Cat(*ol))
+        m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
+        m.d.comb += io16.intermed.eq(self._intermediate_output)
+        for i in range(8):
+            m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
 
         # create _output_8
 
         # create _output_8
-        ol = []
+        m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
+        m.d.comb += io8.intermed.eq(self._intermediate_output)
         for i in range(8):
         for i in range(8):
-            ol.append(
-                Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 16, 8),
-                    self._intermediate_output.part(i * 16 + 8, 8)))
-        m.d.comb += self._output_8.eq(Cat(*ol))
+            m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
 
         # final output
 
         # final output
-        ol = []
-        for i in range(8):
-            ol.append(
-                Mux(self._delayed_part_8[-1][i]
-                    | self._delayed_part_16[-1][i // 2],
-                    Mux(self._delayed_part_8[-1][i],
-                        self._output_8.part(i * 8, 8),
-                        self._output_16.part(i * 8, 8)),
-                    Mux(self._delayed_part_32[-1][i // 4],
-                        self._output_32.part(i * 8, 8),
-                        self._output_64.part(i * 8, 8))))
-        m.d.comb += self.output.eq(Cat(*ol))
+        m.submodules.finalout = finalout = FinalOut(64)
+        for i in range(len(part_8.delayed_parts[-1])):
+            m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
+        for i in range(len(part_16.delayed_parts[-1])):
+            m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
+        for i in range(len(part_32.delayed_parts[-1])):
+            m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
+        m.d.comb += finalout.i8.eq(io8.output)
+        m.d.comb += finalout.i16.eq(io16.output)
+        m.d.comb += finalout.i32.eq(io32.output)
+        m.d.comb += finalout.i64.eq(io64.output)
+        m.d.comb += self.output.eq(finalout.out)
+
         return m
 
 
         return m