remove use of AddReduce, use AddReduceInternal instead
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 4c6b570ce4474008c4fc590110afb1719cb818cb..2c828c187f2747bde3285df64599636719e3be72 100644 (file)
@@ -346,7 +346,9 @@ class FinalAdd(Elaboratable):
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points):
+    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+                       partition_step=1):
+        self.partition_step = partition_step
         self.output_width = output_width
         self.n_inputs = n_inputs
         self.n_parts = n_parts
@@ -381,7 +383,7 @@ class FinalAdd(Elaboratable):
             # base case for adding 2 inputs
             assert self.n_inputs == 2
             adder = PartitionedAdder(output_width,
-                                     self.i.part_pts, 2)
+                                     self.i.part_pts, self.partition_step)
             m.submodules.final_adder = adder
             m.d.comb += adder.a.eq(self.i.terms[0])
             m.d.comb += adder.b.eq(self.i.terms[1])
@@ -406,13 +408,15 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points):
+    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+                       partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
+        self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_parts = n_parts
         self.output_width = output_width
@@ -516,7 +520,8 @@ class AddReduceSingle(Elaboratable):
         part_mask = Signal(self.output_width, reset_less=True)
 
         # get partition points as a mask
-        mask = self.i.part_pts.as_mask(self.output_width, mul=2)
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
@@ -543,7 +548,7 @@ class AddReduceInternal:
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, i, output_width):
+    def __init__(self, i, output_width, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -555,6 +560,7 @@ class AddReduceInternal:
         self.part_ops = i.part_ops
         self.output_width = output_width
         self.partition_points = i.part_pts
+        self.partition_step = partition_step
 
         self.create_levels()
 
@@ -572,7 +578,8 @@ class AddReduceInternal:
             if len(groups) == 0:
                 break
             next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         partition_points)
+                                         partition_points,
+                                         self.partition_step)
             mods.append(next_level)
             partition_points = next_level.i.part_pts
             inputs = next_level.o.terms
@@ -580,7 +587,7 @@ class AddReduceInternal:
             part_ops = next_level.i.part_ops
 
         next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              partition_points)
+                              partition_points, self.partition_step)
         mods.append(next_level)
 
         self.levels = mods
@@ -599,7 +606,7 @@ class AddReduce(AddReduceInternal, Elaboratable):
     """
 
     def __init__(self, inputs, output_width, register_levels, part_pts,
-                       part_ops):
+                       part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -614,7 +621,7 @@ class AddReduce(AddReduceInternal, Elaboratable):
         n_parts = len(part_ops)
         self.i = AddReduceData(part_pts, len(inputs),
                              output_width, n_parts)
-        AddReduceInternal.__init__(self, self.i, output_width)
+        AddReduceInternal.__init__(self, self.i, output_width, partition_step)
         self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
 
@@ -1382,17 +1389,21 @@ class Mul8_16_32_64(Elaboratable):
 
         terms = t.o.terms
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               t.o.part_pts,
-                               t.o.part_ops)
+        at = AddReduceInternal(t.o, 128, partition_step=2)
 
-        m.submodules.add_reduce = add_reduce
+        i = at.i
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            setattr(m.submodules, "addreduce_%d" % idx, mcur)
+            if idx in self.register_levels:
+                m.d.sync += mcur.i.eq(i)
+            else:
+                m.d.comb += mcur.i.eq(i)
+            i = mcur.o # for next loop
 
         interm = Intermediates(128, 8, part_pts)
         m.submodules.intermediates = interm
-        m.d.comb += interm.i.eq(add_reduce.o)
+        m.d.comb += interm.i.eq(i)
 
         # final output
         m.submodules.finalout = finalout = FinalOut(128, 8, part_pts)