Add inference rules for table.group (#8819)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 23 Jun 2022 21:21:56 +0000 (16:21 -0500)
committerGitHub <noreply@github.com>
Thu, 23 Jun 2022 21:21:56 +0000 (21:21 +0000)
16 files changed:
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_utils.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/inference_id.cpp
src/theory/inference_id.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/table_group2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_group3.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_group4.smt2 [new file with mode: 0644]
test/regress/cli/regress1/bags/table_group5.smt2 [new file with mode: 0644]

index 2171954ba5da89481bbb5671f7fc22dcd342b2df..6b8e32f55e95c04f0acf3ddce32ddcb454e12308 100644 (file)
@@ -95,6 +95,9 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::BAGS_MAP_PREIMAGE_INDEX: return "BAGS_MAP_PREIMAGE_INDEX";
     case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
     case SkolemFunId::BAGS_DEQ_DIFF: return "BAGS_DEQ_DIFF";
+    case SkolemFunId::TABLES_GROUP_PART: return "TABLES_GROUP_PART";
+    case SkolemFunId::TABLES_GROUP_PART_ELEMENT:
+      return "TABLES_GROUP_PART_ELEMENT";
     case SkolemFunId::SETS_CHOOSE: return "SETS_CHOOSE";
     case SkolemFunId::SETS_DEQ_DIFF: return "SETS_DEQ_DIFF";
     case SkolemFunId::SETS_FOLD_CARD: return "SETS_FOLD_CARD";
index 1b2f3df3bb3fad6453b7984cf1d2c4f3ad048183..da729e4a1e1859ade01714ac9f9fba28e9806a6c 100644 (file)
@@ -169,6 +169,17 @@ enum class SkolemFunId
   BAGS_MAP_SUM,
   /** bag diff to witness (not (= A B)) */
   BAGS_DEQ_DIFF,
+  /** Given a group term ((_ table.group n1 ... nk) A) of type (Bag (Table T))
+   * this uninterpreted function maps elements of A to their parts in the
+   * resulting partition. It has type (-> T (Table T))
+   */
+  TABLES_GROUP_PART,
+  /**
+   * Given a group term ((_ table.group n1 ... nk) A) of type (Bag (Table T))
+   * and a part B of type (Table T), this function returns a skolem element
+   * that is a member of B if B is not empty.
+   */
+  TABLES_GROUP_PART_ELEMENT,
   /** An interpreted function for bag.choose operator:
    * (choose A) is expanded as
    * (witness ((x elementType))
index 5f52c6215d13bd21eac05f50bc4cfcb7350d5c78..6aea2f5efeab84dbd8fbd40a6dd0647e38d7da99 100644 (file)
@@ -80,6 +80,7 @@ void BagSolver::checkBasicOperations()
         case kind::BAG_MAP: checkMap(n); break;
         case kind::TABLE_PRODUCT: checkProduct(n); break;
         case kind::TABLE_JOIN: checkJoin(n); break;
+        case kind::TABLE_GROUP: checkGroup(n); break;
         default: break;
       }
       it++;
@@ -360,6 +361,95 @@ void BagSolver::checkJoin(Node n)
   }
 }
 
+void BagSolver::checkGroup(Node n)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+
+  InferInfo notEmpty = d_ig.groupNotEmpty(n);
+  d_im.lemmaTheoryInference(&notEmpty);
+
+  Node part = d_ig.defineSkolemPartFunction(n);
+
+  const set<Node>& elementsA = d_state.getElements(n[0]);
+  std::shared_ptr<context::CDHashSet<Node>> skolems =
+      d_state.getPartElementSkolems(n);
+  for (const Node& a : elementsA)
+  {
+    if (skolems->contains(a))
+    {
+      // skip skolem elements that were introduced by groupPartCount below.
+      continue;
+    }
+    Node aRep = d_state.getRepresentative(a);
+    InferInfo i = d_ig.groupUp1(n, aRep, part);
+    d_im.lemmaTheoryInference(&i);
+    i = d_ig.groupUp2(n, aRep, part);
+    d_im.lemmaTheoryInference(&i);
+  }
+
+  std::set<Node> parts = d_state.getElements(n);
+  for (std::set<Node>::iterator partIt1 = parts.begin(); partIt1 != parts.end();
+       ++partIt1)
+  {
+    Node part1 = d_state.getRepresentative(*partIt1);
+    std::vector<Node> partEqc;
+    d_state.getEquivalenceClass(part1, partEqc);
+    bool newPart = true;
+    for (Node p : partEqc)
+    {
+      if (p.getKind() == APPLY_UF && p.getOperator() == part)
+      {
+        newPart = false;
+      }
+    }
+    if (newPart)
+    {
+      // only apply the groupPartCount rule for a part that does not have
+      // nodes of the form (part x) introduced by the group up rule above.
+      InferInfo partCardinality = d_ig.groupPartCount(n, part1, part);
+      d_im.lemmaTheoryInference(&partCardinality);
+    }
+
+    std::set<Node> partElements = d_state.getElements(part1);
+    for (std::set<Node>::iterator i = partElements.begin();
+         i != partElements.end();
+         ++i)
+    {
+      Node x = d_state.getRepresentative(*i);
+      if (!skolems->contains(x))
+      {
+        // only apply down rules for elements not generated by groupPartCount
+        // rule above
+        InferInfo down = d_ig.groupDown(n, part1, x, part);
+        d_im.lemmaTheoryInference(&down);
+      }
+
+      std::set<Node>::iterator j = i;
+      ++j;
+      while (j != partElements.end())
+      {
+        Node y = d_state.getRepresentative(*j);
+        // x, y should have the same projection
+        InferInfo sameProjection =
+            d_ig.groupSameProjection(n, part1, x, y, part);
+        d_im.lemmaTheoryInference(&sameProjection);
+        ++j;
+      }
+
+      for (const Node& a : elementsA)
+      {
+        Node y = d_state.getRepresentative(a);
+        if (x != y)
+        {
+          // x, y should have the same projection
+          InferInfo samePart = d_ig.groupSamePart(n, part1, x, y, part);
+          d_im.lemmaTheoryInference(&samePart);
+        }
+      }
+    }
+  }
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index 2d6e890df900bfac604dade7fe33465ca8267430..d7ed894abf21cf3f21c8dfecee0ec7f9867aac99 100644 (file)
@@ -102,6 +102,8 @@ class BagSolver : protected EnvObj
   void checkProduct(Node n);
   /** apply inference rules for join operator */
   void checkJoin(Node n);
+  /** apply inference rules for group operator */
+  void checkGroup(Node n);
 
   /** The solver state object */
   SolverState& d_state;
index 83be57a5e5dfd4a54db7430a1166f0e4f2bce8f4..17f1cf5d37c5016b52a3a3cbb210818bfcb7d829 100644 (file)
@@ -985,12 +985,6 @@ Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n)
   TypeNode bagType = A.getType();
   TypeNode partitionType = n.getType();
 
