26 #ifndef CASADI_CONCAT_HPP
27 #define CASADI_CONCAT_HPP
29 #include "mx_node.hpp"
42 class CASADI_EXPORT Concat :
public MXNode {
46 Concat(
const std::vector<MX>& x);
49 ~Concat()
override = 0;
53 int eval_gen(
const T*
const* arg, T*
const* res, casadi_int* iw, T* w)
const;
56 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
59 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
64 void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
65 std::vector<std::array<MX, 3> >& res)
const override {
66 eval_linear_rearrange(arg, res);
72 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
77 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
82 void generate(CodeGenerator& g,
83 const std::vector<casadi_int>& arg,
84 const std::vector<casadi_int>& res,
85 const std::vector<bool>& arg_is_ref,
86 std::vector<bool>& res_is_ref)
const override;
89 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
94 bool is_equal(
const MXNode* node, casadi_int depth)
const override {
95 return sameOpAndDeps(node, depth);
101 bool is_valid_input()
const override;
106 casadi_int n_primitives()
const override;
111 void primitives(std::vector<MX>::iterator& it)
const override;
116 bool has_duplicates()
const override;
121 void reset_input()
const override;
127 explicit Concat(DeserializingStream& s) : MXNode(s) {}
137 class CASADI_EXPORT Horzcat :
public Concat {
141 Horzcat(
const std::vector<MX>& x);
144 ~Horzcat()
override {}
149 std::string disp(
const std::vector<std::string>& arg)
const override;
154 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
159 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
160 std::vector<std::vector<MX> >& fsens)
const override;
165 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
166 std::vector<std::vector<MX> >& asens)
const override;
171 casadi_int op()
const override {
return OP_HORZCAT;}
175 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
181 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
182 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
183 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
188 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
194 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
195 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
196 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
202 std::vector<casadi_int> off()
const;
207 static MXNode* deserialize(DeserializingStream& s) {
return new Horzcat(s); }
212 explicit Horzcat(DeserializingStream& s) : Concat(s) {}
221 class CASADI_EXPORT Vertcat :
public Concat {
225 Vertcat(
const std::vector<MX>& x);
228 ~Vertcat()
override {}
233 std::string disp(
const std::vector<std::string>& arg)
const override;
238 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
243 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
244 std::vector<std::vector<MX> >& fsens)
const override;
249 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
250 std::vector<std::vector<MX> >& asens)
const override;
255 casadi_int op()
const override {
return OP_VERTCAT;}
259 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
265 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
266 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
267 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
272 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
278 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
279 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
280 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
286 std::vector<casadi_int> off()
const;
291 static MXNode* deserialize(DeserializingStream& s) {
return new Vertcat(s); }
297 explicit Vertcat(DeserializingStream& s) : Concat(s) {}
306 class CASADI_EXPORT Diagcat :
public Concat {
310 Diagcat(
const std::vector<MX>& x);
313 ~Diagcat()
override {}
318 std::string disp(
const std::vector<std::string>& arg)
const override;
323 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
328 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
329 std::vector<std::vector<MX> >& fsens)
const override;
334 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
335 std::vector<std::vector<MX> >& asens)
const override;
340 casadi_int op()
const override {
return OP_DIAGCAT;}
344 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
350 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
351 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
352 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
357 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
363 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
364 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
365 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
371 std::pair<std::vector<casadi_int>, std::vector<casadi_int> > off()
const;
376 static MXNode* deserialize(DeserializingStream& s) {
return new Diagcat(s); }
382 explicit Diagcat(DeserializingStream& s) : Concat(s) {}