26 #ifndef CASADI_SX_FUNCTION_HPP
27 #define CASADI_SX_FUNCTION_HPP
29 #include "x_function.hpp"
42 struct {
int i1, i2; };
53 class CASADI_EXPORT SXFunction :
54 public XFunction<SXFunction, Matrix<SXElem>, SXNode>{
59 SXFunction(
const std::string& name,
60 const std::vector<Matrix<SXElem> >& inputv,
61 const std::vector<Matrix<SXElem> >& outputv,
62 const std::vector<std::string>& name_in,
63 const std::vector<std::string>& name_out);
68 ~SXFunction()
override;
73 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
78 int eval_sx(
const SXElem** arg, SXElem** res,
79 casadi_int* iw, SXElem* w,
void* mem,
80 bool always_inline,
bool never_inline)
const override;
86 bool always_inline,
bool never_inline)
const override;
89 bool should_inline(
bool with_sx,
bool always_inline,
bool never_inline)
const override;
94 void ad_forward(
const std::vector<std::vector<SX> >& fseed,
95 std::vector<std::vector<SX> >& fsens)
const;
100 void ad_reverse(
const std::vector<std::vector<SX> >& aseed,
101 std::vector<std::vector<SX> >& asens)
const;
106 bool is_smooth()
const;
111 void disp_more(std::ostream& stream)
const override;
116 std::string class_name()
const override {
return "SXFunction";}
121 bool is_a(
const std::string& type,
bool recursive)
const override;
127 const SX sx_in(casadi_int ind)
const override;
128 const std::vector<SX> sx_in()
const override;
132 std::vector<SX> free_sx()
const override {
133 std::vector<SX> ret(free_vars_.size());
134 std::copy(free_vars_.begin(), free_vars_.end(), ret.begin());
141 bool has_free()
const override {
return !free_vars_.empty();}
146 std::vector<std::string> get_free()
const override {
147 std::vector<std::string> ret;
148 for (
auto&& e : free_vars_) ret.push_back(e.name());
155 SX hess(casadi_int iind=0, casadi_int oind=0);
160 casadi_int n_instructions()
const override {
return algorithm_.size();}
165 casadi_int instruction_id(casadi_int k)
const override {
return algorithm_.at(k).op;}
170 std::vector<casadi_int> instruction_input(casadi_int k)
const override {
171 auto e = algorithm_.at(k);
173 const ExtendedAlgEl& m = call_.el[e.i1];
174 return vector_static_cast<casadi_int>(m.dep);
175 }
else if (casadi_math<double>::ndeps(e.op)==2 || e.op==OP_INPUT) {
177 }
else if (casadi_math<double>::ndeps(e.op)==1) {
187 double instruction_constant(casadi_int k)
const override {
188 return algorithm_.at(k).d;
194 std::vector<casadi_int> instruction_output(casadi_int k)
const override {
195 auto e = algorithm_.at(k);
197 const ExtendedAlgEl& m = call_.el[e.i1];
198 return vector_static_cast<casadi_int>(m.res);
199 }
else if (e.op==OP_OUTPUT) {
209 casadi_int n_nodes()
const override {
return algorithm_.size() - nnz_out();}
218 typedef ScalarAtomic AlgEl;
231 std::vector<AlgEl> algorithm_;
237 std::vector<SXElem> free_vars_;
240 std::vector<SXElem> operations_;
243 std::vector<SXElem> constants_;
246 std::vector<double> default_in_;
249 std::vector<bool> copy_elision_;
254 void serialize_body(SerializingStream &s)
const override;
257 struct ExtendedAlgEl {
258 ExtendedAlgEl(
const Function& fun);
261 std::vector<int> dep;
263 std::vector<int> res;
265 std::vector<int> copy_elision_arg;
266 std::vector<int> copy_elision_offset;
273 std::vector<int> f_nnz_in;
274 std::vector<int> f_nnz_out;
280 size_t sz_arg = 0, sz_res = 0, sz_iw = 0, sz_w = 0;
281 size_t sz_w_arg = 0, sz_w_res = 0;
282 std::vector<ExtendedAlgEl> el;
288 static ProtoFunction* deserialize(DeserializingStream& s);
290 static std::vector<SX> order(
const std::vector<SX>& expr);
296 static const Options options_;
297 const Options& get_options()
const override {
return options_;}
301 Dict generate_options(
const std::string& target=
"clone")
const override;
306 void init(
const Dict& opts)
override;
311 void init_copy_elision();
316 size_t codegen_sz_w(
const CodeGenerator& g)
const override;
321 void codegen_declarations(CodeGenerator& g)
const override;
326 void codegen_body(CodeGenerator& g)
const override;
331 int sp_forward(
const bvec_t** arg, bvec_t** res,
332 casadi_int* iw, bvec_t* w,
void* mem)
const override;
337 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w,
void* mem)
const override;
342 SX instructions_sx()
const override;
345 void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth)
const override;
350 double get_default_in(casadi_int ind)
const override {
return default_in_.at(ind);}
355 void export_code_body(
const std::string& lang,
356 std::ostream &stream,
const Dict& options)
const override;
359 bool just_in_time_opencl_;
362 bool just_in_time_sparsity_;
365 bool live_variables_;
369 void call_fwd(
const AlgEl& e,
const T** arg, T** res, casadi_int* iw, T* w)
const;
372 void call_rev(
const AlgEl& e, T** arg, T** res, casadi_int* iw, T* w)
const;
374 template<
typename T,
typename CT>
375 void call_setup(
const ExtendedAlgEl& m,
376 CT*** call_arg, T*** call_res, casadi_int** call_iw, T** call_w, T** nz_in, T** nz_out)
const;
381 explicit SXFunction(DeserializingStream& s);
std::vector< MX > MXVector
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.