concat.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_CONCAT_HPP
27 #define CASADI_CONCAT_HPP
28 
29 #include "mx_node.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
42  class CASADI_EXPORT Concat : public MXNode {
43  public:
44 
46  Concat(const std::vector<MX>& x);
47 
49  ~Concat() override = 0;
50 
52  template<typename T>
53  int eval_gen(const T* const* arg, T* const* res, casadi_int* iw, T* w) const;
54 
56  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
57 
59  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
60 
64  void eval_linear(const std::vector<std::array<MX, 3> >& arg,
65  std::vector<std::array<MX, 3> >& res) const override {
66  eval_linear_rearrange(arg, res);
67  }
68 
72  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
73 
77  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
78 
82  void generate(CodeGenerator& g,
83  const std::vector<casadi_int>& arg,
84  const std::vector<casadi_int>& res,
85  const std::vector<bool>& arg_is_ref,
86  std::vector<bool>& res_is_ref) const override;
87 
89  MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const override;
90 
94  bool is_equal(const MXNode* node, casadi_int depth) const override {
95  return sameOpAndDeps(node, depth);
96  }
97 
101  bool is_valid_input() const override;
102 
106  casadi_int n_primitives() const override;
107 
111  void primitives(std::vector<MX>::iterator& it) const override;
112 
116  bool has_duplicates() const override;
117 
121  void reset_input() const override;
122 
123  protected:
127  explicit Concat(DeserializingStream& s) : MXNode(s) {}
128  };
129 
130 
137  class CASADI_EXPORT Horzcat : public Concat {
138  public:
139 
141  Horzcat(const std::vector<MX>& x);
142 
144  ~Horzcat() override {}
145 
149  std::string disp(const std::vector<std::string>& arg) const override;
150 
154  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
155 
159  void ad_forward(const std::vector<std::vector<MX> >& fseed,
160  std::vector<std::vector<MX> >& fsens) const override;
161 
165  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
166  std::vector<std::vector<MX> >& asens) const override;
167 
171  casadi_int op() const override { return OP_HORZCAT;}
172 
174  template<typename T>
175  void split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const;
176 
178 
181  void split_primitives(const MX& x, std::vector<MX>::iterator& it) const override;
182  void split_primitives(const SX& x, std::vector<SX>::iterator& it) const override;
183  void split_primitives(const DM& x, std::vector<DM>::iterator& it) const override;
185 
187  template<typename T>
188  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
189 
191 
194  MX join_primitives(std::vector<MX>::const_iterator& it) const override;
195  SX join_primitives(std::vector<SX>::const_iterator& it) const override;
196  DM join_primitives(std::vector<DM>::const_iterator& it) const override;
198 
202  std::vector<casadi_int> off() const;
203 
207  static MXNode* deserialize(DeserializingStream& s) { return new Horzcat(s); }
208  protected:
212  explicit Horzcat(DeserializingStream& s) : Concat(s) {}
213  };
214 
221  class CASADI_EXPORT Vertcat : public Concat {
222  public:
223 
225  Vertcat(const std::vector<MX>& x);
226 
228  ~Vertcat() override {}
229 
233  std::string disp(const std::vector<std::string>& arg) const override;
234 
238  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
239 
243  void ad_forward(const std::vector<std::vector<MX> >& fseed,
244  std::vector<std::vector<MX> >& fsens) const override;
245 
249  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
250  std::vector<std::vector<MX> >& asens) const override;
251 
255  casadi_int op() const override { return OP_VERTCAT;}
256 
258  template<typename T>
259  void split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const;
260 
262 
265  void split_primitives(const MX& x, std::vector<MX>::iterator& it) const override;
266  void split_primitives(const SX& x, std::vector<SX>::iterator& it) const override;
267  void split_primitives(const DM& x, std::vector<DM>::iterator& it) const override;
269 
271  template<typename T>
272  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
273 
275 
278  MX join_primitives(std::vector<MX>::const_iterator& it) const override;
279  SX join_primitives(std::vector<SX>::const_iterator& it) const override;
280  DM join_primitives(std::vector<DM>::const_iterator& it) const override;
282 
286  std::vector<casadi_int> off() const;
287 
291  static MXNode* deserialize(DeserializingStream& s) { return new Vertcat(s); }
292 
293  protected:
297  explicit Vertcat(DeserializingStream& s) : Concat(s) {}
298  };
299 
306  class CASADI_EXPORT Diagcat : public Concat {
307  public:
308 
310  Diagcat(const std::vector<MX>& x);
311 
313  ~Diagcat() override {}
314 
318  std::string disp(const std::vector<std::string>& arg) const override;
319 
323  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
324 
328  void ad_forward(const std::vector<std::vector<MX> >& fseed,
329  std::vector<std::vector<MX> >& fsens) const override;
330 
334  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
335  std::vector<std::vector<MX> >& asens) const override;
336 
340  casadi_int op() const override { return OP_DIAGCAT;}
341 
343  template<typename T>
344  void split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const;
345 
347 
350  void split_primitives(const MX& x, std::vector<MX>::iterator& it) const override;
351  void split_primitives(const SX& x, std::vector<SX>::iterator& it) const override;
352  void split_primitives(const DM& x, std::vector<DM>::iterator& it) const override;
354 
356  template<typename T>
357  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
358 
360 
363  MX join_primitives(std::vector<MX>::const_iterator& it) const override;
364  SX join_primitives(std::vector<SX>::const_iterator& it) const override;
365  DM join_primitives(std::vector<DM>::const_iterator& it) const override;
367 
371  std::pair<std::vector<casadi_int>, std::vector<casadi_int> > off() const;
372 
376  static MXNode* deserialize(DeserializingStream& s) { return new Diagcat(s); }
377 
378  protected:
382  explicit Diagcat(DeserializingStream& s) : Concat(s) {}
383  };
384 
385 } // namespace casadi
387 
388 #endif // CASADI_CONCAT_HPP
Helper class for C code generation.
Concatenation: Join multiple expressions stacking the nonzeros.
Definition: concat.hpp:42
bool is_equal(const MXNode *node, casadi_int depth) const override
Check if two nodes are equivalent up to a given depth.
Definition: concat.hpp:94
void eval_linear(const std::vector< std::array< MX, 3 > > &arg, std::vector< std::array< MX, 3 > > &res) const override
Evaluate the MX node on a const/linear/nonlinear partition.
Definition: concat.hpp:64
Concat(DeserializingStream &s)
Deserializing constructor.
Definition: concat.hpp:127
Helper class for Serialization.
Diagonal concatenation of matrices.
Definition: concat.hpp:306
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: concat.hpp:376
Diagcat(DeserializingStream &s)
Deserializing constructor.
Definition: concat.hpp:382
~Diagcat() override
Destructor.
Definition: concat.hpp:313
casadi_int op() const override
Get the operation.
Definition: concat.hpp:340
Horizontal concatenation.
Definition: concat.hpp:137
casadi_int op() const override
Get the operation.
Definition: concat.hpp:171
~Horzcat() override
Destructor.
Definition: concat.hpp:144
Horzcat(DeserializingStream &s)
Deserializing constructor.
Definition: concat.hpp:212
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: concat.hpp:207
Node class for MX objects.
Definition: mx_node.hpp:51
MX - Matrix expression.
Definition: mx.hpp:92
Sparse matrix class. SX and DM are specializations.
Definition: matrix_decl.hpp:99
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
General sparsity class.
Definition: sparsity.hpp:106
Vertical concatenation of vectors.
Definition: concat.hpp:221
casadi_int op() const override
Get the operation.
Definition: concat.hpp:255
static MXNode * deserialize(DeserializingStream &s)
Deserialize without type information.
Definition: concat.hpp:291
~Vertcat() override
Destructor.
Definition: concat.hpp:228
Vertcat(DeserializingStream &s)
Deserializing constructor.
Definition: concat.hpp:297
The casadi namespace.
Definition: archiver.cpp:28
unsigned long long bvec_t
@ OP_DIAGCAT
Definition: calculus.hpp:130
@ OP_HORZCAT
Definition: calculus.hpp:124
@ OP_VERTCAT
Definition: calculus.hpp:127