run tests in parallel
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 7202c9cf653d99c30ded41442eab875dcfd12532..c9239fa6cccf970c34cdf2cf2c71339a4d25d8c0 100644 (file)
@@ -11,296 +11,8 @@ from operator import or_
 from ieee754.pipeline import PipelineSpec
 from nmutil.pipemodbase import PipeModBase
 
-
-class PartitionPoints(dict):
-    """Partition points and corresponding ``Value``s.
-
-    The points at where an ALU is partitioned along with ``Value``s that
-    specify if the corresponding partition points are enabled.
-
-    For example: ``{1: True, 5: True, 10: True}`` with
-    ``width == 16`` specifies that the ALU is split into 4 sections:
-    * bits 0 <= ``i`` < 1
-    * bits 1 <= ``i`` < 5
-    * bits 5 <= ``i`` < 10
-    * bits 10 <= ``i`` < 16
-
-    If the partition_points were instead ``{1: True, 5: a, 10: True}``
-    where ``a`` is a 1-bit ``Signal``:
-    * If ``a`` is asserted:
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 5
-        * bits 5 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    * Otherwise
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    """
-
-    def __init__(self, partition_points=None):
-        """Create a new ``PartitionPoints``.
-
-        :param partition_points: the input partition points to values mapping.
-        """
-        super().__init__()
-        if partition_points is not None:
-            for point, enabled in partition_points.items():
-                if not isinstance(point, int):
-                    raise TypeError("point must be a non-negative integer")
-                if point < 0:
-                    raise ValueError("point must be a non-negative integer")
-                self[point] = Value.wrap(enabled)
-
-    def like(self, name=None, src_loc_at=0, mul=1):
-        """Create a new ``PartitionPoints`` with ``Signal``s for all values.
-
-        :param name: the base name for the new ``Signal``s.
-        :param mul: a multiplication factor on the indices
-        """
-        if name is None:
-            name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
-        retval = PartitionPoints()
-        for point, enabled in self.items():
-            point *= mul
-            retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
-        return retval
-
-    def eq(self, rhs):
-        """Assign ``PartitionPoints`` using ``Signal.eq``."""
-        if set(self.keys()) != set(rhs.keys()):
-            raise ValueError("incompatible point set")
-        for point, enabled in self.items():
-            yield enabled.eq(rhs[point])
-
-    def as_mask(self, width, mul=1):
-        """Create a bit-mask from `self`.
-
-        Each bit in the returned mask is clear only if the partition point at
-        the same bit-index is enabled.
-
-        :param width: the bit width of the resulting mask
-        :param mul: a "multiplier" which in-place expands the partition points
-                    typically set to "2" when used for multipliers
-        """
-        bits = []
-        for i in range(width):
-            i /= mul
-            if i.is_integer() and int(i) in self:
-                bits.append(~self[i])
-            else:
-                bits.append(True)
-        return Cat(*bits)
-
-    def get_max_partition_count(self, width):
-        """Get the maximum number of partitions.
-
-        Gets the number of partitions when all partition points are enabled.
-        """
-        retval = 1
-        for point in self.keys():
-            if point < width:
-                retval += 1
-        return retval
-
-    def fits_in_width(self, width):
-        """Check if all partition points are smaller than `width`."""
-        for point in self.keys():
-            if point >= width:
-                return False
-        return True
-
-    def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
-        if index == -1 or index == 7:
-            return C(True, 1)
-        assert index >= 0 and index < 8
-        return self[(index * 8 + 8)*mfactor]
-
-
-class FullAdder(Elaboratable):
-    """Full Adder.
-
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute carry: the carry output
-
-    Rather than do individual full adders (and have an array of them,
-    which would be very slow to simulate), this module can specify the
-    bit width of the inputs and outputs: in effect it performs multiple
-    Full 3-2 Add operations "in parallel".
-    """
-
-    def __init__(self, width):
-        """Create a ``FullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-        self.carry = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += self.carry.eq((self.in0 & self.in1)
-                                  | (self.in1 & self.in2)
-                                  | (self.in2 & self.in0))
-        return m
-
-
-class MaskedFullAdder(Elaboratable):
-    """Masked Full Adder.
-
-    :attribute mask: the carry partition mask
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute mcarry: the masked carry output
-
-    FullAdders are always used with a "mask" on the output.  To keep
-    the graphviz "clean", this class performs the masking here rather
-    than inside a large for-loop.
-
-    See the following discussion as to why this is no longer derived
-    from FullAdder.  Each carry is shifted here *before* being ANDed
-    with the mask, so that an AOI cell may be used (which is more
-    gate-efficient)
-    https://en.wikipedia.org/wiki/AND-OR-Invert
-    https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
-    """
-
-    def __init__(self, width):
-        """Create a ``MaskedFullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.width = width
-        self.mask = Signal(width, reset_less=True)
-        self.mcarry = Signal(width, reset_less=True)
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        s1 = Signal(self.width, reset_less=True)
-        s2 = Signal(self.width, reset_less=True)
-        s3 = Signal(self.width, reset_less=True)
-        c1 = Signal(self.width, reset_less=True)
-        c2 = Signal(self.width, reset_less=True)
-        c3 = Signal(self.width, reset_less=True)
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += s1.eq(Cat(0, self.in0))
-        m.d.comb += s2.eq(Cat(0, self.in1))
-        m.d.comb += s3.eq(Cat(0, self.in2))
-        m.d.comb += c1.eq(s1 & s2 & self.mask)
-        m.d.comb += c2.eq(s2 & s3 & self.mask)
-        m.d.comb += c3.eq(s3 & s1 & self.mask)
-        m.d.comb += self.mcarry.eq(c1 | c2 | c3)
-        return m
-
-
-class PartitionedAdder(Elaboratable):
-    """Partitioned Adder.
-
-    Performs the final add.  The partition points are included in the
-    actual add (in one of the operands only), which causes a carry over
-    to the next bit.  Then the final output *removes* the extra bits from
-    the result.
-
-    partition: .... P... P... P... P... (32 bits)
-    a        : .... .... .... .... .... (32 bits)
-    b        : .... .... .... .... .... (32 bits)
-    exp-a    : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
-    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
-    exp-o    : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
-    o        : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
-
-    :attribute width: the bit width of the input and output. Read-only.
-    :attribute a: the first input to the adder
-    :attribute b: the second input to the adder
-    :attribute output: the sum output
-    :attribute partition_points: the input partition points. Modification not
-        supported, except for by ``Signal.eq``.
-    """
-
-    def __init__(self, width, partition_points, partition_step=1):
-        """Create a ``PartitionedAdder``.
-
-        :param width: the bit width of the input and output
-        :param partition_points: the input partition points
-        :param partition_step: a multiplier (typically double) step
-                               which in-place "expands" the partition points
-        """
-        self.width = width
-        self.pmul = partition_step
-        self.a = Signal(width, reset_less=True)
-        self.b = Signal(width, reset_less=True)
-        self.output = Signal(width, reset_less=True)
-        self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(width):
-            raise ValueError("partition_points doesn't fit in width")
-        expanded_width = 0
-        for i in range(self.width):
-            if i in self.partition_points:
-                expanded_width += 1
-            expanded_width += 1
-        self._expanded_width = expanded_width
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        expanded_a = Signal(self._expanded_width, reset_less=True)
-        expanded_b = Signal(self._expanded_width, reset_less=True)
-        expanded_o = Signal(self._expanded_width, reset_less=True)
-
-        expanded_index = 0
-        # store bits in a list, use Cat later.  graphviz is much cleaner
-        al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
-
-        # partition points are "breaks" (extra zeros or 1s) in what would
-        # otherwise be a massive long add.  when the "break" points are 0,
-        # whatever is in it (in the output) is discarded.  however when
-        # there is a "1", it causes a roll-over carry to the *next* bit.
-        # we still ignore the "break" bit in the [intermediate] output,
-        # however by that time we've got the effect that we wanted: the
-        # carry has been carried *over* the break point.
-
-        for i in range(self.width):
-            pi = i/self.pmul # double the range of the partition point test
-            if pi.is_integer() and pi in self.partition_points:
-                # add extra bit set to 0 + 0 for enabled partition points
-                # and 1 + 0 for disabled partition points
-                ea.append(expanded_a[expanded_index])
-                al.append(~self.partition_points[pi]) # add extra bit in a
-                eb.append(expanded_b[expanded_index])
-                bl.append(C(0)) # yes, add a zero
-                expanded_index += 1 # skip the extra point.  NOT in the output
-            ea.append(expanded_a[expanded_index])
-            eb.append(expanded_b[expanded_index])
-            eo.append(expanded_o[expanded_index])
-            al.append(self.a[i])
-            bl.append(self.b[i])
-            ol.append(self.output[i])
-            expanded_index += 1
-
-        # combine above using Cat
-        m.d.comb += Cat(*ea).eq(Cat(*al))
-        m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*ol).eq(Cat(*eo))
-
-        # use only one addition to take advantage of look-ahead carry and
-        # special hardware on FPGAs
-        m.d.comb += expanded_o.eq(expanded_a + expanded_b)
-        return m
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
 
 
 FULL_ADDER_INPUT_COUNT = 3
