Fix handling of injectivity for str.unit (#8899)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 23 Jun 2022 00:45:52 +0000 (19:45 -0500)
committerGitHub <noreply@github.com>
Thu, 23 Jun 2022 00:45:52 +0000 (00:45 +0000)
Fixes #8890.

src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/strings/base_solver.cpp
src/theory/strings/base_solver.h
src/theory/strings/core_solver.cpp
src/theory/strings/infer_proof_cons.cpp
src/theory/strings/term_registry.cpp
src/theory/strings/theory_strings_utils.cpp
src/theory/strings/theory_strings_utils.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/strings/issue8890-inj-oob.smt2 [new file with mode: 0644]

index 55cb72451486e96339aacfc0338fd8214fbb7bc9..495971dc95e372c08d2d799bfd719432da512dc4 100644 (file)
@@ -387,6 +387,8 @@ const char* toString(InferenceId i)
     case InferenceId::STRINGS_I_CONST_CONFLICT:
       return "STRINGS_I_CONST_CONFLICT";
     case InferenceId::STRINGS_I_NORM: return "STRINGS_I_NORM";
+    case InferenceId::STRINGS_UNIT_SPLIT: return "STRINGS_UNIT_SPLIT";
+    case InferenceId::STRINGS_UNIT_INJ_OOB: return "STRINGS_UNIT_INJ_OOB";
     case InferenceId::STRINGS_UNIT_INJ: return "STRINGS_UNIT_INJ";
     case InferenceId::STRINGS_UNIT_CONST_CONFLICT:
       return "STRINGS_UNIT_CONST_CONFLICT";
index 3a6452e45743f53f106e4f183e117189bca25c2e..4166560e3446a5b94e5a5e4de76c0b27bf1c7473 100644 (file)
@@ -561,6 +561,11 @@ enum class InferenceId
   // equal after e.g. removing strings that are currently empty. For example:
   //   y = "" ^ z = "" => x ++ y = z ++ x
   STRINGS_I_NORM,
+  // split between the argument of two equated str.unit terms
+  STRINGS_UNIT_SPLIT,
+  // a code point must be out of bounds due to (str.unit x) = (str.unit y) and
+  // x != y.
+  STRINGS_UNIT_INJ_OOB,
   // injectivity of seq.unit
   // (seq.unit x) = (seq.unit y) => x=y, or
   // (seq.unit x) = (seq.unit c) => x=c
index 9193551cfaec9f609223ae0781ac8c45633d51c1..998926802d17593d680fd54d12b8581ec705348c 100644 (file)
@@ -37,7 +37,12 @@ BaseSolver::BaseSolver(Env& env,
                        SolverState& s,
                        InferenceManager& im,
                        TermRegistry& tr)
-    : EnvObj(env), d_state(s), d_im(im), d_termReg(tr), d_congruent(context())
+    : EnvObj(env),
+      d_state(s),
+      d_im(im),
+      d_termReg(tr),
+      d_congruent(context()),
+      d_strUnitOobEq(userContext())
 {
   d_false = NodeManager::currentNM()->mkConst(false);
   d_cardSize = options().strings.stringsAlphaCard;
@@ -56,6 +61,7 @@ void BaseSolver::checkInit()
   // count of congruent, non-congruent per operator (independent of type),
   // for debugging.
   std::map<Kind, std::pair<uint32_t, uint32_t>> congruentCount;
+  NodeManager* nm = NodeManager::currentNM();
   eq::EqualityEngine* ee = d_state.getEqualityEngine();
   eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
   while (!eqcs_i.isFinished())
@@ -122,10 +128,39 @@ void BaseSolver::checkInit()
             }
             if (!d_state.areEqual(s, t))
             {
-              // (seq.unit x) = (seq.unit y) => x=y, or
-              // (seq.unit x) = (seq.unit c) => x=c
               Assert(s.getType() == t.getType());
-              d_im.sendInference(exp, s.eqNode(t), InferenceId::STRINGS_UNIT_INJ);
+              Node eq = s.eqNode(t);
+              if (n.getType().isString())
+              {
+                // String unit is not injective, due to invalid code points.
+                // We do an inference scheme in two parts.
+                // for (str.unit x), (str.unit y): x = y or x != y
+                if (!d_state.areDisequal(s, t))
+                {
+                  d_im.sendSplit(s, t, InferenceId::STRINGS_UNIT_SPLIT);
+                }
+                else if (d_strUnitOobEq.find(eq) == d_strUnitOobEq.end())
+                {
+                  // cache that we have performed this inference
+                  Node eqSym = t.eqNode(s);
+                  d_strUnitOobEq.insert(eq);
+                  d_strUnitOobEq.insert(eqSym);
+                  exp.push_back(eq.notNode());
+                  // (str.unit x) = (str.unit y) ^ x != y =>
+                  // x or y is not a valid code point
+                  Node scr = utils::mkCodeRange(s, d_cardSize);
+                  Node tcr = utils::mkCodeRange(t, d_cardSize);
+                  Node conc = nm->mkNode(OR, scr.notNode(), tcr.notNode());
+                  d_im.sendInference(
+                      exp, conc, InferenceId::STRINGS_UNIT_INJ_OOB);
+                }
+              }
+              else
+              {
+                // (seq.unit x) = (seq.unit y) => x=y, or
+                // (seq.unit x) = (seq.unit c) => x=c
+                d_im.sendInference(exp, eq, InferenceId::STRINGS_UNIT_INJ);
+              }
             }
           }
           // update best content
