Add rel.group operator to sets (#8876)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 28 Jun 2022 17:50:41 +0000 (12:50 -0500)
committerGitHub <noreply@github.com>
Tue, 28 Jun 2022 17:50:41 +0000 (17:50 +0000)
This PR is dependent on #8819 and #8867.

28 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bags_utils.cpp
src/theory/bags/bags_utils.h
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/sets/kinds
src/theory/sets/rels_utils.cpp
src/theory/sets/rels_utils.h
src/theory/sets/solver_state.cpp
src/theory/sets/solver_state.h
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/sets/theory_sets_rewriter.cpp
src/theory/sets/theory_sets_rewriter.h
src/theory/sets/theory_sets_type_rules.cpp
src/theory/sets/theory_sets_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/sets/relation_group1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_group2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_group3.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_group4.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_group5.smt2 [new file with mode: 0644]

index 52c2045b0f8de20aca26195c2d25bc29aa1cf756..1c45d2d4d0735cb45e08f77b680b4e0404c760b3 100644 (file)
@@ -309,6 +309,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(RELATION_TCLOSURE, internal::Kind::RELATION_TCLOSURE),
         KIND_ENUM(RELATION_JOIN_IMAGE, internal::Kind::RELATION_JOIN_IMAGE),
         KIND_ENUM(RELATION_IDEN, internal::Kind::RELATION_IDEN),
+        KIND_ENUM(RELATION_GROUP, internal::Kind::RELATION_GROUP),
         /* Bags ------------------------------------------------------------- */
         KIND_ENUM(BAG_UNION_MAX, internal::Kind::BAG_UNION_MAX),
         KIND_ENUM(BAG_UNION_DISJOINT, internal::Kind::BAG_UNION_DISJOINT),
@@ -633,6 +634,7 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::RELATION_TCLOSURE, RELATION_TCLOSURE},
         {internal::Kind::RELATION_JOIN_IMAGE, RELATION_JOIN_IMAGE},
         {internal::Kind::RELATION_IDEN, RELATION_IDEN},
+        {internal::Kind::RELATION_GROUP, RELATION_GROUP},
         /* Bags ------------------------------------------------------------ */
         {internal::Kind::BAG_UNION_MAX, BAG_UNION_MAX},
         {internal::Kind::BAG_UNION_DISJOINT, BAG_UNION_DISJOINT},
@@ -772,6 +774,7 @@ const static std::unordered_map<Kind, internal::Kind> s_op_kinds{
     {REGEXP_REPEAT, internal::Kind::REGEXP_REPEAT_OP},
     {REGEXP_LOOP, internal::Kind::REGEXP_LOOP_OP},
     {TUPLE_PROJECT, internal::Kind::TUPLE_PROJECT_OP},
+    {RELATION_GROUP, internal::Kind::RELATION_GROUP_OP},
     {TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP},
     {TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP},
     {TABLE_JOIN, internal::Kind::TABLE_JOIN_OP},
@@ -1950,6 +1953,7 @@ size_t Op::getNumIndicesHelper() const
     case FLOATINGPOINT_TO_FP_FROM_UBV: size = 2; break;
     case REGEXP_LOOP: size = 2; break;
     case TUPLE_PROJECT:
+    case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
@@ -2110,6 +2114,7 @@ Term Op::getIndexHelper(size_t index) const
       break;
     }
     case TUPLE_PROJECT:
+    case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
@@ -6151,6 +6156,7 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
       res = mkOpHelper(kind, internal::RegExpLoop(args[0], args[1]));
       break;
     case TUPLE_PROJECT:
+    case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
index 9aa4824d7ac30f24194bb5458069dafc68247b22..f71d6a447932443ec43be23c12c0a59825a1b414 100644 (file)
@@ -3330,6 +3330,37 @@ enum Kind : int32_t
    * \endrst
    */
   RELATION_IDEN,
+  /**
+   * Relation group
+   *
+   * \rst
+   * :math:`((\_ \; rel.group \; n_1 \; \dots \; n_k) \; A)` partitions tuples
+   * of relation :math:`A` such that tuples that have the same projection
+   * with indices :math:`n_1 \; \dots \; n_k` are in the same part.
+   * It returns a set of relations of type :math:`(Set \; T)` where
+   * :math:`T` is the type of :math:`A`.
+   *
+   * - Arity: ``1``
+   *
+   *   - ``1:`` Term of relation sort
+   *
+   * - Indices: ``n``
+   *
+   *   - ``1..n:``  Indices of the projection
+   *
+   * \endrst
+   *
+   * - Create Term of this Kind with:
+   *
+   *   - Solver::mkTerm(Kind, const std::vector<Term>&) const
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  RELATION_GROUP,
 
   /* Bags ------------------------------------------------------------------ */
 
index 6b8e32f55e95c04f0acf3ddce32ddcb454e12308..64420280524f23dfa4cc13a0827203d11d42f6fd 100644 (file)
@@ -98,6 +98,9 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::TABLES_GROUP_PART: return "TABLES_GROUP_PART";
     case SkolemFunId::TABLES_GROUP_PART_ELEMENT:
       return "TABLES_GROUP_PART_ELEMENT";
+    case SkolemFunId::RELATIONS_GROUP_PART: return "RELATIONS_GROUP_PART";
+    case SkolemFunId::RELATIONS_GROUP_PART_ELEMENT:
+      return "RELATIONS_GROUP_PART_ELEMENT";
     case SkolemFunId::SETS_CHOOSE: return "SETS_CHOOSE";
     case SkolemFunId::SETS_DEQ_DIFF: return "SETS_DEQ_DIFF";
     case SkolemFunId::SETS_FOLD_CARD: return "SETS_FOLD_CARD";
index da729e4a1e1859ade01714ac9f9fba28e9806a6c..5b398280df8fb6a06dc12eec93c916c1fba36374 100644 (file)
@@ -180,6 +180,17 @@ enum class SkolemFunId
    * that is a member of B if B is not empty.
    */
   TABLES_GROUP_PART_ELEMENT,
