in-place expansion of partition points
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 09:39:35 +0000 (10:39 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 09:39:35 +0000 (10:39 +0100)
src/ieee754/part_mul_add/multiply.py

index 8e4ea8305ed43441254d33e1002c0c3d92c285a5..d0d84e4f203a875ae5cfde7eb8df478e2d79c96a 100644 (file)
@@ -71,17 +71,20 @@ class PartitionPoints(dict):
         for point, enabled in self.items():
             yield enabled.eq(rhs[point])
 
-    def as_mask(self, width):
+    def as_mask(self, width, mul=1):
         """Create a bit-mask from `self`.
 
         Each bit in the returned mask is clear only if the partition point at
         the same bit-index is enabled.
 
         :param width: the bit width of the resulting mask
+        :param mul: a "multiplier" which in-place expands the partition points
+                    typically set to "2" when used for multipliers
         """
         bits = []
         for i in range(width):
-            if i in self:
+            i /= mul
+            if i.is_integer() and int(i) in self:
                 bits.append(~self[i])
             else:
                 bits.append(True)
@@ -227,13 +230,16 @@ class PartitionedAdder(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, width, partition_points):
+    def __init__(self, width, partition_points, partition_step=1):
         """Create a ``PartitionedAdder``.
 
         :param width: the bit width of the input and output
         :param partition_points: the input partition points
+        :param partition_step: a multiplier (typically double) step
+                               which in-place "expands" the partition points
         """
         self.width = width
+        self.pmul = partition_step
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
         self.output = Signal(width, reset_less=True)
@@ -267,11 +273,12 @@ class PartitionedAdder(Elaboratable):
         # carry has been carried *over* the break point.
 
         for i in range(self.width):
-            if i in self.partition_points:
+            pi = i/self.pmul # double the range of the partition point test
+            if pi.is_integer() and pi in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(expanded_a[expanded_index])
-                al.append(~self.partition_points[i]) # add extra bit in a
+                al.append(~self.partition_points[pi]) # add extra bit in a
                 eb.append(expanded_b[expanded_index])
                 bl.append(C(0)) # yes, add a zero
                 expanded_index += 1 # skip the extra point.  NOT in the output
@@ -503,7 +510,8 @@ class AddReduceSingle(Elaboratable):
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
 
-        mask = self.i.reg_partition_points.as_mask(self.output_width)
+        # get partition points as a mask
+        mask = self.i.reg_partition_points.as_mask(self.output_width, mul=2)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
@@ -794,7 +802,7 @@ class Parts(Elaboratable):
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
+            m.d.comb += pb.eq(epps.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -1120,7 +1128,7 @@ class AllTerms(Elaboratable):
         tl = []
         for i in range(8):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(eps.part_byte(i, mfactor=2))
+            m.d.comb += pb.eq(eps.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -1296,7 +1304,7 @@ class Mul8_16_32_64(Elaboratable):
         expanded_part_pts = eps = PartitionPoints()
         for i, v in self.part_pts.items():
             ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
-            expanded_part_pts[i * 2] = ep
+            expanded_part_pts[i] = ep
             m.d.comb += ep.eq(v)
 
         n_inputs = 64 + 4