index 69c14d34761acd33eb7e94df74731193ce9ba749..f687ed0eb422f861dc4153c47adc83b7ab609181 100644 (file)
@@ -233,6 +233,11 @@ class BaseSolver : protected EnvObj
    * various inference schemas implemented by this class.
    */
   NodeSet d_congruent;
+  /**
+   * Set of equalities that we have applied STRINGS_UNIT_INJ_OOB to
+   * in the current user context
+   */
+  NodeSet d_strUnitOobEq;
   /**
    * Maps equivalence classes to their info, see description of `BaseEqcInfo`
    * for more information.
index 8c74af50f0c0719d542c023fdb8667b815eb4ae5..99aac3815b4e5eeda4746aa1edb90c2a03f46e8d 100644 (file)
@@ -1276,7 +1276,10 @@ void CoreSolver::processSimpleNEq(NormalForm& nfi,
     Trace("strings-solve-debug")
         << "Process " << x << " ... " << y << std::endl;
 
-    if (x == y)
+    // require checking equal here, usually x == y, but this also holds the
+    // case of (str.unit a) and (str.unit b) for distinct a, b, where we do
+    // not want to unify these terms here.
+    if (d_state.areEqual(x, y))
     {
       // The normal forms have the same term at the current position. We just
       // continue with the next index. By construction of the normal forms, we
@@ -1290,7 +1293,6 @@ void CoreSolver::processSimpleNEq(NormalForm& nfi,
       index++;
       continue;
     }
-    Assert(!d_state.areEqual(x, y));
 
     std::vector<Node> lenExp;
     Node xLenTerm = d_state.getLength(x, lenExp);
index 212bc56d3afa82bf61915abbbf3e43b76366c1f2..567f61f2699137b1d6a56897af35dfb04bf18aa7 100644 (file)
@@ -664,6 +664,7 @@ void InferProofCons::convert(InferenceId infer,
     case InferenceId::STRINGS_DEQ_STRINGS_EQ:
     case InferenceId::STRINGS_DEQ_LENS_EQ:
     case InferenceId::STRINGS_DEQ_LENGTH_SP:
+    case InferenceId::STRINGS_UNIT_SPLIT:
     {
       if (conc.getKind() != OR)
       {
index b85e3460b13eba3406567f6622f7587a82d07079..4ca2ac5f7627343abf67eeaa0c7cd2bcdda1e41c 100644 (file)
@@ -77,14 +77,6 @@ uint32_t TermRegistry::getAlphabetCardinality() const { return d_alphaCard; }
 
 void TermRegistry::finishInit(InferenceManager* im) { d_im = im; }
 
-Node mkCodeRange(Node t, uint32_t alphaCard)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  return nm->mkNode(AND,
-                    nm->mkNode(GEQ, t, nm->mkConstInt(Rational(0))),
-                    nm->mkNode(LT, t, nm->mkConstInt(Rational(alphaCard))));
-}
-
 Node TermRegistry::eagerReduce(Node t, SkolemCache* sc, uint32_t alphaCard)
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -96,7 +88,7 @@ Node TermRegistry::eagerReduce(Node t, SkolemCache* sc, uint32_t alphaCard)
     Node len = nm->mkNode(STRING_LENGTH, t[0]);
     Node code_len = len.eqNode(nm->mkConstInt(Rational(1)));
     Node code_eq_neg1 = t.eqNode(nm->mkConstInt(Rational(-1)));
-    Node code_range = mkCodeRange(t, alphaCard);
+    Node code_range = utils::mkCodeRange(t, alphaCard);
     lemma = nm->mkNode(ITE, code_len, code_range, code_eq_neg1);
   }
   else if (tk == SEQ_NTH)
@@ -111,7 +103,7 @@ Node TermRegistry::eagerReduce(Node t, SkolemCache* sc, uint32_t alphaCard)
       Node c2 = nm->mkNode(GT, nm->mkNode(STRING_LENGTH, s), n);
       // check whether this application of seq.nth is defined.
       Node cond = nm->mkNode(AND, c1, c2);
-      Node code_range = mkCodeRange(t, alphaCard);
+      Node code_range = utils::mkCodeRange(t, alphaCard);
       // the lemma for `seq.nth`
       lemma = nm->mkNode(ITE, cond, code_range, t.eqNode(nm->mkConstInt(Rational(-1))));
       // IF: n >=0 AND n < len( s )
index e692e93c421b55a0b2975aa8e8c07a9da1adb048..a3ab50919bcb06c0eafffffe78d581c53058af9d 100644 (file)
@@ -449,6 +449,14 @@ Node mkAbstractStringValueForLength(Node n, Node len, size_t id)
   return quantifiers::mkNamedQuant(WITNESS, bvl, pred, ss.str());
 }
 
+Node mkCodeRange(Node t, uint32_t alphaCard)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  return nm->mkNode(AND,
+                    nm->mkNode(GEQ, t, nm->mkConstInt(Rational(0))),
+                    nm->mkNode(LT, t, nm->mkConstInt(Rational(alphaCard))));
+}
+
 }  // namespace utils
 }  // namespace strings
 }  // namespace theory
index e6671d87e2a380311bfa04b69e976a97fbb984e7..14172a3df74a5ead7e2c51d52af67674865b4033 100644 (file)
@@ -217,6 +217,11 @@ Node mkForallInternal(Node bvl, Node body);
  */
 Node mkAbstractStringValueForLength(Node n, Node len, size_t id);
 
