Add set.aggr operator to sets (#8878)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 30 Jun 2022 02:03:36 +0000 (21:03 -0500)
committerGitHub <noreply@github.com>
Thu, 30 Jun 2022 02:03:36 +0000 (02:03 +0000)
This PR depends on #8876

17 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/Smt2.g
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/sets/kinds
src/theory/sets/rels_utils.cpp
src/theory/sets/rels_utils.h
src/theory/sets/set_reduction.cpp
src/theory/sets/set_reduction.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets_rewriter.cpp
src/theory/sets/theory_sets_rewriter.h
src/theory/sets/theory_sets_type_rules.cpp
src/theory/sets/theory_sets_type_rules.h
test/regress/cli/CMakeLists.txt
test/regress/cli/regress1/sets/relation_aggregate1.smt2 [new file with mode: 0644]

index 1c45d2d4d0735cb45e08f77b680b4e0404c760b3..692eff021ca00438a64b4d458e72cb13361fa290 100644 (file)
@@ -310,6 +310,7 @@ const static std::unordered_map<Kind, std::pair<internal::Kind, std::string>>
         KIND_ENUM(RELATION_JOIN_IMAGE, internal::Kind::RELATION_JOIN_IMAGE),
         KIND_ENUM(RELATION_IDEN, internal::Kind::RELATION_IDEN),
         KIND_ENUM(RELATION_GROUP, internal::Kind::RELATION_GROUP),
+        KIND_ENUM(RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE),
         /* Bags ------------------------------------------------------------- */
         KIND_ENUM(BAG_UNION_MAX, internal::Kind::BAG_UNION_MAX),
         KIND_ENUM(BAG_UNION_DISJOINT, internal::Kind::BAG_UNION_DISJOINT),
@@ -635,6 +636,8 @@ const static std::unordered_map<internal::Kind,
         {internal::Kind::RELATION_JOIN_IMAGE, RELATION_JOIN_IMAGE},
         {internal::Kind::RELATION_IDEN, RELATION_IDEN},
         {internal::Kind::RELATION_GROUP, RELATION_GROUP},
+        {internal::Kind::RELATION_AGGREGATE_OP, RELATION_AGGREGATE},
+        {internal::Kind::RELATION_AGGREGATE, RELATION_AGGREGATE},
         /* Bags ------------------------------------------------------------ */
         {internal::Kind::BAG_UNION_MAX, BAG_UNION_MAX},
         {internal::Kind::BAG_UNION_DISJOINT, BAG_UNION_DISJOINT},
@@ -774,6 +777,7 @@ const static std::unordered_map<Kind, internal::Kind> s_op_kinds{
     {REGEXP_REPEAT, internal::Kind::REGEXP_REPEAT_OP},
     {REGEXP_LOOP, internal::Kind::REGEXP_LOOP_OP},
     {TUPLE_PROJECT, internal::Kind::TUPLE_PROJECT_OP},
+    {RELATION_AGGREGATE, internal::Kind::RELATION_AGGREGATE_OP},
     {RELATION_GROUP, internal::Kind::RELATION_GROUP_OP},
     {TABLE_PROJECT, internal::Kind::TABLE_PROJECT_OP},
     {TABLE_AGGREGATE, internal::Kind::TABLE_AGGREGATE_OP},
@@ -1953,6 +1957,7 @@ size_t Op::getNumIndicesHelper() const
     case FLOATINGPOINT_TO_FP_FROM_UBV: size = 2; break;
     case REGEXP_LOOP: size = 2; break;
     case TUPLE_PROJECT:
+    case RELATION_AGGREGATE:
     case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
@@ -2114,6 +2119,7 @@ Term Op::getIndexHelper(size_t index) const
       break;
     }
     case TUPLE_PROJECT:
+    case RELATION_AGGREGATE:
     case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
@@ -6156,6 +6162,7 @@ Op Solver::mkOp(Kind kind, const std::vector<uint32_t>& args) const
       res = mkOpHelper(kind, internal::RegExpLoop(args[0], args[1]));
       break;
     case TUPLE_PROJECT:
+    case RELATION_AGGREGATE:
     case RELATION_GROUP:
     case TABLE_AGGREGATE:
     case TABLE_GROUP:
index f71d6a447932443ec43be23c12c0a59825a1b414..4d4234f03f1913a65d313a4e1be3641219bb59ba 100644 (file)
@@ -3361,6 +3361,43 @@ enum Kind : int32_t
    * \endrst
    */
   RELATION_GROUP,
+  /**
+   * \rst
+   *
+   * Relation aggregate operator has the form
+   * :math:`((\_ \; rel.aggr \; n_1 ... n_k) \; f \; i \; A)`
+   * where :math:`n_1, ..., n_k` are natural numbers,
+   * :math:`f` is a function of type
+   * :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`,
+   * :math:`i` has the type :math:`T`,
+   * and :math:`A` has type :math:`(Relation \;  T_1 \; ... \; T_j)`.
+   * The returned type is :math:`(Set \; T)`.
+   *
+   * This operator aggregates elements in A that have the same tuple projection
+   * with indices n_1, ..., n_k using the combining function :math:`f`,
+   * and initial value :math:`i`.
+   *
+   * - Arity: ``3``
+   *
+   *   - ``1:`` Term of sort :math:`(\rightarrow (Tuple \;  T_1 \; ... \; T_j)\; T \; T)`
+   *   - ``2:`` Term of Sort :math:`T`
+   *   - ``3:`` Term of relation sort :math:`Relation T_1 ... T_j`
+   *
+   * - Indices: ``n``
+   *   - ``1..n:`` Indices of the projection
+   * \endrst
+   * - Create Term of this Kind with:
+   *   - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+   *
+   * - Create Op of this kind with:
+   *   - Solver::mkOp(Kind, const std::vector<uint32_t>&) const
+   *
+   * \rst
+   * .. warning:: This kind is experimental and may be changed or removed in
+   *              future versions.
+   * \endrst
+   */
+  RELATION_AGGREGATE,
 
   /* Bags ------------------------------------------------------------------ */
 
index 875e41ee43b37001c522aee4b12114217ac3f12a..96421e6ce00731d6bc5f4d68311674689a88e8de 100644 (file)
@@ -1435,6 +1435,12 @@ termNonVariable[cvc5::Term& expr, cvc5::Term& expr2]
     cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_GROUP, indices);
     expr = SOLVER->mkTerm(op, {expr});
   }