+  /** Given a group term ((_ rel.group n1 ... nk) A) of type (Set (Relation T))
+   * this uninterpreted function maps elements of A to their parts in the
+   * resulting partition. It has type (-> T (Relation T))
+   */
+  RELATIONS_GROUP_PART,
+  /**
+   * Given a group term ((_ rel.group n1 ... nk) A) of type (Set (Relation T))
+   * and a part B of type (Relation T), this function returns a skolem element
+   * that is a member of B if B is not empty.
+   */
+  RELATIONS_GROUP_PART_ELEMENT,
   /** An interpreted function for bag.choose operator:
    * (choose A) is expanded as
    * (witness ((x elementType))
index bc5ee3147a14d423ef6a74e097c1e6cc95fdd69d..875e41ee43b37001c522aee4b12114217ac3f12a 100644 (file)
@@ -1429,6 +1429,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::TABLE_GROUP, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK RELATION_GROUP_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_GROUP, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1598,6 +1604,13 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::TABLE_GROUP;
         p.d_op = SOLVER->mkOp(cvc5::TABLE_GROUP, numerals);
       }
+     | RELATION_GROUP_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ rel.group i_1 ... i_n) where
+        // i_1, ..., j_n are numerals
+        p.d_kind = cvc5::RELATION_GROUP;
+        p.d_op = SOLVER->mkOp(cvc5::RELATION_GROUP, numerals);
+      }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2210,6 +2223,7 @@ TABLE_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS
 TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.aggr';
 TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join';
 TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group';
+RELATION_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.group';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index 8c5be875a305330200a58394223b9519608f683a..900ff1016a39ad710d500b3e2e6332d7db600558 100644 (file)
@@ -1130,7 +1130,7 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
   }
   else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
            || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN
-           || p.d_kind == cvc5::TABLE_GROUP)
+           || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index 58cdd77392fb8e57d3104b9fd8cdae7ae034cda3..5dc9252a97b33a2b5ae4d8682d02808484cf003c 100644 (file)
@@ -820,6 +820,21 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::RELATION_GROUP:
+  {
+    ProjectOp op = n.getOperator().getConst<ProjectOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (rel.group A)
+      out << "rel.group " << n[0] << ")";
+    }
+    else
+    {
+      // e.g. ((_ rel.group 0 1 2 3) A)
+      out << "(_ rel.group" << op << ") " << n[0] << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -1159,6 +1174,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::RELATION_TCLOSURE: return "rel.tclosure";
   case kind::RELATION_IDEN: return "rel.iden";
   case kind::RELATION_JOIN_IMAGE: return "rel.join_image";
+  case kind::RELATION_GROUP: return "rel.group";
 
   // bag theory
   case kind::BAG_TYPE: return "Bag";
index 17f1cf5d37c5016b52a3a3cbb210818bfcb7d829..63112bee64214230e4d9f69e5faef6895ba59835 100644 (file)
@@ -146,7 +146,7 @@ Node BagsUtils::evaluate(Rewriter* rewriter, TNode n)
     case BAG_FOLD: return evaluateBagFold(n);
     case TABLE_PRODUCT: return evaluateProduct(n);
     case TABLE_JOIN: return evaluateJoin(rewriter, n);
-    case TABLE_GROUP: return evaluateGroup(rewriter, n);
+    case TABLE_GROUP: return evaluateGroup(n);
     case TABLE_PROJECT: return evaluateTableProject(n);
     default: break;
   }
@@ -975,7 +975,7 @@ Node BagsUtils::evaluateJoin(Rewriter* rewriter, TNode n)
   return ret;
 }
 
-Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n)
+Node BagsUtils::evaluateGroup(TNode n)
 {
   Assert(n.getKind() == TABLE_GROUP);
 
@@ -1049,7 +1049,7 @@ Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n)
     parts[emptyPart] = Rational(1);
   }
   Node ret = constructConstantBagFromElements(partitionType, parts);
-  Trace("bags-partition") << "ret: " << ret << std::endl;
+  Trace("bags-group") << "ret: " << ret << std::endl;
   return ret;
 }
 
index 4da592d5ae69d6928cd0611c4f6121c49dcf4510..8267a0c66db3fa76bc123368848e70406664abee 100644 (file)
@@ -138,7 +138,7 @@ class BagsUtils
    * @return a partition of A such that each part contains tuples with the same
    * projection with indices n_1 ... n_k
    */
