27 #include "serializing_stream.hpp"
32 const std::vector<Function>& f,
const Function& f_def)
36 casadi_assert_dev(!
f_.empty());
61 for (
auto&& i :
f_)
if (!i.is_null())
return 1+i.n_in();
67 for (
auto&& i :
f_)
if (!i.is_null())
return i.n_out();
77 for (
auto&& fk :
f_) {
79 const Sparsity& s = fk.sparsity_in(i-1);
92 for (
auto&& fk :
f_) {
94 const Sparsity& s = fk.sparsity_out(i);
115 for (casadi_int k=0; k<=
f_.size(); ++k) {
126 for (casadi_int i=1; i<
n_in_; ++i) {
136 for (casadi_int i=0; i<
n_out_; ++i) {
146 sz_buf = std::max(sz_buf, sz_buf_k);
153 int Switch::eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const {
154 setup(mem, arg, res, iw, w);
156 casadi_int k = arg[0] ?
static_cast<casadi_int
>(*arg[0]) : 0;
164 for (casadi_int i=0; i<
n_in_-1; ++i) {
165 const Sparsity& f_sp = fk.sparsity_in(i);
168 if (arg1[i] && f_sp!=sp) {
170 arg1[i] = w; w += f_sp.
nnz();
183 for (casadi_int i=0; i<
n_out_; ++i) {
184 const Sparsity& f_sp = fk.sparsity_out(i);
187 if (res1[i] && f_sp!=sp) {
198 if (fk(arg1, res1, iw, w, 0))
return 1;
202 for (casadi_int i=0; i<
n_out_; ++i) {
203 const Sparsity& f_sp = fk.sparsity_out(i);
205 if (res[i] && f_sp!=sp) {
215 const std::vector<std::string>& inames,
216 const std::vector<std::string>& onames,
217 const Dict& opts)
const {
219 std::vector<Function> der(f_.size());
220 for (casadi_int k=0; k<f_.size(); ++k) {
221 if (!f_[k].is_null()) der[k] = f_[k].forward(nfwd);
226 if (!f_def_.is_null()) der_def = f_def_.
forward(nfwd);
232 std::vector<MX> arg = sw.
mx_in();
233 std::vector<MX> res = sw(arg);
236 arg.insert(arg.begin() + n_in_ + n_out_,
MX(1, nfwd));
239 options[
"allow_duplicate_io_names"] =
true;
241 return Function(name, arg, res, inames, onames, options);
246 const std::vector<std::string>& inames,
247 const std::vector<std::string>& onames,
248 const Dict& opts)
const {
250 std::vector<Function> der(f_.size());
251 for (casadi_int k=0; k<f_.size(); ++k) {
252 if (!f_[k].is_null()) der[k] = f_[k].reverse(nadj);
257 if (!f_def_.is_null()) der_def = f_def_.
reverse(nadj);
263 std::vector<MX> arg = sw.
mx_in();
264 std::vector<MX> res = sw(arg);
267 res.insert(res.begin(),
MX(1, nadj));
270 options[
"allow_duplicate_io_names"] =
true;
273 return Function(name, arg, res, inames, onames, options);
284 for (casadi_int k=0; k<
f_.size(); ++k) {
285 if (k!=0) stream <<
", ";
286 stream <<
f_[k].name();
293 for (casadi_int k=0; k<=
f_.size(); ++k) {
300 casadi_int* iw,
SXElem* w,
void* mem,
301 bool always_inline,
bool never_inline)
const {
307 std::vector<SXElem> w_extra(
nnz_out());
308 std::vector<SXElem*> res_tempv(
n_out_);
311 for (casadi_int k=0; k<
f_.size()+1; ++k) {
321 std::copy_n(res,
n_out_, res_temp);
324 for (casadi_int i=0; i<
n_out_; ++i) {
330 std::copy_n(arg+1,
n_in_-1, arg1);
331 std::copy_n(res_temp,
n_out_, res1);
336 for (casadi_int i=0; i<
n_in_-1; ++i) {
349 for (casadi_int i=0; i<
n_out_; ++i) {
353 if (f_sp!=sp) { res1[i] = wl; wl += f_sp.
nnz();}
358 if (fk(arg1, res1, iw, wl, 0))
return 1;
361 for (casadi_int i=0; i<
n_out_; ++i) {
370 SXElem cond = k-1==arg[0][0];
371 for (casadi_int i=0; i<
n_out_; ++i) {
373 for (casadi_int j=0; j<
nnz_out(i); ++j) {
374 res[i][j] =
if_else(cond, res_temp[i][j], res[i][j]);
388 g.
local(
"i",
"casadi_int");
389 g <<
"const casadi_real** arg1 = arg + " <<
n_in_ <<
";\n";
395 g.
local(
"i",
"casadi_int");
396 g <<
"casadi_real** res1 = res + " <<
n_out_ <<
";\n";
400 g <<
"for (i=0; i<" <<
n_in_-1 <<
"; ++i) arg1[i]=arg[i+1];\n";
403 g <<
"for (i=0; i<" <<
n_out_ <<
"; ++i) res1[i]=res[i];\n";
408 g << (
if_else ?
"if" :
"switch") <<
" (arg[0] ? casadi_to_int(*arg[0]) : 0) {\n";
411 for (casadi_int k=0; k<=
f_.size(); ++k) {
414 casadi_int k1 =
if_else ? 1-k : k;
419 g <<
"case " << k1 <<
":\n";
434 for (casadi_int i=0; i<
n_in_-1; ++i) {
439 g <<
"arg1[" << i <<
"]=0;\n";
441 g.
local(
"t",
"casadi_real",
"*");
442 g <<
"t=w, w+=" << f_sp.
nnz() <<
";\n"
443 << g.
project(
"arg1[" +
str(i) +
"]", sp,
"t", f_sp,
"w") <<
"\n"
444 <<
"arg1[" << i <<
"]=t;\n";
450 for (casadi_int i=0; i<
n_out_; ++i) {
455 g <<
"res1[" << i <<
"]=0;\n";
457 g <<
"res1[" << i <<
"]=w, w+=" << f_sp.
nnz() <<
";\n";
463 g <<
"if (" << g(fk,
project_in_ ?
"arg1" :
"arg+1",
465 "iw",
"w") <<
") return 1;\n";
468 for (casadi_int i=0; i<
n_out_; ++i) {
473 g.
res(i), sp,
"w") <<
"\n";
492 casadi_int max_depth)
const {
494 if (!f_k.is_null())
add_embedded(all_fun, f_k, max_depth);
Helper class for C code generation.
std::string project(const std::string &arg, const Sparsity &sp_arg, const std::string &res, const Sparsity &sp_res, const std::string &w)
Sparse assignment.
std::string add_dependency(const Function &f)
Add a function dependency.
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
std::string res(casadi_int i) const
Refer to resuly.
void add_auxiliary(Auxiliary f, const std::vector< std::string > &inst={"casadi_real"})
Add a built-in auxiliary function.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
void init(const Dict &opts) override
Initialize.
std::vector< Sparsity > sparsity_in_
Input and output sparsity.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
size_t n_in_
Number of inputs and outputs.
std::vector< Sparsity > sparsity_out_
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
casadi_int nnz_out() const
Number of input/output nonzeros.
void setup(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const
Set the (persistent and temporary) work vectors.
void alloc(const Function &f, bool persistent=false, int num_threads=1)
Ensure work vectors long enough to evaluate function.
void add_embedded(std::map< FunctionInternal *, Function > &all_fun, const Function &dep, casadi_int max_depth) const
Function forward(casadi_int nfwd) const
Get a function that calculates nfwd forward derivatives.
static Function conditional(const std::string &name, const std::vector< Function > &f, const Function &f_def, const Dict &opts=Dict())
Constuct a switch function.
const MX mx_in(casadi_int ind) const
Get symbolic primitives equivalent to the input expressions.
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
const std::string & name() const
Name of the function.
Function reverse(casadi_int nadj) const
Get a function that calculates nadj adjoint derivatives.
const Sparsity & sparsity_in(casadi_int ind) const
Get sparsity of a given input.
casadi_int n_out() const
Get the number of function outputs.
casadi_int n_in() const
Get the number of function inputs.
bool is_null() const
Is a null pointer?
void clear_mem()
Clear all memory (called from destructor)
The basic scalar symbolic class of CasADi.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
casadi_int size1() const
Get the number of rows.
Sparsity unite(const Sparsity &y, std::vector< unsigned char > &mapping) const
Union of two sparsity patterns.
casadi_int nnz() const
Get the number of (structural) non-zeros.
static Sparsity scalar(bool dense_scalar=true)
Create a scalar sparsity pattern *.
size_t get_n_in() override
Number of function inputs and outputs.
void disp_more(std::ostream &stream) const override
Print description.
Dict info() const override
Switch(const std::string &name, const std::vector< Function > &f, const Function &f_def)
Constructor (generic switch)
void codegen_declarations(CodeGenerator &g) const override
Generate code for the declarations of the C function.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w, void *mem, bool always_inline, bool never_inline) const override
evaluate symbolically while also propagating directional derivatives
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate numerically, work vectors given.
void init(const Dict &opts) override
Initialize.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
std::vector< Function > f_
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nadj adjoint derivatives.
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nfwd forward derivatives.
void find(std::map< FunctionInternal *, Function > &all_fun, casadi_int max_depth) const override
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
~Switch() override
Destructor.
size_t get_n_out() override
Number of function inputs and outputs.
double if_else(double x, double y, double z)
void casadi_project(const T1 *x, const casadi_int *sp_x, T1 *y, const casadi_int *sp_y, T1 *w)
Sparse copy: y <- x, w work vector (length >= number of rows)
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.