Add rel.project operator to sets (#8929)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Wed, 6 Jul 2022 21:02:23 +0000 (16:02 -0500)
committerGitHub <noreply@github.com>
Wed, 6 Jul 2022 21:02:23 +0000 (21:02 +0000)
25 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_reduction.cpp
src/theory/bags/bag_reduction.h
src/theory/bags/bags_utils.cpp
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/sets/kinds
src/theory/sets/set_reduction.cpp
src/theory/sets/set_reduction.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets_rewriter.cpp
src/theory/sets/theory_sets_rewriter.h
src/theory/sets/theory_sets_type_rules.cpp
src/theory/sets/theory_sets_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/bags/table_project2.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_project1.smt2 [new file with mode: 0644]
test/regress/cli/regress1/sets/relation_project2.smt2 [new file with mode: 0644]
test/unit/api/cpp/op_black.cpp
test/unit/api/java/OpTest.java
test/unit/api/python/test_op.py

index 692eff021ca00438a64b4d458e72cb13361fa290..b72f85c44a15a7d7909095573852de111c276edc 100644 (file)
@@ -311,6 +311,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(RELATION_IDEN, internal::Kind::RELATION_IDEN),
         KIND_ENUM(RELATION_GROUP, internal::Kind::RELATION_GROUP),
         KIND_ENUM(RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE),
+        KIND_ENUM(RELATION_PROJECT, internal::Kind::RELATION_PROJECT),
         /* Bags ------------------------------------------------------------- */
         KIND_ENUM(BAG_UNION_MAX, internal::Kind::BAG_UNION_MAX),
         KIND_ENUM(BAG_UNION_DISJOINT, internal::Kind::BAG_UNION_DISJOINT),
@@ -638,6 +639,8 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::RELATION_GROUP, RELATION_GROUP},
         {internal::Kind::RELATION_AGGREGATE_OP, RELATION_AGGREGATE},
         {internal::Kind::RELATION_AGGREGATE, RELATION_AGGREGATE},
+        {internal::Kind::RELATION_PROJECT_OP, RELATION_PROJECT},
+        {internal::Kind::RELATION_PROJECT, RELATION_PROJECT},
         /* Bags ------------------------------------------------------------ */
         {internal::Kind::BAG_UNION_MAX, BAG_UNION_MAX},
         {internal::Kind::BAG_UNION_DISJOINT, BAG_UNION_DISJOINT},
@@ -779,6 +782,7 @@ const static std::unordered_map<Kind, internal::Kind> s_op_kinds{
     {TUPLE_PROJECT, internal::Kind::TUPLE_PROJECT_OP},
     {RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE_OP},
     {RELATION_GROUP, internal::Kind::RELATION_GROUP_OP},
+    {RELATION_PROJECT, internal::Kind::RELATION_PROJECT_OP},
     {TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP},
     {TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP},
     {TABLE_JOIN, internal::Kind::TABLE_JOIN_OP},
@@ -1959,6 +1963,7 @@ size_t Op::getNumIndicesHelper() const
     case TUPLE_PROJECT:
     case RELATION_AGGREGATE:
     case RELATION_GROUP:
+    case RELATION_PROJECT:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
@@ -2121,6 +2126,7 @@ Term Op::getIndexHelper(size_t index) const
     case TUPLE_PROJECT:
     case RELATION_AGGREGATE:
     case RELATION_GROUP:
+    case RELATION_PROJECT:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
@@ -6164,6 +6170,7 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
     case TUPLE_PROJECT:
     case RELATION_AGGREGATE:
     case RELATION_GROUP:
+    case RELATION_PROJECT:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
     case TABLE_JOIN:
index 4d4234f03f1913a65d313a4e1be3641219bb59ba..0220d48afd1918f158e31516fedf5e02ad0bde5b 100644 (file)
@@ -3398,6 +3398,26 @@ enum Kind : int32_t
    * \endrst
    */
   RELATION_AGGREGATE,
+  /**
+   * Relation projection operator extends tuple projection operator to sets.
+   *
+   * - Arity: ``1``
+   *   - ``1:`` Term of relation Sort
+   *
+   * - Indices: ``n``
+   *   - ``1..n:`` Indices of the projection
+   *
+   * - Create Term of this Kind with:
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * - Create Op of this kind with:
+   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  RELATION_PROJECT,
 
   /* Bags ------------------------------------------------------------------ */
 
index 96421e6ce00731d6bc5f4d68311674689a88e8de..ebeabe3f6fb4d921e9d647eaf3a635c1340ada32 100644 (file)
@@ -1441,6 +1441,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK RELATION_PROJECT_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_PROJECT, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1624,6 +1630,13 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::RELATION_AGGREGATE;
         p.d_op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, numerals);
       }
