26 #ifndef CASADI_CONCAT_HPP
27 #define CASADI_CONCAT_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT Concat :
public MXNode {
46 Concat(
const std::vector<MX>& x);
49 ~Concat()
override = 0;
53 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
56 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
59 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
64 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
69 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
74 void generate(CodeGenerator& g,
75 const std::vector<casadi_int>& arg,
76 const std::vector<casadi_int>& res)
const override;
79 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
84 bool is_equal(
const MXNode* node, casadi_int depth)
const override {
85 return sameOpAndDeps(node, depth);
91 bool is_valid_input()
const override;
96 casadi_int n_primitives()
const override;
101 void primitives(std::vector<MX>::iterator& it)
const override;
106 bool has_duplicates()
const override;
111 void reset_input()
const override;
117 explicit Concat(DeserializingStream& s) : MXNode(s) {}
127 class CASADI_EXPORT Horzcat :
public Concat {
131 Horzcat(
const std::vector<MX>& x);
134 ~Horzcat()
override {}
139 std::string disp(
const std::vector<std::string>& arg)
const override;
144 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
149 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
150 std::vector<std::vector<MX> >& fsens)
const override;
155 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
156 std::vector<std::vector<MX> >& asens)
const override;
161 casadi_int op()
const override {
return OP_HORZCAT;}
165 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
171 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
172 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
173 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
178 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
184 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
185 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
186 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
192 std::vector<casadi_int> off()
const;
197 static MXNode* deserialize(DeserializingStream& s) {
return new Horzcat(s); }
202 explicit Horzcat(DeserializingStream& s) : Concat(s) {}
211 class CASADI_EXPORT Vertcat :
public Concat {
215 Vertcat(
const std::vector<MX>& x);
218 ~Vertcat()
override {}
223 std::string disp(
const std::vector<std::string>& arg)
const override;
228 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
233 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
234 std::vector<std::vector<MX> >& fsens)
const override;
239 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
240 std::vector<std::vector<MX> >& asens)
const override;
245 casadi_int op()
const override {
return OP_VERTCAT;}
249 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
255 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
256 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
257 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
262 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
268 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
269 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
270 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
276 std::vector<casadi_int> off()
const;
281 static MXNode* deserialize(DeserializingStream& s) {
return new Vertcat(s); }
287 explicit Vertcat(DeserializingStream& s) : Concat(s) {}
296 class CASADI_EXPORT Diagcat :
public Concat {
300 Diagcat(
const std::vector<MX>& x);
303 ~Diagcat()
override {}
308 std::string disp(
const std::vector<std::string>& arg)
const override;
313 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
318 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
319 std::vector<std::vector<MX> >& fsens)
const override;
324 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
325 std::vector<std::vector<MX> >& asens)
const override;
330 casadi_int op()
const override {
return OP_DIAGCAT;}
334 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
340 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
341 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
342 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
347 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
353 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
354 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
355 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
361 std::pair<std::vector<casadi_int>, std::vector<casadi_int> > off()
const;
366 static MXNode* deserialize(DeserializingStream& s) {
return new Diagcat(s); }
372 explicit Diagcat(DeserializingStream& s) : Concat(s) {}