re-add masking of the shift amount based on partition length
authorMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 19:29:25 +0000 (14:29 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 20:36:55 +0000 (15:36 -0500)
src/ieee754/part_shift/formal/proof_shift_dynamic.py
src/ieee754/part_shift/part_shift_dynamic.py

index d4e7a525282b7ab7a4e753ed7939202074b5dcd2..a836771c2262ebd0bf24f2d720561a5367257708 100644 (file)
@@ -2,7 +2,7 @@
 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
 
 from nmigen import Module, Signal, Elaboratable, Mux, Cat
-from nmigen.asserts import Assert, AnyConst, Assume
+from nmigen.asserts import Assert, AnyConst
 from nmigen.test.utils import FHDLTestCase
 from nmigen.cli import rtlil
 
@@ -63,71 +63,55 @@ class ShifterDriver(Elaboratable):
 
         with m.Switch(points.as_sig()):
             with m.Case(0b000):
-                comb += Assume(b <= 32)
-                comb += Assert(out == (a<<b[0:6]) & 0xffffffff)
+                comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
             with m.Case(0b001):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 24)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[1:4]) ==
                                (Cat(a_intervals[1:4])
-                                << b_intervals[1]) & 0xffffff)
+                                << b_intervals[1][0:5]) & 0xffffff)
             with m.Case(0b010):
-                comb += Assume(b_intervals[0] <= 16)
                 comb += Assert(Cat(out_intervals[0:2]) ==
                                (Cat(a_intervals[0:2])
-                                << b_intervals[0]) & 0xffff)
-                comb += Assume(b_intervals[2] <= 16)
+                                << (b_intervals[0] & 0xf)) & 0xffff)
                 comb += Assert(Cat(out_intervals[2:4]) ==
                                (Cat(a_intervals[2:4])
-                                << b_intervals[2]) & 0xffff)
+                                << (b_intervals[2] & 0xf)) & 0xffff)
             with m.Case(0b011):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 8)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(out_intervals[1] ==
-                               (a_intervals[1] << b_intervals[1]) & 0xff)
-                comb += Assume(b_intervals[2] <= 16)
+                               (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[2:4]) ==
                                (Cat(a_intervals[2:4])
-                                << b_intervals[2]) & 0xffff)
+                                << b_intervals[2][0:4]) & 0xffff)
             with m.Case(0b100):
-                comb += Assume(b_intervals[0] <= 24)
                 comb += Assert(Cat(out_intervals[0:3]) ==
                                (Cat(a_intervals[0:3])
-                                << b_intervals[0]) & 0xffffff)
-                comb += Assume(b_intervals[3] <= 8)
+                                << b_intervals[0][0:5]) & 0xffffff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b101):
-                comb += Assume(b_intervals[0] <= 8)
                 comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0]) & 0xff)
-                comb += Assume(b_intervals[1] <= 16)
+                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
                 comb += Assert(Cat(out_intervals[1:3]) ==
                                (Cat(a_intervals[1:3])
-                                << b_intervals[1]) & 0xffff)
-                comb += Assume(b_intervals[3] <= 8)
+                                << b_intervals[1][0:4]) & 0xffff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b110):
-                comb += Assume(b_intervals[0] <= 16)
                 comb += Assert(Cat(out_intervals[0:2]) ==
                                (Cat(a_intervals[0:2])
-                                << b_intervals[0]) & 0xffff)
-                comb += Assume(b_intervals[2] <= 8)
+                                << b_intervals[0][0:4]) & 0xffff)
                 comb += Assert(out_intervals[2] ==
-                               (a_intervals[2] << b_intervals[2]) & 0xff)
-                comb += Assume(b_intervals[3] <= 8)
+                               (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
                 comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3]) & 0xff)
+                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
             with m.Case(0b111):
                 for i, o in enumerate(out_intervals):
-                    comb += Assume(b_intervals[i] <= 8)
                     comb += Assert(o ==
-                                   (a_intervals[i] << b_intervals[i]) & 0xff)
+                                   (a_intervals[i] << b_intervals[i][0:3])
+                                   & 0xff)
 
         return m
 
@@ -135,6 +119,7 @@ class PartitionedDynamicShiftTestCase(FHDLTestCase):
     def test_shift(self):
         module = ShifterDriver()
         self.assertFormal(module, mode="bmc", depth=4)
+
     def test_ilang(self):
         width = 64
         mwidth = 8
@@ -152,4 +137,3 @@ class PartitionedDynamicShiftTestCase(FHDLTestCase):
 
 if __name__ == "__main__":
     unittest.main()
-
index edcb7162ac68f3014acda5b1bf6dcac7d4b41267..21a0a9efd49be59f4866eab0f5cd556f8d69f19b 100644 (file)
@@ -55,6 +55,25 @@ class PartitionedDynamicShift(Elaboratable):
             intervals.append([start,end])
             start = end
 
+        min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
+        max_bits = math.ceil(math.log2(width))
+
+        shifter_masks = []
+        for i in range(len(b_intervals)):
+            mask = Signal(b_intervals[i].shape(), name="shift_mask%d" % i)
+            bits = []
+            for j in range(i, gates.width):
+                if bits:
+                    bits.append(~gates[j] & bits[-1])
+                else:
+                    bits.append(~gates[j])
+            comb += mask.eq(Cat((1 << min_bits)-1, bits)
+                            & ((1 << max_bits)-1))
+            shifter_masks.append(mask)
+
+        print(shifter_masks)
+
+
         # Instead of generating the matrix described in the wiki, I
         # instead calculate the shift amounts for each partition, then
         # calculate the partial results of each partition << shift
@@ -73,12 +92,16 @@ class PartitionedDynamicShift(Elaboratable):
         # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
         # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
         # those partial results to calculate a0, a1, a2, and a3
+        shiftbits = math.ceil(math.log2(width))
+        element = b_intervals[0] & shifter_masks[0]
         partial_results = []
-        partial_results.append(a_intervals[0] << b_intervals[0])
-        element = b_intervals[0]
+        partial_results.append(a_intervals[0] << element)
         for i in range(1, len(out_intervals)):
             s, e = intervals[i]
-            element = Mux(gates[i-1], b_intervals[i], element)
+            masked = Signal(b_intervals[i].shape(), name="masked%d" % i)
+            comb += masked.eq(b_intervals[i] & shifter_masks[i])
+            element = Mux(gates[i-1], masked,
+                          element)
 
             # This calculates which partition of b to select the
             # shifter from. According to the table above, the
@@ -86,9 +109,8 @@ class PartitionedDynamicShift(Elaboratable):
             # the partition mask, this calculates that with a mux
             # chain
 
-
             # This computes the partial results table
-            shifter = Signal(8, name="shifter%d" % i)
+            shifter = Signal(shiftbits, name="shifter%d" % i)
             comb += shifter.eq(element)
             partial = Signal(width, name="partial%d" % i)
             comb += partial.eq(a_intervals[i] << shifter)