oracle_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_ORACLE_FUNCTION_HPP
27 #define CASADI_ORACLE_FUNCTION_HPP
28 
29 #include "function_internal.hpp"
30 
32 namespace casadi {
33 
34  class OracleFunction;
35 
36  class CASADI_EXPORT OracleCallback {
37  public:
38  std::string name;
39  OracleFunction* oracle_;
40  OracleCallback(const std::string& name, OracleFunction* oracle);
41  OracleCallback();
42  };
43 
44  template<typename T1>
45  int calc_function(const OracleCallback* cb, casadi_oracle_data<T1>* d);
46 
50  struct CASADI_EXPORT LocalOracleMemory : public FunctionMemory {
51  // Work vectors
52  const double** arg;
53  double** res;
54  casadi_int* iw;
55  double* w;
56  };
57 
61  struct CASADI_EXPORT OracleMemory : public FunctionMemory {
62  // Work vector aliases for non-threaded convenience
63  const double** arg;
64  double** res;
65  casadi_int* iw;
66  double* w;
67 
69 
70  std::vector<LocalOracleMemory*> thread_local_mem;
71  ~OracleMemory();
72  };
73 
80  class CASADI_EXPORT OracleFunction : public FunctionInternal {
81  protected:
83  Function oracle_;
84 
86  Dict common_options_;
87  Dict specific_options_;
88 
90  bool show_eval_warnings_;
91 
92  // Maximum number of threads
93  int max_num_threads_;
94 
95  // Information about one function
96  struct RegFun {
97  Function f;
98  bool jit;
99  Function f_original; // Relevant for jit
100  bool monitored = false;
101  };
102 
103  // All NLP functions
104  std::map<std::string, RegFun> all_functions_;
105 
106  // Active monitors
107  std::vector<std::string> monitor_;
108 
109  // Memory stride in case of multipel threads
110  size_t stride_arg_, stride_res_, stride_iw_, stride_w_;
111 
112  public:
116  OracleFunction(const std::string& name, const Function& oracle);
117 
121  ~OracleFunction() override = 0;
122 
124 
127  static const Options options_;
128  const Options& get_options() const override { return options_;}
130 
132  void init(const Dict& opts) override;
133 
135  void finalize() override;
136 
138  void join_results(OracleMemory* m) const;
139 
143  const Function& oracle() const override { return oracle_;}
144 
145  // Replace MX oracle with SX oracle?
146  void expand();
147 
151  Function create_function(const Function& oracle, const std::string& fname,
152  const std::vector<std::string>& s_in,
153  const std::vector<std::string>& s_out,
154  const Function::AuxOut& aux=Function::AuxOut(),
155  const Dict& opts=Dict());
156 
158  Function create_function(const std::string& fname,
159  const std::vector<std::string>& s_in,
160  const std::vector<std::string>& s_out,
161  const Function::AuxOut& aux=Function::AuxOut(),
162  const Dict& opts=Dict());
163 
165  Function create_function(const std::string& fname,
166  const std::vector<MX>& e_in,
167  const std::vector<MX>& e_out,
168  const std::vector<std::string>& s_in,
169  const std::vector<std::string>& s_out,
170  const Dict& opts=Dict());
171 
173  Function create_forward(const std::string& fname, casadi_int nfwd);
174 
176  void set_function(const Function& fcn, const std::string& fname, bool jit=false);
177 
179  void set_function(const Function& fcn) { set_function(fcn, fcn.name()); }
180 
181  // Calculate an oracle function
182  int calc_function(OracleMemory* m, const std::string& fcn,
183  const double* const* arg=nullptr, int thread_id=0) const;
184 
185  // Forward sparsity propagation through a function
186  int calc_sp_forward(const std::string& fcn, const bvec_t** arg, bvec_t** res,
187  casadi_int* iw, bvec_t* w) const;
188 
189  // Reverse sparsity propagation through a function
190  int calc_sp_reverse(const std::string& fcn, bvec_t** arg, bvec_t** res,
191  casadi_int* iw, bvec_t* w) const;
192 
198  std::vector<std::string> get_function() const override;
199 
200  // Get a dependency function
201  const Function& get_function(const std::string &name) const override;
202 
203  // Is a function monitored?
204  virtual bool monitored(const std::string &name) const;
205 
206  // Check if a particular dependency exists
207  bool has_function(const std::string& fname) const override;
208 
212  std::string generate_dependencies(const std::string& fname, const Dict& opts) const override;
213 
217  void jit_dependencies(const std::string& fname) override;
218 
222  void* alloc_mem() const override { return new OracleMemory();}
223 
227  int local_init_mem(void* mem) const;
228 
232  int init_mem(void* mem) const override;
233 
237  void free_mem(void *mem) const override { delete static_cast<OracleMemory*>(mem);}
238 
242  void set_temp(void* mem, const double** arg, double** res,
243  casadi_int* iw, double* w) const override;
244 
246  Dict get_stats(void* mem) const override;
247 
251  virtual void codegen_body_enter(CodeGenerator& g) const;
252 
256  virtual void codegen_body_exit(CodeGenerator& g) const;
257 
261  void serialize_body(SerializingStream &s) const override;
262 
263  protected:
267  explicit OracleFunction(DeserializingStream& s);
268 
269  };
270 
271  template<typename T1>
272  int calc_function(const OracleCallback* cb, casadi_oracle_data<T1>* d) {
273  OracleMemory* m = static_cast<OracleMemory*>(d->m);
274  try {
275  return cb->oracle_->calc_function(m, cb->name);
276  }
277  catch (const std::exception& e) {
278  uerr() << e.what() << std::endl;
279  return 1;
280  }
281  }
282 
283 
284 } // namespace casadi
285 
287 
288 #endif // CASADI_ORACLE_FUNCTION_HPP
std::map< std::string, std::vector< std::string > > AuxOut
Definition: function.hpp:395
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
CASADI_EXPORT std::ostream & uerr()