+  | LPAREN_TOK RELATION_AGGREGATE_TOK term[expr,expr2] RPAREN_TOK
+  {
+    std::vector<uint32_t> indices;
+    cvc5::Op op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, indices);
+    expr = SOLVER->mkTerm(op, {expr});
+  }
   | /* an atomic term (a term with no subterms) */
     termAtomic[atomTerm] { expr = atomTerm; }
   ;
@@ -1611,6 +1617,13 @@ identifier[cvc5::ParseOp& p]
         p.d_kind = cvc5::RELATION_GROUP;
         p.d_op = SOLVER->mkOp(cvc5::RELATION_GROUP, numerals);
       }
+     | RELATION_AGGREGATE_TOK nonemptyNumeralList[numerals]
+      {
+        // we adopt a special syntax (_ rel.aggr i_1 ... i_n) where
+        // i_1, ..., i_n are numerals
+        p.d_kind = cvc5::RELATION_AGGREGATE;
+        p.d_op = SOLVER->mkOp(cvc5::RELATION_AGGREGATE, numerals);
+      }
     | functionName[opName, CHECK_NONE] nonemptyNumeralList[numerals]
       {
         cvc5::Kind k = PARSER_STATE->getIndexedOpKind(opName);
@@ -2224,6 +2237,7 @@ TABLE_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BA
 TABLE_JOIN_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.join';
 TABLE_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_BAGS) }? 'table.group';
 RELATION_GROUP_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.group';
