shuffle and comments
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 26 Feb 2020 17:44:15 +0000 (17:44 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 26 Feb 2020 17:44:15 +0000 (17:44 +0000)
src/ieee754/part_shift/part_shift_dynamic.py

index ba653fcf9e3012bd14d8fe070fba4a689700b009..e04083a2ff5644e7234f303820a89d340d4536b0 100644 (file)
@@ -19,6 +19,7 @@ from ieee754.part_shift.bitrev import GatedBitReverse
 import math
 
 class ShifterMask(Elaboratable):
+
     def __init__(self, pwid, bwid, max_bits, min_bits):
         self.max_bits = max_bits
         self.min_bits = min_bits
@@ -39,6 +40,7 @@ class ShifterMask(Elaboratable):
             comb += self.mask.eq(minm)
             return m
 
+        # create bit-cascade
         bits = Signal(self.pwid, reset_less=True)
         bl = []
         for j in range(self.pwid):
@@ -49,6 +51,7 @@ class ShifterMask(Elaboratable):
             else:
                 comb += bit.eq(~self.gates[j])
             bl.append(bit)
+
         # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
         for j in range(bits.shape()[0]):
             comb += bits[j].eq(bl[j])
@@ -98,6 +101,7 @@ class PartialResult(Elaboratable):
 
 
 class PartitionedDynamicShift(Elaboratable):
+
     def __init__(self, width, partition_points):
         self.width = width
         self.partition_points = PartitionPoints(partition_points)
@@ -109,6 +113,8 @@ class PartitionedDynamicShift(Elaboratable):
 
     def elaborate(self, platform):
         m = Module()
+
+        # temporaries
         comb = m.d.comb
         width = self.width
         pwid = self.partition_points.get_max_partition_count(width)-1
@@ -119,6 +125,8 @@ class PartitionedDynamicShift(Elaboratable):
         keys = list(self.partition_points.keys()) + [self.width]
         start = 0
 
+        # create gated-reversed versions of a, b and the output
+        # left-shift is non-reversed, right-shift is reversed
         m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
         comb += a_br.data.eq(self.a)
         comb += a_br.reverse_en.eq(self.shift_right)
@@ -131,7 +139,6 @@ class PartitionedDynamicShift(Elaboratable):
         comb += gate_br.data.eq(gates)
         comb += gate_br.reverse_en.eq(self.shift_right)
 
-
         # break out both the input and output into partition-stratified blocks
         a_intervals = []
         b_intervals = []
@@ -198,21 +205,26 @@ class PartitionedDynamicShift(Elaboratable):
         b_shl_amount.append(element)
         for i in range(1, len(keys)):
             element = Mux(gates[i-1], masked_b[i], element)
-            b_shl_amount.append(element)
+            b_shl_amount.append(element) # FIXME: creates an O(N^2) cascade
         b_shr_amount = list(reversed(b_shl_amount))
 
+        # select shift-amount (b) for partition based on op being left or right
         shift_amounts = []
         for i in range(len(b_shl_amount)):
             shift_amount = Signal(masked_b[i].width, name="shift_amount%d" % i)
-            comb += shift_amount.eq(
-                Mux(self.shift_right, b_shr_amount[i], b_shl_amount[i]))
+            sel = Mux(self.shift_right, b_shr_amount[i], b_shl_amount[i])
+            comb += shift_amount.eq(sel)
             shift_amounts.append(shift_amount)
 
+        # now calculate partial results
+
+        # first item (simple)
         partial_results = []
         partial = Signal(width, name="partial0", reset_less=True)
         comb += partial.eq(a_intervals[0] << shift_amounts[0])
-
         partial_results.append(partial)
+
+        # rest of list
         for i in range(1, len(keys)):
             reswid = width - intervals[i][0]
             shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
@@ -225,11 +237,10 @@ class PartitionedDynamicShift(Elaboratable):
             comb += pr.a_interval.eq(a_intervals[i])
             partial_results.append(pr.partial)
 
-        out = []
-
         # This calculates the outputs o0-o3 from the partial results
         # table above.  Note: only relevant bits of the partial result equal
         # to the width of the output column are accumulated in a Mux-cascade.
+        out = []
         s,e = intervals[0]
         result = partial_results[0]
         out.append(result[s:e])