Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/pyoptinterface/cppad_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ ADFunDouble sparse_hessian(const ADFunDouble &f, const sparsity_pattern_t &patte
const std::vector<double> &p_values);

// Transform ExpressionGraph to CppAD function
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph);
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true);
// selected_outputs: indices of outputs to trace, empty means all outputs
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
const std::vector<size_t> &selected_outputs = {});
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true,
const std::vector<size_t> &selected_outputs = {});

struct CppADAutodiffGraph
{
Expand Down
33 changes: 12 additions & 21 deletions include/pyoptinterface/knitro_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ struct CallbackEvaluator
std::vector<KNINT> indexCons;

CppAD::ADFun<V> fun;

std::vector<size_t> fun_rows;
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_;
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> jac_;
CppAD::sparse_jac_work jac_work_;
Expand All @@ -163,11 +161,10 @@ struct CallbackEvaluator
void setup()
{
fun.optimize();
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun_rows.size(),
fun_rows.size());
for (size_t k = 0; k < fun_rows.size(); k++)
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_in(fun.Range(), fun.Range(), fun.Range());
for (size_t k = 0; k < fun.Range(); k++)
{
jac_pattern_in.set(k, fun_rows[k], fun_rows[k]);
jac_pattern_in.set(k, k, k);
}
fun.rev_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_);
jac_pattern_in.resize(fun.Domain(), fun.Domain(), fun.Domain());
Expand All @@ -177,11 +174,7 @@ struct CallbackEvaluator
}
CppAD::sparse_rc<std::vector<size_t>> jac_pattern_out;
fun.for_jac_sparsity(jac_pattern_in, false, false, true, jac_pattern_out);
std::vector<bool> select_rows(fun.Range(), false);
for (size_t k = 0; k < fun_rows.size(); k++)
{
select_rows[fun_rows[k]] = true;
}
std::vector<bool> select_rows(fun.Range(), true);
fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_);
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
{
Expand All @@ -205,15 +198,15 @@ struct CallbackEvaluator
x[i] = req_x[indexVars[i]];
}
auto y = fun.Forward(0, x);
for (size_t k = 0; k < fun_rows.size(); k++)
for (size_t k = 0; k < fun.Range(); k++)
{
if (aggregate)
{
res_y[0] += y[fun_rows[k]];
res_y[0] += y[k];
}
else
{
res_y[k] = y[fun_rows[k]];
res_y[k] = y[k];
}
}
}
Expand All @@ -238,15 +231,15 @@ struct CallbackEvaluator
{
x[i] = req_x[indexVars[i]];
}
for (size_t k = 0; k < fun_rows.size(); k++)
for (size_t k = 0; k < fun.Range(); k++)
{
if (aggregate)
{
w[fun_rows[k]] = req_w[0];
w[k] = req_w[0];
}
else
{
w[fun_rows[k]] = req_w[indexCons[k]];
w[k] = req_w[indexCons[k]];
}
}
fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
Expand Down Expand Up @@ -696,14 +689,12 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
}

template <typename T, typename F, typename G, typename H>
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<size_t> &rows,
const std::vector<ConstraintIndex> cons, const T &trace, const F f,
const G g, const H h)
void _add_callback_impl(const ExpressionGraph &graph, const std::vector<ConstraintIndex> cons,
const T &trace, const F f, const G g, const H h)
{
auto evaluator_ptr = std::make_unique<CallbackEvaluator<double>>();
auto *evaluator = evaluator_ptr.get();
evaluator->fun = trace(graph);
evaluator->fun_rows = rows;
evaluator->indexVars.resize(graph.n_variables());
for (size_t i = 0; i < graph.n_variables(); i++)
{
Expand Down
49 changes: 41 additions & 8 deletions lib/cppad_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ CppAD::AD<double> cppad_trace_expression(
return result;
}

ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph,
const std::vector<size_t> &selected_outputs)
{
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;

Expand All @@ -453,13 +454,29 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
}

auto &outputs = graph.m_constraint_outputs;
auto N_outputs = outputs.size();

std::vector<size_t> indices;
if (selected_outputs.empty())
{
indices.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
indices.push_back(i);
}
}
else
{
indices = selected_outputs;
}

auto N_outputs = indices.size();
std::vector<CppAD::AD<double>> y(N_outputs);

// Trace the outputs
// Trace the selected outputs
for (size_t i = 0; i < N_outputs; i++)
{
auto &output = outputs[i];
auto idx = indices[i];
auto &output = outputs[idx];
y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions);
}

Expand All @@ -469,7 +486,8 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
return f;
}

ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate)
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate,
const std::vector<size_t> &selected_outputs)
{
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;

Expand All @@ -493,13 +511,28 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggre
}

auto &outputs = graph.m_objective_outputs;
auto N_outputs = outputs.size();

std::vector<size_t> indices;
if (selected_outputs.empty())
{
indices.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
indices.push_back(i);
}
}
else
{
indices = selected_outputs;
}

auto N_outputs = indices.size();
std::vector<CppAD::AD<double>> y(N_outputs);

// Trace the outputs
for (size_t i = 0; i < N_outputs; i++)
{
auto &output = outputs[i];
auto idx = indices[i];
auto &output = outputs[idx];
y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions);
}

Expand Down
5 changes: 3 additions & 2 deletions lib/cppad_interface_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ NB_MODULE(cppad_interface_ext, m)
.def_ro("jacobian", &CppADAutodiffGraph::jacobian_graph)
.def_ro("hessian", &CppADAutodiffGraph::hessian_graph);

m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints);
m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints, nb::arg("graph"),
nb::arg("selected_outputs") = std::vector<size_t>{});
m.def("cppad_trace_graph_objective", cppad_trace_graph_objective, nb::arg("graph"),
nb::arg("aggregate") = true);
nb::arg("aggregate") = true, nb::arg("selected_outputs") = std::vector<size_t>{});
m.def("cppad_autodiff", &cppad_autodiff);
}
12 changes: 7 additions & 5 deletions lib/knitro_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,10 @@ void KNITROModel::_add_constraint_callback(ExpressionGraph *graph, const Outputs
evaluator->eval_hess(req->x, req->lambda, res->hess);
return 0;
};
auto trace = cppad_trace_graph_constraints;
_add_callback_impl(*graph, outputs.con_idxs, outputs.cons, trace, f, g, h);
auto trace = [outputs](const ExpressionGraph &graph) {
return cppad_trace_graph_constraints(graph, outputs.con_idxs);
};
_add_callback_impl(*graph, outputs.cons, trace, f, g, h);
}

void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs &outputs)
Expand All @@ -881,10 +883,10 @@ void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs
evaluator->eval_hess(req->x, req->sigma, res->hess, true);
return 0;
};
auto trace = [](const ExpressionGraph &graph) {
return cppad_trace_graph_objective(graph, false);
auto trace = [outputs](const ExpressionGraph &graph) {
return cppad_trace_graph_objective(graph, true, outputs.obj_idxs);
};
_add_callback_impl(*graph, outputs.obj_idxs, {}, trace, f, g, h);
_add_callback_impl(*graph, {}, trace, f, g, h);
}

void KNITROModel::_add_callbacks()
Expand Down