26 #ifndef CASADI_ORACLE_FUNCTION_HPP
27 #define CASADI_ORACLE_FUNCTION_HPP
29 #include "function_internal.hpp"
36 class CASADI_EXPORT OracleCallback {
39 OracleFunction* oracle_;
40 OracleCallback(
const std::string& name, OracleFunction* oracle);
50 struct CASADI_EXPORT LocalOracleMemory :
public FunctionMemory {
61 struct CASADI_EXPORT OracleMemory :
public FunctionMemory {
70 std::vector<LocalOracleMemory*> thread_local_mem;
80 class CASADI_EXPORT OracleFunction :
public FunctionInternal {
87 Dict specific_options_;
90 bool show_eval_warnings_;
100 bool monitored =
false;
104 std::map<std::string, RegFun> all_functions_;
107 std::vector<std::string> monitor_;
110 size_t stride_arg_, stride_res_, stride_iw_, stride_w_;
120 OracleFunction(
const std::string& name,
const Function& oracle);
125 ~OracleFunction()
override = 0;
131 static const Options options_;
132 const Options& get_options()
const override {
return options_;}
136 void init(
const Dict& opts)
override;
139 void finalize()
override;
142 void join_results(OracleMemory* m)
const;
147 const Function& oracle()
const override {
return oracle_;}
155 Function create_function(
const Function& oracle,
const std::string& fname,
156 const std::vector<std::string>& s_in,
157 const std::vector<std::string>& s_out,
162 Function create_function(
const std::string& fname,
163 const std::vector<std::string>& s_in,
164 const std::vector<std::string>& s_out,
169 Function create_function(
const std::string& fname,
170 const std::vector<MX>& e_in,
171 const std::vector<MX>& e_out,
172 const std::vector<std::string>& s_in,
173 const std::vector<std::string>& s_out,
177 Function create_forward(
const std::string& fname, casadi_int nfwd);
180 void set_function(
const Function& fcn,
const std::string& fname,
bool jit=
false);
183 void set_function(
const Function& fcn) { set_function(fcn, fcn.name()); }
186 int calc_function(OracleMemory* m,
const std::string& fcn,
187 const double*
const* arg=
nullptr,
int thread_id=0)
const;
190 int calc_sp_forward(
const std::string& fcn,
const bvec_t** arg, bvec_t** res,
191 casadi_int* iw, bvec_t* w)
const;
194 int calc_sp_reverse(
const std::string& fcn, bvec_t** arg, bvec_t** res,
195 casadi_int* iw, bvec_t* w)
const;
202 std::vector<std::string> get_function()
const override;
205 const Function& get_function(
const std::string &name)
const override;
208 virtual bool monitored(
const std::string &name)
const;
211 bool has_function(
const std::string& fname)
const override;
216 std::string generate_dependencies(
const std::string& fname,
const Dict& opts)
const override;
221 void jit_dependencies(
const std::string& fname)
override;
226 void* alloc_mem()
const override {
return new OracleMemory();}
231 int local_init_mem(
void* mem)
const;
236 int init_mem(
void* mem)
const override;
241 void free_mem(
void *mem)
const override {
delete static_cast<OracleMemory*
>(mem);}
246 void set_temp(
void* mem,
const double** arg,
double** res,
247 casadi_int* iw,
double* w)
const override;
250 Dict get_stats(
void* mem)
const override;
255 virtual void codegen_body_enter(CodeGenerator& g)
const;
260 virtual void codegen_body_exit(CodeGenerator& g)
const;
265 void serialize_body(SerializingStream &s)
const override;
271 explicit OracleFunction(DeserializingStream& s);
275 template<
typename T1>
277 OracleMemory* m =
static_cast<OracleMemory*
>(d->
m);
279 return cb->oracle_->calc_function(m, cb->name);
281 catch (
const std::exception& e) {
282 uerr() << e.what() << std::endl;
std::map< std::string, std::vector< std::string > > AuxOut
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
CASADI_EXPORT std::ostream & uerr()