26 #ifndef CASADI_FACTORY_HPP
27 #define CASADI_FACTORY_HPP
30 #include "function.hpp"
51 template<
typename MatType>
78 Block block(
const std::string& s1,
const std::string& s)
const;
81 HBlock hblock(
const std::string& s1,
const std::string& s)
const;
84 void add_input(
const std::string& s,
const MatType& e,
bool is_diff);
87 void add_output(
const std::string& s,
const MatType& e,
bool is_diff);
105 std::vector<Block>::iterator
find_jac(
size_t f,
size_t x);
108 std::vector<HBlock>::iterator
find_hess(
size_t f,
size_t x1,
size_t x2);
126 size_t imap(
const std::string& s)
const;
129 size_t omap(
const std::string& s)
const;
141 static std::pair<std::string, std::string>
split_prefix(
const std::string& s);
151 std::vector<std::string>
iname(
const std::vector<size_t>& ind)
const;
155 std::vector<std::string>
oname(
const std::vector<size_t>& ind)
const;
158 template<
typename MatType>
160 add_input(
const std::string& s,
const MatType& e,
bool is_diff) {
161 size_t ind = in_.size();
162 auto it = imap_.insert(std::make_pair(s, ind));
163 casadi_assert(it.second,
"Duplicate input expression \"" + s +
"\"");
164 is_diff_in_.push_back(is_diff);
169 template<
typename MatType>
171 add_output(
const std::string& s,
const MatType& e,
bool is_diff) {
172 size_t ind = out_.size();
173 auto it = omap_.insert(std::make_pair(s, ind));
174 casadi_assert(it.second,
"Duplicate output expression \"" + s +
"\"");
175 is_diff_out_.push_back(is_diff);
180 template<
typename MatType>
186 casadi_assert(
has_prefix(s),
"Cannot process \"" + s +
"\" as input."
187 " Available: " +
join(iname()) +
".");
188 std::pair<std::string, std::string> ss = split_prefix(s);
190 if (ss.first==
"fwd") {
192 fwd_in_.push_back(imap(ss.second));
193 }
else if (ss.first==
"adj") {
195 adj_in_.push_back(omap(ss.second));
200 std::replace(ret.begin(), ret.end(),
':',
'_');
204 template<
typename MatType>
208 if (has_out(s))
return s;
211 casadi_assert(
has_prefix(s),
"Cannot process \"" + s +
"\" as output."
212 " Available: " +
join(oname()) +
".");
213 std::pair<std::string, std::string> ss = split_prefix(s);
215 if (ss.first==
"fwd") {
216 fwd_out_.push_back(omap(ss.second));
217 }
else if (ss.first==
"adj") {
218 adj_out_.push_back(imap(ss.second));
219 }
else if (ss.first==
"jac") {
220 jac_.push_back(block(ss.second, s));
221 }
else if (ss.first==
"grad") {
222 grad_.push_back(block(ss.second, s));
223 }
else if (ss.first==
"hess") {
224 hess_.push_back(hblock(ss.second, s));
227 request_output(ss.second);
232 replace(ret.begin(), ret.end(),
':',
'_');
236 template<
typename MatType>
238 if (fwd_out_.empty())
return;
239 casadi_assert_dev(!fwd_in_.empty());
241 std::vector<MatType> arg, res;
242 std::vector<std::vector<MatType>> seed(1), sens(1);
244 for (
size_t iind : fwd_in_) {
245 arg.push_back(in_[iind]);
246 Sparsity sp = is_diff_in_.at(iind) ? arg.back().sparsity() :
Sparsity(arg.back().size());
247 seed[0].push_back(MatType::sym(
"fwd_" + iname_[iind], sp));
248 add_input(
"fwd:" + iname_[iind], seed[0].back(),
true);
251 for (
size_t oind : fwd_out_) res.push_back(out_.at(oind));
253 Dict local_opts = opts;
254 local_opts[
"always_inline"] =
true;
255 sens = forward(res, arg, seed, local_opts);
258 for (
size_t i = 0; i < fwd_out_.size(); ++i) {
259 std::string s = oname_.at(fwd_out_[i]);
260 Sparsity sp = is_diff_out_.at(fwd_out_[i]) ? res.at(i).sparsity()
262 add_output(
"fwd:" + s, project(sens[0].at(i), sp), is_diff_out_.at(fwd_out_[i]));
266 template<
typename MatType>
268 if (adj_out_.empty())
return;
269 casadi_assert_dev(!adj_in_.empty());
270 std::vector<MatType> arg, res;
271 std::vector<std::vector<MatType>> seed(1), sens(1);
273 for (
size_t ind : adj_out_) arg.push_back(in_[ind]);
275 for (
size_t ind : adj_in_) {
276 res.push_back(out_.at(ind));
277 Sparsity sp = is_diff_out_.at(ind) ? res.back().sparsity() :
Sparsity(res.back().size());
278 seed[0].push_back(MatType::sym(
"adj_" + oname_[ind], sp));
279 add_input(
"adj:" + oname_[ind], seed[0].back(),
true);
283 local_opts[
"always_inline"] =
true;
284 sens =
reverse(res, arg, seed, local_opts);
287 for (
size_t i=0; i < adj_out_.size(); ++i) {
288 std::string s = iname_[adj_out_[i]];
289 Sparsity sp = is_diff_in_.at(adj_out_[i]) ? arg.at(i).sparsity() :
Sparsity(arg.at(i).size());
290 add_output(
"adj:" + s, project(sens[0].at(i), sp), is_diff_in_.at(adj_out_[i]));
294 template<
typename MatType>
296 for (std::vector<Block>::iterator it = jac_.begin(); it != jac_.end(); ++it) {
297 if (it->f == f && it->x == x)
return it;
303 template<
typename MatType>
305 for (std::vector<HBlock>::iterator it = hess_.begin(); it != hess_.end(); ++it) {
306 if (it->f == f && it->x1 == x1 && it->x2 == x2)
return it;
312 template<
typename MatType>
315 for (
auto &&b : jac_) {
316 if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x)) {
317 b.calculated =
false;
319 add_output(b.s, MatType(out_[b.f].numel(), in_[b.x].numel()),
false);
324 for (
auto &&b : jac_) {
326 if (b.calculated)
continue;
328 std::vector<size_t> all_f;
329 for (
auto &&b1 : jac_) {
330 if (b1.x == b.x && !b1.calculated) all_f.push_back(b1.f);
333 std::vector<size_t> all_x{b.x};
334 for (
auto &&b1 : jac_) {
336 if (std::count(all_x.begin(), all_x.end(), b1.x))
continue;
339 for (
size_t f1 : all_f) {
340 auto it = find_jac(f1, b1.x);
341 if (it == jac_.end() || it->calculated) {
348 all_x.push_back(b1.x);
352 if (all_f.size() == 1 && all_x.size() == 1) {
354 add_output(b.s, MatType::jacobian(out_[b.f], in_[b.x], opts),
true);
358 std::sort(all_x.begin(), all_x.end());
359 std::sort(all_f.begin(), all_f.end());
361 std::vector<MatType> x(all_x.size()), f(all_f.size());
362 for (
size_t i = 0; i < x.size(); ++i) x[i] = in_.at(all_x[i]);
363 for (
size_t i = 0; i < f.size(); ++i) f[i] = out_.at(all_f[i]);
365 MatType J = MatType::jacobian(veccat(f), veccat(x), opts);
367 std::vector<std::vector<MatType>> J_all = blocksplit(J, offset(f), offset(x));
369 for (
size_t i = 0; i < all_f.size(); ++i) {
370 for (
size_t j = 0; j < all_x.size(); ++j) {
371 auto J_it = find_jac(all_f[i], all_x[j]);
372 if (J_it != jac_.end()) {
373 add_output(J_it->s, J_all.at(i).at(j),
true);
374 J_it->calculated =
true;
379 }
catch (std::exception& e) {
380 std::stringstream ss;
381 ss <<
"Calculating Jacobian of " << oname(all_f) <<
" w.r.t. " << iname(all_x)
383 casadi_error(ss.str());
388 template<
typename MatType>
390 for (
auto &&b : grad_) {
391 const MatType& ex = out_.at(b.f);
392 const MatType& arg = in_[b.x];
393 if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x)) {
394 add_output(
"grad:" + oname_[b.f] +
":" + iname_[b.x],
395 project(gradient(ex, arg, opts), arg.sparsity()),
true);
397 casadi_assert(ex.is_scalar(),
"Can only take gradient of scalar expression.");
398 add_output(
"grad:" + oname_[b.f] +
":" + iname_[b.x], MatType(1, arg.numel()),
false);
403 template<
typename MatType>
406 for (
auto &&b : hess_) {
407 if (b.f != f)
continue;
409 if (b.calculated)
continue;
411 std::vector<size_t> all_x1;
412 for (
auto &&b1 : hess_) {
413 if (b1.f == b.f && !b1.calculated) {
416 all_x1.push_back(b1.x2);
417 }
else if (b1.x2 == b.x1) {
419 all_x1.push_back(b1.x1);
424 std::vector<size_t> all_x2;
425 for (
auto &&b1 : hess_) {
426 if (b1.f != f || b1.calculated)
continue;
428 for (
bool test_x1 : {
false,
true}) {
429 size_t cand = test_x1 ? b1.x1 : b1.x2;
430 size_t other = test_x1 ? b1.x2 : b1.x1;
432 bool other_ok =
false;
434 if (std::count(all_x2.begin(), all_x2.end(), cand))
continue;
436 for (
size_t a : all_x1) {
438 if (other == a) other_ok =
true;
440 auto it = find_hess(f, a, cand);
441 if (it == hess_.end() || it->calculated) {
444 it = find_hess(f, cand, a);
445 if (it != hess_.end() && !it->calculated)
continue;
453 if (cand_ok && other_ok) all_x2.push_back(cand);
458 if (all_x1.size() == 1 && all_x2.size() == 1) {
460 MatType H = b.x1 == b.x2 ? hessian(out_.at(f), in_[b.x1], opts)
461 : jacobian(gradient(out_.at(f), in_[b.x1]), in_[b.x2]);
462 add_output(b.s, H,
true);
466 std::sort(all_x1.begin(), all_x1.end());
467 std::sort(all_x2.begin(), all_x2.end());
469 bool symmetric = all_x1 == all_x2;
471 std::vector<MatType> x1(all_x1.size()), x2(all_x2.size());
472 for (
size_t i = 0; i < x1.size(); ++i) x1[i] = in_.at(all_x1[i]);
473 for (
size_t i = 0; i < x2.size(); ++i) x2[i] = in_.at(all_x2[i]);
477 H = hessian(out_.at(f), vertcat(x1));
479 H = jacobian(gradient(out_.at(f), vertcat(x1)), vertcat(x2));
482 std::vector<std::vector<MatType>> H_all = blocksplit(H, offset(x1), offset(x2));
484 for (
auto &&b1 : hess_) {
485 if (b1.f == f && !b1.calculated) {
487 auto it_x1 = std::find(all_x1.begin(), all_x1.end(), b1.x1);
488 auto it_x2 = std::find(all_x2.begin(), all_x2.end(), b1.x2);
489 if (it_x1 != all_x1.end() && it_x2 != all_x2.end()) {
491 const MatType& Hb = H_all.at(it_x1 - all_x1.begin()).at(it_x2 - all_x2.begin());
492 add_output(b1.s, Hb,
true);
493 b1.calculated =
true;
494 }
else if (!symmetric) {
496 it_x1 = std::find(all_x1.begin(), all_x1.end(), b1.x2);
497 it_x2 = std::find(all_x2.begin(), all_x2.end(), b1.x1);
498 if (it_x1 != all_x1.end() && it_x2 != all_x2.end()) {
500 const MatType& Hb = H_all.at(it_x1 - all_x1.begin()).at(it_x2 - all_x2.begin());
501 add_output(b1.s, Hb.T(),
true);
502 b1.calculated =
true;
508 }
catch (std::exception& e) {
509 std::stringstream ss;
510 ss <<
"Calculating Hessian of " << oname_.at(f) <<
" w.r.t. " << iname(all_x1) <<
" and "
511 << iname(all_x2) <<
": " << e.what();
512 casadi_error(ss.str());
517 template<
typename MatType>
520 for (
auto &&b : hess_) {
521 if (is_diff_out_.at(b.f) && is_diff_in_.at(b.x1) && is_diff_in_.at(b.x2)) {
522 b.calculated =
false;
524 add_output(b.s, MatType(in_[b.x1].numel(), in_[b.x2].numel()),
false);
528 casadi_assert(out_.at(b.f).is_scalar(),
529 "Can only take Hessian of scalar expression.");
532 for (
auto &&b : hess_) {
534 if (b.calculated)
continue;
536 calculate_hess(opts, b.f);
540 template<
typename MatType>
543 for (
size_t k = 0; k < out_.size(); ++k) {
544 Sparsity sp = is_diff_out_[k] ? out_.at(k).sparsity() :
Sparsity(out_.at(k).size());
545 add_input(
"lam:" + oname_[k], MatType::sym(
"lam_" + oname_[k], sp),
true);
550 for (
auto j : i.second) {
551 lc +=
dot(in_.at(imap_.at(
"lam:" + j)), out_.at(omap_.at(j)));
553 add_output(i.first, lc,
true);
557 template<
typename MatType>
562 }
catch (std::exception& e) {
563 casadi_error(
"Forward mode AD failed:\n" +
str(e.what()));
569 }
catch (std::exception& e) {
570 casadi_error(
"Reverse mode AD failed:\n" +
str(e.what()));
576 }
catch (std::exception& e) {
577 casadi_error(
"Jacobian generation failed:\n" +
str(e.what()));
582 calculate_grad(opts);
583 }
catch (std::exception& e) {
584 casadi_error(
"Gradient generation failed:\n" +
str(e.what()));
589 calculate_hess(opts);
590 }
catch (std::exception& e) {
591 casadi_error(
"Hessian generation failed:\n" +
str(e.what()));
595 template<
typename MatType>
597 auto it = imap_.find(s);
598 casadi_assert(it!=imap_.end(),
"Cannot retrieve \"" + s +
"\"");
599 return in_.at(it->second);
602 template<
typename MatType>
605 auto it = omap_.find(s);
606 if (it!=omap_.end())
return out_.at(it->second);
609 casadi_assert(
has_prefix(s),
"Cannot process \"" + s +
"\"");
610 std::pair<std::string, std::string> ss = split_prefix(s);
611 std::string a = ss.first;
612 MatType r = get_output(ss.second);
615 if (a==
"transpose") {
617 }
else if (a==
"triu") {
619 }
else if (a==
"tril") {
621 }
else if (a==
"densify") {
623 }
else if (a==
"sym") {
624 casadi_warning(
"Attribute 'sym' has been deprecated. Hessians are symmetric by default.");
626 }
else if (a==
"withdiag") {
629 casadi_error(
"Cannot process attribute \"" + a +
"\"");
634 template<
typename MatType>
637 return s.find(
':') < s.size();
640 template<
typename MatType>
644 casadi_assert_dev(!s.empty());
645 size_t pos = s.find(
':');
646 casadi_assert(pos<s.size(),
"Cannot process \"" + s +
"\"");
647 return std::make_pair(s.substr(0, pos), s.substr(pos+1, std::string::npos));
650 template<
typename MatType>
652 std::vector<std::string> ret;
653 for (
size_t i : ind) ret.push_back(iname_.at(i));
657 template<
typename MatType>
659 std::vector<std::string> ret;
660 for (
size_t i : ind) ret.push_back(oname_.at(i));
664 template<
typename MatType>
666 auto iind = imap_.find(s);
667 casadi_assert(iind != imap_.end(),
668 "Cannot process \"" + s +
"\" as input. Available: " +
join(oname()) +
".");
672 template<
typename MatType>
674 auto oind = omap_.find(s);
675 casadi_assert(oind != omap_.end(),
676 "Cannot process \"" + s +
"\" as output. Available: " +
join(oname()) +
".");
680 template<
typename MatType>
684 size_t pos = s2.find(
':');
685 if (pos < s2.size()) {
686 b.
f = omap(s2.substr(0, pos));
687 b.
x = imap(s2.substr(pos+1, std::string::npos));
692 template<
typename MatType>
696 size_t pos1 = s2.find(
':');
697 if (pos1 < s2.size()) {
698 size_t pos2 = s2.find(
':', pos1 + 1);
699 if (pos2 < s2.size()) {
700 b.
f = omap(s2.substr(0, pos1));
701 b.
x1 = imap(s2.substr(pos1 + 1, pos2 - pos1 - 1));
702 b.
x2 = imap(s2.substr(pos2 + 1, std::string::npos));
bool has_in(const std::string &s) const
std::vector< HBlock >::iterator find_hess(size_t f, size_t x1, size_t x2)
void calculate_fwd(const Dict &opts)
std::vector< std::string > iname_
std::map< std::string, size_t > imap_
std::vector< size_t > fwd_out_
std::vector< MatType > out_
std::vector< bool > is_diff_out_
std::vector< Block >::iterator find_jac(size_t f, size_t x)
static std::pair< std::string, std::string > split_prefix(const std::string &s)
MatType get_output(const std::string &s)
std::vector< Block > jac_
std::string request_output(const std::string &s)
std::vector< bool > is_diff_in_
bool has_out(const std::string &s) const
size_t imap(const std::string &s) const
void calculate_adj(const Dict &opts)
void add_dual(const Function::AuxOut &aux)
std::vector< size_t > fwd_in_
const std::vector< std::string > & iname() const
std::string request_input(const std::string &s)
std::vector< size_t > adj_in_
std::vector< MatType > in_
void calculate_grad(const Dict &opts)
void calculate(const Dict &opts=Dict())
void calculate_jac(const Dict &opts)
void calculate_hess(const Dict &opts, size_t f)
Block block(const std::string &s1, const std::string &s) const
MatType get_input(const std::string &s)
static bool has_prefix(const std::string &s)
HBlock hblock(const std::string &s1, const std::string &s) const
std::vector< Block > grad_
std::vector< std::string > oname_
std::map< std::string, size_t > omap_
void add_input(const std::string &s, const MatType &e, bool is_diff)
size_t omap(const std::string &s) const
void add_output(const std::string &s, const MatType &e, bool is_diff)
std::vector< size_t > adj_out_
std::vector< HBlock > hess_
const std::vector< std::string > & oname() const
std::map< std::string, std::vector< std::string > > AuxOut
static Sparsity diag(casadi_int nrow)
Create diagonal sparsity pattern *.
std::string join(const std::vector< std::string > &l, const std::string &delim)
bool has_prefix(const std::string &s)
CASADI_EXPORT std::string replace(const std::string &s, const std::string &p, const std::string &r)
Replace all occurences of p with r in s.
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
T dot(const std::vector< T > &a, const std::vector< T > &b)
std::vector< T > reverse(const std::vector< T > &v)
Reverse a list.