oracle_function.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_ORACLE_FUNCTION_HPP
27 #define CASADI_ORACLE_FUNCTION_HPP
28 
29 #include "function_internal.hpp"
30 
32 namespace casadi {
33 
34  class OracleFunction;
35 
36  class CASADI_EXPORT OracleCallback {
37  public:
38  std::string name;
39  OracleFunction* oracle_;
40  OracleCallback(const std::string& name, OracleFunction* oracle);
41  OracleCallback();
42  };
43 
44  template<typename T1>
45  int calc_function(const OracleCallback* cb, casadi_oracle_data<T1>* d);
46 
50  struct CASADI_EXPORT LocalOracleMemory : public FunctionMemory {
51  // Work vectors
52  const double** arg;
53  double** res;
54  casadi_int* iw;
55  double* w;
56  };
57 
61  struct CASADI_EXPORT OracleMemory : public FunctionMemory {
62  // Work vector aliases for non-threaded convenience
63  const double** arg;
64  double** res;
65  casadi_int* iw;
66  double* w;
67 
69 
70  std::vector<LocalOracleMemory*> thread_local_mem;
71  ~OracleMemory();
72  };
73 
80  class CASADI_EXPORT OracleFunction : public FunctionInternal {
81  protected:
83  Function oracle_;
84 
86  Dict common_options_;
87  Dict specific_options_;
88 
90  bool show_eval_warnings_;
91 
92  // Maximum number of threads
93  int max_num_threads_;
94 
95  // Information about one function
96  struct RegFun {
97  Function f;
98  bool jit;
99  Function f_original; // Relevant for jit
100  bool monitored = false;
101  };
102 
103  // All NLP functions
104  std::map<std::string, RegFun> all_functions_;
105 
106  // Active monitors
107  std::vector<std::string> monitor_;
108 
109  // Memory stride in case of multipel threads
110  size_t stride_arg_, stride_res_, stride_iw_, stride_w_;
111 
112  // Expand after construction?
113  // Only used in finalize -> no need to serialize
114  bool post_expand_;
115 
116  public:
120  OracleFunction(const std::string& name, const Function& oracle);
121 
125  ~OracleFunction() override = 0;
126 
128 
131  static const Options options_;
132  const Options& get_options() const override { return options_;}
134 
136  void init(const Dict& opts) override;
137 
139  void finalize() override;
140 
142  void join_results(OracleMemory* m) const;
143 
147  const Function& oracle() const override { return oracle_;}
148 
149  // Replace MX oracle with SX oracle?
150  void expand();
151 
155  Function create_function(const Function& oracle, const std::string& fname,
156  const std::vector<std::string>& s_in,
157  const std::vector<std::string>& s_out,
158  const Function::AuxOut& aux=Function::AuxOut(),
159  const Dict& opts=Dict());
160 
162  Function create_function(const std::string& fname,
163  const std::vector<std::string>& s_in,
164  const std::vector<std::string>& s_out,
165  const Function::AuxOut& aux=Function::AuxOut(),
166  const Dict& opts=Dict());
167 
169  Function create_function(const std::string& fname,
170  const std::vector<MX>& e_in,
171  const std::vector<MX>& e_out,
172  const std::vector<std::string>& s_in,
173  const std::vector<std::string>& s_out,
174  const Dict& opts=Dict());
175 
177  Function create_forward(const std::string& fname, casadi_int nfwd);
178 
180  void set_function(const Function& fcn, const std::string& fname, bool jit=false);
181 
183  void set_function(const Function& fcn) { set_function(fcn, fcn.name()); }
184 
185  // Calculate an oracle function
186  int calc_function(OracleMemory* m, const std::string& fcn,
187  const double* const* arg=nullptr, int thread_id=0) const;
188 
189  // Forward sparsity propagation through a function
190  int calc_sp_forward(const std::string& fcn, const bvec_t** arg, bvec_t** res,
191  casadi_int* iw, bvec_t* w) const;
192 
193  // Reverse sparsity propagation through a function
194  int calc_sp_reverse(const std::string& fcn, bvec_t** arg, bvec_t** res,
195  casadi_int* iw, bvec_t* w) const;
196 
202  std::vector<std::string> get_function() const override;
203 
204  // Get a dependency function
205  const Function& get_function(const std::string &name) const override;
206 
207  // Is a function monitored?
208  virtual bool monitored(const std::string &name) const;
209 
210  // Check if a particular dependency exists
211  bool has_function(const std::string& fname) const override;
212 
216  std::string generate_dependencies(const std::string& fname, const Dict& opts) const override;
217 
221  void jit_dependencies(const std::string& fname) override;
222 
226  void* alloc_mem() const override { return new OracleMemory();}
227 
231  int local_init_mem(void* mem) const;
232 
236  int init_mem(void* mem) const override;
237 
241  void free_mem(void *mem) const override { delete static_cast<OracleMemory*>(mem);}
242 
246  void set_temp(void* mem, const double** arg, double** res,
247  casadi_int* iw, double* w) const override;
248 
250  Dict get_stats(void* mem) const override;
251 
255  virtual void codegen_body_enter(CodeGenerator& g) const;
256 
260  virtual void codegen_body_exit(CodeGenerator& g) const;
261 
265  void serialize_body(SerializingStream &s) const override;
266 
267  protected:
271  explicit OracleFunction(DeserializingStream& s);
272 
273  };
274 
275  template<typename T1>
276  int calc_function(const OracleCallback* cb, casadi_oracle_data<T1>* d) {
277  OracleMemory* m = static_cast<OracleMemory*>(d->m);
278  try {
279  return cb->oracle_->calc_function(m, cb->name);
280  }
281  catch (const std::exception& e) {
282  uerr() << e.what() << std::endl;
283  return 1;
284  }
285  }
286 
287 
288 } // namespace casadi
289 
291 
292 #endif // CASADI_ORACLE_FUNCTION_HPP
std::map< std::string, std::vector< std::string > > AuxOut
Definition: function.hpp:404
The casadi namespace.
Definition: archiver.hpp:32
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
CASADI_EXPORT std::ostream & uerr()