+     | RELATION_PROJECT_TOK nonemptyNumeralList[numerals]
+      {
+       // we adopt a special syntax (_ rel.project i_1 ... i_n) where
+       // i_1, ..., i_n are numerals
+       p.d_kind = cvc5::RELATION_PROJECT;
+       p.d_op = SOLVER->mkOp(cvc5::RELATION_PROJECT, numerals);
+      }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2238,6 +2251,7 @@ TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }
 TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group';
 RELATION_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.group';
 RELATION_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.aggr';
+RELATION_PROJECT_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.project';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index 5f4ae42e447ae1b35d5e064fad02e4718ad04c94..fb08a0565b69f9880e97a74a57d3dba868b9c79f 100644 (file)
@@ -258,7 +258,9 @@ void Smt2::addSepOperators() {
 void Smt2::addCoreSymbols()
 {
   defineType("Bool", d_solver->getBooleanSort(), true);
-  defineType("Table", d_solver->mkBagSort(d_solver->mkTupleSort({})), true);
+  Sort tupleSort = d_solver->mkTupleSort({});
+  defineType("Relation", d_solver->mkSetSort(tupleSort), true);
+  defineType("Table", d_solver->mkBagSort(tupleSort), true);
   defineVar("true", d_solver->mkTrue(), true);
   defineVar("false", d_solver->mkFalse(), true);
   addOperator(cvc5::AND, "and");
@@ -1131,7 +1133,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
   else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
            || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN
            || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP
-           || p.d_kind == cvc5::RELATION_AGGREGATE)
+           || p.d_kind == cvc5::RELATION_AGGREGATE
+           || p.d_kind == cvc5::RELATION_PROJECT)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index 3ffe7641ceb301291d49d64d62ca2b0982f5c67a..ccb93ff3c4ea172d5862c1c463cc366425d44f30 100644 (file)
@@ -851,6 +851,21 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::RELATION_PROJECT:
+  {
+    ProjectOp op = n.getOperator().getConst<ProjectOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (rel.project A)
+      out << "rel.project " << n[0] << ")";
+    }
+    else
+    {
+      // e.g. ((_ rel.project 2 4 4) A)
+      out << "(_ rel.project" << op << ") " << n[0] << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -1192,6 +1207,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::RELATION_JOIN_IMAGE: return "rel.join_image";
   case kind::RELATION_GROUP: return "rel.group";
   case kind::RELATION_AGGREGATE: return "rel.aggr";
+  case kind::RELATION_PROJECT: return "rel.project";
 
   // bag theory
   case kind::BAG_TYPE: return "Bag";
index b2713f882b3a8efefc77cd05c2585644bb536b1b..8099c134286b334d1fa6ea6a183cd90f33815aaa 100644 (file)
@@ -226,6 +226,21 @@ Node BagReduction::reduceAggregateOperator(Node node)
   return map;
 }
 
+Node BagReduction::reduceProjectOperator(Node n)
+{
+  Assert(n.getKind() == TABLE_PROJECT);
+  NodeManager* nm = NodeManager::currentNM();
+  Node A = n[0];
+  TypeNode elementType = A.getType().getBagElementType();
+  ProjectOp projectOp = n.getOperator().getConst<ProjectOp>();
+  Node op = nm->mkConst(TUPLE_PROJECT_OP, projectOp);
+  Node t = nm->mkBoundVar("t", elementType);
+  Node projection = nm->mkNode(TUPLE_PROJECT, op, t);
+  Node lambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, t), projection);
+  Node setMap = nm->mkNode(BAG_MAP, lambda, A);
+  return setMap;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5::internal
index cf391a120a3b640e4dce5c79112258dd63016937..6c7122afe14ca85144c9d78eaca7b0b5164888a9 100644 (file)
@@ -105,6 +105,12 @@ class BagReduction
    *   ((_ table.group n1 ... nk) A))
    */
   static Node reduceAggregateOperator(Node node);
