binary_mx_impl.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
28 
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
33 #include <sstream>
34 #include <vector>
35 
36 namespace casadi {
37 
38  template<bool ScX, bool ScY>
39  BinaryMX<ScX, ScY>::BinaryMX(Operation op, const MX& x, const MX& y) : op_(op) {
40  set_dep(x, y);
41  if (ScX) {
42  set_sparsity(y.sparsity());
43  } else {
44  set_sparsity(x.sparsity());
45  }
46  }
47 
48  template<bool ScX, bool ScY>
49  BinaryMX<ScX, ScY>::~BinaryMX() {
50  }
51 
52  template<bool ScX, bool ScY>
53  std::string BinaryMX<ScX, ScY>::disp(const std::vector<std::string>& arg) const {
54  return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
55  }
56 
57  template<bool ScX, bool ScY>
58  void BinaryMX<ScX, ScY>::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
59  casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
60  }
61 
62  template<bool ScX, bool ScY>
63  void BinaryMX<ScX, ScY>::ad_forward(const std::vector<std::vector<MX> >& fseed,
64  std::vector<std::vector<MX> >& fsens) const {
65  // Get partial derivatives
66  MX pd[2];
67  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
68 
69  // Propagate forward seeds
70  for (casadi_int d=0; d<fsens.size(); ++d) {
71  if (op_ == OP_IF_ELSE_ZERO) {
72  fsens[d][0] = if_else_zero(pd[1], fseed[d][1]);
73  } else {
74  fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
75  }
76  }
77  }
78 
79  template<bool ScX, bool ScY>
80  void BinaryMX<ScX, ScY>::ad_reverse(const std::vector<std::vector<MX> >& aseed,
81  std::vector<std::vector<MX> >& asens) const {
82  // Get partial derivatives
83  MX pd[2];
84  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
85 
86  // Propagate adjoint seeds
87  for (casadi_int d=0; d<aseed.size(); ++d) {
88  MX s = aseed[d][0];
89  if (op_ == OP_IF_ELSE_ZERO) {
90  // Special case to avoid NaN propagation
91  if (!s.is_scalar() && dep(1).is_scalar()) {
92  asens[d][1] += dot(dep(0), s);
93  } else {
94  asens[d][1] += if_else_zero(dep(0), s);
95  }
96  } else {
97  // General case
98  for (casadi_int c=0; c<2; ++c) {
99  // Get increment of sensitivity c
100  MX t = pd[c]*s;
101 
102  // If dimension mismatch (i.e. one argument is scalar), then sum all the entries
103  if (!t.is_scalar() && t.size() != dep(c).size()) {
104  if (pd[c].size()!=s.size()) pd[c] = MX(s.sparsity(), pd[c]);
105  t = dot(pd[c], s);
106  }
107 
108  // Propagate the seeds
109  asens[d][c] += t;
110  }
111  }
112  }
113  }
114 
115  template<bool ScX, bool ScY>
116  void BinaryMX<ScX, ScY>::
117  generate(CodeGenerator& g,
118  const std::vector<casadi_int>& arg, const std::vector<casadi_int>& res) const {
119  // Quick return if nothing to do
120  if (nnz()==0) return;
121 
122  // Check if inplace
123  bool inplace;
124  switch (op_) {
125  case OP_ADD:
126  case OP_SUB:
127  case OP_MUL:
128  case OP_DIV:
129  inplace = res[0]==arg[0];
130  break;
131  default:
132  inplace = false;
133  break;
134  }
135 
136  // Scalar names of arguments (start assuming all scalars)
137  std::string r = g.workel(res[0]);
138  std::string x = g.workel(arg[0]);
139  std::string y = g.workel(arg[1]);
140 
141  // Avoid emitting '/*' which will be mistaken for a comment
142  if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
143  y = "(" + y + ")";
144  }
145 
146  // Codegen loop, if needed
147  if (nnz()>1) {
148  // Iterate over result
149  g.local("rr", "casadi_real", "*");
150  g.local("i", "casadi_int");
151  g << "for (i=0, " << "rr=" << g.work(res[0], nnz());
152  r = "(*rr++)";
153 
154  // Iterate over first argument?
155  if (!ScX && !inplace) {
156  g.local("cr", "const casadi_real", "*");
157  g << ", cr=" << g.work(arg[0], dep(0).nnz());
158  if (op_==OP_OR || op_==OP_AND) {
159  // Avoid short-circuiting with side effects
160  x = "cr[i]";
161  } else {
162  x = "(*cr++)";
163  }
164 
165  }
166 
167  // Iterate over second argument?
168  if (!ScY) {
169  g.local("cs", "const casadi_real", "*");
170  g << ", cs=" << g.work(arg[1], dep(1).nnz());
171  if (op_==OP_OR || op_==OP_AND || op_==OP_IF_ELSE_ZERO) {
172  // Avoid short-circuiting with side effects
173  y = "cs[i]";
174  } else {
175  y = "(*cs++)";
176  }
177  }
178 
179  // Close loop
180  g << "; i<" << nnz() << "; ++i) ";
181  }
182 
183  // Perform operation
184  g << r << " ";
185  if (inplace) {
186  g << casadi_math<double>::sep(op_) << "= " << y;
187  } else {
188  g << " = " << g.print_op(op_, x, y);
189  }
190  g << ";\n";
191  }
192 
193  template<bool ScX, bool ScY>
194  int BinaryMX<ScX, ScY>::
195  eval(const double** arg, double** res, casadi_int* iw, double* w) const {
196  return eval_gen<double>(arg, res, iw, w);
197  }
198 
199  template<bool ScX, bool ScY>
200  int BinaryMX<ScX, ScY>::
201  eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
202  return eval_gen<SXElem>(arg, res, iw, w);
203  }
204 
205  template<bool ScX, bool ScY>
206  template<typename T>
207  int BinaryMX<ScX, ScY>::
208  eval_gen(const T* const* arg, T* const* res, casadi_int* iw, T* w) const {
209  // Get data
210  T* output0 = res[0];
211  const T* input0 = arg[0];
212  const T* input1 = arg[1];
213 
214  if (!ScX && !ScY) {
215  casadi_math<T>::fun(op_, input0, input1, output0, nnz());
216  } else if (ScX) {
217  casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
218  } else {
219  casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
220  }
221  return 0;
222  }
223 
224  template<bool ScX, bool ScY>
225  int BinaryMX<ScX, ScY>::
226  sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
227  const bvec_t *a0=arg[0], *a1=arg[1];
228  bvec_t *r=res[0];
229  casadi_int n=nnz();
230  for (casadi_int i=0; i<n; ++i) {
231  if (ScX && ScY)
232  *r++ = *a0 | *a1;
233  else if (ScX && !ScY)
234  *r++ = *a0 | *a1++;
235  else if (!ScX && ScY)
236  *r++ = *a0++ | *a1;
237  else
238  *r++ = *a0++ | *a1++;
239  }
240  return 0;
241  }
242 
243  template<bool ScX, bool ScY>
244  int BinaryMX<ScX, ScY>::
245  sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
246  bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
247  casadi_int n=nnz();
248  for (casadi_int i=0; i<n; ++i) {
249  bvec_t s = *r;
250  *r++ = 0;
251  if (ScX)
252  *a0 |= s;
253  else
254  *a0++ |= s;
255  if (ScY)
256  *a1 |= s;
257  else
258  *a1++ |= s;
259  }
260  return 0;
261  }
262 
263  template<bool ScX, bool ScY>
264  MX BinaryMX<ScX, ScY>::get_unary(casadi_int op) const {
265  //switch (op_) {
266  //default: break; // no rule
267  //}
268 
269  // Fallback to default implementation
270  return MXNode::get_unary(op);
271  }
272 
273  template<bool ScX, bool ScY>
274  MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op, const MX& y, bool scX, bool scY) const {
275  if (!GlobalOptions::simplification_on_the_fly) return MXNode::_get_binary(op, y, scX, scY);
276 
277  switch (op_) {
278  case OP_ADD:
279  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return dep(1);
280  if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
281  break;
282  case OP_SUB:
283  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return -dep(1);
284  if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
285  break;
286  default: break; // no rule
287  }
288 
289  // Fallback to default implementation
290  return MXNode::_get_binary(op, y, scX, scY);
291  }
292 
293  template<bool ScX, bool ScY>
294  void BinaryMX<ScX, ScY>::serialize_body(SerializingStream& s) const {
296  s.pack("BinaryMX::op", static_cast<int>(op_));
297  }
298 
299  template<bool ScX, bool ScY>
300  void BinaryMX<ScX, ScY>::serialize_type(SerializingStream& s) const {
302  char type_x = ScX;
303  char type_y = ScY;
304  char type = type_x | (type_y << 1);
305  s.pack("BinaryMX::scalar_flags", type);
306  }
307 
308  template<bool ScX, bool ScY>
309  MXNode* BinaryMX<ScX, ScY>::deserialize(DeserializingStream& s) {
310  char t;
311  s.unpack("BinaryMX::scalar_flags", t);
312  bool scX = t & 1;
313  bool scY = t & 2;
314 
315  if (scX) {
316  if (scY) return new BinaryMX<true, true>(s);
317  return new BinaryMX<true, false>(s);
318  } else {
319  if (scY) return new BinaryMX<false, true>(s);
320  return new BinaryMX<false, false>(s);
321  }
322  }
323 
324  template<bool ScX, bool ScY>
325  BinaryMX<ScX, ScY>::BinaryMX(DeserializingStream& s) : MXNode(s) {
326  int op;
327  s.unpack("BinaryMX::op", op);
328  op_ = Operation(op);
329  }
330 
331  template<bool ScX, bool ScY>
332  MX BinaryMX<ScX, ScY>::get_solve_triu(const MX& r, bool tr) const {
333  // Identify systems with the structure (I - R)
334  if (!ScX && !ScY && op_ == OP_SUB) {
335  // Is the first term a projected unity matrix?
336  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
337  // Is the second term strictly lower triangular?
338  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_triu(true)) {
339  return dep(1).dep(0)->get_solve_triu_unity(r, tr);
340  }
341  }
342  }
343  // Fall back to default routine
344  return MXNode::get_solve_triu(r, tr);
345  }
346 
347  template<bool ScX, bool ScY>
348  MX BinaryMX<ScX, ScY>::get_solve_tril(const MX& r, bool tr) const {
349  // Identify systems with the structure (I - L)
350  if (!ScX && !ScY && op_ == OP_SUB) {
351  // Is the first term a projected unity matrix?
352  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
353  // Is the second term strictly lower triangular?
354  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_tril(true)) {
355  return dep(1).dep(0)->get_solve_tril_unity(r, tr);
356  }
357  }
358  }
359  // Fall back to default routine
360  return MXNode::get_solve_tril(r, tr);
361  }
362 
363 } // namespace casadi
364 
365 #endif // CASADI_BINARY_MX_IMPL_HPP
Helper class for C code generation.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Sparsity sparsity() const
Get the sparsity pattern.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
bool is_scalar(bool scalar_and_dense=false) const
Check if the matrix expression is scalar.
Node class for MX objects.
Definition: mx_node.hpp:50
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
virtual MX _get_binary(casadi_int op, const MX &y, bool scX, bool scY) const
Get a binary operation operation (matrix-matrix)
virtual MX get_unary(casadi_int op) const
Get a unary operation.
virtual MX get_solve_triu(const MX &r, bool tr) const
Solve a system of linear equations, upper triangular A.
virtual MX get_solve_tril(const MX &r, bool tr) const
Solve a system of linear equations, lower triangular A.
MX - Matrix expression.
Definition: mx.hpp:84
MX dep(casadi_int ch=0) const
Get the nth dependency as MX.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
The casadi namespace.