26 #ifndef CASADI_RESHAPE_HPP
27 #define CASADI_RESHAPE_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT Reshape :
public MXNode {
46 Reshape(
const MX& x, Sparsity sp);
49 ~Reshape()
override {}
53 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
56 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
59 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
64 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
69 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
70 std::vector<std::array<MX, 3> >& res)
const override {
71 eval_linear_rearrange(arg, res);
77 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
78 std::vector<std::vector<MX> >& fsens)
const override;
83 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
84 std::vector<std::vector<MX> >& asens)
const override;
89 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
94 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
99 std::string disp(
const std::vector<std::string>& arg)
const override;
104 void generate(CodeGenerator& g,
105 const std::vector<casadi_int>& arg,
106 const std::vector<casadi_int>& res,
107 const std::vector<bool>& arg_is_ref,
108 std::vector<bool>& res_is_ref)
const override;
113 casadi_int op()
const override {
return OP_RESHAPE;}
116 casadi_int n_inplace()
const override {
return 1;}
119 MX get_reshape(
const Sparsity& sp)
const override;
124 bool is_equal(
const MXNode* node, casadi_int depth)
const override
125 {
return sameOpAndDeps(node, depth) && sparsity()==node->sparsity();}
128 MX get_transpose()
const override;
133 bool is_valid_input()
const override;
138 casadi_int n_primitives()
const override;
143 void primitives(std::vector<MX>::iterator& it)
const override;
147 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
153 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
154 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
155 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
160 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
166 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
167 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
168 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
174 bool has_duplicates()
const override;
179 void reset_input()
const override;
184 static MXNode* deserialize(DeserializingStream& s) {
return new Reshape(s); }
189 explicit Reshape(DeserializingStream& s) : MXNode(s) {}