Fix mod elimination in learned rewrite preprocessing pass (#8938)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 7 Jul 2022 20:09:36 +0000 (15:09 -0500)
committerGitHub <noreply@github.com>
Thu, 7 Jul 2022 20:09:36 +0000 (20:09 +0000)
Fixes #8934.

Also makes this rewrite cover more cases.

src/preprocessing/passes/learned_rewrite.cpp
test/regress/cli/CMakeLists.txt
test/regress/cli/regress0/nl/issue8934-lr-int-mod-range.smt2 [new file with mode: 0644]
test/regress/cli/regress1/nl/learned-rewrite-int-mod-range.smt2 [new file with mode: 0644]

index 194cd7baea1b676996f8d5ec43054690db192cfa..b24534cc8434ae0fffa63a60b75903eed1ebd8ba 100644 (file)
@@ -292,22 +292,32 @@ Node LearnedRewrite::rewriteLearned(Node n,
     Node num = n[0];
     Node den = n[1];
     arith::Bounds db = binfer.get(den);
-    if ((!db.lower_value.isNull()
-         && db.lower_value.getConst<Rational>().sgn() == 1)
-        || (!db.upper_value.isNull()
-            && db.upper_value.getConst<Rational>().sgn() == -1))
+    if (!db.lower_value.isNull() && !db.upper_value.isNull())
     {
-      Rational bden = db.upper_value.isNull()
-                          ? db.lower_value.getConst<Rational>()
-                          : db.upper_value.getConst<Rational>().abs();
-      // if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den)
-      arith::Bounds nb = binfer.get(num);
-      if (!nb.upper_value.isNull())
+      Rational bdenu = db.upper_value.getConst<Rational>();
+      Rational bdenl = db.lower_value.getConst<Rational>();
+      if (bdenl.sgn() == bdenu.sgn())
       {
-        Rational bnum = nb.upper_value.getConst<Rational>();
-        if (bnum.sgn() != -1 && bnum < bden)
+        // if the sign of LB(num) is the sign of UB(num),
+        // the sign of LB(den) is the sign of UB(den), and
+        // abs(LB(num)) and abs(UB(num)) is less than abs(LB(den)) and
+        // abs(UB(den)), then the mod can be eliminated.
+        arith::Bounds nb = binfer.get(num);
+        if (!nb.upper_value.isNull() && !nb.lower_value.isNull())
         {
-          nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE);
+          Rational bnuml = nb.lower_value.getConst<Rational>();
+          Rational bnumu = nb.upper_value.getConst<Rational>();
+          Rational bnum = bnumu.abs() > bnuml.abs() ? bnuml.abs() : bnumu.abs();
+          if (bnuml.sgn() == bnumu.sgn() && bdenl.abs() < bnum
+              && bdenu.abs() < bnum)
+          {
+            // if the numerator is negative, then (mod x y) ---> (+ x (abs y))
+            // otherwise, (mod x y) ---> x
+            Node ret = bnuml.sgn() == -1 ? nm->mkNode(
+                           kind::ADD, nr[0], nm->mkNode(kind::ABS, nr[1]))
+                                         : nr[0];
+            nr = returnRewriteLearned(nr, ret, LearnedRewriteId::INT_MOD_RANGE);
+          }
         }
       }
       // could also do num + k*den checks
index 67c6cc8d4f8b52793f060cdb134a64e399bd1d1a..4bfc3bc7100e45724d7739bfb9562b557a167b93 100644 (file)
@@ -830,6 +830,7 @@ set(regress_0_tests
   regress0/nl/issue8744-real-cov.smt2
   regress0/nl/issue8755-nl-logic-exception.smt2
   regress0/nl/issue8835-int-second.smt2
+  regress0/nl/issue8934-lr-int-mod-range.smt2
   regress0/nl/lazard-spurious-root.smt2
   regress0/nl/magnitude-wrong-1020-m.smt2
   regress0/nl/mult-po.smt2
@@ -2056,6 +2057,7 @@ set(regress_1_tests
   regress1/nl/issue8052-iand-rewrite.smt2
   regress1/nl/issue8118-elim-sin.smt2
   regress1/nl/issue8162-drop-pi-bound.smt2
+  regress1/nl/learned-rewrite-int-mod-range.smt2
   regress1/nl/metitarski-1025.smt2
   regress1/nl/metitarski-3-4.smt2
   regress1/nl/metitarski_3_4_2e.smt2
diff --git a/test/regress/cli/regress0/nl/issue8934-lr-int-mod-range.smt2 b/test/regress/cli/regress0/nl/issue8934-lr-int-mod-range.smt2
new file mode 100644 (file)
index 0000000..0afa542
--- /dev/null
@@ -0,0 +1,9 @@
+; COMMAND-LINE: --learned-rewrite
+; EXPECT: unsat
+; DISABLE-TESTER: unsat-core
+; DISABLE-TESTER: proof
+(set-logic QF_NIA)
+(declare-const x Int)
+(declare-const y Int)
+(assert (and (<= x 0) (< 0 y) (or (= 0 y) (> 0 (mod x y)))))
+(check-sat)
diff --git a/test/regress/cli/regress1/nl/learned-rewrite-int-mod-range.smt2 b/test/regress/cli/regress1/nl/learned-rewrite-int-mod-range.smt2
new file mode 100644 (file)
index 0000000..a8a00e5
--- /dev/null
@@ -0,0 +1,24 @@
+(set-logic QF_NIA)
+(set-info :status unsat)
+(declare-fun ld () Int)
+(declare-fun d () Int)
+(declare-fun ud () Int)
+(declare-fun ln () Int)
+(declare-fun n () Int)
+(declare-fun un () Int)
+(define-fun sgn ((x Int)) Int (ite (< x 0) (- 1) (ite (> x 0) 1 0)))
+
+(assert (<= ld d ud))
+(assert (<= ln n un))
+
+(assert (< (abs ln) (abs ld)))
+(assert (< (abs ln) (abs ud)))
+(assert (< (abs un) (abs ld)))
+(assert (< (abs un) (abs ud)))
+
+(assert (= (sgn ld) (sgn ud)))
+(assert (= (sgn ln) (sgn un)))
+
+(assert (not (= (mod n d) (ite (< n 0) (+ (abs d) n) n))))
+
+(check-sat)