switch.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "switch.hpp"
27 #include "serializing_stream.hpp"
28 
29 namespace casadi {
30 
31  Switch::Switch(const std::string& name,
32  const std::vector<Function>& f, const Function& f_def)
33  : FunctionInternal(name), f_(f), f_def_(f_def) {
34 
35  // Consitency check
36  casadi_assert_dev(!f_.empty());
37  }
38 
41  s.version("Switch", 1);
42  s.pack("Switch::f", f_);
43  s.pack("Switch::f_def", f_def_);
44  s.pack("Switch::project_in", project_in_);
45  s.pack("Switch::project_out", project_out_);
46  }
47 
49  s.version("Switch", 1);
50  s.unpack("Switch::f", f_);
51  s.unpack("Switch::f_def", f_def_);
52  s.unpack("Switch::project_in", project_in_);
53  s.unpack("Switch::project_out", project_out_);
54  }
55 
57  clear_mem();
58  }
59 
60  size_t Switch::get_n_in() {
61  for (auto&& i : f_) if (!i.is_null()) return 1+i.n_in();
62  casadi_assert_dev(!f_def_.is_null());
63  return 1+f_def_.n_in();
64  }
65 
66  size_t Switch::get_n_out() {
67  for (auto&& i : f_) if (!i.is_null()) return i.n_out();
68  casadi_assert_dev(!f_def_.is_null());
69  return f_def_.n_out();
70  }
71 
73  if (i==0) {
74  return Sparsity::scalar();
75  } else {
76  Sparsity ret;
77  for (auto&& fk : f_) {
78  if (!fk.is_null()) {
79  const Sparsity& s = fk.sparsity_in(i-1);
80  ret = ret.is_null() ? s : ret.unite(s);
81  }
82  }
83  casadi_assert_dev(!f_def_.is_null());
84  const Sparsity& s = f_def_.sparsity_in(i-1);
85  ret = ret.is_null() ? s : ret.unite(s);
86  return ret;
87  }
88  }
89 
91  Sparsity ret;
92  for (auto&& fk : f_) {
93  if (!fk.is_null()) {
94  const Sparsity& s = fk.sparsity_out(i);
95  ret = ret.is_null() ? s : ret.unite(s);
96  }
97  }
98  casadi_assert_dev(!f_def_.is_null());
99  const Sparsity& s = f_def_.sparsity_out(i);
100  ret = ret.is_null() ? s : ret.unite(s);
101  return ret;
102  }
103 
104  void Switch::init(const Dict& opts) {
105  // Call the initialization method of the base class
107 
108  // Buffer for mismatching sparsities
109  size_t sz_buf=0;
110 
111  // Keep track of sparsity projections
112  project_in_ = project_out_ = false;
113 
114  // Get required work
115  for (casadi_int k=0; k<=f_.size(); ++k) {
116  const Function& fk = k<f_.size() ? f_[k] : f_def_;
117  if (fk.is_null()) continue;
118 
119  // Memory for evaluation
120  alloc(fk);
121 
122  // Required work vectors
123  size_t sz_buf_k=0;
124 
125  // Add size for input buffers
126  for (casadi_int i=1; i<n_in_; ++i) {
127  const Sparsity& s = fk.sparsity_in(i-1);
128  if (s!=sparsity_in_[i]) {
129  project_in_ = true;
130  alloc_w(s.size1()); // for casadi_project
131  sz_buf_k += s.nnz();
132  }
133  }
134 
135  // Add size for output buffers
136  for (casadi_int i=0; i<n_out_; ++i) {
137  const Sparsity& s = fk.sparsity_out(i);
138  if (s!=sparsity_out_[i]) {
139  project_out_ = true;
140  alloc_w(s.size1()); // for casadi_project
141  sz_buf_k += s.nnz();
142  }
143  }
144 
145  // Only need the largest of these work vectors
146  sz_buf = std::max(sz_buf, sz_buf_k);
147  }
148 
149  // Memory for the work vectors
150  alloc_w(sz_buf, true);
151  }
152 
153  int Switch::eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const {
154  setup(mem, arg, res, iw, w);
155  // Get the function to be evaluated
156  casadi_int k = arg[0] ? static_cast<casadi_int>(*arg[0]) : 0;
157  const Function& fk = k>=0 && k<f_.size() ? f_[k] : f_def_;
158 
159  // Project arguments with different sparsity
160  const double** arg1;
161  if (project_in_) {
162  // Project one or more argument
163  arg1 = arg + n_in_;
164  for (casadi_int i=0; i<n_in_-1; ++i) {
165  const Sparsity& f_sp = fk.sparsity_in(i);
166  const Sparsity& sp = sparsity_in_[i+1];
167  arg1[i] = arg[i+1];
168  if (arg1[i] && f_sp!=sp) {
169  casadi_project(arg1[i], sp, w, f_sp, w + f_sp.nnz());
170  arg1[i] = w; w += f_sp.nnz();
171  }
172  }
173  } else {
174  // No inputs projected
175  arg1 = arg + 1;
176  }
177 
178  // Temporary memory for results with different sparsity
179  double** res1;
180  if (project_out_) {
181  // Project one or more results
182  res1 = res + n_out_;
183  for (casadi_int i=0; i<n_out_; ++i) {
184  const Sparsity& f_sp = fk.sparsity_out(i);
185  const Sparsity& sp = sparsity_out_[i];
186  res1[i] = res[i];
187  if (res1[i] && f_sp!=sp) {
188  res1[i] = w;
189  w += f_sp.nnz();
190  }
191  }
192  } else {
193  // No outputs projected
194  res1 = res;
195  }
196 
197  // Evaluate the corresponding function
198  if (fk(arg1, res1, iw, w, 0)) return 1;
199 
200  // Project results with different sparsity
201  if (project_out_) {
202  for (casadi_int i=0; i<n_out_; ++i) {
203  const Sparsity& f_sp = fk.sparsity_out(i);
204  const Sparsity& sp = sparsity_out_[i];
205  if (res[i] && f_sp!=sp) {
206  casadi_project(res1[i], f_sp, res[i], sp, w);
207  }
208  }
209  }
210  return 0;
211  }
212 
214  ::get_forward(casadi_int nfwd, const std::string& name,
215  const std::vector<std::string>& inames,
216  const std::vector<std::string>& onames,
217  const Dict& opts) const {
218  // Derivative of each case
219  std::vector<Function> der(f_.size());
220  for (casadi_int k=0; k<f_.size(); ++k) {
221  if (!f_[k].is_null()) der[k] = f_[k].forward(nfwd);
222  }
223 
224  // Default case
225  Function der_def;
226  if (!f_def_.is_null()) der_def = f_def_.forward(nfwd);
227 
228  // New Switch for derivatives
229  Function sw = Function::conditional("switch_" + name, der, der_def);
230 
231  // Get expressions for the derivative switch
232  std::vector<MX> arg = sw.mx_in();
233  std::vector<MX> res = sw(arg);
234 
235  // Ignore seed for ind
236  arg.insert(arg.begin() + n_in_ + n_out_, MX(1, nfwd));
237 
238  Dict options = opts;
239  options["allow_duplicate_io_names"] = true;
240  // Create wrapper
241  return Function(name, arg, res, inames, onames, options);
242  }
243 
245  ::get_reverse(casadi_int nadj, const std::string& name,
246  const std::vector<std::string>& inames,
247  const std::vector<std::string>& onames,
248  const Dict& opts) const {
249  // Derivative of each case
250  std::vector<Function> der(f_.size());
251  for (casadi_int k=0; k<f_.size(); ++k) {
252  if (!f_[k].is_null()) der[k] = f_[k].reverse(nadj);
253  }
254 
255  // Default case
256  Function der_def;
257  if (!f_def_.is_null()) der_def = f_def_.reverse(nadj);
258 
259  // New Switch for derivatives
260  Function sw = Function::conditional("switch_" + name, der, der_def);
261 
262  // Get expressions for the derivative switch
263  std::vector<MX> arg = sw.mx_in();
264  std::vector<MX> res = sw(arg);
265 
266  // No derivatives with respect to index
267  res.insert(res.begin(), MX(1, nadj));
268 
269  Dict options = opts;
270  options["allow_duplicate_io_names"] = true;
271 
272  // Create wrapper
273  return Function(name, arg, res, inames, onames, options);
274  }
275 
276  void Switch::disp_more(std::ostream &stream) const {
277  // Print more
278  if (f_.size()==1) {
279  // Print as if-then-else
280  stream << f_def_.name() << ", " << f_[0].name();
281  } else {
282  // Print generic
283  stream << "[";
284  for (casadi_int k=0; k<f_.size(); ++k) {
285  if (k!=0) stream << ", ";
286  stream << f_[k].name();
287  }
288  stream << "], " << f_def_.name();
289  }
290  }
291 
293  for (casadi_int k=0; k<=f_.size(); ++k) {
294  const Function& fk = k<f_.size() ? f_[k] : f_def_;
295  g.add_dependency(fk);
296  }
297  }
298 
299  int Switch::eval_sx(const SXElem** arg, SXElem** res,
300  casadi_int* iw, SXElem* w, void* mem,
301  bool always_inline, bool never_inline) const {
302  // Input and output buffers
303  const SXElem** arg1 = arg + n_in_;
304  SXElem** res1 = res + n_out_;
305 
306  // Extra memory needed for chaining if_else calls
307  std::vector<SXElem> w_extra(nnz_out());
308  std::vector<SXElem*> res_tempv(n_out_);
309  SXElem** res_temp = get_ptr(res_tempv);
310 
311  for (casadi_int k=0; k<f_.size()+1; ++k) {
312 
313  // Local work vector
314  SXElem* wl = w;
315 
316  // Local work vector
317  SXElem* wll = get_ptr(w_extra);
318 
319  if (k==0) {
320  // For the default case, redirect the temporary results to res
321  std::copy_n(res, n_out_, res_temp);
322  } else {
323  // For the other cases, store the temporary results
324  for (casadi_int i=0; i<n_out_; ++i) {
325  res_temp[i] = wll;
326  wll += nnz_out(i);
327  }
328  }
329 
330  std::copy_n(arg+1, n_in_-1, arg1);
331  std::copy_n(res_temp, n_out_, res1);
332 
333  const Function& fk = k==0 ? f_def_ : f_[k-1];
334 
335  // Project arguments with different sparsity
336  for (casadi_int i=0; i<n_in_-1; ++i) {
337  if (arg1[i]) {
338  const Sparsity& f_sp = fk.sparsity_in(i);
339  const Sparsity& sp = sparsity_in_[i+1];
340  if (f_sp!=sp) {
341  SXElem *t = wl; wl += f_sp.nnz(); // t is non-const
342  casadi_project(arg1[i], sp, t, f_sp, wl);
343  arg1[i] = t;
344  }
345  }
346  }
347 
348  // Temporary memory for results with different sparsity
349  for (casadi_int i=0; i<n_out_; ++i) {
350  if (res1[i]) {
351  const Sparsity& f_sp = fk.sparsity_out(i);
352  const Sparsity& sp = sparsity_out_[i];
353  if (f_sp!=sp) { res1[i] = wl; wl += f_sp.nnz();}
354  }
355  }
356 
357  // Evaluate the corresponding function
358  if (fk(arg1, res1, iw, wl, 0)) return 1;
359 
360  // Project results with different sparsity
361  for (casadi_int i=0; i<n_out_; ++i) {
362  if (res1[i]) {
363  const Sparsity& f_sp = fk.sparsity_out(i);
364  const Sparsity& sp = sparsity_out_[i];
365  if (f_sp!=sp) casadi_project(res1[i], f_sp, res_temp[i], sp, wl);
366  }
367  }
368 
369  if (k>0) { // output the temporary results via an if_else
370  SXElem cond = k-1==arg[0][0];
371  for (casadi_int i=0; i<n_out_; ++i) {
372  if (res[i]) {
373  for (casadi_int j=0; j<nnz_out(i); ++j) {
374  res[i][j] = if_else(cond, res_temp[i][j], res[i][j]);
375  }
376  }
377  }
378  }
379 
380  }
381  return 0;
382  }
383 
385  // Project arguments with different sparsity
386  if (project_in_) {
387  // Project one or more argument
388  g.local("i", "casadi_int");
389  g << "const casadi_real** arg1 = arg + " << n_in_ << ";\n";
390  }
391 
392  // Temporary memory for results with different sparsity
393  if (project_out_) {
394  // Project one or more results
395  g.local("i", "casadi_int");
396  g << "casadi_real** res1 = res + " << n_out_ << ";\n";
397  }
398 
399  if (project_in_)
400  g << "for (i=0; i<" << n_in_-1 << "; ++i) arg1[i]=arg[i+1];\n";
401 
402  if (project_out_)
403  g << "for (i=0; i<" << n_out_ << "; ++i) res1[i]=res[i];\n";
404 
405  // Codegen condition
406  bool if_else = f_.size()==1;
408  g << (if_else ? "if" : "switch") << " (arg[0] ? casadi_to_int(*arg[0]) : 0) {\n";
409 
410  // Loop over cases/functions
411  for (casadi_int k=0; k<=f_.size(); ++k) {
412 
413  // For if, reverse order
414  casadi_int k1 = if_else ? 1-k : k;
415 
416  if (!if_else) {
417  // Codegen cases
418  if (k1<f_.size()) {
419  g << "case " << k1 << ":\n";
420  } else {
421  g << "default:\n";
422  }
423  } else if (k1==0) {
424  // Else
425  g << "} else {\n";
426  }
427 
428  // Get the function:
429  const Function& fk = k1<f_.size() ? f_[k1] : f_def_;
430  if (fk.is_null()) {
431  g << "return 1;\n";
432  } else {
433  // Project arguments with different sparsity
434  for (casadi_int i=0; i<n_in_-1; ++i) {
435  const Sparsity& f_sp = fk.sparsity_in(i);
436  const Sparsity& sp = sparsity_in_[i+1];
437  if (f_sp!=sp) {
438  if (f_sp.nnz()==0) {
439  g << "arg1[" << i << "]=0;\n";
440  } else {
441  g.local("t", "casadi_real", "*");
442  g << "t=w, w+=" << f_sp.nnz() << ";\n"
443  << g.project("arg1[" + str(i) + "]", sp, "t", f_sp, "w") << "\n"
444  << "arg1[" << i << "]=t;\n";
445  }
446  }
447  }
448 
449  // Temporary memory for results with different sparsity
450  for (casadi_int i=0; i<n_out_; ++i) {
451  const Sparsity& f_sp = fk.sparsity_out(i);
452  const Sparsity& sp = sparsity_out_[i];
453  if (f_sp!=sp) {
454  if (f_sp.nnz()==0) {
455  g << "res1[" << i << "]=0;\n";
456  } else {
457  g << "res1[" << i << "]=w, w+=" << f_sp.nnz() << ";\n";
458  }
459  }
460  }
461 
462  // Function call
463  g << "if (" << g(fk, project_in_ ? "arg1" : "arg+1",
464  project_out_ ? "res1" : "res",
465  "iw", "w") << ") return 1;\n";
466 
467  // Project results with different sparsity
468  for (casadi_int i=0; i<n_out_; ++i) {
469  const Sparsity& f_sp = fk.sparsity_out(i);
470  const Sparsity& sp = sparsity_out_[i];
471  if (f_sp!=sp) {
472  g << g.project("res1[" + str(i) + "]", f_sp,
473  g.res(i), sp, "w") << "\n";
474  }
475  }
476 
477  // Break (if switch)
478  if (!if_else) g << "break;\n";
479  }
480  }
481 
482  // End switch/else
483  g << "}\n";
484  }
485 
486  Dict Switch::info() const {
487  return {{"project_in", project_in_}, {"project_out", project_out_},
488  {"f_def", f_def_}, {"f", f_}};
489  }
490 
491  void Switch::find(std::map<FunctionInternal*, Function>& all_fun,
492  casadi_int max_depth) const {
493  for (const Function& f_k : f_) {
494  if (!f_k.is_null()) add_embedded(all_fun, f_k, max_depth);
495  }
496  if (!f_def_.is_null()) add_embedded(all_fun, f_def_, max_depth);
497  }
498 
499 } // namespace casadi
Helper class for C code generation.
std::string project(const std::string &arg, const Sparsity &sp_arg, const std::string &res, const Sparsity &sp_res, const std::string &w)
Sparse assignment.
std::string add_dependency(const Function &f)
Add a function dependency.
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
std::string res(casadi_int i) const
Refer to resuly.
void add_auxiliary(Auxiliary f, const std::vector< std::string > &inst={"casadi_real"})
Add a built-in auxiliary function.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
void init(const Dict &opts) override
Initialize.
std::vector< Sparsity > sparsity_in_
Input and output sparsity.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
size_t n_in_
Number of inputs and outputs.
std::vector< Sparsity > sparsity_out_
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
casadi_int nnz_out() const
Number of input/output nonzeros.
void setup(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const
Set the (persistent and temporary) work vectors.
void alloc(const Function &f, bool persistent=false, int num_threads=1)
Ensure work vectors long enough to evaluate function.
void add_embedded(std::map< FunctionInternal *, Function > &all_fun, const Function &dep, casadi_int max_depth) const
Function object.
Definition: function.hpp:60
Function forward(casadi_int nfwd) const
Get a function that calculates nfwd forward derivatives.
Definition: function.cpp:1135
static Function conditional(const std::string &name, const std::vector< Function > &f, const Function &f_def, const Dict &opts=Dict())
Constuct a switch function.
Definition: function.cpp:765
const MX mx_in(casadi_int ind) const
Get symbolic primitives equivalent to the input expressions.
Definition: function.cpp:1584
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
Definition: function.cpp:1031
const std::string & name() const
Name of the function.
Definition: function.cpp:1307
Function reverse(casadi_int nadj) const
Get a function that calculates nadj adjoint derivatives.
Definition: function.cpp:1143
const Sparsity & sparsity_in(casadi_int ind) const
Get sparsity of a given input.
Definition: function.cpp:1015
casadi_int n_out() const
Get the number of function outputs.
Definition: function.cpp:823
casadi_int n_in() const
Get the number of function inputs.
Definition: function.cpp:819
bool is_null() const
Is a null pointer?
MX - Matrix expression.
Definition: mx.hpp:92
void clear_mem()
Clear all memory (called from destructor)
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
General sparsity class.
Definition: sparsity.hpp:106
casadi_int size1() const
Get the number of rows.
Definition: sparsity.cpp:124
Sparsity unite(const Sparsity &y, std::vector< unsigned char > &mapping) const
Union of two sparsity patterns.
Definition: sparsity.cpp:409
casadi_int nnz() const
Get the number of (structural) non-zeros.
Definition: sparsity.cpp:148
static Sparsity scalar(bool dense_scalar=true)
Create a scalar sparsity pattern *.
Definition: sparsity.hpp:153
size_t get_n_in() override
Number of function inputs and outputs.
Definition: switch.cpp:60
void disp_more(std::ostream &stream) const override
Print description.
Definition: switch.cpp:276
bool project_out_
Definition: switch.hpp:140
Dict info() const override
Definition: switch.cpp:486
Switch(const std::string &name, const std::vector< Function > &f, const Function &f_def)
Constructor (generic switch)
Definition: switch.cpp:31
void codegen_declarations(CodeGenerator &g) const override
Generate code for the declarations of the C function.
Definition: switch.cpp:292
Function f_def_
Definition: switch.hpp:137
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w, void *mem, bool always_inline, bool never_inline) const override
evaluate symbolically while also propagating directional derivatives
Definition: switch.cpp:299
bool project_in_
Definition: switch.hpp:140
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate numerically, work vectors given.
Definition: switch.cpp:153
void init(const Dict &opts) override
Initialize.
Definition: switch.cpp:104
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
Definition: switch.cpp:39
std::vector< Function > f_
Definition: switch.hpp:134
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nadj adjoint derivatives.
Definition: switch.cpp:245
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
Definition: switch.cpp:90
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nfwd forward derivatives.
Definition: switch.cpp:214
void find(std::map< FunctionInternal *, Function > &all_fun, casadi_int max_depth) const override
Definition: switch.cpp:491
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
Definition: switch.cpp:72
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
Definition: switch.cpp:384
~Switch() override
Destructor.
Definition: switch.cpp:56
size_t get_n_out() override
Number of function inputs and outputs.
Definition: switch.cpp:66
The casadi namespace.
Definition: archiver.cpp:28
double if_else(double x, double y, double z)
Definition: calculus.hpp:290
void casadi_project(const T1 *x, const casadi_int *sp_x, T1 *y, const casadi_int *sp_y, T1 *w)
Sparse copy: y <- x, w work vector (length >= number of rows)
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.