26 #ifndef CASADI_SPARSITY_CAST_HPP
27 #define CASADI_SPARSITY_CAST_HPP
29 #include "mx_node.hpp"
37 class CASADI_EXPORT SparsityCast :
public MXNode {
41 SparsityCast(
const MX& x, Sparsity sp);
44 ~SparsityCast()
override {}
48 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
51 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
54 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
59 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
64 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
65 std::vector<std::vector<MX> >& fsens)
const override;
70 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
71 std::vector<std::vector<MX> >& asens)
const override;
76 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
81 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
86 std::string disp(
const std::vector<std::string>& arg)
const override;
91 void generate(CodeGenerator& g,
92 const std::vector<casadi_int>& arg,
93 const std::vector<casadi_int>& res)
const override;
98 casadi_int op()
const override {
return OP_SPARSITY_CAST;}
101 casadi_int n_inplace()
const override {
return 1;}
104 MX get_reshape(
const Sparsity& sp)
const override;
109 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
112 MX get_sparsity_cast(
const Sparsity& sp)
const override;
117 bool is_equal(
const MXNode* node, casadi_int depth)
const override
118 {
return sameOpAndDeps(node, depth) && sparsity()==node->sparsity();}
121 MX get_transpose()
const override;
126 bool is_valid_input()
const override;
131 casadi_int n_primitives()
const override;
136 void primitives(std::vector<MX>::iterator& it)
const override;
140 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
146 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
147 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
148 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
153 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
159 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
160 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
161 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
167 bool has_duplicates()
const override;
172 void reset_input()
const override;
177 static MXNode* deserialize(DeserializingStream& s) {
return new SparsityCast(s); }
182 explicit SparsityCast(DeserializingStream& s) : MXNode(s) {}