setnonzeros_param.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SETNONZEROS_PARAM_HPP
27 #define CASADI_SETNONZEROS_PARAM_HPP
28 
29 #include "mx_node.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
36 
43  template<bool Add>
44  class CASADI_EXPORT SetNonzerosParam : public MXNode {
45  public:
47 
53  static MX create(const MX& y, const MX& x, const MX& nz);
54  static MX create(const MX& y, const MX& x, const MX& inner, const Slice& outer);
55  static MX create(const MX& y, const MX& x, const Slice& inner, const MX& outer);
56  static MX create(const MX& y, const MX& x, const MX& inner, const MX& outer);
58 
60  SetNonzerosParam(const MX& y, const MX& x, const MX& nz);
61  SetNonzerosParam(const MX& y, const MX& x, const MX& nz, const MX& nz2);
62 
64  ~SetNonzerosParam() override = 0;
65 
69  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
70 
74  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
75 
79  casadi_int op() const override { return Add ? OP_ADDNONZEROS_PARAM : OP_SETNONZEROS_PARAM;}
80 
82  casadi_int n_inplace() const override { return 1;}
83 
87  void generate(CodeGenerator& g,
88  const std::vector<casadi_int>& arg,
89  const std::vector<casadi_int>& res,
90  const std::vector<bool>& arg_is_ref,
91  std::vector<bool>& res_is_ref) const override;
92 
96  static MXNode* deserialize(DeserializingStream& s);
97 
98  protected:
103  };
104 
105 
112  template<bool Add>
113  class CASADI_EXPORT SetNonzerosParamVector : public SetNonzerosParam<Add>{
114  public:
115 
117  SetNonzerosParamVector(const MX& y, const MX& x, const MX& nz);
118 
121 
125  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
126 
128  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
129 
133  void ad_forward(const std::vector<std::vector<MX> >& fseed,
134  std::vector<std::vector<MX> >& fsens) const override;
135 
139  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
140  std::vector<std::vector<MX> >& asens) const override;
141 
145  std::string disp(const std::vector<std::string>& arg) const override;
146 
150  void generate(CodeGenerator& g,
151  const std::vector<casadi_int>& arg,
152  const std::vector<casadi_int>& res,
153  const std::vector<bool>& arg_is_ref,
154  std::vector<bool>& res_is_ref) const override;
155 
159  void serialize_body(SerializingStream& s) const override;
163  void serialize_type(SerializingStream& s) const override;
164 
169  };
170 
171  // Specialization of the above when nz_ is a Slice
172  template<bool Add>
173  class CASADI_EXPORT SetNonzerosParamSlice : public SetNonzerosParam<Add>{
174  public:
175 
179  size_t sz_iw() const override;
180 
182  SetNonzerosParamSlice(const MX& y, const MX& x, const MX& inner, const Slice& outer) :
183  SetNonzerosParam<Add>(y, x, inner), outer_(outer) {}
184 
187 
191  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
192 
196  void ad_forward(const std::vector<std::vector<MX> >& fseed,
197  std::vector<std::vector<MX> >& fsens) const override;
198 
202  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
203  std::vector<std::vector<MX> >& asens) const override;
204 
206  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
207 
211  std::string disp(const std::vector<std::string>& arg) const override;
212 
216  void generate(CodeGenerator& g,
217  const std::vector<casadi_int>& arg,
218  const std::vector<casadi_int>& res,
219  const std::vector<bool>& arg_is_ref,
220  std::vector<bool>& res_is_ref) const override;
221 
222  // Data member
224 
228  void serialize_body(SerializingStream& s) const override;
232  void serialize_type(SerializingStream& s) const override;
233 
238  };
239 
240 
241  // Specialization of the above when nz_ is a Slice
242  template<bool Add>
243  class CASADI_EXPORT SetNonzerosSliceParam : public SetNonzerosParam<Add>{
244  public:
245 
247  SetNonzerosSliceParam(const MX& y, const MX& x, const Slice& inner, const MX& outer) :
248  SetNonzerosParam<Add>(y, x, outer), inner_(inner) {}
249 
252 
256  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
257 
259  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
260 
264  void ad_forward(const std::vector<std::vector<MX> >& fseed,
265  std::vector<std::vector<MX> >& fsens) const override;
266 
270  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
271  std::vector<std::vector<MX> >& asens) const override;
272 
276  std::string disp(const std::vector<std::string>& arg) const override;
277 
281  void generate(CodeGenerator& g,
282  const std::vector<casadi_int>& arg,
283  const std::vector<casadi_int>& res,
284  const std::vector<bool>& arg_is_ref,
285  std::vector<bool>& res_is_ref) const override;
286 
287  // Data member
289 
293  void serialize_body(SerializingStream& s) const override;
297  void serialize_type(SerializingStream& s) const override;
298 
303  };
304 
305  // Specialization of the above when nz_ is a Slice
306  template<bool Add>
307  class CASADI_EXPORT SetNonzerosParamParam : public SetNonzerosParam<Add>{
308  public:
309 
313  size_t sz_iw() const override;
314 
316  SetNonzerosParamParam(const MX& y, const MX& x, const MX& inner, const MX& outer) :
317  SetNonzerosParam<Add>(y, x, inner, outer) {}
318 
321 
325  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
326 
328  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
329 
333  void ad_forward(const std::vector<std::vector<MX> >& fseed,
334  std::vector<std::vector<MX> >& fsens) const override;
335 
339  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
340  std::vector<std::vector<MX> >& asens) const override;
341 
345  std::string disp(const std::vector<std::string>& arg) const override;
346 
350  void generate(CodeGenerator& g,
351  const std::vector<casadi_int>& arg,
352  const std::vector<casadi_int>& res,
353  const std::vector<bool>& arg_is_ref,
354  std::vector<bool>& res_is_ref) const override;
355 
359  void serialize_type(SerializingStream& s) const override;
360 
365  };
366 
367 } // namespace casadi
369 
370 #endif // CASADI_SETNONZEROS_PARAM_HPP
Helper class for C code generation.
Helper class for Serialization.
Node class for MX objects.
Definition: mx_node.hpp:51
MX - Matrix expression.
Definition: mx.hpp:92
Helper class for Serialization.
SetNonzerosParamParam(const MX &y, const MX &x, const MX &inner, const MX &outer)
Constructor.
~SetNonzerosParamParam() override
Destructor.
~SetNonzerosParamSlice() override
Destructor.
SetNonzerosParamSlice(const MX &y, const MX &x, const MX &inner, const Slice &outer)
Constructor.
Add the nonzeros of a matrix to another matrix, parametrically.
~SetNonzerosParamVector() override
Destructor.
Assign or add entries to a matrix, parametrically.
casadi_int op() const override
Get the operation.
casadi_int n_inplace() const override
Can the operation be performed inplace (i.e. overwrite the result)
SetNonzerosParam(DeserializingStream &s)
Deserializing constructor.
~SetNonzerosSliceParam() override
Destructor.
SetNonzerosSliceParam(const MX &y, const MX &x, const Slice &inner, const MX &outer)
Constructor.
Class representing a Slice.
Definition: slice.hpp:48
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
@ OP_ADDNONZEROS_PARAM
Definition: calculus.hpp:160
@ OP_SETNONZEROS_PARAM
Definition: calculus.hpp:166