Allow a variable number of operands in the proof driver
authorCesar Strauss <cestrauss@gmail.com>
Sat, 16 Jan 2021 17:06:41 +0000 (14:06 -0300)
committerCesar Strauss <cestrauss@gmail.com>
Sat, 16 Jan 2021 17:13:22 +0000 (14:13 -0300)
It can now check unary operations, as well as binary.

Check the "all" method by defining an operation that applies it to
its argument.

src/ieee754/part/formal/proof_partition.py

index f257d0243b897456d4aded6de62b5d12cae2e1db..df1b6f871c15c1cf16e338bb918b3e976d919434 100644 (file)
@@ -323,7 +323,7 @@ class GeneratorDriver(Elaboratable):
 
 class ComparisonOpDriver(Elaboratable):
     """Checks comparison operations on partitioned signals"""
-    def __init__(self, op, width, mwidth):
+    def __init__(self, op, width=64, mwidth=8, nops=2):
         self.op = op
         """Operation to perform. Must accept two integer-like inputs and give
         a predicate-like output (1-bit partitions in case of
@@ -332,37 +332,43 @@ class ComparisonOpDriver(Elaboratable):
         """Partition full width"""
         self.mwidth = mwidth
         """Maximum number of equally sized partitions"""
-
+        self.nops = nops
+        """Number of input operands"""
     def elaborate(self, _):
         m = Module()
         comb = m.d.comb
         width = self.width
         mwidth = self.mwidth
+        nops = self.nops
         # setup partition points and gates
         step = int(width/mwidth)
         points, gates = make_partitions(step, mwidth)
         # setup inputs and outputs
-        a = PartitionedSignal(points, width)
-        b = PartitionedSignal(points, width)
+        operands = list()
+        for i in range(nops):
+            inp = PartitionedSignal(points, width, name=f"i_{i+1}")
+            inp.set_module(m)
+            operands.append(inp)
         output = Signal(mwidth)
-        a.set_module(m)
-        b.set_module(m)
         # perform the operation on the partitioned signals
-        comb += output.eq(self.op(a, b))
+        comb += output.eq(self.op(*operands))
         # instantiate the partitioned gate generator and connect the gates
         m.submodules.gen = gen = GateGenerator(mwidth)
         comb += gates.eq(gen.gates)
         p_offset = gen.p_offset
         p_width = gen.p_width
         # generate shifted down inputs and outputs
+        p_operands = list()
+        for i in range(nops):
+            p_i = Signal(width, name=f"p_{i+1}")
+            p_operands.append(p_i)
+            for pos in range(mwidth):
+                with m.If(p_offset == pos):
+                    comb += p_i.eq(operands[i].sig[pos * step:])
         p_output = Signal(mwidth)
-        p_a = Signal(width)
-        p_b = Signal(width)
         for pos in range(mwidth):
             with m.If(p_offset == pos):
                 comb += p_output.eq(output[pos:])
-                comb += p_a.eq(a.sig[pos * step:])
-                comb += p_b.eq(b.sig[pos * step:])
         # generate and check expected values for all possible partition sizes
         for w in range(1, mwidth+1):
             with m.If(p_width == w):
@@ -371,18 +377,19 @@ class ComparisonOpDriver(Elaboratable):
                 input_bit_width = w * step
                 output_bit_width = w
                 expected = Signal(output_bit_width, name=f"expected_{w}")
-                a = Signal(input_bit_width, name=f"a_{w}")
-                b = Signal(input_bit_width, name=f"b_{w}")
+                trunc_operands = list()
+                for i in range(nops):
+                    t_i = Signal(input_bit_width, name=f"t{w}_{i+1}")
+                    trunc_operands.append(t_i)
+                    comb += t_i.eq(p_operands[i][:input_bit_width])
                 lsb = Signal(name=f"lsb_{w}")
-                comb += a.eq(p_a[:input_bit_width])
-                comb += b.eq(p_b[:input_bit_width])
-                comb += lsb.eq(self.op(a, b))
+                comb += lsb.eq(self.op(*trunc_operands))
                 comb += expected.eq(Repl(lsb, output_bit_width))
                 # truncate the output, compare and assert
                 comb += Assert(p_output[:output_bit_width] == expected)
         # output a test case
         comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1) &
-                      (p_a != 0) & (p_b != 0) & (p_output != 0))
+                      (p_output != 0))
         return m
 
 
@@ -465,12 +472,15 @@ class PartitionTestCase(FHDLTestCase):
             ('p_offset[2:0]', 'dec'),
             ('p_width[3:0]', 'dec'),
             ('p_gates[8:0]', 'bin'),
+            'i_1[63:0]', 'i_2[63:0]',
             ('eq_1', {'submodule': 'eq_1'}, [
                 ('gates[6:0]', 'bin'),
                 'a[63:0]', 'b[63:0]',
                 ('output[7:0]', 'bin')]),
-            'p_a[63:0]', 'p_b[63:0]',
-            ('p_output[7:0]', 'bin')]
+            'p_1[63:0]', 'p_2[63:0]',
+            ('p_output[7:0]', 'bin'),
+            't3_1[23:0]', 't3_2[23:0]', 'lsb_3',
+            ('expected_3[2:0]', 'bin')]
         write_gtkw(
             'proof_partsig_eq_cover.gtkw',
             os.path.dirname(__file__) +
@@ -487,29 +497,63 @@ class PartitionTestCase(FHDLTestCase):
             module='top',
             zoom=-3
         )
-        module = ComparisonOpDriver(operator.eq, 64, 8)
+        module = ComparisonOpDriver(operator.eq)
         self.assertFormal(module, mode="bmc", depth=1)
         self.assertFormal(module, mode="cover", depth=1)
 
     def test_partsig_ne(self):
-        module = ComparisonOpDriver(operator.ne, 64, 8)
+        module = ComparisonOpDriver(operator.ne)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_gt(self):
-        module = ComparisonOpDriver(operator.gt, 64, 8)
+        module = ComparisonOpDriver(operator.gt)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_ge(self):
-        module = ComparisonOpDriver(operator.ge, 64, 8)
+        module = ComparisonOpDriver(operator.ge)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_lt(self):
-        module = ComparisonOpDriver(operator.lt, 64, 8)
+        module = ComparisonOpDriver(operator.lt)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_le(self):
-        module = ComparisonOpDriver(operator.le, 64, 8)
+        module = ComparisonOpDriver(operator.le)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_all(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
+        traces = [
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
+            'i_1[63:0]',
+            ('eq_1', {'submodule': 'eq_1'}, [
+                ('gates[6:0]', 'bin'),
+                'a[63:0]', 'b[63:0]',
+                ('output[7:0]', 'bin')]),
+            'p_1[63:0]',
+            ('p_output[7:0]', 'bin'),
+            't3_1[23:0]', 'lsb_3',
+            ('expected_3[2:0]', 'bin')]
+        write_gtkw(
+            'proof_partsig_all_cover.gtkw',
+            os.path.dirname(__file__) +
+            '/proof_partition_partsig_all/engine_0/trace0.vcd',
+            traces, style,
+            module='top',
+            zoom=-3
+        )
+
+        def op_all(obj):
+            return obj.all()
+
+        module = ComparisonOpDriver(op_all, nops=1)
         self.assertFormal(module, mode="bmc", depth=1)
+        self.assertFormal(module, mode="cover", depth=1)
 
 
 if __name__ == '__main__':