26 #ifndef CASADI_GETNONZEROS_HPP
27 #define CASADI_GETNONZEROS_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT GetNonzeros :
public MXNode {
47 static MX create(
const Sparsity& sp,
const MX& x,
const std::vector<casadi_int>& nz);
48 static MX create(
const Sparsity& sp,
const MX& x,
const Slice& s);
49 static MX create(
const Sparsity& sp,
const MX& x,
const Slice& inner,
const Slice& outer);
53 GetNonzeros(
const Sparsity& sp,
const MX& y);
56 ~GetNonzeros()
override {}
61 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
66 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
67 std::vector<std::vector<MX> >& fsens)
const override;
72 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
73 std::vector<std::vector<MX> >& asens)
const override;
76 Matrix<casadi_int> mapping()
const override;
79 virtual std::vector<casadi_int> all()
const = 0;
84 casadi_int op()
const override {
return OP_GETNONZEROS;}
87 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
92 static MXNode* deserialize(DeserializingStream& s);
98 explicit GetNonzeros(DeserializingStream& s) : MXNode(s) {}
101 class CASADI_EXPORT GetNonzerosVector :
public GetNonzeros {
104 GetNonzerosVector(
const Sparsity& sp,
const MX& x,
105 const std::vector<casadi_int>& nz) : GetNonzeros(sp, x), nz_(nz) {}
108 ~GetNonzerosVector()
override {}
111 std::vector<casadi_int> all()
const override {
return nz_;}
116 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
121 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
126 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
130 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
133 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
136 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
141 std::string disp(
const std::vector<std::string>& arg)
const override;
146 void generate(CodeGenerator& g,
147 const std::vector<casadi_int>& arg,
148 const std::vector<casadi_int>& res)
const override;
153 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
156 Dict info()
const override {
return {{
"nz", nz_}}; }
159 std::vector<casadi_int> nz_;
164 void serialize_body(SerializingStream& s)
const override;
168 void serialize_type(SerializingStream& s)
const override;
173 explicit GetNonzerosVector(DeserializingStream& s);
177 class CASADI_EXPORT GetNonzerosSlice :
public GetNonzeros {
181 GetNonzerosSlice(
const Sparsity& sp,
const MX& x,
const Slice& s) : GetNonzeros(sp, x), s_(s) {}
184 ~GetNonzerosSlice()
override {}
187 std::vector<casadi_int> all()
const override {
return s_.all(s_.stop);}
192 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
197 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
201 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
204 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
207 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
212 std::string disp(
const std::vector<std::string>& arg)
const override;
217 void generate(CodeGenerator& g,
218 const std::vector<casadi_int>& arg,
219 const std::vector<casadi_int>& res)
const override;
224 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
227 Dict info()
const override {
return {{
"slice", s_.info()}}; }
235 void serialize_body(SerializingStream& s)
const override;
239 void serialize_type(SerializingStream& s)
const override;
244 explicit GetNonzerosSlice(DeserializingStream& s);
248 class CASADI_EXPORT GetNonzerosSlice2 :
public GetNonzeros {
252 GetNonzerosSlice2(
const Sparsity& sp,
const MX& x,
const Slice& inner,
253 const Slice& outer) : GetNonzeros(sp, x), inner_(inner), outer_(outer) {}
256 ~GetNonzerosSlice2()
override {}
259 std::vector<casadi_int> all()
const override {
return inner_.all(outer_, outer_.stop);}
264 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
269 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
273 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
276 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
279 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
284 std::string disp(
const std::vector<std::string>& arg)
const override;
289 void generate(CodeGenerator& g,
290 const std::vector<casadi_int>& arg,
291 const std::vector<casadi_int>& res)
const override;
296 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
299 Dict info()
const override {
return {{
"inner", inner_.info()}, {
"outer", outer_.info()}}; }
302 Slice inner_, outer_;
307 void serialize_body(SerializingStream& s)
const override;
311 void serialize_type(SerializingStream& s)
const override;
316 explicit GetNonzerosSlice2(DeserializingStream& s);
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.