-  static Node evaluateGroup(Rewriter* rewriter, TNode n);
+  static Node evaluateGroup(TNode n);
 
   /**
    * @param n of the form ((_ table.project i_1 ... i_n) A) where A is a
index ad3dfbeef9f3cea3876d1d6b7173d34995f75591..d828d6e5ddf3fb73d0387097caef309c47627136 100644 (file)
@@ -334,6 +334,7 @@ const char* toString(InferenceId i)
     case InferenceId::SEP_DISTINCT_REF: return "SEP_DISTINCT_REF";
     case InferenceId::SEP_REF_BOUND: return "SEP_REF_BOUND";
 
+    case InferenceId::SETS_SKOLEM: return "SETS_SKOLEM";
     case InferenceId::SETS_CG_SPLIT: return "SETS_CG_SPLIT";
     case InferenceId::SETS_COMPREHENSION: return "SETS_COMPREHENSION";
     case InferenceId::SETS_DEQ: return "SETS_DEQ";
@@ -390,6 +391,15 @@ const char* toString(InferenceId i)
     case InferenceId::SETS_RELS_TRANSPOSE_REV: return "SETS_RELS_TRANSPOSE_REV";
     case InferenceId::SETS_RELS_TUPLE_REDUCTION:
       return "SETS_RELS_TUPLE_REDUCTION";
+    case InferenceId::SETS_RELS_GROUP_UP1: return "SETS_RELS_GROUP_UP1";
+    case InferenceId::SETS_RELS_GROUP_UP2: return "SETS_RELS_GROUP_UP2";
+    case InferenceId::SETS_RELS_GROUP_DOWN: return "SETS_RELS_GROUP_DOWN";
+    case InferenceId::SETS_RELS_GROUP_PART_MEMBER:
+      return "SETS_RELS_GROUP_PART_MEMBER";
+    case InferenceId::SETS_RELS_GROUP_SAME_PROJECTION:
+      return "SETS_RELS_GROUP_SAME_PROJECTION";
+    case InferenceId::SETS_RELS_GROUP_SAME_PART:
+      return "SETS_RELS_GROUP_SAME_PART";
 
     case InferenceId::STRINGS_I_NORM_S: return "STRINGS_I_NORM_S";
     case InferenceId::STRINGS_I_CONST_MERGE: return "STRINGS_I_CONST_MERGE";
index 4f5e3d15605bb63fcbf7a1f32727e9cd87f7015b..ee3cb6004e37f88f7973ca990de011834df1e5c6 100644 (file)
@@ -485,6 +485,7 @@ enum class InferenceId
   // ---------------------------------- sets theory
   //-------------------- sets core solver
   // split when computing care graph
+  SETS_SKOLEM,
   SETS_CG_SPLIT,
   SETS_COMPREHENSION,
   SETS_DEQ,
@@ -545,6 +546,13 @@ enum class InferenceId
   SETS_RELS_TRANSPOSE_EQ,
   SETS_RELS_TRANSPOSE_REV,
   SETS_RELS_TUPLE_REDUCTION,
+  SETS_RELS_GROUP_NOT_EMPTY,
+  SETS_RELS_GROUP_UP1,
+  SETS_RELS_GROUP_UP2,
+  SETS_RELS_GROUP_DOWN,
+  SETS_RELS_GROUP_PART_MEMBER,
+  SETS_RELS_GROUP_SAME_PROJECTION,
+  SETS_RELS_GROUP_SAME_PART,
   //-------------------------------------- end sets theory
 
   //-------------------------------------- strings theory
index e3e5e06b656eb30e1745c0a7a0aca9622a4d610b..0c0e17dc13053181e761e95a874cd1f25df88c5d 100644 (file)
@@ -89,6 +89,16 @@ operator SET_FILTER        2  "set filter operator"
 #  A: a bag of type (Set T1)
 operator SET_FOLD          3  "set fold operator"
 
+# rel.group operator
+constant RELATION_GROUP_OP \
+  class \
+  ProjectOp+ \
+  ::cvc5::internal::ProjectOpHashFunction \
+  "theory/datatypes/project_op.h" \
+  "operator for RELATION_GROUP; payload is an instance of the cvc5::internal::RelationGroupOp class"
+
+parameterized RELATION_GROUP RELATION_GROUP_OP 1 "relation group"
+
 operator RELATION_JOIN                    2  "relation join"
 operator RELATION_PRODUCT         2  "relation cartesian product"
 operator RELATION_TRANSPOSE    1  "relation transpose"
@@ -120,6 +130,8 @@ typerule RELATION_TRANSPOSE                 ::cvc5::internal::theory::sets::RelTransposeTypeRu
 typerule RELATION_TCLOSURE         ::cvc5::internal::theory::sets::RelTransClosureTypeRule
 typerule RELATION_JOIN_IMAGE       ::cvc5::internal::theory::sets::JoinImageTypeRule
 typerule RELATION_IDEN                         ::cvc5::internal::theory::sets::RelIdenTypeRule
+typerule RELATION_GROUP_OP      "SimpleTypeRule<RBuiltinOperator>"
+typerule RELATION_GROUP         ::cvc5::internal::theory::sets::RelationGroupTypeRule
 
 construle SET_UNION         ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule
 construle SET_SINGLETON     ::cvc5::internal::theory::sets::SingletonTypeRule
index 3b8e9ad328863e215d821a84746f767f0fc62bbd..08d4feb3608e8b2c23a54c68945f37a89ab41917 100644 (file)
 
 #include "expr/dtype.h"
 #include "expr/dtype_cons.h"
+#include "theory/datatypes/project_op.h"
 #include "theory/datatypes/tuple_utils.h"
+#include "theory/sets/normal_form.h"
 
+using namespace cvc5::internal::kind;
 using namespace cvc5::internal::theory::datatypes;
+using namespace cvc5::internal::theory::sets;
 
 namespace cvc5::internal {
 namespace theory {
@@ -73,7 +77,78 @@ Node RelsUtils::constructPair(Node rel, Node a, Node b)
 {
   const DType& dt = rel.getType().getSetElementType().getDType();
   return NodeManager::currentNM()->mkNode(
-      kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b);
+      APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b);
+}
+
+Node RelsUtils::evaluateGroup(TNode n)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  Node A = n[0];
+  TypeNode setType = A.getType();
+  TypeNode partitionType = n.getType();
+
+  if (A.getKind() == SET_EMPTY)
+  {
+    // return a nonempty partition
+    return nm->mkNode(SET_SINGLETON, A);
+  }
+
+  std::vector<uint32_t> indices =
+      n.getOperator().getConst<ProjectOp>().getIndices();
+
+  std::set<Node> elements = NormalForm::getElementsFromNormalConstant(A);
+  Trace("sets-group") << "elements: " << elements << std::endl;
+  // a simple map from elements to equivalent classes with this invariant:
+  // each key element must appear exactly once in one of the values.
+  std::map<Node, std::set<Node>> sets;
+  std::set<Node> emptyClass;
+  for (const Node& element : elements)
+  {
+    // initially each singleton element is an equivalence class
+    sets[element] = {element};
+  }
+  for (std::set<Node>::iterator i = elements.begin(); i != elements.end(); ++i)
+  {
+    Node iElement = *i;
+    if (sets[iElement].empty())
+    {
+      // skip this element since its equivalent class has already been processed
+      continue;
+    }
+    std::set<Node>::iterator j = i;
+    ++j;
+    while (j != elements.end())
+    {
+      Node jElement = *j;
+      if (TupleUtils::sameProjection(indices, iElement, jElement))
+      {
+        // add element j to the equivalent class
+        sets[iElement].insert(jElement);
+        // mark the equivalent class of j as processed
+        sets[jElement] = emptyClass;
+      }
+      ++j;
+    }
+  }
+
+  // construct the partition parts
+  std::set<Node> parts;
+  for (std::pair<Node, std::set<Node>> pair : sets)
+  {
+    const std::set<Node>& eqc = pair.second;
+    if (eqc.empty())
+    {
+      continue;
+    }
+    Node part = NormalForm::elementsToSet(eqc, setType);
+    parts.insert(part);
+  }
+  Node ret = NormalForm::elementsToSet(parts, partitionType);
+  Trace("sets-group") << "ret: " << ret << std::endl;
+  return ret;
 }
 
 }  // namespace sets
index 61d3010741241960e034e3d38ed78a6d8d469b10..559ef52817df32f1e914af25b9c630ab6537f45b 100644 (file)
@@ -64,6 +64,14 @@ class RelsUtils
    * @return  a tuple (tuple a b)
    */
   static Node constructPair(Node rel, Node a, Node b);
+
+  /**
+   * @param n of the form ((_ rel.group (n_1 ... n_k) ) A) where A is a
+   * constant relation
+   * @return a partition of A such that each part contains tuples with the same
+   * projection with indices n_1 ... n_k
+   */
+  static Node evaluateGroup(TNode n);
 };
 }  // namespace sets
 }  // namespace theory
index cb02f4b0625623e0b97d67da785c06ba5e23432d..11b16a0bb077dc51312a8c76325dfa485e9b89ce 100644 (file)
@@ -30,8 +30,10 @@ SolverState::SolverState(Env& env, Valuation val, SkolemCache& skc)
     : TheoryState(env, val),
       d_skCache(skc),
       d_mapTerms(env.getUserContext()),
+      d_groupTerms(env.getUserContext()),
       d_mapSkolemElements(env.getUserContext()),
-      d_members(env.getContext())
+      d_members(env.getContext()),
+      d_partElementSkolems(env.getUserContext())
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
@@ -154,6 +156,13 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
       d_mapSkolemElements[n] = set;
     }
   }