-  if (A.getKind() == BAG_EMPTY)
-  {
-    // return a nonempty partition
-    return nm->mkNode(BAG_MAKE, A, nm->mkConstInt(Rational(1)));
-  }
-
   std::vector<uint32_t> indices =
       n.getOperator().getConst<ProjectOp>().getIndices();
 
@@ -1048,6 +1042,12 @@ Node BagsUtils::evaluateGroup(Rewriter* rewriter, TNode n)
     // each part in the partitions has multiplicity one
     parts[part] = Rational(1);
   }
+  if (parts.empty())
+  {
+    // add an empty part
+    Node emptyPart = nm->mkConst(EmptyBag(bagType));
+    parts[emptyPart] = Rational(1);
+  }
   Node ret = constructConstantBagFromElements(partitionType, parts);
   Trace("bags-partition") << "ret: " << ret << std::endl;
   return ret;
index dd560b8f3bd6132f679018b39ab2c00208fc1e8b..152d9ce6516b293d53c2a80e511d42f403db903c 100644 (file)
@@ -718,6 +718,239 @@ InferInfo InferenceGenerator::joinDown(Node n, Node e)
   return inferInfo;
 }
 
+InferInfo InferenceGenerator::groupNotEmpty(Node n)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+
+  TypeNode bagType = n.getType();
+  Node A = n[0];
+  Node emptyPart = d_nm->mkConst(EmptyBag(A.getType()));
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_NOT_EMPTY);
+  Node A_isEmpty = A.eqNode(emptyPart);
+  inferInfo.d_premises.push_back(A_isEmpty);
+  Node singleton = d_nm->mkNode(BAG_MAKE, emptyPart, d_one);
+  Node groupIsSingleton = skolem.eqNode(singleton);
+
+  inferInfo.d_conclusion = groupIsSingleton;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupUp1(Node n, Node x, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(x.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_UP1);
+  Node count_x_A = getMultiplicityTerm(x, A);
+  Node x_member_A = d_nm->mkNode(GEQ, count_x_A, d_one);
+  inferInfo.d_premises.push_back(x_member_A);
+
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+
+  Node count_x_part_x = getMultiplicityTerm(x, part_x);
+
+  Node sameMultiplicity = count_x_part_x.eqNode(count_x_A);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count_part_x = getMultiplicityTerm(part_x, skolem);
+  Node part_x_member = d_nm->mkNode(EQUAL, count_part_x, d_one);
+
+  Node emptyPart = d_nm->mkConst(EmptyBag(bagType));
+  Node count_emptyPart = getMultiplicityTerm(emptyPart, skolem);
+  Node emptyPart_not_member = count_emptyPart.eqNode(d_zero);
+
+  inferInfo.d_conclusion = d_nm->mkNode(
+      AND, {sameMultiplicity, part_x_member, emptyPart_not_member});
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupUp2(Node n, Node x, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(x.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_UP2);
+  Node count_x_A = getMultiplicityTerm(x, A);
+  Node x_not_in_A = d_nm->mkNode(EQUAL, count_x_A, d_zero);
+  inferInfo.d_premises.push_back(x_not_in_A);
+
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_x_is_empty = part_x.eqNode(d_nm->mkConst(EmptyBag(bagType)));
+  inferInfo.d_conclusion = part_x_is_empty;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupDown(Node n, Node B, Node x, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(B.getType() == n.getType().getBagElementType());
+  Assert(x.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_DOWN);
+  Node count_x_B = getMultiplicityTerm(x, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count_B_n = getMultiplicityTerm(B, skolem);
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_B_n, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_x_B, d_one));
+  Node count_x_A = getMultiplicityTerm(x, A);
+  Node sameMultiplicity = count_x_B.eqNode(count_x_A);
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_x_is_B = part_x.eqNode(B);
+  inferInfo.d_conclusion = d_nm->mkNode(AND, sameMultiplicity, part_x_is_B);
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupPartCount(Node n, Node B, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(B.getType() == n.getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+  Node empty = d_nm->mkConst(EmptyBag(bagType));
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_PART_COUNT);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count_B_n = getMultiplicityTerm(B, skolem);
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_B_n, d_one));
+  Node A_notEmpty = A.eqNode(empty).notNode();
+  inferInfo.d_premises.push_back(A_notEmpty);
+
+  Node x = d_sm->mkSkolemFunction(SkolemFunId::TABLES_GROUP_PART_ELEMENT,
+                                  bagType.getBagElementType(),
+                                  {n, B});
+  d_state->registerPartElementSkolem(n, x);
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node B_is_part_x = B.eqNode(part_x);
+  Node count_x_A = getMultiplicityTerm(x, A);
+  Node count_x_B = getMultiplicityTerm(x, B);
+  Node sameMultiplicity = count_x_A.eqNode(count_x_B);
+  Node x_in_B = d_nm->mkNode(GEQ, count_x_B, d_one);
+  Node count_B_n_is_one = count_B_n.eqNode(d_one);
+  inferInfo.d_conclusion = d_nm->mkNode(AND,
+                                        {
+                                            count_B_n_is_one,
+                                            B_is_part_x,
+                                            x_in_B,
+                                            sameMultiplicity,
+                                        });
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupSameProjection(
+    Node n, Node B, Node x, Node y, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(B.getType() == n.getType().getBagElementType());
+  Assert(x.getType() == n[0].getType().getBagElementType());
+  Assert(y.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_SAME_PROJECTION);
+  Node count_x_B = getMultiplicityTerm(x, B);
+  Node count_y_B = getMultiplicityTerm(y, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count_B_n = getMultiplicityTerm(B, skolem);
+
+  // premises
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_B_n, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_x_B, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_y_B, d_one));
+  inferInfo.d_premises.push_back(x.eqNode(y).notNode());
+
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<ProjectOp>().getIndices();
+
+  Node xProjection = TupleUtils::getTupleProjection(indices, x);
+  Node yProjection = TupleUtils::getTupleProjection(indices, y);
+  Node sameProjection = xProjection.eqNode(yProjection);
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_y = d_nm->mkNode(APPLY_UF, part, y);
+  part_y = registerAndAssertSkolemLemma(part_y, "part_y");
+  Node samePart = part_x.eqNode(part_y);
+  Node part_x_is_B = part_x.eqNode(B);
+  inferInfo.d_conclusion =
+      d_nm->mkNode(AND, sameProjection, samePart, part_x_is_B);
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::groupSamePart(
+    Node n, Node B, Node x, Node y, Node part)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Assert(B.getType() == n.getType().getBagElementType());
+  Assert(x.getType() == n[0].getType().getBagElementType());
+  Assert(y.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  TypeNode bagType = A.getType();
+
+  InferInfo inferInfo(d_im, InferenceId::TABLES_GROUP_SAME_PART);
+  Node count_x_B = getMultiplicityTerm(x, B);
+  Node count_y_A = getMultiplicityTerm(y, A);
+  Node count_y_B = getMultiplicityTerm(y, B);
+
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
+  Node count_B_n = getMultiplicityTerm(B, skolem);
+  const std::vector<uint32_t>& indices =
+      n.getOperator().getConst<ProjectOp>().getIndices();
+
+  Node xProjection = TupleUtils::getTupleProjection(indices, x);
+  Node yProjection = TupleUtils::getTupleProjection(indices, y);
+
+  // premises
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_B_n, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_x_B, d_one));
+  inferInfo.d_premises.push_back(d_nm->mkNode(GEQ, count_y_A, d_one));
+  inferInfo.d_premises.push_back(x.eqNode(y).notNode());
+  inferInfo.d_premises.push_back(xProjection.eqNode(yProjection));
+
+  Node sameMultiplicity = count_y_B.eqNode(count_y_A);
+  Node part_x = d_nm->mkNode(APPLY_UF, part, x);
+  part_x = registerAndAssertSkolemLemma(part_x, "part_x");
+  Node part_y = d_nm->mkNode(APPLY_UF, part, y);
+  part_y = registerAndAssertSkolemLemma(part_y, "part_y");
+  Node samePart = part_x.eqNode(part_y);
+  Node part_x_is_B = part_x.eqNode(B);
+  inferInfo.d_conclusion =
+      d_nm->mkNode(AND, sameMultiplicity, samePart, part_x_is_B);
+
+  return inferInfo;
+}
+
+Node InferenceGenerator::defineSkolemPartFunction(Node n)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  Node A = n[0];
+  TypeNode tableType = A.getType();
+  TypeNode elementType = tableType.getBagElementType();
+
+  // declare an uninterpreted function part: T -> (Table T)
+  TypeNode partType = d_nm->mkFunctionType(elementType, tableType);
+  Node part =
+      d_sm->mkSkolemFunction(SkolemFunId::TABLES_GROUP_PART, partType, {n});
+  return part;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index 91a3ea08a59bc495bbdce074ffaf7e02d6192cac..a0fbc0e3a01f392add3c0e15e1bf5287566da283 100644 (file)
@@ -367,6 +367,147 @@ class InferenceGenerator
    */
   Node getMultiplicityTerm(Node element, Node bag);
 
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type T
+   * @return an inference that represents:
+   * (=>
+   *  (= A (as bag.empty T))
+   *  (= skolem (bag (as bag.empty T) 1))
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A)
+   */
+  InferInfo groupNotEmpty(Node n);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param e an element of type T
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (bag.member x A)
+   *   (and
+   *     (= (bag.count (part x) skolem) 1)
+   *     (= (bag.count x (part x)) (bag.count x A))
+   *     (= (bag.count (as bag.empty (Table T)) skolem) 0)
+   *   )
+   * )
+   *
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A)
+   */
+  InferInfo groupUp1(Node n, Node x, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param e an element of type T
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (= (bag.count x A) 0)
+   *   (= (part x) (as bag.empty (Table T)))
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A)
+   */
+  InferInfo groupUp2(Node n, Node x, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param B an element of type (Table T)
+   * @param x an element of type T
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (bag.member B skolem)
+   *     (bag.member x B)
+   *   )
+   *   (and
+   *     (= (bag.count x B) (bag.count x A))
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A).
+   */
+  InferInfo groupDown(Node n, Node B, Node x, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param B an element of type (Table T) and B is not of the form (part x)
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (bag.member B skolem)
+   *     (not (= A (as bag.empty (Table T)))
+   *   )
+   *   (and
+   *     (= (bag.count B skolem) 1)
+   *     (= B (part k_{n, B}))
+   *     (>= (bag.count k_{n,B} B) 1)
+   *     (= (bag.count k_{n,B} B) (bag.count k_{n,B} A))
+   *   )
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A), and
+   * k_{n, B} is a fresh skolem of type T.
+   */
+  InferInfo groupPartCount(Node n, Node B, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param B an element of type (Table T)
+   * @param x an element of type T
+   * @param y an element of type T
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (bag.member B skolem)
+   *     (bag.member x B)
+   *     (bag.member y B)
+   *     (distinct x y)
+   *   )
+   *   (and
+   *     (= ((_ tuple.project n1 ... nk) x)
+   *        ((_ tuple.project n1 ... nk) y))
+   *     (= (part x) (part y))
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A).
+   */
+  InferInfo groupSameProjection(Node n, Node B, Node x, Node y, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @param B an element of type (Table T)
+   * @param x an element of type T
+   * @param y an element of type T
+   * @param part a skolem function of type T -> (Table T) created uniquely for n
+   * by defineSkolemPartFunction function below
+   * @return an inference that represents:
+   * (=>
+   *   (and
+   *     (bag.member B skolem)
+   *     (bag.member x B)
+   *     (bag.member y A)
+   *     (distinct x y)
+   *     (= ((_ tuple.project n1 ... nk) x)
+   *        ((_ tuple.project n1 ... nk) y))
+   *   )
+   *   (and
+   *     (= (bag.count y B) (bag.count y A))
+   *     (= (part x) (part y))
+   *     (= (part x) B)
+   *   )
+   * )
+   * where skolem is a variable equals ((_ table.group n1 ... nk) A).
+   */
+  InferInfo groupSamePart(Node n, Node B, Node x, Node y, Node part);
+  /**
+   * @param n has form ((_ table.group n1 ... nk) A) where A has type (Table T)
+   * @return a function of type T -> (Table T) that maps elements T to a part in
+   * the partition
+   */
+  Node defineSkolemPartFunction(Node n);
+
  private:
   /**
    * generate skolem variable for node n and add pending lemma for the equality
index 5fe8bae13e08a1e8fbc65ff913c7628517492f5a..ef4f5124686044e954c6bc02b3cb6f1e2c892418 100644 (file)
@@ -27,7 +27,8 @@ namespace cvc5::internal {
 namespace theory {
 namespace bags {
 
-SolverState::SolverState(Env& env, Valuation val) : TheoryState(env, val)
+SolverState::SolverState(Env& env, Valuation val)
+    : TheoryState(env, val), d_partElementSkolems(env.getUserContext())
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
@@ -38,6 +39,12 @@ void SolverState::registerBag(TNode n)
 {
   Assert(n.getType().isBag());
   d_bags.insert(n);
+  if (n.getKind() == TABLE_GROUP)
+  {
+    std::shared_ptr<context::CDHashSet<Node>> set =
+        std::make_shared<context::CDHashSet<Node>>(d_env.getUserContext());
+    d_partElementSkolems[n] = set;
+  }
 }
 
 void SolverState::registerCountTerm(Node bag, Node element, Node skolem)
@@ -129,6 +136,20 @@ void SolverState::collectDisequalBagTerms()
 
 const std::map<Node, Node>& SolverState::getDisequalBagTerms() { return d_deq; }
 
+void SolverState::registerPartElementSkolem(Node group, Node skolemElement)
+{
+  Assert(group.getKind() == TABLE_GROUP);
+  Assert(skolemElement.getType() == group[0].getType().getBagElementType());
+  d_partElementSkolems[group].get()->insert(skolemElement);
+}
+
+std::shared_ptr<context::CDHashSet<Node>> SolverState::getPartElementSkolems(
+    Node n)
+{
+  Assert(n.getKind() == TABLE_GROUP);
+  return d_partElementSkolems[n];
+}
+
 void SolverState::reset()
 {
   d_bagElements.clear();
index c94728d0025ceb1ef66d4978e1c912d8b2b714d0..8bb0e3f15158e19b794bb09922cb870f4a1058d9 100644 (file)
@@ -20,6 +20,8 @@
 
 #include <map>
 
+#include "context/cdhashmap.h"
+#include "context/cdhashset.h"
 #include "theory/theory_state.h"
 
 namespace cvc5::internal {
@@ -88,6 +90,10 @@ class SolverState : public TheoryState
    * skolems that witness the negation of these equalities
    */
   const std::map<Node, Node>& getDisequalBagTerms();
+  /** register skolem element generated by grup count rule */
+  void registerPartElementSkolem(Node group, Node skolemElement);
+  /** return skolem elements generated by group part count rule. */
+  std::shared_ptr<context::CDHashSet<Node>> getPartElementSkolems(Node n);
   /**
    * return a list of bag elements and their skolem counts
    */
@@ -123,6 +129,16 @@ class SolverState : public TheoryState
   std::map<Node, Node> d_deq;
   /** a map from card terms to their skolem variables */
   std::map<Node, Node> d_cardTerms;
+
+  /**
+   * A cache that stores skolem elements generated by inference rule
+   * InferenceId::TABLES_GROUP_PART_COUNT.
+   * It maps table.group nodes to generated skolem elements.
+   * The skolem elements need to persist during checking, and should only change
+   * when the user context changes.
+   */
+  context::CDHashMap<Node, std::shared_ptr<context::CDHashSet<Node>>>
+      d_partElementSkolems;
 }; /* class SolverState */
 
 }  // namespace bags
