26 #ifndef CASADI_GETNONZEROS_HPP
27 #define CASADI_GETNONZEROS_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT GetNonzeros :
public MXNode {
47 static MX create(
const Sparsity& sp,
const MX& x,
const std::vector<casadi_int>& nz);
48 static MX create(
const Sparsity& sp,
const MX& x,
const Slice& s);
49 static MX create(
const Sparsity& sp,
const MX& x,
const Slice& inner,
const Slice& outer);
53 GetNonzeros(
const Sparsity& sp,
const MX& y);
56 ~GetNonzeros()
override {}
61 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
66 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
67 std::vector<std::array<MX, 3> >& res)
const override;
72 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
73 std::vector<std::vector<MX> >& fsens)
const override;
78 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
79 std::vector<std::vector<MX> >& asens)
const override;
82 Matrix<casadi_int> mapping()
const override;
85 virtual std::vector<casadi_int> all()
const = 0;
90 casadi_int op()
const override {
return OP_GETNONZEROS;}
93 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
98 static MXNode* deserialize(DeserializingStream& s);
104 explicit GetNonzeros(DeserializingStream& s) : MXNode(s) {}
107 class CASADI_EXPORT GetNonzerosVector :
public GetNonzeros {
110 GetNonzerosVector(
const Sparsity& sp,
const MX& x,
111 const std::vector<casadi_int>& nz) : GetNonzeros(sp, x), nz_(nz) {}
114 ~GetNonzerosVector()
override {}
117 std::vector<casadi_int> all()
const override {
return nz_;}
122 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
127 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
132 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
136 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
139 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
142 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
147 std::string disp(
const std::vector<std::string>& arg)
const override;
152 void generate(CodeGenerator& g,
153 const std::vector<casadi_int>& arg,
154 const std::vector<casadi_int>& res,
155 const std::vector<bool>& arg_is_ref,
156 std::vector<bool>& res_is_ref)
const override;
161 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
164 Dict info()
const override {
return {{
"nz", nz_}}; }
167 std::vector<casadi_int> nz_;
172 void serialize_body(SerializingStream& s)
const override;
176 void serialize_type(SerializingStream& s)
const override;
181 explicit GetNonzerosVector(DeserializingStream& s);
185 class CASADI_EXPORT GetNonzerosSlice :
public GetNonzeros {
189 GetNonzerosSlice(
const Sparsity& sp,
const MX& x,
const Slice& s) : GetNonzeros(sp, x), s_(s) {}
192 ~GetNonzerosSlice()
override {}
195 std::vector<casadi_int> all()
const override {
return s_.all(s_.stop);}
200 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
205 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
209 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
212 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
215 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
220 std::string disp(
const std::vector<std::string>& arg)
const override;
225 void generate(CodeGenerator& g,
226 const std::vector<casadi_int>& arg,
227 const std::vector<casadi_int>& res,
228 const std::vector<bool>& arg_is_ref,
229 std::vector<bool>& res_is_ref)
const override;
234 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
237 Dict info()
const override {
return {{
"slice", s_.info()}}; }
245 void serialize_body(SerializingStream& s)
const override;
249 void serialize_type(SerializingStream& s)
const override;
254 explicit GetNonzerosSlice(DeserializingStream& s);
258 class CASADI_EXPORT GetNonzerosSlice2 :
public GetNonzeros {
262 GetNonzerosSlice2(
const Sparsity& sp,
const MX& x,
const Slice& inner,
263 const Slice& outer) : GetNonzeros(sp, x), inner_(inner), outer_(outer) {}
266 ~GetNonzerosSlice2()
override {}
269 std::vector<casadi_int> all()
const override {
return inner_.all(outer_, outer_.stop);}
274 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
279 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
283 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
286 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
289 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
294 std::string disp(
const std::vector<std::string>& arg)
const override;
299 void generate(CodeGenerator& g,
300 const std::vector<casadi_int>& arg,
301 const std::vector<casadi_int>& res,
302 const std::vector<bool>& arg_is_ref,
303 std::vector<bool>& res_is_ref)
const override;
308 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
311 Dict info()
const override {
return {{
"inner", inner_.info()}, {
"outer", outer_.info()}}; }
314 Slice inner_, outer_;
319 void serialize_body(SerializingStream& s)
const override;
323 void serialize_type(SerializingStream& s)
const override;
328 explicit GetNonzerosSlice2(DeserializingStream& s);
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.