+/**
+ * Make the formula (and (>= t 0) (< t alphaCard)).
+ */
+Node mkCodeRange(Node t, uint32_t alphaCard);
+
 }  // namespace utils
 }  // namespace strings
 }  // namespace theory
index 45eb213b86ae41454557578292a957d19d392641..30f1d7177757a153982433628f5789ac775e2c9d 100644 (file)
@@ -2622,6 +2622,7 @@ set(regress_1_tests
   regress1/strings/issue8094-witness-model.smt2
   regress1/strings/issue8347-has-skolem.smt2
   regress1/strings/issue8434-nterm-str-rw.smt2
+  regress1/strings/issue8890-inj-oob.smt2
   regress1/strings/kaluza-fl.smt2
   regress1/strings/loop002.smt2
   regress1/strings/loop003.smt2
diff --git a/test/regress/cli/regress1/strings/issue8890-inj-oob.smt2 b/test/regress/cli/regress1/strings/issue8890-inj-oob.smt2
new file mode 100644 (file)
index 0000000..4705a91
--- /dev/null
@@ -0,0 +1,9 @@
+; COMMAND-LINE: --strings-exp --mbqi --strings-fmf
+; EXPECT: sat
+(set-logic ALL)
+(declare-fun b (Int) Bool)
+(declare-fun v () String)
+(declare-fun a () String)
+(assert (or (b 0) (= 0 (ite (= 1 (str.len v)) 1 0))))
+(assert (forall ((e String) (va Int)) (or (= va 0) (distinct 0 (ite (str.prefixof "-" a) (str.to_int (str.substr v 1 (str.len e))) (str.to_int e))))))
+(check-sat)