+RELATION_AGGREGATE_TOK: { PARSER_STATE->isTheoryEnabled(internal::theory::THEORY_SETS) }? 'rel.aggr';
 FMF_CARD_TOK: { !PARSER_STATE->strictModeEnabled() && PARSER_STATE->hasCardinalityConstraints() }? 'fmf.card';
 
 HO_ARROW_TOK : { PARSER_STATE->isHoEnabled() }? '->';
index 900ff1016a39ad710d500b3e2e6332d7db600558..5f4ae42e447ae1b35d5e064fad02e4718ad04c94 100644 (file)
@@ -1130,7 +1130,8 @@ cvc5::Term Smt2::applyParseOp(ParseOp& p, std::vector<cvc5::Term>& args)
   }
   else if (p.d_kind == cvc5::TUPLE_PROJECT || p.d_kind == cvc5::TABLE_PROJECT
            || p.d_kind == cvc5::TABLE_AGGREGATE || p.d_kind == cvc5::TABLE_JOIN
-           || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP)
+           || p.d_kind == cvc5::TABLE_GROUP || p.d_kind == cvc5::RELATION_GROUP
+           || p.d_kind == cvc5::RELATION_AGGREGATE)
   {
     cvc5::Term ret = d_solver->mkTerm(p.d_op, args);
     Trace("parser") << "applyParseOp: return projection " << ret << std::endl;
index 5dc9252a97b33a2b5ae4d8682d02808484cf003c..3ffe7641ceb301291d49d64d62ca2b0982f5c67a 100644 (file)
@@ -779,7 +779,7 @@ void Smt2Printer::toStream(std::ostream& out,
     ProjectOp op = n.getOperator().getConst<ProjectOp>();
     if (op.getIndices().empty())
     {
-      // e.g. (table.project function initial_value bag)
+      // e.g. (table.aggr function initial_value bag)
       out << "table.aggr " << n[0] << " " << n[1] << " " << n[2] << ")";
     }
     else
@@ -835,6 +835,22 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
   }
+  case kind::RELATION_AGGREGATE:
+  {
+    ProjectOp op = n.getOperator().getConst<ProjectOp>();
+    if (op.getIndices().empty())
+    {
+      // e.g. (rel.aggr function initial_value bag)
+      out << "rel.aggr " << n[0] << " " << n[1] << " " << n[2] << ")";
+    }
+    else
+    {
+      // e.g.  ((_ rel.aggr 0) function initial_value bag)
+      out << "(_ rel.aggr" << op << ") " << n[0] << " " << n[1] << " " << n[2]
+          << ")";
+    }
+    return;
+  }
   case kind::CONSTRUCTOR_TYPE:
   {
     out << n[n.getNumChildren()-1];
@@ -1175,6 +1191,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::RELATION_IDEN: return "rel.iden";
   case kind::RELATION_JOIN_IMAGE: return "rel.join_image";
   case kind::RELATION_GROUP: return "rel.group";
+  case kind::RELATION_AGGREGATE: return "rel.aggr";
 
   // bag theory
   case kind::BAG_TYPE: return "Bag";
index 0c0e17dc13053181e761e95a874cd1f25df88c5d..4c5899e96d3fdc909d21942e66c0991b6d1c581b 100644 (file)
@@ -95,10 +95,20 @@ constant RELATION_GROUP_OP \
   ProjectOp+ \
   ::cvc5::internal::ProjectOpHashFunction \
   "theory/datatypes/project_op.h" \
-  "operator for RELATION_GROUP; payload is an instance of the cvc5::internal::RelationGroupOp class"
+  "operator for RELATION_GROUP; payload is an instance of the cvc5::internal::ProjectOp class"
 
 parameterized RELATION_GROUP RELATION_GROUP_OP 1 "relation group"
 
+# relation aggregate operator
+constant RELATION_AGGREGATE_OP \
+  class \
+  ProjectOp+ \
+  ::cvc5::internal::ProjectOpHashFunction \
+  "theory/datatypes/project_op.h" \
+  "operator for RELATION_AGGREGATE; payload is an instance of the cvc5::internal::ProjectOp class"
+
+parameterized RELATION_AGGREGATE RELATION_AGGREGATE_OP 3 "relation aggregate"
+
 operator RELATION_JOIN                    2  "relation join"
 operator RELATION_PRODUCT         2  "relation cartesian product"
 operator RELATION_TRANSPOSE    1  "relation transpose"
@@ -132,6 +142,8 @@ typerule RELATION_JOIN_IMAGE            ::cvc5::internal::theory::sets::JoinImageTypeRu
 typerule RELATION_IDEN                         ::cvc5::internal::theory::sets::RelIdenTypeRule
 typerule RELATION_GROUP_OP      "SimpleTypeRule<RBuiltinOperator>"
 typerule RELATION_GROUP         ::cvc5::internal::theory::sets::RelationGroupTypeRule
+typerule RELATION_AGGREGATE_OP  "SimpleTypeRule<RBuiltinOperator>"
+typerule RELATION_AGGREGATE     ::cvc5::internal::theory::sets::RelationAggregateTypeRule
 
 construle SET_UNION         ::cvc5::internal::theory::sets::SetsBinaryOperatorTypeRule
 construle SET_SINGLETON     ::cvc5::internal::theory::sets::SingletonTypeRule
index 08d4feb3608e8b2c23a54c68945f37a89ab41917..6f6c013dc767e156afdc471ed51bb421989c254e 100644 (file)
@@ -20,6 +20,7 @@
 #include "theory/datatypes/project_op.h"
 #include "theory/datatypes/tuple_utils.h"
 #include "theory/sets/normal_form.h"
+#include "theory/sets/set_reduction.h"
 
 using namespace cvc5::internal::kind;
 using namespace cvc5::internal::theory::datatypes;
@@ -151,6 +152,19 @@ Node RelsUtils::evaluateGroup(TNode n)
   return ret;
 }
 
+Node RelsUtils::evaluateRelationAggregate(TNode n)
+{
+  Assert(n.getKind() == RELATION_AGGREGATE);
+  if (!(n[1].isConst() && n[2].isConst()))
+  {
+    // we can't proceed further.
+    return n;
+  }
+
+  Node reduction = SetReduction::reduceAggregateOperator(n);
+  return reduction;
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index 559ef52817df32f1e914af25b9c630ab6537f45b..3df5a28795909ded42521f70cbe8e581584d1279 100644 (file)
@@ -72,6 +72,13 @@ class RelsUtils
    * projection with indices n_1 ... n_k
    */
   static Node evaluateGroup(TNode n);
+
+  /**
+   * @param n has the form ((_ rel.aggr n1 ... n_k) f initial A)
+   * where initial and A are constants
+   * @return the aggregation result.
+   */
+  static Node evaluateRelationAggregate(TNode n);
 };
 }  // namespace sets
 }  // namespace theory
