26 #ifndef CASADI_SETNONZEROS_HPP
27 #define CASADI_SETNONZEROS_HPP
29 #include "mx_node.hpp"
44 class CASADI_EXPORT SetNonzeros :
public MXNode {
52 static MX create(
const MX& y,
const MX& x,
const std::vector<casadi_int>& nz);
53 static MX create(
const MX& y,
const MX& x,
const Slice& s);
54 static MX create(
const MX& y,
const MX& x,
const Slice& inner,
const Slice& outer);
58 SetNonzeros(
const MX& y,
const MX& x);
61 ~SetNonzeros()
override = 0;
64 virtual std::vector<casadi_int> all()
const = 0;
69 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
74 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
75 std::vector<std::array<MX, 3> >& res)
const override {
76 eval_linear_rearrange(arg, res);
82 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
83 std::vector<std::vector<MX> >& fsens)
const override;
88 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
89 std::vector<std::vector<MX> >& asens)
const override;
94 casadi_int op()
const override {
return Add ? OP_ADDNONZEROS : OP_SETNONZEROS;}
97 Matrix<casadi_int> mapping()
const override;
100 casadi_int n_inplace()
const override {
return 1;}
105 static MXNode* deserialize(DeserializingStream& s);
111 explicit SetNonzeros(DeserializingStream& s) : MXNode(s) {}
122 class CASADI_EXPORT SetNonzerosVector :
public SetNonzeros<Add>{
126 SetNonzerosVector(
const MX& y,
const MX& x,
const std::vector<casadi_int>& nz);
129 ~SetNonzerosVector()
override {}
132 std::vector<casadi_int> all()
const override {
return nz_;}
137 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
141 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
144 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
147 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
152 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
157 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
162 std::string disp(
const std::vector<std::string>& arg)
const override;
167 void generate(CodeGenerator& g,
168 const std::vector<casadi_int>& arg,
169 const std::vector<casadi_int>& res,
170 const std::vector<bool>& arg_is_ref,
171 std::vector<bool>& res_is_ref)
const override;
176 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
179 Dict info()
const override {
return {{
"nz", nz_}, {
"add", Add}}; }
182 std::vector<casadi_int> nz_;
187 void serialize_body(SerializingStream& s)
const override;
191 void serialize_type(SerializingStream& s)
const override;
196 explicit SetNonzerosVector(DeserializingStream& s);
201 class CASADI_EXPORT SetNonzerosSlice :
public SetNonzeros<Add>{
205 SetNonzerosSlice(
const MX& y,
const MX& x,
const Slice& s) : SetNonzeros<Add>(y, x), s_(s) {}
208 ~SetNonzerosSlice()
override {}
211 std::vector<casadi_int> all()
const override {
return s_.all(s_.stop);}
216 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
221 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
226 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
230 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
233 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
236 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
241 std::string disp(
const std::vector<std::string>& arg)
const override;
246 void generate(CodeGenerator& g,
247 const std::vector<casadi_int>& arg,
248 const std::vector<casadi_int>& res,
249 const std::vector<bool>& arg_is_ref,
250 std::vector<bool>& res_is_ref)
const override;
255 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
258 Dict info()
const override {
return {{
"slice", s_.info()}, {
"add", Add}}; }
266 void serialize_body(SerializingStream& s)
const override;
270 void serialize_type(SerializingStream& s)
const override;
275 explicit SetNonzerosSlice(DeserializingStream& s);
280 class CASADI_EXPORT SetNonzerosSlice2 :
public SetNonzeros<Add>{
284 SetNonzerosSlice2(
const MX& y,
const MX& x,
const Slice& inner,
const Slice& outer) :
285 SetNonzeros<Add>(y, x), inner_(inner), outer_(outer) {}
288 ~SetNonzerosSlice2()
override {}
291 std::vector<casadi_int> all()
const override {
return inner_.all(outer_, outer_.stop);}
296 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
301 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
306 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
310 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
313 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
316 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
321 std::string disp(
const std::vector<std::string>& arg)
const override;
326 void generate(CodeGenerator& g,
327 const std::vector<casadi_int>& arg,
328 const std::vector<casadi_int>& res,
329 const std::vector<bool>& arg_is_ref,
330 std::vector<bool>& res_is_ref)
const override;
335 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
338 Dict info()
const override {
return {{
"inner", inner_.info()}, {
"outer", outer_.info()},
342 Slice inner_, outer_;
347 void serialize_body(SerializingStream& s)
const override;
351 void serialize_type(SerializingStream& s)
const override;
356 explicit SetNonzerosSlice2(DeserializingStream& s);
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.