26 #ifndef CASADI_SETNONZEROS_PARAM_HPP
27 #define CASADI_SETNONZEROS_PARAM_HPP
29 #include "mx_node.hpp"
44 class CASADI_EXPORT SetNonzerosParam :
public MXNode {
53 static MX create(
const MX& y,
const MX& x,
const MX& nz);
54 static MX create(
const MX& y,
const MX& x,
const MX& inner,
const Slice& outer);
55 static MX create(
const MX& y,
const MX& x,
const Slice& inner,
const MX& outer);
56 static MX create(
const MX& y,
const MX& x,
const MX& inner,
const MX& outer);
60 SetNonzerosParam(
const MX& y,
const MX& x,
const MX& nz);
61 SetNonzerosParam(
const MX& y,
const MX& x,
const MX& nz,
const MX& nz2);
64 ~SetNonzerosParam()
override = 0;
69 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
74 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
79 casadi_int op()
const override {
return Add ? OP_ADDNONZEROS_PARAM : OP_SETNONZEROS_PARAM;}
82 casadi_int n_inplace()
const override {
return 1;}
87 void generate(CodeGenerator& g,
88 const std::vector<casadi_int>& arg,
89 const std::vector<casadi_int>& res,
90 const std::vector<bool>& arg_is_ref,
91 std::vector<bool>& res_is_ref)
const override;
96 static MXNode* deserialize(DeserializingStream& s);
102 explicit SetNonzerosParam(DeserializingStream& s) : MXNode(s) {}
113 class CASADI_EXPORT SetNonzerosParamVector :
public SetNonzerosParam<Add>{
117 SetNonzerosParamVector(
const MX& y,
const MX& x,
const MX& nz);
120 ~SetNonzerosParamVector()
override {}
125 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
128 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
133 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
134 std::vector<std::vector<MX> >& fsens)
const override;
139 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
140 std::vector<std::vector<MX> >& asens)
const override;
145 std::string disp(
const std::vector<std::string>& arg)
const override;
150 void generate(CodeGenerator& g,
151 const std::vector<casadi_int>& arg,
152 const std::vector<casadi_int>& res,
153 const std::vector<bool>& arg_is_ref,
154 std::vector<bool>& res_is_ref)
const override;
159 void serialize_body(SerializingStream& s)
const override;
163 void serialize_type(SerializingStream& s)
const override;
168 explicit SetNonzerosParamVector(DeserializingStream& s);
173 class CASADI_EXPORT SetNonzerosParamSlice :
public SetNonzerosParam<Add>{
179 size_t sz_iw()
const override;
182 SetNonzerosParamSlice(
const MX& y,
const MX& x,
const MX& inner,
const Slice& outer) :
183 SetNonzerosParam<Add>(y, x, inner), outer_(outer) {}
186 ~SetNonzerosParamSlice()
override {}
191 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
196 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
197 std::vector<std::vector<MX> >& fsens)
const override;
202 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
203 std::vector<std::vector<MX> >& asens)
const override;
206 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
211 std::string disp(
const std::vector<std::string>& arg)
const override;
216 void generate(CodeGenerator& g,
217 const std::vector<casadi_int>& arg,
218 const std::vector<casadi_int>& res,
219 const std::vector<bool>& arg_is_ref,
220 std::vector<bool>& res_is_ref)
const override;
228 void serialize_body(SerializingStream& s)
const override;
232 void serialize_type(SerializingStream& s)
const override;
237 explicit SetNonzerosParamSlice(DeserializingStream& s);
243 class CASADI_EXPORT SetNonzerosSliceParam :
public SetNonzerosParam<Add>{
247 SetNonzerosSliceParam(
const MX& y,
const MX& x,
const Slice& inner,
const MX& outer) :
248 SetNonzerosParam<Add>(y, x, outer), inner_(inner) {}
251 ~SetNonzerosSliceParam()
override {}
256 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
259 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
264 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
265 std::vector<std::vector<MX> >& fsens)
const override;
270 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
271 std::vector<std::vector<MX> >& asens)
const override;
276 std::string disp(
const std::vector<std::string>& arg)
const override;
281 void generate(CodeGenerator& g,
282 const std::vector<casadi_int>& arg,
283 const std::vector<casadi_int>& res,
284 const std::vector<bool>& arg_is_ref,
285 std::vector<bool>& res_is_ref)
const override;
293 void serialize_body(SerializingStream& s)
const override;
297 void serialize_type(SerializingStream& s)
const override;
302 explicit SetNonzerosSliceParam(DeserializingStream& s);
307 class CASADI_EXPORT SetNonzerosParamParam :
public SetNonzerosParam<Add>{
313 size_t sz_iw()
const override;
316 SetNonzerosParamParam(
const MX& y,
const MX& x,
const MX& inner,
const MX& outer) :
317 SetNonzerosParam<Add>(y, x, inner, outer) {}
320 ~SetNonzerosParamParam()
override {}
325 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
328 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
333 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
334 std::vector<std::vector<MX> >& fsens)
const override;
339 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
340 std::vector<std::vector<MX> >& asens)
const override;
345 std::string disp(
const std::vector<std::string>& arg)
const override;
350 void generate(CodeGenerator& g,
351 const std::vector<casadi_int>& arg,
352 const std::vector<casadi_int>& res,
353 const std::vector<bool>& arg_is_ref,
354 std::vector<bool>& res_is_ref)
const override;
359 void serialize_type(SerializingStream& s)
const override;
364 explicit SetNonzerosParamParam(DeserializingStream& s);