index 78b5b8036017b0af3a9078c419bb7208e67b6643..f8c4ed4ae8f9441b0dadee0ace38d5daffc6d8ae 100644 (file)
@@ -18,7 +18,7 @@
 #include "expr/bound_var_manager.h"
 #include "expr/emptyset.h"
 #include "expr/skolem_manager.h"
-#include "theory/datatypes/tuple_utils.h"
+#include "theory/datatypes//project_op.h"
 #include "theory/quantifiers/fmf/bounded_integers.h"
 #include "util/rational.h"
 
@@ -120,6 +120,30 @@ Node SetReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
   return combine_n;
 }
 
+Node SetReduction::reduceAggregateOperator(Node node)
+{
+  Assert(node.getKind() == RELATION_AGGREGATE);
+  NodeManager* nm = NodeManager::currentNM();
+  BoundVarManager* bvm = nm->getBoundVarManager();
+  Node function = node[0];
+  TypeNode elementType = function.getType().getArgTypes()[0];
+  Node initialValue = node[1];
+  Node A = node[2];
+
+  ProjectOp op = node.getOperator().getConst<ProjectOp>();
+  Node groupOp = nm->mkConst(RELATION_GROUP_OP, op);
+  Node group = nm->mkNode(RELATION_GROUP, {groupOp, A});
+
+  Node set = bvm->mkBoundVar<FirstIndexVarAttribute>(
+      group, "set", nm->mkSetType(elementType));
+  Node foldList = nm->mkNode(BOUND_VAR_LIST, set);
+  Node foldBody = nm->mkNode(SET_FOLD, function, initialValue, set);
+
+  Node fold = nm->mkNode(LAMBDA, foldList, foldBody);
+  Node map = nm->mkNode(SET_MAP, fold, group);
+  return map;
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index 43172012b2ac92358eaebae81f25421625c4803d..3b0c4526e4ffa8431927528717c8456f4aeaa9df 100644 (file)
@@ -64,6 +64,16 @@ class SetReduction
    * unionFn: Int -> (Set T1) is an uninterpreted function
    */
   static Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+
+  /**
+   * @param node of the form ((_ rel.aggr n1 ... nk) f initial A))
+   * @return reduction term that uses map, fold, and group operators
+   * as follows:
+   * (set.map
+   *   (lambda ((B Table)) (set.fold f initial B))
+   *   ((_ rel.group n1 ... nk) A))
+   */
+  static Node reduceAggregateOperator(Node node);
 };
 
 }  // namespace sets