+  else if (nk == RELATION_GROUP)
+  {
+    d_groupTerms.insert(n);
+    std::shared_ptr<context::CDHashSet<Node>> set =
+        std::make_shared<context::CDHashSet<Node>>(d_env.getUserContext());
+    d_partElementSkolems[n] = set;
+  }
   else if (nk == SET_COMPREHENSION)
   {
     d_compSets[r].push_back(n);
@@ -481,6 +490,11 @@ const std::vector<Node>& SolverState::getFilterTerms() const { return d_filterTe
 
 const context::CDHashSet<Node>& SolverState::getMapTerms() const { return d_mapTerms; }
 
+const context::CDHashSet<Node>& SolverState::getGroupTerms() const
+{
+  return d_groupTerms;
+}
+
 std::shared_ptr<context::CDHashSet<Node>> SolverState::getMapSkolemElements(
     Node n)
 {
@@ -632,6 +646,20 @@ void SolverState::registerMapSkolemElement(const Node& n, const Node& element)
   d_mapSkolemElements[n].get()->insert(element);
 }
 
+void SolverState::registerPartElementSkolem(Node group, Node skolemElement)
+{
+  Assert(group.getKind() == RELATION_GROUP);
+  Assert(skolemElement.getType() == group[0].getType().getSetElementType());
+  d_partElementSkolems[group].get()->insert(skolemElement);
+}
+
+std::shared_ptr<context::CDHashSet<Node>> SolverState::getPartElementSkolems(
+    Node n)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  return d_partElementSkolems[n];
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index ab240ab228b82747c7929345d6872140087da75e..3e3ee9df2136e282f7ba70335b1451ef324654db 100644 (file)
@@ -165,6 +165,8 @@ class SolverState : public TheoryState
   const std::vector<Node>& getFilterTerms() const;
   /** Get the list of all set.map terms in the current user context */
   const context::CDHashSet<Node>& getMapTerms() const;
+  /** Get the list of all rel.group terms in the current user context */
+  const context::CDHashSet<Node>& getGroupTerms() const;
   /** Get the list of all skolem elements generated for map terms down rules in
    * the current user context */
   std::shared_ptr<context::CDHashSet<Node>> getMapSkolemElements(Node n);
@@ -197,6 +199,10 @@ class SolverState : public TheoryState
 
   /** register the skolem element for the set.map term n */
   void registerMapSkolemElement(const Node& n, const Node& element);
+  /** register skolem element generated by grup count rule */
+  void registerPartElementSkolem(Node group, Node skolemElement);
+  /** return skolem elements generated by group part count rule. */
+  std::shared_ptr<context::CDHashSet<Node>> getPartElementSkolems(Node n);
 
  private:
   /** constants */
@@ -224,6 +230,8 @@ class SolverState : public TheoryState
   std::vector<Node> d_filterTerms;
   /** User context collection of set.map terms */
   context::CDHashSet<Node> d_mapTerms;
+  /** User context collection of rel.group terms */
+  context::CDHashSet<Node> d_groupTerms;
   /** User context collection of skolem elements generated for set.map terms */
   context::CDHashMap<Node, std::shared_ptr<context::CDHashSet<Node>>>
       d_mapSkolemElements;
@@ -283,6 +291,16 @@ class SolverState : public TheoryState
    * members if i=1.
    */
   const std::map<Node, Node>& getMembersInternal(Node r, unsigned i) const;
+
+  /**
+   * A cache that stores skolem elements generated by inference rule
+   * InferenceId::RELATIONS_GROUP_PART_MEMBER.
+   * It maps rel.group nodes to generated skolem elements.
+   * The skolem elements need to persist during checking, and should only change
+   * when the user context changes.
+   */
+  context::CDHashMap<Node, std::shared_ptr<context::CDHashSet<Node>>>
+      d_partElementSkolems;
 }; /* class TheorySetsPrivate */
 
 }  // namespace sets
index bf102204821848213a066f6ce74b4f0a11e31446..5794f8a904f155cd2d850bd9e73937ebf4b93f0b 100644 (file)
@@ -23,6 +23,8 @@
 #include "expr/skolem_manager.h"
 #include "options/sets_options.h"
 #include "smt/smt_statistics_registry.h"
+#include "theory/datatypes/project_op.h"
+#include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
 #include "theory/sets/theory_sets.h"
 #include "theory/theory_model.h"
@@ -31,6 +33,7 @@
 
 using namespace std;
 using namespace cvc5::internal::kind;
+using namespace cvc5::internal::theory::datatypes;
 
 namespace cvc5::internal {
 namespace theory {
@@ -392,6 +395,13 @@ void TheorySetsPrivate::fullEffortCheck()
     {
       continue;
     }
+    // check group up
+    checkGroups();
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
+    {
+      continue;
+    }
     // check disequalities
     checkDisequalities();
     d_im.doPendingLemmas();
@@ -865,6 +875,378 @@ void TheorySetsPrivate::checkMapDown()
   }
 }
 
