from nmigen.cli import main
from functools import reduce
from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
-
-class PartitionPoints(dict):
- """Partition points and corresponding ``Value``s.
-
- The points at where an ALU is partitioned along with ``Value``s that
- specify if the corresponding partition points are enabled.
-
- For example: ``{1: True, 5: True, 10: True}`` with
- ``width == 16`` specifies that the ALU is split into 4 sections:
- * bits 0 <= ``i`` < 1
- * bits 1 <= ``i`` < 5
- * bits 5 <= ``i`` < 10
- * bits 10 <= ``i`` < 16
-
- If the partition_points were instead ``{1: True, 5: a, 10: True}``
- where ``a`` is a 1-bit ``Signal``:
- * If ``a`` is asserted:
- * bits 0 <= ``i`` < 1
- * bits 1 <= ``i`` < 5
- * bits 5 <= ``i`` < 10
- * bits 10 <= ``i`` < 16
- * Otherwise
- * bits 0 <= ``i`` < 1
- * bits 1 <= ``i`` < 10
- * bits 10 <= ``i`` < 16
- """
-
- def __init__(self, partition_points=None):
- """Create a new ``PartitionPoints``.
-
- :param partition_points: the input partition points to values mapping.
- """
- super().__init__()
- if partition_points is not None:
- for point, enabled in partition_points.items():
- if not isinstance(point, int):
- raise TypeError("point must be a non-negative integer")
- if point < 0:
- raise ValueError("point must be a non-negative integer")
- self[point] = Value.wrap(enabled)
-
- def like(self, name=None, src_loc_at=0, mul=1):
- """Create a new ``PartitionPoints`` with ``Signal``s for all values.
-
- :param name: the base name for the new ``Signal``s.
- :param mul: a multiplication factor on the indices
- """
- if name is None:
- name = Signal(src_loc_at=1+src_loc_at).name # get variable name
- retval = PartitionPoints()
- for point, enabled in self.items():
- point *= mul
- retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
- return retval
-
- def eq(self, rhs):
- """Assign ``PartitionPoints`` using ``Signal.eq``."""
- if set(self.keys()) != set(rhs.keys()):
- raise ValueError("incompatible point set")
- for point, enabled in self.items():
- yield enabled.eq(rhs[point])
-
- def as_mask(self, width):
- """Create a bit-mask from `self`.
-
- Each bit in the returned mask is clear only if the partition point at
- the same bit-index is enabled.
-
- :param width: the bit width of the resulting mask
- """
- bits = []
- for i in range(width):
- if i in self:
- bits.append(~self[i])
- else:
- bits.append(True)
- return Cat(*bits)
-
- def get_max_partition_count(self, width):
- """Get the maximum number of partitions.
-
- Gets the number of partitions when all partition points are enabled.
- """
- retval = 1
- for point in self.keys():
- if point < width:
- retval += 1
- return retval
-
- def fits_in_width(self, width):
- """Check if all partition points are smaller than `width`."""
- for point in self.keys():
- if point >= width:
- return False
- return True
-
- def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
- if index == -1 or index == 7:
- return C(True, 1)
- assert index >= 0 and index < 8
- return self[(index * 8 + 8)*mfactor]
-
-
-class FullAdder(Elaboratable):
- """Full Adder.
-
- :attribute in0: the first input
- :attribute in1: the second input
- :attribute in2: the third input
- :attribute sum: the sum output
- :attribute carry: the carry output
-
- Rather than do individual full adders (and have an array of them,
- which would be very slow to simulate), this module can specify the
- bit width of the inputs and outputs: in effect it performs multiple
- Full 3-2 Add operations "in parallel".
- """
-
- def __init__(self, width):
- """Create a ``FullAdder``.
-
- :param width: the bit width of the input and output
- """
- self.in0 = Signal(width, reset_less=True)
- self.in1 = Signal(width, reset_less=True)
- self.in2 = Signal(width, reset_less=True)
- self.sum = Signal(width, reset_less=True)
- self.carry = Signal(width, reset_less=True)
-
- def elaborate(self, platform):
- """Elaborate this module."""
- m = Module()
- m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
- m.d.comb += self.carry.eq((self.in0 & self.in1)
- | (self.in1 & self.in2)
- | (self.in2 & self.in0))
- return m
-
-
-class MaskedFullAdder(Elaboratable):
- """Masked Full Adder.
-
- :attribute mask: the carry partition mask
- :attribute in0: the first input
- :attribute in1: the second input
- :attribute in2: the third input
- :attribute sum: the sum output
- :attribute mcarry: the masked carry output
-
- FullAdders are always used with a "mask" on the output. To keep
- the graphviz "clean", this class performs the masking here rather
- than inside a large for-loop.
-
- See the following discussion as to why this is no longer derived
- from FullAdder. Each carry is shifted here *before* being ANDed
- with the mask, so that an AOI cell may be used (which is more
- gate-efficient)
- https://en.wikipedia.org/wiki/AND-OR-Invert
- https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
- """
-
- def __init__(self, width):
- """Create a ``MaskedFullAdder``.
-
- :param width: the bit width of the input and output
- """
- self.width = width
- self.mask = Signal(width, reset_less=True)
- self.mcarry = Signal(width, reset_less=True)
- self.in0 = Signal(width, reset_less=True)
- self.in1 = Signal(width, reset_less=True)
- self.in2 = Signal(width, reset_less=True)
- self.sum = Signal(width, reset_less=True)
-
- def elaborate(self, platform):
- """Elaborate this module."""
- m = Module()
- s1 = Signal(self.width, reset_less=True)
- s2 = Signal(self.width, reset_less=True)
- s3 = Signal(self.width, reset_less=True)
- c1 = Signal(self.width, reset_less=True)
- c2 = Signal(self.width, reset_less=True)
- c3 = Signal(self.width, reset_less=True)
- m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
- m.d.comb += s1.eq(Cat(0, self.in0))
- m.d.comb += s2.eq(Cat(0, self.in1))
- m.d.comb += s3.eq(Cat(0, self.in2))
- m.d.comb += c1.eq(s1 & s2 & self.mask)
- m.d.comb += c2.eq(s2 & s3 & self.mask)
- m.d.comb += c3.eq(s3 & s1 & self.mask)
- m.d.comb += self.mcarry.eq(c1 | c2 | c3)
- return m
-
-
-class PartitionedAdder(Elaboratable):
- """Partitioned Adder.
-
- Performs the final add. The partition points are included in the
- actual add (in one of the operands only), which causes a carry over
- to the next bit. Then the final output *removes* the extra bits from
- the result.
-
- partition: .... P... P... P... P... (32 bits)
- a : .... .... .... .... .... (32 bits)
- b : .... .... .... .... .... (32 bits)
- exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
- exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
- exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
- o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
-
- :attribute width: the bit width of the input and output. Read-only.
- :attribute a: the first input to the adder
- :attribute b: the second input to the adder
- :attribute output: the sum output
- :attribute partition_points: the input partition points. Modification not
- supported, except for by ``Signal.eq``.
- """
-
- def __init__(self, width, partition_points):
- """Create a ``PartitionedAdder``.
-
- :param width: the bit width of the input and output
- :param partition_points: the input partition points
- """
- self.width = width
- self.a = Signal(width, reset_less=True)
- self.b = Signal(width, reset_less=True)
- self.output = Signal(width, reset_less=True)
- self.partition_points = PartitionPoints(partition_points)
- if not self.partition_points.fits_in_width(width):
- raise ValueError("partition_points doesn't fit in width")
- expanded_width = 0
- for i in range(self.width):
- if i in self.partition_points:
- expanded_width += 1
- expanded_width += 1
- self._expanded_width = expanded_width
-
- def elaborate(self, platform):
- """Elaborate this module."""
- m = Module()
- expanded_a = Signal(self._expanded_width, reset_less=True)
- expanded_b = Signal(self._expanded_width, reset_less=True)
- expanded_o = Signal(self._expanded_width, reset_less=True)
-
- expanded_index = 0
- # store bits in a list, use Cat later. graphviz is much cleaner
- al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
-
- # partition points are "breaks" (extra zeros or 1s) in what would
- # otherwise be a massive long add. when the "break" points are 0,
- # whatever is in it (in the output) is discarded. however when
- # there is a "1", it causes a roll-over carry to the *next* bit.
- # we still ignore the "break" bit in the [intermediate] output,
- # however by that time we've got the effect that we wanted: the
- # carry has been carried *over* the break point.
-
- for i in range(self.width):
- if i in self.partition_points:
- # add extra bit set to 0 + 0 for enabled partition points
- # and 1 + 0 for disabled partition points
- ea.append(expanded_a[expanded_index])
- al.append(~self.partition_points[i]) # add extra bit in a
- eb.append(expanded_b[expanded_index])
- bl.append(C(0)) # yes, add a zero
- expanded_index += 1 # skip the extra point. NOT in the output
- ea.append(expanded_a[expanded_index])
- eb.append(expanded_b[expanded_index])
- eo.append(expanded_o[expanded_index])
- al.append(self.a[i])
- bl.append(self.b[i])
- ol.append(self.output[i])
- expanded_index += 1
-
- # combine above using Cat
- m.d.comb += Cat(*ea).eq(Cat(*al))
- m.d.comb += Cat(*eb).eq(Cat(*bl))
- m.d.comb += Cat(*ol).eq(Cat(*eo))
-
- # use only one addition to take advantage of look-ahead carry and
- # special hardware on FPGAs
- m.d.comb += expanded_o.eq(expanded_a + expanded_b)
- return m
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
FULL_ADDER_INPUT_COUNT = 3
class AddReduceData:
- def __init__(self, ppoints, n_inputs, output_width, n_parts):
+ def __init__(self, part_pts, n_inputs, output_width, n_parts):
self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
for i in range(n_parts)]
- self.inputs = [Signal(output_width, name=f"inputs_{i}",
+ self.terms = [Signal(output_width, name=f"terms_{i}",
reset_less=True)
for i in range(n_inputs)]
- self.reg_partition_points = ppoints.like()
+ self.part_pts = part_pts.like()
- def eq_from(self, reg_partition_points, inputs, part_ops):
- return [self.reg_partition_points.eq(reg_partition_points)] + \
- [self.inputs[i].eq(inputs[i])
- for i in range(len(self.inputs))] + \
+ def eq_from(self, part_pts, inputs, part_ops):
+ return [self.part_pts.eq(part_pts)] + \
+ [self.terms[i].eq(inputs[i])
+ for i in range(len(self.terms))] + \
[self.part_ops[i].eq(part_ops[i])
for i in range(len(self.part_ops))]
def eq(self, rhs):
- return self.eq_from(rhs.reg_partition_points, rhs.inputs, rhs.part_ops)
+ return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
class FinalReduceData:
- def __init__(self, ppoints, output_width, n_parts):
+ def __init__(self, part_pts, output_width, n_parts):
self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
for i in range(n_parts)]
self.output = Signal(output_width, reset_less=True)
- self.reg_partition_points = ppoints.like()
+ self.part_pts = part_pts.like()
- def eq_from(self, reg_partition_points, output, part_ops):
- return [self.reg_partition_points.eq(reg_partition_points)] + \
+ def eq_from(self, part_pts, output, part_ops):
+ return [self.part_pts.eq(part_pts)] + \
[self.output.eq(output)] + \
[self.part_ops[i].eq(part_ops[i])
for i in range(len(self.part_ops))]
def eq(self, rhs):
- return self.eq_from(rhs.reg_partition_points, rhs.output, rhs.part_ops)
+ return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
""" Final stage of add reduce
"""
- def __init__(self, n_inputs, output_width, n_parts, register_levels,
- partition_points):
- self.i = AddReduceData(partition_points, n_inputs,
- output_width, n_parts)
- self.o = FinalReduceData(partition_points, output_width, n_parts)
- self.output_width = output_width
+ def __init__(self, pspec, lidx, n_inputs, partition_points,
+ partition_step=1):
+ self.lidx = lidx
+ self.partition_step = partition_step
+ self.output_width = pspec.width * 2
self.n_inputs = n_inputs
- self.n_parts = n_parts
- self.register_levels = list(register_levels)
+ self.n_parts = pspec.n_parts
self.partition_points = PartitionPoints(partition_points)
- if not self.partition_points.fits_in_width(output_width):
+ if not self.partition_points.fits_in_width(self.output_width):
raise ValueError("partition_points doesn't fit in output_width")
+ super().__init__(pspec, "finaladd")
+
+ def ispec(self):
+ return AddReduceData(self.partition_points, self.n_inputs,
+ self.output_width, self.n_parts)
+
+ def ospec(self):
+ return FinalReduceData(self.partition_points,
+ self.output_width, self.n_parts)
+
def elaborate(self, platform):
"""Elaborate this module."""
m = Module()
m.d.comb += output.eq(0)
elif self.n_inputs == 1:
# handle single input
- m.d.comb += output.eq(self.i.inputs[0])
+ m.d.comb += output.eq(self.i.terms[0])
else:
# base case for adding 2 inputs
assert self.n_inputs == 2
- adder = PartitionedAdder(output_width, self.i.reg_partition_points)
+ adder = PartitionedAdder(output_width,
+ self.i.part_pts, self.partition_step)
m.submodules.final_adder = adder
- m.d.comb += adder.a.eq(self.i.inputs[0])
- m.d.comb += adder.b.eq(self.i.inputs[1])
+ m.d.comb += adder.a.eq(self.i.terms[0])
+ m.d.comb += adder.b.eq(self.i.terms[1])
m.d.comb += output.eq(adder.output)
# create output
- m.d.comb += self.o.eq_from(self.i.reg_partition_points, output,
+ m.d.comb += self.o.eq_from(self.i.part_pts, output,
self.i.part_ops)
return m
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
"""Add list of numbers together.
:attribute inputs: input ``Signal``s to be summed. Modification not
supported, except for by ``Signal.eq``.
"""
- def __init__(self, n_inputs, output_width, n_parts, register_levels,
- partition_points):
+ def __init__(self, pspec, lidx, n_inputs, partition_points,
+ partition_step=1):
"""Create an ``AddReduce``.
:param inputs: input ``Signal``s to be summed.
:param output_width: bit-width of ``output``.
- :param register_levels: List of nesting levels that should have
- pipeline registers.
:param partition_points: the input partition points.
"""
+ self.lidx = lidx
+ self.partition_step = partition_step
self.n_inputs = n_inputs
- self.n_parts = n_parts
- self.output_width = output_width
- self.i = AddReduceData(partition_points, n_inputs,
- output_width, n_parts)
- self.register_levels = list(register_levels)
+ self.n_parts = pspec.n_parts
+ self.output_width = pspec.width * 2
self.partition_points = PartitionPoints(partition_points)
- if not self.partition_points.fits_in_width(output_width):
+ if not self.partition_points.fits_in_width(self.output_width):
raise ValueError("partition_points doesn't fit in output_width")
- max_level = AddReduceSingle.get_max_level(n_inputs)
- for level in self.register_levels:
- if level > max_level:
- raise ValueError(
- "not enough adder levels for specified register levels")
-
self.groups = AddReduceSingle.full_adder_groups(n_inputs)
- n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
- self.o = AddReduceData(partition_points, n_terms, output_width, n_parts)
+ self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+
+ super().__init__(pspec, "addreduce_%d" % lidx)
+
+ def ispec(self):
+ return AddReduceData(self.partition_points, self.n_inputs,
+ self.output_width, self.n_parts)
+
+ def ospec(self):
+ return AddReduceData(self.partition_points, self.n_terms,
+ self.output_width, self.n_parts)
@staticmethod
def calc_n_inputs(n_inputs, groups):
terms.append(adder_i.mcarry)
# handle the remaining inputs.
if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
- terms.append(self.i.inputs[-1])
+ terms.append(self.i.terms[-1])
elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
# Just pass the terms to the next layer, since we wouldn't gain
# anything by using a half adder since there would still be 2 terms
# and just passing the terms to the next layer saves gates.
- terms.append(self.i.inputs[-2])
- terms.append(self.i.inputs[-1])
+ terms.append(self.i.terms[-2])
+ terms.append(self.i.terms[-1])
else:
assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
# copy the intermediate terms to the output
for i, value in enumerate(terms):
- m.d.comb += self.o.inputs[i].eq(value)
+ m.d.comb += self.o.terms[i].eq(value)
# copy reg part points and part ops to output
- m.d.comb += self.o.reg_partition_points.eq(self.i.reg_partition_points)
+ m.d.comb += self.o.part_pts.eq(self.i.part_pts)
m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
for i in range(len(self.i.part_ops))]
# set up the partition mask (for the adders)
part_mask = Signal(self.output_width, reset_less=True)
- mask = self.i.reg_partition_points.as_mask(self.output_width)
+ # get partition points as a mask
+ mask = self.i.part_pts.as_mask(self.output_width,
+ mul=self.partition_step)
m.d.comb += part_mask.eq(mask)
# add and link the intermediate term modules
for i, (iidx, adder_i) in enumerate(adders):
setattr(m.submodules, f"adder_{i}", adder_i)
- m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
- m.d.comb += adder_i.in1.eq(self.i.inputs[iidx + 1])
- m.d.comb += adder_i.in2.eq(self.i.inputs[iidx + 2])
+ m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
+ m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
+ m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
m.d.comb += adder_i.mask.eq(part_mask)
return m
-class AddReduce(Elaboratable):
+class AddReduceInternal:
+ """Iteratively Add list of numbers together.
+
+ :attribute inputs: input ``Signal``s to be summed. Modification not
+ supported, except for by ``Signal.eq``.
+ :attribute register_levels: List of nesting levels that should have
+ pipeline registers.
+ :attribute output: output sum.
+ :attribute partition_points: the input partition points. Modification not
+ supported, except for by ``Signal.eq``.
+ """
+
+ def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
+ """Create an ``AddReduce``.
+
+ :param inputs: input ``Signal``s to be summed.
+ :param output_width: bit-width of ``output``.
+ :param partition_points: the input partition points.
+ """
+ self.pspec = pspec
+ self.n_inputs = n_inputs
+ self.output_width = pspec.width * 2
+ self.partition_points = part_pts
+ self.partition_step = partition_step
+
+ self.create_levels()
+
+ def create_levels(self):
+ """creates reduction levels"""
+
+ mods = []
+ partition_points = self.partition_points
+ ilen = self.n_inputs
+ while True:
+ groups = AddReduceSingle.full_adder_groups(ilen)
+ if len(groups) == 0:
+ break
+ lidx = len(mods)
+ next_level = AddReduceSingle(self.pspec, lidx, ilen,
+ partition_points,
+ self.partition_step)
+ mods.append(next_level)
+ partition_points = next_level.i.part_pts
+ ilen = len(next_level.o.terms)
+
+ lidx = len(mods)
+ next_level = FinalAdd(self.pspec, lidx, ilen,
+ partition_points, self.partition_step)
+ mods.append(next_level)
+
+ self.levels = mods
+
+
+class AddReduce(AddReduceInternal, Elaboratable):
"""Recursively Add list of numbers together.
:attribute inputs: input ``Signal``s to be summed. Modification not
supported, except for by ``Signal.eq``.
"""
- def __init__(self, inputs, output_width, register_levels, partition_points,
- part_ops):
+ def __init__(self, inputs, output_width, register_levels, part_pts,
+ part_ops, partition_step=1):
"""Create an ``AddReduce``.
:param inputs: input ``Signal``s to be summed.
pipeline registers.
:param partition_points: the input partition points.
"""
- self.inputs = inputs
- self.part_ops = part_ops
+ self._inputs = inputs
+ self._part_pts = part_pts
+ self._part_ops = part_ops
n_parts = len(part_ops)
- self.o = FinalReduceData(partition_points, output_width, n_parts)
- self.output_width = output_width
+ self.i = AddReduceData(part_pts, len(inputs),
+ output_width, n_parts)
+ AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+ partition_step)
+ self.o = FinalReduceData(part_pts, output_width, n_parts)
self.register_levels = register_levels
- self.partition_points = partition_points
-
- self.create_levels()
@staticmethod
def get_max_level(input_count):
if level > 0:
yield level - 1
- def create_levels(self):
- """creates reduction levels"""
-
- mods = []
- next_levels = self.register_levels
- partition_points = self.partition_points
- part_ops = self.part_ops
- n_parts = len(part_ops)
- inputs = self.inputs
- ilen = len(inputs)
- while True:
- groups = AddReduceSingle.full_adder_groups(len(inputs))
- if len(groups) == 0:
- break
- next_level = AddReduceSingle(ilen, self.output_width, n_parts,
- next_levels, partition_points)
- mods.append(next_level)
- next_levels = list(AddReduce.next_register_levels(next_levels))
- partition_points = next_level.i.reg_partition_points
- inputs = next_level.o.inputs
- ilen = len(inputs)
- part_ops = next_level.i.part_ops
-
- next_level = FinalAdd(ilen, self.output_width, n_parts,
- next_levels, partition_points)
- mods.append(next_level)
-
- self.levels = mods
-
def elaborate(self, platform):
"""Elaborate this module."""
m = Module()
+ m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+
for i, next_level in enumerate(self.levels):
setattr(m.submodules, "next_level%d" % i, next_level)
- partition_points = self.partition_points
- inputs = self.inputs
- part_ops = self.part_ops
- n_parts = len(part_ops)
- n_inputs = len(inputs)
- output_width = self.output_width
- i = AddReduceData(partition_points, n_inputs, output_width, n_parts)
- m.d.comb += i.eq_from(partition_points, inputs, part_ops)
+ i = self.i
for idx in range(len(self.levels)):
mcur = self.levels[idx]
- if 0 in mcur.register_levels:
+ if idx in self.register_levels:
m.d.sync += mcur.i.eq(i)
else:
m.d.comb += mcur.i.eq(i)
bsb = Signal(self.width, reset_less=True)
a_index, b_index = self.a_index, self.b_index
pwidth = self.pwidth
- m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
- m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+ m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+ m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
m.d.comb += self.ti.eq(bsa * bsb)
m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
"""
asel = Signal(width, reset_less=True)
bsel = Signal(width, reset_less=True)
a_index, b_index = self.a_index, self.b_index
- m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
- m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+ m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+ m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
m.d.comb += self.ti.eq(bsa * bsb)
class Parts(Elaboratable):
- def __init__(self, pbwid, epps, n_parts):
+ def __init__(self, pbwid, part_pts, n_parts):
self.pbwid = pbwid
# inputs
- self.epps = PartitionPoints.like(epps, name="epps") # expanded points
+ self.part_pts = PartitionPoints.like(part_pts)
# outputs
self.parts = [Signal(name=f"part_{i}", reset_less=True)
for i in range(n_parts)]
def elaborate(self, platform):
m = Module()
- epps, parts = self.epps, self.parts
+ part_pts, parts = self.part_pts, self.parts
# collect part-bytes (double factor because the input is extended)
pbs = Signal(self.pbwid, reset_less=True)
tl = []
for i in range(self.pbwid):
pb = Signal(name="pb%d" % i, reset_less=True)
- m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
+ m.d.comb += pb.eq(part_pts.part_byte(i))
tl.append(pb)
m.d.comb += pbs.eq(Cat(*tl))
the extra terms - as separate terms - are then thrown at the
AddReduce alongside the multiplication part-results.
"""
- def __init__(self, epps, width, n_parts, n_levels, pbwid):
+ def __init__(self, part_pts, width, n_parts, pbwid):
self.pbwid = pbwid
- self.epps = epps
+ self.part_pts = part_pts
# inputs
self.a = Signal(64, reset_less=True)
m = Module()
pbs, parts = self.pbs, self.parts
- epps = self.epps
- m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
- m.d.comb += p.epps.eq(epps)
+ part_pts = self.part_pts
+ m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
+ m.d.comb += p.part_pts.eq(part_pts)
parts = p.parts
byte_count = 8 // len(parts)
pa = LSBNegTerm(bit_wid)
setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
m.d.comb += pa.part.eq(parts[i])
- m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
+ m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
nat.append(pa.nt)
pb = LSBNegTerm(bit_wid)
setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
m.d.comb += pb.part.eq(parts[i])
- m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
+ m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
nbt.append(pb.nt)
op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
m.d.comb += op.eq(
Mux(self.part_ops[sel * i] == OP_MUL_LOW,
- self.intermed.part(i * w*2, w),
- self.intermed.part(i * w*2 + w, w)))
+ self.intermed.bit_select(i * w*2, w),
+ self.intermed.bit_select(i * w*2 + w, w)))
ol.append(op)
m.d.comb += self.output.eq(Cat(*ol))
return m
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
""" selects the final output based on the partitioning.
each byte is selectable independently, i.e. it is possible
that some partitions requested 8-bit computation whilst others
requested 16 or 32 bit.
"""
- def __init__(self, output_width, n_parts, partition_points):
- self.expanded_part_points = partition_points
- self.i = IntermediateData(partition_points, output_width, n_parts)
- self.out_wid = output_width//2
- # output
- self.out = Signal(self.out_wid, reset_less=True)
- self.intermediate_output = Signal(output_width, reset_less=True)
+ def __init__(self, pspec, part_pts):
+
+ self.part_pts = part_pts
+ self.output_width = pspec.width * 2
+ self.n_parts = pspec.n_parts
+ self.out_wid = pspec.width
+
+ super().__init__(pspec, "finalout")
+
+ def ispec(self):
+ return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+ def ospec(self):
+ return OutputData()
def elaborate(self, platform):
m = Module()
- eps = self.expanded_part_points
- m.submodules.p_8 = p_8 = Parts(8, eps, 8)
- m.submodules.p_16 = p_16 = Parts(8, eps, 4)
- m.submodules.p_32 = p_32 = Parts(8, eps, 2)
- m.submodules.p_64 = p_64 = Parts(8, eps, 1)
+ part_pts = self.part_pts
+ m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
+ m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
+ m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
+ m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
- out_part_pts = self.i.reg_partition_points
+ out_part_pts = self.i.part_pts
# temporaries
d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
i32 = Signal(self.out_wid, reset_less=True)
i64 = Signal(self.out_wid, reset_less=True)
- m.d.comb += p_8.epps.eq(out_part_pts)
- m.d.comb += p_16.epps.eq(out_part_pts)
- m.d.comb += p_32.epps.eq(out_part_pts)
- m.d.comb += p_64.epps.eq(out_part_pts)
+ m.d.comb += p_8.part_pts.eq(out_part_pts)
+ m.d.comb += p_16.part_pts.eq(out_part_pts)
+ m.d.comb += p_32.part_pts.eq(out_part_pts)
+ m.d.comb += p_64.part_pts.eq(out_part_pts)
for i in range(len(p_8.parts)):
m.d.comb += d8[i].eq(p_8.parts[i])
op = Signal(8, reset_less=True, name="op_%d" % i)
m.d.comb += op.eq(
Mux(d8[i] | d16[i // 2],
- Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)),
- Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8))))
+ Mux(d8[i], i8.bit_select(i * 8, 8),
+ i16.bit_select(i * 8, 8)),
+ Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+ i64.bit_select(i * 8, 8))))
ol.append(op)
- m.d.comb += self.out.eq(Cat(*ol))
- m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
+
+ # create outputs
+ m.d.comb += self.o.output.eq(Cat(*ol))
+ m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
return m
class IntermediateData:
- def __init__(self, ppoints, output_width, n_parts):
+ def __init__(self, part_pts, output_width, n_parts):
self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
for i in range(n_parts)]
- self.reg_partition_points = ppoints.like()
+ self.part_pts = part_pts.like()
self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
for i in range(4)]
# intermediates (needed for unit tests)
self.intermediate_output = Signal(output_width)
- def eq_from(self, reg_partition_points, outputs, intermediate_output,
+ def eq_from(self, part_pts, outputs, intermediate_output,
part_ops):
- return [self.reg_partition_points.eq(reg_partition_points)] + \
+ return [self.part_pts.eq(part_pts)] + \
[self.intermediate_output.eq(intermediate_output)] + \
[self.outputs[i].eq(outputs[i])
for i in range(4)] + \
for i in range(len(self.part_ops))]
def eq(self, rhs):
- return self.eq_from(rhs.reg_partition_points, rhs.outputs,
+ return self.eq_from(rhs.part_pts, rhs.outputs,
rhs.intermediate_output, rhs.part_ops)
-class Intermediates(Elaboratable):
+class InputData:
+
+ def __init__(self):
+ self.a = Signal(64)
+ self.b = Signal(64)
+ self.part_pts = PartitionPoints()
+ for i in range(8, 64, 8):
+ self.part_pts[i] = Signal(name=f"part_pts_{i}")
+ self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+
+ def eq_from(self, part_pts, a, b, part_ops):
+ return [self.part_pts.eq(part_pts)] + \
+ [self.a.eq(a), self.b.eq(b)] + \
+ [self.part_ops[i].eq(part_ops[i])
+ for i in range(len(self.part_ops))]
+
+ def eq(self, rhs):
+ return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
+
+
+class OutputData:
+
+ def __init__(self):
+ self.intermediate_output = Signal(128) # needed for unit tests
+ self.output = Signal(64)
+
+ def eq(self, rhs):
+ return [self.intermediate_output.eq(rhs.intermediate_output),
+ self.output.eq(rhs.output)]
+
+
+class AllTerms(PipeModBase):
+ """Set of terms to be added together
+ """
+
+ def __init__(self, pspec, n_inputs):
+ """Create an ``AllTerms``.
+ """
+ self.n_inputs = n_inputs
+ self.n_parts = pspec.n_parts
+ self.output_width = pspec.width * 2
+ super().__init__(pspec, "allterms")
+
+ def ispec(self):
+ return InputData()
+
+ def ospec(self):
+ return AddReduceData(self.i.part_pts, self.n_inputs,
+ self.output_width, self.n_parts)
+
+ def elaborate(self, platform):
+ m = Module()
+
+ eps = self.i.part_pts
+
+ # collect part-bytes
+ pbs = Signal(8, reset_less=True)
+ tl = []
+ for i in range(8):
+ pb = Signal(name="pb%d" % i, reset_less=True)
+ m.d.comb += pb.eq(eps.part_byte(i))
+ tl.append(pb)
+ m.d.comb += pbs.eq(Cat(*tl))
+
+ # local variables
+ signs = []
+ for i in range(8):
+ s = Signs()
+ signs.append(s)
+ setattr(m.submodules, "signs%d" % i, s)
+ m.d.comb += s.part_ops.eq(self.i.part_ops[i])
+
+ m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+ m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+ m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+ m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
+ nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+ for mod in [part_8, part_16, part_32, part_64]:
+ m.d.comb += mod.a.eq(self.i.a)
+ m.d.comb += mod.b.eq(self.i.b)
+ for i in range(len(signs)):
+ m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
+ m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
+ m.d.comb += mod.pbs.eq(pbs)
+ nat_l.append(mod.not_a_term)
+ nbt_l.append(mod.not_b_term)
+ nla_l.append(mod.neg_lsb_a_term)
+ nlb_l.append(mod.neg_lsb_b_term)
+
+ terms = []
+
+ for a_index in range(8):
+ t = ProductTerms(8, 128, 8, a_index, 8)
+ setattr(m.submodules, "terms_%d" % a_index, t)
+
+ m.d.comb += t.a.eq(self.i.a)
+ m.d.comb += t.b.eq(self.i.b)
+ m.d.comb += t.pb_en.eq(pbs)
+
+ for term in t.terms:
+ terms.append(term)
+
+ # it's fine to bitwise-or data together since they are never enabled
+ # at the same time
+ m.submodules.nat_or = nat_or = OrMod(128)
+ m.submodules.nbt_or = nbt_or = OrMod(128)
+ m.submodules.nla_or = nla_or = OrMod(128)
+ m.submodules.nlb_or = nlb_or = OrMod(128)
+ for l, mod in [(nat_l, nat_or),
+ (nbt_l, nbt_or),
+ (nla_l, nla_or),
+ (nlb_l, nlb_or)]:
+ for i in range(len(l)):
+ m.d.comb += mod.orin[i].eq(l[i])
+ terms.append(mod.orout)
+
+ # copy the intermediate terms to the output
+ for i, value in enumerate(terms):
+ m.d.comb += self.o.terms[i].eq(value)
+
+ # copy reg part points and part ops to output
+ m.d.comb += self.o.part_pts.eq(eps)
+ m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+ for i in range(len(self.i.part_ops))]
+
+ return m
+
+
+class Intermediates(PipeModBase):
""" Intermediate output modules
"""
- def __init__(self, output_width, n_parts, partition_points):
- self.i = FinalReduceData(partition_points, output_width, n_parts)
- self.o = IntermediateData(partition_points, output_width, n_parts)
+ def __init__(self, pspec, part_pts):
+ self.part_pts = part_pts
+ self.output_width = pspec.width * 2
+ self.n_parts = pspec.n_parts
+
+ super().__init__(pspec, "intermediates")
+
+ def ispec(self):
+ return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
+
+ def ospec(self):
+ return IntermediateData(self.part_pts, self.output_width, self.n_parts)
def elaborate(self, platform):
m = Module()
out_part_ops = self.i.part_ops
- out_part_pts = self.i.reg_partition_points
+ out_part_pts = self.i.part_pts
# create _output_64
m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
for i in range(8):
m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
- m.d.comb += self.o.reg_partition_points.eq(out_part_pts)
+ m.d.comb += self.o.part_pts.eq(out_part_pts)
m.d.comb += self.o.intermediate_output.eq(self.i.output)
return m
class Mul8_16_32_64(Elaboratable):
"""Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
+ XXX NOTE: this class is intended for unit test purposes ONLY.
+
Supports partitioning into any combination of 8, 16, 32, and 64-bit
partitions on naturally-aligned boundaries. Supports the operation being
set for each partition independently.
flip-flops are to be inserted.
"""
+ self.id_wid = 0 # num_bits(num_rows)
+ self.op_wid = 0
+ self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+ self.pspec.n_parts = 8
+
# parameter(s)
self.register_levels = list(register_levels)
- # inputs
- self.part_pts = PartitionPoints()
- for i in range(8, 64, 8):
- self.part_pts[i] = Signal(name=f"part_pts_{i}")
- self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
- self.a = Signal(64)
- self.b = Signal(64)
+ self.i = self.ispec()
+ self.o = self.ospec()
- # intermediates (needed for unit tests)
- self.intermediate_output = Signal(128)
+ # inputs
+ self.part_pts = self.i.part_pts
+ self.part_ops = self.i.part_ops
+ self.a = self.i.a
+ self.b = self.i.b
# output
- self.output = Signal(64)
+ self.intermediate_output = self.o.intermediate_output
+ self.output = self.o.output
- def elaborate(self, platform):
- m = Module()
+ def ispec(self):
+ return InputData()
- # collect part-bytes
- pbs = Signal(8, reset_less=True)
- tl = []
- for i in range(8):
- pb = Signal(name="pb%d" % i, reset_less=True)
- m.d.comb += pb.eq(self.part_pts.part_byte(i))
- tl.append(pb)
- m.d.comb += pbs.eq(Cat(*tl))
-
- # create (doubled) PartitionPoints (output is double input width)
- expanded_part_pts = eps = PartitionPoints()
- for i, v in self.part_pts.items():
- ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
- expanded_part_pts[i * 2] = ep
- m.d.comb += ep.eq(v)
-
- n_inputs = 64 + 4
- n_parts = 8 #len(self.part_pts)
- t = AllTerms(8, n_inputs, 128, n_parts, self.register_levels,
- eps)
- m.submodules.allterms = t
- m.d.comb += t.a.eq(self.a)
- m.d.comb += t.b.eq(self.b)
- m.d.comb += t.pbs.eq(pbs)
- m.d.comb += t.epps.eq(eps)
- for i in range(8):
- m.d.comb += t.part_ops[i].eq(self.part_ops[i])
-
- terms = t.o.inputs
-
- add_reduce = AddReduce(terms,
- 128,
- self.register_levels,
- t.o.reg_partition_points,
- t.o.part_ops)
-
- out_part_ops = add_reduce.o.part_ops
- out_part_pts = add_reduce.o.reg_partition_points
-
- m.submodules.add_reduce = add_reduce
- m.d.comb += self.intermediate_output.eq(add_reduce.o.output)
-
- interm = Intermediates(128, 8, expanded_part_pts)
- m.submodules.intermediates = interm
- m.d.comb += interm.i.eq(add_reduce.o)
-
- # final output
- m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts)
- m.d.comb += finalout.i.eq(interm.o)
- m.d.comb += self.output.eq(finalout.out)
-
- return m
-
-
-class AllTerms(Elaboratable):
- """Set of terms to be added together
- """
-
- def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels,
- partition_points):
- """Create an ``AddReduce``.
-
- :param inputs: input ``Signal``s to be summed.
- :param output_width: bit-width of ``output``.
- :param register_levels: List of nesting levels that should have
- pipeline registers.
- :param partition_points: the input partition points.
- """
- self.epps = partition_points.like()
- self.register_levels = register_levels
- self.pbwid = pbwid
- self.n_inputs = n_inputs
- self.n_parts = n_parts
- self.output_width = output_width
- self.o = AddReduceData(self.epps, n_inputs,
- output_width, n_parts)
-
- self.a = Signal(64)
- self.b = Signal(64)
-
- self.pbs = Signal(pbwid, reset_less=True)
- self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+ def ospec(self):
+ return OutputData()
def elaborate(self, platform):
m = Module()
- pbs = self.pbs
- eps = self.epps
-
- # local variables
- signs = []
- for i in range(8):
- s = Signs()
- signs.append(s)
- setattr(m.submodules, "signs%d" % i, s)
- m.d.comb += s.part_ops.eq(self.part_ops[i])
-
- n_levels = len(self.register_levels)+1
- m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
- m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
- m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
- m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
- nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
- for mod in [part_8, part_16, part_32, part_64]:
- m.d.comb += mod.a.eq(self.a)
- m.d.comb += mod.b.eq(self.b)
- for i in range(len(signs)):
- m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
- m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
- m.d.comb += mod.pbs.eq(pbs)
- nat_l.append(mod.not_a_term)
- nbt_l.append(mod.not_b_term)
- nla_l.append(mod.neg_lsb_a_term)
- nlb_l.append(mod.neg_lsb_b_term)
-
- terms = []
+ part_pts = self.part_pts
- for a_index in range(8):
- t = ProductTerms(8, 128, 8, a_index, 8)
- setattr(m.submodules, "terms_%d" % a_index, t)
+ n_inputs = 64 + 4
+ t = AllTerms(self.pspec, n_inputs)
+ t.setup(m, self.i)
- m.d.comb += t.a.eq(self.a)
- m.d.comb += t.b.eq(self.b)
- m.d.comb += t.pb_en.eq(pbs)
+ terms = t.o.terms
- for term in t.terms:
- terms.append(term)
+ at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
- # it's fine to bitwise-or data together since they are never enabled
- # at the same time
- m.submodules.nat_or = nat_or = OrMod(128)
- m.submodules.nbt_or = nbt_or = OrMod(128)
- m.submodules.nla_or = nla_or = OrMod(128)
- m.submodules.nlb_or = nlb_or = OrMod(128)
- for l, mod in [(nat_l, nat_or),
- (nbt_l, nbt_or),
- (nla_l, nla_or),
- (nlb_l, nlb_or)]:
- for i in range(len(l)):
- m.d.comb += mod.orin[i].eq(l[i])
- terms.append(mod.orout)
+ i = t.o
+ for idx in range(len(at.levels)):
+ mcur = at.levels[idx]
+ mcur.setup(m, i)
+ o = mcur.ospec()
+ if idx in self.register_levels:
+ m.d.sync += o.eq(mcur.process(i))
+ else:
+ m.d.comb += o.eq(mcur.process(i))
+ i = o # for next loop
- # copy the intermediate terms to the output
- for i, value in enumerate(terms):
- m.d.comb += self.o.inputs[i].eq(value)
+ interm = Intermediates(self.pspec, part_pts)
+ interm.setup(m, i)
+ o = interm.process(interm.i)
- # copy reg part points and part ops to output
- m.d.comb += self.o.reg_partition_points.eq(eps)
- m.d.comb += [self.o.part_ops[i].eq(self.part_ops[i])
- for i in range(len(self.part_ops))]
+ # final output
+ finalout = FinalOut(self.pspec, part_pts)
+ finalout.setup(m, o)
+ m.d.comb += self.o.eq(finalout.process(o))
return m