pass in part_ops to AddReduce, so that it is syncd alongside the other data
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 10:44:41 +0000 (11:44 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 10:44:41 +0000 (11:44 +0100)
src/ieee754/part_mul_add/multiply.py

index 517e6cf9e2835b712b95c43e7832f04283a5bd3d..32e817a1b8f99de9bdda87eede9365cee8843deb 100644 (file)
@@ -305,7 +305,8 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points):
+    def __init__(self, inputs, output_width, register_levels, partition_points,
+                       part_ops):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -314,6 +315,9 @@ class AddReduceSingle(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
+        self.part_ops = part_ops
+        self._part_ops = [Signal(2, name=f"part_ops_{i}")
+                          for i in range(len(part_ops))]
         self.inputs = list(inputs)
         self._resized_inputs = [
             Signal(output_width, name=f"resized_inputs[{i}]")
@@ -367,10 +371,14 @@ class AddReduceSingle(Elaboratable):
         # pipeline registers
         resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
                                      for i in range(len(self.inputs))]
+        copy_part_ops = [self._part_ops[i].eq(self.part_ops[i])
+                                     for i in range(len(self.part_ops))]
         if 0 in self.register_levels:
+            m.d.sync += copy_part_ops
             m.d.sync += resized_input_assignments
             m.d.sync += self._reg_partition_points.eq(self.partition_points)
         else:
+            m.d.comb += copy_part_ops
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
@@ -465,7 +473,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points):
+    def __init__(self, inputs, output_width, register_levels, partition_points,
+                       part_ops):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -475,6 +484,7 @@ class AddReduce(Elaboratable):
         :param partition_points: the input partition points.
         """
         self.inputs = inputs
+        self.part_ops = part_ops
         self.output = Signal(output_width)
         self.output_width = output_width
         self.register_levels = register_levels
@@ -498,7 +508,7 @@ class AddReduce(Elaboratable):
         inputs = self.inputs
         while True:
             next_level = AddReduceSingle(inputs, self.output_width, next_levels,
-                                 partition_points)
+                                         partition_points, self.part_ops)
             mods.append(next_level)
             if len(next_level.groups) == 0:
                 break
@@ -1030,7 +1040,9 @@ class Mul8_16_32_64(Elaboratable):
         add_reduce = AddReduce(terms,
                                128,
                                self.register_levels,
-                               expanded_part_pts)
+                               expanded_part_pts,
+                               self.part_ops)
+
         m.submodules.add_reduce = add_reduce
         m.d.comb += self._intermediate_output.eq(add_reduce.output)
         # create _output_64