binary_mx_impl.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
28 
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
33 #include <sstream>
34 #include <vector>
35 
36 namespace casadi {
37 
38  template<bool ScX, bool ScY>
39  BinaryMX<ScX, ScY>::BinaryMX(Operation op, const MX& x, const MX& y) : op_(op) {
40  set_dep(x, y);
41  if (ScX) {
42  set_sparsity(y.sparsity());
43  } else {
44  set_sparsity(x.sparsity());
45  }
46  }
47 
48  template<bool ScX, bool ScY>
49  BinaryMX<ScX, ScY>::~BinaryMX() {
50  }
51 
52  template<bool ScX, bool ScY>
53  std::string BinaryMX<ScX, ScY>::disp(const std::vector<std::string>& arg) const {
54  return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
55  }
56 
57  template<bool ScX, bool ScY>
58  void BinaryMX<ScX, ScY>::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
59  casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
60  }
61 
62  template<bool ScX, bool ScY>
63  void BinaryMX<ScX, ScY>::eval_linear(const std::vector<std::array<MX, 3> >& arg,
64  std::vector<std::array<MX, 3> >& res) const {
65  casadi_math<MX>::fun_linear(op_, arg[0].data(), arg[1].data(), res[0].data());
66  }
67 
68  template<bool ScX, bool ScY>
69  void BinaryMX<ScX, ScY>::ad_forward(const std::vector<std::vector<MX> >& fseed,
70  std::vector<std::vector<MX> >& fsens) const {
71  // Get partial derivatives
72  MX pd[2];
73  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
74 
75  // Propagate forward seeds
76  for (casadi_int d=0; d<fsens.size(); ++d) {
77  if (op_ == OP_IF_ELSE_ZERO) {
78  fsens[d][0] = if_else_zero(pd[1], fseed[d][1]);
79  } else {
80  fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
81  }
82  }
83  }
84 
85  template<bool ScX, bool ScY>
86  void BinaryMX<ScX, ScY>::ad_reverse(const std::vector<std::vector<MX> >& aseed,
87  std::vector<std::vector<MX> >& asens) const {
88  // Get partial derivatives
89  MX pd[2];
90  casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
91 
92  // Propagate adjoint seeds
93  for (casadi_int d=0; d<aseed.size(); ++d) {
94  MX s = aseed[d][0];
95  if (op_ == OP_IF_ELSE_ZERO) {
96  // Special case to avoid NaN propagation
97  if (!s.is_scalar() && dep(1).is_scalar()) {
98  asens[d][1] += dot(dep(0), s);
99  } else {
100  asens[d][1] += if_else_zero(dep(0), s);
101  }
102  } else {
103  // General case
104  for (casadi_int c=0; c<2; ++c) {
105  // Get increment of sensitivity c
106  MX t = pd[c]*s;
107 
108  // If dimension mismatch (i.e. one argument is scalar), then sum all the entries
109  if (!t.is_scalar() && t.size() != dep(c).size()) {
110  if (pd[c].size()!=s.size()) pd[c] = MX(s.sparsity(), pd[c]);
111  t = dot(pd[c], s);
112  }
113 
114  // Propagate the seeds
115  asens[d][c] += t;
116  }
117  }
118  }
119  }
120 
121  template<bool ScX, bool ScY>
122  void BinaryMX<ScX, ScY>::
123  generate(CodeGenerator& g,
124  const std::vector<casadi_int>& arg, const std::vector<casadi_int>& res,
125  const std::vector<bool>& arg_is_ref, std::vector<bool>& res_is_ref) const {
126  // Quick return if nothing to do
127  if (nnz()==0) return;
128 
129  // Check if inplace
130  bool inplace;
131  switch (op_) {
132  case OP_ADD:
133  case OP_SUB:
134  case OP_MUL:
135  case OP_DIV:
136  inplace = res[0]==arg[0] && !arg_is_ref[0];
137  break;
138  default:
139  inplace = false;
140  break;
141  }
142 
143  // Scalar names of arguments (start assuming all scalars)
144  std::string r = g.workel(res[0]);
145  std::string x = g.workel(arg[0]);
146  std::string y = g.workel(arg[1]);
147 
148  // Avoid emitting '/*' which will be mistaken for a comment
149  if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
150  y = "(" + y + ")";
151  }
152 
153  // Codegen loop, if needed
154  if (nnz()>1) {
155  // Iterate over result
156  g.local("rr", "casadi_real", "*");
157  g.local("i", "casadi_int");
158  g << "for (i=0, " << "rr=" << g.work(res[0], nnz(), false);
159  r = "(*rr++)";
160 
161  // Iterate over first argument?
162  if (!ScX && !inplace) {
163  g.local("cr", "const casadi_real", "*");
164  g << ", cr=" << g.work(arg[0], dep(0).nnz(), arg_is_ref[0]);
165  if (op_==OP_OR || op_==OP_AND) {
166  // Avoid short-circuiting with side effects
167  x = "cr[i]";
168  } else {
169  x = "(*cr++)";
170  }
171 
172  }
173 
174  // Iterate over second argument?
175  if (!ScY) {
176  g.local("cs", "const casadi_real", "*");
177  g << ", cs=" << g.work(arg[1], dep(1).nnz(), arg_is_ref[1]);
178  if (op_==OP_OR || op_==OP_AND || op_==OP_IF_ELSE_ZERO) {
179  // Avoid short-circuiting with side effects
180  y = "cs[i]";
181  } else {
182  y = "(*cs++)";
183  }
184  }
185 
186  // Close loop
187  g << "; i<" << nnz() << "; ++i) ";
188  }
189 
190  // Perform operation
191  g << r << " ";
192  if (inplace) {
193  g << casadi_math<double>::sep(op_) << "= " << y;
194  } else {
195  g << " = " << g.print_op(op_, x, y);
196  }
197  g << ";\n";
198  }
199 
200  template<bool ScX, bool ScY>
201  int BinaryMX<ScX, ScY>::
202  eval(const double** arg, double** res, casadi_int* iw, double* w) const {
203  return eval_gen<double>(arg, res, iw, w);
204  }
205 
206  template<bool ScX, bool ScY>
207  int BinaryMX<ScX, ScY>::
208  eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
209  return eval_gen<SXElem>(arg, res, iw, w);
210  }
211 
212  template<bool ScX, bool ScY>
213  template<typename T>
214  int BinaryMX<ScX, ScY>::
215  eval_gen(const T* const* arg, T* const* res, casadi_int* iw, T* w) const {
216  // Get data
217  T* output0 = res[0];
218  const T* input0 = arg[0];
219  const T* input1 = arg[1];
220 
221  if (!ScX && !ScY) {
222  casadi_math<T>::fun(op_, input0, input1, output0, nnz());
223  } else if (ScX) {
224  casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
225  } else {
226  casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
227  }
228  return 0;
229  }
230 
231  template<bool ScX, bool ScY>
232  int BinaryMX<ScX, ScY>::
233  sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
234  const bvec_t *a0=arg[0], *a1=arg[1];
235  bvec_t *r=res[0];
236  casadi_int n=nnz();
237  for (casadi_int i=0; i<n; ++i) {
238  if (ScX && ScY)
239  *r++ = *a0 | *a1;
240  else if (ScX && !ScY)
241  *r++ = *a0 | *a1++;
242  else if (!ScX && ScY)
243  *r++ = *a0++ | *a1;
244  else
245  *r++ = *a0++ | *a1++;
246  }
247  return 0;
248  }
249 
250  template<bool ScX, bool ScY>
251  int BinaryMX<ScX, ScY>::
252  sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
253  bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
254  casadi_int n=nnz();
255  for (casadi_int i=0; i<n; ++i) {
256  bvec_t s = *r;
257  *r++ = 0;
258  if (ScX)
259  *a0 |= s;
260  else
261  *a0++ |= s;
262  if (ScY)
263  *a1 |= s;
264  else
265  *a1++ |= s;
266  }
267  return 0;
268  }
269 
270  template<bool ScX, bool ScY>
271  MX BinaryMX<ScX, ScY>::get_unary(casadi_int op) const {
272  //switch (op_) {
273  //default: break; // no rule
274  //}
275 
276  // Fallback to default implementation
277  return MXNode::get_unary(op);
278  }
279 
280  template<bool ScX, bool ScY>
281  MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op, const MX& y, bool scX, bool scY) const {
282  if (!GlobalOptions::simplification_on_the_fly) return MXNode::_get_binary(op, y, scX, scY);
283 
284  switch (op_) {
285  case OP_ADD:
286  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return dep(1);
287  if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
288  break;
289  case OP_SUB:
290  if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return -dep(1);
291  if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
292  break;
293  default: break; // no rule
294  }
295 
296  // Fallback to default implementation
297  return MXNode::_get_binary(op, y, scX, scY);
298  }
299 
300  template<bool ScX, bool ScY>
301  void BinaryMX<ScX, ScY>::serialize_body(SerializingStream& s) const {
302  MXNode::serialize_body(s);
303  s.pack("BinaryMX::op", static_cast<int>(op_));
304  }
305 
306  template<bool ScX, bool ScY>
307  void BinaryMX<ScX, ScY>::serialize_type(SerializingStream& s) const {
308  MXNode::serialize_type(s);
309  char type_x = ScX;
310  char type_y = ScY;
311  char type = type_x | (type_y << 1);
312  s.pack("BinaryMX::scalar_flags", type);
313  }
314 
315  template<bool ScX, bool ScY>
316  MXNode* BinaryMX<ScX, ScY>::deserialize(DeserializingStream& s) {
317  char t;
318  s.unpack("BinaryMX::scalar_flags", t);
319  bool scX = t & 1;
320  bool scY = t & 2;
321 
322  if (scX) {
323  if (scY) return new BinaryMX<true, true>(s);
324  return new BinaryMX<true, false>(s);
325  } else {
326  if (scY) return new BinaryMX<false, true>(s);
327  return new BinaryMX<false, false>(s);
328  }
329  }
330 
331  template<bool ScX, bool ScY>
332  BinaryMX<ScX, ScY>::BinaryMX(DeserializingStream& s) : MXNode(s) {
333  int op;
334  s.unpack("BinaryMX::op", op);
335  op_ = Operation(op);
336  }
337 
338  template<bool ScX, bool ScY>
339  MX BinaryMX<ScX, ScY>::get_solve_triu(const MX& r, bool tr) const {
340  // Identify systems with the structure (I - R)
341  if (!ScX && !ScY && op_ == OP_SUB) {
342  // Is the first term a projected unity matrix?
343  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
344  // Is the second term strictly lower triangular?
345  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_triu(true)) {
346  return dep(1).dep(0)->get_solve_triu_unity(r, tr);
347  }
348  }
349  }
350  // Fall back to default routine
351  return MXNode::get_solve_triu(r, tr);
352  }
353 
354  template<bool ScX, bool ScY>
355  MX BinaryMX<ScX, ScY>::get_solve_tril(const MX& r, bool tr) const {
356  // Identify systems with the structure (I - L)
357  if (!ScX && !ScY && op_ == OP_SUB) {
358  // Is the first term a projected unity matrix?
359  if (dep(0).is_op(OP_PROJECT) && dep(0).dep(0).is_eye()) {
360  // Is the second term strictly lower triangular?
361  if (dep(1).is_op(OP_PROJECT) && dep(1).dep(0).sparsity().is_tril(true)) {
362  return dep(1).dep(0)->get_solve_tril_unity(r, tr);
363  }
364  }
365  }
366  // Fall back to default routine
367  return MXNode::get_solve_tril(r, tr);
368  }
369 
370 } // namespace casadi
371 
372 #endif // CASADI_BINARY_MX_IMPL_HPP
The casadi namespace.
Definition: archiver.hpp:32