26 #ifndef CASADI_CONSTANT_MX_HPP
27 #define CASADI_CONSTANT_MX_HPP
29 #include "mx_node.hpp"
32 #include "serializing_stream.hpp"
48 class CASADI_EXPORT ConstantMX :
public MXNode {
51 explicit ConstantMX(
const Sparsity& sp);
54 ~ConstantMX()
override = 0;
57 static ConstantMX* create(
const Sparsity& sp, casadi_int val);
58 static ConstantMX* create(
const Sparsity& sp,
int val) {
59 return create(sp,
static_cast<casadi_int
>(val));
63 static ConstantMX* create(
const Sparsity& sp,
double val);
66 static ConstantMX* create(
const Matrix<double>& val);
69 static ConstantMX* create(
const Sparsity& sp,
const std::string& fname);
72 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override = 0;
75 int eval_sx(
const SXElem** arg, SXElem** res,
76 casadi_int* iw, SXElem* w)
const override = 0;
81 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
86 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
87 std::vector<std::vector<MX> >& fsens)
const override;
92 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
93 std::vector<std::vector<MX> >& asens)
const override;
98 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
103 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
108 casadi_int op()
const override {
return OP_CONST;}
111 double to_double()
const override = 0;
114 Matrix<double> get_DM()
const override = 0;
120 MX get_dot(
const MX& y)
const override;
123 bool __nonzero__()
const override;
128 bool is_valid_input()
const override;
133 casadi_int n_primitives()
const override;
138 void primitives(std::vector<MX>::iterator& it)
const override;
142 void split_primitives_gen(
const T& x,
typename std::vector<T>::iterator& it)
const;
148 void split_primitives(
const MX& x, std::vector<MX>::iterator& it)
const override;
149 void split_primitives(
const SX& x, std::vector<SX>::iterator& it)
const override;
150 void split_primitives(
const DM& x, std::vector<DM>::iterator& it)
const override;
155 T join_primitives_gen(
typename std::vector<T>::const_iterator& it)
const;
161 MX join_primitives(std::vector<MX>::const_iterator& it)
const override;
162 SX join_primitives(std::vector<SX>::const_iterator& it)
const override;
163 DM join_primitives(std::vector<DM>::const_iterator& it)
const override;
169 bool has_duplicates()
const override {
return false;}
174 void reset_input()
const override {}
179 static MXNode* deserialize(DeserializingStream& s);
184 explicit ConstantMX(DeserializingStream& s) : MXNode(s) {}
188 class CASADI_EXPORT ConstantDM :
public ConstantMX {
194 explicit ConstantDM(
const Matrix<double>& x) : ConstantMX(x.sparsity()), x_(x) {}
197 ~ConstantDM()
override {}
202 std::string disp(
const std::vector<std::string>& arg)
const override {
209 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
210 std::copy(x_->begin(), x_->end(), res[0]);
217 int eval_sx(
const SXElem** arg, SXElem** res,
218 casadi_int* iw, SXElem* w)
const override {
219 std::copy(x_->begin(), x_->end(), res[0]);
226 void generate(CodeGenerator& g,
227 const std::vector<casadi_int>& arg,
228 const std::vector<casadi_int>& res)
const override;
234 bool is_one()
const override;
235 bool is_minus_one()
const override;
236 bool is_eye()
const override;
239 double to_double()
const override {
return x_.scalar();}
242 Matrix<double> get_DM()
const override {
return x_;}
247 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
257 void serialize_body(SerializingStream& s)
const override;
261 void serialize_type(SerializingStream& s)
const override;
266 explicit ConstantDM(DeserializingStream& s);
270 class CASADI_EXPORT ConstantFile :
public ConstantMX {
276 explicit ConstantFile(
const Sparsity& x,
const std::string& fname);
279 ~ConstantFile()
override {}
284 void codegen_incref(CodeGenerator& g, std::set<void*>& added)
const override;
289 std::string disp(
const std::vector<std::string>& arg)
const override;
292 double to_double()
const override;
295 Matrix<double> get_DM()
const override;
300 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
301 std::copy(x_.begin(), x_.end(), res[0]);
308 int eval_sx(
const SXElem** arg, SXElem** res,
309 casadi_int* iw, SXElem* w)
const override {
310 std::copy(x_.begin(), x_.end(), res[0]);
317 void generate(CodeGenerator& g,
318 const std::vector<casadi_int>& arg,
319 const std::vector<casadi_int>& res)
const override;
324 void add_dependency(CodeGenerator& g)
const override;
334 std::vector<double> x_;
339 void serialize_body(SerializingStream& s)
const override;
343 void serialize_type(SerializingStream& s)
const override;
348 explicit ConstantFile(DeserializingStream& s);
352 class CASADI_EXPORT ZeroByZero :
public ConstantMX {
357 explicit ZeroByZero() : ConstantMX(Sparsity(0, 0)) {
365 static ZeroByZero* getInstance() {
366 static ZeroByZero instance;
371 ~ZeroByZero()
override {
378 std::string disp(
const std::vector<std::string>& arg)
const override;
384 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override {
389 int eval_sx(
const SXElem** arg, SXElem** res,
390 casadi_int* iw, SXElem* w)
const override {
397 void generate(CodeGenerator& g,
398 const std::vector<casadi_int>& arg,
399 const std::vector<casadi_int>& res)
const override {}
402 double to_double()
const override {
return 0;}
405 DM get_DM()
const override {
return DM(); }
408 MX get_project(
const Sparsity& sp)
const override;
411 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
414 MX get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const override;
417 MX get_transpose()
const override;
420 MX get_unary(casadi_int op)
const override;
423 MX _get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const override;
426 MX get_reshape(
const Sparsity& sp)
const override;
431 bool is_valid_input()
const override {
return true;}
436 const std::string& name()
const override {
437 static std::string dummyname;
444 void serialize_type(SerializingStream& s)
const override;
448 void serialize_body(SerializingStream& s)
const override;
456 struct RuntimeConst {
459 RuntimeConst(T v) : value(v) {}
460 static char type_char();
461 void serialize_type(SerializingStream& s)
const {
462 s.pack(
"Constant::value", value);
464 static RuntimeConst deserialize(DeserializingStream& s) {
466 s.unpack(
"Constant::value", v);
467 return RuntimeConst(v);
472 inline char RuntimeConst<T>::type_char() {
return 'u'; }
475 inline char RuntimeConst<casadi_int>::type_char() {
return 'I'; }
478 inline char RuntimeConst<double>::type_char() {
return 'D'; }
481 struct CompiletimeConst {
482 static const int value = v;
483 static char type_char();
484 void serialize_type(SerializingStream& s)
const {}
485 static CompiletimeConst deserialize(DeserializingStream& s) {
486 return CompiletimeConst();
491 inline char CompiletimeConst<v>::type_char() {
return 'u'; }
494 inline char CompiletimeConst<0>::type_char() {
return '0'; }
496 inline char CompiletimeConst<(-1)>::type_char() {
return 'm'; }
498 inline char CompiletimeConst<1>::type_char() {
return '1'; }
501 template<
typename Value>
502 class CASADI_EXPORT Constant :
public ConstantMX {
508 explicit Constant(
const Sparsity& sp, Value v = Value()) : ConstantMX(sp), v_(v) {}
513 explicit Constant(DeserializingStream& s,
const Value& v);
516 ~Constant()
override {}
521 std::string disp(
const std::vector<std::string>& arg)
const override;
527 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
530 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
535 void generate(CodeGenerator& g,
536 const std::vector<casadi_int>& arg,
537 const std::vector<casadi_int>& res)
const override;
542 bool is_zero()
const override {
return v_.value==0;}
543 bool is_one()
const override {
return v_.value==1;}
544 bool is_eye()
const override {
return v_.value==1 && sparsity().is_diag();}
545 bool is_value(
double val)
const override {
return v_.value==val;}
548 double to_double()
const override {
549 return static_cast<double>(v_.value);
553 Matrix<double> get_DM()
const override {
554 return Matrix<double>(sparsity(), to_double(),
false);
558 MX get_project(
const Sparsity& sp)
const override;
561 MX get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const override;
564 MX get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const override;
567 MX get_transpose()
const override;
570 MX get_unary(casadi_int op)
const override;
573 MX _get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const override;
576 MX get_reshape(
const Sparsity& sp)
const override;
579 MX get_horzcat(
const std::vector<MX>& x)
const override;
582 MX get_vertcat(
const std::vector<MX>& x)
const override;
587 bool is_equal(
const MXNode* node, casadi_int depth)
const override;
592 void serialize_body(SerializingStream& s)
const override;
596 void serialize_type(SerializingStream& s)
const override;
601 template<
typename Value>
602 void Constant<Value>::serialize_type(SerializingStream& s)
const {
604 s.pack(
"ConstantMX::type", Value::type_char());
605 v_.serialize_type(s);
608 template<
typename Value>
609 void Constant<Value>::serialize_body(SerializingStream& s)
const {
613 template<
typename Value>
614 Constant<Value>::Constant(DeserializingStream& s,
const Value& v) : ConstantMX(s), v_(v) {
617 template<
typename Value>
618 MX Constant<Value>::get_horzcat(
const std::vector<MX>& x)
const {
621 if (!i->is_value(to_double())) {
623 return ConstantMX::get_horzcat(x);
628 std::vector<Sparsity> sp;
629 for (
auto&& i : x) sp.push_back(i.sparsity());
630 return MX(horzcat(sp), v_.value,
false);
633 template<
typename Value>
634 MX Constant<Value>::get_vertcat(
const std::vector<MX>& x)
const {
637 if (!i->is_value(to_double())) {
639 return ConstantMX::get_vertcat(x);
644 std::vector<Sparsity> sp;
645 for (
auto&& i : x) sp.push_back(i.sparsity());
646 return MX(vertcat(sp), v_.value,
false);
649 template<
typename Value>
650 MX Constant<Value>::get_reshape(
const Sparsity& sp)
const {
651 return MX::create(
new Constant<Value>(sp, v_));
654 template<
typename Value>
655 MX Constant<Value>::get_transpose()
const {
656 return MX::create(
new Constant<Value>(sparsity().T(), v_));
659 template<
typename Value>
660 MX Constant<Value>::get_unary(casadi_int op)
const {
663 casadi_math<double>::fun(op, to_double(), 0.0, ret);
664 if (operation_checker<F0XChecker>(op) || sparsity().is_dense()) {
665 return MX(sparsity(), ret);
668 if (
is_zero() && operation_checker<F0XChecker>(op)) {
669 return MX(sparsity(), ret,
false);
671 return repmat(MX(ret), size1(), size2());
675 casadi_math<double>::fun(op, 0, 0.0, ret2);
676 return DM(sparsity(), ret,
false)
677 +
DM(sparsity().pattern_inverse(), ret2,
false);
681 template<
typename Value>
682 MX Constant<Value>::_get_binary(casadi_int op,
const MX& y,
bool ScX,
bool ScY)
const {
683 casadi_assert_dev(sparsity()==y.sparsity() || ScX || ScY);
685 if (ScX && !operation_checker<FX0Checker>(op)) {
687 casadi_math<double>::fun(op, nnz()> 0 ? to_double(): 0.0, 0, ret);
690 Sparsity f = Sparsity::dense(y.size1(), y.size2());
691 MX yy = project(y, f);
692 return MX(f, shared_from_this<MX>())->_get_binary(op, yy,
false,
false);
694 }
else if (ScY && !operation_checker<F0XChecker>(op)) {
696 if (y->op()==OP_CONST &&
dynamic_cast<const ConstantDM*
>(y.get())==
nullptr) {
698 casadi_math<double>::fun(op, 0, y.nnz()>0 ? y->to_double() : 0, ret);
702 Sparsity f = Sparsity::dense(size1(), size2());
703 MX xx = project(shared_from_this<MX>(), f);
704 return xx->_get_binary(op, MX(f, y),
false,
false);
710 if (v_.value==0)
return ScY && !y->is_zero() ? repmat(y, size1(), size2()) : y;
713 if (v_.value==0)
return ScY && !y->is_zero() ? repmat(-y, size1(), size2()) : -y;
716 if (v_.value==1)
return y;
717 if (v_.value==-1)
return -y;
718 if (v_.value==2)
return y->get_unary(OP_TWICE);
721 if (v_.value==1)
return y->get_unary(OP_INV);
722 if (v_.value==-1)
return -y->get_unary(OP_INV);
725 if (v_.value==0)
return MX::zeros(y.sparsity());
726 if (v_.value==1)
return MX::ones(y.sparsity());
727 if (v_.value==std::exp(1.0))
return y->get_unary(OP_EXP);
734 if (y->op()==OP_CONST &&
dynamic_cast<const ConstantDM*
>(y.get())==
nullptr) {
735 double y_value = y.nnz()>0 ? y->to_double() : 0;
737 casadi_math<double>::fun(op, nnz()> 0.0 ? to_double(): 0, y_value, ret);
739 return MX(y.sparsity(), ret,
false);
743 return MXNode::_get_binary(op, y, ScX, ScY);
746 template<
typename Value>
747 int Constant<Value>::eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const {
748 std::fill(res[0], res[0]+nnz(), to_double());
752 template<
typename Value>
753 int Constant<Value>::
754 eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const {
755 std::fill(res[0], res[0]+nnz(), SXElem(v_.value));
759 template<
typename Value>
760 void Constant<Value>::generate(CodeGenerator& g,
761 const std::vector<casadi_int>& arg,
762 const std::vector<casadi_int>& res)
const {
765 }
else if (nnz()==1) {
766 g << g.workel(res[0]) <<
" = " << g.constant(to_double()) <<
";\n";
768 if (to_double()==0) {
769 g << g.clear(g.work(res[0], nnz()), nnz()) <<
'\n';
771 g << g.fill(g.work(res[0], nnz()), nnz(), g.constant(to_double())) <<
'\n';
776 template<
typename Value>
777 MX Constant<Value>::get_nzref(
const Sparsity& sp,
const std::vector<casadi_int>& nz)
const {
780 for (std::vector<casadi_int>::const_iterator k=nz.begin(); k!=nz.end(); ++k) {
783 return MXNode::get_nzref(sp, nz);
787 return MX::create(
new Constant<Value>(sp, v_));
790 template<
typename Value>
791 MX Constant<Value>::get_nzassign(
const MX& y,
const std::vector<casadi_int>& nz)
const {
792 if (y.is_constant() && y->is_zero() && v_.value==0) {
797 return MXNode::get_nzassign(y, nz);
800 template<
typename Value>
801 MX Constant<Value>::get_project(
const Sparsity& sp)
const {
803 return MX::create(
new Constant<Value>(sp, v_));
804 }
else if (sp.is_dense()) {
805 return densify(get_DM());
807 return MXNode::get_project(sp);
811 template<
typename Value>
813 Constant<Value>::disp(
const std::vector<std::string>& arg)
const {
814 std::stringstream ss;
815 if (sparsity().is_scalar()) {
817 if (sparsity().nnz()==0) {
822 }
else if (sparsity().is_empty()) {
829 }
else if (v_.value==1) {
831 }
else if (v_.value!=v_.value) {
833 }
else if (v_.value==std::numeric_limits<double>::infinity()) {
835 }
else if (v_.value==-std::numeric_limits<double>::infinity()) {
838 ss <<
"all_" << v_.value <<
"(";
848 template<
typename Value>
849 bool Constant<Value>::is_equal(
const MXNode* node, casadi_int depth)
const {
850 return node->is_value(to_double()) && sparsity()==node->sparsity();
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.