use intermediate data from finalout, move AllTerms class
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 18:46:27 +0000 (19:46 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 18:46:27 +0000 (19:46 +0100)
src/ieee754/part_mul_add/multiply.py

index f55d1f7670a75a3faadd35f0b00abefee8d3bbf8..7b8782e220797f43db32f2fdae442c90bf464980 100644 (file)
@@ -1070,6 +1070,106 @@ class IntermediateData:
                             rhs.intermediate_output, rhs.part_ops)
 
 
+class AllTerms(Elaboratable):
+    """Set of terms to be added together
+    """
+
+    def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels,
+                       partition_points):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param register_levels: List of nesting levels that should have
+            pipeline registers.
+        :param partition_points: the input partition points.
+        """
+        self.epps = partition_points.like()
+        self.register_levels = register_levels
+        self.pbwid = pbwid
+        self.n_inputs = n_inputs
+        self.n_parts = n_parts
+        self.output_width = output_width
+        self.o = AddReduceData(self.epps, n_inputs,
+                               output_width, n_parts)
+
+        self.a = Signal(64)
+        self.b = Signal(64)
+
+        self.pbs = Signal(pbwid, reset_less=True)
+        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+
+    def elaborate(self, platform):
+        m = Module()
+
+        pbs = self.pbs
+        eps = self.epps
+
+        # local variables
+        signs = []
+        for i in range(8):
+            s = Signs()
+            signs.append(s)
+            setattr(m.submodules, "signs%d" % i, s)
+            m.d.comb += s.part_ops.eq(self.part_ops[i])
+
+        n_levels = len(self.register_levels)+1
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
+        nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+        for mod in [part_8, part_16, part_32, part_64]:
+            m.d.comb += mod.a.eq(self.a)
+            m.d.comb += mod.b.eq(self.b)
+            for i in range(len(signs)):
+                m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
+                m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
+            m.d.comb += mod.pbs.eq(pbs)
+            nat_l.append(mod.not_a_term)
+            nbt_l.append(mod.not_b_term)
+            nla_l.append(mod.neg_lsb_a_term)
+            nlb_l.append(mod.neg_lsb_b_term)
+
+        terms = []
+
+        for a_index in range(8):
+            t = ProductTerms(8, 128, 8, a_index, 8)
+            setattr(m.submodules, "terms_%d" % a_index, t)
+
+            m.d.comb += t.a.eq(self.a)
+            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.pb_en.eq(pbs)
+
+            for term in t.terms:
+                terms.append(term)
+
+        # it's fine to bitwise-or data together since they are never enabled
+        # at the same time
+        m.submodules.nat_or = nat_or = OrMod(128)
+        m.submodules.nbt_or = nbt_or = OrMod(128)
+        m.submodules.nla_or = nla_or = OrMod(128)
+        m.submodules.nlb_or = nlb_or = OrMod(128)
+        for l, mod in [(nat_l, nat_or),
+                             (nbt_l, nbt_or),
+                             (nla_l, nla_or),
+                             (nlb_l, nlb_or)]:
+            for i in range(len(l)):
+                m.d.comb += mod.orin[i].eq(l[i])
+            terms.append(mod.orout)
+
+        # copy the intermediate terms to the output
+        for i, value in enumerate(terms):
+            m.d.comb += self.o.inputs[i].eq(value)
+
+        # copy reg part points and part ops to output
+        m.d.comb += self.o.reg_partition_points.eq(eps)
+        m.d.comb += [self.o.part_ops[i].eq(self.part_ops[i])
+                                     for i in range(len(self.part_ops))]
+
+        return m
+
+
 class Intermediates(Elaboratable):
     """ Intermediate output modules
     """
@@ -1213,7 +1313,6 @@ class Mul8_16_32_64(Elaboratable):
         out_part_pts = add_reduce.o.reg_partition_points
 
         m.submodules.add_reduce = add_reduce
-        m.d.comb += self.intermediate_output.eq(add_reduce.o.output)
 
         interm = Intermediates(128, 8, expanded_part_pts)
         m.submodules.intermediates = interm
@@ -1223,106 +1322,7 @@ class Mul8_16_32_64(Elaboratable):
         m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts)
         m.d.comb += finalout.i.eq(interm.o)
         m.d.comb += self.output.eq(finalout.out)
-
-        return m
-
-
-class AllTerms(Elaboratable):
-    """Set of terms to be added together
-    """
-
-    def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        """Create an ``AddReduce``.
-
-        :param inputs: input ``Signal``s to be summed.
-        :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
-        :param partition_points: the input partition points.
-        """
-        self.epps = partition_points.like()
-        self.register_levels = register_levels
-        self.pbwid = pbwid
-        self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.o = AddReduceData(self.epps, n_inputs,
-                               output_width, n_parts)
-
-        self.a = Signal(64)
-        self.b = Signal(64)
-
-        self.pbs = Signal(pbwid, reset_less=True)
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-
-    def elaborate(self, platform):
-        m = Module()
-
-        pbs = self.pbs
-        eps = self.epps
-
-        # local variables
-        signs = []
-        for i in range(8):
-            s = Signs()
-            signs.append(s)
-            setattr(m.submodules, "signs%d" % i, s)
-            m.d.comb += s.part_ops.eq(self.part_ops[i])
-
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
-        nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
-        for mod in [part_8, part_16, part_32, part_64]:
-            m.d.comb += mod.a.eq(self.a)
-            m.d.comb += mod.b.eq(self.b)
-            for i in range(len(signs)):
-                m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
-                m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
-            m.d.comb += mod.pbs.eq(pbs)
-            nat_l.append(mod.not_a_term)
-            nbt_l.append(mod.not_b_term)
-            nla_l.append(mod.neg_lsb_a_term)
-            nlb_l.append(mod.neg_lsb_b_term)
-
-        terms = []
-
-        for a_index in range(8):
-            t = ProductTerms(8, 128, 8, a_index, 8)
-            setattr(m.submodules, "terms_%d" % a_index, t)
-
-            m.d.comb += t.a.eq(self.a)
-            m.d.comb += t.b.eq(self.b)
-            m.d.comb += t.pb_en.eq(pbs)
-
-            for term in t.terms:
-                terms.append(term)
-
-        # it's fine to bitwise-or data together since they are never enabled
-        # at the same time
-        m.submodules.nat_or = nat_or = OrMod(128)
-        m.submodules.nbt_or = nbt_or = OrMod(128)
-        m.submodules.nla_or = nla_or = OrMod(128)
-        m.submodules.nlb_or = nlb_or = OrMod(128)
-        for l, mod in [(nat_l, nat_or),
-                             (nbt_l, nbt_or),
-                             (nla_l, nla_or),
-                             (nlb_l, nlb_or)]:
-            for i in range(len(l)):
-                m.d.comb += mod.orin[i].eq(l[i])
-            terms.append(mod.orout)
-
-        # copy the intermediate terms to the output
-        for i, value in enumerate(terms):
-            m.d.comb += self.o.inputs[i].eq(value)
-
-        # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(eps)
-        m.d.comb += [self.o.part_ops[i].eq(self.part_ops[i])
-                                     for i in range(len(self.part_ops))]
+        m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)
 
         return m