26 #ifndef CASADI_X_FUNCTION_HPP
27 #define CASADI_X_FUNCTION_HPP
30 #include "function_internal.hpp"
31 #include "factory.hpp"
32 #include "serializing_stream.hpp"
35 #include <unordered_map>
36 #define SPARSITY_MAP std::unordered_map
39 #define CASADI_THROW_ERROR(FNAME, WHAT) \
40 throw CasadiException("Error in XFunction::" FNAME " for '" + this->name_ + "' "\
41 "[" + this->class_name() + "] at " + CASADI_WHERE + ":\n"\
56 template<
typename DerivedType,
typename MatType,
typename NodeType>
57 class CASADI_EXPORT XFunction :
public FunctionInternal {
63 XFunction(
const std::string& name,
64 const std::vector<MatType>& ex_in,
65 const std::vector<MatType>& ex_out,
66 const std::vector<std::string>& name_in,
67 const std::vector<std::string>& name_out);
72 ~XFunction()
override {
78 void init(
const Dict& opts)
override;
82 bool has_spfwd()
const override {
return true;}
83 bool has_sprev()
const override {
return true;}
89 static void sort_depth_first(std::stack<NodeType*>& s, std::vector<NodeType*>& nodes);
94 std::vector<MatType> jac(
const Dict& opts)
const;
99 bool is_a(
const std::string& type,
bool recursive)
const override {
100 return type==
"xfunction" || (recursive && FunctionInternal::is_a(type, recursive));
104 Function factory(
const std::string& name,
105 const std::vector<std::string>& s_in,
106 const std::vector<std::string>& s_out,
108 const Dict& opts)
const override;
116 std::vector<bool> which_depends(
const std::string& s_in,
117 const std::vector<std::string>& s_out,
118 casadi_int order,
bool tr=
false)
const override;
124 bool has_forward(casadi_int nfwd)
const override {
return true;}
125 Function get_forward(casadi_int nfwd,
const std::string& name,
126 const std::vector<std::string>& inames,
127 const std::vector<std::string>& onames,
128 const Dict& opts)
const override;
135 bool has_reverse(casadi_int nadj)
const override {
return true;}
136 Function get_reverse(casadi_int nadj,
const std::string& name,
137 const std::vector<std::string>& inames,
138 const std::vector<std::string>& onames,
139 const Dict& opts)
const override;
146 bool has_jacobian()
const override {
return true;}
147 Function get_jacobian(
const std::string& name,
148 const std::vector<std::string>& inames,
149 const std::vector<std::string>& onames,
150 const Dict& opts)
const override;
156 Function slice(
const std::string& name,
const std::vector<casadi_int>& order_in,
157 const std::vector<casadi_int>& order_out,
const Dict& opts)
const override;
162 void codegen_declarations(CodeGenerator& g)
const override = 0;
167 void codegen_body(CodeGenerator& g)
const override = 0;
172 void export_code(
const std::string& lang,
173 std::ostream &stream,
const Dict& options)
const override;
178 virtual void export_code_body(
const std::string& lang,
179 std::ostream &stream,
const Dict& options)
const = 0;
184 bool has_codegen()
const override {
return true;}
189 virtual bool isInput(
const std::vector<MatType>& arg)
const;
192 virtual bool should_inline(
bool always_inline,
bool never_inline)
const = 0;
197 void call_forward(
const std::vector<MatType>& arg,
198 const std::vector<MatType>& res,
199 const std::vector<std::vector<MatType> >& fseed,
200 std::vector<std::vector<MatType> >& fsens,
201 bool always_inline,
bool never_inline)
const override;
206 void call_reverse(
const std::vector<MatType>& arg,
207 const std::vector<MatType>& res,
208 const std::vector<std::vector<MatType> >& aseed,
209 std::vector<std::vector<MatType> >& asens,
210 bool always_inline,
bool never_inline)
const override;
216 size_t get_n_in()
override {
return in_.size(); }
217 size_t get_n_out()
override {
return out_.size(); }
224 Sparsity get_sparsity_in(casadi_int i)
override {
return in_.at(i).sparsity();}
225 Sparsity get_sparsity_out(casadi_int i)
override {
return out_.at(i).sparsity();}
231 explicit XFunction(DeserializingStream& s);
235 void serialize_body(SerializingStream &s)
const override;
244 void delayed_serialize_members(SerializingStream &s)
const;
245 void delayed_deserialize_members(DeserializingStream &s);
253 std::vector<MatType> in_;
258 std::vector<MatType> out_;
263 template<
typename DerivedType,
typename MatType,
typename NodeType>
264 XFunction<DerivedType, MatType, NodeType>::
265 XFunction(
const std::string& name,
266 const std::vector<MatType>& ex_in,
267 const std::vector<MatType>& ex_out,
268 const std::vector<std::string>& name_in,
269 const std::vector<std::string>& name_out)
270 : FunctionInternal(name), in_(ex_in), out_(ex_out) {
272 if (!name_in.empty()) {
273 casadi_assert(ex_in.size()==name_in.size(),
274 "Mismatching number of input names");
278 if (!name_out.empty()) {
279 casadi_assert(ex_out.size()==name_out.size(),
280 "Mismatching number of output names");
281 name_out_ = name_out;
285 template<
typename DerivedType,
typename MatType,
typename NodeType>
286 XFunction<DerivedType, MatType, NodeType>::
287 XFunction(DeserializingStream& s) : FunctionInternal(s) {
288 s.version(
"XFunction", 1);
289 s.unpack(
"XFunction::in", in_);
293 template<
typename DerivedType,
typename MatType,
typename NodeType>
294 void XFunction<DerivedType, MatType, NodeType>::
295 delayed_deserialize_members(DeserializingStream& s) {
296 s.unpack(
"XFunction::out", out_);
299 template<
typename DerivedType,
typename MatType,
typename NodeType>
300 void XFunction<DerivedType, MatType, NodeType>::
301 delayed_serialize_members(SerializingStream& s)
const {
302 s.pack(
"XFunction::out", out_);
305 template<
typename DerivedType,
typename MatType,
typename NodeType>
306 void XFunction<DerivedType, MatType, NodeType>::
307 serialize_body(SerializingStream& s)
const {
308 FunctionInternal::serialize_body(s);
309 s.version(
"XFunction", 1);
310 s.pack(
"XFunction::in", in_);
314 template<
typename DerivedType,
typename MatType,
typename NodeType>
315 void XFunction<DerivedType, MatType, NodeType>::init(
const Dict& opts) {
317 FunctionInternal::init(opts);
319 bool allow_duplicate_io_names =
false;
321 for (
auto&& op : opts) {
322 if (op.first==
"allow_duplicate_io_names") {
323 allow_duplicate_io_names = op.second;
327 if (verbose_) casadi_message(name_ +
"::init");
329 for (casadi_int i=0; i<n_in_; ++i) {
330 if (in_.at(i).nnz()>0 && !in_.at(i).is_valid_input()) {
331 casadi_error(
"For " + this->name_ +
": Xfunction input arguments must be purely symbolic."
332 "\nArgument " + str(i) +
"(" + name_in_[i] +
") is not symbolic.");
336 bool has_duplicates =
false;
337 for (
auto&& i : in_) {
338 if (i.has_duplicates()) {
339 has_duplicates =
true;
344 for (
auto&& i : in_) i.reset_input();
346 if (has_duplicates) {
348 s <<
"The input expressions are not independent:\n";
349 for (casadi_int iind=0; iind<in_.size(); ++iind) {
350 s << iind <<
": " << in_[iind] <<
"\n";
352 casadi_error(s.str());
355 if (!allow_duplicate_io_names) {
357 std::hash<std::string> hasher;
358 std::vector<size_t> iohash;
359 iohash.reserve(name_in_.size() + name_out_.size());
360 for (
const std::string& s : name_in_) iohash.push_back(hasher(s));
361 for (
const std::string& s : name_out_) iohash.push_back(hasher(s));
362 std::sort(iohash.begin(), iohash.end());
365 for (
size_t h : iohash) {
368 std::vector<std::string> io_names;
369 io_names.reserve(iohash.size());
370 for (
const std::string& s : name_in_) io_names.push_back(s);
371 for (
const std::string& s : name_out_) io_names.push_back(s);
372 std::sort(io_names.begin(), io_names.end());
375 for (std::string h : io_names) {
376 if (h == prev) casadi_error(
"Duplicate IO name: " + h +
". "
377 "To ignore this error, set 'allow_duplicate_io_names' option.");
386 template<
typename DerivedType,
typename MatType,
typename NodeType>
387 void XFunction<DerivedType, MatType, NodeType>::sort_depth_first(
388 std::stack<NodeType*>& s, std::vector<NodeType*>& nodes) {
391 NodeType* t = s.top();
393 if (t && t->temp>=0) {
395 casadi_int next_dep = t->temp++;
397 if (next_dep < t->n_dep()) {
399 s.push(
static_cast<NodeType*
>(t->dep(next_dep).get()));
415 template<
typename DerivedType,
typename MatType,
typename NodeType>
416 std::vector<MatType> XFunction<DerivedType, MatType, NodeType>
417 ::jac(
const Dict& opts)
const {
420 bool compact =
false;
421 bool symmetric =
false;
422 bool allow_forward =
true;
423 bool allow_reverse =
true;
424 for (
auto&& op : opts) {
425 if (op.first==
"compact") {
427 }
else if (op.first==
"symmetric") {
428 symmetric = op.second;
429 }
else if (op.first==
"allow_forward") {
430 allow_forward = op.second;
431 }
else if (op.first==
"allow_reverse") {
432 allow_reverse = op.second;
433 }
else if (op.first==
"verbose") {
436 casadi_error(
"No such Jacobian option: " + std::string(op.first));
441 std::vector<MatType> ret(n_in_ * n_out_);
444 if (nnz_in() == 0 || nnz_out() == 0) {
445 for (casadi_int i = 0; i < n_out_; ++i) {
446 for (casadi_int j = 0; j < n_in_; ++j) {
448 ret[i * n_in_ + j] = MatType(nnz_out(i), nnz_in(j));
450 ret[i * n_in_ + j] = MatType(numel_out(i), numel_in(j));
458 casadi_int iind = 0, oind = 0;
459 casadi_assert(n_in_ == 1,
"Not implemented");
460 casadi_assert(n_out_ == 1,
"Not implemented");
463 ret.at(0) = MatType::zeros(jac_sparsity(0, 0,
false, symmetric).T());
464 if (verbose_) casadi_message(
"Allocated return value");
467 if (ret.at(0).nnz()==0) {
468 ret.at(0) = ret.at(0).T();
474 get_partition(iind, oind, D1, D2,
true, symmetric, allow_forward, allow_reverse);
475 if (verbose_) casadi_message(
"Graph coloring completed");
478 casadi_int nfdir = D1.is_null() ? 0 : D1.size2();
479 casadi_int nadir = D2.is_null() ? 0 : D2.size2();
482 casadi_int max_nfdir = max_num_dir_;
483 casadi_int max_nadir = max_num_dir_;
486 casadi_int offset_nfdir = 0, offset_nadir = 0;
489 std::vector<MatType> res(out_);
492 std::vector<std::vector<MatType> > fseed, aseed, fsens, asens;
495 Sparsity jsp = jac_sparsity(0, 0,
true, symmetric).T();
496 const casadi_int* jsp_colind = jsp.colind();
497 const casadi_int* jsp_row = jsp.row();
500 std::vector<casadi_int> input_col = sparsity_in_.at(iind).get_col();
501 const casadi_int* input_row = sparsity_in_.at(iind).row();
504 std::vector<casadi_int> output_col = sparsity_out_.at(oind).get_col();
505 const casadi_int* output_row = sparsity_out_.at(oind).row();
508 if (verbose_) casadi_message(
"jac transposes and mapping");
509 std::vector<casadi_int> mapping;
512 jsp_trans = jsp.transpose(mapping);
516 std::vector<casadi_int> nzmap, nzmap2;
519 std::vector<casadi_int> adds, adds2;
522 std::vector<casadi_int> tmp;
525 casadi_int progress = -10;
528 casadi_int nsweep_fwd = nfdir/max_nfdir;
529 if (nfdir%max_nfdir>0) nsweep_fwd++;
530 casadi_int nsweep_adj = nadir/max_nadir;
531 if (nadir%max_nadir>0) nsweep_adj++;
532 casadi_int nsweep = std::max(nsweep_fwd, nsweep_adj);
534 casadi_message(str(nsweep) +
" sweeps needed for " + str(nfdir) +
" forward and "
535 + str(nadir) +
" reverse directions");
539 std::vector<casadi_int> seed_col, seed_row;
542 for (casadi_int s=0; s<nsweep; ++s) {
545 casadi_int progress_new = (s*100)/nsweep;
547 if (progress_new / 10 > progress / 10) {
548 progress = progress_new;
549 casadi_message(str(progress) +
" %");
554 casadi_int nfdir_batch = std::min(nfdir - offset_nfdir, max_nfdir);
555 casadi_int nadir_batch = std::min(nadir - offset_nadir, max_nadir);
558 fseed.resize(nfdir_batch);
559 for (casadi_int d=0; d<nfdir_batch; ++d) {
565 for (casadi_int el = D1.colind(offset_nfdir+d); el<D1.colind(offset_nfdir+d+1); ++el) {
568 casadi_int c = D1.row(el);
571 seed_col.push_back(input_col[c]);
572 seed_row.push_back(input_row[c]);
576 fseed[d].resize(n_in_);
577 for (casadi_int ind=0; ind<fseed[d].size(); ++ind) {
578 casadi_int nrow = size1_in(ind), ncol = size2_in(ind);
580 fseed[d][ind] = MatType::ones(Sparsity::triplet(nrow, ncol, seed_row, seed_col));
582 fseed[d][ind] = MatType(nrow, ncol);
588 aseed.resize(nadir_batch);
589 for (casadi_int d=0; d<nadir_batch; ++d) {
595 for (casadi_int el = D2.colind(offset_nadir+d); el<D2.colind(offset_nadir+d+1); ++el) {
598 casadi_int c = D2.row(el);
601 seed_col.push_back(output_col[c]);
602 seed_row.push_back(output_row[c]);
606 aseed[d].resize(n_out_);
607 for (casadi_int ind=0; ind<aseed[d].size(); ++ind) {
608 casadi_int nrow = size1_out(ind), ncol = size2_out(ind);
610 aseed[d][ind] = MatType::ones(Sparsity::triplet(nrow, ncol, seed_row, seed_col));
612 aseed[d][ind] = MatType(nrow, ncol);
618 fsens.resize(nfdir_batch);
619 for (casadi_int d=0; d<nfdir_batch; ++d) {
621 fsens[d].resize(n_out_);
622 for (casadi_int oind=0; oind<fsens[d].size(); ++oind) {
623 fsens[d][oind] = MatType::zeros(sparsity_out_.at(oind));
628 asens.resize(nadir_batch);
629 for (casadi_int d=0; d<nadir_batch; ++d) {
631 asens[d].resize(n_in_);
632 for (casadi_int ind=0; ind<asens[d].size(); ++ind) {
633 asens[d][ind] = MatType::zeros(sparsity_in_.at(ind));
638 if (!fseed.empty()) {
639 casadi_assert_dev(aseed.empty());
640 if (verbose_) casadi_message(
"Calling 'ad_forward'");
641 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
642 if (verbose_) casadi_message(
"Back from 'ad_forward'");
643 }
else if (!aseed.empty()) {
644 casadi_assert_dev(fseed.empty());
645 if (verbose_) casadi_message(
"Calling 'ad_reverse'");
646 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
647 if (verbose_) casadi_message(
"Back from 'ad_reverse'");
651 for (casadi_int d=0; d<nfdir_batch; ++d) {
653 if (fsens[d][oind].nnz()==0) {
660 tmp.resize(nnz_out(oind));
661 std::fill(tmp.begin(), tmp.end(), 0);
664 for (casadi_int el = D1.colind(offset_nfdir+d); el<D1.colind(offset_nfdir+d+1); ++el) {
667 casadi_int c = D1.row(el);
670 for (casadi_int el_jsp=jsp_colind[c]; el_jsp<jsp_colind[c+1]; ++el_jsp) {
671 tmp[jsp_row[el_jsp]]++;
677 sparsity_out_.at(oind).find(nzmap);
678 fsens[d][oind].sparsity().get_nz(nzmap);
681 sparsity_in_.at(iind).find(nzmap2);
682 fsens[d][oind].sparsity().get_nz(nzmap2);
686 adds.resize(fsens[d][oind].nnz());
687 std::fill(adds.begin(), adds.end(), -1);
689 adds2.resize(adds.size());
690 std::fill(adds2.begin(), adds2.end(), -1);
694 for (casadi_int el = D1.colind(offset_nfdir+d); el<D1.colind(offset_nfdir+d+1); ++el) {
697 casadi_int c = D1.row(el);
704 for (casadi_int el_out = jsp_trans.colind(c); el_out<jsp_trans.colind(c+1); ++el_out) {
707 casadi_int r_out = jsp_trans.row(el_out);
710 casadi_int f_out = nzmap[r_out];
711 if (f_out<0)
continue;
714 casadi_int elJ = mapping[el_out];
718 adds[f_out] = el_out;
729 tmp.resize(adds.size());
731 for (casadi_int i=0; i<adds.size(); ++i) {
741 ret.at(0).nz(adds) = fsens[d][oind].nz(tmp);
745 tmp.resize(adds2.size());
747 for (casadi_int i=0; i<adds2.size(); ++i) {
749 adds2[sz] = adds2[i];
757 ret.at(0).nz(adds2) = fsens[d][oind].nz(tmp);
762 for (casadi_int d=0; d<nadir_batch; ++d) {
764 if (asens[d][iind].nnz()==0) {
769 sparsity_in_.at(iind).find(nzmap);
770 asens[d][iind].sparsity().get_nz(nzmap);
773 for (casadi_int el = D2.colind(offset_nadir+d); el<D2.colind(offset_nadir+d+1); ++el) {
776 casadi_int r = D2.row(el);
779 for (casadi_int elJ = jsp.colind(r); elJ<jsp.colind(r+1); ++elJ) {
782 casadi_int inz = jsp.row(elJ);
785 casadi_int anz = nzmap[inz];
789 ret.at(0).nz(elJ) = asens[d][iind].nz(anz);
795 offset_nfdir += nfdir_batch;
796 offset_nadir += nadir_batch;
800 for (MatType& Jb : ret) Jb = Jb.T();
803 }
catch (std::exception& e) {
804 CASADI_THROW_ERROR(
"jac", e.what());
808 template<
typename DerivedType,
typename MatType,
typename NodeType>
809 Function XFunction<DerivedType, MatType, NodeType>
810 ::get_forward(casadi_int nfwd,
const std::string& name,
811 const std::vector<std::string>& inames,
812 const std::vector<std::string>& onames,
813 const Dict& opts)
const {
816 std::vector<std::vector<MatType> > fseed = fwd_seed<MatType>(nfwd), fsens;
819 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
820 casadi_assert_dev(fsens.size()==fseed.size());
823 std::vector<MatType> ret_in(inames.size());
824 std::copy(in_.begin(), in_.end(), ret_in.begin());
825 for (casadi_int i=0; i<n_out_; ++i) {
826 ret_in.at(n_in_+i) = MatType::sym(inames[n_in_+i], Sparsity(out_.at(i).size()));
828 std::vector<MatType> v(nfwd);
829 for (casadi_int i=0; i<n_in_; ++i) {
830 for (casadi_int d=0; d<nfwd; ++d) v[d] = fseed[d][i];
831 ret_in.at(n_in_ + n_out_ + i) = horzcat(v);
835 std::vector<MatType> ret_out(onames.size());
836 for (casadi_int i=0; i<n_out_; ++i) {
837 if (is_diff_out_[i]) {
839 for (casadi_int d=0; d<nfwd; ++d) v[d] = fsens[d][i];
840 ret_out.at(i) = ensure_stacked(horzcat(v), sparsity_out(i), nfwd);
843 ret_out.at(i) = MatType(size1_out(i), size2_out(i) * nfwd);
848 if (opts.find(
"is_diff_in")==opts.end())
849 options[
"is_diff_in"] = join(is_diff_in_, is_diff_out_, is_diff_in_);
850 if (opts.find(
"is_diff_out")==opts.end())
851 options[
"is_diff_out"] = is_diff_out_;
852 options[
"allow_duplicate_io_names"] =
true;
854 return Function(name, ret_in, ret_out, inames, onames, options);
855 }
catch (std::exception& e) {
856 CASADI_THROW_ERROR(
"get_forward", e.what());
860 template<
typename DerivedType,
typename MatType,
typename NodeType>
861 Function XFunction<DerivedType, MatType, NodeType>
862 ::get_reverse(casadi_int nadj,
const std::string& name,
863 const std::vector<std::string>& inames,
864 const std::vector<std::string>& onames,
865 const Dict& opts)
const {
868 std::vector<std::vector<MatType> > aseed = symbolicAdjSeed(nadj, out_), asens;
871 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
874 std::vector<MatType> ret_in(inames.size());
875 std::copy(in_.begin(), in_.end(), ret_in.begin());
876 for (casadi_int i=0; i<n_out_; ++i) {
877 ret_in.at(n_in_ + i) = MatType::sym(inames[n_in_+i], Sparsity(out_.at(i).size()));
879 std::vector<MatType> v(nadj);
880 for (casadi_int i=0; i<n_out_; ++i) {
881 for (casadi_int d=0; d<nadj; ++d) v[d] = aseed[d][i];
882 ret_in.at(n_in_ + n_out_ + i) = horzcat(v);
886 std::vector<MatType> ret_out(onames.size());
887 for (casadi_int i=0; i<n_in_; ++i) {
888 if (is_diff_in_[i]) {
890 for (casadi_int d=0; d<nadj; ++d) v[d] = asens[d][i];
891 ret_out.at(i) = ensure_stacked(horzcat(v), sparsity_in(i), nadj);
894 ret_out.at(i) = MatType(size1_in(i), size2_in(i) * nadj);
899 if (opts.find(
"is_diff_in")==opts.end())
900 options[
"is_diff_in"] = join(is_diff_in_, is_diff_out_, is_diff_out_);
901 if (opts.find(
"is_diff_out")==opts.end())
902 options[
"is_diff_out"] = is_diff_in_;
904 options[
"allow_duplicate_io_names"] =
true;
906 return Function(name, ret_in, ret_out, inames, onames, options);
907 }
catch (std::exception& e) {
908 CASADI_THROW_ERROR(
"get_reverse", e.what());
912 template<
typename DerivedType,
typename MatType,
typename NodeType>
913 Function XFunction<DerivedType, MatType, NodeType>
914 ::get_jacobian(
const std::string& name,
915 const std::vector<std::string>& inames,
916 const std::vector<std::string>& onames,
917 const Dict& opts)
const {
919 Dict tmp_options = generate_options(
"tmp");
920 tmp_options[
"allow_free"] =
true;
921 tmp_options[
"allow_duplicate_io_names"] =
true;
923 Function tmp(
"flattened_" + name, {veccat(in_)}, {veccat(out_)}, tmp_options);
926 MatType J = tmp.get<DerivedType>()->jac(
Dict()).at(0);
929 std::vector<casadi_int> r_offset = {0}, c_offset = {0};
930 for (
auto& e : out_) r_offset.push_back(r_offset.back() + e.numel());
931 for (
auto& e : in_) c_offset.push_back(c_offset.back() + e.numel());
932 auto Jblocks = MatType::blocksplit(J, r_offset, c_offset);
935 std::vector<MatType> ret_out;
936 ret_out.reserve(onames.size());
937 for (casadi_int i = 0; i < n_out_; ++i) {
938 for (casadi_int j = 0; j < n_in_; ++j) {
939 MatType b = Jblocks.at(i).at(j);
940 if (!is_diff_out_.at(i) || !is_diff_in_.at(j)) {
941 b = MatType(b.size());
943 ret_out.push_back(b);
948 std::vector<MatType> ret_in(inames.size());
949 std::copy(in_.begin(), in_.end(), ret_in.begin());
950 for (casadi_int i=0; i<n_out_; ++i) {
951 ret_in.at(n_in_+i) = MatType::sym(inames[n_in_+i], Sparsity(out_.at(i).size()));
955 options[
"allow_free"] =
true;
956 options[
"allow_duplicate_io_names"] =
true;
959 return Function(name, ret_in, ret_out, inames, onames, options);
960 }
catch (std::exception& e) {
961 CASADI_THROW_ERROR(
"get_jacobian", e.what());
965 template<
typename DerivedType,
typename MatType,
typename NodeType>
966 Function XFunction<DerivedType, MatType, NodeType>
967 ::slice(
const std::string& name,
const std::vector<casadi_int>& order_in,
968 const std::vector<casadi_int>& order_out,
const Dict& opts)
const {
970 std::vector<MatType> ret_in, ret_out;
971 std::vector<std::string> ret_in_name, ret_out_name;
974 for (casadi_int k : order_in) {
975 ret_in.push_back(in_.at(k));
976 ret_in_name.push_back(name_in_.at(k));
980 for (casadi_int k : order_out) {
981 ret_out.push_back(out_.at(k));
982 ret_out_name.push_back(name_out_.at(k));
986 return Function(name, ret_in, ret_out,
987 ret_in_name, ret_out_name, opts);
990 template<
typename DerivedType,
typename MatType,
typename NodeType>
991 void XFunction<DerivedType, MatType, NodeType>
992 ::export_code(
const std::string& lang, std::ostream &stream,
const Dict& options)
const {
994 casadi_assert(!has_free(),
"export_code needs a Function without free variables");
996 casadi_assert(lang==
"matlab",
"Only matlab language supported for now.");
999 stream <<
"function [varargout] = " << name_ <<
"(varargin)" << std::endl;
1002 for (casadi_int i=0;i<n_out_;++i) {
1003 stream <<
" argout_" << i <<
" = cell(" << nnz_out(i) <<
",1);" << std::endl;
1007 opts[
"indent_level"] = 1;
1008 export_code_body(lang, stream, opts);
1011 for (casadi_int i=0;i<n_out_;++i) {
1012 const Sparsity& out = sparsity_out_.at(i);
1013 if (out.is_dense()) {
1015 stream <<
" varargout{" << i+1 <<
"} = reshape(vertcat(argout_" << i <<
"{:}), ";
1016 stream << out.size1() <<
", " << out.size2() <<
");" << std::endl;
1020 opts[
"name"] =
"sp";
1021 opts[
"indent_level"] = 1;
1022 opts[
"as_matrix"] =
false;
1023 out.export_code(
"matlab", stream, opts);
1024 stream <<
" varargout{" << i+1 <<
"} = ";
1025 stream <<
"sparse(sp_i, sp_j, vertcat(argout_" << i <<
"{:}), sp_m, sp_n);" << std::endl;
1030 stream <<
"end" << std::endl;
1031 stream <<
"function y=nonzeros_gen(x)" << std::endl;
1032 stream <<
" if isa(x,'casadi.SX') || isa(x,'casadi.MX') || isa(x,'casadi.DM')" << std::endl;
1033 stream <<
" y = x{:};" << std::endl;
1034 stream <<
" elseif isa(x,'sdpvar')" << std::endl;
1035 stream <<
" b = getbase(x);" << std::endl;
1036 stream <<
" f = find(sum(b~=0,2));" << std::endl;
1037 stream <<
" y = sdpvar(length(f),1,[],getvariables(x),b(f,:));" << std::endl;
1038 stream <<
" else" << std::endl;
1039 stream <<
" y = nonzeros(x);" << std::endl;
1040 stream <<
" end" << std::endl;
1041 stream <<
"end" << std::endl;
1042 stream <<
"function y=if_else_zero_gen(c,e)" << std::endl;
1043 stream <<
" if isa(c+e,'casadi.SX') || isa(c+e,'casadi.MX') "
1044 "|| isa(c+e,'casadi.DM')" << std::endl;
1045 stream <<
" y = if_else(c, e, 0);" << std::endl;
1046 stream <<
" else" << std::endl;
1047 stream <<
" if c" << std::endl;
1048 stream <<
" y = x;" << std::endl;
1049 stream <<
" else" << std::endl;
1050 stream <<
" y = 0;" << std::endl;
1051 stream <<
" end" << std::endl;
1052 stream <<
" end" << std::endl;
1053 stream <<
"end" << std::endl;
1058 template<
typename DerivedType,
typename MatType,
typename NodeType>
1059 bool XFunction<DerivedType, MatType, NodeType>
1060 ::isInput(
const std::vector<MatType>& arg)
const {
1063 const casadi_int checking_depth = 2;
1064 for (casadi_int i=0; i<arg.size(); ++i) {
1065 if (!is_equal(arg[i], in_[i], checking_depth)) {
1072 template<
typename DerivedType,
typename MatType,
typename NodeType>
1073 void XFunction<DerivedType, MatType, NodeType>::
1074 call_forward(
const std::vector<MatType>& arg,
1075 const std::vector<MatType>& res,
1076 const std::vector<std::vector<MatType> >& fseed,
1077 std::vector<std::vector<MatType> >& fsens,
1078 bool always_inline,
bool never_inline)
const {
1079 casadi_assert(!(always_inline && never_inline),
"Inconsistent options");
1080 if (!should_inline(always_inline, never_inline)) {
1082 return FunctionInternal::call_forward(arg, res, fseed, fsens,
1083 always_inline, never_inline);
1087 if (fseed.empty()) {
1095 static_cast<const DerivedType*
>(
this)->ad_forward(fseed, fsens);
1098 Function f(
"tmp_call_forward", arg, res);
1099 static_cast<DerivedType *
>(f.get())->ad_forward(fseed, fsens);
1103 template<
typename DerivedType,
typename MatType,
typename NodeType>
1104 void XFunction<DerivedType, MatType, NodeType>::
1105 call_reverse(
const std::vector<MatType>& arg,
1106 const std::vector<MatType>& res,
1107 const std::vector<std::vector<MatType> >& aseed,
1108 std::vector<std::vector<MatType> >& asens,
1109 bool always_inline,
bool never_inline)
const {
1110 casadi_assert(!(always_inline && never_inline),
"Inconsistent options");
1111 if (!should_inline(always_inline, never_inline)) {
1113 return FunctionInternal::call_reverse(arg, res, aseed, asens,
1114 always_inline, never_inline);
1118 if (aseed.empty()) {
1126 static_cast<const DerivedType*
>(
this)->ad_reverse(aseed, asens);
1129 Function f(
"tmp_call_reverse", arg, res);
1130 static_cast<DerivedType *
>(f.get())->ad_reverse(aseed, asens);
1134 template<
typename DerivedType,
typename MatType,
typename NodeType>
1135 Function XFunction<DerivedType, MatType, NodeType>::
1136 factory(
const std::string& name,
1137 const std::vector<std::string>& s_in,
1138 const std::vector<std::string>& s_out,
1139 const Function::AuxOut& aux,
1140 const Dict& opts)
const {
1142 Dict g_ops = generate_options(
"clone");
1144 f_options[
"helper_options"] = g_ops;
1145 f_options[
"final_options"] = g_ops;
1146 update_dict(f_options, opts,
true);
1149 extract_from_dict_inplace(f_options,
"final_options", final_options);
1150 final_options[
"allow_duplicate_io_names"] =
true;
1154 for (casadi_int i=0; i<in_.size(); ++i) f.add_input(name_in_[i], in_[i], is_diff_in_[i]);
1155 for (casadi_int i=0; i<out_.size(); ++i) f.add_output(name_out_[i], out_[i], is_diff_out_[i]);
1159 std::vector<std::string> ret_iname;
1160 for (
const std::string& s : s_in) {
1162 ret_iname.push_back(f.request_input(s));
1163 }
catch (CasadiException& ex) {
1164 casadi_error(
"Cannot process factory input \"" + s +
"\":" + ex.what());
1169 std::vector<std::string> ret_oname;
1170 for (
const std::string& s : s_out) {
1172 ret_oname.push_back(f.request_output(s));
1173 }
catch (CasadiException& ex) {
1174 casadi_error(
"Cannot process factory output \"" + s +
"\":" + ex.what());
1179 f.calculate(f_options);
1182 std::vector<MatType> ret_in;
1183 ret_in.reserve(s_in.size());
1184 for (
const std::string& s : s_in) ret_in.push_back(f.get_input(s));
1187 std::vector<MatType> ret_out;
1188 ret_out.reserve(s_out.size());
1189 for (
const std::string& s : s_out) ret_out.push_back(f.get_output(s));
1192 Dict final_options_allow_free = final_options;
1193 final_options_allow_free[
"allow_free"] =
true;
1194 final_options_allow_free[
"allow_duplicate_io_names"] =
true;
1195 Function ret(name, ret_in, ret_out, ret_iname, ret_oname, final_options_allow_free);
1196 if (ret.has_free()) {
1199 std::vector<MatType> free_in = MatType::get_free(ret);
1200 std::vector<MatType> free_sub = free_in;
1201 for (
auto&& e : free_sub) e = MatType::zeros(e.sparsity());
1202 ret_out = substitute(ret_out, free_in, free_sub);
1203 ret = Function(name, ret_in, ret_out, ret_iname, ret_oname, final_options);
1208 template<
typename DerivedType,
typename MatType,
typename NodeType>
1209 std::vector<bool> XFunction<DerivedType, MatType, NodeType>::
1210 which_depends(
const std::string& s_in,
const std::vector<std::string>& s_out,
1211 casadi_int order,
bool tr)
const {
1214 auto it = std::find(name_in_.begin(), name_in_.end(), s_in);
1215 casadi_assert_dev(it!=name_in_.end());
1216 MatType arg = in_.at(it-name_in_.begin());
1219 std::vector<MatType> res;
1220 for (
auto&& s : s_out) {
1221 it = std::find(name_out_.begin(), name_out_.end(), s);
1222 casadi_assert_dev(it!=name_out_.end());
1223 res.push_back(out_.at(it-name_out_.begin()));
1227 return MatType::which_depends(veccat(res), arg, order, tr);
1230 template<
typename MatType>
1231 Sparsity _jacobian_sparsity(
const MatType &expr,
const MatType &var) {
1232 Dict opts{{
"max_io", 0}, {
"allow_free",
true}};
1233 Function f = Function(
"tmp_jacobian_sparsity", {var}, {expr}, opts);
1234 return f.jac_sparsity(0, 0,
false);
1237 template<
typename MatType>
1238 std::vector<bool> _which_depends(
const MatType &expr,
const MatType &var,
1239 casadi_int order,
bool tr) {
1241 if (expr.is_empty() || var.is_empty()) {
1242 return std::vector<bool>(tr? expr.numel() : var.numel(),
false);
1248 casadi_assert(order==1 || order==2,
1249 "which_depends: order argument must be 1 or 2, got " + str(order) +
" instead.");
1251 MatType v = MatType::sym(
"v", var.sparsity());
1252 for (casadi_int i=1;i<order;++i) {
1253 e = jtimes(e, var, v);
1256 Dict opts{{
"max_io", 0}, {
"allow_free",
true}};
1257 Function f = Function(
"tmp_which_depends", {var}, {e}, opts);
1259 std::vector<bvec_t> seed(tr? f.nnz_in(0) : f.nnz_out(0), 1);
1260 std::vector<bvec_t> sens(tr? f.nnz_out(0) : f.nnz_in(0), 0);
1263 f({get_ptr(seed)}, {get_ptr(sens)});
1265 f.rev({get_ptr(sens)}, {get_ptr(seed)});
1267 std::vector<bool> ret(sens.size());
1268 std::copy(sens.begin(), sens.end(), ret.begin());
1271 if (tr && e.sparsity()!=expr.sparsity()) {
1274 std::vector<casadi_int> source(sens.size());
1275 std::copy(ret.begin(), ret.end(), source.begin());
1276 std::vector<casadi_int> target(expr.nnz());
1279 std::vector<casadi_int> scratch(expr.size1());
1280 casadi_project(get_ptr(source), e.sparsity(), get_ptr(target), expr.sparsity(),
1284 ret.resize(expr.nnz());
1285 std::copy(target.begin(), target.end(), ret.begin());
1293 #undef CASADI_THROW_ERROR
std::map< std::string, std::vector< std::string > > AuxOut
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.