From a60f0c986284e28b204c1057ba9a501353683449 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Mon, 6 Jan 2020 21:16:08 +0000 Subject: [PATCH] split out adder code (PartitionedAdder) into module, PartitionPoints too --- src/ieee754/part_mul_add/adder.py | 202 +++++++++++++++++ src/ieee754/part_mul_add/multiply.py | 292 +------------------------ src/ieee754/part_mul_add/partpoints.py | 112 ++++++++++ 3 files changed, 316 insertions(+), 290 deletions(-) create mode 100644 src/ieee754/part_mul_add/adder.py create mode 100644 src/ieee754/part_mul_add/partpoints.py diff --git a/src/ieee754/part_mul_add/adder.py b/src/ieee754/part_mul_add/adder.py new file mode 100644 index 00000000..0c28e6c4 --- /dev/null +++ b/src/ieee754/part_mul_add/adder.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information +"""Integer Multiplication.""" + +from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl +from nmigen.hdl.ast import Assign +from abc import ABCMeta, abstractmethod +from nmigen.cli import main +from functools import reduce +from operator import or_ +from ieee754.pipeline import PipelineSpec +from nmutil.pipemodbase import PipeModBase + +from ieee754.part_mul_add.partpoints import PartitionPoints + + +class FullAdder(Elaboratable): + """Full Adder. + + :attribute in0: the first input + :attribute in1: the second input + :attribute in2: the third input + :attribute sum: the sum output + :attribute carry: the carry output + + Rather than do individual full adders (and have an array of them, + which would be very slow to simulate), this module can specify the + bit width of the inputs and outputs: in effect it performs multiple + Full 3-2 Add operations "in parallel". + """ + + def __init__(self, width): + """Create a ``FullAdder``. + + :param width: the bit width of the input and output + """ + self.in0 = Signal(width, reset_less=True) + self.in1 = Signal(width, reset_less=True) + self.in2 = Signal(width, reset_less=True) + self.sum = Signal(width, reset_less=True) + self.carry = Signal(width, reset_less=True) + + def elaborate(self, platform): + """Elaborate this module.""" + m = Module() + m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) + m.d.comb += self.carry.eq((self.in0 & self.in1) + | (self.in1 & self.in2) + | (self.in2 & self.in0)) + return m + + +class MaskedFullAdder(Elaboratable): + """Masked Full Adder. + + :attribute mask: the carry partition mask + :attribute in0: the first input + :attribute in1: the second input + :attribute in2: the third input + :attribute sum: the sum output + :attribute mcarry: the masked carry output + + FullAdders are always used with a "mask" on the output. To keep + the graphviz "clean", this class performs the masking here rather + than inside a large for-loop. + + See the following discussion as to why this is no longer derived + from FullAdder. Each carry is shifted here *before* being ANDed + with the mask, so that an AOI cell may be used (which is more + gate-efficient) + https://en.wikipedia.org/wiki/AND-OR-Invert + https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ + """ + + def __init__(self, width): + """Create a ``MaskedFullAdder``. + + :param width: the bit width of the input and output + """ + self.width = width + self.mask = Signal(width, reset_less=True) + self.mcarry = Signal(width, reset_less=True) + self.in0 = Signal(width, reset_less=True) + self.in1 = Signal(width, reset_less=True) + self.in2 = Signal(width, reset_less=True) + self.sum = Signal(width, reset_less=True) + + def elaborate(self, platform): + """Elaborate this module.""" + m = Module() + s1 = Signal(self.width, reset_less=True) + s2 = Signal(self.width, reset_less=True) + s3 = Signal(self.width, reset_less=True) + c1 = Signal(self.width, reset_less=True) + c2 = Signal(self.width, reset_less=True) + c3 = Signal(self.width, reset_less=True) + m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) + m.d.comb += s1.eq(Cat(0, self.in0)) + m.d.comb += s2.eq(Cat(0, self.in1)) + m.d.comb += s3.eq(Cat(0, self.in2)) + m.d.comb += c1.eq(s1 & s2 & self.mask) + m.d.comb += c2.eq(s2 & s3 & self.mask) + m.d.comb += c3.eq(s3 & s1 & self.mask) + m.d.comb += self.mcarry.eq(c1 | c2 | c3) + return m + + +class PartitionedAdder(Elaboratable): + """Partitioned Adder. + + Performs the final add. The partition points are included in the + actual add (in one of the operands only), which causes a carry over + to the next bit. Then the final output *removes* the extra bits from + the result. + + partition: .... P... P... P... P... (32 bits) + a : .... .... .... .... .... (32 bits) + b : .... .... .... .... .... (32 bits) + exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition) + exp-b : ....0....0....0....0.... (32 bits plus 4 zeros) + exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded) + o : .... N... N... N... N... (32 bits - x ignored, N is carry-over) + + :attribute width: the bit width of the input and output. Read-only. + :attribute a: the first input to the adder + :attribute b: the second input to the adder + :attribute output: the sum output + :attribute partition_points: the input partition points. Modification not + supported, except for by ``Signal.eq``. + """ + + def __init__(self, width, partition_points, partition_step=1): + """Create a ``PartitionedAdder``. + + :param width: the bit width of the input and output + :param partition_points: the input partition points + :param partition_step: a multiplier (typically double) step + which in-place "expands" the partition points + """ + self.width = width + self.pmul = partition_step + self.a = Signal(width, reset_less=True) + self.b = Signal(width, reset_less=True) + self.output = Signal(width, reset_less=True) + self.partition_points = PartitionPoints(partition_points) + if not self.partition_points.fits_in_width(width): + raise ValueError("partition_points doesn't fit in width") + expanded_width = 0 + for i in range(self.width): + if i in self.partition_points: + expanded_width += 1 + expanded_width += 1 + self._expanded_width = expanded_width + + def elaborate(self, platform): + """Elaborate this module.""" + m = Module() + expanded_a = Signal(self._expanded_width, reset_less=True) + expanded_b = Signal(self._expanded_width, reset_less=True) + expanded_o = Signal(self._expanded_width, reset_less=True) + + expanded_index = 0 + # store bits in a list, use Cat later. graphviz is much cleaner + al, bl, ol, ea, eb, eo = [],[],[],[],[],[] + + # partition points are "breaks" (extra zeros or 1s) in what would + # otherwise be a massive long add. when the "break" points are 0, + # whatever is in it (in the output) is discarded. however when + # there is a "1", it causes a roll-over carry to the *next* bit. + # we still ignore the "break" bit in the [intermediate] output, + # however by that time we've got the effect that we wanted: the + # carry has been carried *over* the break point. + + for i in range(self.width): + pi = i/self.pmul # double the range of the partition point test + if pi.is_integer() and pi in self.partition_points: + # add extra bit set to 0 + 0 for enabled partition points + # and 1 + 0 for disabled partition points + ea.append(expanded_a[expanded_index]) + al.append(~self.partition_points[pi]) # add extra bit in a + eb.append(expanded_b[expanded_index]) + bl.append(C(0)) # yes, add a zero + expanded_index += 1 # skip the extra point. NOT in the output + ea.append(expanded_a[expanded_index]) + eb.append(expanded_b[expanded_index]) + eo.append(expanded_o[expanded_index]) + al.append(self.a[i]) + bl.append(self.b[i]) + ol.append(self.output[i]) + expanded_index += 1 + + # combine above using Cat + m.d.comb += Cat(*ea).eq(Cat(*al)) + m.d.comb += Cat(*eb).eq(Cat(*bl)) + m.d.comb += Cat(*ol).eq(Cat(*eo)) + + # use only one addition to take advantage of look-ahead carry and + # special hardware on FPGAs + m.d.comb += expanded_o.eq(expanded_a + expanded_b) + return m + + diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 215d18c6..c9239fa6 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -11,296 +11,8 @@ from operator import or_ from ieee754.pipeline import PipelineSpec from nmutil.pipemodbase import PipeModBase - -class PartitionPoints(dict): - """Partition points and corresponding ``Value``s. - - The points at where an ALU is partitioned along with ``Value``s that - specify if the corresponding partition points are enabled. - - For example: ``{1: True, 5: True, 10: True}`` with - ``width == 16`` specifies that the ALU is split into 4 sections: - * bits 0 <= ``i`` < 1 - * bits 1 <= ``i`` < 5 - * bits 5 <= ``i`` < 10 - * bits 10 <= ``i`` < 16 - - If the partition_points were instead ``{1: True, 5: a, 10: True}`` - where ``a`` is a 1-bit ``Signal``: - * If ``a`` is asserted: - * bits 0 <= ``i`` < 1 - * bits 1 <= ``i`` < 5 - * bits 5 <= ``i`` < 10 - * bits 10 <= ``i`` < 16 - * Otherwise - * bits 0 <= ``i`` < 1 - * bits 1 <= ``i`` < 10 - * bits 10 <= ``i`` < 16 - """ - - def __init__(self, partition_points=None): - """Create a new ``PartitionPoints``. - - :param partition_points: the input partition points to values mapping. - """ - super().__init__() - if partition_points is not None: - for point, enabled in partition_points.items(): - if not isinstance(point, int): - raise TypeError("point must be a non-negative integer") - if point < 0: - raise ValueError("point must be a non-negative integer") - self[point] = Value.wrap(enabled) - - def like(self, name=None, src_loc_at=0, mul=1): - """Create a new ``PartitionPoints`` with ``Signal``s for all values. - - :param name: the base name for the new ``Signal``s. - :param mul: a multiplication factor on the indices - """ - if name is None: - name = Signal(src_loc_at=1+src_loc_at).name # get variable name - retval = PartitionPoints() - for point, enabled in self.items(): - point *= mul - retval[point] = Signal(enabled.shape(), name=f"{name}_{point}") - return retval - - def eq(self, rhs): - """Assign ``PartitionPoints`` using ``Signal.eq``.""" - if set(self.keys()) != set(rhs.keys()): - raise ValueError("incompatible point set") - for point, enabled in self.items(): - yield enabled.eq(rhs[point]) - - def as_mask(self, width, mul=1): - """Create a bit-mask from `self`. - - Each bit in the returned mask is clear only if the partition point at - the same bit-index is enabled. - - :param width: the bit width of the resulting mask - :param mul: a "multiplier" which in-place expands the partition points - typically set to "2" when used for multipliers - """ - bits = [] - for i in range(width): - i /= mul - if i.is_integer() and int(i) in self: - bits.append(~self[i]) - else: - bits.append(True) - return Cat(*bits) - - def get_max_partition_count(self, width): - """Get the maximum number of partitions. - - Gets the number of partitions when all partition points are enabled. - """ - retval = 1 - for point in self.keys(): - if point < width: - retval += 1 - return retval - - def fits_in_width(self, width): - """Check if all partition points are smaller than `width`.""" - for point in self.keys(): - if point >= width: - return False - return True - - def part_byte(self, index, mfactor=1): # mfactor used for "expanding" - if index == -1 or index == 7: - return C(True, 1) - assert index >= 0 and index < 8 - return self[(index * 8 + 8)*mfactor] - - -class FullAdder(Elaboratable): - """Full Adder. - - :attribute in0: the first input - :attribute in1: the second input - :attribute in2: the third input - :attribute sum: the sum output - :attribute carry: the carry output - - Rather than do individual full adders (and have an array of them, - which would be very slow to simulate), this module can specify the - bit width of the inputs and outputs: in effect it performs multiple - Full 3-2 Add operations "in parallel". - """ - - def __init__(self, width): - """Create a ``FullAdder``. - - :param width: the bit width of the input and output - """ - self.in0 = Signal(width, reset_less=True) - self.in1 = Signal(width, reset_less=True) - self.in2 = Signal(width, reset_less=True) - self.sum = Signal(width, reset_less=True) - self.carry = Signal(width, reset_less=True) - - def elaborate(self, platform): - """Elaborate this module.""" - m = Module() - m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) - m.d.comb += self.carry.eq((self.in0 & self.in1) - | (self.in1 & self.in2) - | (self.in2 & self.in0)) - return m - - -class MaskedFullAdder(Elaboratable): - """Masked Full Adder. - - :attribute mask: the carry partition mask - :attribute in0: the first input - :attribute in1: the second input - :attribute in2: the third input - :attribute sum: the sum output - :attribute mcarry: the masked carry output - - FullAdders are always used with a "mask" on the output. To keep - the graphviz "clean", this class performs the masking here rather - than inside a large for-loop. - - See the following discussion as to why this is no longer derived - from FullAdder. Each carry is shifted here *before* being ANDed - with the mask, so that an AOI cell may be used (which is more - gate-efficient) - https://en.wikipedia.org/wiki/AND-OR-Invert - https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ - """ - - def __init__(self, width): - """Create a ``MaskedFullAdder``. - - :param width: the bit width of the input and output - """ - self.width = width - self.mask = Signal(width, reset_less=True) - self.mcarry = Signal(width, reset_less=True) - self.in0 = Signal(width, reset_less=True) - self.in1 = Signal(width, reset_less=True) - self.in2 = Signal(width, reset_less=True) - self.sum = Signal(width, reset_less=True) - - def elaborate(self, platform): - """Elaborate this module.""" - m = Module() - s1 = Signal(self.width, reset_less=True) - s2 = Signal(self.width, reset_less=True) - s3 = Signal(self.width, reset_less=True) - c1 = Signal(self.width, reset_less=True) - c2 = Signal(self.width, reset_less=True) - c3 = Signal(self.width, reset_less=True) - m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) - m.d.comb += s1.eq(Cat(0, self.in0)) - m.d.comb += s2.eq(Cat(0, self.in1)) - m.d.comb += s3.eq(Cat(0, self.in2)) - m.d.comb += c1.eq(s1 & s2 & self.mask) - m.d.comb += c2.eq(s2 & s3 & self.mask) - m.d.comb += c3.eq(s3 & s1 & self.mask) - m.d.comb += self.mcarry.eq(c1 | c2 | c3) - return m - - -class PartitionedAdder(Elaboratable): - """Partitioned Adder. - - Performs the final add. The partition points are included in the - actual add (in one of the operands only), which causes a carry over - to the next bit. Then the final output *removes* the extra bits from - the result. - - partition: .... P... P... P... P... (32 bits) - a : .... .... .... .... .... (32 bits) - b : .... .... .... .... .... (32 bits) - exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition) - exp-b : ....0....0....0....0.... (32 bits plus 4 zeros) - exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded) - o : .... N... N... N... N... (32 bits - x ignored, N is carry-over) - - :attribute width: the bit width of the input and output. Read-only. - :attribute a: the first input to the adder - :attribute b: the second input to the adder - :attribute output: the sum output - :attribute partition_points: the input partition points. Modification not - supported, except for by ``Signal.eq``. - """ - - def __init__(self, width, partition_points, partition_step=1): - """Create a ``PartitionedAdder``. - - :param width: the bit width of the input and output - :param partition_points: the input partition points - :param partition_step: a multiplier (typically double) step - which in-place "expands" the partition points - """ - self.width = width - self.pmul = partition_step - self.a = Signal(width, reset_less=True) - self.b = Signal(width, reset_less=True) - self.output = Signal(width, reset_less=True) - self.partition_points = PartitionPoints(partition_points) - if not self.partition_points.fits_in_width(width): - raise ValueError("partition_points doesn't fit in width") - expanded_width = 0 - for i in range(self.width): - if i in self.partition_points: - expanded_width += 1 - expanded_width += 1 - self._expanded_width = expanded_width - - def elaborate(self, platform): - """Elaborate this module.""" - m = Module() - expanded_a = Signal(self._expanded_width, reset_less=True) - expanded_b = Signal(self._expanded_width, reset_less=True) - expanded_o = Signal(self._expanded_width, reset_less=True) - - expanded_index = 0 - # store bits in a list, use Cat later. graphviz is much cleaner - al, bl, ol, ea, eb, eo = [],[],[],[],[],[] - - # partition points are "breaks" (extra zeros or 1s) in what would - # otherwise be a massive long add. when the "break" points are 0, - # whatever is in it (in the output) is discarded. however when - # there is a "1", it causes a roll-over carry to the *next* bit. - # we still ignore the "break" bit in the [intermediate] output, - # however by that time we've got the effect that we wanted: the - # carry has been carried *over* the break point. - - for i in range(self.width): - pi = i/self.pmul # double the range of the partition point test - if pi.is_integer() and pi in self.partition_points: - # add extra bit set to 0 + 0 for enabled partition points - # and 1 + 0 for disabled partition points - ea.append(expanded_a[expanded_index]) - al.append(~self.partition_points[pi]) # add extra bit in a - eb.append(expanded_b[expanded_index]) - bl.append(C(0)) # yes, add a zero - expanded_index += 1 # skip the extra point. NOT in the output - ea.append(expanded_a[expanded_index]) - eb.append(expanded_b[expanded_index]) - eo.append(expanded_o[expanded_index]) - al.append(self.a[i]) - bl.append(self.b[i]) - ol.append(self.output[i]) - expanded_index += 1 - - # combine above using Cat - m.d.comb += Cat(*ea).eq(Cat(*al)) - m.d.comb += Cat(*eb).eq(Cat(*bl)) - m.d.comb += Cat(*ol).eq(Cat(*eo)) - - # use only one addition to take advantage of look-ahead carry and - # special hardware on FPGAs - m.d.comb += expanded_o.eq(expanded_a + expanded_b) - return m +from ieee754.part_mul_add.partpoints import PartitionPoints +from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder FULL_ADDER_INPUT_COUNT = 3 diff --git a/src/ieee754/part_mul_add/partpoints.py b/src/ieee754/part_mul_add/partpoints.py new file mode 100644 index 00000000..0ea3668f --- /dev/null +++ b/src/ieee754/part_mul_add/partpoints.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information +"""Integer Multiplication.""" + +from nmigen import Signal, Value, Cat, C + + +class PartitionPoints(dict): + """Partition points and corresponding ``Value``s. + + The points at where an ALU is partitioned along with ``Value``s that + specify if the corresponding partition points are enabled. + + For example: ``{1: True, 5: True, 10: True}`` with + ``width == 16`` specifies that the ALU is split into 4 sections: + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 5 + * bits 5 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + + If the partition_points were instead ``{1: True, 5: a, 10: True}`` + where ``a`` is a 1-bit ``Signal``: + * If ``a`` is asserted: + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 5 + * bits 5 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + * Otherwise + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + """ + + def __init__(self, partition_points=None): + """Create a new ``PartitionPoints``. + + :param partition_points: the input partition points to values mapping. + """ + super().__init__() + if partition_points is not None: + for point, enabled in partition_points.items(): + if not isinstance(point, int): + raise TypeError("point must be a non-negative integer") + if point < 0: + raise ValueError("point must be a non-negative integer") + self[point] = Value.wrap(enabled) + + def like(self, name=None, src_loc_at=0, mul=1): + """Create a new ``PartitionPoints`` with ``Signal``s for all values. + + :param name: the base name for the new ``Signal``s. + :param mul: a multiplication factor on the indices + """ + if name is None: + name = Signal(src_loc_at=1+src_loc_at).name # get variable name + retval = PartitionPoints() + for point, enabled in self.items(): + point *= mul + retval[point] = Signal(enabled.shape(), name=f"{name}_{point}") + return retval + + def eq(self, rhs): + """Assign ``PartitionPoints`` using ``Signal.eq``.""" + if set(self.keys()) != set(rhs.keys()): + raise ValueError("incompatible point set") + for point, enabled in self.items(): + yield enabled.eq(rhs[point]) + + def as_mask(self, width, mul=1): + """Create a bit-mask from `self`. + + Each bit in the returned mask is clear only if the partition point at + the same bit-index is enabled. + + :param width: the bit width of the resulting mask + :param mul: a "multiplier" which in-place expands the partition points + typically set to "2" when used for multipliers + """ + bits = [] + for i in range(width): + i /= mul + if i.is_integer() and int(i) in self: + bits.append(~self[i]) + else: + bits.append(True) + return Cat(*bits) + + def get_max_partition_count(self, width): + """Get the maximum number of partitions. + + Gets the number of partitions when all partition points are enabled. + """ + retval = 1 + for point in self.keys(): + if point < width: + retval += 1 + return retval + + def fits_in_width(self, width): + """Check if all partition points are smaller than `width`.""" + for point in self.keys(): + if point >= width: + return False + return True + + def part_byte(self, index, mfactor=1): # mfactor used for "expanding" + if index == -1 or index == 7: + return C(True, 1) + assert index >= 0 and index < 8 + return self[(index * 8 + 8)*mfactor] + + -- 2.30.2