mx_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_MX_FUNCTION_HPP
27 #define CASADI_MX_FUNCTION_HPP
28 
29 #include <iostream>
30 #include <map>
31 #include <set>
32 #include <vector>
33 
34 #include "x_function.hpp"
35 #include "mx_node.hpp"
36 
38 
39 namespace casadi {
40 
41 #ifndef SWIG
45  struct MXAlgEl {
47  casadi_int op;
48 
51 
53  std::vector<casadi_int> arg;
54 
56  std::vector<casadi_int> res;
57  };
58 #endif // SWIG
59 
66  class CASADI_EXPORT MXFunction :
67  public XFunction<MXFunction, MX, MXNode>{
68  public:
72  typedef MXAlgEl AlgEl;
73 
77  std::vector<AlgEl> algorithm_;
78 
82  std::vector<casadi_int> workloc_;
83 
84 
85  std::vector<bool> workstate_;
86 
88  std::vector<MX> free_vars_;
89 
91  std::vector<double> default_in_;
92 
95 
98 
102  MXFunction(const std::string& name,
103  const std::vector<MX>& input, const std::vector<MX>& output,
104  const std::vector<std::string>& name_in,
105  const std::vector<std::string>& name_out);
106 
110  ~MXFunction() override;
111 
115  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
116 
120  void disp_more(std::ostream& stream) const override;
121 
125  std::string class_name() const override {return "MXFunction";}
126 
130  bool is_a(const std::string& type, bool recursive) const override;
131 
133 
136  static const Options options_;
137  const Options& get_options() const override { return options_;}
139 
141  Dict get_stats(void* mem) const override;
142 
144  Dict generate_options(const std::string& target="clone") const override;
145 
149  void init(const Dict& opts) override;
150 
154  void codegen_declarations(CodeGenerator& g) const override;
155 
159  void codegen_incref(CodeGenerator& g) const override;
160 
164  void codegen_decref(CodeGenerator& g) const override;
165 
169  void codegen_body(CodeGenerator& g) const override;
170 
174  void serialize_body(SerializingStream &s) const override;
175 
179  static ProtoFunction* deserialize(DeserializingStream& s);
180 
181  static std::vector<MX> order(const std::vector<MX>& expr);
182 
188  void generate_lifted(Function& vdef_fcn, Function& vinit_fcn) const override;
189 
191  bool should_inline(bool with_sx, bool always_inline, bool never_inline) const override;
192 
196  int eval_sx(const SXElem** arg, SXElem** res,
197  casadi_int* iw, SXElem* w, void* mem,
198  bool always_inline, bool never_inline) const override;
199 
203  void eval_mx(const MXVector& arg, MXVector& res,
204  bool always_inline, bool never_inline) const override;
205 
209  void ad_forward(const std::vector<std::vector<MX> >& fwdSeed,
210  std::vector<std::vector<MX> >& fwdSens) const;
211 
215  void ad_reverse(const std::vector<std::vector<MX> >& adjSeed,
216  std::vector<std::vector<MX> >& adjSens) const;
217 
219  std::vector<MX> symbolic_output(const std::vector<MX>& arg) const override;
220 
224  int sp_forward(const bvec_t** arg, bvec_t** res,
225  casadi_int* iw, bvec_t* w, void* mem) const override;
226 
230  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
231 
232  // print an element of an algorithm
233  std::string print(const AlgEl& el) const;
234 
235  // Print the input arguments of an instruction
236  void print_arg(std::ostream &stream, casadi_int k, const AlgEl& el, const double** arg) const;
237 
238  // Print the output arguments of an instruction
239  void print_res(std::ostream &stream, casadi_int k, const AlgEl& el, double** res) const;
240 
242 
245  const MX mx_in(casadi_int ind) const override;
246  const std::vector<MX> mx_in() const override;
248 
250  std::vector<MX> free_mx() const override {return free_vars_;}
251 
255  bool has_free() const override { return !free_vars_.empty();}
256 
260  std::vector<std::string> get_free() const override {
261  std::vector<std::string> ret;
262  for (auto&& e : free_vars_) ret.push_back(e.name());
263  return ret;
264  }
265 
269  casadi_int n_nodes() const override { return algorithm_.size();}
270 
271  casadi_int n_instructions() const override { return algorithm_.size();}
272 
276  MX instruction_MX(casadi_int k) const override;
277 
281  casadi_int instruction_id(casadi_int k) const override { return algorithm_.at(k).op;}
282 
286  double get_default_in(casadi_int ind) const override { return default_in_.at(ind);}
287 
291  std::vector<casadi_int> instruction_input(casadi_int k) const override;
292 
296  std::vector<casadi_int> instruction_output(casadi_int k) const override;
297 
301  void export_code_body(const std::string& lang,
302  std::ostream &stream, const Dict& options) const override;
303 
305  void substitute_inplace(std::vector<MX>& vdef, std::vector<MX>& ex) const;
306 
307  // Get all embedded functions, recursively
308  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
309 
313  void change_option(const std::string& option_name, const GenericType& option_value) override;
314 
315  protected:
319  explicit MXFunction(DeserializingStream& s);
320  };
321 
322 } // namespace casadi
324 
325 #endif // CASADI_MX_FUNCTION_HPP
Helper class for C code generation.
Helper class for Serialization.
Function object.
Definition: function.hpp:60
Generic data type, can hold different types such as bool, casadi_int, std::string etc.
Internal node class for MXFunction.
Definition: mx_function.hpp:67
static const Options options_
Options.
std::vector< casadi_int > workloc_
Offsets for elements in the w_ vector.
Definition: mx_function.hpp:82
bool live_variables_
Live variables?
Definition: mx_function.hpp:94
std::vector< double > default_in_
Default input values.
Definition: mx_function.hpp:91
casadi_int instruction_id(casadi_int k) const override
Get an atomic operation operator index.
double get_default_in(casadi_int ind) const override
Get default input value.
casadi_int n_nodes() const override
Number of nodes in the algorithm.
std::vector< bool > workstate_
Definition: mx_function.hpp:85
std::vector< std::string > get_free() const override
Print free variables.
bool has_free() const override
Does the function have free variables.
bool print_instructions_
Print instructions during evaluation.
Definition: mx_function.hpp:97
MXAlgEl AlgEl
An element of the algorithm, namely an MX node.
Definition: mx_function.hpp:72
casadi_int n_instructions() const override
Get the number of atomic operations.
std::vector< AlgEl > algorithm_
All the runtime elements in the order of evaluation.
Definition: mx_function.hpp:77
const Options & get_options() const override
Options.
std::vector< MX > free_mx() const override
Get free variables (MX)
std::vector< MX > free_vars_
Free variables.
Definition: mx_function.hpp:88
std::string class_name() const override
Get type name.
MX - Matrix expression.
Definition: mx.hpp:92
Base class for FunctionInternal and LinsolInternal.
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
Internal node class for the base class of SXFunction and MXFunction.
Definition: x_function.hpp:57
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
std::vector< casadi_int > find(const std::vector< T > &v)
find nonzeros
std::vector< MX > MXVector
Definition: mx.hpp:1006
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
An element of the algorithm, namely an MX node.
Definition: mx_function.hpp:45
MX data
Data associated with the operation.
Definition: mx_function.hpp:50
std::vector< casadi_int > arg
Work vector indices of the arguments.
Definition: mx_function.hpp:53
casadi_int op
Operator index.
Definition: mx_function.hpp:47
std::vector< casadi_int > res
Work vector indices of the results.
Definition: mx_function.hpp:56
Options metadata for a class.
Definition: options.hpp:40