sx_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SX_FUNCTION_HPP
27 #define CASADI_SX_FUNCTION_HPP
28 
29 #include "x_function.hpp"
30 
32 
33 namespace casadi {
37  struct ScalarAtomic {
38  int op;
39  int i0;
40  union {
41  double d;
42  struct { int i1, i2; };
43  };
44  };
45 
53 class CASADI_EXPORT SXFunction :
54  public XFunction<SXFunction, Matrix<SXElem>, SXNode>{
55  public:
59  SXFunction(const std::string& name,
60  const std::vector<Matrix<SXElem> >& inputv,
61  const std::vector<Matrix<SXElem> >& outputv,
62  const std::vector<std::string>& name_in,
63  const std::vector<std::string>& name_out);
64 
68  ~SXFunction() override;
69 
73  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
74 
78  int eval_sx(const SXElem** arg, SXElem** res,
79  casadi_int* iw, SXElem* w, void* mem,
80  bool always_inline, bool never_inline) const override;
81 
85  void eval_mx(const MXVector& arg, MXVector& res,
86  bool always_inline, bool never_inline) const override;
87 
89  bool should_inline(bool with_sx, bool always_inline, bool never_inline) const override;
90 
94  void ad_forward(const std::vector<std::vector<SX> >& fseed,
95  std::vector<std::vector<SX> >& fsens) const;
96 
100  void ad_reverse(const std::vector<std::vector<SX> >& aseed,
101  std::vector<std::vector<SX> >& asens) const;
102 
106  bool is_smooth() const;
107 
108  // print an element of an algorithm
109  std::string print(const ScalarAtomic& a) const;
110 
111  // Print the input arguments of an instruction
112  void print_arg(std::ostream &stream, casadi_int k, const ScalarAtomic& el,
113  const double* w) const;
114 
115  // Print the input arguments of an instruction
116  void print_arg(CodeGenerator& g, casadi_int k, const ScalarAtomic& el) const;
117 
118  // Print the output arguments of an instruction
119  void print_res(std::ostream &stream, casadi_int k, const ScalarAtomic& el,
120  const double* w) const;
121 
122  // Print the output arguments of an instruction
123  void print_res(CodeGenerator& g, casadi_int k, const ScalarAtomic& el) const;
124 
128  void disp_more(std::ostream& stream) const override;
129 
133  std::string class_name() const override {return "SXFunction";}
134 
138  bool is_a(const std::string& type, bool recursive) const override;
139 
141 
144  const SX sx_in(casadi_int ind) const override;
145  const std::vector<SX> sx_in() const override;
147 
149  std::vector<SX> free_sx() const override {
150  std::vector<SX> ret(free_vars_.size());
151  std::copy(free_vars_.begin(), free_vars_.end(), ret.begin());
152  return ret;
153  }
154 
158  bool has_free() const override { return !free_vars_.empty();}
159 
163  std::vector<std::string> get_free() const override {
164  std::vector<std::string> ret;
165  for (auto&& e : free_vars_) ret.push_back(e.name());
166  return ret;
167  }
168 
172  SX hess(casadi_int iind=0, casadi_int oind=0);
173 
177  casadi_int n_instructions() const override { return algorithm_.size();}
178 
182  casadi_int instruction_id(casadi_int k) const override { return algorithm_.at(k).op;}
183 
187  std::vector<casadi_int> instruction_input(casadi_int k) const override {
188  auto e = algorithm_.at(k);
189  if (e.op==OP_CALL) {
190  const ExtendedAlgEl& m = call_.el[e.i1];
191  return vector_static_cast<casadi_int>(m.dep);
192  } else if (casadi_math<double>::ndeps(e.op)==2 || e.op==OP_INPUT) {
193  return {e.i1, e.i2};
194  } else if (casadi_math<double>::ndeps(e.op)==1) {
195  return {e.i1};
196  } else {
197  return {};
198  }
199  }
200 
204  double instruction_constant(casadi_int k) const override {
205  return algorithm_.at(k).d;
206  }
207 
211  std::vector<casadi_int> instruction_output(casadi_int k) const override {
212  auto e = algorithm_.at(k);
213  if (e.op==OP_CALL) {
214  const ExtendedAlgEl& m = call_.el[e.i1];
215  return vector_static_cast<casadi_int>(m.res);
216  } else if (e.op==OP_OUTPUT) {
217  return {e.i0, e.i2};
218  } else {
219  return {e.i0};
220  }
221  }
222 
226  casadi_int n_nodes() const override { return algorithm_.size() - nnz_out();}
227 
235  typedef ScalarAtomic AlgEl;
236 
240  template<typename T>
241  struct TapeEl {
242  T d[2];
243  };
244 
248  std::vector<AlgEl> algorithm_;
249 
250  // Work vector size
251  size_t worksize_;
252 
254  std::vector<SXElem> free_vars_;
255 
257  std::vector<SXElem> operations_;
258 
260  std::vector<SXElem> constants_;
261 
263  std::vector<double> default_in_;
264 
266  std::vector<bool> copy_elision_;
267 
269  bool print_instructions_;
270 
274  void serialize_body(SerializingStream &s) const override;
275 
276  // call node information that won't fit into AlgEl
277  struct ExtendedAlgEl {
278  ExtendedAlgEl(const Function& fun);
279  Function f;
280  // Work vector indices of the arguments (cfr AlgEl::arg)
281  std::vector<int> dep;
282  // Work vector indices of the results (cfr AlgEl::res)
283  std::vector<int> res;
284 
285  std::vector<int> copy_elision_arg;
286  std::vector<int> copy_elision_offset;
287 
288  // Following fields are redundant but will increase eval speed
289  casadi_int n_dep;
290  casadi_int n_res;
291  casadi_int f_n_in;
292  casadi_int f_n_out;
293  std::vector<int> f_nnz_in;
294  std::vector<int> f_nnz_out;
295  };
296 
298  struct CallInfo {
299  // Maximum memory requirements across all call nodes
300  size_t sz_arg = 0, sz_res = 0, sz_iw = 0, sz_w = 0;
301  size_t sz_w_arg = 0, sz_w_res = 0;
302  std::vector<ExtendedAlgEl> el;
303  } call_;
304 
308  static ProtoFunction* deserialize(DeserializingStream& s);
309 
310  static std::vector<SX> order(const std::vector<SX>& expr);
311 
313 
316  static const Options options_;
317  const Options& get_options() const override { return options_;}
319 
321  Dict generate_options(const std::string& target="clone") const override;
322 
326  void init(const Dict& opts) override;
327 
331  void init_copy_elision();
332 
336  size_t codegen_sz_w(const CodeGenerator& g) const override;
337 
341  void codegen_declarations(CodeGenerator& g) const override;
342 
346  void codegen_body(CodeGenerator& g) const override;
347 
351  int sp_forward(const bvec_t** arg, bvec_t** res,
352  casadi_int* iw, bvec_t* w, void* mem) const override;
353 
357  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
358 
362  SX instructions_sx() const override;
363 
364  // Get all embedded functions, recursively
365  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
366 
370  double get_default_in(casadi_int ind) const override { return default_in_.at(ind);}
371 
375  void export_code_body(const std::string& lang,
376  std::ostream &stream, const Dict& options) const override;
377 
379  bool just_in_time_opencl_;
380 
382  bool just_in_time_sparsity_;
383 
385  bool live_variables_;
386 
387 protected:
388  template<typename T>
389  void call_fwd(const AlgEl& e, const T** arg, T** res, casadi_int* iw, T* w) const;
390 
391  template<typename T>
392  void call_rev(const AlgEl& e, T** arg, T** res, casadi_int* iw, T* w) const;
393 
394  template<typename T, typename CT>
395  void call_setup(const ExtendedAlgEl& m,
396  CT*** call_arg, T*** call_res, casadi_int** call_iw, T** call_w, T** nz_in, T** nz_out) const;
397 
401  explicit SXFunction(DeserializingStream& s);
402 };
403 
404 
405 } // namespace casadi
406 
408 #endif // CASADI_SX_FUNCTION_HPP
The casadi namespace.
Definition: archiver.hpp:32
std::vector< MX > MXVector
Definition: mx.hpp:1006
Matrix< SXElem > SX
Definition: sx_fwd.hpp:32
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.