index 820f33e3bcf7d418ae77e6709979ff25e8de1ca5..b09080186d0417d6c288104d7b0bb4d4d375344c 100644 (file)
@@ -168,6 +168,11 @@ TrustNode TheorySets::ppRewrite(TNode n, std::vector<SkolemLemma>& lems)
     d_im.lemma(andNode, InferenceId::BAGS_FOLD);
     return TrustNode::mkTrustRewrite(n, ret, nullptr);
   }
+  if (nk == TABLE_AGGREGATE)
+  {
+    Node ret = SetReduction::reduceAggregateOperator(n);
+    return TrustNode::mkTrustRewrite(ret, ret, nullptr);
+  }
   return d_internal->ppRewrite(n, lems);
 }
 
index 5fcd15dd1e44713e267b1a742134d83bccaa5ba8..e9b3b31b05a2e68e17e004be232753e8046e8580 100644 (file)
@@ -590,6 +590,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
   }
 
   case RELATION_GROUP: return postRewriteGroup(node);
+  case RELATION_AGGREGATE: return postRewriteAggregate(node);
   default: break;
   }
 
@@ -765,6 +766,21 @@ RewriteResponse TheorySetsRewriter::postRewriteGroup(TNode n)
   return RewriteResponse(REWRITE_DONE, n);
 }
 
+RewriteResponse TheorySetsRewriter::postRewriteAggregate(TNode n)
+{
+  Assert(n.getKind() == kind::RELATION_AGGREGATE);
+  if (n[1].isConst() && n[2].isConst())
+  {
+    Node ret = RelsUtils::evaluateRelationAggregate(n);
+    if (ret != n)
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+    }
+  }
+
+  return RewriteResponse(REWRITE_DONE, n);
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace cvc5::internal
index dc4f64762e979f10fecf2f9d29735014c262bc1e..2d321f11c0c544facecd66fc993b3f99dc8285a4 100644 (file)
@@ -69,48 +69,55 @@ class TheorySetsRewriter : public TheoryRewriter
     // often this will suffice
     return postRewrite(equality).d_node;
   }
-private:
+
+ private:
   /**
    * Returns true if elementTerm is in setTerm, where both terms are constants.
    */
   bool checkConstantMembership(TNode elementTerm, TNode setTerm);
- /**
-  *  rewrites for n include:
-  *  - (set.map f (as set.empty (Set T1)) = (as set.empty (Set T2))
-  *  - (set.map f (set.singleton x)) = (set.singleton (apply f x))
-  *  - (set.map f (set.union A B)) =
-  *       (set.union (set.map f A) (set.map f B))
-  *  where f: T1 -> T2
-  */
- RewriteResponse postRewriteMap(TNode n);
 /**
+   *  rewrites for n include:
+   *  - (set.map f (as set.empty (Set T1)) = (as set.empty (Set T2))
+   *  - (set.map f (set.singleton x)) = (set.singleton (apply f x))
+   *  - (set.map f (set.union A B)) =
+   *       (set.union (set.map f A) (set.map f B))
+   *  where f: T1 -> T2
+   */
 RewriteResponse postRewriteMap(TNode n);
 
