26 #include "oracle_function.hpp"
27 #include "external.hpp"
28 #include "serializing_stream.hpp"
54 "Replace MX with SX expressions in problem formulation [false] "
55 "This happens before creating derivatives unless indicated by postpone_expand"}},
58 "When expand is active, postpone it until after creation of derivatives. Default: False"}},
61 "Set of user problem functions to be monitored"}},
62 {
"show_eval_warnings",
64 "Show warnings generated from function evaluations [true]"}},
67 "Options for auto-generated functions"}},
70 "Options for specific auto-generated functions,"
71 " overwriting the defaults from common_options. Nested dictionary."}}
81 bool postpone_expand =
false;
89 for (
auto&& op : opts) {
90 if (op.first==
"expand") {
92 }
else if (op.first==
"postpone_expand") {
93 postpone_expand = op.second;
94 }
else if (op.first==
"common_options") {
96 }
else if (op.first==
"specific_options") {
99 casadi_assert(i.second.is_dict(),
100 "specific_option must be a nested dictionary."
101 " Type mismatch for entry '" + i.first+
"': "
102 " got type " + i.second.get_description() +
".");
104 }
else if (op.first==
"monitor") {
106 }
else if (op.first==
"show_eval_warnings") {
141 bool persistent =
false;
146 for (
const std::string& fname :
monitor_) {
149 casadi_warning(
"Ignoring monitor '" + fname +
"'."
152 if (it->second.monitored) casadi_warning(
"Duplicate monitor " + fname);
153 it->second.monitored =
true;
160 casadi_warning(
"Ignoring specific_options entry '" + i.first+
"'."
173 for (
auto&& s : ml->fstats) {
174 m->
fstats.at(s.first).join(s.second);
180 const std::vector<std::string>& s_in,
181 const std::vector<std::string>& s_out,
188 const std::vector<MX>& e_in,
189 const std::vector<MX>& e_out,
190 const std::vector<std::string>& s_in,
191 const std::vector<std::string>& s_out,
196 casadi_message(
name_ +
"::create_function " + fname +
":" +
str(s_in) +
"->" +
str(s_out));
203 casadi_assert(ret.
n_in() == s_in.size(), fname +
" has wrong number of inputs");
204 casadi_assert(ret.
n_out() == s_out.size(), fname +
" has wrong number of outputs");
207 Dict specific_options;
216 ret =
Function(fname, e_in, e_out, s_in, s_out, opt);
220 casadi_error(
"Cannot create '" + fname +
"' since " +
str(ret.
get_free()) +
" are free.");
236 const std::vector<std::string>& s_in,
237 const std::vector<std::string>& s_out,
242 casadi_message(
name_ +
"::create_function " + fname +
":" +
str(s_in) +
"->" +
str(s_out));
249 casadi_assert(ret.
n_in() == s_in.size(), fname +
" has wrong number of inputs");
250 casadi_assert(ret.
n_out() == s_out.size(), fname +
" has wrong number of outputs");
253 Dict specific_options;
266 casadi_error(
"Cannot create '" + fname +
"' since " +
str(ret.
get_free()) +
" are free.");
288 casadi_assert(!
has_function(fname),
"Duplicate function " + fname);
296 g.
local(
"d_oracle",
"struct casadi_oracle_data");
304 const double*
const* arg,
int thread_id)
const {
310 if (
monitored) casadi_message(
"Calling \"" + fcn +
"\"");
322 FStats& fstats = ml->fstats.at(fcn);
325 casadi_int n_in = f.
n_in(), n_out = f.
n_out();
332 std::fill_n(ml->arg, n_in,
nullptr);
333 for (casadi_int i=0; i<n_in; ++i) ml->arg[i] = *arg++;
339 s << fcn <<
" input nonzeros:\n";
340 for (casadi_int i=0; i<n_in; ++i) {
341 s <<
" " << i <<
" (" << f.
name_in(i) <<
"): ";
345 for (casadi_int k=0; k<f.
nnz_in(i); ++k) {
355 casadi_message(s.str());
360 if (f(ml->arg, ml->res, ml->iw, ml->w)) {
365 }
catch(std::exception& ex) {
367 casadi_error(
"Error in " +
name_ +
":" + fcn +
":" + std::string(ex.what()));
373 s << fcn <<
" output nonzeros:\n";
374 for (casadi_int i=0; i<n_out; ++i) {
375 s <<
" " << i <<
" (" << f.
name_out(i) <<
"): ";
379 for (casadi_int k=0; k<f.
nnz_out(i); ++k) {
389 casadi_message(s.str());
393 for (casadi_int i=0; i<n_out; ++i) {
394 if (!ml->res[i])
continue;
395 if (!std::all_of(ml->res[i], ml->res[i]+f.
nnz_out(i), [](
double v) { return isfinite(v);})) {
396 std::stringstream ss;
398 auto it = std::find_if(ml->res[i], ml->res[i] + f.
nnz_out(i),
399 [](
double v) { return !isfinite(v);});
400 casadi_int k = std::distance(ml->res[i], it);
401 bool is_nan = isnan(ml->res[i][k]);
402 ss <<
name_ <<
":" << fcn <<
" failed: " << (is_nan?
"NaN" :
"Inf") <<
406 casadi_error(ss.str());
419 casadi_int* iw,
bvec_t* w)
const {
424 casadi_int* iw,
bvec_t* w)
const {
429 const Dict& opts)
const {
433 if (e.second.jit) gen.
add(e.second.f);
440 if (
verbose_) casadi_message(
"compiling to "+ fname+
"'.");
447 if (
verbose_) casadi_message(
"loading '" + e.second.f.name() +
"' from '" + fname +
"'.");
449 e.second.f_original = e.second.f;
488 casadi_assert_dev(m->thread_local_mem.empty());
506 casadi_int* iw,
double* w)
const {
513 m->d_oracle.arg = arg;
514 m->d_oracle.res = res;
518 auto* ml = m->thread_local_mem[i];
519 for (
auto&& s : ml->fstats) s.second.reset();
532 std::vector<std::string> ret;
535 ret.push_back(e.first);
543 "No function \"" + name +
"\" in " +
name_ +
". " +
551 "No function \"" + name +
"\" in " +
name_+
". " +
553 return it->second.monitored;
564 s.
version(
"OracleFunction", 3);
572 s.
pack(
"OracleFunction::all_functions::key", e.first);
573 s.
pack(
"OracleFunction::all_functions::value::jit", e.second.jit);
574 if (
jit_ && e.second.jit) {
577 s.
pack(
"OracleFunction::all_functions::value::f", e.second.f_original);
579 std::string f_name = e.second.f.name();
580 s.
pack(
"OracleFunction::all_functions::value::f_name", f_name);
585 s.
pack(
"OracleFunction::all_functions::value::f", e.second.f);
587 s.
pack(
"OracleFunction::all_functions::value::monitored", e.second.monitored);
599 int version = s.
version(
"OracleFunction", 1, 3);
613 s.
unpack(
"OracleFunction::all_functions::size", size);
614 for (casadi_int i=0;i<size;++i) {
616 s.
unpack(
"OracleFunction::all_functions::key", key);
619 s.
unpack(
"OracleFunction::all_functions::value::f", r.
f);
620 s.
unpack(
"OracleFunction::all_functions::value::jit", r.
jit);
622 s.
unpack(
"OracleFunction::all_functions::value::jit", r.
jit);
625 s.
unpack(
"OracleFunction::all_functions::value::f", r.
f);
628 s.
unpack(
"OracleFunction::all_functions::value::f_name", f_name);
629 r.
f =
Function(f_name, std::vector<MX>{}, std::vector<MX>{});
633 s.
unpack(
"OracleFunction::all_functions::value::f", r.
f);
636 s.
unpack(
"OracleFunction::all_functions::value::monitored", r.
monitored);
Helper class for C code generation.
void add(const Function &f, bool with_jac_sparsity=false)
Add a function (name generated)
std::string generate(const std::string &prefix="")
Generate file(s)
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
std::string jit_serialize_
Serialize behaviour.
Dict get_stats(void *mem) const override
Get all statistics.
void init(const Dict &opts) override
Initialize.
void finalize() override
Finalize the object creation.
void tocache_if_missing(Function &f, const std::string &suffix="") const
Save function to cache, only if missing.
static std::string forward_name(const std::string &fcn, casadi_int nfwd)
Helper function: Get name of forward derivative function.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
std::string compiler_plugin_
Just-in-time compiler.
bool jit_
Use just-in-time compiler.
bool incache(const std::string &fname, Function &f, const std::string &suffix="") const
Get function in cache.
size_t sz_res() const
Get required length of res field.
static const Options options_
Options.
size_t sz_w() const
Get required length of w field.
size_t sz_arg() const
Get required length of arg field.
void alloc(const Function &f, bool persistent=false, int num_threads=1)
Ensure work vectors long enough to evaluate function.
size_t sz_iw() const
Get required length of iw field.
casadi_int nnz_out() const
Get number of output nonzeros.
void sz_work(size_t &sz_arg, size_t &sz_res, size_t &sz_iw, size_t &sz_w) const
Get number of temporary variables needed.
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
Function expand() const
Expand a function to SX.
const std::vector< std::string > & name_in() const
Get input scheme.
casadi_int n_out() const
Get the number of function outputs.
casadi_int n_in() const
Get the number of function inputs.
std::vector< std::string > get_free() const
Get free variables as a string.
bool has_free() const
Does the function have free variables.
casadi_int nnz_in() const
Get number of input nonzeros.
std::map< std::string, std::vector< std::string > > AuxOut
Function factory(const std::string &name, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const AuxOut &aux=AuxOut(), const Dict &opts=Dict()) const
const std::vector< std::string > & name_out() const
Get output scheme.
bool is_null() const
Is a null pointer?
static void check()
Raises an error if an interrupt was captured.
void print_scalar(std::ostream &stream) const
Print scalar.
Base class for functions that perform calculation with an oracle.
void set_function(const Function &fcn, const std::string &fname, bool jit=false)
Function oracle_
Oracle: Used to generate other functions.
int calc_sp_forward(const std::string &fcn, const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const
Function create_function(const Function &oracle, const std::string &fname, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const Function::AuxOut &aux=Function::AuxOut(), const Dict &opts=Dict())
std::map< std::string, RegFun > all_functions_
void join_results(OracleMemory *m) const
Combine results from different threads.
void init(const Dict &opts) override
~OracleFunction() override=0
Destructor.
void jit_dependencies(const std::string &fname) override
JIT for dependencies.
OracleFunction(const std::string &name, const Function &oracle)
Constructor.
Function create_forward(const std::string &fname, casadi_int nfwd)
int init_mem(void *mem) const override
Initalize memory block.
virtual void codegen_body_enter(CodeGenerator &g) const
Generate code for the function body.
int calc_function(OracleMemory *m, const std::string &fcn, const double *const *arg=nullptr, int thread_id=0) const
std::vector< std::string > get_function() const override
Get list of dependency functions.
std::vector< std::string > monitor_
bool has_function(const std::string &fname) const override
virtual bool monitored(const std::string &name) const
Dict common_options_
Options for creating functions.
void set_temp(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const override
Set the work vectors.
int local_init_mem(void *mem) const
Initalize memory block.
static const Options options_
Options.
const Function & oracle() const override
Get oracle.
bool show_eval_warnings_
Show evaluation warnings.
Dict get_stats(void *mem) const override
Get all statistics.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int calc_sp_reverse(const std::string &fcn, bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const
std::string generate_dependencies(const std::string &fname, const Dict &opts) const override
Export / Generate C code for the generated functions.
void finalize() override
Finalize initialization.
virtual void codegen_body_exit(CodeGenerator &g) const
Generate code for the function body.
virtual int init_mem(void *mem) const
Initalize memory block.
bool regularity_check_
Errors are thrown when NaN is produced.
bool verbose_
Verbose printout.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
std::string repr_el(casadi_int k) const
Describe the nonzero location k as a string.
std::string join(const std::vector< std::string > &l, const std::string &delim)
unsigned long long bvec_t
Dict combine(const Dict &first, const Dict &second, bool recurse)
Combine two dicts. First has priority.
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
Function external(const std::string &name, const Importer &li, const Dict &opts)
Load a just-in-time compiled external function.
Function memory with temporary work vectors.
Options metadata for a class.
std::vector< LocalOracleMemory * > thread_local_mem
std::map< std::string, FStats > fstats
void add_stat(const std::string &s)