casadi_call.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "casadi_call.hpp"
27 #include "function_internal.hpp"
28 #include "casadi_misc.hpp"
29 #include "serializing_stream.hpp"
30 
31 #define CASADI_THROW_ERROR(FNAME, WHAT) \
32 throw CasadiException("Error in Call::" FNAME " for '" + fcn_.name() + "' "\
33  "[" + fcn_.class_name() + "] at " + CASADI_WHERE + ":\n" + std::string(WHAT));
34 
35 namespace casadi {
36 
37  MX Call::projectArg(const MX& x, const Sparsity& sp, casadi_int i) {
38  if (x.size()==sp.size()) {
39  // Insert sparsity projection nodes if needed
40  return project(x, sp);
41  } else {
42  // Different dimensions
43  if (x.is_empty() || sp.is_empty()) { // NOTE: To permissive?
44  // Replace nulls with zeros of the right dimension
45  return MX::zeros(sp);
46  } else if (x.is_scalar()) {
47  // Scalar argument means set all
48  return MX(sp, x);
49  } else if (x.size1()==sp.size2() && x.size2()==sp.size1() && sp.is_vector()) {
50  // Transposed vector
51  return projectArg(x.T(), sp, i);
52  } else {
53  // Mismatching dimensions
54  casadi_error("Cannot create function call node: Dimension mismatch for argument "
55  + str(i) + ". Argument has shape " + str(x.size())
56  + " but function input has shape " + str(sp.size()));
57  }
58  }
59  }
60 
61  MX Call::get_output(casadi_int oind) const {
62  MX this_ = shared_from_this<MX>();
63  // No need for an OutputNode if sparsity is fully sparse
64  if (this_->sparsity(oind).nnz()==0) return MX(this_->sparsity(oind));
65  MX ret;
66  if (!cache_.incache(oind, ret)) {
67  ret = MX::create(new OutputNode(this_, oind));
68  cache_.tocache_if_missing(oind, ret);
69  }
70  return ret;
71  }
72 
73  Call::Call(const Function& fcn, const std::vector<MX>& arg) : fcn_(fcn) {
74 
75  // Number inputs and outputs
76  casadi_int num_in = fcn.n_in();
77  casadi_assert(arg.size()==num_in, "Argument list length (" + str(arg.size())
78  + ") does not match number of inputs (" + str(num_in)
79  + ") for function " + fcn.name());
80 
81  // Create arguments of the right dimensions and sparsity
82  std::vector<MX> arg1(num_in);
83  for (casadi_int i=0; i<num_in; ++i) {
84  arg1[i] = projectArg(arg[i], fcn_.sparsity_in(i), i);
85  }
86  set_dep(arg1);
88  }
89 
90  std::string Call::disp(const std::vector<std::string>& arg) const {
91  std::stringstream ss;
92  ss << fcn_.name() << "(";
93  for (casadi_int i=0; i<n_dep(); ++i) {
94  if (i!=0) ss << ", ";
95  ss << arg.at(i);
96  }
97  ss << ")";
98  return ss.str();
99  }
100 
101  int Call::eval(const double** arg, double** res, casadi_int* iw, double* w) const {
102  return fcn_(arg, res, iw, w);
103  }
104 
105  casadi_int Call::nout() const {
106  return fcn_.n_out();
107  }
108 
109  const Sparsity& Call::sparsity(casadi_int oind) const {
110  return fcn_.sparsity_out(oind);
111  }
112 
113  int Call::eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
114  return fcn_(arg, res, iw, w);
115  }
116 
117  void Call::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
118  res = create(fcn_, arg);
119  }
120 
121  void Call::ad_forward(const std::vector<std::vector<MX>>& fseed,
122  std::vector<std::vector<MX>>& fsens) const {
123  try {
124  // Nondifferentiated inputs and outputs
125  std::vector<MX> arg(n_dep());
126  for (casadi_int i=0; i<arg.size(); ++i) arg[i] = dep(i);
127  std::vector<MX> res(nout());
128  for (casadi_int i=0; i<res.size(); ++i) res[i] = get_output(i);
129 
130  // Call the cached functions
131  fcn_->call_forward(arg, res, fseed, fsens, false, false);
132  } catch (std::exception& e) {
133  CASADI_THROW_ERROR("ad_forward", e.what());
134  }
135  }
136 
137  void Call::ad_reverse(const std::vector<std::vector<MX>>& aseed,
138  std::vector<std::vector<MX>>& asens) const {
139  try {
140  // Find a common conditional argument among the seeds, if any
141  MX cond = common_cond(aseed);
142  // Nondifferentiated inputs and outputs
143  std::vector<MX> arg(n_dep());
144  for (casadi_int i=0; i<arg.size(); ++i) arg[i] = dep(i);
145  std::vector<MX> res(nout());
146  for (casadi_int i=0; i<res.size(); ++i) res[i] = get_output(i);
147  // Call the cached functions
148  std::vector<std::vector<MX>> v;
149  fcn_->call_reverse(arg, res, aseed, v, false, false);
150  for (casadi_int i=0; i<v.size(); ++i) {
151  for (casadi_int j=0; j<v[i].size(); ++j) {
152  // Skip structurally zero contributions (necessary?)
153  if (v[i][j].is_empty()) continue;
154  // Prevent propagation of NaNs through if/else
155  if (!cond.is_empty()) v[i][j] = if_else(cond, v[i][j], 0);
156  // Add seeds
157  asens[i][j] += v[i][j];
158  }
159  }
160  } catch (std::exception& e) {
161  CASADI_THROW_ERROR("ad_reverse", e.what());
162  }
163  }
164 
165  int Call::sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
166  return fcn_(arg, res, iw, w);
167  }
168 
169  int Call::sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
170  return fcn_.rev(arg, res, iw, w);
171  }
172 
174  g.add_dependency(fcn_);
175  }
176 
177  bool Call::has_refcount() const {
178  return fcn_->has_refcount_;
179  }
180 
182  const std::vector<casadi_int>& arg,
183  const std::vector<casadi_int>& res,
184  const std::vector<bool>& arg_is_ref,
185  std::vector<bool>& res_is_ref) const {
186  // Collect input arguments
187  g.local("arg1", "const casadi_real", "**");
188  for (casadi_int i=0; i<arg.size(); ++i) {
189  g << "arg1[" << i << "]=" << g.work(arg[i], fcn_.nnz_in(i), arg_is_ref[i]) << ";\n";
190  }
191 
192  // Collect output arguments
193  g.local("res1", "casadi_real", "**");
194  for (casadi_int i=0; i<res.size(); ++i) {
195  g << "res1[" << i << "]=" << g.work(res[i], fcn_.nnz_out(i), false) << ";\n";
196  }
197 
198  // Call function
199  std::string flag = g(fcn_, "arg1", "res1", "iw", "w");
200  g << "if (" << flag << ") return 1;\n";
201  }
202 
203  void Call::codegen_incref(CodeGenerator& g, std::set<void*>& added) const {
204  if (has_refcount()) {
205  auto i = added.insert(fcn_.get());
206  if (i.second) { // prevent duplicate calls
207  g << fcn_->codegen_name(g) << "_incref();\n";
208  }
209  }
210  }
211 
212  void Call::codegen_decref(CodeGenerator& g, std::set<void*>& added) const {
213  if (has_refcount()) {
214  auto i = added.insert(fcn_.get());
215  if (i.second) { // prevent duplicate calls
216  g << fcn_->codegen_name(g) << "_decref();\n";
217  }
218  }
219  }
220 
221  size_t Call::sz_arg() const {
222  return fcn_.sz_arg();
223  }
224 
225  size_t Call::sz_res() const {
226  return fcn_.sz_res();
227  }
228 
229  size_t Call::sz_iw() const {
230  return fcn_.sz_iw();
231  }
232 
233  size_t Call::sz_w() const {
234  return fcn_.sz_w();
235  }
236 
237  std::vector<MX> Call::create(const Function& fcn, const std::vector<MX>& arg) {
238  return MX::createMultipleOutput(new Call(fcn, arg));
239  }
240 
241  MX Call::create_call(const Function& fcn, const std::vector<MX>& arg) {
242  return MX::create(new Call(fcn, arg));
243  }
244 
247  s.pack("Call::fcn", fcn_);
248  }
249 
251  s.unpack("Call::fcn", fcn_);
252  }
253 
254  MX Call::common_cond(const std::vector<std::vector<MX> >& seed) {
255  // Check if all seeds are conditional with the same seed
256  MX c;
257  for (const std::vector<MX>& seed_dir : seed) {
258  for (const MX& s : seed_dir) {
259  // Skip zero seeds
260  if (s.is_zero()) continue;
261  // If not a conditional, no common condition
262  if (!s.is_op(OP_IF_ELSE_ZERO)) return MX();
263  // Get conditional
264  MX c1 = s.dep(0);
265  // Has c already been set
266  if (c.is_empty(true)) {
267  // First time encountered
268  c = c1;
269  } else if (!MX::is_equal(c, c1)) {
270  // Different conditionals
271  return MX();
272  }
273  }
274  }
275  return c;
276  }
277 
278 } // namespace casadi
void codegen_incref(CodeGenerator &g, std::set< void * > &added) const override
Codegen incref.
static std::vector< MX > create(const Function &fcn, const std::vector< MX > &arg)
Create function call node.
Call(const Function &fcn, const std::vector< MX > &arg)
Constructor (should not be used directly)
Definition: casadi_call.cpp:73
WeakCache< casadi_int, MX > cache_
Output node cache.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
size_t sz_w() const override
Get required length of w field.
MX get_output(casadi_int oind) const override
Get an output.
Definition: casadi_call.cpp:61
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
Function fcn_
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
void codegen_decref(CodeGenerator &g, std::set< void * > &added) const override
Codegen decref.
bool has_refcount() const override
Is reference counting needed in codegen?
size_t sz_res() const override
Get required length of res field.
static MX create_call(const Function &fcn, const std::vector< MX > &arg)
Create function call node.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
size_t sz_iw() const override
Get required length of iw field.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
size_t sz_arg() const override
Get required length of arg field.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
casadi_int nout() const override
Number of outputs.
void add_dependency(CodeGenerator &g) const override
Add a dependent function.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
static MX projectArg(const MX &x, const Sparsity &sp, casadi_int i)
Project a function input to a particular sparsity.
Definition: casadi_call.cpp:37
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
Definition: casadi_call.cpp:90
static MX common_cond(const std::vector< std::vector< MX >> &seed)
Find a common conditional argument for all seeds.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
Helper class for C code generation.
std::string add_dependency(const Function &f)
Add a function dependency.
std::string work(casadi_int n, casadi_int sz, bool is_ref) const
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
bool has_refcount_
Reference counting in codegen?
virtual void call_forward(const std::vector< MX > &arg, const std::vector< MX > &res, const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens, bool always_inline, bool never_inline) const
Forward mode AD, virtual functions overloaded in derived classes.
virtual void call_reverse(const std::vector< MX > &arg, const std::vector< MX > &res, const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens, bool always_inline, bool never_inline) const
Reverse mode, virtual functions overloaded in derived classes.
virtual std::string codegen_name(const CodeGenerator &g, bool ns=true) const
Get name in codegen.
Function object.
Definition: function.hpp:60
casadi_int nnz_out() const
Get number of output nonzeros.
Definition: function.cpp:855
size_t sz_res() const
Get required length of res field.
Definition: function.cpp:1085
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
Definition: function.cpp:1031
FunctionInternal * get() const
Definition: function.cpp:353
const std::string & name() const
Name of the function.
Definition: function.cpp:1307
int rev(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, int mem=0) const
Propagate sparsity backward.
Definition: function.cpp:1100
const Sparsity & sparsity_in(casadi_int ind) const
Get sparsity of a given input.
Definition: function.cpp:1015
size_t sz_iw() const
Get required length of iw field.
Definition: function.cpp:1087
casadi_int n_out() const
Get the number of function outputs.
Definition: function.cpp:823
casadi_int n_in() const
Get the number of function inputs.
Definition: function.cpp:819
size_t sz_w() const
Get required length of w field.
Definition: function.cpp:1089
size_t sz_arg() const
Get required length of arg field.
Definition: function.cpp:1083
casadi_int nnz_in() const
Get number of input nonzeros.
Definition: function.cpp:851
bool is_empty(bool both=false) const
Check if the sparsity is empty, i.e. if one of the dimensions is zero.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
static MX zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
bool is_scalar(bool scalar_and_dense=false) const
Check if the matrix expression is scalar.
friend class MX
Definition: mx_node.hpp:52
const Sparsity & sparsity() const
Get the sparsity.
Definition: mx_node.hpp:372
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
Definition: mx_node.hpp:354
casadi_int n_dep() const
Number of dependencies.
Definition: mx_node.cpp:206
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
Definition: mx_node.cpp:523
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
Definition: mx_node.cpp:222
void set_dep(const MX &dep)
Set unary dependency.
Definition: mx_node.cpp:226
MX - Matrix expression.
Definition: mx.hpp:92
static MX create(MXNode *node)
Create from node.
Definition: mx.cpp:67
static bool is_equal(const MX &x, const MX &y, casadi_int depth=0)
Definition: mx.cpp:838
MX T() const
Transpose the matrix.
Definition: mx.cpp:1029
static std::vector< MX > createMultipleOutput(MXNode *node)
Create from node (multiple-outputs)
Definition: mx.cpp:128
MX dep(casadi_int ch=0) const
Get the nth dependency as MX.
Definition: mx.cpp:754
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
General sparsity class.
Definition: sparsity.hpp:106
bool is_vector() const
Check if the pattern is a row or column vector.
Definition: sparsity.cpp:289
casadi_int size1() const
Get the number of rows.
Definition: sparsity.cpp:124
casadi_int nnz() const
Get the number of (structural) non-zeros.
Definition: sparsity.cpp:148
casadi_int size2() const
Get the number of columns.
Definition: sparsity.cpp:128
std::pair< casadi_int, casadi_int > size() const
Get the shape.
Definition: sparsity.cpp:152
static Sparsity scalar(bool dense_scalar=true)
Create a scalar sparsity pattern *.
Definition: sparsity.hpp:153
bool is_empty(bool both=false) const
Check if the sparsity is empty.
Definition: sparsity.cpp:144
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
double if_else(double x, double y, double z)
Definition: calculus.hpp:290
std::string str(const T &v)
String representation, any type.
@ OP_IF_ELSE_ZERO
Definition: calculus.hpp:71