mx_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_MX_FUNCTION_HPP
27 #define CASADI_MX_FUNCTION_HPP
28 
29 #include <iostream>
30 #include <map>
31 #include <set>
32 #include <vector>
33 
34 #include "x_function.hpp"
35 #include "mx_node.hpp"
36 
38 
39 namespace casadi {
40 
41 #ifndef SWIG
45  struct MXAlgEl {
47  casadi_int op;
48 
50  MX data;
51 
53  std::vector<casadi_int> arg;
54 
56  std::vector<casadi_int> res;
57  };
58 #endif // SWIG
59 
66  class CASADI_EXPORT MXFunction :
67  public XFunction<MXFunction, MX, MXNode>{
68  public:
72  typedef MXAlgEl AlgEl;
73 
77  std::vector<AlgEl> algorithm_;
78 
82  std::vector<casadi_int> workloc_;
83 
84 
85  std::vector<bool> workstate_;
86 
88  std::vector<MX> free_vars_;
89 
91  std::vector<double> default_in_;
92 
94  bool live_variables_;
95 
97  bool print_instructions_;
98 
102  MXFunction(const std::string& name,
103  const std::vector<MX>& input, const std::vector<MX>& output,
104  const std::vector<std::string>& name_in,
105  const std::vector<std::string>& name_out);
106 
110  ~MXFunction() override;
111 
115  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
116 
120  void disp_more(std::ostream& stream) const override;
121 
125  std::string class_name() const override {return "MXFunction";}
126 
130  bool is_a(const std::string& type, bool recursive) const override;
131 
133 
136  static const Options options_;
137  const Options& get_options() const override { return options_;}
139 
141  Dict get_stats(void* mem) const override;
142 
144  Dict generate_options(const std::string& target="clone") const override;
145 
149  void init(const Dict& opts) override;
150 
154  void codegen_declarations(CodeGenerator& g) const override;
155 
159  void codegen_incref(CodeGenerator& g) const override;
160 
164  void codegen_decref(CodeGenerator& g) const override;
165 
169  void codegen_body(CodeGenerator& g) const override;
170 
174  void serialize_body(SerializingStream &s) const override;
175 
179  static ProtoFunction* deserialize(DeserializingStream& s);
180 
181  static std::vector<MX> order(const std::vector<MX>& expr);
182 
188  void generate_lifted(Function& vdef_fcn, Function& vinit_fcn) const override;
189 
191  bool should_inline(bool with_sx, bool always_inline, bool never_inline) const override;
192 
196  int eval_sx(const SXElem** arg, SXElem** res,
197  casadi_int* iw, SXElem* w, void* mem,
198  bool always_inline, bool never_inline) const override;
199 
203  void eval_mx(const MXVector& arg, MXVector& res,
204  bool always_inline, bool never_inline) const override;
205 
209  void ad_forward(const std::vector<std::vector<MX> >& fwdSeed,
210  std::vector<std::vector<MX> >& fwdSens) const;
211 
215  void ad_reverse(const std::vector<std::vector<MX> >& adjSeed,
216  std::vector<std::vector<MX> >& adjSens) const;
217 
219  std::vector<MX> symbolic_output(const std::vector<MX>& arg) const override;
220 
224  int sp_forward(const bvec_t** arg, bvec_t** res,
225  casadi_int* iw, bvec_t* w, void* mem) const override;
226 
230  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
231 
232  // print an element of an algorithm
233  std::string print(const AlgEl& el) const;
234 
235  // Print the input arguments of an instruction
236  void print_arg(std::ostream &stream, casadi_int k, const AlgEl& el, const double** arg) const;
237 
238  // Print the output arguments of an instruction
239  void print_res(std::ostream &stream, casadi_int k, const AlgEl& el, double** res) const;
240 
242 
245  const MX mx_in(casadi_int ind) const override;
246  const std::vector<MX> mx_in() const override;
248 
250  std::vector<MX> free_mx() const override {return free_vars_;}
251 
255  bool has_free() const override { return !free_vars_.empty();}
256 
260  std::vector<std::string> get_free() const override {
261  std::vector<std::string> ret;
262  for (auto&& e : free_vars_) ret.push_back(e.name());
263  return ret;
264  }
265 
269  casadi_int n_nodes() const override { return algorithm_.size();}
270 
271  casadi_int n_instructions() const override { return algorithm_.size();}
272 
276  MX instruction_MX(casadi_int k) const override;
277 
281  casadi_int instruction_id(casadi_int k) const override { return algorithm_.at(k).op;}
282 
286  double get_default_in(casadi_int ind) const override { return default_in_.at(ind);}
287 
291  std::vector<casadi_int> instruction_input(casadi_int k) const override;
292 
296  std::vector<casadi_int> instruction_output(casadi_int k) const override;
297 
301  void export_code_body(const std::string& lang,
302  std::ostream &stream, const Dict& options) const override;
303 
305  void substitute_inplace(std::vector<MX>& vdef, std::vector<MX>& ex) const;
306 
307  // Get all embedded functions, recursively
308  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
309 
313  void change_option(const std::string& option_name, const GenericType& option_value) override;
314 
315  protected:
319  explicit MXFunction(DeserializingStream& s);
320  };
321 
322 } // namespace casadi
324 
325 #endif // CASADI_MX_FUNCTION_HPP
The casadi namespace.
Definition: archiver.hpp:32
std::vector< MX > MXVector
Definition: mx.hpp:1006
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.