index 854070b0cdc55857b86fec89db2f17f2f7bf9e1a..ad3dfbeef9f3cea3876d1d6b7173d34995f75591 100644 (file)
@@ -143,6 +143,14 @@ const char* toString(InferenceId i)
     case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN";
     case InferenceId::TABLES_JOIN_UP: return "TABLES_JOIN_UP";
     case InferenceId::TABLES_JOIN_DOWN: return "TABLES_JOIN_DOWN";
+    case InferenceId::TABLES_GROUP_NOT_EMPTY: return "TABLES_GROUP_NOT_EMPTY";
+    case InferenceId::TABLES_GROUP_UP1: return "TABLES_GROUP_UP1";
+    case InferenceId::TABLES_GROUP_UP2: return "TABLES_GROUP_UP2";
+    case InferenceId::TABLES_GROUP_DOWN: return "TABLES_GROUP_DOWN";
+    case InferenceId::TABLES_GROUP_PART_COUNT: return "TABLES_GROUP_PART_COUNT";
+    case InferenceId::TABLES_GROUP_SAME_PROJECTION:
+      return "TABLES_GROUP_SAME_PROJECTION";
+    case InferenceId::TABLES_GROUP_SAME_PART: return "TABLES_GROUP_SAME_PART";
 
     case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
     case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
