26 #ifndef CASADI_ORACLE_FUNCTION_HPP
27 #define CASADI_ORACLE_FUNCTION_HPP
29 #include "function_internal.hpp"
36 class CASADI_EXPORT OracleCallback {
39 OracleFunction* oracle_;
40 OracleCallback(
const std::string& name, OracleFunction* oracle);
50 struct CASADI_EXPORT LocalOracleMemory :
public FunctionMemory {
61 struct CASADI_EXPORT OracleMemory :
public FunctionMemory {
70 std::vector<LocalOracleMemory*> thread_local_mem;
80 class CASADI_EXPORT OracleFunction :
public FunctionInternal {
87 Dict specific_options_;
90 bool show_eval_warnings_;
100 bool monitored =
false;
104 std::map<std::string, RegFun> all_functions_;
107 std::vector<std::string> monitor_;
110 size_t stride_arg_, stride_res_, stride_iw_, stride_w_;
116 OracleFunction(
const std::string& name,
const Function& oracle);
121 ~OracleFunction()
override = 0;
127 static const Options options_;
128 const Options& get_options()
const override {
return options_;}
132 void init(
const Dict& opts)
override;
135 void finalize()
override;
138 void join_results(OracleMemory* m)
const;
143 const Function& oracle()
const override {
return oracle_;}
151 Function create_function(
const Function& oracle,
const std::string& fname,
152 const std::vector<std::string>& s_in,
153 const std::vector<std::string>& s_out,
158 Function create_function(
const std::string& fname,
159 const std::vector<std::string>& s_in,
160 const std::vector<std::string>& s_out,
165 Function create_function(
const std::string& fname,
166 const std::vector<MX>& e_in,
167 const std::vector<MX>& e_out,
168 const std::vector<std::string>& s_in,
169 const std::vector<std::string>& s_out,
173 Function create_forward(
const std::string& fname, casadi_int nfwd);
176 void set_function(
const Function& fcn,
const std::string& fname,
bool jit=
false);
179 void set_function(
const Function& fcn) { set_function(fcn, fcn.name()); }
182 int calc_function(OracleMemory* m,
const std::string& fcn,
183 const double*
const* arg=
nullptr,
int thread_id=0)
const;
186 int calc_sp_forward(
const std::string& fcn,
const bvec_t** arg, bvec_t** res,
187 casadi_int* iw, bvec_t* w)
const;
190 int calc_sp_reverse(
const std::string& fcn, bvec_t** arg, bvec_t** res,
191 casadi_int* iw, bvec_t* w)
const;
198 std::vector<std::string> get_function()
const override;
201 const Function& get_function(
const std::string &name)
const override;
204 virtual bool monitored(
const std::string &name)
const;
207 bool has_function(
const std::string& fname)
const override;
212 std::string generate_dependencies(
const std::string& fname,
const Dict& opts)
const override;
217 void jit_dependencies(
const std::string& fname)
override;
222 void* alloc_mem()
const override {
return new OracleMemory();}
227 int local_init_mem(
void* mem)
const;
232 int init_mem(
void* mem)
const override;
237 void free_mem(
void *mem)
const override {
delete static_cast<OracleMemory*
>(mem);}
242 void set_temp(
void* mem,
const double** arg,
double** res,
243 casadi_int* iw,
double* w)
const override;
246 Dict get_stats(
void* mem)
const override;
251 virtual void codegen_body_enter(CodeGenerator& g)
const;
256 virtual void codegen_body_exit(CodeGenerator& g)
const;
261 void serialize_body(SerializingStream &s)
const override;
267 explicit OracleFunction(DeserializingStream& s);
271 template<
typename T1>
273 OracleMemory* m =
static_cast<OracleMemory*
>(d->
m);
275 return cb->oracle_->calc_function(m, cb->name);
277 catch (
const std::exception& e) {
278 uerr() << e.what() << std::endl;
std::map< std::string, std::vector< std::string > > AuxOut
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
CASADI_EXPORT std::ostream & uerr()