create AllTermsData class and use it
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 23:45:38 +0000 (00:45 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 23:45:38 +0000 (00:45 +0100)
src/ieee754/part_mul_add/multiply.py

index 0d65e30fd79ed49ab53d2b077bff5b6156c5bd51..8e4ea8305ed43441254d33e1002c0c3d92c285a5 100644 (file)
@@ -1070,6 +1070,24 @@ class IntermediateData:
                             rhs.intermediate_output, rhs.part_ops)
 
 
+class AllTermsData:
+
+    def __init__(self, partition_points):
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self.epps = partition_points.like()
+        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+
+    def eq_from(self, epps, inputs, part_ops):
+        return [self.epps.eq(epps)] + \
+               [self.a.eq(a), self.b.eq(b)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                                     for i in range(len(self.part_ops))]
+
+    def eq(self, rhs):
+        return self.eq_from(rhs.epps, rhs.a, rhs.b, rhs.part_ops)
+
+
 class AllTerms(Elaboratable):
     """Set of terms to be added together
     """
@@ -1084,23 +1102,18 @@ class AllTerms(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.epps = partition_points.like()
+        self.i = AllTermsData(partition_points)
         self.register_levels = register_levels
         self.n_inputs = n_inputs
         self.n_parts = n_parts
         self.output_width = output_width
-        self.o = AddReduceData(self.epps, n_inputs,
+        self.o = AddReduceData(self.i.epps, n_inputs,
                                output_width, n_parts)
 
-        self.a = Signal(64)
-        self.b = Signal(64)
-
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-
     def elaborate(self, platform):
         m = Module()
 
-        eps = self.epps
+        eps = self.i.epps
 
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
@@ -1117,7 +1130,7 @@ class AllTerms(Elaboratable):
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
-            m.d.comb += s.part_ops.eq(self.part_ops[i])
+            m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
         n_levels = len(self.register_levels)+1
         m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
@@ -1126,8 +1139,8 @@ class AllTerms(Elaboratable):
         m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
-            m.d.comb += mod.a.eq(self.a)
-            m.d.comb += mod.b.eq(self.b)
+            m.d.comb += mod.a.eq(self.i.a)
+            m.d.comb += mod.b.eq(self.i.b)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
@@ -1143,8 +1156,8 @@ class AllTerms(Elaboratable):
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
-            m.d.comb += t.a.eq(self.a)
-            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.a.eq(self.i.a)
+            m.d.comb += t.b.eq(self.i.b)
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
@@ -1170,8 +1183,8 @@ class AllTerms(Elaboratable):
 
         # copy reg part points and part ops to output
         m.d.comb += self.o.reg_partition_points.eq(eps)
-        m.d.comb += [self.o.part_ops[i].eq(self.part_ops[i])
-                                     for i in range(len(self.part_ops))]
+        m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+                                     for i in range(len(self.i.part_ops))]
 
         return m
 
@@ -1291,11 +1304,11 @@ class Mul8_16_32_64(Elaboratable):
         t = AllTerms(n_inputs, 128, n_parts, self.register_levels,
                        eps)
         m.submodules.allterms = t
-        m.d.comb += t.a.eq(self.a)
-        m.d.comb += t.b.eq(self.b)
-        m.d.comb += t.epps.eq(eps)
+        m.d.comb += t.i.a.eq(self.a)
+        m.d.comb += t.i.b.eq(self.b)
+        m.d.comb += t.i.epps.eq(eps)
         for i in range(8):
-            m.d.comb += t.part_ops[i].eq(self.part_ops[i])
+            m.d.comb += t.i.part_ops[i].eq(self.part_ops[i])
 
         terms = t.o.inputs