+  /**
+   * @param n has the form ((table.project n1 ... nk) A) where A has type
+   *          (Bag T)
+   * @return (bag.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A)
+   */
+  static Node reduceProjectOperator(Node n);
 };
 
 }  // namespace bags
index 63112bee64214230e4d9f69e5faef6895ba59835..183f96a80997957be097d114fd66376a1dd31ab1 100644 (file)
@@ -1056,30 +1056,8 @@ Node BagsUtils::evaluateGroup(TNode n)
 Node BagsUtils::evaluateTableProject(TNode n)
 {
   Assert(n.getKind() == TABLE_PROJECT);
-  // Examples
-  // --------
-  // - ((_ table.project 1) (bag (tuple true "a") 4)) = (bag (tuple "a") 4)
-  // - (table.project (bag.union_disjoint
-  //                    (bag (tuple "a") 4)
-  //                    (bag (tuple "b") 3))) = (bag tuple 7)
-
-  Node A = n[0];
-
-  std::map<Node, Rational> elementsA = BagsUtils::getBagElements(A);
-
-  std::map<Node, Rational> elements;
-  std::vector<uint32_t> indices =
-      n.getOperator().getConst<ProjectOp>().getIndices();
-
-  for (const auto& [a, countA] : elementsA)
-  {
-    Node element = TupleUtils::getTupleProjection(indices, a);
-    // multiple elements could be projected to the same tuple.
-    // Zero is the default value for Rational values.
-    elements[element] += countA;
-  }
-
-  Node ret = BagsUtils::constructConstantBagFromElements(n.getType(), elements);
+  Node bagMap = BagReduction::reduceProjectOperator(n);
+  Node ret = evaluateBagMap(bagMap);
   return ret;
 }
 
index 1581a091bddabbb5ee699facc0ced0dd56447894..e150199eedfae0e3be1e0d9da361eacf4465a651 100644 (file)
@@ -112,6 +112,12 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
       Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl;
       return TrustNode::mkTrustRewrite(atom, ret, nullptr);
     }
+    case kind::TABLE_PROJECT:
+    {
+      Node ret = BagReduction::reduceProjectOperator(atom);
+      Trace("bags::ppr") << "reduce(" << atom << ") = " << ret << std::endl;
+      return TrustNode::mkTrustRewrite(atom, ret, nullptr);
+    }
     default: return TrustNode::null();
   }
 }
@@ -465,7 +471,6 @@ void TheoryBags::preRegisterTerm(TNode n)
     case BAG_TO_SET:
     case BAG_IS_SINGLETON:
     case BAG_PARTITION:
