+class AddReduceData:
+
+ def __init__(self, part_pts, n_inputs, output_width, n_parts):
+ self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+ for i in range(n_parts)]
+ self.terms = [Signal(output_width, name=f"inputs_{i}",
+ reset_less=True)
+ for i in range(n_inputs)]
+ self.part_pts = part_pts.like()
+
+ def eq_from(self, part_pts, inputs, part_ops):
+ return [self.part_pts.eq(part_pts)] + \
+ [self.terms[i].eq(inputs[i])
+ for i in range(len(self.terms))] + \
+ [self.part_ops[i].eq(part_ops[i])
+ for i in range(len(self.part_ops))]
+
+ def eq(self, rhs):
+ return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
+
+
+class FinalReduceData:
+
+ def __init__(self, part_pts, output_width, n_parts):
+ self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+ for i in range(n_parts)]
+ self.output = Signal(output_width, reset_less=True)
+ self.part_pts = part_pts.like()
+
+ def eq_from(self, part_pts, output, part_ops):
+ return [self.part_pts.eq(part_pts)] + \
+ [self.output.eq(output)] + \
+ [self.part_ops[i].eq(part_ops[i])
+ for i in range(len(self.part_ops))]
+
+ def eq(self, rhs):
+ return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
+
+
+class FinalAdd(Elaboratable):
+ """ Final stage of add reduce
+ """
+
+ def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
+ partition_step=1):
+ self.lidx = lidx
+ self.partition_step = partition_step
+ self.output_width = output_width
+ self.n_inputs = n_inputs
+ self.n_parts = n_parts
+ self.partition_points = PartitionPoints(partition_points)
+ if not self.partition_points.fits_in_width(output_width):
+ raise ValueError("partition_points doesn't fit in output_width")
+
+ self.i = self.ispec()
+ self.o = self.ospec()
+
+ def ispec(self):
+ return AddReduceData(self.partition_points, self.n_inputs,
+ self.output_width, self.n_parts)
+
+ def ospec(self):
+ return FinalReduceData(self.partition_points,
+ self.output_width, self.n_parts)
+
+ def setup(self, m, i):
+ m.submodules.finaladd = self
+ m.d.comb += self.i.eq(i)
+
+ def process(self, i):
+ return self.o
+
+ def elaborate(self, platform):
+ """Elaborate this module."""
+ m = Module()
+
+ output_width = self.output_width
+ output = Signal(output_width, reset_less=True)
+ if self.n_inputs == 0:
+ # use 0 as the default output value
+ m.d.comb += output.eq(0)
+ elif self.n_inputs == 1:
+ # handle single input
+ m.d.comb += output.eq(self.i.terms[0])
+ else:
+ # base case for adding 2 inputs
+ assert self.n_inputs == 2
+ adder = PartitionedAdder(output_width,
+ self.i.part_pts, self.partition_step)
+ m.submodules.final_adder = adder
+ m.d.comb += adder.a.eq(self.i.terms[0])
+ m.d.comb += adder.b.eq(self.i.terms[1])
+ m.d.comb += output.eq(adder.output)
+
+ # create output
+ m.d.comb += self.o.eq_from(self.i.part_pts, output,
+ self.i.part_ops)
+
+ return m
+