From 92ca6002d7435795ef3e20c8be00567786baec5d Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Fri, 23 Aug 2019 13:26:40 +0100 Subject: [PATCH] munge AddReduce internals --- src/ieee754/part_mul_add/multiply.py | 33 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index aef5c16e..4c6b570c 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -543,18 +543,18 @@ class AddReduceInternal: supported, except for by ``Signal.eq``. """ - def __init__(self, inputs, output_width, partition_points, - part_ops): + def __init__(self, i, output_width): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. :param output_width: bit-width of ``output``. :param partition_points: the input partition points. """ - self.inputs = inputs - self.part_ops = part_ops + self.i = i + self.inputs = i.terms + self.part_ops = i.part_ops self.output_width = output_width - self.partition_points = partition_points + self.partition_points = i.part_pts self.create_levels() @@ -598,7 +598,7 @@ class AddReduce(AddReduceInternal, Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, inputs, output_width, register_levels, partition_points, + def __init__(self, inputs, output_width, register_levels, part_pts, part_ops): """Create an ``AddReduce``. @@ -608,10 +608,14 @@ class AddReduce(AddReduceInternal, Elaboratable): pipeline registers. :param partition_points: the input partition points. """ - AddReduceInternal.__init__(self, inputs, output_width, - partition_points, part_ops) + self._inputs = inputs + self._part_pts = part_pts + self._part_ops = part_ops n_parts = len(part_ops) - self.o = FinalReduceData(partition_points, output_width, n_parts) + self.i = AddReduceData(part_pts, len(inputs), + output_width, n_parts) + AddReduceInternal.__init__(self, self.i, output_width) + self.o = FinalReduceData(part_pts, output_width, n_parts) self.register_levels = register_levels @staticmethod @@ -629,17 +633,12 @@ class AddReduce(AddReduceInternal, Elaboratable): """Elaborate this module.""" m = Module() + m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops) + for i, next_level in enumerate(self.levels): setattr(m.submodules, "next_level%d" % i, next_level) - partition_points = self.partition_points - inputs = self.inputs - part_ops = self.part_ops - n_parts = len(part_ops) - n_inputs = len(inputs) - output_width = self.output_width - i = AddReduceData(partition_points, n_inputs, output_width, n_parts) - m.d.comb += i.eq_from(partition_points, inputs, part_ops) + i = self.i for idx in range(len(self.levels)): mcur = self.levels[idx] if idx in self.register_levels: -- 2.30.2