-    case TABLE_PROJECT:
     {
       std::stringstream ss;
       ss << "Term of kind " << n.getKind() << " is not supported yet";
index 86fa4282297902e3c3dac23e8ad3191409e4443e..a498c944330d3580f69536d15eae64d3690142a5 100644 (file)
@@ -181,13 +181,13 @@ struct TableProductTypeRule
 /**
  * Table project is indexed by a list of indices (n_1, ..., n_m). It ensures
  * that the argument is a bag of tuples whose arity k is greater than each n_i
- * for i = 1, ..., m. If the argument is of type (Bag (Tuple T_1 ... T_k)), then
- * the returned type is (Bag (Tuple T_{n_1} ... T_{n_m})).
+ * for i = 1, ..., m. If the argument is of type (Table T_1 ... T_k), then
+ * the returned type is (Table T_{n_1} ... T_{n_m}).
  */
 struct TableProjectTypeRule
 {
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
-}; /* struct BagFoldTypeRule */
+}; /* struct TableProjectTypeRule */
 
 /**
  * Table aggregate operator is indexed by a list of indices (n_1, ..., n_k).
index 4c5899e96d3fdc909d21942e66c0991b6d1c581b..07db0da7ebcca80022d3bf850ee5c86b9a450976 100644 (file)
@@ -109,6 +109,16 @@ constant RELATION_AGGREGATE_OP \
 
 parameterized RELATION_AGGREGATE RELATION_AGGREGATE_OP 3 "relation aggregate"
 
+# rel.project operator extends datatypes tuple_project operator to a set of tuples
+constant RELATION_PROJECT_OP \
+  class \
+  ProjectOp+ \
+  ::cvc5::internal::ProjectOpHashFunction \
+  "theory/datatypes/project_op.h" \
+  "operator for RELATION_PROJECT; payload is an instance of the cvc5::internal::ProjectOp class"
+
+parameterized RELATION_PROJECT RELATION_PROJECT_OP 1 "relation projection"
+
 operator RELATION_JOIN                    2  "relation join"
 operator RELATION_PRODUCT         2  "relation cartesian product"
 operator RELATION_TRANSPOSE    1  "relation transpose"
@@ -136,14 +146,16 @@ typerule SET_FOLD           ::cvc5::internal::theory::sets::SetFoldTypeRule
 
 typerule RELATION_JOIN                         ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
 typerule RELATION_PRODUCT              ::cvc5::internal::theory::sets::RelBinaryOperatorTypeRule
-typerule RELATION_TRANSPOSE            ::cvc5::internal::theory::sets::RelTransposeTypeRule
+typerule RELATION_TRANSPOSE    ::cvc5::internal::theory::sets::RelTransposeTypeRule
 typerule RELATION_TCLOSURE         ::cvc5::internal::theory::sets::RelTransClosureTypeRule
-typerule RELATION_JOIN_IMAGE       ::cvc5::internal::theory::sets::JoinImageTypeRule
+typerule RELATION_JOIN_IMAGE   ::cvc5::internal::theory::sets::JoinImageTypeRule
 typerule RELATION_IDEN                         ::cvc5::internal::theory::sets::RelIdenTypeRule
 typerule RELATION_GROUP_OP      "SimpleTypeRule<RBuiltinOperator>"
 typerule RELATION_GROUP         ::cvc5::internal::theory::sets::RelationGroupTypeRule
 typerule RELATION_AGGREGATE_OP  "SimpleTypeRule<RBuiltinOperator>"
 typerule RELATION_AGGREGATE     ::cvc5::internal::theory::sets::RelationAggregateTypeRule
+typerule RELATION_PROJECT_OP    "SimpleTypeRule<RBuiltinOperator>"
+typerule RELATION_PROJECT       ::cvc5::internal::theory::sets::RelationProjectTypeRule
 
 construle SET_UNION         ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule
 construle SET_SINGLETON     ::cvc5::internal::theory::sets::SingletonTypeRule
index f8c4ed4ae8f9441b0dadee0ace38d5daffc6d8ae..988374d6a7a62e0c5afe4cdbf3f9a160724257c2 100644 (file)
@@ -144,6 +144,21 @@ Node SetReduction::reduceAggregateOperator(Node node)
   return map;
 }
 
+Node SetReduction::reduceProjectOperator(Node n)
+{
+  Assert(n.getKind() == RELATION_PROJECT);
+  NodeManager* nm = NodeManager::currentNM();
+  Node A = n[0];
+  TypeNode elementType = A.getType().getSetElementType();
+  ProjectOp projectOp = n.getOperator().getConst<ProjectOp>();
+  Node op = nm->mkConst(TUPLE_PROJECT_OP, projectOp);
+  Node t = nm->mkBoundVar("t", elementType);
+  Node projection = nm->mkNode(TUPLE_PROJECT, op, t);
+  Node lambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, t), projection);
+  Node setMap = nm->mkNode(SET_MAP, lambda, A);
+  return setMap;
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index 3b0c4526e4ffa8431927528717c8456f4aeaa9df..38a5740c112dce6c668344c07ca5333ae038fcbb 100644 (file)
@@ -74,6 +74,11 @@ class SetReduction
    *   ((_ rel.group n1 ... nk) A))
    */
   static Node reduceAggregateOperator(Node node);