- /**
-  *  rewrites for n include:
-  *  - (set.filter p (as set.empty (Set T)) = (as set.empty (Set T))
-  *  - (set.filter p (set.singleton x)) =
-  *       (ite (p x) (set.singleton x) (as set.empty (Set T)))
-  *  - (set.filter p (set.union A B)) =
-  *       (set.union (set.filter p A) (set.filter p B))
-  *  where p: T -> Bool
-  */
- RewriteResponse postRewriteFilter(TNode n);
- /**
-  *  rewrites for n include:
-  *  - (set.fold f t (as set.empty (Set T))) = t
-  *  - (set.fold f t (set.singleton x)) = (f t x)
-  *  - (set.fold f t (set.union A B)) = (set.fold f (set.fold f t A) B))
-  *  where f: T -> S -> S, and t : S
-  */
- RewriteResponse postRewriteFold(TNode n);
- /**
-  *  rewrites for n include:
-  *  - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) =
-  *          (rel.singleton (as set.empty (Relation T) ))
-  *  - ((_ rel.group n1 ... nk) (set.singleton x)) =
-  *          (set.singleton (set.singleton x))
-  *  - Evaluation of ((_ rel.group n1 ... nk) A) when A is a constant
-  */
- RewriteResponse postRewriteGroup(TNode n);
+  /**
+   *  rewrites for n include:
+   *  - (set.filter p (as set.empty (Set T)) = (as set.empty (Set T))
+   *  - (set.filter p (set.singleton x)) =
+   *       (ite (p x) (set.singleton x) (as set.empty (Set T)))
+   *  - (set.filter p (set.union A B)) =
+   *       (set.union (set.filter p A) (set.filter p B))
+   *  where p: T -> Bool
+   */
+  RewriteResponse postRewriteFilter(TNode n);
+  /**
+   *  rewrites for n include:
+   *  - (set.fold f t (as set.empty (Set T))) = t
+   *  - (set.fold f t (set.singleton x)) = (f t x)
+   *  - (set.fold f t (set.union A B)) = (set.fold f (set.fold f t A) B))
+   *  where f: T -> S -> S, and t : S
+   */
+  RewriteResponse postRewriteFold(TNode n);
+  /**
+   *  rewrites for n include:
+   *  - ((_ rel.group n1 ... nk) (as set.empty (Relation T))) =
+   *          (rel.singleton (as set.empty (Relation T) ))
+   *  - ((_ rel.group n1 ... nk) (set.singleton x)) =
+   *          (set.singleton (set.singleton x))
+   *  - Evaluation of ((_ rel.group n1 ... nk) A) when A is a constant
+   */
+  RewriteResponse postRewriteGroup(TNode n);
+  /**
+   * @param n has the form ((_ rel.aggr n1 ... n_k) f initial A)
+   * where initial and A are constants
+   * @return the aggregation result.
+   */
+  RewriteResponse postRewriteAggregate(TNode n);
 }; /* class TheorySetsRewriter */
 
 }  // namespace sets
