mx_node.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_MX_NODE_HPP
27 #define CASADI_MX_NODE_HPP
28 
29 #include "mx.hpp"
30 #include "shared_object.hpp"
31 #include "sx_elem.hpp"
32 #include "calculus.hpp"
33 #include "code_generator.hpp"
34 #include "linsol.hpp"
35 #include <vector>
36 #include <stack>
37 #include <array>
38 
39 namespace casadi {
40 
41  class SerializingStream;
42  class DeserializingStream;
43 
51  class CASADI_EXPORT MXNode : public SharedObjectInternal {
52  friend class MX;
53 
54  public:
56  MXNode();
57 
61  ~MXNode() override=0;
62 
66  virtual bool __nonzero__() const;
67 
71  virtual bool is_zero() const { return false;}
72 
76  virtual bool is_one() const { return false;}
77 
81  virtual bool is_minus_one() const { return false;}
82 
86  virtual bool is_value(double val) const { return false;}
87 
91  virtual bool is_eye() const { return false;}
92 
96  virtual bool is_unary() const { return false;}
97 
101  virtual bool is_binary() const { return false;}
102 
106  void can_inline(std::map<const MXNode*, casadi_int>& nodeind) const;
107 
111  std::string print_compact(std::map<const MXNode*, casadi_int>& nodeind,
112  std::vector<std::string>& intermed) const;
113 
117  virtual std::string disp(const std::vector<std::string>& arg) const = 0;
118 
122  virtual void add_dependency(CodeGenerator& g) const {}
123 
127  virtual bool has_refcount() const { return false;}
128 
132  virtual void codegen_incref(CodeGenerator& g, std::set<void*>& added) const {}
133 
137  virtual void codegen_decref(CodeGenerator& g, std::set<void*>& added) const {}
138 
142  virtual void generate(CodeGenerator& g,
143  const std::vector<casadi_int>& arg,
144  const std::vector<casadi_int>& res,
145  const std::vector<bool>& arg_is_ref,
146  std::vector<bool>& res_is_ref) const;
147 
148  void generate_copy(CodeGenerator& g,
149  const std::vector<casadi_int>& arg,
150  const std::vector<casadi_int>& res,
151  const std::vector<bool>& arg_is_ref,
152  std::vector<bool>& res_is_ref,
153  casadi_int i) const;
154 
158  virtual int eval(const double** arg, double** res, casadi_int* iw, double* w) const;
159 
163  virtual int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const;
164 
168  virtual void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const;
169 
173  virtual void eval_linear(const std::vector<std::array<MX, 3> >& arg,
174  std::vector<std::array<MX, 3> >& res) const;
175 
179  void eval_linear_unary(const std::vector<std::array<MX, 3> >& arg,
180  std::vector<std::array<MX, 3> >& res) const;
181 
188  void eval_linear_rearrange(const std::vector<std::array<MX, 3> >& arg,
189  std::vector<std::array<MX, 3> >& res) const;
190 
194  virtual void ad_forward(const std::vector<std::vector<MX> >& fseed,
195  std::vector<std::vector<MX> >& fsens) const;
196 
200  virtual void ad_reverse(const std::vector<std::vector<MX> >& aseed,
201  std::vector<std::vector<MX> >& asens) const;
202 
206  virtual int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const;
207 
211  virtual int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const;
212 
216  virtual const std::string& name() const;
217 
221  std::string class_name() const override;
222 
226  void disp(std::ostream& stream, bool more) const override;
227 
231  virtual bool is_valid_input() const { return false;}
232 
236  virtual casadi_int n_primitives() const;
237 
241  virtual void primitives(std::vector<MX>::iterator& it) const;
242 
244 
247  virtual void split_primitives(const MX& x, std::vector<MX>::iterator& it) const;
248  virtual void split_primitives(const SX& x, std::vector<SX>::iterator& it) const;
249  virtual void split_primitives(const DM& x, std::vector<DM>::iterator& it) const;
251 
253  template<typename T>
254  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
255 
257 
260  virtual MX join_primitives(std::vector<MX>::const_iterator& it) const;
261  virtual SX join_primitives(std::vector<SX>::const_iterator& it) const;
262  virtual DM join_primitives(std::vector<DM>::const_iterator& it) const;
264 
270  virtual bool has_duplicates() const;
271 
277  virtual void reset_input() const;
278 
282  virtual bool is_output() const {return false;}
283 
287  virtual bool has_output() const {return false;}
288 
292  virtual casadi_int which_output() const;
293 
297  virtual const Function& which_function() const;
298 
302  virtual casadi_int op() const = 0;
303 
305  virtual Dict info() const;
306 
310  void serialize(SerializingStream& s) const;
311 
315  virtual void serialize_body(SerializingStream& s) const;
316 
323  virtual void serialize_type(SerializingStream& s) const;
324 
331  static MXNode* deserialize(DeserializingStream& s);
332 
336  static bool is_equal(const MXNode* x, const MXNode* y, casadi_int depth);
337  virtual bool is_equal(const MXNode* node, casadi_int depth) const { return false;}
338 
342  inline static bool maxDepth() { return MX::get_max_depth();}
343 
349  bool sameOpAndDeps(const MXNode* node, casadi_int depth) const;
350 
354  const MX& dep(casadi_int ind=0) const { return dep_.at(ind);}
355 
359  casadi_int n_dep() const;
360 
364  virtual casadi_int nout() const { return 1;}
365 
369  virtual MX get_output(casadi_int oind) const;
370 
372  const Sparsity& sparsity() const { return sparsity_;}
373 
375  virtual const Sparsity& sparsity(casadi_int oind) const;
376 
377  template<class T>
378  bool matches_sparsity(const std::vector<T>& arg) const {
379  for (casadi_int i=0;i<dep_.size();++i) {
380  if (dep_[i].sparsity()!=arg[i].sparsity()) {
381  return false;
382  }
383  }
384  return true;
385  }
386 
388  casadi_int numel() const { return sparsity().numel(); }
389  casadi_int nnz(casadi_int i=0) const { return sparsity(i).nnz(); }
390  casadi_int size1() const { return sparsity().size1(); }
391  casadi_int size2() const { return sparsity().size2(); }
392  std::pair<casadi_int, casadi_int> size() const { return sparsity().size();}
393 
394  // Get IO index
395  virtual casadi_int ind() const;
396 
397  // Get IO segment
398  virtual casadi_int segment() const;
399 
400  // Get IO offset
401  virtual casadi_int offset() const;
402 
404  void set_sparsity(const Sparsity& sparsity);
405 
409  virtual size_t sz_arg() const { return n_dep();}
410 
414  virtual size_t sz_res() const { return nout();}
415 
419  virtual size_t sz_iw() const { return 0;}
420 
424  virtual size_t sz_w() const { return 0;}
425 
427  void set_dep(const MX& dep);
428 
430  void set_dep(const MX& dep1, const MX& dep2);
431 
433  void set_dep(const MX& dep1, const MX& dep2, const MX& dep3);
434 
436  void set_dep(const std::vector<MX>& dep);
437 
439  void check_dep() const;
440 
442  inline static MX to_matrix(const MX& x, const Sparsity& sp) {
443  if (x.size()==sp.size()) {
444  return x;
445  } else {
446  return MX(sp, x);
447  }
448  }
449 
451  virtual double to_double() const;
452 
454  virtual DM get_DM() const;
455 
457  virtual casadi_int n_inplace() const { return 0;}
458 
460  virtual Matrix<casadi_int> mapping() const;
461 
463  virtual MX get_horzcat(const std::vector<MX>& x) const;
464 
466  virtual std::vector<MX> get_horzsplit(const std::vector<casadi_int>& output_offset) const;
467 
469  virtual MX get_repmat(casadi_int m, casadi_int n) const;
470 
472  virtual MX get_repsum(casadi_int m, casadi_int n) const;
473 
475  virtual MX get_vertcat(const std::vector<MX>& x) const;
476 
478  virtual std::vector<MX> get_vertsplit(const std::vector<casadi_int>& output_offset) const;
479 
481  virtual MX get_diagcat(const std::vector<MX>& x) const;
482 
484  virtual std::vector<MX> get_diagsplit(const std::vector<casadi_int>& offset1,
485  const std::vector<casadi_int>& offset2) const;
486 
488  virtual MX get_transpose() const;
489 
491  virtual MX get_reshape(const Sparsity& sp) const;
492 
494  virtual MX get_sparsity_cast(const Sparsity& sp) const;
495 
499  virtual MX get_mac(const MX& y, const MX& z) const;
500 
504  virtual MX get_einstein(const MX& A, const MX& B,
505  const std::vector<casadi_int>& dim_c, const std::vector<casadi_int>& dim_a,
506  const std::vector<casadi_int>& dim_b,
507  const std::vector<casadi_int>& c, const std::vector<casadi_int>& a,
508  const std::vector<casadi_int>& b) const;
509 
513  virtual MX get_bilin(const MX& x, const MX& y) const;
514 
518  virtual MX get_rank1(const MX& alpha, const MX& x, const MX& y) const;
519 
523  virtual MX get_logsumexp() const;
524 
532  virtual MX get_solve(const MX& r, bool tr, const Linsol& linear_solver) const;
533 
541  virtual MX get_solve_triu(const MX& r, bool tr) const;
542 
550  virtual MX get_solve_tril(const MX& r, bool tr) const;
551 
559  virtual MX get_solve_triu_unity(const MX& r, bool tr) const;
560 
568  virtual MX get_solve_tril_unity(const MX& r, bool tr) const;
569 
577  virtual MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const;
578 
582  virtual MX get_nz_ref(const MX& nz) const;
583 
587  virtual MX get_nz_ref(const MX& inner, const Slice& outer) const;
588 
592  virtual MX get_nz_ref(const Slice& inner, const MX& outer) const;
593 
597  virtual MX get_nz_ref(const MX& inner, const MX& outer) const;
598 
605  virtual MX get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const;
606 
613  virtual MX get_nzadd(const MX& y, const std::vector<casadi_int>& nz) const;
614 
621  virtual MX get_nzassign(const MX& y, const MX& nz) const;
622 
629  virtual MX get_nzassign(const MX& y, const MX& inner, const Slice& outer) const;
630 
637  virtual MX get_nzassign(const MX& y, const Slice& inner, const MX& outer) const;
638 
645  virtual MX get_nzassign(const MX& y, const MX& inner, const MX& outer) const;
646 
653  virtual MX get_nzadd(const MX& y, const MX& nz) const;
654 
661  virtual MX get_nzadd(const MX& y, const MX& inner, const Slice& outer) const;
662 
669  virtual MX get_nzadd(const MX& y, const Slice& inner, const MX& outer) const;
670 
677  virtual MX get_nzadd(const MX& y, const MX& inner, const MX& outer) const;
678 
680  virtual MX get_subref(const Slice& i, const Slice& j) const;
681 
683  virtual MX get_subassign(const MX& y, const Slice& i, const Slice& j) const;
684 
686  virtual MX get_project(const Sparsity& sp) const;
687 
689  virtual MX get_unary(casadi_int op) const;
690 
692  MX get_binary(casadi_int op, const MX& y) const;
693 
695  virtual MX _get_binary(casadi_int op, const MX& y, bool scX, bool scY) const;
696 
698  virtual MX get_det() const;
699 
701  virtual MX get_inv() const;
702 
704  virtual MX get_dot(const MX& y) const;
705 
707  virtual MX get_norm_fro() const;
708 
710  virtual MX get_norm_2() const;
711 
713  virtual MX get_norm_inf() const;
714 
716  virtual MX get_norm_1() const;
717 
719  virtual MX get_mmin() const;
720 
722  virtual MX get_mmax() const;
723 
725  MX get_assert(const MX& y, const std::string& fail_message) const;
726 
728  MX get_monitor(const std::string& comment) const;
729 
731  MX get_find() const;
732 
734  MX get_low(const MX& v, const Dict& options) const;
735 
737  MX get_bspline(const std::vector<double>& knots,
738  const std::vector<casadi_int>& offset,
739  const std::vector<double>& coeffs,
740  const std::vector<casadi_int>& degree,
741  casadi_int m,
742  const std::vector<casadi_int>& lookup_mode) const;
744  MX get_bspline(const MX& coeffs, const std::vector<double>& knots,
745  const std::vector<casadi_int>& offset,
746  const std::vector<casadi_int>& degree,
747  casadi_int m,
748  const std::vector<casadi_int>& lookup_mode) const;
749 
751  MX get_convexify(const Dict& opts) const;
752 
757  mutable casadi_int temp;
758 
762  std::vector<MX> dep_;
763 
768 
772  static void copy_fwd(const bvec_t* arg, bvec_t* res, casadi_int len);
773 
777  static void copy_rev(bvec_t* arg, bvec_t* res, casadi_int len);
778 
779  static std::map<casadi_int, MXNode* (*)(DeserializingStream&)> deserialize_map;
780 
781  protected:
785  explicit MXNode(DeserializingStream& s);
786  };
787 
789 } // namespace casadi
790 
791 #endif // CASADI_MX_NODE_HPP
Helper class for C code generation.
Helper class for Serialization.
Function object.
Definition: function.hpp:60
std::pair< casadi_int, casadi_int > size() const
Get the shape.
Linear solver.
Definition: linsol.hpp:55
Node class for MX objects.
Definition: mx_node.hpp:51
virtual bool has_output() const
Check if a multiple output node.
Definition: mx_node.hpp:287
void eval_linear_unary(const std::vector< std::array< MX, 3 > > &arg, std::vector< std::array< MX, 3 > > &res) const
Evaluate the MX node on a const/linear/nonlinear partition.
virtual bool is_zero() const
Check if identically zero.
Definition: mx_node.hpp:71
virtual size_t sz_arg() const
Get required length of arg field.
Definition: mx_node.hpp:409
virtual bool is_valid_input() const
Check if valid function input.
Definition: mx_node.hpp:231
static bool maxDepth()
Get equality checking depth.
Definition: mx_node.hpp:342
virtual size_t sz_w() const
Get required length of w field.
Definition: mx_node.hpp:424
virtual bool is_one() const
Check if identically one.
Definition: mx_node.hpp:76
virtual bool is_binary() const
Check if binary operation.
Definition: mx_node.hpp:101
virtual void add_dependency(CodeGenerator &g) const
Add a dependent function.
Definition: mx_node.hpp:122
virtual casadi_int n_inplace() const
Can the operation be performed inplace (i.e. overwrite the result)
Definition: mx_node.hpp:457
std::pair< casadi_int, casadi_int > size() const
Definition: mx_node.hpp:392
Sparsity sparsity_
The sparsity pattern.
Definition: mx_node.hpp:767
casadi_int temp
Definition: mx_node.hpp:757
casadi_int numel() const
Get shape.
Definition: mx_node.hpp:388
static std::map< casadi_int, MXNode *(*)(DeserializingStream &)> deserialize_map
Definition: mx_node.hpp:779
const Sparsity & sparsity() const
Get the sparsity.
Definition: mx_node.hpp:372
virtual size_t sz_res() const
Get required length of res field.
Definition: mx_node.hpp:414
casadi_int size2() const
Definition: mx_node.hpp:391
casadi_int nnz(casadi_int i=0) const
Definition: mx_node.hpp:389
bool matches_sparsity(const std::vector< T > &arg) const
Definition: mx_node.hpp:378
virtual void codegen_incref(CodeGenerator &g, std::set< void * > &added) const
Codegen incref.
Definition: mx_node.hpp:132
virtual casadi_int nout() const
Number of outputs.
Definition: mx_node.hpp:364
virtual bool is_value(double val) const
Check if a certain value.
Definition: mx_node.hpp:86
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
Definition: mx_node.hpp:354
std::vector< MX > dep_
dependencies - functions that have to be evaluated before this one
Definition: mx_node.hpp:762
virtual bool is_unary() const
Check if unary operation.
Definition: mx_node.hpp:96
virtual void codegen_decref(CodeGenerator &g, std::set< void * > &added) const
Codegen decref.
Definition: mx_node.hpp:137
virtual casadi_int op() const =0
Get the operation.
casadi_int size1() const
Definition: mx_node.hpp:390
virtual bool has_refcount() const
Is reference counting needed in codegen?
Definition: mx_node.hpp:127
virtual bool is_minus_one() const
Check if identically minus one.
Definition: mx_node.hpp:81
virtual std::string disp(const std::vector< std::string > &arg) const =0
Print expression.
virtual bool is_output() const
Check if evaluation output.
Definition: mx_node.hpp:282
virtual bool is_equal(const MXNode *node, casadi_int depth) const
Definition: mx_node.hpp:337
virtual size_t sz_iw() const
Get required length of iw field.
Definition: mx_node.hpp:419
virtual bool is_eye() const
Check if identity matrix.
Definition: mx_node.hpp:91
static MX to_matrix(const MX &x, const Sparsity &sp)
Convert scalar to matrix.
Definition: mx_node.hpp:442
MX - Matrix expression.
Definition: mx.hpp:92
static casadi_int get_max_depth()
Get the depth to which equalities are being checked for simplifications.
Definition: mx.cpp:913
Sparse matrix class. SX and DM are specializations.
Definition: matrix_decl.hpp:99
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
Class representing a Slice.
Definition: slice.hpp:48
General sparsity class.
Definition: sparsity.hpp:106
std::pair< casadi_int, casadi_int > size() const
Get the shape.
Definition: sparsity.cpp:152
The casadi namespace.
Definition: archiver.cpp:28
bool is_equal(double x, double y, casadi_int depth=0)
Definition: calculus.hpp:281
unsigned long long bvec_t
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.