Use function array constants in HO solver (#8818)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 26 May 2022 20:16:15 +0000 (15:16 -0500)
committerGitHub <noreply@github.com>
Thu, 26 May 2022 20:16:15 +0000 (20:16 +0000)
This makes lambdas rewrite to function array constants when possible. This extends our HO solver and utilities to be robust to check whether a node represents a lambda (uf::FunctionConst::toLambda).

This furthermore removes the isConst rule for LAMBDA; lambdas are never constant.

The PR also improves our check-model so that warnings are not thrown if rewriting can show that the model value of a term is equivalent modulo rewriting to its representative in the model equality engine.

This eliminates the last remaining static calls to rewrite. This is work towards eliminating SmtEngineScope.

19 files changed:
src/expr/array_store_all.h
src/preprocessing/passes/ho_elim.cpp
src/printer/smt2/smt2_printer.cpp
src/proof/lfsc/lfsc_node_converter.cpp
src/theory/evaluator.cpp
src/theory/quantifiers/oracle_checker.cpp
src/theory/strings/term_registry.cpp
src/theory/theory_model.cpp
src/theory/theory_model_builder.cpp
src/theory/uf/function_const.cpp
src/theory/uf/function_const.h
src/theory/uf/ho_extension.cpp
src/theory/uf/kinds
src/theory/uf/lambda_lift.cpp
src/theory/uf/theory_uf_rewriter.cpp
src/theory/uf/theory_uf_type_rules.cpp
src/theory/uf/theory_uf_type_rules.h
src/theory/uf/type_enumerator.cpp
src/theory/uf/type_enumerator.h

index ddcdaa1704b93a78eb38b605df0809e0b4a24e42..24e1257b1b316ec542dd0ba8de164fdfcd41e5c8 100644 (file)
@@ -16,8 +16,8 @@
 
 #include "cvc5_public.h"
 
-#ifndef CVC5__ARRAY_STORE_ALL_H
-#define CVC5__ARRAY_STORE_ALL_H
+#ifndef CVC5__EXPR__ARRAY_STORE_ALL_H
+#define CVC5__EXPR__ARRAY_STORE_ALL_H
 
 #include <iosfwd>
 #include <memory>
index 232e0a2c7c714f34efa82eb01c4c69b69dce836c..e83315d81577ef96cff09f1e2daab14796b8dac8 100644 (file)
@@ -24,6 +24,7 @@
 #include "options/quantifiers_options.h"
 #include "preprocessing/assertion_pipeline.h"
 #include "theory/rewriter.h"
+#include "theory/uf/function_const.h"
 #include "theory/uf/theory_uf_rewriter.h"
 
 using namespace cvc5::internal::kind;
@@ -51,17 +52,18 @@ Node HoElim::eliminateLambdaComplete(Node n, std::map<Node, Node>& newLambda)
 
     if (it == d_visited.end())
     {
-      if (cur.getKind() == LAMBDA)
+      Node lam = theory::uf::FunctionConst::toLambda(cur);
+      if (!lam.isNull())
       {
-        Trace("ho-elim-ll") << "Lambda lift: " << cur << std::endl;
+        Trace("ho-elim-ll") << "Lambda lift: " << lam << std::endl;
         // must also get free variables in lambda
         std::vector<Node> lvars;
         std::vector<TypeNode> ftypes;
         std::unordered_set<Node> fvs;
-        expr::getFreeVariables(cur, fvs);
+        expr::getFreeVariables(lam, fvs);
         std::vector<Node> nvars;
         std::vector<Node> vars;
-        Node sbd = cur[1];
+        Node sbd = lam[1];
         if (!fvs.empty())
         {
           Trace("ho-elim-ll")
@@ -78,20 +80,20 @@ Node HoElim::eliminateLambdaComplete(Node n, std::map<Node, Node>& newLambda)
           sbd = sbd.substitute(
               vars.begin(), vars.end(), nvars.begin(), nvars.end());
         }
-        for (const Node& bv : cur[0])
+        for (const Node& bv : lam[0])
         {
           TypeNode bvt = bv.getType();
           ftypes.push_back(bvt);
           lvars.push_back(bv);
         }
-        Node nlambda = cur;
+        Node nlambda = lam;
         if (!fvs.empty())
         {
           nlambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, lvars), sbd);
           Trace("ho-elim-ll")
               << "...new lambda definition: " << nlambda << std::endl;
         }
-        TypeNode rangeType = cur.getType().getRangeType();
+        TypeNode rangeType = lam.getType().getRangeType();
         TypeNode nft = nm->mkFunctionType(ftypes, rangeType);
         Node nf = sm->mkDummySkolem("ll", nft);
         Trace("ho-elim-ll")
index 2790463003b01f31cf4d349a6f0b7bf99f7617bb..ff13b0600e2ced0c3bfd7b90f9ada365ea893244 100644 (file)
@@ -29,6 +29,7 @@
 #include "expr/dtype_cons.h"
 #include "expr/emptybag.h"
 #include "expr/emptyset.h"
+#include "expr/function_array_const.h"
 #include "expr/node_manager_attributes.h"
 #include "expr/node_visitor.h"
 #include "expr/sequence.h"
 #include "printer/let_binding.h"
 #include "proof/unsat_core.h"
 #include "smt/command.h"
-#include "theory/bags/table_project_op.h"
 #include "theory/arrays/theory_arrays_rewriter.h"
+#include "theory/bags/table_project_op.h"
 #include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/datatypes/tuple_project_op.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/theory_model.h"
+#include "theory/uf/function_const.h"
 #include "util/bitvector.h"
 #include "util/divisible.h"
 #include "util/floatingpoint.h"
@@ -309,6 +311,13 @@ void Smt2Printer::toStream(std::ostream& out,
       out << ")";
       break;
     }
+    case kind::FUNCTION_ARRAY_CONST:
+    {
+      // prints as the equivalent lambda
+      Node lam = theory::uf::FunctionConst::toLambda(n);
+      toStream(out, lam, toDepth);
+      break;
+    }
 
     case kind::UNINTERPRETED_SORT_VALUE:
     {
index ee644e2d4dcd92bbbda25fdae701636c498e4355..738dac6b87f7151c5a1017af55c72c6d813cc60e 100644 (file)
@@ -31,6 +31,7 @@
 #include "theory/bv/theory_bv_utils.h"
 #include "theory/datatypes/datatypes_rewriter.h"
 #include "theory/strings/word.h"
+#include "theory/uf/function_const.h"
 #include "theory/uf/theory_uf_rewriter.h"
 #include "util/bitvector.h"
 #include "util/floatingpoint.h"
@@ -369,6 +370,13 @@ Node LfscNodeConverter::postConvert(Node n)
     // notice that intentionally we drop annotations here
     return ret;
   }
+  else if (k == FUNCTION_ARRAY_CONST)
+  {
+    // must convert to lambda and then run the conversion
+    Node lam = theory::uf::FunctionConst::toLambda(n);
+    Assert(!lam.isNull());
+    return convert(lam);
+  }
   else if (k == REGEXP_LOOP)
   {
     // ((_ re.loop n1 n2) t) is ((re.loop n1 n2) t)
index 7caca8427345d64805b97fbb630b93119e064931..efe985d78743364919a0f220880c5173483be1b9 100644 (file)
@@ -19,6 +19,7 @@
 #include "theory/rewriter.h"
 #include "theory/strings/theory_strings_utils.h"
 #include "theory/theory.h"
+#include "theory/uf/function_const.h"
 #include "util/integer.h"
 
 using namespace cvc5::internal::kind;
@@ -330,9 +331,18 @@ EvalResult Evaluator::evalInternal(
         {
           Trace("evaluator") << "Evaluate " << currNode << std::endl;
           TNode op = currNode.getOperator();
-          Assert(evalAsNode.find(op) != evalAsNode.end());
-          // no function can be a valid EvalResult
-          op = evalAsNode[op];
+          if (op.getKind() == kind::FUNCTION_ARRAY_CONST)
+          {
+            // If we have a function constant as the operator, it was not
+            // processed. We require converting to a lambda now.
+            op = uf::FunctionConst::toLambda(op);
+          }
+          else
+          {
+            Assert(evalAsNode.find(op) != evalAsNode.end());
+            // no function can be a valid EvalResult
+            op = evalAsNode[op];
+          }
           Trace("evaluator") << "Operator evaluated to " << op << std::endl;
           if (op.getKind() != kind::LAMBDA)
           {
@@ -362,9 +372,10 @@ EvalResult Evaluator::evalInternal(
 
           // Lambdas are evaluated in a recursive fashion because each
           // evaluation requires different substitutions. We use a fresh cache
-          // since the evaluation of op[1] is under a new substitution and thus
-          // should not be cached. We could alternatively copy evalAsNode to
-          // evalAsNodeC but favor avoiding this copy for performance reasons.
+          // since the evaluation of op[1] is under a new substitution and
+          // thus should not be cached. We could alternatively copy evalAsNode
+          // to evalAsNodeC but favor avoiding this copy for performance
+          // reasons.
           std::unordered_map<TNode, Node> evalAsNodeC;
           std::unordered_map<TNode, EvalResult> resultsC;
           results[currNode] = evalInternal(
index 13a2e7630f6f7f1aae28ffac32e3a9fda63396e2..b1fb3aec54f1712995fa309856e99e057546854e 100644 (file)
@@ -109,7 +109,7 @@ Node OracleChecker::postConvert(Node n)
     }
   }
   // otherwise, always rewrite
-  return Rewriter::rewrite(n);
+  return rewrite(n);
 }
 bool OracleChecker::hasOracles() const { return !d_callers.empty(); }
 bool OracleChecker::hasOracleCalls(Node f) const
index 728f4b047ddf8b6cfea9dac4f2172bd05362f179..1e7da4c70b4a2df7aff96dbdce2f38930cb6a5db 100644 (file)
@@ -15,7 +15,6 @@
 
 #include "theory/strings/term_registry.h"
 
-#include "expr/attribute.h"
 #include "options/smt_options.h"
 #include "options/strings_options.h"
 #include "printer/smt2/smt2_printer.h"
index 52c78151c76a1c4f8164f39b008d15ba752f7270..d33f81fe76358fc27e7c216e4661bb7d7ac2226b 100644 (file)
@@ -24,6 +24,7 @@
 #include "smt/env.h"
 #include "smt/solver_engine.h"
 #include "theory/trust_substitutions.h"
+#include "theory/uf/function_const.h"
 #include "util/rational.h"
 
 using namespace std;
@@ -138,7 +139,12 @@ Node TheoryModel::getValue(TNode n) const
   {
     return nn;
   }
-  else if (nn.getKind() == kind::LAMBDA)
+  if (nn.getKind() == kind::FUNCTION_ARRAY_CONST)
+  {
+    // return the lambda instead
+    nn = uf::FunctionConst::toLambda(nn);
+  }
+  if (nn.getKind() == kind::LAMBDA)
   {
     if (options().theory.condenseFunctionValues)
     {
index c29d89fdef312416488888555ef1e4d2959655a2..c83ce2b63f30043dc59c893c3e506dd15a6073e4 100644 (file)
@@ -23,6 +23,7 @@
 #include "options/uf_options.h"
 #include "smt/env.h"
 #include "theory/rewriter.h"
+#include "theory/uf/function_const.h"
 #include "theory/uf/theory_uf_model.h"
 #include "util/uninterpreted_sort_value.h"
 
@@ -1171,23 +1172,24 @@ void TheoryEngineModelBuilder::debugCheckModel(TheoryModel* tm)
           << "Representative " << rep << " of " << n
           << " violates type constraints (" << rep.getType() << " and "
           << n.getType() << ")";
-      Node val = tm->getValue(*eqc_i);
+      Node val = tm->getValue(n);
       if (val != rep)
       {
         std::stringstream err;
         err << "Failed representative check:" << std::endl
             << "( " << repCheckInstance << ") "
-            << "n: " << n << endl
-            << "getValue(n): " << tm->getValue(n) << std::endl
+            << "n: " << n << std::endl
+            << "getValue(n): " << val << std::endl
             << "rep: " << rep << std::endl;
         if (val.isConst() && rep.isConst())
         {
           AlwaysAssert(val == rep) << err.str();
         }
-        else
+        else if (rewrite(val) != rewrite(rep))
         {
           // if it does not evaluate, it is just a warning, which may be the
-          // case for non-constant values, e.g. lambdas.
+          // case for non-constant values, e.g. lambdas. Furthermore we only
+          // throw this warning if rewriting cannot show they are equal.
           warning() << err.str();
         }
       }
@@ -1359,7 +1361,10 @@ void TheoryEngineModelBuilder::assignHoFunction(TheoryModel* m, Node f)
       Assert(hnv.isConst());
       if (!apply_args.empty())
       {
-        Assert(hnv.getKind() == kind::LAMBDA
+        // Convert to lambda, which is necessary if hnv is a function array
+        // constant.
+        hnv = uf::FunctionConst::toLambda(hnv);
+        Assert(!hnv.isNull() && hnv.getKind() == kind::LAMBDA
                && hnv[0].getNumChildren() + 1 == args.size());
         std::vector<TNode> largs;
         for (unsigned j = 0; j < hnv[0].getNumChildren(); j++)
index 27b39018a9ea1d423edbabfa0f2a77f2a4fad316..b2bde1306aadd19916f802c760714e5726647589 100644 (file)
 #include "theory/uf/function_const.h"
 
 #include "expr/array_store_all.h"
+#include "expr/attribute.h"
+#include "expr/bound_var_manager.h"
+#include "expr/function_array_const.h"
 #include "theory/arrays/theory_arrays_rewriter.h"
 #include "theory/rewriter.h"
+#include "util/rational.h"
 
 namespace cvc5::internal {
 namespace theory {
 namespace uf {
 
+/**
+ * Attribute for constructing a unique bound variable list for the lambda
+ * corresponding to an array constant.
+ */
+struct FunctionBoundVarListTag
+{
+};
+using FunctionBoundVarListAttribute =
+    expr::Attribute<FunctionBoundVarListTag, Node>;
+/**
+ * An attribute to cache the conversion between array constants and lambdas.
+ */
+struct ArrayToLambdaTag
+{
+};
+using ArrayToLambdaAttribute = expr::Attribute<ArrayToLambdaTag, Node>;
+
+Node FunctionConst::toLambda(TNode n)
+{
+  Kind nk = n.getKind();
+  if (nk == kind::LAMBDA)
+  {
+    return n;
+  }
+  else if (nk == kind::FUNCTION_ARRAY_CONST)
+  {
+    ArrayToLambdaAttribute atla;
+    if (n.hasAttribute(atla))
+    {
+      return n.getAttribute(atla);
+    }
+    const FunctionArrayConst& fc = n.getConst<FunctionArrayConst>();
+    Node avalue = fc.getArrayValue();
+    TypeNode tn = fc.getType();
+    Assert(tn.isFunction());
+    std::vector<TypeNode> argTypes = tn.getArgTypes();
+    std::vector<Node> bvs;
+    NodeManager* nm = NodeManager::currentNM();
+    BoundVarManager* bvm = nm->getBoundVarManager();
+    // associate a unique bound variable list with the value
+    for (size_t i = 0, nargs = argTypes.size(); i < nargs; i++)
+    {
+      Node cacheVal =
+          BoundVarManager::getCacheValue(n, nm->mkConstInt(Rational(i)));
+      Node v =
+          bvm->mkBoundVar<FunctionBoundVarListAttribute>(cacheVal, argTypes[i]);
+      bvs.push_back(v);
+    }
+    Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, bvs);
+    Node lam = getLambdaForArrayRepresentation(avalue, bvl);
+    n.setAttribute(atla, lam);
+    return lam;
+  }
+  return Node::null();
+}
+
 TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl)
 {
   std::vector<TypeNode> children;
@@ -112,7 +172,6 @@ Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl)
   Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited);
   if (!body.isNull())
   {
-    body = Rewriter::rewrite(body);
     Trace("builtin-rewrite-debug")
         << "...got lambda body " << body << std::endl;
     return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body);
@@ -269,13 +328,6 @@ Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n,
         return Node::null();
       }
     }
-    else if (Rewriter::rewrite(index_eq) != index_eq)
-    {
-      // equality must be oriented correctly based on rewriter
-      Trace("builtin-rewrite-debug2")
-          << "  ...equality not oriented properly." << std::endl;
-      return Node::null();
-    }
 
     // [3] We ensure that "index_eq" is an equality that is equivalent to
     // "first_arg" = "curr_index", where curr_index is a constant, and
@@ -395,13 +447,22 @@ Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n,
   return Node::null();
 }
 
-Node FunctionConst::getArrayRepresentationForLambda(TNode n)
+Node FunctionConst::toArrayConst(TNode n)
 {
-  Assert(n.getKind() == kind::LAMBDA);
-  // must carry the overall return type to deal with cases like (lambda ((x Int)
-  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
-  // case above should be (arraystoreall (Array Int Real) 0.0)
-  return getArrayRepresentationForLambdaRec(n, n[1].getType());
+  Kind nk = n.getKind();
+  if (nk == kind::FUNCTION_ARRAY_CONST)
+  {
+    const FunctionArrayConst& fc = n.getConst<FunctionArrayConst>();
+    return fc.getArrayValue();
+  }
+  else if (nk == kind::LAMBDA)
+  {
+    // must carry the overall return type to deal with cases like (lambda ((x
+    // Int) (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for
+    // the else case above should be (arraystoreall (Array Int Real) 0.0)
+    return getArrayRepresentationForLambdaRec(n, n[1].getType());
+  }
+  return Node::null();
 }
 
 }  // namespace uf
index b74503cafa95a4b654416aa8b553d2e9106da6c7..02324194e18b1cc5145b4a220875874b1a327c81 100644 (file)
@@ -53,6 +53,35 @@ class FunctionConst
    * getArrayRepresentationForLambda( t ), where t.getType()=ftn.
    */
   static TypeNode getArrayTypeForFunctionType(TypeNode ftn);
+  /**
+   * Returns a node of kind LAMBDA that is equivalent to n, or null otherwise.
+   *
+   * This is the identity function for lambda terms and runs the conversion
+   * for constant array functions, and null for all other nodes. For details,
+   * see the method getLambdaForArrayRepresentation.
+   */
+  static Node toLambda(TNode n);
+  /**
+   * Extracts the array constant from the payload of a a function array constant
+   *
+   *
+   * Given a lambda expression n, returns an array term that corresponds to n.
+   * This does the opposite direction of the examples described above the
+   * method getLambdaForArrayRepresentation.
+   *
+   * We limit the return values of this method to be almost constant functions,
+   * that is, arrays of the form:
+   *   (store ... (store (storeall _ b) i1 e1) ... in en)
+   * where b, i1, e1, ..., in, en are constants.
+   * Notice however that the return value of this form need not be an
+   * array such that isConst is true.
+   *
+   * If it is not possible to construct an array of this form that corresponds
+   * to n, this method returns null.
+   */
+  static Node toArrayConst(TNode n);
+
+ private:
   /**
    * Given an array constant a, returns a lambda expression that it corresponds
    * to, with bound variable list bvl.
@@ -76,30 +105,13 @@ class FunctionConst
    * (lambda x. (ite (= x 1) true (= x 2)))
    */
   static Node getLambdaForArrayRepresentation(TNode a, TNode bvl);
-  /**
-   * Given a lambda expression n, returns an array term that corresponds to n.
-   * This does the opposite direction of the examples described above.
-   *
-   * We limit the return values of this method to be almost constant functions,
-   * that is, arrays of the form:
-   *   (store ... (store (storeall _ b) i1 e1) ... in en)
-   * where b, i1, e1, ..., in, en are constants.
-   * Notice however that the return value of this form need not be a (canonical)
-   * array constant.
-   *
-   * If it is not possible to construct an array of this form that corresponds
-   * to n, this method returns null.
-   */
-  static Node getArrayRepresentationForLambda(TNode n);
-
- private:
   /** recursive helper for getLambdaForArrayRepresentation */
   static Node getLambdaForArrayRepresentationRec(
       TNode a,
       TNode bvl,
       unsigned bvlIndex,
       std::unordered_map<TNode, Node>& visited);
-  /** recursive helper for getArrayRepresentationForLambda */
+  /** recursive helper for toArrayConst */
   static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType);
 };
 
index 521ab4f9c7fbad4e2b36df258c7f15c057e93a15..b0359fa4821d222dda5e5908995f3299f9dea390 100644 (file)
@@ -19,6 +19,7 @@
 #include "expr/skolem_manager.h"
 #include "options/uf_options.h"
 #include "theory/theory_model.h"
+#include "theory/uf/function_const.h"
 #include "theory/uf/lambda_lift.h"
 #include "theory/uf/theory_uf_rewriter.h"
 
@@ -100,7 +101,7 @@ TrustNode HoExtension::ppRewrite(Node node, std::vector<SkolemLemma>& lems)
       }
     }
   }
