split out adder code (PartitionedAdder) into module, PartitionPoints too
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 215d18c6a1aacce049dd74c6331437c2e74e5853..c9239fa6cccf970c34cdf2cf2c71339a4d25d8c0 100644 (file)
@@ -11,296 +11,8 @@ from operator import or_
 from ieee754.pipeline import PipelineSpec
 from nmutil.pipemodbase import PipeModBase
 
-
-class PartitionPoints(dict):
-    """Partition points and corresponding ``Value``s.
-
-    The points at where an ALU is partitioned along with ``Value``s that
-    specify if the corresponding partition points are enabled.
-
-    For example: ``{1: True, 5: True, 10: True}`` with
-    ``width == 16`` specifies that the ALU is split into 4 sections:
-    * bits 0 <= ``i`` < 1
-    * bits 1 <= ``i`` < 5
-    * bits 5 <= ``i`` < 10
-    * bits 10 <= ``i`` < 16
-
-    If the partition_points were instead ``{1: True, 5: a, 10: True}``
-    where ``a`` is a 1-bit ``Signal``:
-    * If ``a`` is asserted:
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 5
-        * bits 5 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    * Otherwise
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    """
-
-    def __init__(self, partition_points=None):
-        """Create a new ``PartitionPoints``.
-
-        :param partition_points: the input partition points to values mapping.
-        """
-        super().__init__()
-        if partition_points is not None:
-            for point, enabled in partition_points.items():
-                if not isinstance(point, int):
-                    raise TypeError("point must be a non-negative integer")
-                if point < 0:
-                    raise ValueError("point must be a non-negative integer")
-                self[point] = Value.wrap(enabled)
-
-    def like(self, name=None, src_loc_at=0, mul=1):
-        """Create a new ``PartitionPoints`` with ``Signal``s for all values.
-
-        :param name: the base name for the new ``Signal``s.
-        :param mul: a multiplication factor on the indices
-        """
-        if name is None:
-            name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
-        retval = PartitionPoints()
-        for point, enabled in self.items():
-            point *= mul
-            retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
-        return retval
-
-    def eq(self, rhs):
-        """Assign ``PartitionPoints`` using ``Signal.eq``."""
-        if set(self.keys()) != set(rhs.keys()):
-            raise ValueError("incompatible point set")
-        for point, enabled in self.items():
-            yield enabled.eq(rhs[point])
-
-    def as_mask(self, width, mul=1):
-        """Create a bit-mask from `self`.
-
-        Each bit in the returned mask is clear only if the partition point at
-        the same bit-index is enabled.
-
-        :param width: the bit width of the resulting mask
-        :param mul: a "multiplier" which in-place expands the partition points
-                    typically set to "2" when used for multipliers
-        """
-        bits = []
-        for i in range(width):
-            i /= mul
-            if i.is_integer() and int(i) in self:
-                bits.append(~self[i])
-            else:
-                bits.append(True)
-        return Cat(*bits)
-
-    def get_max_partition_count(self, width):
-        """Get the maximum number of partitions.
-
-        Gets the number of partitions when all partition points are enabled.
-        """
-        retval = 1
-        for point in self.keys():
-            if point < width:
-                retval += 1
-        return retval
-
-    def fits_in_width(self, width):
-        """Check if all partition points are smaller than `width`."""
-        for point in self.keys():
-            if point >= width:
-                return False
-        return True
-
-    def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
-        if index == -1 or index == 7:
-            return C(True, 1)
-        assert index >= 0 and index < 8
-        return self[(index * 8 + 8)*mfactor]
-
-
-class FullAdder(Elaboratable):
-    """Full Adder.
-
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute carry: the carry output
-
-    Rather than do individual full adders (and have an array of them,
-    which would be very slow to simulate), this module can specify the
-    bit width of the inputs and outputs: in effect it performs multiple
-    Full 3-2 Add operations "in parallel".
-    """
-
-    def __init__(self, width):
-        """Create a ``FullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-        self.carry = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += self.carry.eq((self.in0 & self.in1)
-                                  | (self.in1 & self.in2)
-                                  | (self.in2 & self.in0))
-        return m
-
-
-class MaskedFullAdder(Elaboratable):
-    """Masked Full Adder.
-
-    :attribute mask: the carry partition mask
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute mcarry: the masked carry output
-
-    FullAdders are always used with a "mask" on the output.  To keep
-    the graphviz "clean", this class performs the masking here rather
-    than inside a large for-loop.
-
-    See the following discussion as to why this is no longer derived
-    from FullAdder.  Each carry is shifted here *before* being ANDed
-    with the mask, so that an AOI cell may be used (which is more
-    gate-efficient)
-    https://en.wikipedia.org/wiki/AND-OR-Invert
-    https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
-    """
-
-    def __init__(self, width):
-        """Create a ``MaskedFullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.width = width
-        self.mask = Signal(width, reset_less=True)
-        self.mcarry = Signal(width, reset_less=True)
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        s1 = Signal(self.width, reset_less=True)
-        s2 = Signal(self.width, reset_less=True)
-        s3 = Signal(self.width, reset_less=True)
-        c1 = Signal(self.width, reset_less=True)
-        c2 = Signal(self.width, reset_less=True)
-        c3 = Signal(self.width, reset_less=True)
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += s1.eq(Cat(0, self.in0))
-        m.d.comb += s2.eq(Cat(0, self.in1))
-        m.d.comb += s3.eq(Cat(0, self.in2))
-        m.d.comb += c1.eq(s1 & s2 & self.mask)
-        m.d.comb += c2.eq(s2 & s3 & self.mask)
-        m.d.comb += c3.eq(s3 & s1 & self.mask)
-        m.d.comb += self.mcarry.eq(c1 | c2 | c3)
-        return m
-
-
-class PartitionedAdder(Elaboratable):
-    """Partitioned Adder.
-
-    Performs the final add.  The partition points are included in the
-    actual add (in one of the operands only), which causes a carry over
-    to the next bit.  Then the final output *removes* the extra bits from
-    the result.
-
-    partition: .... P... P... P... P... (32 bits)
-    a        : .... .... .... .... .... (32 bits)
-    b        : .... .... .... .... .... (32 bits)
-    exp-a    : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
-    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
-    exp-o    : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
-    o        : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
-
-    :attribute width: the bit width of the input and output. Read-only.
-    :attribute a: the first input to the adder
-    :attribute b: the second input to the adder
-    :attribute output: the sum output
-    :attribute partition_points: the input partition points. Modification not
-        supported, except for by ``Signal.eq``.
-    """
-
-    def __init__(self, width, partition_points, partition_step=1):
-        """Create a ``PartitionedAdder``.
-
-        :param width: the bit width of the input and output
-        :param partition_points: the input partition points
-        :param partition_step: a multiplier (typically double) step
-                               which in-place "expands" the partition points
-        """
-        self.width = width
-        self.pmul = partition_step
-        self.a = Signal(width, reset_less=True)
-        self.b = Signal(width, reset_less=True)
-        self.output = Signal(width, reset_less=True)
-        self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(width):
-            raise ValueError("partition_points doesn't fit in width")
-        expanded_width = 0
-        for i in range(self.width):
-            if i in self.partition_points:
-                expanded_width += 1
-            expanded_width += 1
-        self._expanded_width = expanded_width
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        expanded_a = Signal(self._expanded_width, reset_less=True)
-        expanded_b = Signal(self._expanded_width, reset_less=True)
-        expanded_o = Signal(self._expanded_width, reset_less=True)
-
-        expanded_index = 0
-        # store bits in a list, use Cat later.  graphviz is much cleaner
-        al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
-
-        # partition points are "breaks" (extra zeros or 1s) in what would
-        # otherwise be a massive long add.  when the "break" points are 0,
-        # whatever is in it (in the output) is discarded.  however when
-        # there is a "1", it causes a roll-over carry to the *next* bit.
-        # we still ignore the "break" bit in the [intermediate] output,
-        # however by that time we've got the effect that we wanted: the
-        # carry has been carried *over* the break point.
-
-        for i in range(self.width):
-            pi = i/self.pmul # double the range of the partition point test
-            if pi.is_integer() and pi in self.partition_points:
-                # add extra bit set to 0 + 0 for enabled partition points
-                # and 1 + 0 for disabled partition points
-                ea.append(expanded_a[expanded_index])
-                al.append(~self.partition_points[pi]) # add extra bit in a
-                eb.append(expanded_b[expanded_index])
-                bl.append(C(0)) # yes, add a zero
-                expanded_index += 1 # skip the extra point.  NOT in the output
-            ea.append(expanded_a[expanded_index])
-            eb.append(expanded_b[expanded_index])
-            eo.append(expanded_o[expanded_index])
-            al.append(self.a[i])
-            bl.append(self.b[i])
-            ol.append(self.output[i])
-            expanded_index += 1
-
-        # combine above using Cat
-        m.d.comb += Cat(*ea).eq(Cat(*al))
-        m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*ol).eq(Cat(*eo))
-
-        # use only one addition to take advantage of look-ahead carry and
-        # special hardware on FPGAs
-        m.d.comb += expanded_o.eq(expanded_a + expanded_b)
-        return m
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
 
 
 FULL_ADDER_INPUT_COUNT = 3