@@ -310,7 +22,7 @@ class AddReduceData:
     def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
                           for i in range(n_parts)]
-        self.terms = [Signal(output_width, name=f"inputs_{i}",
+        self.terms = [Signal(output_width, name=f"terms_{i}",
                               reset_less=True)
                         for i in range(n_inputs)]
         self.part_pts = part_pts.like()
@@ -344,23 +56,22 @@ class FinalReduceData:
         return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
     """ Final stage of add reduce
     """
 
-    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
                        partition_step=1):
         self.lidx = lidx
         self.partition_step = partition_step
-        self.output_width = output_width
+        self.output_width = pspec.width * 2
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
+        self.n_parts = pspec.n_parts
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "finaladd")
 
     def ispec(self):
         return AddReduceData(self.partition_points, self.n_inputs,
@@ -370,13 +81,6 @@ class FinalAdd(Elaboratable):
         return FinalReduceData(self.partition_points,
                                  self.output_width, self.n_parts)
 
-    def setup(self, m, i):
-        m.submodules.finaladd = self
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -406,7 +110,7 @@ class FinalAdd(Elaboratable):
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -418,7 +122,7 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
                        partition_step=1):
         """Create an ``AddReduce``.
 
@@ -429,17 +133,16 @@ class AddReduceSingle(Elaboratable):
         self.lidx = lidx
         self.partition_step = partition_step
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
 
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "addreduce_%d" % lidx)
 
     def ispec(self):
         return AddReduceData(self.partition_points, self.n_inputs,
@@ -449,13 +152,6 @@ class AddReduceSingle(Elaboratable):
         return AddReduceData(self.partition_points, self.n_terms,
                              self.output_width, self.n_parts)
 
-    def setup(self, m, i):
-        setattr(m.submodules, "addreduce_%d" % self.lidx, self)
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
         retval = len(groups)*2
@@ -555,7 +251,7 @@ class AddReduceSingle(Elaboratable):
 
 
 class AddReduceInternal:
-    """Recursively Add list of numbers together.
+    """Iteratively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
         supported, except for by ``Signal.eq``.
@@ -566,18 +262,17 @@ class AddReduceInternal:
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, i, output_width, partition_step=1):
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
-        self.i = i
-        self.inputs = i.terms
-        self.part_ops = i.part_ops
-        self.output_width = output_width
-        self.partition_points = i.part_pts
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
         self.partition_step = partition_step
 
         self.create_levels()
@@ -587,26 +282,21 @@ class AddReduceInternal:
 
         mods = []
         partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
+        ilen = self.n_inputs
         while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
+            groups = AddReduceSingle.full_adder_groups(ilen)
             if len(groups) == 0:
                 break
             lidx = len(mods)
-            next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts,
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
                                          partition_points,
                                          self.partition_step)
             mods.append(next_level)
             partition_points = next_level.i.part_pts
-            inputs = next_level.o.terms
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
+            ilen = len(next_level.o.terms)
 
         lidx = len(mods)
-        next_level = FinalAdd(lidx, ilen, self.output_width, n_parts,
+        next_level = FinalAdd(self.pspec, lidx, ilen,
                               partition_points, self.partition_step)
         mods.append(next_level)
 
@@ -641,7 +331,8 @@ class AddReduce(AddReduceInternal, Elaboratable):
         n_parts = len(part_ops)
         self.i = AddReduceData(part_pts, len(inputs),
                              output_width, n_parts)
-        AddReduceInternal.__init__(self, self.i, output_width, partition_step)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
         self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
 
@@ -993,21 +684,21 @@ class IntermediateOut(Elaboratable):
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, output_width, n_parts, part_pts):
+    def __init__(self, pspec, part_pts):
+
         self.part_pts = part_pts
-        self.output_width = output_width
-        self.n_parts = n_parts
-        self.out_wid = output_width//2
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
 
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "finalout")
 
     def ispec(self):
         return IntermediateData(self.part_pts, self.output_width, self.n_parts)
@@ -1015,13 +706,6 @@ class FinalOut(Elaboratable):
     def ospec(self):
         return OutputData()
 
-    def setup(self, m, i):
-        m.submodules.finalout = self
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     def elaborate(self, platform):
         m = Module()
 
@@ -1185,12 +869,12 @@ class AllTerms(PipeModBase):
     """Set of terms to be added together
     """
 
-    def __init__(self, pspec):
+    def __init__(self, pspec, n_inputs):
         """Create an ``AllTerms``.
         """
-        self.n_inputs = pspec.n_inputs
+        self.n_inputs = n_inputs
         self.n_parts = pspec.n_parts
-        self.output_width = pspec.width
+        self.output_width = pspec.width * 2
         super().__init__(pspec, "allterms")
 
     def ispec(self):
@@ -1278,17 +962,16 @@ class AllTerms(PipeModBase):
         return m
 
 
-class Intermediates(Elaboratable):
+class Intermediates(PipeModBase):
     """ Intermediate output modules
     """
 
-    def __init__(self, output_width, n_parts, part_pts):
+    def __init__(self, pspec, part_pts):
         self.part_pts = part_pts
-        self.output_width = output_width
-        self.n_parts = n_parts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
 
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "intermediates")
 
     def ispec(self):
         return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
@@ -1296,13 +979,6 @@ class Intermediates(Elaboratable):
     def ospec(self):
         return IntermediateData(self.part_pts, self.output_width, self.n_parts)
 
-    def setup(self, m, i):
-        m.submodules.intermediates = self
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     def elaborate(self, platform):
         m = Module()
 
@@ -1348,6 +1024,8 @@ class Intermediates(Elaboratable):
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
 
+    XXX NOTE: this class is intended for unit test purposes ONLY.
+
     Supports partitioning into any combination of 8, 16, 32, and 64-bit
     partitions on naturally-aligned boundaries. Supports the operation being
     set for each partition independently.
@@ -1380,8 +1058,7 @@ class Mul8_16_32_64(Elaboratable):
 
         self.id_wid = 0 # num_bits(num_rows)
         self.op_wid = 0
-        self.pspec = PipelineSpec(128, self.id_wid, self.op_wid, n_ops=3)
-        self.pspec.n_inputs = 64 + 4
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
         self.pspec.n_parts = 8
 
         # parameter(s)
@@ -1411,17 +1088,15 @@ class Mul8_16_32_64(Elaboratable):
 
         part_pts = self.part_pts
 
-        n_parts = self.pspec.n_parts
-        n_inputs = self.pspec.n_inputs
-        output_width = self.pspec.width
-        t = AllTerms(self.pspec)
+        n_inputs = 64 + 4
+        t = AllTerms(self.pspec, n_inputs)
         t.setup(m, self.i)
 
         terms = t.o.terms
 
-        at = AddReduceInternal(t.process(self.i), 128, partition_step=2)
+        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
 
-        i = at.i
+        i = t.o
         for idx in range(len(at.levels)):
             mcur = at.levels[idx]
             mcur.setup(m, i)
@@ -1432,12 +1107,12 @@ class Mul8_16_32_64(Elaboratable):
                 m.d.comb += o.eq(mcur.process(i))
             i = o # for next loop
 
-        interm = Intermediates(128, 8, part_pts)
+        interm = Intermediates(self.pspec, part_pts)
         interm.setup(m, i)
         o = interm.process(interm.i)
 
         # final output
-        finalout = FinalOut(128, 8, part_pts)
+        finalout = FinalOut(self.pspec, part_pts)
         finalout.setup(m, o)
         m.d.comb += self.o.eq(finalout.process(o))