26 #ifndef CASADI_SX_FUNCTION_HPP
27 #define CASADI_SX_FUNCTION_HPP
29 #include "x_function.hpp"
42 struct {
int i1, i2; };
53 class CASADI_EXPORT SXFunction :
54 public XFunction<SXFunction, Matrix<SXElem>, SXNode>{
59 SXFunction(
const std::string& name,
60 const std::vector<Matrix<SXElem> >& inputv,
61 const std::vector<Matrix<SXElem> >& outputv,
62 const std::vector<std::string>& name_in,
63 const std::vector<std::string>& name_out);
68 ~SXFunction()
override;
73 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
78 int eval_sx(
const SXElem** arg, SXElem** res,
79 casadi_int* iw, SXElem* w,
void* mem,
80 bool always_inline,
bool never_inline)
const override;
86 bool always_inline,
bool never_inline)
const override;
89 bool should_inline(
bool with_sx,
bool always_inline,
bool never_inline)
const override;
94 void ad_forward(
const std::vector<std::vector<SX> >& fseed,
95 std::vector<std::vector<SX> >& fsens)
const;
100 void ad_reverse(
const std::vector<std::vector<SX> >& aseed,
101 std::vector<std::vector<SX> >& asens)
const;
106 bool is_smooth()
const;
109 std::string print(
const ScalarAtomic& a)
const;
112 void print_arg(std::ostream &stream, casadi_int k,
const ScalarAtomic& el,
113 const double* w)
const;
116 void print_arg(CodeGenerator& g, casadi_int k,
const ScalarAtomic& el)
const;
119 void print_res(std::ostream &stream, casadi_int k,
const ScalarAtomic& el,
120 const double* w)
const;
123 void print_res(CodeGenerator& g, casadi_int k,
const ScalarAtomic& el)
const;
128 void disp_more(std::ostream& stream)
const override;
133 std::string class_name()
const override {
return "SXFunction";}
138 bool is_a(
const std::string& type,
bool recursive)
const override;
144 const SX sx_in(casadi_int ind)
const override;
145 const std::vector<SX> sx_in()
const override;
149 std::vector<SX> free_sx()
const override {
150 std::vector<SX> ret(free_vars_.size());
151 std::copy(free_vars_.begin(), free_vars_.end(), ret.begin());
158 bool has_free()
const override {
return !free_vars_.empty();}
163 std::vector<std::string> get_free()
const override {
164 std::vector<std::string> ret;
165 for (
auto&& e : free_vars_) ret.push_back(e.name());
172 SX hess(casadi_int iind=0, casadi_int oind=0);
177 casadi_int n_instructions()
const override {
return algorithm_.size();}
182 casadi_int instruction_id(casadi_int k)
const override {
return algorithm_.at(k).op;}
187 std::vector<casadi_int> instruction_input(casadi_int k)
const override {
188 auto e = algorithm_.at(k);
190 const ExtendedAlgEl& m = call_.el[e.i1];
191 return vector_static_cast<casadi_int>(m.dep);
192 }
else if (casadi_math<double>::ndeps(e.op)==2 || e.op==OP_INPUT) {
194 }
else if (casadi_math<double>::ndeps(e.op)==1) {
204 double instruction_constant(casadi_int k)
const override {
205 return algorithm_.at(k).d;
211 std::vector<casadi_int> instruction_output(casadi_int k)
const override {
212 auto e = algorithm_.at(k);
214 const ExtendedAlgEl& m = call_.el[e.i1];
215 return vector_static_cast<casadi_int>(m.res);
216 }
else if (e.op==OP_OUTPUT) {
226 casadi_int n_nodes()
const override {
return algorithm_.size() - nnz_out();}
235 typedef ScalarAtomic AlgEl;
248 std::vector<AlgEl> algorithm_;
254 std::vector<SXElem> free_vars_;
257 std::vector<SXElem> operations_;
260 std::vector<SXElem> constants_;
263 std::vector<double> default_in_;
266 std::vector<bool> copy_elision_;
269 bool print_instructions_;
274 void serialize_body(SerializingStream &s)
const override;
277 struct ExtendedAlgEl {
278 ExtendedAlgEl(
const Function& fun);
281 std::vector<int> dep;
283 std::vector<int> res;
285 std::vector<int> copy_elision_arg;
286 std::vector<int> copy_elision_offset;
293 std::vector<int> f_nnz_in;
294 std::vector<int> f_nnz_out;
300 size_t sz_arg = 0, sz_res = 0, sz_iw = 0, sz_w = 0;
301 size_t sz_w_arg = 0, sz_w_res = 0;
302 std::vector<ExtendedAlgEl> el;
308 static ProtoFunction* deserialize(DeserializingStream& s);
310 static std::vector<SX> order(
const std::vector<SX>& expr);
316 static const Options options_;
317 const Options& get_options()
const override {
return options_;}
321 Dict generate_options(
const std::string& target=
"clone")
const override;
326 void init(
const Dict& opts)
override;
331 void init_copy_elision();
336 size_t codegen_sz_w(
const CodeGenerator& g)
const override;
341 void codegen_declarations(CodeGenerator& g)
const override;
346 void codegen_body(CodeGenerator& g)
const override;
351 int sp_forward(
const bvec_t** arg, bvec_t** res,
352 casadi_int* iw, bvec_t* w,
void* mem)
const override;
357 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w,
void* mem)
const override;
362 SX instructions_sx()
const override;
365 void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth)
const override;
370 double get_default_in(casadi_int ind)
const override {
return default_in_.at(ind);}
375 void export_code_body(
const std::string& lang,
376 std::ostream &stream,
const Dict& options)
const override;
379 bool just_in_time_opencl_;
382 bool just_in_time_sparsity_;
385 bool live_variables_;
389 void call_fwd(
const AlgEl& e,
const T** arg, T** res, casadi_int* iw, T* w)
const;
392 void call_rev(
const AlgEl& e, T** arg, T** res, casadi_int* iw, T* w)
const;
394 template<
typename T,
typename CT>
395 void call_setup(
const ExtendedAlgEl& m,
396 CT*** call_arg, T*** call_res, casadi_int** call_iw, T** call_w, T** nz_in, T** nz_out)
const;
401 explicit SXFunction(DeserializingStream& s);
std::vector< MX > MXVector
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.