yosys-smtbmc: Option to keep going after failed assertions in BMC mode
authorJannis Harder <me@jix.one>
Mon, 21 Mar 2022 17:26:27 +0000 (18:26 +0100)
committerJannis Harder <me@jix.one>
Thu, 24 Mar 2022 15:01:14 +0000 (16:01 +0100)
backends/smt2/smtbmc.py

index 98804bb32ee476ae60dd0ed096d1abaeb69533b1..e878c3561d4a16890e95a56583295c53204226e6 100644 (file)
@@ -50,6 +50,7 @@ smtcinit = False
 smtctop = None
 noinit = False
 binarymode = False
+keep_going = False
 so = SmtOpts()
 
 
@@ -153,6 +154,13 @@ def usage():
 
     --binary
         dump anyconst values as raw bit strings
+
+    --keep-going
+        continue BMC after the first failed assertion and report
+        further failed assertions. To output multiple traces
+        covering all found failed assertions, the character '%' is
+        replaced in all dump filenames with an increasing number.
+
 """ + so.helpmsg())
     sys.exit(1)
 
@@ -161,7 +169,7 @@ try:
     opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:igcm:", so.longopts +
             ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "btorwit=", "presat",
              "dump-vcd=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
-             "smtc-init", "smtc-top=", "noinit", "binary"])
+             "smtc-init", "smtc-top=", "noinit", "binary", "keep-going"])
 except:
     usage()
 
@@ -234,6 +242,8 @@ for o, a in opts:
         topmod = a
     elif o == "--binary":
         binarymode = True
+    elif o == "--keep-going":
+        keep_going = True
     elif so.handle(o, a):
         pass
     else:
@@ -341,13 +351,13 @@ for fn in inconstr:
             assert False
 
 
-def get_constr_expr(db, state, final=False, getvalues=False):
+def get_constr_expr(db, state, final=False, getvalues=False, individual=False):
     if final:
         if ("final-%d" % state) not in db:
-            return ([], [], []) if getvalues else "true"
+            return ([], [], []) if getvalues or individual else "true"
     else:
         if state not in db:
-            return ([], [], []) if getvalues else "true"
+            return ([], [], []) if getvalues or individual else "true"
 
     netref_regex = re.compile(r'(^|[( ])\[(-?[0-9]+:|)([^\]]*|\S*)\](?=[ )]|$)')
 
@@ -368,15 +378,18 @@ def get_constr_expr(db, state, final=False, getvalues=False):
     expr_list = list()
     for loc, expr in db[("final-%d" % state) if final else state]:
         actual_expr = netref_regex.sub(replace_netref, expr)
-        if getvalues:
+        if getvalues or individual:
             expr_list.append((loc, expr, actual_expr))
         else:
             expr_list.append(actual_expr)
 
-    if getvalues:
-        loc_list, expr_list, acual_expr_list = zip(*expr_list)
-        value_list = smt.get_list(acual_expr_list)
-        return loc_list, expr_list, value_list
+    if getvalues or individual:
+        loc_list, expr_list, actual_expr_list = zip(*expr_list)
+        if individual:
+            return loc_list, expr_list, actual_expr_list
+        else:
+            value_list = smt.get_list(actual_expr_list)
+            return loc_list, expr_list, value_list
 
     if len(expr_list) == 0:
         return "true"
@@ -1071,7 +1084,7 @@ def write_trace(steps_start, steps_stop, index):
         write_constr_trace(steps_start, steps_stop, index)
 
 
-def print_failed_asserts_worker(mod, state, path, extrainfo):
+def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()):
     assert mod in smt.modinfo
     found_failed_assert = False
 
@@ -1079,29 +1092,31 @@ def print_failed_asserts_worker(mod, state, path, extrainfo):
         return
 
     for cellname, celltype in smt.modinfo[mod].cells.items():
-        if print_failed_asserts_worker(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname, extrainfo):
+        cell_infokey = (mod, cellname, infokey)
+        if print_failed_asserts_worker(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname, extrainfo, infomap, cell_infokey):
             found_failed_assert = True
 
     for assertfun, assertinfo in smt.modinfo[mod].asserts.items():
         if smt.get("(|%s| %s)" % (assertfun, state)) in ["false", "#b0"]:
-            print_msg("Assert failed in %s: %s%s" % (path, assertinfo, extrainfo))
+            assert_key = (assertfun, infokey)
+            print_msg("Assert failed in %s: %s%s%s" % (path, assertinfo, extrainfo, infomap.get(assert_key, '')))
             found_failed_assert = True
 
     return found_failed_assert
 
 
-def print_failed_asserts(state, final=False, extrainfo=""):
+def print_failed_asserts(state, final=False, extrainfo="", infomap={}):
     if noinfo: return
     loc_list, expr_list, value_list = get_constr_expr(constr_asserts, state, final=final, getvalues=True)
     found_failed_assert = False
 
     for loc, expr, value in zip(loc_list, expr_list, value_list):
         if smt.bv2int(value) == 0:
-            print_msg("Assert %s failed: %s%s" % (loc, expr, extrainfo))
+            print_msg("Assert %s failed: %s%s%s" % (loc, expr, extrainfo, infomap.get(loc, '')))
             found_failed_assert = True
 
     if not final:
-        if print_failed_asserts_worker(topmod, "s%d" % state, topmod, extrainfo):
+        if print_failed_asserts_worker(topmod, "s%d" % state, topmod, extrainfo, infomap):
             found_failed_assert = True
 
     return found_failed_assert
@@ -1148,6 +1163,43 @@ def get_cover_list(mod, base):
 
     return cover_expr, cover_desc
 
+
+def get_assert_map(mod, base, path, key_base=()):
+    assert mod in smt.modinfo
+
+    assert_map = dict()
+
+    for expr, desc in smt.modinfo[mod].asserts.items():
+        assert_map[(expr, key_base)] = ("(|%s| %s)" % (expr, base), path, desc)
+
+    for cell, submod in smt.modinfo[mod].cells.items():
+        assert_map.update(get_assert_map(submod, "(|%s_h %s| %s)" % (mod, cell, base), path + "." + cell, (mod, cell, key_base)))
+
+    return assert_map
+
+
+def get_assert_keys():
+    keys = set()
+    keys.update(get_assert_map(topmod, 'state', topmod).keys())
+    for step_constr_asserts in constr_asserts.values():
+        keys.update(loc for loc, expr in step_constr_asserts)
+
+    return keys
+
+
+def get_active_assert_map(step, active):
+    assert_map = dict()
+    for key, assert_data in get_assert_map(topmod, "s%s" % step, topmod).items():
+        if key in active:
+            assert_map[key] = assert_data
+
+    for loc, expr, actual_expr in zip(*get_constr_expr(constr_asserts, step, individual=True)):
+        if loc in active:
+            assert_map[loc] = (actual_expr, None, (expr, loc))
+
+    return assert_map
+
+
 states = list()
 asserts_antecedent_cache = [list()]
 asserts_consequent_cache = [list()]
@@ -1457,6 +1509,10 @@ elif covermode:
                 print_msg("Unreached cover statement at %s." % cover_desc[i])
 
 else:  # not tempind, covermode
+    active_assert_keys = get_assert_keys()
+    failed_assert_infomap = dict()
+    traceidx = 0
+
     step = 0
     retstatus = "PASSED"
     while step < num_steps:
@@ -1510,44 +1566,81 @@ else:  # not tempind, covermode
                     break
 
             if not final_only:
-                if last_check_step == step:
-                    print_msg("Checking assertions in step %d.." % (step))
-                else:
-                    print_msg("Checking assertions in steps %d to %d.." % (step, last_check_step))
-                smt_push()
-
-                smt_assert("(not (and %s))" % " ".join(["(|%s_a| s%d)" % (topmod, i) for i in range(step, last_check_step+1)] +
-                        [get_constr_expr(constr_asserts, i) for i in range(step, last_check_step+1)]))
-
-                if smt_check_sat() == "sat":
-                    print("%s BMC failed!" % smt.timestamp())
-                    if append_steps > 0:
-                        for i in range(last_check_step+1, last_check_step+1+append_steps):
-                            print_msg("Appending additional step %d." % i)
-                            smt_state(i)
-                            smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
-                            smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
-                            smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
-                            smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
-                            smt_assert_consequent(get_constr_expr(constr_assumes, i))
-                        print_msg("Re-solving with appended steps..")
-                        if smt_check_sat() == "unsat":
-                            print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
-                            retstatus = "FAILED"
-                            break
-                    print_anyconsts(step)
+                recheck_current_step = True
+                while recheck_current_step:
+                    recheck_current_step = False
+                    if last_check_step == step:
+                        print_msg("Checking assertions in step %d.." % (step))
+                    else:
+                        print_msg("Checking assertions in steps %d to %d.." % (step, last_check_step))
+                    smt_push()
+
+                    active_assert_maps = dict()
+                    active_assert_exprs = list()
                     for i in range(step, last_check_step+1):
-                        print_failed_asserts(i)
-                    write_trace(0, last_check_step+1+append_steps, '%')
-                    retstatus = "FAILED"
-                    break
+                        assert_expr_map = get_active_assert_map(i, active_assert_keys)
+                        active_assert_maps[i] = assert_expr_map
+                        active_assert_exprs.extend(assert_data[0] for assert_data in assert_expr_map.values())
 
-                smt_pop()
+                    if active_assert_exprs:
+                        if len(active_assert_exprs) == 1:
+                            active_assert_expr = active_assert_exprs[0]
+                        else:
+                            active_assert_expr = "(and %s)" % " ".join(active_assert_exprs)
+
+                        smt_assert("(not %s)" % active_assert_expr)
+
+
+                    if smt_check_sat() == "sat":
+                        if retstatus != "FAILED":
+                            print("%s BMC failed!" % smt.timestamp())
+
+                        if append_steps > 0:
+                            for i in range(last_check_step+1, last_check_step+1+append_steps):
+                                print_msg("Appending additional step %d." % i)
+                                smt_state(i)
+                                smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
+                                smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
+                                smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
+                                smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
+                                smt_assert_consequent(get_constr_expr(constr_assumes, i))
+                            print_msg("Re-solving with appended steps..")
+                            if smt_check_sat() == "unsat":
+                                print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
+                                retstatus = "FAILED"
+                                break
+                        print_anyconsts(step)
+
+                        for i in range(step, last_check_step+1):
+                            print_failed_asserts(i, infomap=failed_assert_infomap)
+
+                        if keep_going:
+                            for i in range(step, last_check_step+1):
+                                for key, (expr, path, desc) in active_assert_maps[i].items():
+                                    if key in active_assert_keys and not smt.bv2int(smt.get(expr)):
+                                        failed_assert_infomap[key] = " [failed before]"
+
+                                        active_assert_keys.remove(key)
+
+                            if active_assert_keys:
+                                recheck_current_step = True
+
+                        write_trace(0, last_check_step+1+append_steps, "%d" % traceidx if keep_going else '%')
+                        traceidx += 1
+                        retstatus = "FAILED"
+
+                    smt_pop()
+                    if recheck_current_step:
+                        print_msg("Checking remaining assertions..")
+
+                if retstatus == "FAILED" and not (keep_going and active_assert_keys):
+                    break
 
             if (constr_final_start is not None) or (last_check_step+1 != num_steps):
                 for i in range(step, last_check_step+1):
-                    smt_assert("(|%s_a| s%d)" % (topmod, i))
-                    smt_assert(get_constr_expr(constr_asserts, i))
+                    assert_expr_map = get_active_assert_map(i, active_assert_keys)
+                    for assert_data in assert_expr_map.values():
+                        smt_assert(assert_data[0])
 
             if constr_final_start is not None:
                 for i in range(step, last_check_step+1):