26 #ifndef CASADI_MX_NODE_HPP
27 #define CASADI_MX_NODE_HPP
30 #include "shared_object.hpp"
31 #include "sx_elem.hpp"
32 #include "calculus.hpp"
33 #include "code_generator.hpp"
41 class SerializingStream;
42 class DeserializingStream;
66 virtual bool __nonzero__()
const;
71 virtual bool is_zero()
const {
return false;}
76 virtual bool is_one()
const {
return false;}
86 virtual bool is_value(
double val)
const {
return false;}
91 virtual bool is_eye()
const {
return false;}
106 void can_inline(std::map<const MXNode*, casadi_int>& nodeind)
const;
111 std::string print_compact(std::map<const MXNode*, casadi_int>& nodeind,
112 std::vector<std::string>& intermed)
const;
117 virtual std::string
disp(
const std::vector<std::string>& arg)
const = 0;
143 const std::vector<casadi_int>& arg,
144 const std::vector<casadi_int>& res,
145 const std::vector<bool>& arg_is_ref,
146 std::vector<bool>& res_is_ref)
const;
149 const std::vector<casadi_int>& arg,
150 const std::vector<casadi_int>& res,
151 const std::vector<bool>& arg_is_ref,
152 std::vector<bool>& res_is_ref,
158 virtual int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const;
163 virtual int eval_sx(
const SXElem** arg,
SXElem** res, casadi_int* iw,
SXElem* w)
const;
168 virtual void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const;
173 virtual void eval_linear(
const std::vector<std::array<MX, 3> >& arg,
174 std::vector<std::array<MX, 3> >& res)
const;
180 std::vector<std::array<MX, 3> >& res)
const;
188 void eval_linear_rearrange(
const std::vector<std::array<MX, 3> >& arg,
189 std::vector<std::array<MX, 3> >& res)
const;
194 virtual void ad_forward(
const std::vector<std::vector<MX> >& fseed,
195 std::vector<std::vector<MX> >& fsens)
const;
200 virtual void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
201 std::vector<std::vector<MX> >& asens)
const;
206 virtual int sp_forward(
const bvec_t** arg,
bvec_t** res, casadi_int* iw,
bvec_t* w)
const;
216 virtual const std::string& name()
const;
221 std::string class_name()
const override;
226 void disp(std::ostream& stream,
bool more)
const override;
236 virtual casadi_int n_primitives()
const;
241 virtual void primitives(std::vector<MX>::iterator& it)
const;
247 virtual void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const;
248 virtual void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const;
249 virtual void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const;
254 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
260 virtual MX join_primitives(std::vector<MX>::const_iterator& it)
const;
261 virtual SX join_primitives(std::vector<SX>::const_iterator& it)
const;
262 virtual DM join_primitives(std::vector<DM>::const_iterator& it)
const;
270 virtual bool has_duplicates()
const;
277 virtual void reset_input()
const;
292 virtual casadi_int which_output()
const;
297 virtual const Function& which_function()
const;
302 virtual casadi_int
op()
const = 0;
305 virtual Dict info()
const;
337 virtual bool is_equal(
const MXNode* node, casadi_int depth)
const {
return false;}
349 bool sameOpAndDeps(
const MXNode* node, casadi_int depth)
const;
354 const MX&
dep(casadi_int ind=0)
const {
return dep_.at(ind);}
359 casadi_int n_dep()
const;
364 virtual casadi_int
nout()
const {
return 1;}
369 virtual MX get_output(casadi_int oind)
const;
375 virtual const Sparsity& sparsity(casadi_int oind)
const;
379 for (casadi_int i=0;i<dep_.size();++i) {
380 if (dep_[i].sparsity()!=arg[i].sparsity()) {
388 casadi_int
numel()
const {
return sparsity().numel(); }
389 casadi_int
nnz(casadi_int i=0)
const {
return sparsity(i).nnz(); }
390 casadi_int
size1()
const {
return sparsity().size1(); }
391 casadi_int
size2()
const {
return sparsity().size2(); }
392 std::pair<casadi_int, casadi_int>
size()
const {
return sparsity().size();}
395 virtual casadi_int ind()
const;
398 virtual casadi_int segment()
const;
401 virtual casadi_int offset()
const;
404 void set_sparsity(
const Sparsity& sparsity);
409 virtual size_t sz_arg()
const {
return n_dep();}
414 virtual size_t sz_res()
const {
return nout();}
419 virtual size_t sz_iw()
const {
return 0;}
424 virtual size_t sz_w()
const {
return 0;}
427 void set_dep(
const MX& dep);
430 void set_dep(
const MX& dep1,
const MX& dep2);
433 void set_dep(
const MX& dep1,
const MX& dep2,
const MX& dep3);
436 void set_dep(
const std::vector<MX>& dep);
439 void check_dep()
const;
451 virtual double to_double()
const;
454 virtual DM get_DM()
const;
463 virtual MX get_horzcat(
const std::vector<MX>& x)
const;
466 virtual std::vector<MX> get_horzsplit(
const std::vector<casadi_int>& output_offset)
const;
469 virtual MX get_repmat(casadi_int m, casadi_int n)
const;
472 virtual MX get_repsum(casadi_int m, casadi_int n)
const;
475 virtual MX get_vertcat(
const std::vector<MX>& x)
const;
478 virtual std::vector<MX> get_vertsplit(
const std::vector<casadi_int>& output_offset)
const;
481 virtual MX get_diagcat(
const std::vector<MX>& x)
const;
484 virtual std::vector<MX> get_diagsplit(
const std::vector<casadi_int>& offset1,
485 const std::vector<casadi_int>& offset2)
const;
488 virtual MX get_transpose()
const;
491 virtual MX get_reshape(
const Sparsity& sp)
const;
494 virtual MX get_sparsity_cast(
const Sparsity& sp)
const;
499 virtual MX get_mac(
const MX& y,
const MX& z)
const;
504 virtual MX get_einstein(
const MX& A,
const MX& B,
505 const std::vector<casadi_int>& dim_c,
const std::vector<casadi_int>& dim_a,
506 const std::vector<casadi_int>& dim_b,
507 const std::vector<casadi_int>& c,
const std::vector<casadi_int>& a,
508 const std::vector<casadi_int>& b)
const;
513 virtual MX get_bilin(
const MX& x,
const MX& y)
const;
518 virtual MX get_rank1(
const MX& alpha,
const MX& x,
const MX& y)
const;
523 virtual MX get_logsumexp()
const;
532 virtual MX get_solve(
const MX& r,
bool tr,
const Linsol& linear_solver)
const;
541 virtual MX get_solve_triu(
const MX& r,
bool tr)
const;
550 virtual MX get_solve_tril(
const MX& r,
bool tr)
const;
559 virtual MX get_solve_triu_unity(
const MX& r,
bool tr)
const;
568 virtual MX get_solve_tril_unity(
const MX& r,
bool tr)
const;
577 virtual MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const;
582 virtual MX get_nz_ref(
const MX& nz)
const;
587 virtual MX get_nz_ref(
const MX& inner,
const Slice& outer)
const;
592 virtual MX get_nz_ref(
const Slice& inner,
const MX& outer)
const;
597 virtual MX get_nz_ref(
const MX& inner,
const MX& outer)
const;
605 virtual MX get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const;
613 virtual MX get_nzadd(
const MX& y,
const std::vector<casadi_int>& nz)
const;
621 virtual MX get_nzassign(
const MX& y,
const MX& nz)
const;
629 virtual MX get_nzassign(
const MX& y,
const MX& inner,
const Slice& outer)
const;
637 virtual MX get_nzassign(
const MX& y,
const Slice& inner,
const MX& outer)
const;
645 virtual MX get_nzassign(
const MX& y,
const MX& inner,
const MX& outer)
const;
653 virtual MX get_nzadd(
const MX& y,
const MX& nz)
const;
661 virtual MX get_nzadd(
const MX& y,
const MX& inner,
const Slice& outer)
const;
669 virtual MX get_nzadd(
const MX& y,
const Slice& inner,
const MX& outer)
const;
677 virtual MX get_nzadd(
const MX& y,
const MX& inner,
const MX& outer)
const;
680 virtual MX get_subref(
const Slice& i,
const Slice& j)
const;
683 virtual MX get_subassign(
const MX& y,
const Slice& i,
const Slice& j)
const;
686 virtual MX get_project(
const Sparsity& sp)
const;
689 virtual MX get_unary(casadi_int op)
const;
692 MX get_binary(casadi_int op,
const MX& y)
const;
695 virtual MX _get_binary(casadi_int op,
const MX& y,
bool scX,
bool scY)
const;
698 virtual MX get_det()
const;
701 virtual MX get_inv()
const;
704 virtual MX get_dot(
const MX& y)
const;
707 virtual MX get_norm_fro()
const;
710 virtual MX get_norm_2()
const;
713 virtual MX get_norm_inf()
const;
716 virtual MX get_norm_1()
const;
719 virtual MX get_mmin()
const;
722 virtual MX get_mmax()
const;
725 MX get_assert(
const MX& y,
const std::string& fail_message)
const;
728 MX get_monitor(
const std::string& comment)
const;
734 MX get_low(
const MX& v,
const Dict& options)
const;
737 MX get_bspline(
const std::vector<double>& knots,
738 const std::vector<casadi_int>& offset,
739 const std::vector<double>& coeffs,
740 const std::vector<casadi_int>& degree,
742 const std::vector<casadi_int>& lookup_mode)
const;
744 MX get_bspline(
const MX& coeffs,
const std::vector<double>& knots,
745 const std::vector<casadi_int>& offset,
746 const std::vector<casadi_int>& degree,
748 const std::vector<casadi_int>& lookup_mode)
const;
751 MX get_convexify(
const Dict& opts)
const;
772 static void copy_fwd(
const bvec_t* arg,
bvec_t* res, casadi_int len);
777 static void copy_rev(
bvec_t* arg,
bvec_t* res, casadi_int len);
Helper class for C code generation.
Helper class for Serialization.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
Node class for MX objects.
virtual bool has_output() const
Check if a multiple output node.
void eval_linear_unary(const std::vector< std::array< MX, 3 > > &arg, std::vector< std::array< MX, 3 > > &res) const
Evaluate the MX node on a const/linear/nonlinear partition.
virtual bool is_zero() const
Check if identically zero.
virtual size_t sz_arg() const
Get required length of arg field.
virtual bool is_valid_input() const
Check if valid function input.
static bool maxDepth()
Get equality checking depth.
virtual size_t sz_w() const
Get required length of w field.
virtual bool is_one() const
Check if identically one.
virtual bool is_binary() const
Check if binary operation.
virtual void add_dependency(CodeGenerator &g) const
Add a dependent function.
virtual casadi_int n_inplace() const
Can the operation be performed inplace (i.e. overwrite the result)
std::pair< casadi_int, casadi_int > size() const
Sparsity sparsity_
The sparsity pattern.
casadi_int numel() const
Get shape.
static std::map< casadi_int, MXNode *(*)(DeserializingStream &)> deserialize_map
const Sparsity & sparsity() const
Get the sparsity.
virtual size_t sz_res() const
Get required length of res field.
casadi_int nnz(casadi_int i=0) const
bool matches_sparsity(const std::vector< T > &arg) const
virtual void codegen_incref(CodeGenerator &g, std::set< void * > &added) const
Codegen incref.
virtual casadi_int nout() const
Number of outputs.
virtual bool is_value(double val) const
Check if a certain value.
const MX & dep(casadi_int ind=0) const
dependencies - functions that have to be evaluated before this one
std::vector< MX > dep_
dependencies - functions that have to be evaluated before this one
virtual bool is_unary() const
Check if unary operation.
virtual void codegen_decref(CodeGenerator &g, std::set< void * > &added) const
Codegen decref.
virtual casadi_int op() const =0
Get the operation.
virtual bool has_refcount() const
Is reference counting needed in codegen?
virtual bool is_minus_one() const
Check if identically minus one.
virtual std::string disp(const std::vector< std::string > &arg) const =0
Print expression.
virtual bool is_output() const
Check if evaluation output.
virtual bool is_equal(const MXNode *node, casadi_int depth) const
virtual size_t sz_iw() const
Get required length of iw field.
virtual bool is_eye() const
Check if identity matrix.
static MX to_matrix(const MX &x, const Sparsity &sp)
Convert scalar to matrix.
static casadi_int get_max_depth()
Get the depth to which equalities are being checked for simplifications.
Sparse matrix class. SX and DM are specializations.
The basic scalar symbolic class of CasADi.
Helper class for Serialization.
Class representing a Slice.
std::pair< casadi_int, casadi_int > size() const
Get the shape.
bool is_equal(double x, double y, casadi_int depth=0)
unsigned long long bvec_t
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.