move and reorg create_next_terms in AddReduceSingle, call in elaborate
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 01:27:02 +0000 (02:27 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 01:27:02 +0000 (02:27 +0100)
src/ieee754/part_mul_add/multiply.py

index 6f770842f8aa048c6afc2b8e54b5013f2e91e985..c74c80e60cd25476466260d1c6a721c3792ada7f 100644 (file)
@@ -418,18 +418,20 @@ class AddReduceSingle(Elaboratable):
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
-        # this is annoying.  we have to create the modules (and terms)
-        # because we need to know what they are (in order to set up the
-        # interconnects back in AddReduce), but cannot do the m.d.comb +=
-        # etc because this is not in elaboratable.
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
-        self._intermediate_terms = []
-        self.adders = []
-        if len(self.groups) != 0:
-            self.create_next_terms()
+        n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
+        self.o = AddReduceData(partition_points, n_terms, output_width, n_parts)
 
 
-        self.o = AddReduceData(partition_points, len(self._intermediate_terms),
-                               output_width, n_parts)
+    @staticmethod
+    def calc_n_inputs(n_inputs, groups):
+        retval = len(groups)*2
+        if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            retval += 1
+        elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            retval += 2
+        else:
+            assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
+        return retval
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -454,12 +456,43 @@ class AddReduceSingle(Elaboratable):
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
+    def create_next_terms(self):
+        """ create next intermediate terms, for linking up in elaborate, below
+        """
+        terms = []
+        adders = []
+
+        # create full adders for this recursive level.
+        # this shrinks N terms to 2 * (N // 3) plus the remainder
+        for i in self.groups:
+            adder_i = MaskedFullAdder(self.output_width)
+            adders.append((i, adder_i))
+            # add both the sum and the masked-carry to the next level.
+            # 3 inputs have now been reduced to 2...
+            terms.append(adder_i.sum)
+            terms.append(adder_i.mcarry)
+        # handle the remaining inputs.
+        if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            terms.append(self.i.inputs[-1])
+        elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            # Just pass the terms to the next layer, since we wouldn't gain
+            # anything by using a half adder since there would still be 2 terms
+            # and just passing the terms to the next layer saves gates.
+            terms.append(self.i.inputs[-2])
+            terms.append(self.i.inputs[-1])
+        else:
+            assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
+
+        return terms, adders
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        terms, adders = self.create_next_terms()
+
         # copy the intermediate terms to the output
         # copy the intermediate terms to the output
-        for i, value in enumerate(self._intermediate_terms):
+        for i, value in enumerate(terms):
             m.d.comb += self.o.inputs[i].eq(value)
 
         # copy reg part points and part ops to output
             m.d.comb += self.o.inputs[i].eq(value)
 
         # copy reg part points and part ops to output
@@ -474,7 +507,7 @@ class AddReduceSingle(Elaboratable):
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
-        for i, (iidx, adder_i) in enumerate(self.adders):
+        for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
             m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
             setattr(m.submodules, f"adder_{i}", adder_i)
 
             m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
@@ -484,36 +517,6 @@ class AddReduceSingle(Elaboratable):
 
         return m
 
 
         return m
 
-    def create_next_terms(self):
-
-        _intermediate_terms = []
-
-        def add_intermediate_term(value):
-            _intermediate_terms.append(value)
-
-        # create full adders for this recursive level.
-        # this shrinks N terms to 2 * (N // 3) plus the remainder
-        for i in self.groups:
-            adder_i = MaskedFullAdder(self.output_width)
-            self.adders.append((i, adder_i))
-            # add both the sum and the masked-carry to the next level.
-            # 3 inputs have now been reduced to 2...
-            add_intermediate_term(adder_i.sum)
-            add_intermediate_term(adder_i.mcarry)
-        # handle the remaining inputs.
-        if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
-            add_intermediate_term(self.i.inputs[-1])
-        elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
-            # Just pass the terms to the next layer, since we wouldn't gain
-            # anything by using a half adder since there would still be 2 terms
-            # and just passing the terms to the next layer saves gates.
-            add_intermediate_term(self.i.inputs[-2])
-            add_intermediate_term(self.i.inputs[-1])
-        else:
-            assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
-
-        self._intermediate_terms = _intermediate_terms
-
 
 class AddReduce(Elaboratable):
     """Recursively Add list of numbers together.
 
 class AddReduce(Elaboratable):
     """Recursively Add list of numbers together.