26 #ifndef CASADI_X_FUNCTION_HPP
27 #define CASADI_X_FUNCTION_HPP
30 #include "function_internal.hpp"
31 #include "factory.hpp"
32 #include "serializing_stream.hpp"
35 #include <unordered_map>
36 #define SPARSITY_MAP std::unordered_map
39 #define CASADI_THROW_ERROR(FNAME, WHAT) \
40 throw CasadiException("Error in XFunction::" FNAME " for '" + this->name_ + "' "\
41 "[" + this->class_name() + "] at " + CASADI_WHERE + ":\n"\
56 template<
typename DerivedType,
typename MatType,
typename NodeType>
64 const std::vector<MatType>& ex_in,
65 const std::vector<MatType>& ex_out,
66 const std::vector<std::string>& name_in,
67 const std::vector<std::string>& name_out);
98 std::vector<MatType>
jac(
const Dict& opts)
const;
103 bool is_a(
const std::string& type,
bool recursive)
const override {
109 const std::vector<std::string>& s_in,
110 const std::vector<std::string>& s_out,
112 const Dict& opts)
const override;
121 const std::vector<std::string>& s_out,
122 casadi_int order,
bool tr=
false)
const override;
130 const std::vector<std::string>& inames,
131 const std::vector<std::string>& onames,
132 const Dict& opts)
const override;
141 const std::vector<std::string>& inames,
142 const std::vector<std::string>& onames,
143 const Dict& opts)
const override;
152 const std::vector<std::string>& inames,
153 const std::vector<std::string>& onames,
154 const Dict& opts)
const override;
160 Function slice(
const std::string& name,
const std::vector<casadi_int>& order_in,
161 const std::vector<casadi_int>& order_out,
const Dict& opts)
const override;
177 std::ostream &stream,
const Dict& options)
const override;
183 std::ostream &stream,
const Dict& options)
const = 0;
193 virtual bool isInput(
const std::vector<MatType>& arg)
const;
196 virtual bool should_inline(
bool with_sx,
bool always_inline,
bool never_inline)
const = 0;
202 const std::vector<MatType>& res,
203 const std::vector<std::vector<MatType> >& fseed,
204 std::vector<std::vector<MatType> >& fsens,
205 bool always_inline,
bool never_inline)
const override;
211 const std::vector<MatType>& res,
212 const std::vector<std::vector<MatType> >& aseed,
213 std::vector<std::vector<MatType> >& asens,
214 bool always_inline,
bool never_inline)
const override;
267 template<
typename DerivedType,
typename MatType,
typename NodeType>
270 const std::vector<MatType>& ex_in,
271 const std::vector<MatType>& ex_out,
272 const std::vector<std::string>& name_in,
273 const std::vector<std::string>& name_out)
276 if (!name_in.empty()) {
277 casadi_assert(ex_in.size()==name_in.size(),
278 "Mismatching number of input names");
282 if (!name_out.empty()) {
283 casadi_assert(ex_out.size()==name_out.size(),
284 "Mismatching number of output names");
289 template<
typename DerivedType,
typename MatType,
typename NodeType>
297 template<
typename DerivedType,
typename MatType,
typename NodeType>
300 s.
unpack(
"XFunction::out", out_);
303 template<
typename DerivedType,
typename MatType,
typename NodeType>
306 s.
pack(
"XFunction::out", out_);
309 template<
typename DerivedType,
typename MatType,
typename NodeType>
314 s.
pack(
"XFunction::in", in_);
318 template<
typename DerivedType,
typename MatType,
typename NodeType>
323 bool allow_duplicate_io_names =
false;
325 for (
auto&& op : opts) {
326 if (op.first==
"allow_duplicate_io_names") {
327 allow_duplicate_io_names = op.second;
331 if (verbose_) casadi_message(name_ +
"::init");
333 for (casadi_int i=0; i<n_in_; ++i) {
334 if (in_.at(i).nnz()>0 && !in_.at(i).is_valid_input()) {
335 casadi_error(
"For " + this->name_ +
": Xfunction input arguments must be purely symbolic."
336 "\nArgument " +
str(i) +
"(" + name_in_[i] +
") is not symbolic.");
339 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
340 std::lock_guard<std::mutex> lock(MatType::get_mutex_temp());
344 bool has_duplicates =
false;
345 for (
auto&& i : in_) {
346 if (i.has_duplicates()) {
347 has_duplicates =
true;
352 for (
auto&& i : in_) i.reset_input();
354 if (has_duplicates) {
356 s <<
"The input expressions are not independent:\n";
357 for (casadi_int iind=0; iind<in_.size(); ++iind) {
358 s << iind <<
": " << in_[iind] <<
"\n";
360 casadi_error(s.str());
363 if (!allow_duplicate_io_names) {
365 std::hash<std::string> hasher;
366 std::vector<size_t> iohash;
367 iohash.reserve(name_in_.size() + name_out_.size());
368 for (
const std::string& s : name_in_) iohash.push_back(hasher(s));
369 for (
const std::string& s : name_out_) iohash.push_back(hasher(s));
370 std::sort(iohash.begin(), iohash.end());
373 for (
size_t h : iohash) {
376 std::vector<std::string> io_names;
377 io_names.reserve(iohash.size());
378 for (
const std::string& s : name_in_) io_names.push_back(s);
379 for (
const std::string& s : name_out_) io_names.push_back(s);
380 std::sort(io_names.begin(), io_names.end());
383 for (std::string h : io_names) {
384 if (h == prev) casadi_error(
"Duplicate IO name: " + h +
". "
385 "To ignore this error, set 'allow_duplicate_io_names' option.");
394 template<
typename DerivedType,
typename MatType,
typename NodeType>
396 std::stack<NodeType*>& s, std::vector<NodeType*>& nodes) {
399 NodeType* t = s.top();
401 if (t && t->temp>=0) {
403 casadi_int next_dep = t->temp++;
405 if (next_dep < t->n_dep()) {
407 s.push(
static_cast<NodeType*
>(t->dep(next_dep).get()));
423 template<
typename DerivedType,
typename MatType,
typename NodeType>
428 bool compact =
false;
429 bool symmetric =
false;
430 bool allow_forward =
true;
431 bool allow_reverse =
true;
432 for (
auto&& op : opts) {
433 if (op.first==
"compact") {
435 }
else if (op.first==
"symmetric") {
436 symmetric = op.second;
437 }
else if (op.first==
"allow_forward") {
438 allow_forward = op.second;
439 }
else if (op.first==
"allow_reverse") {
440 allow_reverse = op.second;
441 }
else if (op.first==
"verbose") {
444 casadi_error(
"No such Jacobian option: " + std::string(op.first));
449 std::vector<MatType> ret(n_in_ * n_out_);
452 if (nnz_in() == 0 || nnz_out() == 0) {
453 for (casadi_int i = 0; i < n_out_; ++i) {
454 for (casadi_int j = 0; j < n_in_; ++j) {
456 ret[i * n_in_ + j] = MatType(nnz_out(i), nnz_in(j));
458 ret[i * n_in_ + j] = MatType(numel_out(i), numel_in(j));
466 casadi_int iind = 0, oind = 0;
467 casadi_assert(n_in_ == 1,
"Not implemented");
468 casadi_assert(n_out_ == 1,
"Not implemented");
471 ret.at(0) = MatType::zeros(jac_sparsity(0, 0,
false, symmetric).
T());
472 if (verbose_) casadi_message(
"Allocated return value");
475 if (ret.at(0).nnz()==0) {
476 ret.at(0) = ret.at(0).T();
482 get_partition(iind, oind, D1, D2,
true, symmetric, allow_forward, allow_reverse);
483 if (verbose_) casadi_message(
"Graph coloring completed");
490 casadi_int max_nfdir = max_num_dir_;
491 casadi_int max_nadir = max_num_dir_;
494 casadi_int offset_nfdir = 0, offset_nadir = 0;
497 std::vector<MatType> res(out_);
500 std::vector<std::vector<MatType> > fseed, aseed, fsens, asens;
503 Sparsity jsp = jac_sparsity(0, 0,
true, symmetric).
T();
504 const casadi_int* jsp_colind = jsp.
colind();
505 const casadi_int* jsp_row = jsp.
row();
508 std::vector<casadi_int> input_col = sparsity_in_.at(iind).get_col();
509 const casadi_int* input_row = sparsity_in_.at(iind).row();
512 std::vector<casadi_int> output_col = sparsity_out_.at(oind).get_col();
513 const casadi_int* output_row = sparsity_out_.at(oind).row();
516 if (verbose_) casadi_message(
"jac transposes and mapping");
517 std::vector<casadi_int> mapping;
524 std::vector<casadi_int> nzmap, nzmap2;
527 std::vector<casadi_int> adds, adds2;
530 std::vector<casadi_int> tmp;
533 casadi_int progress = -10;
536 casadi_int nsweep_fwd = nfdir/max_nfdir;
537 if (nfdir%max_nfdir>0) nsweep_fwd++;
538 casadi_int nsweep_adj = nadir/max_nadir;
539 if (nadir%max_nadir>0) nsweep_adj++;
540 casadi_int nsweep = std::max(nsweep_fwd, nsweep_adj);
542 casadi_message(
str(nsweep) +
" sweeps needed for " +
str(nfdir) +
" forward and "
543 +
str(nadir) +
" reverse directions");
547 std::vector<casadi_int> seed_col, seed_row;
550 for (casadi_int s=0; s<nsweep; ++s) {
553 casadi_int progress_new = (s*100)/nsweep;
555 if (progress_new / 10 > progress / 10) {
556 progress = progress_new;
557 casadi_message(
str(progress) +
" %");
562 casadi_int nfdir_batch = std::min(nfdir - offset_nfdir, max_nfdir);
563 casadi_int nadir_batch = std::min(nadir - offset_nadir, max_nadir);
566 fseed.resize(nfdir_batch);
567 for (casadi_int d=0; d<nfdir_batch; ++d) {
573 for (casadi_int el = D1.
colind(offset_nfdir+d); el<D1.
colind(offset_nfdir+d+1); ++el) {
576 casadi_int c = D1.
row(el);
579 seed_col.push_back(input_col[c]);
580 seed_row.push_back(input_row[c]);
584 fseed[d].resize(n_in_);
585 for (casadi_int ind=0; ind<fseed[d].size(); ++ind) {
586 casadi_int nrow = size1_in(ind), ncol = size2_in(ind);
588 fseed[d][ind] = MatType::ones(
Sparsity::triplet(nrow, ncol, seed_row, seed_col));
590 fseed[d][ind] = MatType(nrow, ncol);
596 aseed.resize(nadir_batch);
597 for (casadi_int d=0; d<nadir_batch; ++d) {
603 for (casadi_int el = D2.
colind(offset_nadir+d); el<D2.
colind(offset_nadir+d+1); ++el) {
606 casadi_int c = D2.
row(el);
609 seed_col.push_back(output_col[c]);
610 seed_row.push_back(output_row[c]);
614 aseed[d].resize(n_out_);
615 for (casadi_int ind=0; ind<aseed[d].size(); ++ind) {
616 casadi_int nrow = size1_out(ind), ncol = size2_out(ind);
618 aseed[d][ind] = MatType::ones(
Sparsity::triplet(nrow, ncol, seed_row, seed_col));
620 aseed[d][ind] = MatType(nrow, ncol);
626 fsens.resize(nfdir_batch);
627 for (casadi_int d=0; d<nfdir_batch; ++d) {
629 fsens[d].resize(n_out_);
630 for (casadi_int oind=0; oind<fsens[d].size(); ++oind) {
631 fsens[d][oind] = MatType::zeros(sparsity_out_.at(oind));
636 asens.resize(nadir_batch);
637 for (casadi_int d=0; d<nadir_batch; ++d) {
639 asens[d].resize(n_in_);
640 for (casadi_int ind=0; ind<asens[d].size(); ++ind) {
641 asens[d][ind] = MatType::zeros(sparsity_in_.at(ind));
646 if (!fseed.empty()) {
647 casadi_assert_dev(aseed.empty());
648 if (verbose_) casadi_message(
"Calling 'ad_forward'");
649 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
650 if (verbose_) casadi_message(
"Back from 'ad_forward'");
651 }
else if (!aseed.empty()) {
652 casadi_assert_dev(fseed.empty());
653 if (verbose_) casadi_message(
"Calling 'ad_reverse'");
654 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
655 if (verbose_) casadi_message(
"Back from 'ad_reverse'");
659 for (casadi_int d=0; d<nfdir_batch; ++d) {
661 if (fsens[d][oind].nnz()==0) {
668 tmp.resize(nnz_out(oind));
669 std::fill(tmp.begin(), tmp.end(), 0);
672 for (casadi_int el = D1.
colind(offset_nfdir+d); el<D1.
colind(offset_nfdir+d+1); ++el) {
675 casadi_int c = D1.
row(el);
678 for (casadi_int el_jsp=jsp_colind[c]; el_jsp<jsp_colind[c+1]; ++el_jsp) {
679 tmp[jsp_row[el_jsp]]++;
685 sparsity_out_.at(oind).find(nzmap);
686 fsens[d][oind].sparsity().get_nz(nzmap);
689 sparsity_in_.at(iind).find(nzmap2);
690 fsens[d][oind].sparsity().get_nz(nzmap2);
694 adds.resize(fsens[d][oind].nnz());
695 std::fill(adds.begin(), adds.end(), -1);
697 adds2.resize(adds.size());
698 std::fill(adds2.begin(), adds2.end(), -1);
702 for (casadi_int el = D1.
colind(offset_nfdir+d); el<D1.
colind(offset_nfdir+d+1); ++el) {
705 casadi_int c = D1.
row(el);
712 for (casadi_int el_out = jsp_trans.
colind(c); el_out<jsp_trans.
colind(c+1); ++el_out) {
715 casadi_int r_out = jsp_trans.
row(el_out);
718 casadi_int f_out = nzmap[r_out];
719 if (f_out<0)
continue;
722 casadi_int elJ = mapping[el_out];
726 adds[f_out] = el_out;
737 tmp.resize(adds.size());
739 for (casadi_int i=0; i<adds.size(); ++i) {
749 ret.at(0).nz(adds) = fsens[d][oind].nz(tmp);
753 tmp.resize(adds2.size());
755 for (casadi_int i=0; i<adds2.size(); ++i) {
757 adds2[sz] = adds2[i];
765 ret.at(0).nz(adds2) = fsens[d][oind].nz(tmp);
770 for (casadi_int d=0; d<nadir_batch; ++d) {
772 if (asens[d][iind].nnz()==0) {
777 sparsity_in_.at(iind).find(nzmap);
778 asens[d][iind].sparsity().get_nz(nzmap);
781 for (casadi_int el = D2.
colind(offset_nadir+d); el<D2.
colind(offset_nadir+d+1); ++el) {
784 casadi_int r = D2.
row(el);
787 for (casadi_int elJ = jsp.
colind(r); elJ<jsp.
colind(r+1); ++elJ) {
790 casadi_int inz = jsp.
row(elJ);
793 casadi_int anz = nzmap[inz];
797 ret.at(0).nz(elJ) = asens[d][iind].nz(anz);
803 offset_nfdir += nfdir_batch;
804 offset_nadir += nadir_batch;
808 for (MatType& Jb : ret) Jb = Jb.T();
811 }
catch (std::exception& e) {
812 CASADI_THROW_ERROR(
"jac", e.what());
816 template<
typename DerivedType,
typename MatType,
typename NodeType>
819 const std::vector<std::string>& inames,
820 const std::vector<std::string>& onames,
821 const Dict& opts)
const {
824 std::vector<std::vector<MatType> > fseed = fwd_seed<MatType>(nfwd), fsens;
827 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
828 casadi_assert_dev(fsens.size()==fseed.size());
831 std::vector<MatType> ret_in(inames.size());
832 std::copy(in_.begin(), in_.end(), ret_in.begin());
833 for (casadi_int i=0; i<n_out_; ++i) {
834 ret_in.at(n_in_+i) = MatType::sym(inames[n_in_+i],
Sparsity(out_.at(i).size()));
836 std::vector<MatType> v(nfwd);
837 for (casadi_int i=0; i<n_in_; ++i) {
838 for (casadi_int d=0; d<nfwd; ++d) v[d] = fseed[d][i];
839 ret_in.at(n_in_ + n_out_ + i) = horzcat(v);
843 std::vector<MatType> ret_out(onames.size());
844 for (casadi_int i=0; i<n_out_; ++i) {
845 if (is_diff_out_[i]) {
847 for (casadi_int d=0; d<nfwd; ++d) v[d] = fsens[d][i];
848 ret_out.at(i) = ensure_stacked(horzcat(v), sparsity_out(i), nfwd);
851 ret_out.at(i) = MatType(size1_out(i), size2_out(i) * nfwd);
856 if (opts.find(
"is_diff_in")==opts.end())
857 options[
"is_diff_in"] =
join(is_diff_in_, is_diff_out_, is_diff_in_);
858 if (opts.find(
"is_diff_out")==opts.end())
859 options[
"is_diff_out"] = is_diff_out_;
860 options[
"allow_duplicate_io_names"] =
true;
862 return Function(name, ret_in, ret_out, inames, onames, options);
863 }
catch (std::exception& e) {
864 CASADI_THROW_ERROR(
"get_forward", e.what());
868 template<
typename DerivedType,
typename MatType,
typename NodeType>
871 const std::vector<std::string>& inames,
872 const std::vector<std::string>& onames,
873 const Dict& opts)
const {
876 std::vector<std::vector<MatType> > aseed = symbolicAdjSeed(nadj, out_), asens;
879 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
882 std::vector<MatType> ret_in(inames.size());
883 std::copy(in_.begin(), in_.end(), ret_in.begin());
884 for (casadi_int i=0; i<n_out_; ++i) {
885 ret_in.at(n_in_ + i) = MatType::sym(inames[n_in_+i],
Sparsity(out_.at(i).size()));
887 std::vector<MatType> v(nadj);
888 for (casadi_int i=0; i<n_out_; ++i) {
889 for (casadi_int d=0; d<nadj; ++d) v[d] = aseed[d][i];
890 ret_in.at(n_in_ + n_out_ + i) = horzcat(v);
894 std::vector<MatType> ret_out(onames.size());
895 for (casadi_int i=0; i<n_in_; ++i) {
896 if (is_diff_in_[i]) {
898 for (casadi_int d=0; d<nadj; ++d) v[d] = asens[d][i];
899 ret_out.at(i) = ensure_stacked(horzcat(v), sparsity_in(i), nadj);
902 ret_out.at(i) = MatType(size1_in(i), size2_in(i) * nadj);
907 if (opts.find(
"is_diff_in")==opts.end())
908 options[
"is_diff_in"] =
join(is_diff_in_, is_diff_out_, is_diff_out_);
909 if (opts.find(
"is_diff_out")==opts.end())
910 options[
"is_diff_out"] = is_diff_in_;
912 options[
"allow_duplicate_io_names"] =
true;
914 return Function(name, ret_in, ret_out, inames, onames, options);
915 }
catch (std::exception& e) {
916 CASADI_THROW_ERROR(
"get_reverse", e.what());
920 template<
typename DerivedType,
typename MatType,
typename NodeType>
923 const std::vector<std::string>& inames,
924 const std::vector<std::string>& onames,
925 const Dict& opts)
const {
927 Dict tmp_options = generate_options(
"tmp");
928 tmp_options[
"allow_free"] =
true;
929 tmp_options[
"allow_duplicate_io_names"] =
true;
931 Function tmp(
"flattened_" + name, {veccat(in_)}, {veccat(out_)}, tmp_options);
934 MatType J = tmp.
get<DerivedType>()->jac(
Dict()).at(0);
937 std::vector<casadi_int> r_offset = {0}, c_offset = {0};
938 for (
auto& e : out_) r_offset.push_back(r_offset.back() + e.numel());
939 for (
auto& e : in_) c_offset.push_back(c_offset.back() + e.numel());
940 auto Jblocks = MatType::blocksplit(J, r_offset, c_offset);
943 std::vector<MatType> ret_out;
944 ret_out.reserve(onames.size());
945 for (casadi_int i = 0; i < n_out_; ++i) {
946 for (casadi_int j = 0; j < n_in_; ++j) {
947 MatType b = Jblocks.at(i).at(j);
948 if (!is_diff_out_.at(i) || !is_diff_in_.at(j)) {
949 b = MatType(b.size());
951 ret_out.push_back(b);
956 std::vector<MatType> ret_in(inames.size());
957 std::copy(in_.begin(), in_.end(), ret_in.begin());
958 for (casadi_int i=0; i<n_out_; ++i) {
959 ret_in.at(n_in_+i) = MatType::sym(inames[n_in_+i],
Sparsity(out_.at(i).size()));
963 options[
"allow_free"] =
true;
964 options[
"allow_duplicate_io_names"] =
true;
967 return Function(name, ret_in, ret_out, inames, onames, options);
968 }
catch (std::exception& e) {
969 CASADI_THROW_ERROR(
"get_jacobian", e.what());
973 template<
typename DerivedType,
typename MatType,
typename NodeType>
975 ::slice(
const std::string& name,
const std::vector<casadi_int>& order_in,
976 const std::vector<casadi_int>& order_out,
const Dict& opts)
const {
978 std::vector<MatType> ret_in, ret_out;
979 std::vector<std::string> ret_in_name, ret_out_name;
982 for (casadi_int k : order_in) {
983 ret_in.push_back(in_.at(k));
984 ret_in_name.push_back(name_in_.at(k));
988 for (casadi_int k : order_out) {
989 ret_out.push_back(out_.at(k));
990 ret_out_name.push_back(name_out_.at(k));
994 return Function(name, ret_in, ret_out,
995 ret_in_name, ret_out_name, opts);
998 template<
typename DerivedType,
typename MatType,
typename NodeType>
1002 casadi_assert(!has_free(),
"export_code needs a Function without free variables");
1004 casadi_assert(lang==
"matlab",
"Only matlab language supported for now.");
1007 stream <<
"function [varargout] = " << name_ <<
"(varargin)" << std::endl;
1010 for (casadi_int i=0;i<n_out_;++i) {
1011 stream <<
" argout_" << i <<
" = cell(" << nnz_out(i) <<
",1);" << std::endl;
1015 opts[
"indent_level"] = 1;
1016 export_code_body(lang, stream, opts);
1019 for (casadi_int i=0;i<n_out_;++i) {
1020 const Sparsity& out = sparsity_out_.at(i);
1023 stream <<
" varargout{" << i+1 <<
"} = reshape(vertcat(argout_" << i <<
"{:}), ";
1024 stream << out.
size1() <<
", " << out.
size2() <<
");" << std::endl;
1028 opts[
"name"] =
"sp";
1029 opts[
"indent_level"] = 1;
1030 opts[
"as_matrix"] =
false;
1032 stream <<
" varargout{" << i+1 <<
"} = ";
1033 stream <<
"sparse(sp_i, sp_j, vertcat(argout_" << i <<
"{:}), sp_m, sp_n);" << std::endl;
1038 stream <<
"end" << std::endl;
1039 stream <<
"function y=nonzeros_gen(x)" << std::endl;
1040 stream <<
" if isa(x,'casadi.SX') || isa(x,'casadi.MX') || isa(x,'casadi.DM')" << std::endl;
1041 stream <<
" y = x{:};" << std::endl;
1042 stream <<
" elseif isa(x,'sdpvar')" << std::endl;
1043 stream <<
" b = getbase(x);" << std::endl;
1044 stream <<
" f = find(sum(b~=0,2));" << std::endl;
1045 stream <<
" y = sdpvar(length(f),1,[],getvariables(x),b(f,:));" << std::endl;
1046 stream <<
" else" << std::endl;
1047 stream <<
" y = nonzeros(x);" << std::endl;
1048 stream <<
" end" << std::endl;
1049 stream <<
"end" << std::endl;
1050 stream <<
"function y=if_else_zero_gen(c,e)" << std::endl;
1051 stream <<
" if isa(c+e,'casadi.SX') || isa(c+e,'casadi.MX') "
1052 "|| isa(c+e,'casadi.DM')" << std::endl;
1053 stream <<
" y = if_else(c, e, 0);" << std::endl;
1054 stream <<
" else" << std::endl;
1055 stream <<
" if c" << std::endl;
1056 stream <<
" y = x;" << std::endl;
1057 stream <<
" else" << std::endl;
1058 stream <<
" y = 0;" << std::endl;
1059 stream <<
" end" << std::endl;
1060 stream <<
" end" << std::endl;
1061 stream <<
"end" << std::endl;
1066 template<
typename DerivedType,
typename MatType,
typename NodeType>
1071 const casadi_int checking_depth = 2;
1072 for (casadi_int i=0; i<arg.size(); ++i) {
1073 if (!
is_equal(arg[i], in_[i], checking_depth)) {
1080 template<
typename DerivedType,
typename MatType,
typename NodeType>
1083 const std::vector<MatType>& res,
1084 const std::vector<std::vector<MatType> >& fseed,
1085 std::vector<std::vector<MatType> >& fsens,
1086 bool always_inline,
bool never_inline)
const {
1087 casadi_assert(!(always_inline && never_inline),
"Inconsistent options");
1088 if (!should_inline(MatType::type_name()==
"SX", always_inline, never_inline)) {
1091 always_inline, never_inline);
1095 if (fseed.empty()) {
1103 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
1106 Function f(
"tmp_call_forward", arg, res);
1107 static_cast<DerivedType *
>(f.
get())->ad_forward(fseed, fsens);
1111 template<
typename DerivedType,
typename MatType,
typename NodeType>
1114 const std::vector<MatType>& res,
1115 const std::vector<std::vector<MatType> >& aseed,
1116 std::vector<std::vector<MatType> >& asens,
1117 bool always_inline,
bool never_inline)
const {
1118 casadi_assert(!(always_inline && never_inline),
"Inconsistent options");
1119 if (!should_inline(MatType::type_name()==
"SX", always_inline, never_inline)) {
1122 always_inline, never_inline);
1126 if (aseed.empty()) {
1134 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
1137 Function f(
"tmp_call_reverse", arg, res);
1138 static_cast<DerivedType *
>(f.
get())->ad_reverse(aseed, asens);
1142 template<
typename DerivedType,
typename MatType,
typename NodeType>
1144 factory(
const std::string& name,
1145 const std::vector<std::string>& s_in,
1146 const std::vector<std::string>& s_out,
1148 const Dict& opts)
const {
1150 Dict g_ops = generate_options(
"clone");
1152 f_options[
"helper_options"] = g_ops;
1153 f_options[
"final_options"] = g_ops;
1158 final_options[
"allow_duplicate_io_names"] =
true;
1162 for (casadi_int i=0; i<in_.size(); ++i) f.
add_input(name_in_[i], in_[i], is_diff_in_[i]);
1163 for (casadi_int i=0; i<out_.size(); ++i) f.
add_output(name_out_[i], out_[i], is_diff_out_[i]);
1167 std::vector<std::string> ret_iname;
1168 for (
const std::string& s : s_in) {
1172 casadi_error(
"Cannot process factory input \"" + s +
"\":" + ex.
what());
1177 std::vector<std::string> ret_oname;
1178 for (
const std::string& s : s_out) {
1182 casadi_error(
"Cannot process factory output \"" + s +
"\":" + ex.
what());
1190 std::vector<MatType> ret_in;
1191 ret_in.reserve(s_in.size());
1192 for (
const std::string& s : s_in) ret_in.push_back(f.
get_input(s));
1195 std::vector<MatType> ret_out;
1196 ret_out.reserve(s_out.size());
1197 for (
const std::string& s : s_out) ret_out.push_back(f.
get_output(s));
1200 Dict final_options_allow_free = final_options;
1201 final_options_allow_free[
"allow_free"] =
true;
1202 final_options_allow_free[
"allow_duplicate_io_names"] =
true;
1203 Function ret(name, ret_in, ret_out, ret_iname, ret_oname, final_options_allow_free);
1207 std::vector<MatType> free_in = MatType::get_free(ret);
1208 std::vector<MatType> free_sub = free_in;
1209 for (
auto&& e : free_sub) e = MatType::zeros(e.sparsity());
1210 ret_out = substitute(ret_out, free_in, free_sub);
1211 ret =
Function(name, ret_in, ret_out, ret_iname, ret_oname, final_options);
1216 template<
typename DerivedType,
typename MatType,
typename NodeType>
1218 which_depends(
const std::string& s_in,
const std::vector<std::string>& s_out,
1219 casadi_int order,
bool tr)
const {
1222 auto it = std::find(name_in_.begin(), name_in_.end(), s_in);
1223 casadi_assert_dev(it!=name_in_.end());
1224 MatType arg = in_.at(it-name_in_.begin());
1227 std::vector<MatType> res;
1228 for (
auto&& s : s_out) {
1229 it = std::find(name_out_.begin(), name_out_.end(), s);
1230 casadi_assert_dev(it!=name_out_.end());
1231 res.push_back(out_.at(it-name_out_.begin()));
1235 return MatType::which_depends(veccat(res), arg, order, tr);
1238 template<
typename MatType>
1240 Dict opts{{
"max_io", 0}, {
"allow_free",
true}};
1245 template<
typename MatType>
1247 casadi_int order,
bool tr) {
1249 if (expr.is_empty() || var.is_empty()) {
1250 return std::vector<bool>(tr? expr.numel() : var.numel(),
false);
1256 casadi_assert(order==1 || order==2,
1257 "which_depends: order argument must be 1 or 2, got " +
str(order) +
" instead.");
1259 MatType v = MatType::sym(
"v", var.sparsity());
1260 for (casadi_int i=1;i<order;++i) {
1261 e = jtimes(e, var, v);
1264 Dict opts{{
"max_io", 0}, {
"allow_free",
true}};
1267 std::vector<bvec_t> seed(tr? f.
nnz_in(0) : f.
nnz_out(0), 1);
1268 std::vector<bvec_t> sens(tr? f.
nnz_out(0) : f.
nnz_in(0), 0);
1275 std::vector<bool> ret(sens.size());
1276 std::copy(sens.begin(), sens.end(), ret.begin());
1279 if (tr && e.sparsity()!=expr.sparsity()) {
1282 std::vector<casadi_int> source(sens.size());
1283 std::copy(ret.begin(), ret.end(), source.begin());
1284 std::vector<casadi_int> target(expr.nnz());
1287 std::vector<casadi_int> scratch(expr.size1());
1292 ret.resize(expr.nnz());
1293 std::copy(target.begin(), target.end(), ret.begin());
1301 #undef CASADI_THROW_ERROR
const char * what() const override
Display error.
Helper class for C code generation.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
MatType get_output(const std::string &s)
std::string request_output(const std::string &s)
void add_dual(const Function::AuxOut &aux)
std::string request_input(const std::string &s)
void calculate(const Dict &opts=Dict())
MatType get_input(const std::string &s)
void add_input(const std::string &s, const MatType &e, bool is_diff)
void add_output(const std::string &s, const MatType &e, bool is_diff)
Internal class for Function.
void init(const Dict &opts) override
Initialize.
virtual void call_forward(const std::vector< MX > &arg, const std::vector< MX > &res, const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens, bool always_inline, bool never_inline) const
Forward mode AD, virtual functions overloaded in derived classes.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
virtual void call_reverse(const std::vector< MX > &arg, const std::vector< MX > &res, const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens, bool always_inline, bool never_inline) const
Reverse mode, virtual functions overloaded in derived classes.
virtual bool is_a(const std::string &type, bool recursive) const
Check if the function is of a particular type.
std::vector< std::string > name_out_
std::vector< std::string > name_in_
Input and output scheme.
casadi_int nnz_out() const
Get number of output nonzeros.
FunctionInternal * get() const
int rev(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, int mem=0) const
Propagate sparsity backward.
bool has_free() const
Does the function have free variables.
casadi_int nnz_in() const
Get number of input nonzeros.
const std::vector< Sparsity > & jac_sparsity(bool compact=false) const
Get, if necessary generate, the sparsity of all Jacobian blocks.
std::map< std::string, std::vector< std::string > > AuxOut
bool is_null() const
Is a null pointer?
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
casadi_int size1() const
Get the number of rows.
Sparsity transpose(std::vector< casadi_int > &mapping, bool invert_mapping=false) const
Transpose the matrix and get the reordering of the non-zero entries.
Sparsity T() const
Transpose the matrix.
casadi_int size2() const
Get the number of columns.
const casadi_int * row() const
Get a reference to row-vector,.
void export_code(const std::string &lang, std::ostream &stream=casadi::uout(), const Dict &options=Dict()) const
Export matrix in specific language.
const casadi_int * colind() const
Get a reference to the colindex of all column element (see class description)
bool is_dense() const
Is dense?
static Sparsity triplet(casadi_int nrow, casadi_int ncol, const std::vector< casadi_int > &row, const std::vector< casadi_int > &col, std::vector< casadi_int > &mapping, bool invert_mapping)
Create a sparsity pattern given the nonzeros in sparse triplet form *.
Internal node class for the base class of SXFunction and MXFunction.
std::vector< bool > which_depends(const std::string &s_in, const std::vector< std::string > &s_out, casadi_int order, bool tr=false) const override
Which variables enter with some order.
std::vector< MatType > out_
Outputs of the function (needed for symbolic calculations)
void export_code(const std::string &lang, std::ostream &stream, const Dict &options) const override
Export function in a specific language.
void delayed_deserialize_members(DeserializingStream &s)
bool has_jacobian() const override
Return Jacobian of all input elements with respect to all output elements.
virtual bool should_inline(bool with_sx, bool always_inline, bool never_inline) const =0
void codegen_declarations(CodeGenerator &g) const override=0
Generate code for the declarations of the C function.
bool has_codegen() const override
Is codegen supported?
Function get_jacobian(const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Return Jacobian of all input elements with respect to all output elements.
size_t get_n_out() override
Number of function inputs and outputs.
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nadj adjoint derivatives.
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
void init(const Dict &opts) override
Initialize.
bool has_forward(casadi_int nfwd) const override
Generate a function that calculates nfwd forward derivatives.
void codegen_body(CodeGenerator &g) const override=0
Generate code for the body of the C function.
bool has_reverse(casadi_int nadj) const override
Generate a function that calculates nadj adjoint derivatives.
XFunction(const std::string &name, const std::vector< MatType > &ex_in, const std::vector< MatType > &ex_out, const std::vector< std::string > &name_in, const std::vector< std::string > &name_out)
Constructor.
virtual bool isInput(const std::vector< MatType > &arg) const
Helper function: Check if a vector equals ex_in.
bool has_spfwd() const override
Function slice(const std::string &name, const std::vector< casadi_int > &order_in, const std::vector< casadi_int > &order_out, const Dict &opts) const override
returns a new function with a selection of inputs/outputs of the original
std::vector< MatType > jac(const Dict &opts) const
Construct a complete Jacobian by compression.
std::vector< MatType > in_
Inputs of the function (needed for symbolic calculations)
~XFunction() override
Destructor.
size_t get_n_in() override
Number of function inputs and outputs.
Function factory(const std::string &name, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const Function::AuxOut &aux, const Dict &opts) const override
void delayed_serialize_members(SerializingStream &s) const
Helper functions to avoid recursion limit.
void call_reverse(const std::vector< MatType > &arg, const std::vector< MatType > &res, const std::vector< std::vector< MatType > > &aseed, std::vector< std::vector< MatType > > &asens, bool always_inline, bool never_inline) const override
Create call to (cached) derivative function, reverse mode.
virtual void export_code_body(const std::string &lang, std::ostream &stream, const Dict &options) const =0
Export function body in a specific language.
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nfwd forward derivatives.
XFunction(DeserializingStream &s)
Deserializing constructor.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
bool has_sprev() const override
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
void call_forward(const std::vector< MatType > &arg, const std::vector< MatType > &res, const std::vector< std::vector< MatType > > &fseed, std::vector< std::vector< MatType > > &fsens, bool always_inline, bool never_inline) const override
Create call to (cached) derivative function, forward mode.
static void sort_depth_first(std::stack< NodeType * > &s, std::vector< NodeType * > &nodes)
Topological sorting of the nodes based on Depth-First Search (DFS)
bool is_a(const std::string &type, bool recursive) const override
Check if the function is of a particular type.
bool is_equal(double x, double y, casadi_int depth=0)
std::string join(const std::vector< std::string > &l, const std::string &delim)
std::vector< bool > _which_depends(const MatType &expr, const MatType &var, casadi_int order, bool tr)
void casadi_project(const T1 *x, const casadi_int *sp_x, T1 *y, const casadi_int *sp_y, T1 *w)
Sparse copy: y <- x, w work vector (length >= number of rows)
Sparsity _jacobian_sparsity(const MatType &expr, const MatType &var)
void extract_from_dict_inplace(Dict &d, const std::string &key, T &value)
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
void update_dict(Dict &target, const Dict &source, bool recurse)
Update the target dictionary in place with source elements.
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.