26 #ifndef CASADI_CONSTANT_MX_HPP
27 #define CASADI_CONSTANT_MX_HPP
29 #include "mx_node.hpp"
32 #include "serializing_stream.hpp"
48 class CASADI_EXPORT ConstantMX :
public MXNode {
51 explicit ConstantMX(
const Sparsity& sp);
54 ~ConstantMX()
override = 0;
57 static ConstantMX* create(
const Sparsity& sp, casadi_int val);
58 static ConstantMX* create(
const Sparsity& sp,
int val) {
59 return create(sp,
static_cast<casadi_int
>(val));
63 static ConstantMX* create(
const Sparsity& sp,
double val);
66 static ConstantMX* create(
const Matrix<double>& val);
69 static ConstantMX* create(
const Sparsity& sp,
const std::string& fname);
72 static ConstantMX* create(
const Matrix<double>& val,
const std::string& name);
75 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override = 0;
78 int eval_sx(
const SXElem** arg, SXElem** res,
79 casadi_int* iw, SXElem* w)
const override = 0;
84 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
89 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
90 std::vector<std::vector<MX> >& fsens)
const override;
95 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
96 std::vector<std::vector<MX> >& asens)
const override;
101 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
106 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
111 casadi_int op()
const override {
return OP_CONST;}
114 double to_double()
const override = 0;
117 Matrix<double> get_DM()
const override = 0;
123 MX get_dot(
const MX& y)
const override;
126 bool __nonzero__()
const override;
131 bool is_valid_input()
const override;
136 casadi_int n_primitives()
const override;
141 void primitives(std::vector<MX>::iterator& it)
const override;
145 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
151 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
152 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
153 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
158 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
164 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
165 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
166 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
172 bool has_duplicates()
const override {
return false;}
177 void reset_input()
const override {}
182 static MXNode* deserialize(DeserializingStream& s);
187 explicit ConstantMX(DeserializingStream& s) : MXNode(s) {}
191 class CASADI_EXPORT ConstantDM :
public ConstantMX {
197 explicit ConstantDM(
const Matrix<double>& x) : ConstantMX(x.sparsity()), x_(x) {}
200 ~ConstantDM()
override {}
205 std::string disp(
const std::vector<std::string>& arg)
const override {
212 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
213 std::copy(x_->begin(), x_->end(), res[0]);
220 int eval_sx(
const SXElem** arg, SXElem** res,
221 casadi_int* iw, SXElem* w)
const override {
222 std::copy(x_->begin(), x_->end(), res[0]);
229 void generate(CodeGenerator& g,
230 const std::vector<casadi_int>& arg,
231 const std::vector<casadi_int>& res,
232 const std::vector<bool>& arg_is_ref,
233 std::vector<bool>& res_is_ref)
const override;
239 bool is_one()
const override;
240 bool is_minus_one()
const override;
241 bool is_eye()
const override;
244 double to_double()
const override {
return x_.scalar();}
247 Matrix<double> get_DM()
const override {
return x_;}
252 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
262 void serialize_body(SerializingStream& s)
const override;
266 void serialize_type(SerializingStream& s)
const override;
271 explicit ConstantDM(DeserializingStream& s);
275 class CASADI_EXPORT ConstantFile :
public ConstantMX {
281 explicit ConstantFile(
const Sparsity& x,
const std::string& fname);
284 ~ConstantFile()
override {}
289 void codegen_incref(CodeGenerator& g, std::set<void*>& added)
const override;
294 std::string disp(
const std::vector<std::string>& arg)
const override;
297 double to_double()
const override;
300 Matrix<double> get_DM()
const override;
305 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
306 std::copy(x_.begin(), x_.end(), res[0]);
313 int eval_sx(
const SXElem** arg, SXElem** res,
314 casadi_int* iw, SXElem* w)
const override {
315 std::copy(x_.begin(), x_.end(), res[0]);
322 void generate(CodeGenerator& g,
323 const std::vector<casadi_int>& arg,
324 const std::vector<casadi_int>& res,
325 const std::vector<bool>& arg_is_ref,
326 std::vector<bool>& res_is_ref)
const override;
331 void add_dependency(CodeGenerator& g)
const override;
341 std::vector<double> x_;
346 void serialize_body(SerializingStream& s)
const override;
350 void serialize_type(SerializingStream& s)
const override;
355 explicit ConstantFile(DeserializingStream& s);
359 class CASADI_EXPORT ConstantPool :
public ConstantMX {
365 explicit ConstantPool(
const DM& x,
const std::string& name);
368 ~ConstantPool()
override {}
373 std::string disp(
const std::vector<std::string>& arg)
const override;
376 double to_double()
const override;
379 Matrix<double> get_DM()
const override;
384 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
385 if (res[0]) std::copy(x_.begin(), x_.end(), res[0]);
392 int eval_sx(
const SXElem** arg, SXElem** res,
393 casadi_int* iw, SXElem* w)
const override {
394 casadi_error(
"eval_sx not supported");
401 void generate(CodeGenerator& g,
402 const std::vector<casadi_int>& arg,
403 const std::vector<casadi_int>& res,
404 const std::vector<bool>& arg_is_ref,
405 std::vector<bool>& res_is_ref)
const override;
410 void add_dependency(CodeGenerator& g)
const override;
420 std::vector<double> x_;
425 void serialize_body(SerializingStream& s)
const override;
430 void serialize_type(SerializingStream& s)
const override;
435 explicit ConstantPool(DeserializingStream& s);
439 class CASADI_EXPORT ZeroByZero :
public ConstantMX {
444 explicit ZeroByZero() : ConstantMX(Sparsity(0, 0)) {
452 static ZeroByZero* getInstance() {
453 static ZeroByZero instance;
458 ~ZeroByZero()
override {
465 std::string disp(
const std::vector<std::string>& arg)
const override;
471 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
476 int eval_sx(
const SXElem** arg, SXElem** res,
477 casadi_int* iw, SXElem* w)
const override {
484 void generate(CodeGenerator& g,
485 const std::vector<casadi_int>& arg,
486 const std::vector<casadi_int>& res,
487 const std::vector<bool>& arg_is_ref,
488 std::vector<bool>& res_is_ref)
const override {}
491 double to_double()
const override {
return 0;}
494 DM get_DM()
const override {
return DM(); }
497 MX get_project(
const Sparsity& sp)
const override;
500 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
503 MX get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const override;
506 MX get_transpose()
const override;
509 MX get_unary(casadi_int op)
const override;
512 MX _get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const override;
515 MX get_reshape(
const Sparsity& sp)
const override;
520 bool is_valid_input()
const override {
return true;}
525 const std::string& name()
const override {
526 static std::string dummyname;
533 void serialize_type(SerializingStream& s)
const override;
537 void serialize_body(SerializingStream& s)
const override;
545 struct RuntimeConst {
548 RuntimeConst(T v) : value(v) {}
549 static char type_char();
550 void serialize_type(SerializingStream& s)
const {
551 s.pack(
"Constant::value", value);
553 static RuntimeConst deserialize(DeserializingStream& s) {
555 s.unpack(
"Constant::value", v);
556 return RuntimeConst(v);
561 inline char RuntimeConst<T>::type_char() {
return 'u'; }
564 inline char RuntimeConst<casadi_int>::type_char() {
return 'I'; }
567 inline char RuntimeConst<double>::type_char() {
return 'D'; }
570 struct CompiletimeConst {
571 static const int value = v;
572 static char type_char();
573 void serialize_type(SerializingStream& s)
const {}
574 static CompiletimeConst deserialize(DeserializingStream& s) {
575 return CompiletimeConst();
580 inline char CompiletimeConst<v>::type_char() {
return 'u'; }
583 inline char CompiletimeConst<0>::type_char() {
return '0'; }
585 inline char CompiletimeConst<(-1)>::type_char() {
return 'm'; }
587 inline char CompiletimeConst<1>::type_char() {
return '1'; }
590 template<
typename Value>
591 class CASADI_EXPORT Constant :
public ConstantMX {
597 explicit Constant(
const Sparsity& sp, Value v = Value()) : ConstantMX(sp), v_(v) {}
602 explicit Constant(DeserializingStream& s,
const Value& v);
605 ~Constant()
override {}
610 std::string disp(
const std::vector<std::string>& arg)
const override;
616 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
619 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
624 void generate(CodeGenerator& g,
625 const std::vector<casadi_int>& arg,
626 const std::vector<casadi_int>& res,
627 const std::vector<bool>& arg_is_ref,
628 std::vector<bool>& res_is_ref)
const override;
633 bool is_zero()
const override {
return v_.value==0;}
634 bool is_one()
const override {
return v_.value==1;}
635 bool is_eye()
const override {
return v_.value==1 && sparsity().is_diag();}
636 bool is_value(
double val)
const override {
return v_.value==val;}
639 double to_double()
const override {
640 return static_cast<double>(v_.value);
644 Matrix<double> get_DM()
const override {
645 return Matrix<double>(sparsity(), to_double(),
false);
649 MX get_project(
const Sparsity& sp)
const override;
652 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
655 MX get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const override;
658 MX get_transpose()
const override;
661 MX get_unary(casadi_int op)
const override;
664 MX _get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const override;
667 MX get_reshape(
const Sparsity& sp)
const override;
670 MX get_horzcat(
const std::vector<MX>& x)
const override;
673 MX get_vertcat(
const std::vector<MX>& x)
const override;
678 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
683 void serialize_body(SerializingStream& s)
const override;
687 void serialize_type(SerializingStream& s)
const override;
692 template<
typename Value>
693 void Constant<Value>::serialize_type(SerializingStream& s)
const {
695 s.pack(
"ConstantMX::type", Value::type_char());
696 v_.serialize_type(s);
699 template<
typename Value>
700 void Constant<Value>::serialize_body(SerializingStream& s)
const {
704 template<
typename Value>
705 Constant<Value>::Constant(DeserializingStream& s,
const Value& v) : ConstantMX(s), v_(v) {
708 template<
typename Value>
709 MX Constant<Value>::get_horzcat(
const std::vector<MX>& x)
const {
712 if (!i->is_value(to_double())) {
714 return ConstantMX::get_horzcat(x);
719 std::vector<Sparsity> sp;
720 for (
auto&& i : x) sp.push_back(i.sparsity());
721 return MX(horzcat(sp), v_.value,
false);
724 template<
typename Value>
725 MX Constant<Value>::get_vertcat(
const std::vector<MX>& x)
const {
728 if (!i->is_value(to_double())) {
730 return ConstantMX::get_vertcat(x);
735 std::vector<Sparsity> sp;
736 for (
auto&& i : x) sp.push_back(i.sparsity());
737 return MX(vertcat(sp), v_.value,
false);
740 template<
typename Value>
741 MX Constant<Value>::get_reshape(
const Sparsity& sp)
const {
742 return MX::create(
new Constant<Value>(sp, v_));
745 template<
typename Value>
746 MX Constant<Value>::get_transpose()
const {
747 return MX::create(
new Constant<Value>(sparsity().
T(), v_));
750 template<
typename Value>
751 MX Constant<Value>::get_unary(casadi_int op)
const {
754 casadi_math<double>::fun(op, to_double(), 0.0, ret);
755 if (operation_checker<F0XChecker>(op) || sparsity().is_dense()) {
756 return MX(sparsity(), ret);
759 if (
is_zero() && operation_checker<F0XChecker>(op)) {
760 return MX(sparsity(), ret,
false);
762 return repmat(MX(ret), size1(), size2());
766 casadi_math<double>::fun(op, 0, 0.0, ret2);
767 return DM(sparsity(), ret,
false)
768 +
DM(sparsity().pattern_inverse(), ret2,
false);
772 template<
typename Value>
773 MX Constant<Value>::_get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const {
774 casadi_assert_dev(sparsity()==y.sparsity() || ScX || ScY);
776 if (ScX && !operation_checker<FX0Checker>(op)) {
778 casadi_math<double>::fun(op, nnz()> 0 ? to_double(): 0.0, 0, ret);
781 Sparsity f = Sparsity::dense(y.size1(), y.size2());
782 MX yy = project(y, f);
783 return MX(f, shared_from_this<MX>())->_get_binary(op, yy,
false,
false);
785 }
else if (ScY && !operation_checker<F0XChecker>(op)) {
787 if (y->op()==OP_CONST &&
dynamic_cast<const ConstantDM*
>(y.get())==
nullptr) {
789 casadi_math<double>::fun(op, 0, y.nnz()>0 ? y->to_double() : 0, ret);
793 Sparsity f = Sparsity::dense(size1(), size2());
794 MX xx = project(shared_from_this<MX>(), f);
795 return xx->_get_binary(op, MX(f, y),
false,
false);
801 if (v_.value==0)
return ScY && !y->is_zero() ? repmat(y, size1(), size2()) : y;
804 if (v_.value==0)
return ScY && !y->is_zero() ? repmat(-y, size1(), size2()) : -y;
807 if (v_.value==1)
return y;
808 if (v_.value==-1)
return -y;
809 if (v_.value==2)
return y->get_unary(OP_TWICE);
812 if (v_.value==1)
return y->get_unary(OP_INV);
813 if (v_.value==-1)
return -y->get_unary(OP_INV);
817 if (v_.value==1)
return MX::ones(y.sparsity());
818 if (v_.value==std::exp(1.0))
return y->get_unary(OP_EXP);
825 if (y->op()==OP_CONST &&
dynamic_cast<const ConstantDM*
>(y.get())==
nullptr) {
826 double y_value = y.nnz()>0 ? y->to_double() : 0;
828 casadi_math<double>::fun(op, nnz()> 0.0 ? to_double(): 0, y_value, ret);
830 return MX(y.sparsity(), ret,
false);
834 return MXNode::_get_binary(op, y, ScX, ScY);
837 template<
typename Value>
838 int Constant<Value>::eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
839 std::fill(res[0], res[0]+nnz(), to_double());
843 template<
typename Value>
844 int Constant<Value>::
845 eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const {
846 std::fill(res[0], res[0]+nnz(), SXElem(v_.value));
850 template<
typename Value>
851 void Constant<Value>::generate(CodeGenerator& g,
852 const std::vector<casadi_int>& arg,
853 const std::vector<casadi_int>& res,
854 const std::vector<bool>& arg_is_ref,
855 std::vector<bool>& res_is_ref)
const {
858 }
else if (nnz()==1) {
859 g << g.workel(res[0]) <<
" = " << g.constant(to_double()) <<
";\n";
861 if (to_double()==0) {
862 g << g.clear(g.work(res[0], nnz(),
false), nnz()) <<
'\n';
864 g << g.fill(g.work(res[0], nnz(),
false), nnz(), g.constant(to_double())) <<
'\n';
869 template<
typename Value>
870 MX Constant<Value>::get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const {
873 for (std::vector<casadi_int>::const_iterator k=nz.begin(); k!=nz.end(); ++k) {
876 return MXNode::get_nzref(sp, nz);
880 return MX::create(
new Constant<Value>(sp, v_));
883 template<
typename Value>
884 MX Constant<Value>::get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const {
885 if (y.is_constant() && y->is_zero() && v_.value==0) {
890 return MXNode::get_nzassign(y, nz);
893 template<
typename Value>
894 MX Constant<Value>::get_project(
const Sparsity& sp)
const {
896 return MX::create(
new Constant<Value>(sp, v_));
897 }
else if (sp.is_dense()) {
898 return densify(get_DM());
900 return MXNode::get_project(sp);
904 template<
typename Value>
906 Constant<Value>::disp(
const std::vector<std::string>& arg)
const {
907 std::stringstream ss;
908 if (sparsity().is_scalar()) {
910 if (sparsity().nnz()==0) {
915 }
else if (sparsity().is_empty()) {
922 }
else if (v_.value==1) {
924 }
else if (v_.value!=v_.value) {
926 }
else if (v_.value==std::numeric_limits<double>::infinity()) {
928 }
else if (v_.value==-std::numeric_limits<double>::infinity()) {
931 ss <<
"all_" << v_.value <<
"(";
941 template<
typename Value>
942 bool Constant<Value>::is_equal(
const MXNode* node, casadi_int depth)
const {
943 return node->is_value(to_double()) && sparsity()==node->sparsity();
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.