index 8a163489a436e621be1241033ed3947bc746a049..f9dd7d390e4aabf032449f4a80eaae4343ac39b7 100644 (file)
@@ -27,6 +27,8 @@ namespace cvc5::internal {
 namespace theory {
 namespace sets {
 
+using namespace cvc5::internal::theory::datatypes;
+
 TypeNode SetsBinaryOperatorTypeRule::computeType(NodeManager* nodeManager,
                                                  TNode n,
                                                  bool check)
@@ -612,6 +614,72 @@ TypeNode RelationGroupTypeRule::computeType(NodeManager* nm, TNode n, bool check
   return nm->mkSetType(setType);
 }
 
+TypeNode RelationAggregateTypeRule::computeType(NodeManager* nm,
+                                                TNode n,
+                                                bool check)
+{
+  Assert(n.getKind() == kind::RELATION_AGGREGATE && n.hasOperator()
+         && n.getOperator().getKind() == kind::RELATION_AGGREGATE_OP);
+  ProjectOp op = n.getOperator().getConst<ProjectOp>();
+  const std::vector<uint32_t>& indices = op.getIndices();
+
+  TypeNode functionType = n[0].getType(check);
+  TypeNode initialValueType = n[1].getType(check);
+  TypeNode setType = n[2].getType(check);
+
+  if (check)
+  {
+    if (!setType.isSet())
+    {
+      std::stringstream ss;
+      ss << "RELATION_PROJECT operator expects a table. Found '" << n[2]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TypeNode tupleType = setType.getSetElementType();
+    if (!tupleType.isTuple())
+    {
+      std::stringstream ss;
+      ss << "TABLE_PROJECT operator expects a table. Found '" << n[2]
+         << "' of type '" << setType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+
+    TupleUtils::checkTypeIndices(n, tupleType, indices);
+
+    TypeNode elementType = setType.getSetElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T T) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    TypeNode rangeType = functionType.getRangeType();
+    if (!(argTypes.size() == 2 && argTypes[0] == elementType
+          && argTypes[1] == rangeType))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T T). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    if (rangeType != initialValueType)
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects an initial value of type "
+         << rangeType << ". Found a term of type '" << initialValueType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  return nm->mkSetType(functionType.getRangeType());
+}
+
 Cardinality SetsProperties::computeCardinality(TypeNode type)
 {
   Assert(type.getKind() == kind::SET_TYPE);
index 551513fe6fc3a3cf63f14aff1fc5ae6c63009b22..ed973669ebb0f813f842894fd1e640de0b3e6cf3 100644 (file)
@@ -223,6 +223,19 @@ struct RelationGroupTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct RelationGroupTypeRule */
 
+/**
+ * Relation aggregate operator is indexed by a list of indices (n_1, ..., n_k).
+ * It ensures that it has 3 arguments:
+ * - A combining function of type (-> (Tuple T_1 ... T_j) T T)
+ * - Initial value of type T
+ * - A relation of type (Relation T_1 ... T_j) where 0 <= n_1, ..., n_k < j
+ * the returned type is (Relation T).
+ */
+struct RelationAggregateTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct RelationAggregateTypeRule */
+
 struct SetsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index 2ecf50241601583bef7bcf8b2ac82f6014c70af1..75f6a29297147400696f6fee5d0b6bdd2040c66d 100644 (file)
@@ -2505,6 +2505,7 @@ set(regress_1_tests
   regress1/sets/proj-issue164.smt2
   regress1/sets/proj-issue178.smt2
   regress1/sets/proj-issue494-finite-leafof.smt2
+  regress1/sets/relation_aggregate1.smt2
   regress1/sets/relation_group1.smt2
   regress1/sets/relation_group2.smt2
   regress1/sets/relation_group3.smt2
diff --git a/test/regress/cli/regress1/sets/relation_aggregate1.smt2 b/test/regress/cli/regress1/sets/relation_aggregate1.smt2
new file mode 100644 (file)
index 0000000..c467790
--- /dev/null
@@ -0,0 +1,31 @@
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+(define-fun sumByCategory ((x (Tuple String String Int)) (y (Tuple String Int))) (Tuple String Int)
+  (tuple
+   ((_ tuple.select 0) x)
+   (+ ((_ tuple.select 2) x) ((_ tuple.select 1) y))))
+
+(declare-fun categorySales () (Set (Tuple String Int)))
+
+;(define-fun categorySales () (Set (Tuple String Int))
+;  (set.union
+;   (set.singleton (tuple "Software" 5))
+;   (set.singleton (tuple "Hardware" 4))))
+
+(assert
+ (= categorySales
+    ((_ rel.aggr 0)
+      sumByCategory
+      (tuple "" 0)
+      (set.union
+       (set.singleton (tuple "Software" "win" 1))
+       (set.singleton (tuple "Software" "mac" 4))
+       (set.singleton (tuple "Hardware" "cpu" 2))
+       (set.singleton (tuple "Hardware" "gpu" 2))))))
+
+(check-sat)