26 #ifndef CASADI_RESHAPE_HPP
27 #define CASADI_RESHAPE_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT Reshape :
public MXNode {
46 Reshape(
const MX& x, Sparsity sp);
49 ~Reshape()
override {}
53 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
56 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
59 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
64 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
69 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
70 std::vector<std::vector<MX> >& fsens)
const override;
75 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
76 std::vector<std::vector<MX> >& asens)
const override;
81 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
86 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
91 std::string disp(
const std::vector<std::string>& arg)
const override;
96 void generate(CodeGenerator& g,
97 const std::vector<casadi_int>& arg,
98 const std::vector<casadi_int>& res)
const override;
103 casadi_int op()
const override {
return OP_RESHAPE;}
106 casadi_int n_inplace()
const override {
return 1;}
109 MX get_reshape(
const Sparsity& sp)
const override;
114 bool is_equal(
const MXNode* node, casadi_int depth)
const override
115 {
return sameOpAndDeps(node, depth) && sparsity()==node->sparsity();}
118 MX get_transpose()
const override;
123 bool is_valid_input()
const override;
128 casadi_int n_primitives()
const override;
133 void primitives(std::vector<MX>::iterator& it)
const override;
137 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
143 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
144 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
145 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
150 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
156 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
157 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
158 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
164 bool has_duplicates()
const override;
169 void reset_input()
const override;
174 static MXNode* deserialize(DeserializingStream& s) {
return new Reshape(s); }
179 explicit Reshape(DeserializingStream& s) : MXNode(s) {}