mx_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_MX_FUNCTION_HPP
27 #define CASADI_MX_FUNCTION_HPP
28 
29 #include <iostream>
30 #include <map>
31 #include <set>
32 #include <vector>
33 
34 #include "x_function.hpp"
35 #include "mx_node.hpp"
36 
38 
39 namespace casadi {
40 
41 #ifndef SWIG
45  struct MXAlgEl {
47  casadi_int op;
48 
50  MX data;
51 
53  std::vector<casadi_int> arg;
54 
56  std::vector<casadi_int> res;
57  };
58 #endif // SWIG
59 
66  class CASADI_EXPORT MXFunction :
67  public XFunction<MXFunction, MX, MXNode>{
68  public:
72  typedef MXAlgEl AlgEl;
73 
77  std::vector<AlgEl> algorithm_;
78 
82  std::vector<casadi_int> workloc_;
83 
85  std::vector<MX> free_vars_;
86 
88  std::vector<double> default_in_;
89 
91  bool live_variables_;
92 
94  bool print_instructions_;
95 
99  MXFunction(const std::string& name,
100  const std::vector<MX>& input, const std::vector<MX>& output,
101  const std::vector<std::string>& name_in,
102  const std::vector<std::string>& name_out);
103 
107  ~MXFunction() override;
108 
112  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
113 
117  void disp_more(std::ostream& stream) const override;
118 
122  std::string class_name() const override {return "MXFunction";}
123 
127  bool is_a(const std::string& type, bool recursive) const override;
128 
130 
133  static const Options options_;
134  const Options& get_options() const override { return options_;}
136 
138  Dict get_stats(void* mem) const override;
139 
141  Dict generate_options(const std::string& target="clone") const override;
142 
146  void init(const Dict& opts) override;
147 
151  void codegen_declarations(CodeGenerator& g) const override;
152 
156  void codegen_incref(CodeGenerator& g) const override;
157 
161  void codegen_decref(CodeGenerator& g) const override;
162 
166  void codegen_body(CodeGenerator& g) const override;
167 
171  void serialize_body(SerializingStream &s) const override;
172 
176  static ProtoFunction* deserialize(DeserializingStream& s);
177 
183  void generate_lifted(Function& vdef_fcn, Function& vinit_fcn) const override;
184 
186  bool should_inline(bool always_inline, bool never_inline) const override;
187 
191  int eval_sx(const SXElem** arg, SXElem** res,
192  casadi_int* iw, SXElem* w, void* mem) const override;
193 
197  void eval_mx(const MXVector& arg, MXVector& res,
198  bool always_inline, bool never_inline) const override;
199 
203  void ad_forward(const std::vector<std::vector<MX> >& fwdSeed,
204  std::vector<std::vector<MX> >& fwdSens) const;
205 
209  void ad_reverse(const std::vector<std::vector<MX> >& adjSeed,
210  std::vector<std::vector<MX> >& adjSens) const;
211 
213  std::vector<MX> symbolic_output(const std::vector<MX>& arg) const override;
214 
218  int sp_forward(const bvec_t** arg, bvec_t** res,
219  casadi_int* iw, bvec_t* w, void* mem) const override;
220 
224  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
225 
226  // print an element of an algorithm
227  std::string print(const AlgEl& el) const;
228 
229  // Print the input arguments of an instruction
230  void print_arg(std::ostream &stream, casadi_int k, const AlgEl& el, const double** arg) const;
231 
232  // Print the output arguments of an instruction
233  void print_res(std::ostream &stream, casadi_int k, const AlgEl& el, double** res) const;
234 
236 
239  const MX mx_in(casadi_int ind) const override;
240  const std::vector<MX> mx_in() const override;
242 
244  std::vector<MX> free_mx() const override {return free_vars_;}
245 
249  bool has_free() const override { return !free_vars_.empty();}
250 
254  std::vector<std::string> get_free() const override {
255  std::vector<std::string> ret;
256  for (auto&& e : free_vars_) ret.push_back(e.name());
257  return ret;
258  }
259 
263  casadi_int n_nodes() const override { return algorithm_.size();}
264 
265  casadi_int n_instructions() const override { return algorithm_.size();}
266 
270  MX instruction_MX(casadi_int k) const override;
271 
275  casadi_int instruction_id(casadi_int k) const override { return algorithm_.at(k).op;}
276 
280  double get_default_in(casadi_int ind) const override { return default_in_.at(ind);}
281 
285  std::vector<casadi_int> instruction_input(casadi_int k) const override;
286 
290  std::vector<casadi_int> instruction_output(casadi_int k) const override;
291 
295  void export_code_body(const std::string& lang,
296  std::ostream &stream, const Dict& options) const override;
297 
299  void substitute_inplace(std::vector<MX>& vdef, std::vector<MX>& ex) const;
300 
301  // Get all embedded functions, recursively
302  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
303 
307  void change_option(const std::string& option_name, const GenericType& option_value) override;
308 
309  protected:
313  explicit MXFunction(DeserializingStream& s);
314  };
315 
316 } // namespace casadi
318 
319 #endif // CASADI_MX_FUNCTION_HPP
The casadi namespace.
std::vector< MX > MXVector
Definition: mx.hpp:940
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.