sx_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SX_FUNCTION_HPP
27 #define CASADI_SX_FUNCTION_HPP
28 
29 #include "x_function.hpp"
30 
32 
33 namespace casadi {
37  struct ScalarAtomic {
38  int op;
39  int i0;
40  union {
41  double d;
42  struct { int i1, i2; };
43  };
44  };
45 
53 class CASADI_EXPORT SXFunction :
54  public XFunction<SXFunction, Matrix<SXElem>, SXNode>{
55  public:
59  SXFunction(const std::string& name,
60  const std::vector<Matrix<SXElem> >& inputv,
61  const std::vector<Matrix<SXElem> >& outputv,
62  const std::vector<std::string>& name_in,
63  const std::vector<std::string>& name_out);
64 
68  ~SXFunction() override;
69 
73  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
74 
78  int eval_sx(const SXElem** arg, SXElem** res,
79  casadi_int* iw, SXElem* w, void* mem,
80  bool always_inline, bool never_inline) const override;
81 
85  void eval_mx(const MXVector& arg, MXVector& res,
86  bool always_inline, bool never_inline) const override;
87 
89  bool should_inline(bool with_sx, bool always_inline, bool never_inline) const override;
90 
94  void ad_forward(const std::vector<std::vector<SX> >& fseed,
95  std::vector<std::vector<SX> >& fsens) const;
96 
100  void ad_reverse(const std::vector<std::vector<SX> >& aseed,
101  std::vector<std::vector<SX> >& asens) const;
102 
106  bool is_smooth() const;
107 
111  void disp_more(std::ostream& stream) const override;
112 
116  std::string class_name() const override {return "SXFunction";}
117 
121  bool is_a(const std::string& type, bool recursive) const override;
122 
124 
127  const SX sx_in(casadi_int ind) const override;
128  const std::vector<SX> sx_in() const override;
130 
132  std::vector<SX> free_sx() const override {
133  std::vector<SX> ret(free_vars_.size());
134  std::copy(free_vars_.begin(), free_vars_.end(), ret.begin());
135  return ret;
136  }
137 
141  bool has_free() const override { return !free_vars_.empty();}
142 
146  std::vector<std::string> get_free() const override {
147  std::vector<std::string> ret;
148  for (auto&& e : free_vars_) ret.push_back(e.name());
149  return ret;
150  }
151 
155  SX hess(casadi_int iind=0, casadi_int oind=0);
156 
160  casadi_int n_instructions() const override { return algorithm_.size();}
161 
165  casadi_int instruction_id(casadi_int k) const override { return algorithm_.at(k).op;}
166 
170  std::vector<casadi_int> instruction_input(casadi_int k) const override {
171  auto e = algorithm_.at(k);
172  if (e.op==OP_CALL) {
173  const ExtendedAlgEl& m = call_.el[e.i1];
174  return vector_static_cast<casadi_int>(m.dep);
175  } else if (casadi_math<double>::ndeps(e.op)==2 || e.op==OP_INPUT) {
176  return {e.i1, e.i2};
177  } else if (casadi_math<double>::ndeps(e.op)==1) {
178  return {e.i1};
179  } else {
180  return {};
181  }
182  }
183 
187  double instruction_constant(casadi_int k) const override {
188  return algorithm_.at(k).d;
189  }
190 
194  std::vector<casadi_int> instruction_output(casadi_int k) const override {
195  auto e = algorithm_.at(k);
196  if (e.op==OP_CALL) {
197  const ExtendedAlgEl& m = call_.el[e.i1];
198  return vector_static_cast<casadi_int>(m.res);
199  } else if (e.op==OP_OUTPUT) {
200  return {e.i0, e.i2};
201  } else {
202  return {e.i0};
203  }
204  }
205 
209  casadi_int n_nodes() const override { return algorithm_.size() - nnz_out();}
210 
219 
223  template<typename T>
224  struct TapeEl {
225  T d[2];
226  };
227 
231  std::vector<AlgEl> algorithm_;
232 
233  // Work vector size
234  size_t worksize_;
235 
237  std::vector<SXElem> free_vars_;
238 
240  std::vector<SXElem> operations_;
241 
243  std::vector<SXElem> constants_;
244 
246  std::vector<double> default_in_;
247 
249  std::vector<bool> copy_elision_;
250 
254  void serialize_body(SerializingStream &s) const override;
255 
256  // call node information that won't fit into AlgEl
257  struct ExtendedAlgEl {
258  ExtendedAlgEl(const Function& fun);
260  // Work vector indices of the arguments (cfr AlgEl::arg)
261  std::vector<int> dep;
262  // Work vector indices of the results (cfr AlgEl::res)
263  std::vector<int> res;
264 
265  std::vector<int> copy_elision_arg;
266  std::vector<int> copy_elision_offset;
267 
268  // Following fields are redundant but will increase eval speed
269  casadi_int n_dep;
270  casadi_int n_res;
271  casadi_int f_n_in;
272  casadi_int f_n_out;
273  std::vector<int> f_nnz_in;
274  std::vector<int> f_nnz_out;
275  };
276 
278  struct CallInfo {
279  // Maximum memory requirements across all call nodes
280  size_t sz_arg = 0, sz_res = 0, sz_iw = 0, sz_w = 0;
281  size_t sz_w_arg = 0, sz_w_res = 0;
282  std::vector<ExtendedAlgEl> el;
283  } call_;
284 
288  static ProtoFunction* deserialize(DeserializingStream& s);
289 
290  static std::vector<SX> order(const std::vector<SX>& expr);
291 
293 
296  static const Options options_;
297  const Options& get_options() const override { return options_;}
299 
301  Dict generate_options(const std::string& target="clone") const override;
302 
306  void init(const Dict& opts) override;
307 
311  void init_copy_elision();
312 
316  size_t codegen_sz_w(const CodeGenerator& g) const override;
317 
321  void codegen_declarations(CodeGenerator& g) const override;
322 
326  void codegen_body(CodeGenerator& g) const override;
327 
331  int sp_forward(const bvec_t** arg, bvec_t** res,
332  casadi_int* iw, bvec_t* w, void* mem) const override;
333 
337  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
338 
342  SX instructions_sx() const override;
343 
344  // Get all embedded functions, recursively
345  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
346 
350  double get_default_in(casadi_int ind) const override { return default_in_.at(ind);}
351 
355  void export_code_body(const std::string& lang,
356  std::ostream &stream, const Dict& options) const override;
357 
360 
363 
366 
367 protected:
368  template<typename T>
369  void call_fwd(const AlgEl& e, const T** arg, T** res, casadi_int* iw, T* w) const;
370 
371  template<typename T>
372  void call_rev(const AlgEl& e, T** arg, T** res, casadi_int* iw, T* w) const;
373 
374  template<typename T, typename CT>
375  void call_setup(const ExtendedAlgEl& m,
376  CT*** call_arg, T*** call_res, casadi_int** call_iw, T** call_w, T** nz_in, T** nz_out) const;
377 
381  explicit SXFunction(DeserializingStream& s);
382 };
383 
384 
385 } // namespace casadi
386 
388 #endif // CASADI_SX_FUNCTION_HPP
Helper class for C code generation.
Helper class for Serialization.
Function object.
Definition: function.hpp:60
Sparse matrix class. SX and DM are specializations.
Definition: matrix_decl.hpp:99
Base class for FunctionInternal and LinsolInternal.
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Internal node class for SXFunction.
Definition: sx_function.hpp:54
std::vector< SXElem > operations_
The expressions corresponding to each binary operation.
std::string class_name() const override
Get type name.
SXFunction(const std::string &name, const std::vector< Matrix< SXElem > > &inputv, const std::vector< Matrix< SXElem > > &outputv, const std::vector< std::string > &name_in, const std::vector< std::string > &name_out)
Constructor.
casadi_int n_nodes() const override
Number of nodes in the algorithm.
casadi_int instruction_id(casadi_int k) const override
Get an atomic operation operator index.
std::vector< SX > free_sx() const override
Get free variables (SX)
SX hess(casadi_int iind=0, casadi_int oind=0)
Hessian (forward over adjoint) via source code transformation.
static const Options options_
Options.
std::vector< bool > copy_elision_
Copy elision per algel.
ScalarAtomic AlgEl
DATA MEMBERS.
std::vector< casadi_int > instruction_output(casadi_int k) const override
Get the (integer) output argument of an atomic operation.
bool has_free() const override
Does the function have free variables.
double get_default_in(casadi_int ind) const override
Get default input value.
std::vector< SXElem > constants_
The expressions corresponding to each constant.
const Options & get_options() const override
Options.
bool just_in_time_opencl_
With just-in-time compilation using OpenCL.
std::vector< std::string > get_free() const override
Print free variables.
std::vector< AlgEl > algorithm_
all binary nodes of the tree in the order of execution
std::vector< SXElem > free_vars_
Free variables.
std::vector< double > default_in_
Default input values.
std::vector< casadi_int > instruction_input(casadi_int k) const override
Get the (integer) input arguments of an atomic operation.
double instruction_constant(casadi_int k) const override
Get the floating point output argument of an atomic operation.
bool live_variables_
Live variables?
casadi_int n_instructions() const override
Get the number of atomic operations.
bool just_in_time_sparsity_
With just-in-time compilation for the sparsity propagation.
Helper class for Serialization.
Internal node class for the base class of SXFunction and MXFunction.
Definition: x_function.hpp:57
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
std::vector< casadi_int > find(const std::vector< T > &v)
find nonzeros
std::vector< MX > MXVector
Definition: mx.hpp:1006
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
@ OP_OUTPUT
Definition: calculus.hpp:82
@ OP_INPUT
Definition: calculus.hpp:82
@ OP_CALL
Definition: calculus.hpp:88
Options metadata for a class.
Definition: options.hpp:40
Metadata for call nodes.
std::vector< ExtendedAlgEl > el
std::vector< int > copy_elision_offset
std::vector< int > copy_elision_arg
An element of the tape.
An atomic operation for the SXElem virtual machine.
Definition: sx_function.hpp:37
int i0
Operator index.
Definition: sx_function.hpp:39
Easy access to all the functions for a particular type.
Definition: calculus.hpp:1125