+  /**
+   * @param n has the form ((rel.project n1 ... nk) A) where A has type (Set T)
+   * @return (set.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A)
+   */
+  static Node reduceProjectOperator(Node n);
 };
 
 }  // namespace sets
index b09080186d0417d6c288104d7b0bb4d4d375344c..b1ae7a8eb7a7d97f1288f4062e03580dc186546c 100644 (file)
@@ -168,11 +168,16 @@ TrustNode TheorySets::ppRewrite(TNode n, std::vector<SkolemLemma>& lems)
     d_im.lemma(andNode, InferenceId::BAGS_FOLD);
     return TrustNode::mkTrustRewrite(n, ret, nullptr);
   }
-  if (nk == TABLE_AGGREGATE)
+  if (nk == RELATION_AGGREGATE)
   {
     Node ret = SetReduction::reduceAggregateOperator(n);
     return TrustNode::mkTrustRewrite(ret, ret, nullptr);
   }
+  if (nk == RELATION_PROJECT)
+  {
+    Node ret = SetReduction::reduceProjectOperator(n);
+    return TrustNode::mkTrustRewrite(ret, ret, nullptr);
+  }
   return d_internal->ppRewrite(n, lems);
 }
 
index e9b3b31b05a2e68e17e004be232753e8046e8580..cbb65b8769846fbc3c45dab81d82a116ebf32423 100644 (file)
@@ -22,6 +22,7 @@
 #include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
 #include "theory/sets/rels_utils.h"
+#include "theory/sets/set_reduction.h"
 #include "util/rational.h"
 
 using namespace cvc5::internal::kind;
@@ -591,6 +592,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
 
   case RELATION_GROUP: return postRewriteGroup(node);
   case RELATION_AGGREGATE: return postRewriteAggregate(node);
+  case RELATION_PROJECT: return postRewriteProject(node);
   default: break;
   }
 
@@ -781,6 +783,17 @@ RewriteResponse TheorySetsRewriter::postRewriteAggregate(TNode n)
   return RewriteResponse(REWRITE_DONE, n);
 }
 
+RewriteResponse TheorySetsRewriter::postRewriteProject(TNode n)
+{
+  Assert(n.getKind() == RELATION_PROJECT);
+  Node ret = SetReduction::reduceProjectOperator(n);
+  if (ret != n)
+  {
+    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+  }
+  return RewriteResponse(REWRITE_DONE, n);
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index 2d321f11c0c544facecd66fc993b3f99dc8285a4..b98e3a531ee0512c00b93703455a14c3a5b0c46c 100644 (file)
@@ -118,6 +118,11 @@ class TheorySetsRewriter : public TheoryRewriter
    * @return the aggregation result.
    */
   RewriteResponse postRewriteAggregate(TNode n);
+  /**
+   * If A has type (Set T), then rewrite ((rel.project n1 ... nk) A) as
+   * (set.map (lambda ((t T)) ((_ tuple.project n1 ... nk) t)) A)
+   */
+  RewriteResponse postRewriteProject(TNode n);
 }; /* class TheorySetsRewriter */
 
 }  // namespace sets
index f9dd7d390e4aabf032449f4a80eaae4343ac39b7..83f56b90a67bf53acb53e7ce9d70cb647594dea1 100644 (file)
 
 #include "theory/sets/theory_sets_type_rules.h"
 
-#include <climits>
 #include <sstream>
 
+#include "expr/dtype.h"
+#include "expr/dtype_cons.h"
 #include "theory/sets/normal_form.h"
 #include "util/cardinality.h"
 #include "theory/datatypes/project_op.h"
@@ -632,7 +633,7 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm,
     if (!setType.isSet())
     {
       std::stringstream ss;
-      ss << "RELATION_PROJECT operator expects a table. Found '" << n[2]
+      ss << "RELATION_AGGREGATE operator expects a set. Found '" << n[2]
          << "' of type '" << setType << "'.";
       throw TypeCheckingExceptionPrivate(n, ss.str());
     }
