Add set.filter operator and its inference rules (#8856)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 7 Jun 2022 18:37:07 +0000 (13:37 -0500)
committerGitHub <noreply@github.com>
Tue, 7 Jun 2022 18:37:07 +0000 (18:37 +0000)
26 files changed:
proofs/lfsc/signatures/theory_def.plf
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_solver.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/kinds
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/sets/kinds
src/theory/sets/solver_state.cpp
src/theory/sets/solver_state.h
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/sets/theory_sets_rewriter.cpp
src/theory/sets/theory_sets_rewriter.h
src/theory/sets/theory_sets_type_rules.cpp
src/theory/sets/theory_sets_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/filter3.smt2
test/regress/cli/regress1/sets/set_filter1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/set_filter2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/set_filter3.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/set_filter4.smt2 [new file with mode: 0644]

index 04ac320aeb7843f479bda0bd22c2c9c3d3788f18..30f00dc2957174dddd369c2a68264f41e913cac6 100644 (file)
 (define rel.join_image (# x term (# y term (apply (apply f_rel.join_image x) y))))
 (declare f_set.insert term)
 (define set.insert (# x term (# y term (apply (apply f_set.insert x) y))))
+(declare f_set.filter term)
+(define set.filter (# x term (# y term (apply (apply f_set.filter x) y))))
 (declare f_set.map term)
 (define set.map (# x term (# y term (apply (apply f_set.map x) y))))
 
index 9705757b9e6a041f7c2bd2ca7ed8b773ae7a1e50..baa94039aeb46ffde39008a865257dedab84e249 100644 (file)
@@ -301,6 +301,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(SET_CHOOSE, internal::Kind::SET_CHOOSE),
         KIND_ENUM(SET_IS_SINGLETON, internal::Kind::SET_IS_SINGLETON),
         KIND_ENUM(SET_MAP, internal::Kind::SET_MAP),
+        KIND_ENUM(SET_FILTER, internal::Kind::SET_FILTER),
         /* Relations -------------------------------------------------------- */
         KIND_ENUM(RELATION_JOIN, internal::Kind::RELATION_JOIN),
         KIND_ENUM(RELATION_PRODUCT, internal::Kind::RELATION_PRODUCT),
@@ -623,6 +624,7 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::SET_CHOOSE, SET_CHOOSE},
         {internal::Kind::SET_IS_SINGLETON, SET_IS_SINGLETON},
         {internal::Kind::SET_MAP, SET_MAP},
+        {internal::Kind::SET_FILTER, SET_FILTER},
         /* Relations ------------------------------------------------------- */
         {internal::Kind::RELATION_JOIN, RELATION_JOIN},
         {internal::Kind::RELATION_PRODUCT, RELATION_PRODUCT},
index 40797e027718fa78fe37710c71287117775685ff..453e2b2528b0417cb1e3362bda7b306dc01c26ff 100644 (file)
@@ -3162,7 +3162,34 @@ enum Kind : int32_t
    * \endrst
    */
    SET_MAP,
-
+  /**
+   * Set filter.
+   *
+   * \rst
+   * This operator filters the elements of a set.
+   * (set.filter :math:`p \; A`) takes a predicate :math:`p` of Sort
+   * :math:`(\rightarrow T \; Bool)` as a first argument, and a set :math:`A`
+   * of Sort (Set :math:`T`) as a second argument, and returns a subset of Sort
+   * (Set :math:`T`) that includes all elements of :math:`A` that satisfy
+   * :math:`p`.
+   *
+   * - Arity: ``2``
+   *
+   *   - ``1:`` Term of function Sort :math:`(\rightarrow T \; Bool)`
+   *   - ``2:`` Term of bag Sort (Set :math:`T`)
+   * \endrst
+   *
+   * - Create Term of this Kind with:
+   *
+   *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+   SET_FILTER,
   /* Relations ------------------------------------------------------------- */
 
   /**
@@ -3624,15 +3651,15 @@ enum Kind : int32_t
    * \rst
    * This operator filters the elements of a bag.
    * (bag.filter :math:`p \; B`) takes a predicate :math:`p` of Sort
-   * :math:`(\rightarrow S_1 \; S_2)` as a first argument, and a bag :math:`B`
-   * of Sort (Bag :math:`S`) as a second argument, and returns a subbag of Sort
+   * :math:`(\rightarrow T \; Bool)` as a first argument, and a bag :math:`B`
+   * of Sort (Bag :math:`T`) as a second argument, and returns a subbag of Sort
    * (Bag :math:`T`) that includes all elements of :math:`B` that satisfy
    * :math:`p` with the same multiplicity.
    *
    * - Arity: ``2``
    *
-   *   - ``1:`` Term of function Sort :math:`(\rightarrow S_1 \; S_2)`
-   *   - ``2:`` Term of bag Sort (Bag :math:`S_1`)
+   *   - ``1:`` Term of function Sort :math:`(\rightarrow T \; Bool)`
+   *   - ``2:`` Term of bag Sort (Bag :math:`T`)
    * \endrst
    *
    * - Create Term of this Kind with:
@@ -3640,10 +3667,6 @@ enum Kind : int32_t
    *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
    *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
    *
-   * - Create Op of this kind with:
-   *
-   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
-   *
    * \rst
    * .. warning:: This kind is experimental and may be changed or removed in
    *              future versions.
index 66f2db214a2db5538e41b248130b6f4358bd6c77..cbf699f6b3ac787703f2f933b4160caf85483663 100644 (file)
@@ -605,6 +605,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(cvc5::SET_CHOOSE, "set.choose");
     addOperator(cvc5::SET_IS_SINGLETON, "set.is_singleton");
     addOperator(cvc5::SET_MAP, "set.map");
+    addOperator(cvc5::SET_FILTER, "set.filter");
     addOperator(cvc5::RELATION_JOIN, "rel.join");
     addOperator(cvc5::RELATION_PRODUCT, "rel.product");
     addOperator(cvc5::RELATION_TRANSPOSE, "rel.transpose");
index 6a7b8ccb63c1f83bca67b8eee462f4d83d07a136..5908e82e7834a9cf73d14cec00f9bb496ac54836 100644 (file)
@@ -1152,6 +1152,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::SET_CHOOSE: return "set.choose";
   case kind::SET_IS_SINGLETON: return "set.is_singleton";
   case kind::SET_MAP: return "set.map";
+  case kind::SET_FILTER: return "set.filter";
   case kind::RELATION_JOIN: return "rel.join";
   case kind::RELATION_PRODUCT: return "rel.product";
   case kind::RELATION_TRANSPOSE: return "rel.transpose";
index 5118df4621785b9672f4ff2a9de952d08c19e9e0..5f52c6215d13bd21eac05f50bc4cfcb7350d5c78 100644 (file)
@@ -302,12 +302,12 @@ void BagSolver::checkFilter(Node n)
 
   for (const Node& e : elements)
   {
-    InferInfo i = d_ig.filterDownwards(n, d_state.getRepresentative(e));
+    InferInfo i = d_ig.filterDown(n, d_state.getRepresentative(e));
     d_im.lemmaTheoryInference(&i);
   }
   for (const Node& e : elements)
   {
-    InferInfo i = d_ig.filterUpwards(n, d_state.getRepresentative(e));
+    InferInfo i = d_ig.filterUp(n, d_state.getRepresentative(e));
     d_im.lemmaTheoryInference(&i);
   }
 }
index 8d8ee22c59b3c51369e1589491bc15de51c65076..31509fa9c423e603c9aa5898a1d10c085c608c65 100644 (file)
@@ -533,7 +533,7 @@ InferInfo InferenceGenerator::mapUp(
   return inferInfo;
 }
 
-InferInfo InferenceGenerator::filterDownwards(Node n, Node e)
+InferInfo InferenceGenerator::filterDown(Node n, Node e)
 {
   Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
   Assert(e.getType() == n[1].getType().getBagElementType());
@@ -555,7 +555,7 @@ InferInfo InferenceGenerator::filterDownwards(Node n, Node e)
   return inferInfo;
 }
 
-InferInfo InferenceGenerator::filterUpwards(Node n, Node e)
+InferInfo InferenceGenerator::filterUp(Node n, Node e)
 {
   Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
   Assert(e.getType() == n[1].getType().getBagElementType());
index 7d92ea4b028c70e263cc344ed2f97de18d39b052..91a3ea08a59bc495bbdce074ffaf7e02d6192cac 100644 (file)
@@ -289,7 +289,7 @@ class InferenceGenerator
    *     (= (bag.count e skolem) (bag.count e A)))
    * where skolem is a variable equals (bag.filter p A)
    */
-  InferInfo filterDownwards(Node n, Node e);
+  InferInfo filterDown(Node n, Node e);
 
   /**
    * @param n is (bag.filter p A) where p is a function (-> E Bool),
@@ -303,7 +303,7 @@ class InferenceGenerator
    *     (and (not (p e)) (= (bag.count e skolem) 0)))
    * where skolem is a variable equals (bag.filter p A)
    */
-  InferInfo filterUpwards(Node n, Node e);
+  InferInfo filterUp(Node n, Node e);
 
   /**
    * @param n is a (table.product A B) where A, B are tables
index 2522585bdea937b9266ff5dfd4137a6bf9d2c675..9961b586492656d0f9f6e006f7cdddcedcfbb007 100644 (file)
@@ -71,7 +71,7 @@ operator BAG_CHOOSE        1  "return an element in the bag given as a parameter
 operator BAG_MAP           2  "bag map function"
 
 # The bag.filter operator takes a predicate of type (-> T Bool) and a bag of type (Bag T)
-# and return the same bag excluding those elements that do not satisfy the predicate
+# and returns the same bag excluding those elements that do not satisfy the predicate
 operator BAG_FILTER        2  "bag filter operator"
 
 # bag.fold operator combines elements of a bag into a single value.
index 3637a5e0e3d0a81c84bda855e359ac473c1a1b61..55cb72451486e96339aacfc0338fd8214fbb7bc9 100644 (file)
@@ -333,6 +333,8 @@ const char* toString(InferenceId i)
     case InferenceId::SETS_EQ_CONFLICT: return "SETS_EQ_CONFLICT";
     case InferenceId::SETS_EQ_MEM: return "SETS_EQ_MEM";
     case InferenceId::SETS_EQ_MEM_CONFLICT: return "SETS_EQ_MEM_CONFLICT";
+    case InferenceId::SETS_FILTER_DOWN: return "SETS_FILTER_DOWN";
+    case InferenceId::SETS_FILTER_UP: return "SETS_FILTER_UP";
     case InferenceId::SETS_MAP_DOWN_POSITIVE: return "SETS_MAP_DOWN_POSITIVE";
     case InferenceId::SETS_MAP_UP: return "SETS_MAP_UP";
     case InferenceId::SETS_MEM_EQ: return "SETS_MEM_EQ";
index bd0e2fc7e6c6c6fd4583b1149e47083e121a45de..3a6452e45743f53f106e4f183e117189bca25c2e 100644 (file)
@@ -486,6 +486,8 @@ enum class InferenceId
   SETS_EQ_CONFLICT,
   SETS_EQ_MEM,
   SETS_EQ_MEM_CONFLICT,
+  SETS_FILTER_DOWN,
+  SETS_FILTER_UP,
   SETS_MAP_DOWN_POSITIVE,
   SETS_MAP_UP,
   SETS_MEM_EQ,
index da7c2b930a69a15930c53c2c17c4c46b8ce37fb7..d1e22cab103fca6ad247dff154b53e640b8bc2f3 100644 (file)
@@ -77,6 +77,10 @@ operator SET_IS_SINGLETON  1  "return whether the given set is a singleton"
 # of the second argument, a set of type (Set T1), and returns a set of type (Set T2).
 operator SET_MAP           2  "set map function"
 
+# The set.filter operator takes a predicate of type (-> T Bool) and a set of type (Set T)
+# and returns the same set excluding those elements that do not satisfy the predicate
+operator SET_FILTER        2  "set filter operator"
+
 operator RELATION_JOIN                    2  "relation join"
 operator RELATION_PRODUCT         2  "relation cartesian product"
 operator RELATION_TRANSPOSE    1  "relation transpose"
@@ -99,6 +103,7 @@ typerule SET_COMPREHENSION  ::cvc5::internal::theory::sets::ComprehensionTypeRul
 typerule SET_CHOOSE         ::cvc5::internal::theory::sets::ChooseTypeRule
 typerule SET_IS_SINGLETON   ::cvc5::internal::theory::sets::IsSingletonTypeRule
 typerule SET_MAP            ::cvc5::internal::theory::sets::SetMapTypeRule
+typerule SET_FILTER         ::cvc5::internal::theory::sets::SetFilterTypeRule
 
 typerule RELATION_JOIN                         ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
 typerule RELATION_PRODUCT              ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
index 5c3874936f3af68cbc6d96c0a834c71f4205cbf4..cb02f4b0625623e0b97d67da785c06ba5e23432d 100644 (file)
@@ -54,6 +54,7 @@ void SolverState::reset()
   d_bop_index.clear();
   d_op_list.clear();
   d_allCompSets.clear();
+  d_filterTerms.clear();
 }
 
 void SolverState::registerEqc(TypeNode tn, Node r)
@@ -139,6 +140,10 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
     d_nvar_sets[r].push_back(n);
     Trace("sets-debug2") << "Non-var-set[" << r << "] : " << n << std::endl;
   }
+  else if (nk == SET_FILTER)
+  {
+    d_filterTerms.push_back(n);
+  }
   else if (nk == SET_MAP)
   {
     d_mapTerms.insert(n);
@@ -472,6 +477,8 @@ const std::map<Kind, std::vector<Node> >& SolverState::getOperatorList() const
   return d_op_list;
 }
 
+const std::vector<Node>& SolverState::getFilterTerms() const { return d_filterTerms; }
+
 const context::CDHashSet<Node>& SolverState::getMapTerms() const { return d_mapTerms; }
 
 std::shared_ptr<context::CDHashSet<Node>> SolverState::getMapSkolemElements(
index 307d37c0783577fc702c7d0ae282c786563655ed..ab240ab228b82747c7929345d6872140087da75e 100644 (file)
@@ -161,6 +161,8 @@ class SolverState : public TheoryState
    * map is a representative of its congruence class.
    */
   const std::map<Kind, std::vector<Node> >& getOperatorList() const;
+  /** Get the list of all set.filter terms */
+  const std::vector<Node>& getFilterTerms() const;
   /** Get the list of all set.map terms in the current user context */
   const context::CDHashSet<Node>& getMapTerms() const;
   /** Get the list of all skolem elements generated for map terms down rules in
@@ -218,6 +220,8 @@ class SolverState : public TheoryState
   std::map<Node, Node> d_congruent;
   /** Map from equivalence classes to the list of non-variable sets in it */
   std::map<Node, std::vector<Node> > d_nvar_sets;
+  /** A list of filter terms. It is initialized during full effort check */
+  std::vector<Node> d_filterTerms;
   /** User context collection of set.map terms */
   context::CDHashSet<Node> d_mapTerms;
   /** User context collection of skolem elements generated for set.map terms */
index 4aae37ecfcf3440be83a5be82a0f7e3a14b8fcfa..8c2665eb8eb58947e1e4d7b977cbf9ac903398ac 100644 (file)
@@ -364,6 +364,20 @@ void TheorySetsPrivate::fullEffortCheck()
     {
       continue;
     }
+    // check filter up rule
+    checkFilterUp();
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
+    {
+      continue;
+    }
+    // check filter down rules
+    checkFilterDown();
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
+    {
+      continue;
+    }
     // check map up rules
     checkMapUp();
     d_im.doPendingLemmas();
@@ -679,6 +693,70 @@ void TheorySetsPrivate::checkUpwardsClosure()
   }
 }
 
+void TheorySetsPrivate::checkFilterUp()
+{
+  NodeManager* nm = NodeManager::currentNM();
+  const std::vector<Node>& filterTerms = d_state.getFilterTerms();
+
+  for (const Node& term : filterTerms)
+  {
+    Node p = term[0];
+    Node A = term[1];
+    const std::map<Node, Node>& positiveMembers =
+        d_state.getMembers(d_state.getRepresentative(A));
+    for (const std::pair<const Node, Node>& pair : positiveMembers)
+    {
+      Node x = pair.first;
+      std::vector<Node> exp;
+      exp.push_back(pair.second);
+      Node B = pair.second[1];
+      d_state.addEqualityToExp(A, B, exp);
+      Node p_x = nm->mkNode(APPLY_UF, p, x);
+      Node skolem = d_treg.getProxy(term);
+      Node memberFilter = nm->mkNode(kind::SET_MEMBER, x, skolem);
+      Node not_p_x = p_x.notNode();
+      Node not_memberFilter = memberFilter.notNode();
+      Node orNode =
+          p_x.andNode(memberFilter).orNode(not_p_x.andNode(not_memberFilter));
+      d_im.assertInference(orNode, InferenceId::SETS_FILTER_UP, exp);
+      if (d_state.isInConflict())
+      {
+        return;
+      }
+    }
+  }
+}
+
+void TheorySetsPrivate::checkFilterDown()
+{
+  NodeManager* nm = NodeManager::currentNM();
+  const std::vector<Node>& filterTerms = d_state.getFilterTerms();
+  for (const Node& term : filterTerms)
+  {
+    Node p = term[0];
+    Node A = term[1];
+
+    const std::map<Node, Node>& positiveMembers =
+        d_state.getMembers(d_state.getRepresentative(term));
+    for (const std::pair<const Node, Node>& pair : positiveMembers)
+    {
+      std::vector<Node> exp;
+      Node B = pair.second[1];
+      exp.push_back(pair.second);
+      d_state.addEqualityToExp(B, term, exp);
+      Node x = pair.first;
+      Node memberA = nm->mkNode(kind::SET_MEMBER, x, A);
+      Node p_x = nm->mkNode(APPLY_UF, p, x);
+      Node fact = memberA.andNode(p_x);
+      d_im.assertInference(fact, InferenceId::SETS_FILTER_DOWN, exp);
+      if (d_state.isInConflict())
+      {
+        return;
+      }
+    }
+  }
+}
+
 void TheorySetsPrivate::checkMapUp()
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -1164,8 +1242,10 @@ void TheorySetsPrivate::processCarePairArgs(TNode a, TNode b)
   }
 }
 
-/** returns whether the given kind is a higher order kind for sets. */
-bool TheorySetsPrivate::isHigherOrderKind(Kind k) { return k == SET_MAP; }
+bool TheorySetsPrivate::isHigherOrderKind(Kind k)
+{
+  return k == SET_MAP || k == SET_FILTER;
+}
 
 Node TheorySetsPrivate::explain(TNode literal)
 {
index cda1fae373f907232eb974c07657c2e374688fe9..87ad4678e01be5db5b306df64ec4fcda1f7d2c7d 100644 (file)
@@ -76,6 +76,29 @@ class TheorySetsPrivate : protected EnvObj
    */
   void checkUpwardsClosure();
 
+  /**
+   * Apply the following rule for filter terms (set.filter p A):
+   * (=>
+   *   (and (set.member x B) (= A B))
+   *   (or
+   *    (and (p x) (set.member x (set.filter p A)))
+   *    (and (not (p x)) (not (set.member x (set.filter p A))))
+   *   )
+   * )
+   */
+  void checkFilterUp();
+  /**
+   * Apply the following rule for filter terms (set.filter p A):
+   * (=>
+   *   (bag.member x (set.filter p A))
+   *   (and
+   *    (p x)
+   *    (set.member x A)
+   *   )
+   * )
+   */
+  void checkFilterDown();
+
   /**
    * Apply the following rule for map terms (set.map f A):
    * Positive member rule:
index d356ddbe73b9e9070ca568a4bed865ba598fcddb..9a0b5a875fb4b2415c4781cb27a977121ef0d08a 100644 (file)
@@ -333,6 +333,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
   }  // kind::SET_IS_SINGLETON
 
   case SET_MAP: return postRewriteMap(node);
+  case SET_FILTER: return postRewriteFilter(node);
 
   case kind::RELATION_TRANSPOSE:
   {
@@ -669,6 +670,41 @@ RewriteResponse TheorySetsRewriter::postRewriteMap(TNode n)
   }
 }
 
+RewriteResponse TheorySetsRewriter::postRewriteFilter(TNode n)
+{
+  Assert(n.getKind() == kind::SET_FILTER);
+  NodeManager* nm = NodeManager::currentNM();
+  Kind k = n[1].getKind();
+  switch (k)
+  {
+    case SET_EMPTY:
+    {
+      // (set.filter p (as set.empty (Set T)) = (as set.empty (Set T))
+      return RewriteResponse(REWRITE_DONE, n[1]);
+    }
+    case SET_SINGLETON:
+    {
+      // (set.filter p (set.singleton x)) =
+      //       (ite (p x) (set.singleton x) (as set.empty (Set T)))
+      Node empty = nm->mkConst(EmptySet(n.getType()));
+      Node condition = nm->mkNode(APPLY_UF, n[0], n[1][0]);
+      Node ret = nm->mkNode(ITE, condition, n[1], empty);
+      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+    }
+    case SET_UNION:
+    {
+      // (set.filter p (set.union A B)) =
+      //   (set.union (set.filter p A) (set.filter p B))
+      Node a = nm->mkNode(SET_FILTER, n[0], n[1][0]);
+      Node b = nm->mkNode(SET_FILTER, n[0], n[1][1]);
+      Node ret = nm->mkNode(SET_UNION, a, b);
+      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+    }
+
+    default: return RewriteResponse(REWRITE_DONE, n);
+  }
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index ee9c7f8edef6a2546143892f58cd2511d07f29fd..74735a878b61282afa25d838fe881da80e81e39e 100644 (file)
@@ -83,6 +83,17 @@ private:
   *  where f: T1 -> T2
   */
  RewriteResponse postRewriteMap(TNode n);
+
+ /**
+  *  rewrites for n include:
+  *  - (set.filter p (as set.empty (Set T)) = (as set.empty (Set T))
+  *  - (set.filter p (set.singleton x)) =
+  *       (ite (p x) (set.singleton x) (as set.empty (Set T)))
+  *  - (set.filter p (set.union A B)) =
+  *       (set.union (set.filter p A) (set.filter p B))
+  *  where p: T -> Bool
+  */
+ RewriteResponse postRewriteFilter(TNode n);
 }; /* class TheorySetsRewriter */
 
 }  // namespace sets
index 27583e71f6ae52b75920fe6171d04ab081a95781..49bd24e171fe8d0e4615258cb2b4f36c7beb0c30 100644 (file)
@@ -314,6 +314,48 @@ TypeNode SetMapTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode SetFilterTypeRule::computeType(NodeManager* nodeManager,
+                                        TNode n,
+                                        bool check)
+{
+  Assert(n.getKind() == kind::SET_FILTER);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode setType = n[1].getType(check);
+  if (check)
+  {
+    if (!setType.isSet())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "set.filter operator expects a set in the second argument, "
+          "a non-set is found");
+    }
+
+    TypeNode elementType = setType.getSetElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " Bool) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    NodeManager* nm = NodeManager::currentNM();
+    if (!(argTypes.size() == 1 && argTypes[0] == elementType
+          && functionType.getRangeType() == nm->booleanType()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " Bool). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  return setType;
+}
+
 TypeNode RelBinaryOperatorTypeRule::computeType(NodeManager* nodeManager,
                                                 TNode n,
                                                 bool check)
index ee1fc06c7b01344e5741b6f567b5743380237f3b..7461ec4cc0bc0c1b31728544f6e9c137e9770cc5 100644 (file)
@@ -140,6 +140,15 @@ struct SetMapTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct SetMapTypeRule */
 
+/**
+ * Type rule for (set.filter p A) to make sure p is a unary predicate of type
+ * (-> T Bool) where A is a set of type (Set T)
+ */
+struct SetFilterTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct SetFilterTypeRule */
+
 /**
  * Type rule for binary operators (rel.join, rel.product) to check
  * if the two arguments are relations (set of tuples).
index ec6772707b3ccf6d8258e14934666607de204fa0..763270a6f9776b1340b213877b2ad623f12e77c1 100644 (file)
@@ -2503,6 +2503,10 @@ set(regress_1_tests
   regress1/sets/sets-tuple-poly.cvc.smt2
   regress1/sets/sets-uc-wrong.smt2
   regress1/sets/set-comp-sat.smt2
+  regress1/sets/set_filter1.smt2
+  regress1/sets/set_filter2.smt2
+  regress1/sets/set_filter3.smt2
+  regress1/sets/set_filter4.smt2
   regress1/sets/set_map_card_incomplete.smt2
   regress1/sets/set_map_negative_members.smt2
   regress1/sets/set_map_positive_members.smt2
index 10f6370153ea2f1862588e4f5de3e4d7650d0453..7453cc00d41c0026d3b0ae6808aa11b3515a6857 100644 (file)
@@ -5,6 +5,6 @@
 (declare-fun B () (Bag Int))
 (define-fun p ((x Int)) Bool (> x 1))
 (assert (= B (bag.filter p A)))
-(assert (= (bag.count 3 B) 57))
+(assert (= (bag.count 3 A) 57))
 (assert (= (bag.count 3 B) 58))
 (check-sat)
diff --git a/test/regress/cli/regress1/sets/set_filter1.smt2 b/test/regress/cli/regress1/sets/set_filter1.smt2
new file mode 100644 (file)
index 0000000..32e3251
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Set Int))
+(declare-fun B () (Set Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(declare-fun p (Int) Bool)
+(assert (= A (set.union (set.singleton x) (set.singleton y))))
+(assert (= B (set.filter p A)))
+(assert (distinct (p x) (p y)))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/set_filter2.smt2 b/test/regress/cli/regress1/sets/set_filter2.smt2
new file mode 100644 (file)
index 0000000..d7c08ae
--- /dev/null
@@ -0,0 +1,9 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(declare-fun A () (Set Int))
+(declare-fun B () (Set Int))
+(declare-fun p (Int) Bool)
+(assert (= B (set.filter p A)))
+(assert (set.member (- 2) B))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/set_filter3.smt2 b/test/regress/cli/regress1/sets/set_filter3.smt2
new file mode 100644 (file)
index 0000000..bf96f92
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Set Int))
+(declare-fun B () (Set Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (set.filter p A)))
+(assert (p element))
+(assert (not (set.member element B)))
+(assert (set.member element A))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/set_filter4.smt2 b/test/regress/cli/regress1/sets/set_filter4.smt2
new file mode 100644 (file)
index 0000000..858cdc0
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Set Int))
+(declare-fun B () (Set Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (set.filter p A)))
+(assert (p element))
+(assert (not (set.member element A)))
+(assert (set.member element B))
+(check-sat)