+void TheorySetsPrivate::checkGroups()
+{
+  const context::CDHashSet<Node>& groupTerms = d_state.getGroupTerms();
+  for (const Node& n : groupTerms)
+  {
+    checkGroup(n);
+  }
+}
+
+void TheorySetsPrivate::checkGroup(Node n)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  groupNotEmpty(n);
+  d_im.doPendingLemmas();
+  if (d_im.hasSent())
+  {
+    return;
+  }
+  Node part = defineSkolemPartFunction(n);
+  Node A = d_state.getRepresentative(n[0]);
+
+  const std::map<Node, Node>& membersA = d_state.getMembers(A);
+  const std::map<Node, Node>& negativeMembersA = d_state.getNegativeMembers(A);
+  std::shared_ptr<context::CDHashSet<Node>> skolems =
+      d_state.getPartElementSkolems(n);
+  for (const auto& aPair : membersA)
+  {
+    if (skolems->contains(aPair.first))
+    {
+      // skip skolem elements that were introduced by groupPartCount below.
+      continue;
+    }
+    Node aRep = d_state.getRepresentative(aPair.first);
+    groupUp1(n, aRep, part);
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
+    {
+      return;
+    }
+  }
+  for (const auto& aPair : negativeMembersA)
+  {
+    Node aRep = d_state.getRepresentative(aPair.first);
+    groupUp2(n, aRep, part);
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
+    {
+      return;
+    }
+  }
+  Node nRep = d_state.getRepresentative(n);
+  const std::map<Node, Node>& parts = d_state.getMembers(nRep);
+  for (std::map<Node, Node>::const_iterator partIt1 = parts.begin();
+       partIt1 != parts.end();
+       ++partIt1)
+  {
+    Node part1 = d_state.getRepresentative(partIt1->first);
+    std::vector<Node> partEqc;
+    d_state.getEquivalenceClass(part1, partEqc);
+    bool newPart = true;
+    for (Node p : partEqc)
+    {
+      if (p.getKind() == APPLY_UF && p.getOperator() == part)
+      {
+        newPart = false;
+      }
+    }
+    if (newPart)
+    {
+      // only apply the groupPartCount rule for a part that does not have
+      // nodes of the form (part x) introduced by the group up rule above.
+      groupPartMember(n, part1, part);
+      d_im.doPendingLemmas();
+      if (d_im.hasSent())
+      {
+        return;
+      }
+    }
+    Node part1Rep = d_state.getRepresentative(part1);
+    const std::map<Node, Node>& partElements = d_state.getMembers(part1Rep);
+    for (std::map<Node, Node>::const_iterator i = partElements.begin();
+         i != partElements.end();
+         ++i)
+    {
+      Node x = d_state.getRepresentative(i->first);
+      if (!skolems->contains(x))
+      {
+        // only apply down rules for elements not generated by groupPartCount
+        // rule above
+        groupDown(n, part1, x, part);
+        d_im.doPendingLemmas();
+        if (d_im.hasSent())
+        {
+          return;
+        }
+      }
+
+      std::map<Node, Node>::const_iterator j = i;
+      ++j;
+      while (j != partElements.end())
+      {
+        Node y = d_state.getRepresentative(j->first);
+        // x, y should have the same projection
+        groupSameProjection(n, part1, x, y, part);
+        d_im.doPendingLemmas();
+        if (d_im.hasSent())
+        {
+          return;
+        }
+        ++j;
+      }
+
+      for (const auto& aPair : membersA)
+      {
+        Node y = d_state.getRepresentative(aPair.first);
+        if (x != y)
+        {
+          // x, y should have the same projection
+          groupSamePart(n, part1, x, y, part);
+          d_im.doPendingLemmas();
+          if (d_im.hasSent())
+          {
+            return;
+          }
+        }
+      }
+    }
+  }
+}
+
+void TheorySetsPrivate::groupNotEmpty(Node n)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode bagType = n.getType();
+  Node A = n[0];
+  Node emptyPart = nm->mkConst(EmptySet(A.getType()));
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+
+  Node A_isEmpty = A.eqNode(emptyPart);
+  std::vector<Node> exp;
+  exp.push_back(A_isEmpty);
+  Node singleton = nm->mkNode(SET_SINGLETON, emptyPart);
+  Node groupIsSingleton = skolem.eqNode(singleton);
+
+  Node conclusion = groupIsSingleton;
+  d_im.assertInference(
+      conclusion, InferenceId::SETS_RELS_GROUP_NOT_EMPTY, exp, 1);
+}
+
+void TheorySetsPrivate::groupUp1(Node n, Node x, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(x.getType() == n[0].getType().getSetElementType());
+  NodeManager* nm = NodeManager::currentNM();
+
+  Node A = n[0];
+  TypeNode setType = A.getType();
+
+  Node member_x_A = nm->mkNode(SET_MEMBER, x, A);
+
+  std::vector<Node> exp;
+  exp.push_back(member_x_A);
+
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+
+  Node member_x_part_x = nm->mkNode(SET_MEMBER, x, part_x);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+  Node member_part_x_n = nm->mkNode(SET_MEMBER, part_x, skolem);
+
+  Node emptyPart = nm->mkConst(EmptySet(setType));
+  Node member_emptyPart = nm->mkNode(SET_MEMBER, emptyPart, skolem);
+  Node emptyPart_not_member = member_emptyPart.notNode();
+
+  Node conclusion =
+      nm->mkNode(AND, {member_part_x_n, member_x_part_x, emptyPart_not_member});
+  d_im.assertInference(conclusion, InferenceId::SETS_RELS_GROUP_UP1, exp, 1);
+}
+
+void TheorySetsPrivate::groupUp2(Node n, Node x, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(x.getType() == n[0].getType().getSetElementType());
+  NodeManager* nm = NodeManager::currentNM();
+  Node A = n[0];
+  TypeNode setType = A.getType();
+
+  Node member_x_A = nm->mkNode(SET_MEMBER, x, A);
+
+  std::vector<Node> exp;
+  exp.push_back(member_x_A.notNode());
+
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_x_is_empty = part_x.eqNode(nm->mkConst(EmptySet(setType)));
+  Node conclusion = part_x_is_empty;
+  d_im.assertInference(conclusion, InferenceId::SETS_RELS_GROUP_UP2, exp, 1);
+}
+
+void TheorySetsPrivate::groupDown(Node n, Node B, Node x, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(B.getType() == n.getType().getSetElementType());
+  Assert(x.getType() == n[0].getType().getSetElementType());
+  NodeManager* nm = NodeManager::currentNM();
+  Node A = n[0];
+  TypeNode setType = A.getType();
+
+  Node member_x_B = nm->mkNode(SET_MEMBER, x, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+  Node member_B_n = nm->mkNode(SET_MEMBER, B, skolem);
+  std::vector<Node> exp;
+  exp.push_back(member_B_n);
+  exp.push_back(member_x_B);
+  Node member_x_A = nm->mkNode(SET_MEMBER, x, A);
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_x_is_B = part_x.eqNode(B);
+  Node conclusion = nm->mkNode(AND, member_x_A, part_x_is_B);
+  d_im.assertInference(conclusion, InferenceId::SETS_RELS_GROUP_DOWN, exp, 1);
+}
+
+void TheorySetsPrivate::groupPartMember(Node n, Node B, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(B.getType() == n.getType().getSetElementType());
+
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+
+  Node A = n[0];
+  TypeNode setType = A.getType();
+  Node empty = nm->mkConst(EmptySet(setType));
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+  Node member_B_n = nm->mkNode(SET_MEMBER, B, skolem);
+  std::vector<Node> exp;
+  exp.push_back(member_B_n);
+  Node A_notEmpty = A.eqNode(empty).notNode();
+  exp.push_back(A_notEmpty);
+
+  Node x = sm->mkSkolemFunction(SkolemFunId::RELATIONS_GROUP_PART_ELEMENT,
+                                setType.getSetElementType(),
+                                {n, B});
+  d_state.registerPartElementSkolem(n, x);
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node B_is_part_x = B.eqNode(part_x);
+  Node member_x_A = nm->mkNode(SET_MEMBER, x, A);
+  Node member_x_B = nm->mkNode(SET_MEMBER, x, B);
+
+  Node conclusion = nm->mkNode(AND, {B_is_part_x, member_x_B, member_x_A});
+  d_im.assertInference(
+      conclusion, InferenceId::SETS_RELS_GROUP_PART_MEMBER, exp, 1);
+}
+
+void TheorySetsPrivate::groupSameProjection(
+    Node n, Node B, Node x, Node y, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(B.getType() == n.getType().getSetElementType());
+  Assert(x.getType() == n[0].getType().getSetElementType());
+  Assert(y.getType() == n[0].getType().getSetElementType());
+  NodeManager* nm = NodeManager::currentNM();
+
+  Node A = n[0];
+  TypeNode setType = A.getType();
+
+  Node member_x_B = nm->mkNode(SET_MEMBER, x, B);
+  Node member_y_B = nm->mkNode(SET_MEMBER, y, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+  Node member_B_n = nm->mkNode(SET_MEMBER, B, skolem);
+
+  // premises
+  std::vector<Node> exp;
+  exp.push_back(member_B_n);
+  exp.push_back(member_x_B);
+  exp.push_back(member_y_B);
+  exp.push_back(x.eqNode(y).notNode());
+
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<ProjectOp>().getIndices();
+
+  Node xProjection = TupleUtils::getTupleProjection(indices, x);
+  Node yProjection = TupleUtils::getTupleProjection(indices, y);
+  Node sameProjection = xProjection.eqNode(yProjection);
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_y = nm->mkNode(APPLY_UF, part, y);
+  part_y = registerAndAssertSkolemLemma(part_y, "part_y");
+  Node samePart = part_x.eqNode(part_y);
+  Node part_x_is_B = part_x.eqNode(B);
+  Node conclusion = nm->mkNode(AND, sameProjection, samePart, part_x_is_B);
+  d_im.assertInference(
+      conclusion, InferenceId::SETS_RELS_GROUP_SAME_PROJECTION, exp, 1);
+}
+
+void TheorySetsPrivate::groupSamePart(Node n, Node B, Node x, Node y, Node part)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Assert(B.getType() == n.getType().getSetElementType());
+  Assert(x.getType() == n[0].getType().getSetElementType());
+  Assert(y.getType() == n[0].getType().getSetElementType());
+  NodeManager* nm = NodeManager::currentNM();
+  Node A = n[0];
+  TypeNode setType = A.getType();
+
+  std::vector<Node> exp;
+  Node member_x_B = nm->mkNode(SET_MEMBER, x, B);
+  Node member_y_A = nm->mkNode(SET_MEMBER, y, A);
+  Node member_y_B = nm->mkNode(SET_MEMBER, y, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_group");
+  Node member_B_n = nm->mkNode(SET_MEMBER, B, skolem);
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<ProjectOp>().getIndices();
+
+  Node xProjection = TupleUtils::getTupleProjection(indices, x);
+  Node yProjection = TupleUtils::getTupleProjection(indices, y);
+
+  // premises
+  exp.push_back(member_B_n);
+  exp.push_back(member_x_B);
+  exp.push_back(member_y_A);
+  exp.push_back(x.eqNode(y).notNode());
+  exp.push_back(xProjection.eqNode(yProjection));
+
+  Node part_x = nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_y = nm->mkNode(APPLY_UF, part, y);
+  part_y = registerAndAssertSkolemLemma(part_y, "part_y");
+  Node samePart = part_x.eqNode(part_y);
+  Node part_x_is_B = part_x.eqNode(B);
+  Node conclusion = nm->mkNode(AND, member_y_B, samePart, part_x_is_B);
+
+  d_im.assertInference(
+      conclusion, InferenceId::SETS_RELS_GROUP_SAME_PART, exp, 1);
+}
+
+Node TheorySetsPrivate::defineSkolemPartFunction(Node n)
+{
+  Assert(n.getKind() == RELATION_GROUP);
+  Node A = n[0];
+  TypeNode relationType = A.getType();
+  TypeNode elementType = relationType.getSetElementType();
+
+  // declare an uninterpreted function part: T -> (Relation T)
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+  TypeNode partType = nm->mkFunctionType(elementType, relationType);
+  Node part =
+      sm->mkSkolemFunction(SkolemFunId::RELATIONS_GROUP_PART, partType, {n});
+  return part;
+}
+
+Node TheorySetsPrivate::registerAndAssertSkolemLemma(Node& n,
+                                                     const std::string& prefix)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+  Node skolem = sm->mkPurifySkolem(n, prefix);
+  Node lemma = n.eqNode(skolem);
+  d_im.addPendingLemma(lemma, InferenceId::SETS_SKOLEM);
+  Trace("sets-skolems") << "sets-skolems:  " << skolem << " = " << n
+                        << std::endl;
+  return skolem;
+}
+
 void TheorySetsPrivate::checkDisequalities()
 {
   // disequalities
index 87ad4678e01be5db5b306df64ec4fcda1f7d2c7d..e80aa7012a23ae6e9b69529f5f5ea148d38bc3af 100644 (file)
@@ -127,6 +127,151 @@ class TheorySetsPrivate : protected EnvObj
    *   )
    */
   void checkMapDown();
+  void checkGroups();
+  void checkGroup(Node n);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type T
+   * @return an inference that represents:
+   * (=>
+   *  (= A (as set.empty T))
+   *  (= skolem (set.singleton (as set.empty T)))
+   * )
+   */
+  void groupNotEmpty(Node n);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param e an element of type T
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (set.member x A)
+   *   (and
+   *     (set.member (part x) skolem)
+   *     (set.member x (part x))
+   *     (not (set.member (as set.empty (Relation T)) skolem))
+   *   )
+   * )
+   *
+   * where skolem is a variable equals ((_ rel.group n1 ... nk) A)
+   */
+  void groupUp1(Node n, Node x, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param e an element of type T
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (not (set.member x A))
+   *   (= (part x) (as set.empty (Relation T)))
+   * )
+   *
+   * where skolem is a variable equals ((_ rel.group n1 ... nk) A)
+   */
+  void groupUp2(Node n, Node x, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param B an element of type (Relation T)
+   * @param x an element of type T
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (set.member B skolem)
+   *     (set.member x B)
+   *   )
+   *   (and
+   *     (set.member x A)
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A).
+   */
+  void groupDown(Node n, Node B, Node x, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param B an element of type (Relation T) and B is not of the form (part x)
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (set.member B skolem)
+   *     (not (= A (as set.empty (Relation T)))
+   *   )
+   *   (and
+   *     (= B (part k_{n, B}))
+   *     (set.member k_{n,B} B)
+   *     (set.member k_{n,B} A)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ rel.group n1 ... nk) A), and
+   * k_{n, B} is a fresh skolem of type T.
+   */
+  void groupPartMember(Node n, Node B, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param B an element of type (Relation T)
+   * @param x an element of type T
+   * @param y an element of type T
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (set.member B skolem)
+   *     (set.member x B)
+   *     (set.member y B)
+   *     (distinct x y)
+   *   )
+   *   (and
+   *     (= ((_ tuple.project n1 ... nk) x)
+   *        ((_ tuple.project n1 ... nk) y))
+   *     (= (part x) (part y))
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ rel.group n1 ... nk) A).
+   */
+  void groupSameProjection(Node n, Node B, Node x, Node y, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @param B an element of type (Relation T)
+   * @param x an element of type T
+   * @param y an element of type T
+   * @param part a skolem function of type T -> (Relation T) created uniquely
+   * for n by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (set.member B skolem)
+   *     (set.member x B)
+   *     (set.member y A)
+   *     (distinct x y)
+   *     (= ((_ tuple.project n1 ... nk) x)
+   *        ((_ tuple.project n1 ... nk) y))
+   *   )
+   *   (and
+   *     (set.member y B)
+   *     (= (part x) (part y))
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ rel.group n1 ... nk) A).
+   */
+  void groupSamePart(Node n, Node B, Node x, Node y, Node part);
+  /**
+   * @param n has form ((_ rel.group n1 ... nk) A) where A has type (Relation T)
+   * @return a function of type T -> (Relation T) that maps elements T to a
+   * part in the partition
+   */
+  Node defineSkolemPartFunction(Node n);
+  /**
+   * generate skolem variable for node n and add pending lemma for the equality
+   */
+  Node registerAndAssertSkolemLemma(Node& n, const std::string& prefix);
   /**
    * This implements a strategy for splitting for set disequalities which
    * roughly corresponds the SET DISEQUALITY rule from Bansal et al IJCAR 2016.
index 8ee9a33d1d82719664f573c95a797a48285f9de1..5fcd15dd1e44713e267b1a742134d83bccaa5ba8 100644 (file)
@@ -589,8 +589,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
     break;
   }
 
-  default:
-    break;
+  case RELATION_GROUP: return postRewriteGroup(node);
+  default: break;
   }
 
   return RewriteResponse(REWRITE_DONE, node);
@@ -741,6 +741,30 @@ RewriteResponse TheorySetsRewriter::postRewriteFold(TNode n)
   }
 }
 
+RewriteResponse TheorySetsRewriter::postRewriteGroup(TNode n)
+{
+  Assert(n.getKind() == kind::RELATION_GROUP);
+  Node A = n[0];
+  Kind k = A.getKind();
+  if (k == SET_EMPTY || k == SET_SINGLETON)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    // - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) =
+    //    (rel.singleton (as set.empty (Relation T) ))
+    // - ((_ rel.group n1 ... nk) (set.singleton x)) =
+    //      (set.singleton (set.singleton x))
+    Node singleton = nm->mkNode(SET_SINGLETON, A);
+    return RewriteResponse(REWRITE_AGAIN_FULL, singleton);
+  }
+  if (A.isConst())
+  {
+    Node evaluation = RelsUtils::evaluateGroup(n);
+    return RewriteResponse(REWRITE_AGAIN_FULL, evaluation);
+  }
+
+  return RewriteResponse(REWRITE_DONE, n);
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index ba48e403018485a8e93aa03e77fd1cc6781ef53d..dc4f64762e979f10fecf2f9d29735014c262bc1e 100644 (file)
@@ -102,6 +102,15 @@ private:
   *  where f: T -> S -> S, and t : S
   */
  RewriteResponse postRewriteFold(TNode n);
