constant_mx.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_CONSTANT_MX_HPP
27 #define CASADI_CONSTANT_MX_HPP
28 
29 #include "mx_node.hpp"
30 #include <iomanip>
31 #include <iostream>
32 #include "serializing_stream.hpp"
33 
35 
36 namespace casadi {
37 
48  class CASADI_EXPORT ConstantMX : public MXNode {
49  public:
51  explicit ConstantMX(const Sparsity& sp);
52 
54  ~ConstantMX() override = 0;
55 
56  // Creator (all values are the same integer)
57  static ConstantMX* create(const Sparsity& sp, casadi_int val);
58  static ConstantMX* create(const Sparsity& sp, int val) {
59  return create(sp, static_cast<casadi_int>(val));
60  }
61 
62  // Creator (all values are the same floating point value)
63  static ConstantMX* create(const Sparsity& sp, double val);
64 
65  // Creator (values may be different)
66  static ConstantMX* create(const Matrix<double>& val);
67 
68  // Creator (values may be different)
69  static ConstantMX* create(const Sparsity& sp, const std::string& fname);
70 
71  // Creator (values may be different)
72  static ConstantMX* create(const Matrix<double>& val, const std::string& name);
73 
75  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override = 0;
76 
78  int eval_sx(const SXElem** arg, SXElem** res,
79  casadi_int* iw, SXElem* w) const override = 0;
80 
84  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
85 
89  void ad_forward(const std::vector<std::vector<MX> >& fseed,
90  std::vector<std::vector<MX> >& fsens) const override;
91 
95  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
96  std::vector<std::vector<MX> >& asens) const override;
97 
101  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
102 
106  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
107 
111  casadi_int op() const override { return OP_CONST;}
112 
114  double to_double() const override = 0;
115 
117  Matrix<double> get_DM() const override = 0;
118 
120  // virtual MX get_mac(const MX& y) const;
121 
123  MX get_dot(const MX& y) const override;
124 
126  bool __nonzero__() const override;
127 
131  bool is_valid_input() const override;
132 
136  casadi_int n_primitives() const override;
137 
141  void primitives(std::vector<MX>::iterator& it) const override;
142 
144  template<typename T>
145  void split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const;
146 
148 
151  void split_primitives(const MX& x, std::vector<MX>::iterator& it) const override;
152  void split_primitives(const SX& x, std::vector<SX>::iterator& it) const override;
153  void split_primitives(const DM& x, std::vector<DM>::iterator& it) const override;
155 
157  template<typename T>
158  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
159 
161 
164  MX join_primitives(std::vector<MX>::const_iterator& it) const override;
165  SX join_primitives(std::vector<SX>::const_iterator& it) const override;
166  DM join_primitives(std::vector<DM>::const_iterator& it) const override;
168 
172  bool has_duplicates() const override { return false;}
173 
177  void reset_input() const override {}
178 
182  static MXNode* deserialize(DeserializingStream& s);
183 
187  explicit ConstantMX(DeserializingStream& s) : MXNode(s) {}
188  };
189 
191  class CASADI_EXPORT ConstantDM : public ConstantMX {
192  public:
193 
197  explicit ConstantDM(const Matrix<double>& x) : ConstantMX(x.sparsity()), x_(x) {}
198 
200  ~ConstantDM() override {}
201 
205  std::string disp(const std::vector<std::string>& arg) const override {
206  return x_.get_str();
207  }
208 
212  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
213  std::copy(x_->begin(), x_->end(), res[0]);
214  return 0;
215  }
216 
220  int eval_sx(const SXElem** arg, SXElem** res,
221  casadi_int* iw, SXElem* w) const override {
222  std::copy(x_->begin(), x_->end(), res[0]);
223  return 0;
224  }
225 
229  void generate(CodeGenerator& g,
230  const std::vector<casadi_int>& arg,
231  const std::vector<casadi_int>& res,
232  const std::vector<bool>& arg_is_ref,
233  std::vector<bool>& res_is_ref) const override;
234 
238  bool is_zero() const override;
239  bool is_one() const override;
240  bool is_minus_one() const override;
241  bool is_eye() const override;
242 
244  double to_double() const override {return x_.scalar();}
245 
247  Matrix<double> get_DM() const override { return x_;}
248 
252  bool is_equal(const MXNode* node, casadi_int depth) const override;
253 
257  Matrix<double> x_;
258 
262  void serialize_body(SerializingStream& s) const override;
266  void serialize_type(SerializingStream& s) const override;
267 
271  explicit ConstantDM(DeserializingStream& s);
272  };
273 
275  class CASADI_EXPORT ConstantFile : public ConstantMX {
276  public:
277 
281  explicit ConstantFile(const Sparsity& x, const std::string& fname);
282 
284  ~ConstantFile() override {}
285 
289  void codegen_incref(CodeGenerator& g, std::set<void*>& added) const override;
290 
294  std::string disp(const std::vector<std::string>& arg) const override;
295 
297  double to_double() const override;
298 
300  Matrix<double> get_DM() const override;
301 
305  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
306  std::copy(x_.begin(), x_.end(), res[0]);
307  return 0;
308  }
309 
313  int eval_sx(const SXElem** arg, SXElem** res,
314  casadi_int* iw, SXElem* w) const override {
315  std::copy(x_.begin(), x_.end(), res[0]);
316  return 0;
317  }
318 
322  void generate(CodeGenerator& g,
323  const std::vector<casadi_int>& arg,
324  const std::vector<casadi_int>& res,
325  const std::vector<bool>& arg_is_ref,
326  std::vector<bool>& res_is_ref) const override;
327 
331  void add_dependency(CodeGenerator& g) const override;
332 
336  std::string fname_;
337 
341  std::vector<double> x_;
342 
346  void serialize_body(SerializingStream& s) const override;
350  void serialize_type(SerializingStream& s) const override;
351 
355  explicit ConstantFile(DeserializingStream& s);
356  };
357 
359  class CASADI_EXPORT ConstantPool : public ConstantMX {
360  public:
361 
365  explicit ConstantPool(const DM& x, const std::string& name);
366 
368  ~ConstantPool() override {}
369 
373  std::string disp(const std::vector<std::string>& arg) const override;
374 
376  double to_double() const override;
377 
379  Matrix<double> get_DM() const override;
380 
384  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
385  if (res[0]) std::copy(x_.begin(), x_.end(), res[0]);
386  return 0;
387  }
388 
392  int eval_sx(const SXElem** arg, SXElem** res,
393  casadi_int* iw, SXElem* w) const override {
394  casadi_error("eval_sx not supported");
395  return 0;
396  }
397 
401  void generate(CodeGenerator& g,
402  const std::vector<casadi_int>& arg,
403  const std::vector<casadi_int>& res,
404  const std::vector<bool>& arg_is_ref,
405  std::vector<bool>& res_is_ref) const override;
406 
410  void add_dependency(CodeGenerator& g) const override;
411 
415  std::string name_;
416 
420  std::vector<double> x_;
421 
425  void serialize_body(SerializingStream& s) const override;
426 
430  void serialize_type(SerializingStream& s) const override;
431 
435  explicit ConstantPool(DeserializingStream& s);
436  };
437 
439  class CASADI_EXPORT ZeroByZero : public ConstantMX {
440  private:
444  explicit ZeroByZero() : ConstantMX(Sparsity(0, 0)) {
445  initSingleton();
446  }
447 
448  public:
452  static ZeroByZero* getInstance() {
453  static ZeroByZero instance;
454  return &instance;
455  }
456 
458  ~ZeroByZero() override {
459  destroySingleton();
460  }
461 
465  std::string disp(const std::vector<std::string>& arg) const override;
466 
471  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
472  return 0;
473  }
474 
476  int eval_sx(const SXElem** arg, SXElem** res,
477  casadi_int* iw, SXElem* w) const override {
478  return 0;
479  }
480 
484  void generate(CodeGenerator& g,
485  const std::vector<casadi_int>& arg,
486  const std::vector<casadi_int>& res,
487  const std::vector<bool>& arg_is_ref,
488  std::vector<bool>& res_is_ref) const override {}
489 
491  double to_double() const override { return 0;}
492 
494  DM get_DM() const override { return DM(); }
495 
497  MX get_project(const Sparsity& sp) const override;
498 
500  MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const override;
501 
503  MX get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const override;
504 
506  MX get_transpose() const override;
507 
509  MX get_unary(casadi_int op) const override;
510 
512  MX _get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const override;
513 
515  MX get_reshape(const Sparsity& sp) const override;
516 
520  bool is_valid_input() const override { return true;}
521 
525  const std::string& name() const override {
526  static std::string dummyname;
527  return dummyname;
528  }
529 
533  void serialize_type(SerializingStream& s) const override;
537  void serialize_body(SerializingStream& s) const override;
538 
539  };
540 
544  template<typename T>
545  struct RuntimeConst {
546  const T value;
547  RuntimeConst() {}
548  RuntimeConst(T v) : value(v) {}
549  static char type_char();
550  void serialize_type(SerializingStream& s) const {
551  s.pack("Constant::value", value);
552  }
553  static RuntimeConst deserialize(DeserializingStream& s) {
554  T v;
555  s.unpack("Constant::value", v);
556  return RuntimeConst(v);
557  }
558  };
559 
560  template<typename T>
561  inline char RuntimeConst<T>::type_char() { return 'u'; }
562 
563  template<>
564  inline char RuntimeConst<casadi_int>::type_char() { return 'I'; }
565 
566  template<>
567  inline char RuntimeConst<double>::type_char() { return 'D'; }
568 
569  template<int v>
570  struct CompiletimeConst {
571  static const int value = v;
572  static char type_char();
573  void serialize_type(SerializingStream& s) const {}
574  static CompiletimeConst deserialize(DeserializingStream& s) {
575  return CompiletimeConst();
576  }
577  };
578 
579  template<int v>
580  inline char CompiletimeConst<v>::type_char() { return 'u'; }
581 
582  template<>
583  inline char CompiletimeConst<0>::type_char() { return '0'; }
584  template<>
585  inline char CompiletimeConst<(-1)>::type_char() { return 'm'; }
586  template<>
587  inline char CompiletimeConst<1>::type_char() { return '1'; }
588 
590  template<typename Value>
591  class CASADI_EXPORT Constant : public ConstantMX {
592  public:
593 
597  explicit Constant(const Sparsity& sp, Value v = Value()) : ConstantMX(sp), v_(v) {}
598 
602  explicit Constant(DeserializingStream& s, const Value& v);
603 
605  ~Constant() override {}
606 
610  std::string disp(const std::vector<std::string>& arg) const override;
611 
616  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
617 
619  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
620 
624  void generate(CodeGenerator& g,
625  const std::vector<casadi_int>& arg,
626  const std::vector<casadi_int>& res,
627  const std::vector<bool>& arg_is_ref,
628  std::vector<bool>& res_is_ref) const override;
629 
633  bool is_zero() const override { return v_.value==0;}
634  bool is_one() const override { return v_.value==1;}
635  bool is_eye() const override { return v_.value==1 && sparsity().is_diag();}
636  bool is_value(double val) const override { return v_.value==val;}
637 
639  double to_double() const override {
640  return static_cast<double>(v_.value);
641  }
642 
644  Matrix<double> get_DM() const override {
645  return Matrix<double>(sparsity(), to_double(), false);
646  }
647 
649  MX get_project(const Sparsity& sp) const override;
650 
652  MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const override;
653 
655  MX get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const override;
656 
658  MX get_transpose() const override;
659 
661  MX get_unary(casadi_int op) const override;
662 
664  MX _get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const override;
665 
667  MX get_reshape(const Sparsity& sp) const override;
668 
670  MX get_horzcat(const std::vector<MX>& x) const override;
671 
673  MX get_vertcat(const std::vector<MX>& x) const override;
674 
678  bool is_equal(const MXNode* node, casadi_int depth) const override;
679 
683  void serialize_body(SerializingStream& s) const override;
687  void serialize_type(SerializingStream& s) const override;
688 
689  Value v_;
690  };
691 
692  template<typename Value>
693  void Constant<Value>::serialize_type(SerializingStream& s) const {
695  s.pack("ConstantMX::type", Value::type_char());
696  v_.serialize_type(s);
697  }
698 
699  template<typename Value>
700  void Constant<Value>::serialize_body(SerializingStream& s) const {
702  }
703 
704  template<typename Value>
705  Constant<Value>::Constant(DeserializingStream& s, const Value& v) : ConstantMX(s), v_(v) {
706  }
707 
708  template<typename Value>
709  MX Constant<Value>::get_horzcat(const std::vector<MX>& x) const {
710  // Check if all arguments have the same constant value
711  for (auto&& i : x) {
712  if (!i->is_value(to_double())) {
713  // Not all the same value, fall back to base class
714  return ConstantMX::get_horzcat(x);
715  }
716  }
717 
718  // Assemble the sparsity pattern
719  std::vector<Sparsity> sp;
720  for (auto&& i : x) sp.push_back(i.sparsity());
721  return MX(horzcat(sp), v_.value, false);
722  }
723 
724  template<typename Value>
725  MX Constant<Value>::get_vertcat(const std::vector<MX>& x) const {
726  // Check if all arguments have the same constant value
727  for (auto&& i : x) {
728  if (!i->is_value(to_double())) {
729  // Not all the same value, fall back to base class
730  return ConstantMX::get_vertcat(x);
731  }
732  }
733 
734  // Assemble the sparsity pattern
735  std::vector<Sparsity> sp;
736  for (auto&& i : x) sp.push_back(i.sparsity());
737  return MX(vertcat(sp), v_.value, false);
738  }
739 
740  template<typename Value>
741  MX Constant<Value>::get_reshape(const Sparsity& sp) const {
742  return MX::create(new Constant<Value>(sp, v_));
743  }
744 
745  template<typename Value>
746  MX Constant<Value>::get_transpose() const {
747  return MX::create(new Constant<Value>(sparsity().T(), v_));
748  }
749 
750  template<typename Value>
751  MX Constant<Value>::get_unary(casadi_int op) const {
752  // Constant folding
753  double ret(0);
754  casadi_math<double>::fun(op, to_double(), 0.0, ret);
755  if (operation_checker<F0XChecker>(op) || sparsity().is_dense()) {
756  return MX(sparsity(), ret);
757  } else {
758  if (v_.value==0) {
759  if (is_zero() && operation_checker<F0XChecker>(op)) {
760  return MX(sparsity(), ret, false);
761  } else {
762  return repmat(MX(ret), size1(), size2());
763  }
764  }
765  double ret2;
766  casadi_math<double>::fun(op, 0, 0.0, ret2);
767  return DM(sparsity(), ret, false)
768  + DM(sparsity().pattern_inverse(), ret2, false);
769  }
770  }
771 
772  template<typename Value>
773  MX Constant<Value>::_get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const {
774  casadi_assert_dev(sparsity()==y.sparsity() || ScX || ScY);
775 
776  if (ScX && !operation_checker<FX0Checker>(op)) {
777  double ret;
778  casadi_math<double>::fun(op, nnz()> 0 ? to_double(): 0.0, 0, ret);
779 
780  if (ret!=0) {
781  Sparsity f = Sparsity::dense(y.size1(), y.size2());
782  MX yy = project(y, f);
783  return MX(f, shared_from_this<MX>())->_get_binary(op, yy, false, false);
784  }
785  } else if (ScY && !operation_checker<F0XChecker>(op)) {
786  bool grow = true;
787  if (y->op()==OP_CONST && dynamic_cast<const ConstantDM*>(y.get())==nullptr) {
788  double ret;
789  casadi_math<double>::fun(op, 0, y.nnz()>0 ? y->to_double() : 0, ret);
790  grow = ret!=0;
791  }
792  if (grow) {
793  Sparsity f = Sparsity::dense(size1(), size2());
794  MX xx = project(shared_from_this<MX>(), f);
795  return xx->_get_binary(op, MX(f, y), false, false);
796  }
797  }
798 
799  switch (op) {
800  case OP_ADD:
801  if (v_.value==0) return ScY && !y->is_zero() ? repmat(y, size1(), size2()) : y;
802  break;
803  case OP_SUB:
804  if (v_.value==0) return ScY && !y->is_zero() ? repmat(-y, size1(), size2()) : -y;
805  break;
806  case OP_MUL:
807  if (v_.value==1) return y;
808  if (v_.value==-1) return -y;
809  if (v_.value==2) return y->get_unary(OP_TWICE);
810  break;
811  case OP_DIV:
812  if (v_.value==1) return y->get_unary(OP_INV);
813  if (v_.value==-1) return -y->get_unary(OP_INV);
814  break;
815  case OP_POW:
816  // Note: v_.value can still lead to one when a y entry is zero
817  if (v_.value==1) return MX::ones(y.sparsity());
818  if (v_.value==std::exp(1.0)) return y->get_unary(OP_EXP);
819  break;
820  default: break; //no rule
821  }
822 
823  // Constant folding
824  // NOTE: ugly, should use a function instead of a cast
825  if (y->op()==OP_CONST && dynamic_cast<const ConstantDM*>(y.get())==nullptr) {
826  double y_value = y.nnz()>0 ? y->to_double() : 0;
827  double ret;
828  casadi_math<double>::fun(op, nnz()> 0.0 ? to_double(): 0, y_value, ret);
829 
830  return MX(y.sparsity(), ret, false);
831  }
832 
833  // Fallback
834  return MXNode::_get_binary(op, y, ScX, ScY);
835  }
836 
837  template<typename Value>
838  int Constant<Value>::eval(const double** arg, double** res, casadi_int* iw, double* w) const {
839  std::fill(res[0], res[0]+nnz(), to_double());
840  return 0;
841  }
842 
843  template<typename Value>
844  int Constant<Value>::
845  eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
846  std::fill(res[0], res[0]+nnz(), SXElem(v_.value));
847  return 0;
848  }
849 
850  template<typename Value>
851  void Constant<Value>::generate(CodeGenerator& g,
852  const std::vector<casadi_int>& arg,
853  const std::vector<casadi_int>& res,
854  const std::vector<bool>& arg_is_ref,
855  std::vector<bool>& res_is_ref) const {
856  if (nnz()==0) {
857  // Quick return
858  } else if (nnz()==1) {
859  g << g.workel(res[0]) << " = " << g.constant(to_double()) << ";\n";
860  } else {
861  if (to_double()==0) {
862  g << g.clear(g.work(res[0], nnz(), false), nnz()) << '\n';
863  } else {
864  g << g.fill(g.work(res[0], nnz(), false), nnz(), g.constant(to_double())) << '\n';
865  }
866  }
867  }
868 
869  template<typename Value>
870  MX Constant<Value>::get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const {
871  if (v_.value!=0) {
872  // Check if any "holes"
873  for (std::vector<casadi_int>::const_iterator k=nz.begin(); k!=nz.end(); ++k) {
874  if (*k<0) {
875  // Do not simplify
876  return MXNode::get_nzref(sp, nz);
877  }
878  }
879  }
880  return MX::create(new Constant<Value>(sp, v_));
881  }
882 
883  template<typename Value>
884  MX Constant<Value>::get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const {
885  if (y.is_constant() && y->is_zero() && v_.value==0) {
886  return y;
887  }
888 
889  // Fall-back
890  return MXNode::get_nzassign(y, nz);
891  }
892 
893  template<typename Value>
894  MX Constant<Value>::get_project(const Sparsity& sp) const {
895  if (is_zero()) {
896  return MX::create(new Constant<Value>(sp, v_));
897  } else if (sp.is_dense()) {
898  return densify(get_DM());
899  } else {
900  return MXNode::get_project(sp);
901  }
902  }
903 
904  template<typename Value>
905  std::string
906  Constant<Value>::disp(const std::vector<std::string>& arg) const {
907  std::stringstream ss;
908  if (sparsity().is_scalar()) {
909  // Print scalar
910  if (sparsity().nnz()==0) {
911  ss << "00";
912  } else {
913  ss << v_.value;
914  }
915  } else if (sparsity().is_empty()) {
916  // Print empty
917  sparsity().disp(ss);
918  } else {
919  // Print value
920  if (v_.value==0) {
921  ss << "zeros(";
922  } else if (v_.value==1) {
923  ss << "ones(";
924  } else if (v_.value!=v_.value) {
925  ss << "nan(";
926  } else if (v_.value==std::numeric_limits<double>::infinity()) {
927  ss << "inf(";
928  } else if (v_.value==-std::numeric_limits<double>::infinity()) {
929  ss << "-inf(";
930  } else {
931  ss << "all_" << v_.value << "(";
932  }
933 
934  // Print sparsity
935  sparsity().disp(ss);
936  ss << ")";
937  }
938  return ss.str();
939  }
940 
941  template<typename Value>
942  bool Constant<Value>::is_equal(const MXNode* node, casadi_int depth) const {
943  return node->is_value(to_double()) && sparsity()==node->sparsity();
944  }
945 
946 
947 } // namespace casadi
949 
950 
951 #endif // CASADI_CONSTANT_MX_HPP
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
The casadi namespace.
Definition: archiver.hpp:32
Matrix< SXElem > SX
Definition: sx_fwd.hpp:32
bool is_zero(const T &x)
Matrix< double > DM
Definition: dm_fwd.hpp:33