getnonzeros_param.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_GETNONZEROS_PARAM_HPP
27 #define CASADI_GETNONZEROS_PARAM_HPP
28 
29 #include "mx_node.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
42  class CASADI_EXPORT GetNonzerosParam : public MXNode {
43  public:
44 
47  static MX create(const MX& x, const MX& nz);
48  static MX create(const MX& x, const MX& inner, const Slice& outer);
49  static MX create(const MX& x, const Slice& inner, const MX& outer);
50  static MX create(const MX& x, const MX& inner, const MX& outer);
52 
54  GetNonzerosParam(const Sparsity& sp, const MX& y, const MX& nz);
55  GetNonzerosParam(const Sparsity& sp, const MX& y, const MX& nz, const MX& nz_extra);
56 
58  ~GetNonzerosParam() override {}
59 
63  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
64 
68  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
69 
73  casadi_int op() const override { return OP_GETNONZEROS_PARAM;}
74 
78  static MXNode* deserialize(DeserializingStream& s);
79 
80  protected:
84  explicit GetNonzerosParam(DeserializingStream& s) : MXNode(s) {}
85  };
86 
87 
88  class CASADI_EXPORT GetNonzerosParamVector : public GetNonzerosParam {
89  public:
91  GetNonzerosParamVector(const MX& x,
92  const MX& nz) : GetNonzerosParam(nz.sparsity(), x, nz) {}
93 
95  ~GetNonzerosParamVector() override {}
96 
100  void ad_forward(const std::vector<std::vector<MX> >& fseed,
101  std::vector<std::vector<MX> >& fsens) const override;
102 
106  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
107  std::vector<std::vector<MX> >& asens) const override;
108 
110  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
111 
115  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
116 
120  std::string disp(const std::vector<std::string>& arg) const override;
121 
125  void generate(CodeGenerator& g,
126  const std::vector<casadi_int>& arg,
127  const std::vector<casadi_int>& res) const override;
128 
132  void serialize_body(SerializingStream& s) const override;
136  void serialize_type(SerializingStream& s) const override;
137 
141  explicit GetNonzerosParamVector(DeserializingStream& s);
142  };
143 
144  // Specialization of the above when nz_ is a nested Slice
145  class CASADI_EXPORT GetNonzerosSliceParam : public GetNonzerosParam {
146  public:
147 
149  GetNonzerosSliceParam(const Sparsity& sp, const MX& x, const Slice& inner,
150  const MX& outer) :
151  GetNonzerosParam(sp, x, outer), inner_(inner) {}
152 
154  ~GetNonzerosSliceParam() override {}
155 
157  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
158 
162  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
163 
167  void ad_forward(const std::vector<std::vector<MX> >& fseed,
168  std::vector<std::vector<MX> >& fsens) const override;
169 
173  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
174  std::vector<std::vector<MX> >& asens) const override;
175 
179  std::string disp(const std::vector<std::string>& arg) const override;
180 
184  void generate(CodeGenerator& g,
185  const std::vector<casadi_int>& arg,
186  const std::vector<casadi_int>& res) const override;
187 
189  Dict info() const override { return {{"inner", inner_.info()}}; }
190 
191  // Data members
192  Slice inner_;
193 
197  void serialize_body(SerializingStream& s) const override;
201  void serialize_type(SerializingStream& s) const override;
202 
206  explicit GetNonzerosSliceParam(DeserializingStream& s);
207  };
208 
209  // Specialization of the above when nz_ is a nested Slice
210  class CASADI_EXPORT GetNonzerosParamSlice : public GetNonzerosParam {
211  public:
212 
214  GetNonzerosParamSlice(const Sparsity& sp, const MX& x, const MX& inner,
215  const Slice& outer) :
216  GetNonzerosParam(sp, x, inner), outer_(outer) {}
217 
219  ~GetNonzerosParamSlice() override {}
220 
224  size_t sz_iw() const override;
225 
227  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
228 
232  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
233 
237  void ad_forward(const std::vector<std::vector<MX> >& fseed,
238  std::vector<std::vector<MX> >& fsens) const override;
239 
243  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
244  std::vector<std::vector<MX> >& asens) const override;
245 
249  std::string disp(const std::vector<std::string>& arg) const override;
250 
254  void generate(CodeGenerator& g,
255  const std::vector<casadi_int>& arg,
256  const std::vector<casadi_int>& res) const override;
257 
259  Dict info() const override { return {{"outer", outer_.info()}}; }
260 
261  // Data members
262  Slice outer_;
263 
267  void serialize_body(SerializingStream& s) const override;
271  void serialize_type(SerializingStream& s) const override;
272 
276  explicit GetNonzerosParamSlice(DeserializingStream& s);
277  };
278 
279 
280  // Specialization of the above when nz_ is a nested Slice
281  class CASADI_EXPORT GetNonzerosParamParam : public GetNonzerosParam {
282  public:
283 
285  GetNonzerosParamParam(const Sparsity& sp, const MX& x, const MX& inner,
286  const MX& outer) :
287  GetNonzerosParam(sp, x, inner, outer) {}
288 
290  ~GetNonzerosParamParam() override {}
291 
295  size_t sz_iw() const override;
296 
298  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
299 
303  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
304 
308  void ad_forward(const std::vector<std::vector<MX> >& fseed,
309  std::vector<std::vector<MX> >& fsens) const override;
310 
314  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
315  std::vector<std::vector<MX> >& asens) const override;
316 
320  std::string disp(const std::vector<std::string>& arg) const override;
321 
325  void generate(CodeGenerator& g,
326  const std::vector<casadi_int>& arg,
327  const std::vector<casadi_int>& res) const override;
328 
330  Dict info() const override { return {}; }
331 
332 
336  void serialize_type(SerializingStream& s) const override;
337 
341  explicit GetNonzerosParamParam(DeserializingStream& s);
342  };
343 
344 } // namespace casadi
346 
347 #endif // CASADI_GETNONZEROS_HPP
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.