26 #ifndef CASADI_SETNONZEROS_HPP
27 #define CASADI_SETNONZEROS_HPP
29 #include "mx_node.hpp"
44 class CASADI_EXPORT SetNonzeros :
public MXNode {
52 static MX create(
const MX& y,
const MX& x,
const std::vector<casadi_int>& nz);
53 static MX create(
const MX& y,
const MX& x,
const Slice& s);
54 static MX create(
const MX& y,
const MX& x,
const Slice& inner,
const Slice& outer);
58 SetNonzeros(
const MX& y,
const MX& x);
61 ~SetNonzeros()
override = 0;
64 virtual std::vector<casadi_int> all()
const = 0;
69 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
74 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
75 std::vector<std::vector<MX> >& fsens)
const override;
80 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
81 std::vector<std::vector<MX> >& asens)
const override;
86 casadi_int op()
const override {
return Add ? OP_ADDNONZEROS : OP_SETNONZEROS;}
89 Matrix<casadi_int> mapping()
const override;
92 casadi_int n_inplace()
const override {
return 1;}
97 static MXNode* deserialize(DeserializingStream& s);
103 explicit SetNonzeros(DeserializingStream& s) : MXNode(s) {}
114 class CASADI_EXPORT SetNonzerosVector :
public SetNonzeros<Add>{
118 SetNonzerosVector(
const MX& y,
const MX& x,
const std::vector<casadi_int>& nz);
121 ~SetNonzerosVector()
override {}
124 std::vector<casadi_int> all()
const override {
return nz_;}
129 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
133 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
136 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
139 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
144 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
149 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
154 std::string disp(
const std::vector<std::string>& arg)
const override;
159 void generate(CodeGenerator& g,
160 const std::vector<casadi_int>& arg,
161 const std::vector<casadi_int>& res)
const override;
166 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
169 Dict info()
const override {
return {{
"nz", nz_}, {
"add", Add}}; }
172 std::vector<casadi_int> nz_;
177 void serialize_body(SerializingStream& s)
const override;
181 void serialize_type(SerializingStream& s)
const override;
186 explicit SetNonzerosVector(DeserializingStream& s);
191 class CASADI_EXPORT SetNonzerosSlice :
public SetNonzeros<Add>{
195 SetNonzerosSlice(
const MX& y,
const MX& x,
const Slice& s) : SetNonzeros<Add>(y, x), s_(s) {}
198 ~SetNonzerosSlice()
override {}
201 std::vector<casadi_int> all()
const override {
return s_.all(s_.stop);}
206 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
211 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
216 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
220 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
223 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
226 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
231 std::string disp(
const std::vector<std::string>& arg)
const override;
236 void generate(CodeGenerator& g,
237 const std::vector<casadi_int>& arg,
238 const std::vector<casadi_int>& res)
const override;
243 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
246 Dict info()
const override {
return {{
"slice", s_.info()}, {
"add", Add}}; }
254 void serialize_body(SerializingStream& s)
const override;
258 void serialize_type(SerializingStream& s)
const override;
263 explicit SetNonzerosSlice(DeserializingStream& s);
268 class CASADI_EXPORT SetNonzerosSlice2 :
public SetNonzeros<Add>{
272 SetNonzerosSlice2(
const MX& y,
const MX& x,
const Slice& inner,
const Slice& outer) :
273 SetNonzeros<Add>(y, x), inner_(inner), outer_(outer) {}
276 ~SetNonzerosSlice2()
override {}
279 std::vector<casadi_int> all()
const override {
return inner_.all(outer_, outer_.stop);}
284 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
289 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
294 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
298 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
301 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
304 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
309 std::string disp(
const std::vector<std::string>& arg)
const override;
314 void generate(CodeGenerator& g,
315 const std::vector<casadi_int>& arg,
316 const std::vector<casadi_int>& res)
const override;
321 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
324 Dict info()
const override {
return {{
"inner", inner_.info()}, {
"outer", outer_.info()},
328 Slice inner_, outer_;
333 void serialize_body(SerializingStream& s)
const override;
337 void serialize_type(SerializingStream& s)
const override;
342 explicit SetNonzerosSlice2(DeserializingStream& s);
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.