split.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SPLIT_HPP
27 #define CASADI_SPLIT_HPP
28 
29 #include "multiple_output.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
36 
43  class CASADI_EXPORT Split : public MultipleOutput {
44  public:
46  Split(const MX& x, const std::vector<casadi_int>& offset);
47 
49  ~Split() override = 0;
50 
54  casadi_int nout() const override { return output_sparsity_.size(); }
55 
59  const Sparsity& sparsity(casadi_int oind) const override { return output_sparsity_.at(oind);}
60 
62  template<typename T>
63  int eval_gen(const T** arg, T** res, casadi_int* iw, T* w) const;
64 
66  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
67 
69  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
70 
74  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
75 
79  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
80 
84  void generate(CodeGenerator& g,
85  const std::vector<casadi_int>& arg,
86  const std::vector<casadi_int>& res) const override;
87 
89  Dict info() const override;
90 
91  // Sparsity pattern of the outputs
92  std::vector<casadi_int> offset_;
93  std::vector<Sparsity> output_sparsity_;
94 
98  void serialize_body(SerializingStream& s) const override;
99 
100  protected:
104  explicit Split(DeserializingStream& s);
105  };
106 
113  class CASADI_EXPORT Horzsplit : public Split {
114  public:
115 
117  Horzsplit(const MX& x, const std::vector<casadi_int>& offset);
118 
120  ~Horzsplit() override {}
121 
125  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
126 
130  void ad_forward(const std::vector<std::vector<MX> >& fseed,
131  std::vector<std::vector<MX> >& fsens) const override;
132 
136  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
137  std::vector<std::vector<MX> >& asens) const override;
138 
142  std::string disp(const std::vector<std::string>& arg) const override;
143 
147  casadi_int op() const override { return OP_HORZSPLIT;}
148 
150  MX get_horzcat(const std::vector<MX>& x) const override;
151 
155  static MXNode* deserialize(DeserializingStream& s) { return new Horzsplit(s); }
156 
157  protected:
161  explicit Horzsplit(DeserializingStream& s) : Split(s) {}
162  };
163 
170  class CASADI_EXPORT Diagsplit : public Split {
171  public:
172 
174  Diagsplit(const MX& x,
175  const std::vector<casadi_int>& offset1, const std::vector<casadi_int>& offset2);
176 
178  ~Diagsplit() override {}
179 
183  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
184 
188  void ad_forward(const std::vector<std::vector<MX> >& fseed,
189  std::vector<std::vector<MX> >& fsens) const override;
190 
194  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
195  std::vector<std::vector<MX> >& asens) const override;
196 
200  std::string disp(const std::vector<std::string>& arg) const override;
201 
205  casadi_int op() const override { return OP_DIAGSPLIT;}
206 
208  MX get_diagcat(const std::vector<MX>& x) const override;
209 
213  static MXNode* deserialize(DeserializingStream& s) { return new Diagsplit(s); }
214 
215  protected:
219  explicit Diagsplit(DeserializingStream& s) : Split(s) {}
220  };
221 
228  class CASADI_EXPORT Vertsplit : public Split {
229  public:
230 
232  Vertsplit(const MX& x, const std::vector<casadi_int>& offset);
233 
235  ~Vertsplit() override {}
236 
240  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
241 
245  void ad_forward(const std::vector<std::vector<MX> >& fseed,
246  std::vector<std::vector<MX> >& fsens) const override;
247 
251  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
252  std::vector<std::vector<MX> >& asens) const override;
253 
257  std::string disp(const std::vector<std::string>& arg) const override;
258 
262  casadi_int op() const override { return OP_VERTSPLIT;}
263 
265  MX get_vertcat(const std::vector<MX>& x) const override;
266 
270  static MXNode* deserialize(DeserializingStream& s) { return new Vertsplit(s); }
271 
272  protected:
276  explicit Vertsplit(DeserializingStream& s) : Split(s) {}
277  };
278 
279 } // namespace casadi
280 
282 
283 #endif // CASADI_SPLIT_HPP
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.