binary_mx_impl.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
28 
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
33 #include <sstream>
34 #include <vector>
35 
36 namespace casadi {
37 
38  template<bool ScX, bool ScY>
39  BinaryMX<ScX, ScY>::BinaryMX(Operation op, const MX& x, const MX& y) : op_(op) {
40  set_dep(x, y);
41  if (ScX) {
43  } else {
45  }
46  }
47 
48  template<bool ScX, bool ScY>
50  }
51 
52  template<bool ScX, bool ScY>
53  std::string BinaryMX<ScX, ScY>::disp(const std::vector<std::string>& arg) const {
54  return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
55  }
56 
57  template<bool ScX, bool ScY>
58  void BinaryMX<ScX, ScY>::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
59  casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
60  }
61 
62  template<bool ScX, bool ScY>
63  void BinaryMX<ScX, ScY>::eval_linear(const std::vector<std::array<MX, 3> >& arg,
64  std::vector<std::array<MX, 3> >& res) const {
65  casadi_math<MX>::fun_linear(op_, arg[0].data(), arg[1].data(), res[0].data());
66  }
67 
68  template<bool ScX, bool ScY>
69  void BinaryMX<ScX, ScY>::ad_forward(const std::vector<std::vector<MX> >& fseed,
70  std::vector<std::vector<MX> >& fsens) const {
71  // Get partial derivatives
72  MX pd[2];
73  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
74 
75  // Propagate forward seeds
76  for (casadi_int d=0; d<fsens.size(); ++d) {
77  if (op_ == OP_IF_ELSE_ZERO) {
78  fsens[d][0] = if_else_zero(pd[1], fseed[d][1]);
79  } else {
80  fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
81  }
82  }
83  }
84 
85  template<bool ScX, bool ScY>
86  void BinaryMX<ScX, ScY>::ad_reverse(const std::vector<std::vector<MX> >& aseed,
87  std::vector<std::vector<MX> >& asens) const {
88  // Get partial derivatives
89  MX pd[2];
90  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
91 
92  // Propagate adjoint seeds
93  for (casadi_int d=0; d<aseed.size(); ++d) {
94  MX s = aseed[d][0];
95  if (op_ == OP_IF_ELSE_ZERO) {
96  // Special case to avoid NaN propagation
97  if (!s.is_scalar() && dep(1).is_scalar()) {
98  asens[d][1] += dot(dep(0), s);
99  } else {
100  asens[d][1] += if_else_zero(dep(0), s);
101  }
102  } else {
103  // General case
104  for (casadi_int c=0; c<2; ++c) {
105  // Get increment of sensitivity c
106  MX t = pd[c]*s;
107 
108  // If dimension mismatch (i.e. one argument is scalar), then sum all the entries
109  if (!t.is_scalar() && t.size() != dep(c).size()) {
110  if (pd[c].size()!=s.size()) pd[c] = MX(s.sparsity(), pd[c]);
111  t = dot(pd[c], s);
112  }
113 
114  // Propagate the seeds
115  asens[d][c] += t;
116  }
117  }
118  }
119  }
120 
121  template<bool ScX, bool ScY>
124  const std::vector<casadi_int>& arg, const std::vector<casadi_int>& res,
125  const std::vector<bool>& arg_is_ref, std::vector<bool>& res_is_ref) const {
126  // Quick return if nothing to do
127  if (nnz()==0) return;
128 
129  // Check if inplace
130  bool inplace;
131  switch (op_) {
132  case OP_ADD:
133  case OP_SUB:
134  case OP_MUL:
135  case OP_DIV:
136  inplace = res[0]==arg[0] && !arg_is_ref[0];
137  break;
138  default:
139  inplace = false;
140  break;
141  }
142 
143  // Scalar names of arguments (start assuming all scalars)
144  std::string r = g.workel(res[0]);
145  std::string x = g.workel(arg[0]);
146  std::string y = g.workel(arg[1]);
147 
148  // Avoid emitting '/*' which will be mistaken for a comment
149  if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
150  y = "(" + y + ")";
151  }
152 
153  // Codegen loop, if needed
154  if (nnz()>1) {
155  // Iterate over result
156  g.local("rr", "casadi_real", "*");
157  g.local("i", "casadi_int");
158  g << "for (i=0, " << "rr=" << g.work(res[0], nnz(), false);
159  r = "(*rr++)";
160 
161  // Iterate over first argument?
162  if (!ScX && !inplace) {
163  g.local("cr", "const casadi_real", "*");
164  g << ", cr=" << g.work(arg[0], dep(0).nnz(), arg_is_ref[0]);
165  if (op_==OP_OR || op_==OP_AND) {
166  // Avoid short-circuiting with side effects
167  x = "cr[i]";
168  } else {
169  x = "(*cr++)";
170  }
171 
172  }
173 
174  // Iterate over second argument?
175  if (!ScY) {
176  g.local("cs", "const casadi_real", "*");
177  g << ", cs=" << g.work(arg[1], dep(1).nnz(), arg_is_ref[1]);
178  if (op_==OP_OR || op_==OP_AND || op_==OP_IF_ELSE_ZERO) {
179  // Avoid short-circuiting with side effects
180  y = "cs[i]";
181  } else {
182  y = "(*cs++)";
183  }
184  }
185 
186  // Close loop
187  g << "; i<" << nnz() << "; ++i) ";
188  }
189 
190  // Perform operation
191  g << r << " ";
192  if (inplace) {
193  g << casadi_math<double>::sep(op_) << "= " << y;
194  } else {
195  g << " = " << g.print_op(op_, x, y);
196  }
197  g << ";\n";
198  }
199 
200  template<bool ScX, bool ScY>
202  eval(const double** arg, double** res, casadi_int* iw, double* w) const {
203  return eval_gen<double>(arg, res, iw, w);
204  }
205 
206  template<bool ScX, bool ScY>
208  eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
209  return eval_gen<SXElem>(arg, res, iw, w);
210  }
211 
212  template<bool ScX, bool ScY>
213  template<typename T>
215  eval_gen(const T* const* arg, T* const* res, casadi_int* iw, T* w) const {
216  // Get data
217  T* output0 = res[0];
218  const T* input0 = arg[0];
219  const T* input1 = arg[1];
220 
221  if (!ScX && !ScY) {
222  casadi_math<T>::fun(op_, input0, input1, output0, nnz());
223  } else if (ScX) {
224  casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
225  } else {
226  casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
227  }
228  return 0;
229  }
230 
231  template<bool ScX, bool ScY>
233  sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
234  const bvec_t *a0=arg[0], *a1=arg[1];
235  bvec_t *r=res[0];
236  casadi_int n=nnz();
237  for (casadi_int i=0; i<n; ++i) {
238  if (ScX && ScY)
239  *r++ = *a0 | *a1;
240  else if (ScX && !ScY)
241  *r++ = *a0 | *a1++;
242  else if (!ScX && ScY)
243  *r++ = *a0++ | *a1;
244  else
245  *r++ = *a0++ | *a1++;
246  }
247  return 0;
248  }
249 
250  template<bool ScX, bool ScY>
252  sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
253  bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
254  casadi_int n=nnz();
255  for (casadi_int i=0; i<n; ++i) {
256  bvec_t s = *r;
257  *r++ = 0;
258  if (ScX)
259  *a0 |= s;
260  else
261  *a0++ |= s;
262  if (ScY)
263  *a1 |= s;
264  else
265  *a1++ |= s;
266  }
267  return 0;
268  }
269 
270  template<bool ScX, bool ScY>
271  MX BinaryMX<ScX, ScY>::get_unary(casadi_int op) const {
272  //switch (op_) {
273  //default: break; // no rule
274  //}
275 
276  // Fallback to default implementation
277  return MXNode::get_unary(op);
278  }
279 
280  template<bool ScX, bool ScY>
281  MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op, const MX& y, bool scX, bool scY) const {
283 
284  switch (op_) {
285  case OP_ADD:
286  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return dep(1);
287  if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
288  break;
289  case OP_SUB:
290  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return -dep(1);
291  if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
292  break;
293  default: break; // no rule
294  }
295 
296  // Fallback to default implementation
297  return MXNode::_get_binary(op, y, scX, scY);
298  }
299 
300  template<bool ScX, bool ScY>
303  s.pack("BinaryMX::op", static_cast<int>(op_));
304  }
305 
306  template<bool ScX, bool ScY>
309  char type_x = ScX;
310  char type_y = ScY;
311  char type = type_x | (type_y << 1);
312  s.pack("BinaryMX::scalar_flags", type);
313  }
314 
315  template<bool ScX, bool ScY>
317  char t;
318  s.unpack("BinaryMX::scalar_flags", t);
319  bool scX = t & 1;
320  bool scY = t & 2;
321 
322  if (scX) {
323  if (scY) return new BinaryMX<true, true>(s);
324  return new BinaryMX<true, false>(s);
325  } else {
326  if (scY) return new BinaryMX<false, true>(s);
327  return new BinaryMX<false, false>(s);
328  }
329  }
330 
331  template<bool ScX, bool ScY>
333  int op;
334  s.unpack("BinaryMX::op", op);
335  op_ = Operation(op);
336  }
337 
338  template<bool ScX, bool ScY>
339  MX BinaryMX<ScX, ScY>::get_solve_triu(const MX& r, bool tr) const {
340  // Identify systems with the structure (I - R)
341  if (!ScX && !ScY && op_ == OP_SUB) {
342  // Is the first term a projected unity matrix?
343  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
344  // Is the second term strictly lower triangular?
345  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_triu(true)) {
346  return dep(1).dep(0)->get_solve_triu_unity(r, tr);
347  }
348  }
349  }
350  // Fall back to default routine
351  return MXNode::get_solve_triu(r, tr);
352  }
353 
354  template<bool ScX, bool ScY>
355  MX BinaryMX<ScX, ScY>::get_solve_tril(const MX& r, bool tr) const {
356  // Identify systems with the structure (I - L)
357  if (!ScX && !ScY && op_ == OP_SUB) {
358  // Is the first term a projected unity matrix?
359  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
360  // Is the second term strictly lower triangular?
361  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_tril(true)) {
362  return dep(1).dep(0)->get_solve_tril_unity(r, tr);
363  }
364  }
365  }
366  // Fall back to default routine
367  return MXNode::get_solve_tril(r, tr);
368  }
369 
370 } // namespace casadi
371 
372 #endif // CASADI_BINARY_MX_IMPL_HPP
Represents any binary operation that involves two matrices.
Definition: binary_mx.hpp:41
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
~BinaryMX() override
Destructor.
casadi_int op() const override
Get the operation.
Definition: binary_mx.hpp:61
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
MX get_solve_tril(const MX &r, bool tr) const override
Solve a system of linear equations, lower triangular A.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
MX get_unary(casadi_int op) const override
Get a unary operation.
MX _get_binary(casadi_int op, const MX &y, bool scX, bool scY) const override
Get a binary operation operation.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
MX get_solve_triu(const MX &r, bool tr) const override
Solve a system of linear equations, upper triangular A.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
BinaryMX(Operation op, const MX &x, const MX &y)
Constructor.
int eval_gen(const T *const *arg, T *const *res, casadi_int *iw, T *w) const
Evaluate the function (template)
void eval_linear(const std::vector< std::array< MX, 3 > > &arg, std::vector< std::array< MX, 3 > > &res) const override
Evaluate the MX node on a const/linear/nonlinear partition.
Operation op_
Operation.
Definition: binary_mx.hpp:171
void serialize_type(SerializingStream &s) const override
Serialize type information.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
Helper class for C code generation.
bool codegen_scalars
Codegen scalar.
std::string work(casadi_int n, casadi_int sz, bool is_ref) const
std::string print_op(casadi_int op, const std::string &a0)
Print an operation to a c file.
void local(const std::string &name, const std::string &type, const std::string &ref="")
Declare a local variable.
std::string workel(casadi_int n) const
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
bool is_scalar(bool scalar_and_dense=false) const
Check if the matrix expression is scalar.
static bool simplification_on_the_fly
Indicates whether simplifications should be made on the fly.
Node class for MX objects.
Definition: mx_node.hpp:51
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
Definition: mx_node.cpp:528
virtual MX get_solve_triu(const MX &r, bool tr) const
Solve a system of linear equations, upper triangular A.
Definition: mx_node.cpp:617
virtual MX get_solve_tril_unity(const MX &r, bool tr) const
Solve a system of linear equations, lower triangular A, unity diagnal.
Definition: mx_node.cpp:641
virtual MX get_solve_tril(const MX &r, bool tr) const
Solve a system of linear equations, lower triangular A.
Definition: mx_node.cpp:625
virtual MX get_unary(casadi_int op) const
Get a unary operation.
Definition: mx_node.cpp:778
virtual MX get_solve_triu_unity(const MX &r, bool tr) const
Solve a system of linear equations, upper triangular A, unity diagonal.
Definition: mx_node.cpp:633
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
Definition: mx_node.cpp:523
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
Definition: mx_node.cpp:222
void set_dep(const MX &dep)
Set unary dependency.
Definition: mx_node.cpp:226
virtual MX _get_binary(casadi_int op, const MX &y, bool scX, bool scY) const
Get a binary operation operation (matrix-matrix)
Definition: mx_node.cpp:843
MX - Matrix expression.
Definition: mx.hpp:92
const Sparsity & sparsity() const
Get the sparsity pattern.
Definition: mx.cpp:592
static bool is_equal(const MX &x, const MX &y, casadi_int depth=0)
Definition: mx.cpp:838
MX dep(casadi_int ch=0) const
Get the nth dependency as MX.
Definition: mx.cpp:754
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
The casadi namespace.
Definition: archiver.cpp:28
double if_else_zero(double x, double y)
Conditional assignment.
Definition: calculus.hpp:289
unsigned long long bvec_t
T dot(const std::vector< T > &a, const std::vector< T > &b)
Operation
Enum for quick access to any node.
Definition: calculus.hpp:60
@ OP_IF_ELSE_ZERO
Definition: calculus.hpp:71
@ OP_AND
Definition: calculus.hpp:70
@ OP_OR
Definition: calculus.hpp:70
@ OP_SUB
Definition: calculus.hpp:65
@ OP_PROJECT
Definition: calculus.hpp:169
@ OP_ADD
Definition: calculus.hpp:65
@ OP_DIV
Definition: calculus.hpp:65
@ OP_MUL
Definition: calculus.hpp:65
static void der(unsigned char op, const T &x, const T &y, const T &f, T *d)
Evaluate a built in derivative function.
Definition: calculus.hpp:1375
static void fun_linear(unsigned char op, const T *x, const T *y, T *f)
Evaluate function on a const/linear/nonlinear partition.
Definition: calculus.hpp:1552
static std::string print(unsigned char op, const std::string &x, const std::string &y)
Print.
Definition: calculus.hpp:1641
static void fun(unsigned char op, const T &x, const T &y, T &f)
Evaluate a built in function (scalar-scalar)
Definition: calculus.hpp:1289