oracle_function.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "oracle_function.hpp"
27 #include "external.hpp"
28 #include "serializing_stream.hpp"
29 
30 #include <iomanip>
31 #include <iostream>
32 
33 namespace casadi {
34 
35 OracleCallback::OracleCallback(const std::string& name,
36  OracleFunction* oracle) : name(name), oracle_(oracle) {
37 }
38 
39 OracleCallback::OracleCallback() : name("undefined"), oracle_(0) {
40 }
41 
42 
43 OracleFunction::OracleFunction(const std::string& name, const Function& oracle)
44 : FunctionInternal(name), oracle_(oracle) {
45 }
46 
48 }
49 
52  {{"expand",
53  {OT_BOOL,
54  "Replace MX with SX expressions in problem formulation [false] "
55  "This happens before creating derivatives unless indicated by postpone_expand"}},
56  {"postpone_expand",
57  {OT_BOOL,
58  "When expand is active, postpone it until after creation of derivatives. Default: False"}},
59  {"monitor",
61  "Set of user problem functions to be monitored"}},
62  {"show_eval_warnings",
63  {OT_BOOL,
64  "Show warnings generated from function evaluations [true]"}},
65  {"common_options",
66  {OT_DICT,
67  "Options for auto-generated functions"}},
68  {"specific_options",
69  {OT_DICT,
70  "Options for specific auto-generated functions,"
71  " overwriting the defaults from common_options. Nested dictionary."}}
72  }
73 };
74 
75 void OracleFunction::init(const Dict& opts) {
76 
78 
79  // Default options
80  bool expand = false;
81  bool postpone_expand = false;
82 
83  show_eval_warnings_ = true;
84 
85  max_num_threads_ = 1;
86  post_expand_ = false;
87 
88  // Read options
89  for (auto&& op : opts) {
90  if (op.first=="expand") {
91  expand = op.second;
92  } else if (op.first=="postpone_expand") {
93  postpone_expand = op.second;
94  } else if (op.first=="common_options") {
95  common_options_ = op.second;
96  } else if (op.first=="specific_options") {
97  specific_options_ = op.second;
98  for (auto&& i : specific_options_) {
99  casadi_assert(i.second.is_dict(),
100  "specific_option must be a nested dictionary."
101  " Type mismatch for entry '" + i.first+ "': "
102  " got type " + i.second.get_description() + ".");
103  }
104  } else if (op.first=="monitor") {
105  monitor_ = op.second;
106  } else if (op.first=="show_eval_warnings") {
107  show_eval_warnings_ = op.second;
108  }
109  }
110 
111  // Replace MX oracle with SX oracle?
112  if (expand && !postpone_expand) oracle_ = oracle_.expand();
113  if (expand && postpone_expand) post_expand_ = true;
114 
115  stride_arg_ = 0;
116  stride_res_ = 0;
117  stride_iw_ = 0;
118  stride_w_ = 0;
119 
120 }
121 
123  if (post_expand_) {
124  for (auto&& e : all_functions_) {
125  Function& fcn = e.second.f;
126  fcn = fcn.expand();
127  }
128  }
129 
130  // Allocate space for (parallel) evaluations
131  // Lifted from set_function as max_num_threads_ is not known yet in that method
132  for (auto&& e : all_functions_) {
133  Function& fcn = e.second.f;
134  // Compute strides for multi threading
135  size_t sz_arg, sz_res, sz_iw, sz_w;
136  fcn.sz_work(sz_arg, sz_res, sz_iw, sz_w);
137  stride_arg_ = std::max(stride_arg_, sz_arg);
138  stride_res_ = std::max(stride_res_, sz_res);
139  stride_iw_ = std::max(stride_iw_, sz_iw);
140  stride_w_ = std::max(stride_w_, sz_w);
141  bool persistent = false;
142  alloc(fcn, persistent, max_num_threads_);
143  }
144 
145  // Set corresponding monitors
146  for (const std::string& fname : monitor_) {
147  auto it = all_functions_.find(fname);
148  if (it==all_functions_.end()) {
149  casadi_warning("Ignoring monitor '" + fname + "'."
150  " Available functions: " + join(get_function()) + ".");
151  } else {
152  if (it->second.monitored) casadi_warning("Duplicate monitor " + fname);
153  it->second.monitored = true;
154  }
155  }
156 
157  // Check specific options
158  for (auto&& i : specific_options_) {
159  if (all_functions_.find(i.first)==all_functions_.end())
160  casadi_warning("Ignoring specific_options entry '" + i.first+"'."
161  " Available functions: " + join(get_function()) + ".");
162  }
163 
164  // Recursive call
166 }
167 
169  // Combine runtime statistics
170  // Note: probably not correct to simply add wall times
171  for (int i = 0; i < max_num_threads_; ++i) {
172  auto* ml = m->thread_local_mem[i];
173  for (auto&& s : ml->fstats) {
174  m->fstats.at(s.first).join(s.second);
175  }
176  }
177 }
178 
179 Function OracleFunction::create_function(const std::string& fname,
180  const std::vector<std::string>& s_in,
181  const std::vector<std::string>& s_out,
182  const Function::AuxOut& aux,
183  const Dict& opts) {
184  return create_function(oracle_, fname, s_in, s_out, aux, opts);
185 }
186 
187 Function OracleFunction::create_function(const std::string& fname,
188  const std::vector<MX>& e_in,
189  const std::vector<MX>& e_out,
190  const std::vector<std::string>& s_in,
191  const std::vector<std::string>& s_out,
192  const Dict& opts) {
193 
194  // Print progress
195  if (verbose_) {
196  casadi_message(name_ + "::create_function " + fname + ":" + str(s_in) + "->" + str(s_out));
197  }
198 
199  // Check if function is already in cache
200  Function ret;
201  if (incache(fname, ret)) {
202  // Consistency checks
203  casadi_assert(ret.n_in() == s_in.size(), fname + " has wrong number of inputs");
204  casadi_assert(ret.n_out() == s_out.size(), fname + " has wrong number of outputs");
205  } else {
206  // Retrieve specific set of options if available
207  Dict specific_options;
208  auto it = specific_options_.find(fname);
209  if (it!=specific_options_.end()) specific_options = it->second;
210 
211  // Combine specific and common options
212  Dict opt = combine(specific_options, common_options_);
213  opt = combine(opts, opt);
214 
215  // Generate the function
216  ret = Function(fname, e_in, e_out, s_in, s_out, opt);
217 
218  // Make sure that it's sound
219  if (ret.has_free()) {
220  casadi_error("Cannot create '" + fname + "' since " + str(ret.get_free()) + " are free.");
221  }
222 
223  // TODO(jgillis) Conditionally convert to SX
224 
225  // Add to cache
226  tocache_if_missing(ret);
227  }
228 
229  // Save and return
230  set_function(ret, fname, true);
231  return ret;
232 
233 }
234 
235 Function OracleFunction::create_function(const Function& oracle, const std::string& fname,
236  const std::vector<std::string>& s_in,
237  const std::vector<std::string>& s_out,
238  const Function::AuxOut& aux,
239  const Dict& opts) {
240  // Print progress
241  if (verbose_) {
242  casadi_message(name_ + "::create_function " + fname + ":" + str(s_in) + "->" + str(s_out));
243  }
244 
245  // Check if function is already in cache
246  Function ret;
247  if (incache(fname, ret)) {
248  // Consistency checks
249  casadi_assert(ret.n_in() == s_in.size(), fname + " has wrong number of inputs");
250  casadi_assert(ret.n_out() == s_out.size(), fname + " has wrong number of outputs");
251  } else {
252  // Retrieve specific set of options if available
253  Dict specific_options;
254  auto it = specific_options_.find(fname);
255  if (it!=specific_options_.end()) specific_options = it->second;
256 
257  // Combine specific and common options
258  Dict opt = combine(specific_options, common_options_);
259  opt = combine(opts, opt);
260 
261  // Generate the function
262  ret = oracle.factory(fname, s_in, s_out, aux, opt);
263 
264  // Make sure that it's sound
265  if (ret.has_free()) {
266  casadi_error("Cannot create '" + fname + "' since " + str(ret.get_free()) + " are free.");
267  }
268 
269  // Add to cache
270  tocache_if_missing(ret);
271  }
272 
273  // Save and return
274  set_function(ret, fname, true);
275  return ret;
276 }
277 
278 Function OracleFunction::create_forward(const std::string& fname, casadi_int nfwd) {
279  // Create derivative
280  Function ret = get_function(fname).forward(nfwd);
281  std::string fwd_name = forward_name(fname, nfwd); // may be different from ret.name()
282  if (!has_function(fwd_name)) set_function(ret, fwd_name, true);
283  return ret;
284 }
285 
287 set_function(const Function& fcn, const std::string& fname, bool jit) {
288  casadi_assert(!has_function(fname), "Duplicate function " + fname);
289  RegFun& r = all_functions_[fname];
290  r.f = fcn;
291  r.jit = jit;
292 }
293 
294 
296  g.local("d_oracle", "struct casadi_oracle_data");
297  }
298 
300  }
301 
303 calc_function(OracleMemory* m, const std::string& fcn,
304  const double* const* arg, int thread_id) const {
305  auto ml = m->thread_local_mem.at(thread_id);
306  // Is the function monitored?
307  bool monitored = this->monitored(fcn);
308 
309  // Print progress
310  if (monitored) casadi_message("Calling \"" + fcn + "\"");
311 
312  // Respond to a possible Crl+C signals
313  // Python interrupt checker needs the GIL.
314  // We may not have access to it in a multi-threaded context
315  // See issue #2955
317 
318  // Get function
319  const Function& f = get_function(fcn);
320 
321  // Get statistics structure
322  FStats& fstats = ml->fstats.at(fcn);
323 
324  // Number of inputs and outputs
325  casadi_int n_in = f.n_in(), n_out = f.n_out();
326 
327  // Prepare stats, start timer
328  ScopedTiming tic(fstats);
329 
330  // Input buffers
331  if (arg) {
332  std::fill_n(ml->arg, n_in, nullptr);
333  for (casadi_int i=0; i<n_in; ++i) ml->arg[i] = *arg++;
334  }
335 
336  // Print inputs nonzeros
337  if (monitored) {
338  std::stringstream s;
339  s << fcn << " input nonzeros:\n";
340  for (casadi_int i=0; i<n_in; ++i) {
341  s << " " << i << " (" << f.name_in(i) << "): ";
342  if (ml->arg[i]) {
343  // Print nonzeros
344  s << "[";
345  for (casadi_int k=0; k<f.nnz_in(i); ++k) {
346  if (k!=0) s << ", ";
347  DM::print_scalar(s, ml->arg[i][k]);
348  }
349  s << "]\n";
350  } else {
351  // All zero input
352  s << "0\n";
353  }
354  }
355  casadi_message(s.str());
356  }
357 
358  // Evaluate memory-less
359  try {
360  if (f(ml->arg, ml->res, ml->iw, ml->w)) {
361  // Recoverable error
362  if (monitored) casadi_message(name_ + ":" + fcn + " failed");
363  return 1;
364  }
365  } catch(std::exception& ex) {
366  // Fatal error: Generate stack trace
367  casadi_error("Error in " + name_ + ":" + fcn + ":" + std::string(ex.what()));
368  }
369 
370  // Print output nonzeros
371  if (monitored) {
372  std::stringstream s;
373  s << fcn << " output nonzeros:\n";
374  for (casadi_int i=0; i<n_out; ++i) {
375  s << " " << i << " (" << f.name_out(i) << "): ";
376  if (ml->res[i]) {
377  // Print nonzeros
378  s << "[";
379  for (casadi_int k=0; k<f.nnz_out(i); ++k) {
380  if (k!=0) s << ", ";
381  DM::print_scalar(s, ml->res[i][k]);
382  }
383  s << "]\n";
384  } else {
385  // Ignored output
386  s << " N/A\n";
387  }
388  }
389  casadi_message(s.str());
390  }
391 
392  // Make sure not NaN or Inf
393  for (casadi_int i=0; i<n_out; ++i) {
394  if (!ml->res[i]) continue;
395  if (!std::all_of(ml->res[i], ml->res[i]+f.nnz_out(i), [](double v) { return isfinite(v);})) {
396  std::stringstream ss;
397 
398  auto it = std::find_if(ml->res[i], ml->res[i] + f.nnz_out(i),
399  [](double v) { return !isfinite(v);});
400  casadi_int k = std::distance(ml->res[i], it);
401  bool is_nan = isnan(ml->res[i][k]);
402  ss << name_ << ":" << fcn << " failed: " << (is_nan? "NaN" : "Inf") <<
403  " detected for output " << f.name_out(i) << ", at " << f.sparsity_out(i).repr_el(k) << ".";
404 
405  if (regularity_check_) {
406  casadi_error(ss.str());
407  } else {
408  if (show_eval_warnings_) casadi_warning(ss.str());
409  }
410  return -1;
411  }
412  }
413 
414  // Success
415  return 0;
416 }
417 
418 int OracleFunction::calc_sp_forward(const std::string& fcn, const bvec_t** arg, bvec_t** res,
419  casadi_int* iw, bvec_t* w) const {
420  return get_function(fcn)(arg, res, iw, w);
421 }
422 
423 int OracleFunction::calc_sp_reverse(const std::string& fcn, bvec_t** arg, bvec_t** res,
424  casadi_int* iw, bvec_t* w) const {
425  return get_function(fcn).rev(arg, res, iw, w);
426 }
427 
428 std::string OracleFunction::generate_dependencies(const std::string& fname,
429  const Dict& opts) const {
430  CodeGenerator gen(fname, opts);
431  gen.add(oracle_);
432  for (auto&& e : all_functions_) {
433  if (e.second.jit) gen.add(e.second.f);
434  }
435  return gen.generate();
436 }
437 
438 void OracleFunction::jit_dependencies(const std::string& fname) {
439  if (compiler_.is_null()) {
440  if (verbose_) casadi_message("compiling to "+ fname+"'.");
441  // JIT dependent functions
444  }
445  // Replace the Oracle functions with generated functions
446  for (auto&& e : all_functions_) {
447  if (verbose_) casadi_message("loading '" + e.second.f.name() + "' from '" + fname + "'.");
448  if (e.second.jit) {
449  e.second.f_original = e.second.f;
450  e.second.f = external(e.second.f.name(), compiler_);
451  }
452  }
453 }
454 
456  oracle_ = oracle_.expand();
457 }
458 
459 Dict OracleFunction::get_stats(void *mem) const {
460  Dict stats = FunctionInternal::get_stats(mem);
461  //auto m = static_cast<OracleMemory*>(mem);
462  return stats;
463 }
464 
465 int OracleFunction::local_init_mem(void* mem) const {
466  if (ProtoFunction::init_mem(mem)) return 1;
467  if (!mem) return 1;
468  auto m = static_cast<LocalOracleMemory*>(mem);
469 
470  // Create statistics
471  for (auto&& e : all_functions_) {
472  m->add_stat(e.first);
473  }
474 
475  return 0;
476 }
477 
478 int OracleFunction::init_mem(void* mem) const {
479  if (ProtoFunction::init_mem(mem)) return 1;
480  if (!mem) return 1;
481  auto m = static_cast<OracleMemory*>(mem);
482 
483  // Create statistics
484  for (auto&& e : all_functions_) {
485  m->add_stat(e.first);
486  }
487 
488  casadi_assert_dev(m->thread_local_mem.empty());
489 
490  // Allocate and initialize local memory for threads
491  for (int i = 0; i < max_num_threads_; ++i) {
492  m->thread_local_mem.push_back(new LocalOracleMemory());
493  if (OracleFunction::local_init_mem(m->thread_local_mem[i])) return 1;
494  }
495 
496  return 0;
497 }
498 
500  for (auto* ml : thread_local_mem) {
501  delete static_cast<LocalOracleMemory*>(ml);
502  }
503 }
504 
505 void OracleFunction::set_temp(void* mem, const double** arg, double** res,
506  casadi_int* iw, double* w) const {
507 
508  auto m = static_cast<OracleMemory*>(mem);
509  m->arg = arg;
510  m->res = res;
511  m->iw = iw;
512  m->w = w;
513  m->d_oracle.arg = arg;
514  m->d_oracle.res = res;
515  m->d_oracle.iw = iw;
516  m->d_oracle.w = w;
517  for (int i = 0; i < max_num_threads_; ++i) {
518  auto* ml = m->thread_local_mem[i];
519  for (auto&& s : ml->fstats) s.second.reset();
520  ml->arg = arg;
521  ml->res = res;
522  ml->iw = iw;
523  ml->w = w;
524  arg += stride_arg_;
525  res += stride_res_;
526  iw += stride_iw_;
527  w += stride_w_;
528  }
529 }
530 
531 std::vector<std::string> OracleFunction::get_function() const {
532  std::vector<std::string> ret;
533  ret.reserve(all_functions_.size());
534  for (auto&& e : all_functions_) {
535  ret.push_back(e.first);
536  }
537  return ret;
538 }
539 
540 const Function& OracleFunction::get_function(const std::string &name) const {
541  auto it = all_functions_.find(name);
542  casadi_assert(it!=all_functions_.end(),
543  "No function \"" + name + "\" in " + name_ + ". " +
544  "Available functions: " + join(get_function()) + ".");
545  return it->second.f;
546 }
547 
548 bool OracleFunction::monitored(const std::string &name) const {
549  auto it = all_functions_.find(name);
550  casadi_assert(it!=all_functions_.end(),
551  "No function \"" + name + "\" in " + name_+ ". " +
552  "Available functions: " + join(get_function()) + ".");
553  return it->second.monitored;
554 }
555 
556 bool OracleFunction::has_function(const std::string& fname) const {
557  return all_functions_.find(fname) != all_functions_.end();
558 }
559 
560 
563 
564  s.version("OracleFunction", 3);
565  s.pack("OracleFunction::oracle", oracle_);
566  s.pack("OracleFunction::common_options", common_options_);
567  s.pack("OracleFunction::specific_options", specific_options_);
568  s.pack("OracleFunction::show_eval_warnings", show_eval_warnings_);
569  s.pack("OracleFunction::max_num_threads", max_num_threads_);
570  s.pack("OracleFunction::all_functions::size", all_functions_.size());
571  for (auto &e : all_functions_) {
572  s.pack("OracleFunction::all_functions::key", e.first);
573  s.pack("OracleFunction::all_functions::value::jit", e.second.jit);
574  if (jit_ && e.second.jit) {
575  if (jit_serialize_=="source") {
576  // Save original f, such that it can be built
577  s.pack("OracleFunction::all_functions::value::f", e.second.f_original);
578  } else {
579  std::string f_name = e.second.f.name();
580  s.pack("OracleFunction::all_functions::value::f_name", f_name);
581  // FunctionInternal will set compiler_
582  }
583  } else {
584  // Save f
585  s.pack("OracleFunction::all_functions::value::f", e.second.f);
586  }
587  s.pack("OracleFunction::all_functions::value::monitored", e.second.monitored);
588  }
589  s.pack("OracleFunction::monitor", monitor_);
590  s.pack("OracleFunction::stride_arg", stride_arg_);
591  s.pack("OracleFunction::stride_res", stride_res_);
592  s.pack("OracleFunction::stride_iw", stride_iw_);
593  s.pack("OracleFunction::stride_w", stride_w_);
594 
595 }
596 
598 
599  int version = s.version("OracleFunction", 1, 3);
600  s.unpack("OracleFunction::oracle", oracle_);
601  s.unpack("OracleFunction::common_options", common_options_);
602  s.unpack("OracleFunction::specific_options", specific_options_);
603  s.unpack("OracleFunction::show_eval_warnings", show_eval_warnings_);
604 
605  if (version>=3) {
606  s.unpack("OracleFunction::max_num_threads", max_num_threads_);
607  } else {
608  max_num_threads_ = 1;
609  }
610 
611  size_t size;
612 
613  s.unpack("OracleFunction::all_functions::size", size);
614  for (casadi_int i=0;i<size;++i) {
615  std::string key;
616  s.unpack("OracleFunction::all_functions::key", key);
617  RegFun r;
618  if (version==1) {
619  s.unpack("OracleFunction::all_functions::value::f", r.f);
620  s.unpack("OracleFunction::all_functions::value::jit", r.jit);
621  } else {
622  s.unpack("OracleFunction::all_functions::value::jit", r.jit);
623  if (jit_ && r.jit) {
624  if (jit_serialize_=="source") {
625  s.unpack("OracleFunction::all_functions::value::f", r.f);
626  } else {
627  std::string f_name;
628  s.unpack("OracleFunction::all_functions::value::f_name", f_name);
629  r.f = Function(f_name, std::vector<MX>{}, std::vector<MX>{});
630  // FunctionInternal will set compiler_
631  }
632  } else {
633  s.unpack("OracleFunction::all_functions::value::f", r.f);
634  }
635  }
636  s.unpack("OracleFunction::all_functions::value::monitored", r.monitored);
637  all_functions_[key] = r;
638  }
639  s.unpack("OracleFunction::monitor", monitor_);
640  if (version>=3) {
641  s.unpack("OracleFunction::stride_arg", stride_arg_);
642  s.unpack("OracleFunction::stride_res", stride_res_);
643  s.unpack("OracleFunction::stride_iw", stride_iw_);
644  s.unpack("OracleFunction::stride_w", stride_w_);
645  } else {
646  stride_arg_ = 0;
647  stride_res_ = 0;
648  stride_iw_ = 0;
649  stride_w_ = 0;
650  }
651  post_expand_ = false;
652 }
653 
654 } // namespace casadi
Helper class for C code generation.
void add(const Function &f, bool with_jac_sparsity=false)
Add a function (name generated)
std::string generate(const std::string &prefix="")
Generate file(s)
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
std::string jit_serialize_
Serialize behaviour.
Dict get_stats(void *mem) const override
Get all statistics.
void init(const Dict &opts) override
Initialize.
void finalize() override
Finalize the object creation.
void tocache_if_missing(Function &f, const std::string &suffix="") const
Save function to cache, only if missing.
static std::string forward_name(const std::string &fcn, casadi_int nfwd)
Helper function: Get name of forward derivative function.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
std::string compiler_plugin_
Just-in-time compiler.
bool jit_
Use just-in-time compiler.
bool incache(const std::string &fname, Function &f, const std::string &suffix="") const
Get function in cache.
size_t sz_res() const
Get required length of res field.
static const Options options_
Options.
size_t sz_w() const
Get required length of w field.
size_t sz_arg() const
Get required length of arg field.
void alloc(const Function &f, bool persistent=false, int num_threads=1)
Ensure work vectors long enough to evaluate function.
size_t sz_iw() const
Get required length of iw field.
Function object.
Definition: function.hpp:60
casadi_int nnz_out() const
Get number of output nonzeros.
Definition: function.cpp:855
void sz_work(size_t &sz_arg, size_t &sz_res, size_t &sz_iw, size_t &sz_w) const
Get number of temporary variables needed.
Definition: function.cpp:1079
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
Definition: function.cpp:1031
Function expand() const
Expand a function to SX.
Definition: function.cpp:308
const std::vector< std::string > & name_in() const
Get input scheme.
Definition: function.cpp:961
casadi_int n_out() const
Get the number of function outputs.
Definition: function.cpp:823
casadi_int n_in() const
Get the number of function inputs.
Definition: function.cpp:819
std::vector< std::string > get_free() const
Get free variables as a string.
Definition: function.cpp:1184
bool has_free() const
Does the function have free variables.
Definition: function.cpp:1697
casadi_int nnz_in() const
Get number of input nonzeros.
Definition: function.cpp:851
std::map< std::string, std::vector< std::string > > AuxOut
Definition: function.hpp:404
Function factory(const std::string &name, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const AuxOut &aux=AuxOut(), const Dict &opts=Dict()) const
Definition: function.cpp:1812
const std::vector< std::string > & name_out() const
Get output scheme.
Definition: function.cpp:965
bool is_null() const
Is a null pointer?
Importer.
Definition: importer.hpp:86
static void check()
Raises an error if an interrupt was captured.
void print_scalar(std::ostream &stream) const
Print scalar.
Base class for functions that perform calculation with an oracle.
void set_function(const Function &fcn, const std::string &fname, bool jit=false)
Function oracle_
Oracle: Used to generate other functions.
int calc_sp_forward(const std::string &fcn, const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const
Function create_function(const Function &oracle, const std::string &fname, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const Function::AuxOut &aux=Function::AuxOut(), const Dict &opts=Dict())
std::map< std::string, RegFun > all_functions_
void join_results(OracleMemory *m) const
Combine results from different threads.
void init(const Dict &opts) override
~OracleFunction() override=0
Destructor.
void jit_dependencies(const std::string &fname) override
JIT for dependencies.
OracleFunction(const std::string &name, const Function &oracle)
Constructor.
Function create_forward(const std::string &fname, casadi_int nfwd)
int init_mem(void *mem) const override
Initalize memory block.
virtual void codegen_body_enter(CodeGenerator &g) const
Generate code for the function body.
int calc_function(OracleMemory *m, const std::string &fcn, const double *const *arg=nullptr, int thread_id=0) const
std::vector< std::string > get_function() const override
Get list of dependency functions.
std::vector< std::string > monitor_
bool has_function(const std::string &fname) const override
virtual bool monitored(const std::string &name) const
Dict common_options_
Options for creating functions.
void set_temp(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const override
Set the work vectors.
int local_init_mem(void *mem) const
Initalize memory block.
static const Options options_
Options.
const Function & oracle() const override
Get oracle.
bool show_eval_warnings_
Show evaluation warnings.
Dict get_stats(void *mem) const override
Get all statistics.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int calc_sp_reverse(const std::string &fcn, bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const
std::string generate_dependencies(const std::string &fname, const Dict &opts) const override
Export / Generate C code for the generated functions.
void finalize() override
Finalize initialization.
virtual void codegen_body_exit(CodeGenerator &g) const
Generate code for the function body.
virtual int init_mem(void *mem) const
Initalize memory block.
bool regularity_check_
Errors are thrown when NaN is produced.
bool verbose_
Verbose printout.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
std::string repr_el(casadi_int k) const
Describe the nonzero location k as a string.
Definition: sparsity.cpp:607
The casadi namespace.
Definition: archiver.cpp:28
std::string join(const std::vector< std::string > &l, const std::string &delim)
unsigned long long bvec_t
Dict combine(const Dict &first, const Dict &second, bool recurse)
Combine two dicts. First has priority.
@ OT_STRINGVECTOR
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
Function external(const std::string &name, const Importer &li, const Dict &opts)
Load a just-in-time compiled external function.
Definition: external.cpp:42
Function memory with temporary work vectors.
Options metadata for a class.
Definition: options.hpp:40
Function memory.
std::vector< LocalOracleMemory * > thread_local_mem
std::map< std::string, FStats > fstats
void add_stat(const std::string &s)