split.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SPLIT_HPP
27 #define CASADI_SPLIT_HPP
28 
29 #include "multiple_output.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
36 
43  class CASADI_EXPORT Split : public MultipleOutput {
44  public:
46  Split(const MX& x, const std::vector<casadi_int>& offset);
47 
49  ~Split() override = 0;
50 
54  casadi_int nout() const override { return output_sparsity_.size(); }
55 
59  const Sparsity& sparsity(casadi_int oind) const override { return output_sparsity_.at(oind);}
60 
62  template<typename T>
63  int eval_gen(const T** arg, T** res, casadi_int* iw, T* w) const;
64 
66  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
67 
69  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
70 
74  void eval_linear(const std::vector<std::array<MX, 3> >& arg,
75  std::vector<std::array<MX, 3> >& res) const override {
76  eval_linear_rearrange(arg, res);
77  }
78 
82  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
83 
87  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
88 
92  void generate(CodeGenerator& g,
93  const std::vector<casadi_int>& arg,
94  const std::vector<casadi_int>& res,
95  const std::vector<bool>& arg_is_ref,
96  std::vector<bool>& res_is_ref) const override;
97 
99  Dict info() const override;
100 
101  // Sparsity pattern of the outputs
102  std::vector<casadi_int> offset_;
103  std::vector<Sparsity> output_sparsity_;
104 
108  void serialize_body(SerializingStream& s) const override;
109 
110  protected:
114  explicit Split(DeserializingStream& s);
115  };
116 
123  class CASADI_EXPORT Horzsplit : public Split {
124  public:
125 
127  Horzsplit(const MX& x, const std::vector<casadi_int>& offset);
128 
130  ~Horzsplit() override {}
131 
135  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
136 
140  void ad_forward(const std::vector<std::vector<MX> >& fseed,
141  std::vector<std::vector<MX> >& fsens) const override;
142 
146  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
147  std::vector<std::vector<MX> >& asens) const override;
148 
152  std::string disp(const std::vector<std::string>& arg) const override;
153 
157  casadi_int op() const override { return OP_HORZSPLIT;}
158 
160  MX get_horzcat(const std::vector<MX>& x) const override;
161 
165  static MXNode* deserialize(DeserializingStream& s) { return new Horzsplit(s); }
166 
167  protected:
171  explicit Horzsplit(DeserializingStream& s) : Split(s) {}
172  };
173 
180  class CASADI_EXPORT Diagsplit : public Split {
181  public:
182 
184  Diagsplit(const MX& x,
185  const std::vector<casadi_int>& offset1, const std::vector<casadi_int>& offset2);
186 
188  ~Diagsplit() override {}
189 
193  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
194 
198  void ad_forward(const std::vector<std::vector<MX> >& fseed,
199  std::vector<std::vector<MX> >& fsens) const override;
200 
204  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
205  std::vector<std::vector<MX> >& asens) const override;
206 
210  std::string disp(const std::vector<std::string>& arg) const override;
211 
215  casadi_int op() const override { return OP_DIAGSPLIT;}
216 
218  MX get_diagcat(const std::vector<MX>& x) const override;
219 
223  static MXNode* deserialize(DeserializingStream& s) { return new Diagsplit(s); }
224 
225  protected:
229  explicit Diagsplit(DeserializingStream& s) : Split(s) {}
230  };
231 
238  class CASADI_EXPORT Vertsplit : public Split {
239  public:
240 
242  Vertsplit(const MX& x, const std::vector<casadi_int>& offset);
243 
245  ~Vertsplit() override {}
246 
250  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
251 
255  void ad_forward(const std::vector<std::vector<MX> >& fseed,
256  std::vector<std::vector<MX> >& fsens) const override;
257 
261  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
262  std::vector<std::vector<MX> >& asens) const override;
263 
267  std::string disp(const std::vector<std::string>& arg) const override;
268 
272  casadi_int op() const override { return OP_VERTSPLIT;}
273 
275  MX get_vertcat(const std::vector<MX>& x) const override;
276 
280  static MXNode* deserialize(DeserializingStream& s) { return new Vertsplit(s); }
281 
282  protected:
286  explicit Vertsplit(DeserializingStream& s) : Split(s) {}
287  };
288 
289 } // namespace casadi
290 
292 
293 #endif // CASADI_SPLIT_HPP
Helper class for C code generation.
Helper class for Serialization.
Diag split, x -> x0, x1, ...
Definition: split.hpp:180
casadi_int op() const override
Get the operation.
Definition: split.hpp:215
Diagsplit(DeserializingStream &s)
Deserializing constructor.
Definition: split.hpp:229
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: split.hpp:223
~Diagsplit() override
Destructor.
Definition: split.hpp:188
Horizontal split, x -> x0, x1, ...
Definition: split.hpp:123
Horzsplit(DeserializingStream &s)
Deserializing constructor.
Definition: split.hpp:171
~Horzsplit() override
Destructor.
Definition: split.hpp:130
casadi_int op() const override
Get the operation.
Definition: split.hpp:157
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: split.hpp:165
Node class for MX objects.
Definition: mx_node.hpp:51
MX - Matrix expression.
Definition: mx.hpp:92
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
General sparsity class.
Definition: sparsity.hpp:106
Split: Split into multiple expressions splitting the nonzeros.
Definition: split.hpp:43
std::vector< Sparsity > output_sparsity_
Definition: split.hpp:103
void eval_linear(const std::vector< std::array< MX, 3 > > &arg, std::vector< std::array< MX, 3 > > &res) const override
Evaluate the MX node on a const/linear/nonlinear partition.
Definition: split.hpp:74
std::vector< casadi_int > offset_
Definition: split.hpp:102
casadi_int nout() const override
Number of outputs.
Definition: split.hpp:54
const Sparsity & sparsity(casadi_int oind) const override
Get the sparsity of output oind.
Definition: split.hpp:59
Vertical split of vectors, x -> x0, x1, ...
Definition: split.hpp:238
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: split.hpp:280
~Vertsplit() override
Destructor.
Definition: split.hpp:245
Vertsplit(DeserializingStream &s)
Deserializing constructor.
Definition: split.hpp:286
casadi_int op() const override
Get the operation.
Definition: split.hpp:272
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
@ OP_DIAGSPLIT
Definition: calculus.hpp:139
@ OP_VERTSPLIT
Definition: calculus.hpp:136
@ OP_HORZSPLIT
Definition: calculus.hpp:133