+ /**
+  *  rewrites for n include:
+  *  - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) =
+  *          (rel.singleton (as set.empty (Relation T) ))
+  *  - ((_ rel.group n1 ... nk) (set.singleton x)) =
+  *          (set.singleton (set.singleton x))
+  *  - Evaluation of ((_ rel.group n1 ... nk) A) when A is a constant
+  */
+ RewriteResponse postRewriteGroup(TNode n);
 }; /* class TheorySetsRewriter */
 
 }  // namespace sets
index b2eb7987ae6df2ddd994bb38f0ab2abbf6ab5a44..8a163489a436e621be1241033ed3947bc746a049 100644 (file)
@@ -20,6 +20,8 @@
 
 #include "theory/sets/normal_form.h"
 #include "util/cardinality.h"
+#include "theory/datatypes/project_op.h"
+#include "theory/datatypes/tuple_utils.h"
 
 namespace cvc5::internal {
 namespace theory {
@@ -577,6 +579,39 @@ TypeNode RelIdenTypeRule::computeType(NodeManager* nodeManager,
   return nodeManager->mkSetType(nodeManager->mkTupleType(tupleTypes));
 }
 
+TypeNode RelationGroupTypeRule::computeType(NodeManager* nm, TNode n, bool check)
+{
+  Assert(n.getKind() == kind::RELATION_GROUP && n.hasOperator()
+         && n.getOperator().getKind() == kind::RELATION_GROUP_OP);
+  ProjectOp op = n.getOperator().getConst<ProjectOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+
+  TypeNode setType = n[0].getType(check);
+
+  if (check)
+  {
+    if (!setType.isSet())
+    {
+      std::stringstream ss;
+      ss << "RELATION_GROUP operator expects a relation. Found '" << n[0]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = setType.getSetElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "RELATION_GROUP operator expects a relation. Found '" << n[0]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    datatypes::TupleUtils::checkTypeIndices(n, tupleType, indices);
+  }
+  return nm->mkSetType(setType);
+}
+
 Cardinality SetsProperties::computeCardinality(TypeNode type)
 {
   Assert(type.getKind() == kind::SET_TYPE);
index 59ea661580372179fc9a5172c1caddd7e70e68c6..551513fe6fc3a3cf63f14aff1fc5ae6c63009b22 100644 (file)
@@ -212,6 +212,17 @@ struct RelIdenTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 };
 
+/**
+ * Relation group operator is indexed by a list of indices (n_1, ..., n_k). It
+ * ensures that the argument is a relation whose arity is greater than each n_i
+ * for i = 1, ..., k. If the passed relation is of type T, then the returned
+ * type is (Set T), i.e., set of relations.
+ */
+struct RelationGroupTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct RelationGroupTypeRule */
+
 struct SetsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index 1a159a942404af4b4b8cf28b8e0ea68823ee5164..0697ccb53129eb3470515f2761d6a19a82175399 100644 (file)
@@ -2505,6 +2505,11 @@ set(regress_1_tests
   regress1/sets/proj-issue164.smt2
   regress1/sets/proj-issue178.smt2
   regress1/sets/proj-issue494-finite-leafof.smt2
+  regress1/sets/relation_group1.smt2
+  regress1/sets/relation_group2.smt2
+  regress1/sets/relation_group3.smt2
+  regress1/sets/relation_group4.smt2
+  regress1/sets/relation_group5.smt2
   regress1/sets/remove_check_free_31_6.smt2
   regress1/sets/sets-disequal.smt2
   regress1/sets/sets-tuple-poly.cvc.smt2
diff --git a/test/regress/cli/regress1/sets/relation_group1.smt2 b/test/regress/cli/regress1/sets/relation_group1.smt2
new file mode 100644 (file)
index 0000000..867c1d6
--- /dev/null
@@ -0,0 +1,110 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(define-fun truthRelation () (Relation String String String)
+  (set.union
+   (set.singleton (tuple "A" "X" "0"))
+   (set.singleton (tuple "A" "X" "1"))
+   (set.singleton (tuple "A" "Y" "0"))
+   (set.singleton (tuple "A" "Y" "1"))
+   (set.singleton (tuple "B" "X" "0"))
+   (set.singleton (tuple "B" "X" "1"))
+   (set.singleton (tuple "B" "Y" "0"))
+   (set.singleton (tuple "B" "Y" "1"))))
+
+; parition by first column
+(assert
+ (= ((_ rel.group 0) truthRelation)
+    (set.union
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "X" "0"))
+                 (set.singleton (tuple "A" "X" "1"))
+                 (set.singleton (tuple "A" "Y" "0"))
+                 (set.singleton (tuple "A" "Y" "1"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "B" "X" "0"))
+                 (set.singleton (tuple "B" "X" "1"))
+                 (set.singleton (tuple "B" "Y" "0"))
+                 (set.singleton (tuple "B" "Y" "1")))))))
+
+; parition by second column
+(assert
+ (= ((_ rel.group 1) truthRelation)
+    (set.union
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "X" "0"))
+                 (set.singleton (tuple "A" "X" "1"))
+                 (set.singleton (tuple "B" "X" "0"))
+                 (set.singleton (tuple "B" "X" "1"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "Y" "0"))
+                 (set.singleton (tuple "A" "Y" "1"))
+                 (set.singleton (tuple "B" "Y" "0"))
+                 (set.singleton (tuple "B" "Y" "1")))))))
+
+; parition by third column
+(assert
+ (= ((_ rel.group 2) truthRelation)
+    (set.union
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "X" "0"))
+                 (set.singleton (tuple "A" "Y" "0"))
+                 (set.singleton (tuple "B" "X" "0"))
+                 (set.singleton (tuple "B" "Y" "0"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "X" "1"))
+                 (set.singleton (tuple "A" "Y" "1"))
+                 (set.singleton (tuple "B" "X" "1"))
+                 (set.singleton (tuple "B" "Y" "1")))))))
+
+; parition by first,second columns
+(assert
+ (= ((_ rel.group 0 1) truthRelation)
+    (set.union
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "X" "0"))
+                 (set.singleton (tuple "A" "X" "1"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "A" "Y" "0"))
+                 (set.singleton (tuple "A" "Y" "1"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "B" "X" "0"))
+                 (set.singleton (tuple "B" "X" "1"))))
+     (set.singleton
+      (set.union (set.singleton (tuple "B" "Y" "0"))
+                 (set.singleton (tuple "B" "Y" "1")))))))
+
+; parition by no column
+(assert
+ (= (rel.group truthRelation)
+    (set.singleton
+     (set.union
+      (set.singleton (tuple "A" "X" "0"))
+      (set.singleton (tuple "A" "X" "1"))
+      (set.singleton (tuple "A" "Y" "0"))
+      (set.singleton (tuple "A" "Y" "1"))
+      (set.singleton (tuple "B" "X" "0"))
+      (set.singleton (tuple "B" "X" "1"))
+      (set.singleton (tuple "B" "Y" "0"))
+      (set.singleton (tuple "B" "Y" "1"))))))
+
+; parition by all columns
+(assert
+ (= ((_ rel.group 0 1 2) truthRelation)
+    (set.union
+     (set.singleton (set.singleton (tuple "A" "X" "0")))
+     (set.singleton (set.singleton (tuple "A" "X" "1")))
+     (set.singleton (set.singleton (tuple "A" "Y" "0")))
+     (set.singleton (set.singleton (tuple "A" "Y" "1")))
+     (set.singleton (set.singleton (tuple "B" "X" "0")))
+     (set.singleton (set.singleton (tuple "B" "X" "1")))
+     (set.singleton (set.singleton (tuple "B" "Y" "0")))
+     (set.singleton (set.singleton (tuple "B" "Y" "1"))))))
+
+; group of set.empty
+(assert
+ (= ((_ rel.group 0) (as set.empty (Relation String String String)))
+    (set.singleton (as set.empty (Relation String String String)))))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_group2.smt2 b/test/regress/cli/regress1/sets/relation_group2.smt2
new file mode 100644 (file)
index 0000000..992e6a2
--- /dev/null
@@ -0,0 +1,16 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun data () (Relation String String String))
+(declare-fun part1 () (Relation String String String))
+(declare-fun part2 () (Relation String String String))
+(declare-fun partition () (Set (Relation String String String)))
+
+(assert (distinct data part1 part2 (as set.empty (Relation String String String))))
+
+(assert (= partition ((_ rel.group 0) data)))
+(assert (set.member part1 partition))
+(assert (set.member part2 partition))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_group3.smt2 b/test/regress/cli/regress1/sets/relation_group3.smt2
new file mode 100644 (file)
index 0000000..2af8f0c
--- /dev/null
@@ -0,0 +1,21 @@
+; DISABLE-TESTER: lfsc
+; Disabled since rel.group is not supported in LFSC
+(set-logic HO_ALL)
+
+(set-info :status unsat)
+
+(declare-fun data () (Relation String String String))
+(declare-fun part1 () (Relation String String String))
+(declare-fun part2 () (Relation String String String))
+(declare-fun partition () (Set (Relation String String String)))
+
+(assert (distinct data part1 part2 (as set.empty (Relation String String String))))
+
+(assert (= partition ((_ rel.group 0 1) data)))
+(assert (set.member part1 partition))
+(assert (set.member part2 partition))
+
+(assert (set.member (tuple "A" "X" "0") part1))
+(assert (set.member (tuple "A" "X" "1") part2))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_group4.smt2 b/test/regress/cli/regress1/sets/relation_group4.smt2
new file mode 100644 (file)
index 0000000..33136c7
--- /dev/null
@@ -0,0 +1,13 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun data () (Relation String String String))
+(declare-fun groupBy0 () (Set (Relation String String String)))
+(declare-fun groupBy1 () (Set (Relation String String String)))
+
+(assert (= groupBy0 ((_ rel.group 0) data)))
+(assert (= groupBy1 ((_ rel.group 1) data)))
+(assert (distinct groupBy0 groupBy1))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_group5.smt2 b/test/regress/cli/regress1/sets/relation_group5.smt2
new file mode 100644 (file)
index 0000000..d9dc1bc
--- /dev/null
@@ -0,0 +1,15 @@
+; DISABLE-TESTER: lfsc
+; Disabled since table.group is not supported in LFSC
+(set-logic HO_ALL)
+
+(set-info :status unsat)
+
+(declare-fun data () (Relation String String String))
+(declare-fun groupBy01 () (Set (Relation String String String)))
+(declare-fun groupBy10 () (Set (Relation String String String)))
+
+(assert (= groupBy01 ((_ rel.group 0 1) data)))
+(assert (= groupBy10 ((_ rel.group 1 0) data)))
+(assert (distinct groupBy01 groupBy10))
+
+(check-sat)