index 723345a17c00da14cc4d2a54ad53f6c1bfdffe42..4f5e3d15605bb63fcbf7a1f32727e9cd87f7015b 100644 (file)
@@ -210,6 +210,13 @@ enum class InferenceId
   TABLES_PRODUCT_DOWN,
   TABLES_JOIN_UP,
   TABLES_JOIN_DOWN,
+  TABLES_GROUP_NOT_EMPTY,
+  TABLES_GROUP_UP1,
+  TABLES_GROUP_UP2,
+  TABLES_GROUP_DOWN,
+  TABLES_GROUP_PART_COUNT,
+  TABLES_GROUP_SAME_PROJECTION,
+  TABLES_GROUP_SAME_PART,
   // ---------------------------------- end bags theory
 
   // ---------------------------------- bitvector theory
index 1ced25e3cc895b23766bdcbd5268b272abaacd95..ee6df9221f6608af85eda29a42126d8649ea9afb 100644 (file)
@@ -1841,6 +1841,10 @@ set(regress_1_tests
   regress1/bags/subbag2.smt2
   regress1/bags/table_aggregate1.smt2
   regress1/bags/table_group1.smt2
+  regress1/bags/table_group2.smt2
+  regress1/bags/table_group3.smt2
+  regress1/bags/table_group4.smt2
+  regress1/bags/table_group5.smt2
   regress1/bags/table_join1.smt2
   regress1/bags/table_join2.smt2
   regress1/bags/table_join3.smt2
diff --git a/test/regress/cli/regress1/bags/table_group2.smt2 b/test/regress/cli/regress1/bags/table_group2.smt2
new file mode 100644 (file)
index 0000000..728988e
--- /dev/null
@@ -0,0 +1,16 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun data () (Table String String String))
+(declare-fun part1 () (Table String String String))
+(declare-fun part2 () (Table String String String))
+(declare-fun partition () (Bag (Table String String String)))
+
+(assert (distinct data part1 part2 (as bag.empty (Table String String String))))
+
+(assert (= partition ((_ table.group 0) data)))
+(assert (bag.member part1 partition))
+(assert (bag.member part2 partition))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_group3.smt2 b/test/regress/cli/regress1/bags/table_group3.smt2
new file mode 100644 (file)
index 0000000..68e4e8c
--- /dev/null
@@ -0,0 +1,21 @@
+; DISABLE-TESTER: lfsc
+; Disabled since table.group is not supported in LFSC
+(set-logic HO_ALL)
+
+(set-info :status unsat)
+
+(declare-fun data () (Table String String String))
+(declare-fun part1 () (Table String String String))
+(declare-fun part2 () (Table String String String))
+(declare-fun partition () (Bag (Table String String String)))
+
+(assert (distinct data part1 part2 (as bag.empty (Table String String String))))
+
+(assert (= partition ((_ table.group 0 1) data)))
+(assert (bag.member part1 partition))
+(assert (bag.member part2 partition))
+
+(assert (bag.member (tuple "A" "X" "0") part1))
+(assert (bag.member (tuple "A" "X" "1") part2))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_group4.smt2 b/test/regress/cli/regress1/bags/table_group4.smt2
new file mode 100644 (file)
index 0000000..01a114a
--- /dev/null
@@ -0,0 +1,13 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun data () (Table String String String))
+(declare-fun groupBy0 () (Bag (Table String String String)))
+(declare-fun groupBy1 () (Bag (Table String String String)))
+
+(assert (= groupBy0 ((_ table.group 0) data)))
+(assert (= groupBy1 ((_ table.group 1) data)))
+(assert (distinct groupBy0 groupBy1))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/bags/table_group5.smt2 b/test/regress/cli/regress1/bags/table_group5.smt2
new file mode 100644 (file)
index 0000000..e053a46
--- /dev/null
@@ -0,0 +1,15 @@
+; DISABLE-TESTER: lfsc
+; Disabled since table.group is not supported in LFSC
+(set-logic HO_ALL)
+
+(set-info :status unsat)
+
+(declare-fun data () (Table String String String))
+(declare-fun groupBy01 () (Bag (Table String String String)))
+(declare-fun groupBy10 () (Bag (Table String String String)))
+
+(assert (= groupBy01 ((_ table.group 0 1) data)))
+(assert (= groupBy10 ((_ table.group 1 0) data)))
+(assert (distinct groupBy01 groupBy10))
+
+(check-sat)