@@ -641,7 +642,7 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm,
     if (!tupleType.isTuple())
     {
       std::stringstream ss;
-      ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+      ss << "RELATION_AGGREGATE operator expects a relation. Found '" << n[2]
          << "' of type '" << setType << "'.";
       throw TypeCheckingExceptionPrivate(n, ss.str());
     }
@@ -680,6 +681,63 @@ TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm,
   return nm->mkSetType(functionType.getRangeType());
 }
 
+TypeNode RelationProjectTypeRule::computeType(NodeManager* nm,
+                                              TNode n,
+                                              bool check)
+{
+  Assert(n.getKind() == kind::RELATION_PROJECT && n.hasOperator()
+         && n.getOperator().getKind() == kind::RELATION_PROJECT_OP);
+  ProjectOp op = n.getOperator().getConst<ProjectOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+  TypeNode setType = n[0].getType(check);
+  if (check)
+  {
+    if (n.getNumChildren() != 1)
+    {
+      std::stringstream ss;
+      ss << "operands in term " << n << " are " << n.getNumChildren()
+         << ", but RELATION_PROJECT expects 1 operand.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    if (!setType.isSet())
+    {
+      std::stringstream ss;
+      ss << "RELATION_PROJECT operator expects a set. Found '" << n[0]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = setType.getSetElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "RELATION_PROJECT operator expects a relation. Found '" << n[0]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    // make sure all indices are less than the length of the tuple type
+    DType dType = tupleType.getDType();
+    DTypeConstructor constructor = dType[0];
+    size_t numArgs = constructor.getNumArgs();
+    for (uint32_t index : indices)
+    {
+      std::stringstream ss;
+      if (index >= numArgs)
+      {
+        ss << "Index " << index << " in term " << n << " is >= " << numArgs
+           << " which is the number of columns in " << n[0] << ".";
+        throw TypeCheckingExceptionPrivate(n, ss.str());
+      }
+    }
+  }
+  TypeNode tupleType = setType.getSetElementType();
+  TypeNode retTupleType =
+      TupleUtils::getTupleProjectionType(indices, tupleType);
+  return nm->mkSetType(retTupleType);
+}
+
 Cardinality SetsProperties::computeCardinality(TypeNode type)
 {
   Assert(type.getKind() == kind::SET_TYPE);
index ed973669ebb0f813f842894fd1e640de0b3e6cf3..9755961c8d658ff51988b3a36cd7c88b43c9aa76 100644 (file)
@@ -223,6 +223,17 @@ struct RelationGroupTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct RelationGroupTypeRule */
 
+/**
+ * Relation project is indexed by a list of indices (n_1, ..., n_m). It ensures
+ * that the argument is a set of tuples whose arity k is greater than each n_i
+ * for i = 1, ..., m. If the argument is of type (Relation T_1 ... T_k), then
+ * the returned type is (Relation T_{n_1} ... T_{n_m}).
+ */
+struct RelationProjectTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct RelationProjectTypeRule */
+
 /**
  * Relation aggregate operator is indexed by a list of indices (n_1, ..., n_k).
  * It ensures that it has 3 arguments:
index bfb23057f48c2a57c5a8f315419b5a88efafece8..c8efaad52e4974326e052ffa75cc22d2a8a1ccf0 100644 (file)
@@ -1849,6 +1849,7 @@ set(regress_1_tests
   regress1/bags/table_join2.smt2
   regress1/bags/table_join3.smt2
   regress1/bags/table_project1.smt2
+  regress1/bags/table_project2.smt2
   regress1/bags/union_disjoint.smt2
   regress1/bags/union_max1.smt2
   regress1/bags/union_max2.smt2
@@ -2511,6 +2512,8 @@ set(regress_1_tests
   regress1/sets/relation_group3.smt2
   regress1/sets/relation_group4.smt2
   regress1/sets/relation_group5.smt2
+  regress1/sets/relation_project1.smt2
+  regress1/sets/relation_project2.smt2
   regress1/sets/remove_check_free_31_6.smt2
   regress1/sets/sets-disequal.smt2
   regress1/sets/sets-tuple-poly.cvc.smt2
diff --git a/test/regress/cli/regress1/bags/table_project2.smt2 b/test/regress/cli/regress1/bags/table_project2.smt2
new file mode 100644 (file)
index 0000000..1af4e1e
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+(declare-fun A () (Table String String))
+(declare-fun B () (Table String String))
+
+(assert (= B ((_ table.project 1 0) A)))
+(assert (bag.member (tuple "y" "x") B))
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_project1.smt2 b/test/regress/cli/regress1/sets/relation_project1.smt2
new file mode 100644 (file)
index 0000000..7f1dea3
--- /dev/null
@@ -0,0 +1,26 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(declare-fun A () (Relation String Int String Bool))
+(declare-fun B () (Relation Int Bool String String))
+(declare-fun C () (Relation String String))
+(declare-fun D () Relation)
+
+(assert
+ (= A
+    (set.union
+     (set.singleton (tuple "x" 0 "y" false))
+     (set.singleton (tuple "x" 1 "z" true)))))
+
+
+; (set.union (set.singleton (tuple 0 false "x" "y")) (set.singleton (tuple 1 true "x" "z")))
+(assert (= B ((_ rel.project 1 3 0 2) A)))
+
+; (set.singleton (tuple "x" "x"))
+(assert (= C ((_ rel.project 0 0) A)))
+
+; (set.singleton tuple)
+(assert (= D (rel.project A)))
+
+(check-sat)
diff --git a/test/regress/cli/regress1/sets/relation_project2.smt2 b/test/regress/cli/regress1/sets/relation_project2.smt2
new file mode 100644 (file)
index 0000000..ca39f22
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :uf-lazy-ll true)
+
+(declare-fun A () (Relation String String))
+(declare-fun B () (Relation String String))
+
+(assert (= B ((_ rel.project 1 0) A)))
+(assert (set.member (tuple "y" "x") B))
+(check-sat)
index 498c18c001180ac1c1d74e3162b2ab72a0964d9d..72e05acd973ef1e7d02390d87398a40d2a728060 100644 (file)
@@ -103,6 +103,9 @@ TEST_F(TestApiBlackOp, getNumIndices)
   Op tupleProject = d_solver.mkOp(TUPLE_PROJECT, indices);
   ASSERT_EQ(indices.size(), tupleProject.getNumIndices());
 
+  Op relationProject = d_solver.mkOp(RELATION_PROJECT, indices);
+  ASSERT_EQ(indices.size(), relationProject.getNumIndices());
+
   Op tableProject = d_solver.mkOp(TABLE_PROJECT, indices);
   ASSERT_EQ(indices.size(), tableProject.getNumIndices());
 }
index c14518a2f68c4a4b7a3c8901125cf551d28dfcbf..688c0aaea3c745c8f4354a2c8b1a93de7e0738d0 100644 (file)
@@ -120,6 +120,9 @@ class OpTest
     Op tupleProject = d_solver.mkOp(TUPLE_PROJECT, indices);
     assertEquals(6, tupleProject.getNumIndices());
 
+    Op relationProject = d_solver.mkOp(RELATION_PROJECT, indices);
+    assertEquals(6, relationProject.getNumIndices());
+
     Op tableProject = d_solver.mkOp(TABLE_PROJECT, indices);
     assertEquals(6, tableProject.getNumIndices());
   }
index 4ba60792654e63849d37825badba15fb129ec2f0..5959a53c96422747d83e2c712f3e17405c6a4349 100644 (file)
@@ -105,6 +105,9 @@ def test_get_num_indices(solver):
     tuple_project_op = solver.mkOp(Kind.TUPLE_PROJECT, *indices)
     assert len(indices) == tuple_project_op.getNumIndices()
 
+    relation_project_op = solver.mkOp(Kind.RELATION_PROJECT, *indices)
+    assert len(indices) == relation_project_op.getNumIndices()
+
     table_project_op = solver.mkOp(Kind.TABLE_PROJECT, *indices)
     assert len(indices) == table_project_op.getNumIndices()