26 #ifndef CASADI_SPARSITY_CAST_HPP
27 #define CASADI_SPARSITY_CAST_HPP
29 #include "mx_node.hpp"
37 class CASADI_EXPORT SparsityCast :
public MXNode {
41 SparsityCast(
const MX& x, Sparsity sp);
44 ~SparsityCast()
override {}
48 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
51 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
54 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
59 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
60 std::vector<std::array<MX, 3> >& res)
const override {
61 eval_linear_rearrange(arg, res);
67 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
72 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
73 std::vector<std::vector<MX> >& fsens)
const override;
78 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
79 std::vector<std::vector<MX> >& asens)
const override;
84 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
89 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
94 std::string disp(
const std::vector<std::string>& arg)
const override;
99 void generate(CodeGenerator& g,
100 const std::vector<casadi_int>& arg,
101 const std::vector<casadi_int>& res,
102 const std::vector<bool>& arg_is_ref,
103 std::vector<bool>& res_is_ref)
const override;
108 casadi_int op()
const override {
return OP_SPARSITY_CAST;}
111 casadi_int n_inplace()
const override {
return 1;}
114 MX get_reshape(
const Sparsity& sp)
const override;
119 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
122 MX get_sparsity_cast(
const Sparsity& sp)
const override;
127 bool is_equal(
const MXNode* node, casadi_int depth)
const override
128 {
return sameOpAndDeps(node, depth) && sparsity()==node->sparsity();}
131 MX get_transpose()
const override;
136 bool is_valid_input()
const override;
141 casadi_int n_primitives()
const override;
146 void primitives(std::vector<MX>::iterator& it)
const override;
150 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
156 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
157 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
158 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
163 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
169 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
170 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
171 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
177 bool has_duplicates()
const override;
182 void reset_input()
const override;
187 static MXNode* deserialize(DeserializingStream& s) {
return new SparsityCast(s); }
192 explicit SparsityCast(DeserializingStream& s) : MXNode(s) {}