-  else if (k == kind::LAMBDA)
+  else if (k == kind::LAMBDA || k == kind::FUNCTION_ARRAY_CONST)
   {
     Trace("uf-lazy-ll") << "Preprocess lambda: " << node << std::endl;
     TrustNode skTrn = d_ll.ppRewrite(node, lems);
index 304679df25162db40feccf17bafb7e91dbfc4d73..0837f29027c9a2ff46392e5bbf536b516b2ad18a 100644 (file)
@@ -33,9 +33,6 @@ typerule LAMBDA ::cvc5::internal::theory::uf::LambdaTypeRule
 
 variable BOOLEAN_TERM_VARIABLE "Boolean term variable"
 
-# lambda expressions that are isomorphic to array constants can be considered constants
-construle LAMBDA ::cvc5::internal::theory::uf::LambdaTypeRule
-
 operator HO_APPLY 2 "higher-order (partial) function application"
 typerule HO_APPLY ::cvc5::internal::theory::uf::HoApplyTypeRule
 
index e9313278ced0dd817b09fa069146e971130c220e..7e1823dfc93c330f01b68df735e6004df8e0b128 100644 (file)
@@ -19,6 +19,7 @@
 #include "expr/skolem_manager.h"
 #include "options/uf_options.h"
 #include "smt/env.h"
+#include "theory/uf/function_const.h"
 
 using namespace cvc5::internal::kind;
 
@@ -59,15 +60,16 @@ TrustNode LambdaLift::lift(Node node)
 
 TrustNode LambdaLift::ppRewrite(Node node, std::vector<SkolemLemma>& lems)
 {
-  TNode skolem = getSkolemFor(node);
+  Node lam = FunctionConst::toLambda(node);
+  TNode skolem = getSkolemFor(lam);
   if (skolem.isNull())
   {
     return TrustNode::null();
   }
-  d_lambdaMap[skolem] = node;
+  d_lambdaMap[skolem] = lam;
   if (!options().uf.ufHoLazyLambdaLift)
   {
-    TrustNode trn = lift(node);
+    TrustNode trn = lift(lam);
     lems.push_back(SkolemLemma(trn, skolem));
   }
   // if no proofs, return lemma with no generator
@@ -102,21 +104,21 @@ Node LambdaLift::getAssertionFor(TNode node)
   {
     return Node::null();
   }
-  Kind k = node.getKind();
   Node assertion;
-  if (k == LAMBDA)
+  Node lambda = FunctionConst::toLambda(node);
+  if (!lambda.isNull())
   {
     NodeManager* nm = NodeManager::currentNM();
     // The new assertion
     std::vector<Node> children;
     // bound variable list
-    children.push_back(node[0]);
+    children.push_back(lambda[0]);
     // body
     std::vector<Node> skolem_app_c;
     skolem_app_c.push_back(skolem);
-    skolem_app_c.insert(skolem_app_c.end(), node[0].begin(), node[0].end());
+    skolem_app_c.insert(skolem_app_c.end(), lambda[0].begin(), lambda[0].end());
     Node skolem_app = nm->mkNode(APPLY_UF, skolem_app_c);
-    skolem_app_c[0] = node;
+    skolem_app_c[0] = lambda;
     Node rhs = nm->mkNode(APPLY_UF, skolem_app_c);
     // For the sake of proofs, we use
     // (= (k t1 ... tn) ((lambda (x1 ... xn) s) t1 ... tn)) here. This is instead of
index 0f326bdd0f8aacf708cd8301fdbc193b354c7564..73c36ade757676b2743acbb81dbc536fa7c3d0ef 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/uf/theory_uf_rewriter.h"
 
+#include "expr/function_array_const.h"
 #include "expr/node_algorithm.h"
 #include "theory/rewriter.h"
 #include "theory/substitutions.h"
@@ -53,11 +54,11 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node)
   }
   if (node.getKind() == kind::APPLY_UF)
   {
-    if (node.getOperator().getKind() == kind::LAMBDA)
+    Node lambda = FunctionConst::toLambda(node.getOperator());
+    if (!lambda.isNull())
     {
-      Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing all args of : " << node
-                          << "\n";
-      TNode lambda = node.getOperator();
+      Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing all args of : "
+                          << lambda << " for " << node << "\n";
       Node ret;
       // build capture-avoiding substitution since in HOL shadowing may have
       // been introduced
@@ -102,17 +103,18 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node)
   }
   else if (node.getKind() == kind::HO_APPLY)
   {
-    if (node[0].getKind() == kind::LAMBDA)
+    Node lambda = FunctionConst::toLambda(node[0]);
+    if (!lambda.isNull())
     {
       // resolve one argument of the lambda
       Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing one argument of : "
-                          << node[0] << " with " << node[1] << "\n";
+                          << lambda << " with " << node[1] << "\n";
 
       // reconstruct the lambda first to avoid variable shadowing
-      Node new_body = node[0][1];
-      if (node[0][0].getNumChildren() > 1)
+      Node new_body = lambda[1];
+      if (lambda[0].getNumChildren() > 1)
       {
-        std::vector<Node> new_vars(node[0][0].begin() + 1, node[0][0].end());
+        std::vector<Node> new_vars(lambda[0].begin() + 1, lambda[0].end());
         std::vector<Node> largs;
         largs.push_back(
             NodeManager::currentNM()->mkNode(kind::BOUND_VAR_LIST, new_vars));
@@ -127,13 +129,13 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node)
       if (d_isHigherOrder)
       {
         Node arg = node[1];
-        Node var = node[0][0][0];
+        Node var = lambda[0][0];
         new_body = expr::substituteCaptureAvoiding(new_body, var, arg);
       }
       else
       {
         TNode arg = node[1];
-        TNode var = node[0][0][0];
+        TNode var = lambda[0][0];
         new_body = new_body.substitute(var, arg);
       }
       Trace("uf-ho-beta") << "uf-ho-beta : ..new body : " << new_body << "\n";
@@ -221,7 +223,7 @@ Node TheoryUfRewriter::rewriteLambda(Node node)
   // normalization on array constants, and then converting the array constant
   // back to a lambda.
   Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
-  Node anode = FunctionConst::getArrayRepresentationForLambda(node);
+  Node anode = FunctionConst::toArrayConst(node);
   // Only rewrite constant array nodes, since these are the only cases
   // where we require canonicalization of lambdas. Moreover, applying the
   // below code is not correct if the arguments to the lambda occur
@@ -231,26 +233,12 @@ Node TheoryUfRewriter::rewriteLambda(Node node)
   if (!anode.isNull() && anode.isConst())
   {
     Assert(anode.getType().isArray());
-    // must get the standard bound variable list
-    Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType(
-        node.getType());
-    Node retNode =
-        FunctionConst::getLambdaForArrayRepresentation(anode, varList);
-    if (!retNode.isNull() && retNode != node)
-    {
-      Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
-      Trace("builtin-rewrite") << "     input  : " << node << std::endl;
-      Trace("builtin-rewrite")
-          << "     output : " << retNode << ", constant = " << retNode.isConst()
-          << std::endl;
-      Trace("builtin-rewrite")
-          << "  array rep : " << anode << ", constant = " << anode.isConst()
-          << std::endl;
-      Assert(anode.isConst() == retNode.isConst());
-      Assert(retNode.getType() == node.getType());
-      Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
-      return retNode;
-    }
+    Node retNode = NodeManager::currentNM()->mkConst(
+        FunctionArrayConst(node.getType(), anode));
+    Assert(anode.isConst() == retNode.isConst());
+    Assert(retNode.getType() == node.getType());
+    Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
+    return retNode;
   }
   else
   {
index 180504da256c193f2369fb470e4ff70e5735b105..1f9bc7b1437111d27e48f8c7fd48b107cbea4d87 100644 (file)
@@ -176,53 +176,6 @@ TypeNode LambdaTypeRule::computeType(NodeManager* nodeManager,
   return nodeManager->mkFunctionType(argTypes, rangeType);
 }
 
-bool LambdaTypeRule::computeIsConst(NodeManager* nodeManager, TNode n)
-{
-  Assert(n.getKind() == kind::LAMBDA);
-  // get array representation of this function, if possible
-  Node na = FunctionConst::getArrayRepresentationForLambda(n);
-  if (!na.isNull())
-  {
-    Assert(na.getType().isArray());
-    Trace("lambda-const") << "Array representation for " << n << " is " << na
-                          << " " << na.getType() << std::endl;
-    // must have the standard bound variable list
-    Node bvl =
-        NodeManager::currentNM()->getBoundVarListForFunctionType(n.getType());
-    if (bvl == n[0])
-    {
-      // array must be constant
-      if (na.isConst())
-      {
-        Trace("lambda-const") << "*** Constant lambda : " << n;
-        Trace("lambda-const") << " since its array representation : " << na
-                              << " is constant." << std::endl;
-        return true;
-      }
-      else
-      {
-        Trace("lambda-const") << "Non-constant lambda : " << n
-                              << " since array is not constant." << std::endl;
-      }
-    }
-    else
-    {
-      Trace("lambda-const")
-          << "Non-constant lambda : " << n
-          << " since its varlist is not standard." << std::endl;
-      Trace("lambda-const") << "  standard : " << bvl << std::endl;
-      Trace("lambda-const") << "   current : " << n[0] << std::endl;
-    }
-  }
-  else
-  {
-    Trace("lambda-const") << "Non-constant lambda : " << n
-                          << " since it has no array representation."
-                          << std::endl;
-  }
-  return false;
-}
-
 TypeNode FunctionArrayConstTypeRule::computeType(NodeManager* nodeManager,
                                                  TNode n,
                                                  bool check)
index 12fc2d679833ec084a2961fa4cef963ab3ccf647..c75d8c1692b32f77a8799450877cfd481a05379d 100644 (file)
@@ -98,9 +98,6 @@ class LambdaTypeRule
 {
  public:
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
-  // computes whether a lambda is a constant value, via conversion to array
-  // representation
-  static bool computeIsConst(NodeManager* nodeManager, TNode n);
 }; /* class LambdaTypeRule */
 
 /**
@@ -111,7 +108,7 @@ class FunctionArrayConstTypeRule
 {
  public:
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
-}; /* class LambdaTypeRule */
+};
 
 class FunctionProperties
 {
index 84fafa6b8dfd4edcca08575fac85e362879b8e9e..a0151e7779c3fb63444a2a03a3e26b4e35e166bf 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/uf/type_enumerator.h"
 
+#include "expr/function_array_const.h"
 #include "theory/uf/function_const.h"
 
 namespace cvc5::internal {
@@ -27,7 +28,6 @@ FunctionEnumerator::FunctionEnumerator(TypeNode type,
       d_arrayEnum(FunctionConst::getArrayTypeForFunctionType(type), tep)
 {
   Assert(type.getKind() == kind::FUNCTION_TYPE);
-  d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type);
 }
 
 Node FunctionEnumerator::operator*()
@@ -37,7 +37,7 @@ Node FunctionEnumerator::operator*()
     throw NoMoreValuesException(getType());
   }
   Node a = *d_arrayEnum;
-  return FunctionConst::getLambdaForArrayRepresentation(a, d_bvl);
+  return NodeManager::currentNM()->mkConst(FunctionArrayConst(getType(), a));
 }
 
 FunctionEnumerator& FunctionEnumerator::operator++()
index 75ea631de02d9334e3a45c17065765364fdaa583..66f4ba0b8ed2da6119e3105a8076d3924d775c7d 100644 (file)
@@ -45,11 +45,6 @@ class FunctionEnumerator : public TypeEnumeratorBase<FunctionEnumerator>
  private:
   /** Enumerates arrays, which we convert to functions. */
   TypeEnumerator d_arrayEnum;
-  /** The bound variable list for the function type we are enumerating.
-   * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for
-   * some term t.
-   */
-  Node d_bvl;
 }; /* class FunctionEnumerator */
 
 }  // namespace uf