Keep the valid signal from the formal engine ALU stable, until read
[soc.git] / src / soc / experiment / formal / proof_compalu_multi.py
index d81f35e12151518a77a2bf5358016acef9c7952c..bf6d2615b8142c15561cc63ca946ce718547e227 100644 (file)
@@ -1,6 +1,6 @@
 # SPDX-License-Identifier: LGPLv3+
 # Copyright (C) 2022 Cesar Strauss <cestrauss@gmail.com>
-# Sponsored by NLnet and NGI POINTER under EU Grants 871528 and 957073
+# Sponsored by NLnet under EU Grant and 957073
 # Part of the Libre-SOC Project.
 
 """
@@ -35,7 +35,7 @@ https://bugs.libre-soc.org/show_bug.cgi?id=197
 import unittest
 
 from nmigen import Signal, Module
-from nmigen.hdl.ast import Cover
+from nmigen.hdl.ast import Cover, Const, Assume, Assert
 from nmutil.formaltest import FHDLTestCase
 from nmutil.singlepipe import ControlBase
 
@@ -139,10 +139,55 @@ class CompALUMultiTestCase(FHDLTestCase):
         cnt_alu_read = Signal(4)
         m.d.sync += cnt_alu_read.eq(cnt_alu_read + do_alu_read)
         cnt_masked_read = []
+        do_masked_read = Signal(dut.n_src)
         for i in range(dut.n_src):
             cnt = Signal(4, name="cnt_masked_read_%d" % i)
-            m.d.sync += cnt.eq(cnt + (do_issue & dut.rdmaskn[i]))
+            if i == 0:
+                extra = dut.oper_i.zero_a
+            elif i == 1:
+                extra = dut.oper_i.imm_data.ok
+            else:
+                extra = Const(0, 1)
+            m.d.comb += do_masked_read[i].eq(do_issue &
+                                             (dut.rdmaskn[i] | extra))
+            m.d.sync += cnt.eq(cnt + do_masked_read[i])
             cnt_masked_read.append(cnt)
+        # If the ALU is idle, do not assert valid
+        with m.If(cnt_alu_read == cnt_alu_write):
+            m.d.comb += Assume(~alu.n.o_valid)
+        # Keep ALU valid high, until read
+        last_alu_valid = Signal()
+        m.d.sync += last_alu_valid.eq(alu.n.o_valid & ~alu.n.i_ready)
+        with m.If(last_alu_valid):
+            m.d.comb += Assume(alu.n.o_valid)
+
+        # Invariant checks
+
+        # For every instruction issued, at any point in time,
+        # each operand was either:
+        # 1) Already read
+        # 2) Not read yet, but the read is pending (rel_o high)
+        # 3) Masked
+        for i in range(dut.n_src):
+            sum_read = Signal(4)
+            m.d.comb += sum_read.eq(
+                cnt_read[i] + cnt_masked_read[i] + dut.cu.rd.rel_o[i])
+            m.d.comb += Assert(sum_read == cnt_issue)
+
+        # For every instruction, either:
+        # 1) The ALU is executing the instruction
+        # 2) Otherwise, execution is pending (alu.p.i_valid is high)
+        # 3) Otherwise, it is waiting for operands
+        #    (some dut.cu.rd.rel_o are still high)
+        # 4) ... unless all operands are masked, in which case there is a one
+        #    cycle delay
+        all_masked = Signal()
+        m.d.sync += all_masked.eq(do_masked_read.all())
+        sum_alu_write = Signal(4)
+        m.d.comb += sum_alu_write.eq(
+            cnt_alu_write +
+            (dut.cu.rd.rel_o.any() | all_masked | alu.p.i_valid))
+        m.d.comb += Assert(sum_alu_write == cnt_issue)
 
         # Ask the formal engine to give an example
         m.d.comb += Cover((cnt_issue == 2)
@@ -154,7 +199,12 @@ class CompALUMultiTestCase(FHDLTestCase):
                           & (cnt_alu_read == 1)
                           & (cnt_masked_read[0] == 1)
                           & (cnt_masked_read[1] == 1))
-        self.assertFormal(m, mode="cover", depth=10)
+        with self.subTest("cover"):
+            self.assertFormal(m, mode="cover", depth=10)
+
+        # Check assertions
+        with self.subTest("bmc"):
+            self.assertFormal(m, mode="bmc", depth=10)
 
 
 if __name__ == "__main__":