setnonzeros_param.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SETNONZEROS_PARAM_HPP
27 #define CASADI_SETNONZEROS_PARAM_HPP
28 
29 #include "mx_node.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
36 
43  template<bool Add>
44  class CASADI_EXPORT SetNonzerosParam : public MXNode {
45  public:
47 
53  static MX create(const MX& y, const MX& x, const MX& nz);
54  static MX create(const MX& y, const MX& x, const MX& inner, const Slice& outer);
55  static MX create(const MX& y, const MX& x, const Slice& inner, const MX& outer);
56  static MX create(const MX& y, const MX& x, const MX& inner, const MX& outer);
58 
60  SetNonzerosParam(const MX& y, const MX& x, const MX& nz);
61  SetNonzerosParam(const MX& y, const MX& x, const MX& nz, const MX& nz2);
62 
64  ~SetNonzerosParam() override = 0;
65 
69  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
70 
74  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
75 
79  casadi_int op() const override { return Add ? OP_ADDNONZEROS_PARAM : OP_SETNONZEROS_PARAM;}
80 
82  casadi_int n_inplace() const override { return 1;}
83 
87  void generate(CodeGenerator& g,
88  const std::vector<casadi_int>& arg,
89  const std::vector<casadi_int>& res) const override;
90 
94  static MXNode* deserialize(DeserializingStream& s);
95 
96  protected:
100  explicit SetNonzerosParam(DeserializingStream& s) : MXNode(s) {}
101  };
102 
103 
110  template<bool Add>
111  class CASADI_EXPORT SetNonzerosParamVector : public SetNonzerosParam<Add>{
112  public:
113 
115  SetNonzerosParamVector(const MX& y, const MX& x, const MX& nz);
116 
118  ~SetNonzerosParamVector() override {}
119 
123  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
124 
126  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
127 
131  void ad_forward(const std::vector<std::vector<MX> >& fseed,
132  std::vector<std::vector<MX> >& fsens) const override;
133 
137  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
138  std::vector<std::vector<MX> >& asens) const override;
139 
143  std::string disp(const std::vector<std::string>& arg) const override;
144 
148  void generate(CodeGenerator& g,
149  const std::vector<casadi_int>& arg,
150  const std::vector<casadi_int>& res) const override;
151 
155  void serialize_body(SerializingStream& s) const override;
159  void serialize_type(SerializingStream& s) const override;
160 
164  explicit SetNonzerosParamVector(DeserializingStream& s);
165  };
166 
167  // Specialization of the above when nz_ is a Slice
168  template<bool Add>
169  class CASADI_EXPORT SetNonzerosParamSlice : public SetNonzerosParam<Add>{
170  public:
171 
175  size_t sz_iw() const override;
176 
178  SetNonzerosParamSlice(const MX& y, const MX& x, const MX& inner, const Slice& outer) :
179  SetNonzerosParam<Add>(y, x, inner), outer_(outer) {}
180 
182  ~SetNonzerosParamSlice() override {}
183 
187  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
188 
192  void ad_forward(const std::vector<std::vector<MX> >& fseed,
193  std::vector<std::vector<MX> >& fsens) const override;
194 
198  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
199  std::vector<std::vector<MX> >& asens) const override;
200 
202  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
203 
207  std::string disp(const std::vector<std::string>& arg) const override;
208 
212  void generate(CodeGenerator& g,
213  const std::vector<casadi_int>& arg,
214  const std::vector<casadi_int>& res) const override;
215 
216  // Data member
217  Slice outer_;
218 
222  void serialize_body(SerializingStream& s) const override;
226  void serialize_type(SerializingStream& s) const override;
227 
231  explicit SetNonzerosParamSlice(DeserializingStream& s);
232  };
233 
234 
235  // Specialization of the above when nz_ is a Slice
236  template<bool Add>
237  class CASADI_EXPORT SetNonzerosSliceParam : public SetNonzerosParam<Add>{
238  public:
239 
241  SetNonzerosSliceParam(const MX& y, const MX& x, const Slice& inner, const MX& outer) :
242  SetNonzerosParam<Add>(y, x, outer), inner_(inner) {}
243 
245  ~SetNonzerosSliceParam() override {}
246 
250  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
251 
253  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
254 
258  void ad_forward(const std::vector<std::vector<MX> >& fseed,
259  std::vector<std::vector<MX> >& fsens) const override;
260 
264  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
265  std::vector<std::vector<MX> >& asens) const override;
266 
270  std::string disp(const std::vector<std::string>& arg) const override;
271 
275  void generate(CodeGenerator& g,
276  const std::vector<casadi_int>& arg,
277  const std::vector<casadi_int>& res) const override;
278 
279  // Data member
280  Slice inner_;
281 
285  void serialize_body(SerializingStream& s) const override;
289  void serialize_type(SerializingStream& s) const override;
290 
294  explicit SetNonzerosSliceParam(DeserializingStream& s);
295  };
296 
297  // Specialization of the above when nz_ is a Slice
298  template<bool Add>
299  class CASADI_EXPORT SetNonzerosParamParam : public SetNonzerosParam<Add>{
300  public:
301 
305  size_t sz_iw() const override;
306 
308  SetNonzerosParamParam(const MX& y, const MX& x, const MX& inner, const MX& outer) :
309  SetNonzerosParam<Add>(y, x, inner, outer) {}
310 
312  ~SetNonzerosParamParam() override {}
313 
317  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
318 
320  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
321 
325  void ad_forward(const std::vector<std::vector<MX> >& fseed,
326  std::vector<std::vector<MX> >& fsens) const override;
327 
331  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
332  std::vector<std::vector<MX> >& asens) const override;
333 
337  std::string disp(const std::vector<std::string>& arg) const override;
338 
342  void generate(CodeGenerator& g,
343  const std::vector<casadi_int>& arg,
344  const std::vector<casadi_int>& res) const override;
345 
349  void serialize_type(SerializingStream& s) const override;
350 
354  explicit SetNonzerosParamParam(DeserializingStream& s);
355  };
356 
357 } // namespace casadi
359 
360 #endif // CASADI_SETNONZEROS_PARAM_HPP
The casadi namespace.