reshape.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "reshape.hpp"
27 #include "casadi_misc.hpp"
28 
29 namespace casadi {
30 
31  Reshape::Reshape(const MX& x, Sparsity sp) {
32  casadi_assert_dev(x.nnz()==sp.nnz());
33  set_dep(x);
34  set_sparsity(sp);
35  }
36 
37  int Reshape::eval(const double** arg, double** res, casadi_int* iw, double* w) const {
38  return eval_gen<double>(arg, res, iw, w);
39  }
40 
41  int Reshape::eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
42  return eval_gen<SXElem>(arg, res, iw, w);
43  }
44 
45  template<typename T>
46  int Reshape::eval_gen(const T** arg, T** res, casadi_int* iw, T* w) const {
47  if (arg[0]!=res[0]) std::copy(arg[0], arg[0]+nnz(), res[0]);
48  return 0;
49  }
50 
51  int Reshape::sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
52  copy_fwd(arg[0], res[0], nnz());
53  return 0;
54  }
55 
56  int Reshape::sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
57  copy_rev(arg[0], res[0], nnz());
58  return 0;
59  }
60 
61  std::string Reshape::disp(const std::vector<std::string>& arg) const {
62  // For vectors, reshape is also a transpose
63  if (dep().is_vector() && sparsity().is_vector()) {
64  // Print as transpose: X'
65  return arg.at(0) + "'";
66  } else {
67  // Print as reshape(X) or vec(X)
68  if (sparsity().is_column()) {
69  return "vec(" + arg.at(0) + ")";
70  } else {
71  return "reshape(" + arg.at(0) + ")";
72  }
73  }
74  }
75 
76  void Reshape::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
77  res[0] = reshape(arg[0], size());
78  }
79 
80  void Reshape::ad_forward(const std::vector<std::vector<MX> >& fseed,
81  std::vector<std::vector<MX> >& fsens) const {
82  for (casadi_int d = 0; d<fsens.size(); ++d) {
83  fsens[d][0] = reshape(fseed[d][0], size());
84  }
85  }
86 
87  void Reshape::ad_reverse(const std::vector<std::vector<MX> >& aseed,
88  std::vector<std::vector<MX> >& asens) const {
89  for (casadi_int d=0; d<aseed.size(); ++d) {
90  asens[d][0] += reshape(aseed[d][0], dep().size());
91  }
92  }
93 
95  const std::vector<casadi_int>& arg,
96  const std::vector<casadi_int>& res,
97  const std::vector<bool>& arg_is_ref,
98  std::vector<bool>& res_is_ref) const {
99  generate_copy(g, arg, res, arg_is_ref, res_is_ref, 0);
100  }
101 
102  MX Reshape::get_reshape(const Sparsity& sp) const {
103  return reshape(dep(0), sp);
104  }
105 
107  // For vectors, reshape is also a transpose
108  if (dep().is_vector() && sparsity().is_vector()) {
109  return dep();
110  } else {
111  return MXNode::get_transpose();
112  }
113  }
114 
115  bool Reshape::is_valid_input() const {
116  return dep()->is_valid_input();
117  }
118 
119  casadi_int Reshape::n_primitives() const {
120  return dep()->n_primitives();
121  }
122 
123  void Reshape::primitives(std::vector<MX>::iterator& it) const {
124  dep()->primitives(it);
125  }
126 
127  template<typename T>
128  void Reshape::split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const {
129  dep()->split_primitives(reshape(x, dep().size()), it);
130  }
131 
132  void Reshape::split_primitives(const MX& x, std::vector<MX>::iterator& it) const {
133  split_primitives_gen<MX>(x, it);
134  }
135 
136  void Reshape::split_primitives(const SX& x, std::vector<SX>::iterator& it) const {
137  split_primitives_gen<SX>(x, it);
138  }
139 
140  void Reshape::split_primitives(const DM& x, std::vector<DM>::iterator& it) const {
141  split_primitives_gen<DM>(x, it);
142  }
143 
144  template<typename T>
145  T Reshape::join_primitives_gen(typename std::vector<T>::const_iterator& it) const {
146  return reshape(dep()->join_primitives(it), size());
147  }
148 
149  MX Reshape::join_primitives(std::vector<MX>::const_iterator& it) const {
150  return join_primitives_gen<MX>(it);
151  }
152 
153  SX Reshape::join_primitives(std::vector<SX>::const_iterator& it) const {
154  return join_primitives_gen<SX>(it);
155  }
156 
157  DM Reshape::join_primitives(std::vector<DM>::const_iterator& it) const {
158  return join_primitives_gen<DM>(it);
159  }
160 
161  bool Reshape::has_duplicates() const {
162  return dep()->has_duplicates();
163  }
164 
165  void Reshape::reset_input() const {
166  dep()->reset_input();
167  }
168 
169 } // namespace casadi
Helper class for C code generation.
casadi_int nnz() const
Get the number of (structural) non-zero elements.
virtual void reset_input() const
Reset the marker for an input expression.
Definition: mx_node.cpp:150
virtual casadi_int n_primitives() const
Get the number of symbolic primitives.
Definition: mx_node.cpp:142
static void copy_fwd(const bvec_t *arg, bvec_t *res, casadi_int len)
Propagate sparsities forward through a copy operation.
Definition: mx_node.cpp:1241
virtual bool has_duplicates() const
Detect duplicate symbolic expressions.
Definition: mx_node.cpp:146
virtual bool is_valid_input() const
Check if valid function input.
Definition: mx_node.hpp:231
static void copy_rev(bvec_t *arg, bvec_t *res, casadi_int len)
Propagate sparsities backwards through a copy operation.
Definition: mx_node.cpp:1247
std::pair< casadi_int, casadi_int > size() const
Definition: mx_node.hpp:392
void generate_copy(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref, casadi_int i) const
Definition: mx_node.cpp:457
const Sparsity & sparsity() const
Get the sparsity.
Definition: mx_node.hpp:372
casadi_int nnz(casadi_int i=0) const
Definition: mx_node.hpp:389
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
Definition: mx_node.hpp:354
virtual void primitives(std::vector< MX >::iterator &it) const
Get symbolic primitives.
Definition: mx_node.cpp:154
void set_sparsity(const Sparsity &sparsity)
Set the sparsity.
Definition: mx_node.cpp:222
virtual MX get_transpose() const
Transpose.
Definition: mx_node.cpp:484
void set_dep(const MX &dep)
Set unary dependency.
Definition: mx_node.cpp:226
virtual void split_primitives(const MX &x, std::vector< MX >::iterator &it) const
Split up an expression along symbolic primitives.
Definition: mx_node.cpp:158
MX - Matrix expression.
Definition: mx.hpp:92
Sparse matrix class. SX and DM are specializations.
Definition: matrix_decl.hpp:99
MX get_reshape(const Sparsity &sp) const override
Reshape.
Definition: reshape.cpp:102
int eval_gen(const T **arg, T **res, casadi_int *iw, T *w) const
Evaluate the function (template)
Definition: reshape.cpp:46
bool has_duplicates() const override
Detect duplicate symbolic expressions.
Definition: reshape.cpp:161
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
Definition: reshape.cpp:51
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
Definition: reshape.cpp:37
T join_primitives_gen(typename std::vector< T >::const_iterator &it) const
Join an expression along symbolic primitives (template)
Definition: reshape.cpp:145
MX join_primitives(std::vector< MX >::const_iterator &it) const override
Join an expression along symbolic primitives.
Definition: reshape.cpp:149
void split_primitives_gen(const T &x, typename std::vector< T >::iterator &it) const
Split up an expression along primitives (template)
Definition: reshape.cpp:128
void primitives(std::vector< MX >::iterator &it) const override
Get symbolic primitives.
Definition: reshape.cpp:123
void reset_input() const override
Reset the marker for an input expression.
Definition: reshape.cpp:165
Reshape(const MX &x, Sparsity sp)
Constructor.
Definition: reshape.cpp:31
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
Definition: reshape.cpp:80
MX get_transpose() const override
Transpose (if a dimension is one)
Definition: reshape.cpp:106
bool is_valid_input() const override
Check if valid function input.
Definition: reshape.cpp:115
casadi_int n_primitives() const override
Get the number of symbolic primitives.
Definition: reshape.cpp:119
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
Definition: reshape.cpp:41
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
Definition: reshape.cpp:61
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
Definition: reshape.cpp:56
void split_primitives(const MX &x, std::vector< MX >::iterator &it) const override
Split up an expression along symbolic primitives.
Definition: reshape.cpp:132
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
Definition: reshape.cpp:94
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
Definition: reshape.cpp:76
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
Definition: reshape.cpp:87
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
General sparsity class.
Definition: sparsity.hpp:106
casadi_int nnz() const
Get the number of (structural) non-zeros.
Definition: sparsity.cpp:148
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t