add Stage API setup/process to AddReduceInternal
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 14:34:23 +0000 (15:34 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 14:34:23 +0000 (15:34 +0100)
src/ieee754/part_mul_add/multiply.py

index 92ff7d305b059a78ff49e81fccd6d06ff4125c2a..8aecf8ae4a87dd50d3bff4d6daddb98fa7131946 100644 (file)
@@ -346,8 +346,9 @@ class FinalAdd(Elaboratable):
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
                        partition_step=1):
+        self.lidx = lidx
         self.partition_step = partition_step
         self.output_width = output_width
         self.n_inputs = n_inputs
@@ -367,6 +368,13 @@ class FinalAdd(Elaboratable):
         return FinalReduceData(self.partition_points,
                                  self.output_width, self.n_parts)
 
+    def setup(self, m, i):
+        m.submodules.finaladd = self
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        return self.o
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -408,7 +416,7 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
                        partition_step=1):
         """Create an ``AddReduce``.
 
@@ -416,6 +424,7 @@ class AddReduceSingle(Elaboratable):
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
         self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_parts = n_parts
@@ -438,6 +447,13 @@ class AddReduceSingle(Elaboratable):
         return AddReduceData(self.partition_points, self.n_terms,
                              self.output_width, self.n_parts)
 
+    def setup(self, m, i):
+        setattr(m.submodules, "addreduce_%d" % self.lidx, self)
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        return self.o
+
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
         retval = len(groups)*2
@@ -577,7 +593,8 @@ class AddReduceInternal:
             groups = AddReduceSingle.full_adder_groups(len(inputs))
             if len(groups) == 0:
                 break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
+            lidx = len(mods)
+            next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts,
                                          partition_points,
                                          self.partition_step)
             mods.append(next_level)
@@ -586,7 +603,8 @@ class AddReduceInternal:
             ilen = len(inputs)
             part_ops = next_level.i.part_ops
 
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
+        lidx = len(mods)
+        next_level = FinalAdd(lidx, ilen, self.output_width, n_parts,
                               partition_points, self.partition_step)
         mods.append(next_level)
 
@@ -1414,12 +1432,13 @@ class Mul8_16_32_64(Elaboratable):
         i = at.i
         for idx in range(len(at.levels)):
             mcur = at.levels[idx]
-            setattr(m.submodules, "addreduce_%d" % idx, mcur)
+            mcur.setup(m, i)
+            o = mcur.ospec()
             if idx in self.register_levels:
-                m.d.sync += mcur.i.eq(i)
+                m.d.sync += o.eq(mcur.process(i))
             else:
-                m.d.comb += mcur.i.eq(i)
-            i = mcur.o # for next loop
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
 
         interm = Intermediates(128, 8, part_pts)
         interm.setup(m, i)