From 3feb6cc6b8fa0ac8ea0ac0d8a91a91722a07a6c5 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Tue, 20 Aug 2019 07:13:47 +0100 Subject: [PATCH] split "actionable" part of AddReduce out from "recursive" part --- src/ieee754/part_mul_add/multiply.py | 59 ++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 4c3a3cf1..e86f655b 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -269,7 +269,7 @@ class PartitionedAdder(Elaboratable): FULL_ADDER_INPUT_COUNT = 3 -class AddReduce(Elaboratable): +class AddReduceSingle(Elaboratable): """Add list of numbers together. :attribute inputs: input ``Signal``s to be summed. Modification not @@ -300,6 +300,7 @@ class AddReduce(Elaboratable): if not self.partition_points.fits_in_width(output_width): raise ValueError("partition_points doesn't fit in output_width") self._reg_partition_points = self.partition_points.like() + max_level = AddReduce.get_max_level(len(self.inputs)) for level in self.register_levels: if level > max_level: @@ -321,13 +322,6 @@ class AddReduce(Elaboratable): input_count %= FULL_ADDER_INPUT_COUNT input_count += 2 * len(groups) retval += 1 - - def next_register_levels(self): - """``Iterable`` of ``register_levels`` for next recursive level.""" - for level in self.register_levels: - if level > 0: - yield level - 1 - @staticmethod def full_adder_groups(input_count): """Get ``inputs`` indices for which a full adder should be built.""" @@ -335,7 +329,7 @@ class AddReduce(Elaboratable): input_count - FULL_ADDER_INPUT_COUNT + 1, FULL_ADDER_INPUT_COUNT) - def elaborate(self, platform): + def _elaborate(self, platform): """Elaborate this module.""" m = Module() @@ -350,7 +344,7 @@ class AddReduce(Elaboratable): m.d.comb += resized_input_assignments m.d.comb += self._reg_partition_points.eq(self.partition_points) - groups = AddReduce.full_adder_groups(len(self.inputs)) + groups = AddReduceSingle.full_adder_groups(len(self.inputs)) # if there are no full adders to create, then we handle the base cases # and return, otherwise we go on to the recursive case if len(groups) == 0: @@ -370,8 +364,9 @@ class AddReduce(Elaboratable): m.d.comb += adder.a.eq(self._resized_inputs[0]) m.d.comb += adder.b.eq(self._resized_inputs[1]) m.d.comb += self.output.eq(adder.output) - return m - # go on to handle recursive case + return None, m + + # go on to prepare recursive case intermediate_terms = [] def add_intermediate_term(value): @@ -410,6 +405,46 @@ class AddReduce(Elaboratable): add_intermediate_term(self._resized_inputs[-1]) else: assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0 + + return intermediate_terms, m + + +class AddReduce(AddReduceSingle): + """Recursively Add list of numbers together. + + :attribute inputs: input ``Signal``s to be summed. Modification not + supported, except for by ``Signal.eq``. + :attribute register_levels: List of nesting levels that should have + pipeline registers. + :attribute output: output sum. + :attribute partition_points: the input partition points. Modification not + supported, except for by ``Signal.eq``. + """ + + def __init__(self, inputs, output_width, register_levels, partition_points): + """Create an ``AddReduce``. + + :param inputs: input ``Signal``s to be summed. + :param output_width: bit-width of ``output``. + :param register_levels: List of nesting levels that should have + pipeline registers. + :param partition_points: the input partition points. + """ + AddReduceSingle.__init__(self, inputs, output_width, register_levels, + partition_points) + + def next_register_levels(self): + """``Iterable`` of ``register_levels`` for next recursive level.""" + for level in self.register_levels: + if level > 0: + yield level - 1 + + def elaborate(self, platform): + """Elaborate this module.""" + intermediate_terms, m = AddReduceSingle._elaborate(self, platform) + if intermediate_terms is None: + return m + # recursive invocation of ``AddReduce`` next_level = AddReduce(intermediate_terms, len(self.output), -- 2.30.2