27 #include "serializing_stream.hpp"
29 #ifdef CASADI_WITH_THREAD
30 #ifdef CASADI_WITH_THREAD_MINGW
31 #include <mingw.thread.h>
41 std::string suffix =
str(n) +
"_" + f.
name();
57 bool Map::is_a(
const std::string& type,
bool recursive)
const {
64 || (recursive &&
Map::is_a(type, recursive));
68 return type==
"ThreadMap"
69 || (recursive &&
Map::is_a(type, recursive));
78 "No function \"" + name +
"\" in " +
name_ +
". " +
113 casadi_error(
"class name '" +
class_name +
"' unknown.");
135 int Map::eval_gen(
const T** arg, T** res, casadi_int* iw, T* w,
int mem)
const {
136 const T** arg1 = arg+
n_in_;
137 std::copy_n(arg,
n_in_, arg1);
139 std::copy_n(res,
n_out_, res1);
140 for (casadi_int i=0; i<
n_; ++i) {
141 if (
f_(arg1, res1, iw, w, mem))
return 1;
142 for (casadi_int j=0; j<
n_in_; ++j) {
143 if (arg1[j]) arg1[j] +=
f_.
nnz_in(j);
145 for (casadi_int j=0; j<
n_out_; ++j) {
153 bool always_inline,
bool never_inline)
const {
158 casadi_int* iw,
bvec_t* w,
void* mem)
const {
164 std::copy_n(arg,
n_in_, arg1);
166 std::copy_n(res,
n_out_, res1);
167 for (casadi_int i=0; i<
n_; ++i) {
168 if (
f_.
rev(arg1, res1, iw, w))
return 1;
169 for (casadi_int j=0; j<
n_in_; ++j) {
170 if (arg1[j]) arg1[j] +=
f_.
nnz_in(j);
172 for (casadi_int j=0; j<
n_out_; ++j) {
184 g.
local(
"i",
"casadi_int");
185 g.
local(
"arg1",
"const casadi_real*",
"*");
186 g.
local(
"res1",
"casadi_real*",
"*");
189 g <<
"arg1 = arg+" <<
n_in_ <<
";\n"
190 <<
"for (i=0; i<" <<
n_in_ <<
"; ++i) arg1[i]=arg[i];\n";
192 g <<
"res1 = res+" <<
n_out_ <<
";\n"
193 <<
"for (i=0; i<" <<
n_out_ <<
"; ++i) res1[i]=res[i];\n"
194 <<
"for (i=0; i<" <<
n_ <<
"; ++i) {\n";
196 g <<
"if (" << g(
f_,
"arg1",
"res1",
"iw",
"w") <<
") return 1;\n";
198 for (casadi_int j=0; j<
n_in_; ++j) {
200 g <<
"if (arg1[" << j <<
"]) arg1[" << j <<
"]+=" <<
f_.
nnz_in(j) <<
";\n";
203 for (casadi_int j=0; j<
n_out_; ++j) {
205 g <<
"if (res1[" << j <<
"]) res1[" << j <<
"]+=" <<
f_.
nnz_out(j) <<
";\n";
212 const std::vector<std::string>& inames,
213 const std::vector<std::string>& onames,
214 const Dict& opts)
const {
220 std::vector<MX> arg = dm.
mx_in();
223 std::vector<MX> res = arg;
224 std::vector<MX>::iterator it=res.begin()+n_in_+n_out_;
225 std::vector<casadi_int> ind;
226 for (casadi_int i=0; i<n_in_; ++i, ++it) {
227 casadi_int sz = f_.size2_in(i);
229 for (casadi_int k=0; k<n_; ++k) {
230 for (casadi_int d=0; d<nfwd; ++d) {
231 for (casadi_int j=0; j<sz; ++j) {
232 ind.push_back((d*n_ + k)*sz + j);
236 *it = (*it)(
Slice(), ind);
244 for (casadi_int i=0; i<n_out_; ++i, ++it) {
245 casadi_int sz = f_.size2_out(i);
247 for (casadi_int d=0; d<nfwd; ++d) {
248 for (casadi_int k=0; k<n_; ++k) {
249 for (casadi_int j=0; j<sz; ++j) {
250 ind.push_back((k*nfwd + d)*sz + j);
254 *it = (*it)(
Slice(), ind);
258 options[
"allow_duplicate_io_names"] =
true;
261 return Function(name, arg, res, inames, onames, options);
266 const std::vector<std::string>& inames,
267 const std::vector<std::string>& onames,
268 const Dict& opts)
const {
274 std::vector<MX> arg = dm.
mx_in();
277 std::vector<MX> res = arg;
278 std::vector<MX>::iterator it=res.begin()+n_in_+n_out_;
279 std::vector<casadi_int> ind;
280 for (casadi_int i=0; i<n_out_; ++i, ++it) {
281 casadi_int sz = f_.size2_out(i);
283 for (casadi_int k=0; k<n_; ++k) {
284 for (casadi_int d=0; d<nadj; ++d) {
285 for (casadi_int j=0; j<sz; ++j) {
286 ind.push_back((d*n_ + k)*sz + j);
290 *it = (*it)(
Slice(), ind);
298 for (casadi_int i=0; i<n_in_; ++i, ++it) {
299 casadi_int sz = f_.size2_in(i);
301 for (casadi_int d=0; d<nadj; ++d) {
302 for (casadi_int k=0; k<n_; ++k) {
303 for (casadi_int j=0; j<sz; ++j) {
304 ind.push_back((k*nadj + d)*sz + j);
308 *it = (*it)(
Slice(), ind);
312 options[
"allow_duplicate_io_names"] =
true;
315 return Function(name, arg, res, inames, onames, options);
318 int Map::eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const {
322 setup(mem, arg, res, iw, w);
324 return eval_gen(arg, res, iw, w, m);
331 int OmpMap::eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const {
335 setup(mem, arg, res, iw, w);
343 std::vector< scoped_checkout<Function> > ind; ind.reserve(
n_);
344 for (casadi_int i=0; i<
n_; ++i) ind.emplace_back(
f_);
347 #pragma omp parallel for reduction(||:flag)
348 for (casadi_int i=0; i<
n_; ++i) {
351 for (casadi_int j=0; j<
n_in_; ++j) {
352 arg1[j] = arg[j] ? arg[j] + i*
f_.
nnz_in(j) : 0;
357 for (casadi_int j=0; j<
n_out_; ++j) {
358 res1[j] = res[j] ? res[j] + i*
f_.
nnz_out(j) : 0;
363 flag =
f_(arg1, res1, iw + i*
sz_iw, w + i*
sz_w, ind[i]) || flag;
364 }
catch (std::exception& e) {
366 casadi_warning(
"Exception raised: " + std::string(e.what()));
369 casadi_warning(
"Uncaught exception.");
382 g <<
"casadi_int i;\n"
383 <<
"const double** arg1;\n"
384 <<
"double** res1;\n"
385 <<
"casadi_int flag = 0;\n"
386 <<
"#pragma omp parallel for private(i,arg1,res1) reduction(||:flag)\n"
387 <<
"for (i=0; i<" <<
n_ <<
"; ++i) {\n"
388 <<
"arg1 = arg + " <<
n_in_ <<
"+i*" <<
sz_arg <<
";\n";
389 for (casadi_int j=0; j<
n_in_; ++j) {
390 g <<
"arg1[" << j <<
"] = arg[" << j <<
"] ? "
393 g <<
"res1 = res + " <<
n_out_ <<
"+i*" <<
sz_res <<
";\n";
394 for (casadi_int j=0; j<
n_out_; ++j) {
395 g <<
"res1[" << j <<
"] = res[" << j <<
"] ?"
399 << g(
f_,
"arg1",
"res1",
"iw+i*" +
str(
sz_iw),
"w+i*" +
str(
sz_w)) <<
" || flag;\n"
401 <<
"if (flag) return 1;\n";
406 casadi_warning(
"CasADi was not compiled with WITH_OPENMP=ON. "
407 "Falling back to serial evaluation.");
428 const double** arg,
double** res,
429 casadi_int* iw,
double* w,
430 casadi_int ind,
int& ret) {
433 casadi_int n_in = f.
n_in();
434 casadi_int n_out = f.
n_out();
437 size_t sz_arg, sz_res, sz_iw, sz_w;
438 f.
sz_work(sz_arg, sz_res, sz_iw, sz_w);
441 const double** arg1 = arg + n_in + i*sz_arg;
442 for (casadi_int j=0; j<n_in; ++j) {
443 arg1[j] = arg[j] ? arg[j] + i*f.
nnz_in(j) :
nullptr;
447 double** res1 = res + n_out + i*sz_res;
448 for (casadi_int j=0; j<n_out; ++j) {
449 res1[j] = res[j] ? res[j] + i*f.
nnz_out(j) :
nullptr;
453 ret = f(arg1, res1, iw + i*sz_iw, w + i*sz_w, ind);
454 }
catch (std::exception& e) {
456 casadi_warning(
"Exception raised: " + std::string(e.what()));
459 casadi_warning(
"Uncaught exception.");
465 #ifndef CASADI_WITH_THREAD
468 setup(mem, arg, res, iw, w);
470 std::vector< scoped_checkout<Function> > ind; ind.reserve(
n_);
471 for (casadi_int i=0; i<
n_; ++i) ind.emplace_back(
f_);
474 std::vector<int> ret_values(
n_);
477 std::vector<std::thread> threads;
478 for (casadi_int i=0; i<
n_; ++i) {
482 threads.emplace_back(
483 [i](
const Function& f,
const double** arg,
double** res,
484 casadi_int* iw,
double* w, casadi_int ind,
int& ret) {
487 std::ref(
f_), arg, res, iw, w, casadi_int(ind[i]), std::ref(ret_values[i]));
491 for (
auto && th : threads) th.join();
497 for (
int e : ret_values) ret = ret || e;
508 #ifndef CASADI_WITH_THREAD
509 casadi_warning(
"CasADi was not compiled with WITH_THREAD=ON. "
510 "Falling back to serial evaluation.");
Helper class for C code generation.
std::string add_dependency(const Function &f)
Add a function dependency.
std::string arg(casadi_int i) const
Refer to argument.
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
std::string res(casadi_int i) const
Refer to resuly.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Internal class for Function.
void alloc_iw(size_t sz_iw, bool persistent=false)
Ensure required length of iw field.
void init(const Dict &opts) override
Initialize.
std::vector< bool > is_diff_out_
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
void alloc_res(size_t sz_res, bool persistent=false)
Ensure required length of res field.
void alloc_arg(size_t sz_arg, bool persistent=false)
Ensure required length of arg field.
virtual bool is_a(const std::string &type, bool recursive) const
Check if the function is of a particular type.
size_t n_in_
Number of inputs and outputs.
size_t sz_res() const
Get required length of res field.
void serialize_type(SerializingStream &s) const override
Serialize type information.
size_t sz_w() const
Get required length of w field.
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
size_t sz_arg() const
Get required length of arg field.
void setup(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const
Set the (persistent and temporary) work vectors.
std::vector< bool > is_diff_in_
Are inputs and outputs differentiable?
size_t sz_iw() const
Get required length of iw field.
Function forward(casadi_int nfwd) const
Get a function that calculates nfwd forward derivatives.
casadi_int nnz_out() const
Get number of output nonzeros.
void sz_work(size_t &sz_arg, size_t &sz_res, size_t &sz_iw, size_t &sz_w) const
Get number of temporary variables needed.
size_t sz_res() const
Get required length of res field.
const MX mx_in(casadi_int ind) const
Get symbolic primitives equivalent to the input expressions.
const std::string & name() const
Name of the function.
Function reverse(casadi_int nadj) const
Get a function that calculates nadj adjoint derivatives.
static Function create(FunctionInternal *node)
Create from node.
bool is_diff_out(casadi_int ind) const
Get differentiability of inputs/output.
int rev(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, int mem=0) const
Propagate sparsity backward.
size_t sz_iw() const
Get required length of iw field.
casadi_int n_out() const
Get the number of function outputs.
casadi_int n_in() const
Get the number of function inputs.
bool is_diff_in(casadi_int ind) const
Get differentiability of inputs/output.
Function map(casadi_int n, const std::string ¶llelization="serial") const
Create a mapped version of this function.
size_t sz_w() const
Get required length of w field.
size_t sz_arg() const
Get required length of arg field.
casadi_int nnz_in() const
Get number of input nonzeros.
int eval_gen(const T **arg, T **res, casadi_int *iw, T *w, int mem=0) const
Evaluate or propagate sparsities.
void serialize_type(SerializingStream &s) const override
Serialize type information.
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nadj adjoint derivatives.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const override
Propagate sparsity forward.
void init(const Dict &opts) override
Initialize.
~Map() override
Destructor.
bool is_a(const std::string &type, bool recursive) const override
Check if the function is of a particular type.
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate the function numerically.
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
bool has_function(const std::string &fname) const override
void codegen_declarations(CodeGenerator &g) const override
Generate code for the declarations of the C function.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const override
Propagate sparsity backwards.
std::string class_name() const override
Get type name.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w, void *mem, bool always_inline, bool never_inline) const override
evaluate symbolically while also propagating directional derivatives
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
virtual std::vector< std::string > get_function() const override
static Function create(const std::string ¶llelization, const Function &f, casadi_int n)
Map(DeserializingStream &s)
Deserializing constructor.
virtual std::string parallelization() const
Type of parallellization.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nfwd forward derivatives.
void init(const Dict &opts) override
Initialize.
bool is_a(const std::string &type, bool recursive) const override
Check if the function is of a particular type.
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate the function numerically.
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
~OmpMap() override
Destructor.
Base class for FunctionInternal and LinsolInternal.
void clear_mem()
Clear all memory (called from destructor)
The basic scalar symbolic class of CasADi.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
Class representing a Slice.
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
~ThreadMap() override
Destructor.
void init(const Dict &opts) override
Initialize.
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate the function numerically.
bool is_a(const std::string &type, bool recursive) const override
Check if the function is of a particular type.
void ThreadsWork(const Function &f, casadi_int i, const double **arg, double **res, casadi_int *iw, double *w, casadi_int ind, int &ret)
std::string join(const std::vector< std::string > &l, const std::string &delim)
unsigned long long bvec_t
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.