constant_mx.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_CONSTANT_MX_HPP
27 #define CASADI_CONSTANT_MX_HPP
28 
29 #include "mx_node.hpp"
30 #include <iomanip>
31 #include <iostream>
32 #include "serializing_stream.hpp"
33 
35 
36 namespace casadi {
37 
48  class CASADI_EXPORT ConstantMX : public MXNode {
49  public:
51  explicit ConstantMX(const Sparsity& sp);
52 
54  ~ConstantMX() override = 0;
55 
56  // Creator (all values are the same integer)
57  static ConstantMX* create(const Sparsity& sp, casadi_int val);
58  static ConstantMX* create(const Sparsity& sp, int val) {
59  return create(sp, static_cast<casadi_int>(val));
60  }
61 
62  // Creator (all values are the same floating point value)
63  static ConstantMX* create(const Sparsity& sp, double val);
64 
65  // Creator (values may be different)
66  static ConstantMX* create(const Matrix<double>& val);
67 
68  // Creator (values may be different)
69  static ConstantMX* create(const Sparsity& sp, const std::string& fname);
70 
72  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override = 0;
73 
75  int eval_sx(const SXElem** arg, SXElem** res,
76  casadi_int* iw, SXElem* w) const override = 0;
77 
81  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
82 
86  void ad_forward(const std::vector<std::vector<MX> >& fseed,
87  std::vector<std::vector<MX> >& fsens) const override;
88 
92  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
93  std::vector<std::vector<MX> >& asens) const override;
94 
98  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
99 
103  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
104 
108  casadi_int op() const override { return OP_CONST;}
109 
111  double to_double() const override = 0;
112 
114  Matrix<double> get_DM() const override = 0;
115 
117  // virtual MX get_mac(const MX& y) const;
118 
120  MX get_dot(const MX& y) const override;
121 
123  bool __nonzero__() const override;
124 
128  bool is_valid_input() const override;
129 
133  casadi_int n_primitives() const override;
134 
138  void primitives(std::vector<MX>::iterator& it) const override;
139 
141  template<typename T>
142  void split_primitives_gen(const T& x, typename std::vector<T>::iterator& it) const;
143 
145 
148  void split_primitives(const MX& x, std::vector<MX>::iterator& it) const override;
149  void split_primitives(const SX& x, std::vector<SX>::iterator& it) const override;
150  void split_primitives(const DM& x, std::vector<DM>::iterator& it) const override;
152 
154  template<typename T>
155  T join_primitives_gen(typename std::vector<T>::const_iterator& it) const;
156 
158 
161  MX join_primitives(std::vector<MX>::const_iterator& it) const override;
162  SX join_primitives(std::vector<SX>::const_iterator& it) const override;
163  DM join_primitives(std::vector<DM>::const_iterator& it) const override;
165 
169  bool has_duplicates() const override { return false;}
170 
174  void reset_input() const override {}
175 
179  static MXNode* deserialize(DeserializingStream& s);
180 
184  explicit ConstantMX(DeserializingStream& s) : MXNode(s) {}
185  };
186 
188  class CASADI_EXPORT ConstantDM : public ConstantMX {
189  public:
190 
194  explicit ConstantDM(const Matrix<double>& x) : ConstantMX(x.sparsity()), x_(x) {}
195 
197  ~ConstantDM() override {}
198 
202  std::string disp(const std::vector<std::string>& arg) const override {
203  return x_.get_str();
204  }
205 
209  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
210  std::copy(x_->begin(), x_->end(), res[0]);
211  return 0;
212  }
213 
217  int eval_sx(const SXElem** arg, SXElem** res,
218  casadi_int* iw, SXElem* w) const override {
219  std::copy(x_->begin(), x_->end(), res[0]);
220  return 0;
221  }
222 
226  void generate(CodeGenerator& g,
227  const std::vector<casadi_int>& arg,
228  const std::vector<casadi_int>& res) const override;
229 
233  bool is_zero() const override;
234  bool is_one() const override;
235  bool is_minus_one() const override;
236  bool is_eye() const override;
237 
239  double to_double() const override {return x_.scalar();}
240 
242  Matrix<double> get_DM() const override { return x_;}
243 
247  bool is_equal(const MXNode* node, casadi_int depth) const override;
248 
252  Matrix<double> x_;
253 
257  void serialize_body(SerializingStream& s) const override;
261  void serialize_type(SerializingStream& s) const override;
262 
266  explicit ConstantDM(DeserializingStream& s);
267  };
268 
270  class CASADI_EXPORT ConstantFile : public ConstantMX {
271  public:
272 
276  explicit ConstantFile(const Sparsity& x, const std::string& fname);
277 
279  ~ConstantFile() override {}
280 
284  void codegen_incref(CodeGenerator& g, std::set<void*>& added) const override;
285 
289  std::string disp(const std::vector<std::string>& arg) const override;
290 
292  double to_double() const override;
293 
295  Matrix<double> get_DM() const override;
296 
300  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
301  std::copy(x_.begin(), x_.end(), res[0]);
302  return 0;
303  }
304 
308  int eval_sx(const SXElem** arg, SXElem** res,
309  casadi_int* iw, SXElem* w) const override {
310  std::copy(x_.begin(), x_.end(), res[0]);
311  return 0;
312  }
313 
317  void generate(CodeGenerator& g,
318  const std::vector<casadi_int>& arg,
319  const std::vector<casadi_int>& res) const override;
320 
324  void add_dependency(CodeGenerator& g) const override;
325 
329  std::string fname_;
330 
334  std::vector<double> x_;
335 
339  void serialize_body(SerializingStream& s) const override;
343  void serialize_type(SerializingStream& s) const override;
344 
348  explicit ConstantFile(DeserializingStream& s);
349  };
350 
352  class CASADI_EXPORT ZeroByZero : public ConstantMX {
353  private:
357  explicit ZeroByZero() : ConstantMX(Sparsity(0, 0)) {
358  initSingleton();
359  }
360 
361  public:
365  static ZeroByZero* getInstance() {
366  static ZeroByZero instance;
367  return &instance;
368  }
369 
371  ~ZeroByZero() override {
372  destroySingleton();
373  }
374 
378  std::string disp(const std::vector<std::string>& arg) const override;
379 
384  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override {
385  return 0;
386  }
387 
389  int eval_sx(const SXElem** arg, SXElem** res,
390  casadi_int* iw, SXElem* w) const override {
391  return 0;
392  }
393 
397  void generate(CodeGenerator& g,
398  const std::vector<casadi_int>& arg,
399  const std::vector<casadi_int>& res) const override {}
400 
402  double to_double() const override { return 0;}
403 
405  DM get_DM() const override { return DM(); }
406 
408  MX get_project(const Sparsity& sp) const override;
409 
411  MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const override;
412 
414  MX get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const override;
415 
417  MX get_transpose() const override;
418 
420  MX get_unary(casadi_int op) const override;
421 
423  MX _get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const override;
424 
426  MX get_reshape(const Sparsity& sp) const override;
427 
431  bool is_valid_input() const override { return true;}
432 
436  const std::string& name() const override {
437  static std::string dummyname;
438  return dummyname;
439  }
440 
444  void serialize_type(SerializingStream& s) const override;
448  void serialize_body(SerializingStream& s) const override;
449 
450  };
451 
455  template<typename T>
456  struct RuntimeConst {
457  const T value;
458  RuntimeConst() {}
459  RuntimeConst(T v) : value(v) {}
460  static char type_char();
461  void serialize_type(SerializingStream& s) const {
462  s.pack("Constant::value", value);
463  }
464  static RuntimeConst deserialize(DeserializingStream& s) {
465  T v;
466  s.unpack("Constant::value", v);
467  return RuntimeConst(v);
468  }
469  };
470 
471  template<typename T>
472  inline char RuntimeConst<T>::type_char() { return 'u'; }
473 
474  template<>
475  inline char RuntimeConst<casadi_int>::type_char() { return 'I'; }
476 
477  template<>
478  inline char RuntimeConst<double>::type_char() { return 'D'; }
479 
480  template<int v>
481  struct CompiletimeConst {
482  static const int value = v;
483  static char type_char();
484  void serialize_type(SerializingStream& s) const {}
485  static CompiletimeConst deserialize(DeserializingStream& s) {
486  return CompiletimeConst();
487  }
488  };
489 
490  template<int v>
491  inline char CompiletimeConst<v>::type_char() { return 'u'; }
492 
493  template<>
494  inline char CompiletimeConst<0>::type_char() { return '0'; }
495  template<>
496  inline char CompiletimeConst<(-1)>::type_char() { return 'm'; }
497  template<>
498  inline char CompiletimeConst<1>::type_char() { return '1'; }
499 
501  template<typename Value>
502  class CASADI_EXPORT Constant : public ConstantMX {
503  public:
504 
508  explicit Constant(const Sparsity& sp, Value v = Value()) : ConstantMX(sp), v_(v) {}
509 
513  explicit Constant(DeserializingStream& s, const Value& v);
514 
516  ~Constant() override {}
517 
521  std::string disp(const std::vector<std::string>& arg) const override;
522 
527  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
528 
530  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
531 
535  void generate(CodeGenerator& g,
536  const std::vector<casadi_int>& arg,
537  const std::vector<casadi_int>& res) const override;
538 
542  bool is_zero() const override { return v_.value==0;}
543  bool is_one() const override { return v_.value==1;}
544  bool is_eye() const override { return v_.value==1 && sparsity().is_diag();}
545  bool is_value(double val) const override { return v_.value==val;}
546 
548  double to_double() const override {
549  return static_cast<double>(v_.value);
550  }
551 
553  Matrix<double> get_DM() const override {
554  return Matrix<double>(sparsity(), to_double(), false);
555  }
556 
558  MX get_project(const Sparsity& sp) const override;
559 
561  MX get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const override;
562 
564  MX get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const override;
565 
567  MX get_transpose() const override;
568 
570  MX get_unary(casadi_int op) const override;
571 
573  MX _get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const override;
574 
576  MX get_reshape(const Sparsity& sp) const override;
577 
579  MX get_horzcat(const std::vector<MX>& x) const override;
580 
582  MX get_vertcat(const std::vector<MX>& x) const override;
583 
587  bool is_equal(const MXNode* node, casadi_int depth) const override;
588 
592  void serialize_body(SerializingStream& s) const override;
596  void serialize_type(SerializingStream& s) const override;
597 
598  Value v_;
599  };
600 
601  template<typename Value>
602  void Constant<Value>::serialize_type(SerializingStream& s) const {
604  s.pack("ConstantMX::type", Value::type_char());
605  v_.serialize_type(s);
606  }
607 
608  template<typename Value>
609  void Constant<Value>::serialize_body(SerializingStream& s) const {
611  }
612 
613  template<typename Value>
614  Constant<Value>::Constant(DeserializingStream& s, const Value& v) : ConstantMX(s), v_(v) {
615  }
616 
617  template<typename Value>
618  MX Constant<Value>::get_horzcat(const std::vector<MX>& x) const {
619  // Check if all arguments have the same constant value
620  for (auto&& i : x) {
621  if (!i->is_value(to_double())) {
622  // Not all the same value, fall back to base class
623  return ConstantMX::get_horzcat(x);
624  }
625  }
626 
627  // Assemble the sparsity pattern
628  std::vector<Sparsity> sp;
629  for (auto&& i : x) sp.push_back(i.sparsity());
630  return MX(horzcat(sp), v_.value, false);
631  }
632 
633  template<typename Value>
634  MX Constant<Value>::get_vertcat(const std::vector<MX>& x) const {
635  // Check if all arguments have the same constant value
636  for (auto&& i : x) {
637  if (!i->is_value(to_double())) {
638  // Not all the same value, fall back to base class
639  return ConstantMX::get_vertcat(x);
640  }
641  }
642 
643  // Assemble the sparsity pattern
644  std::vector<Sparsity> sp;
645  for (auto&& i : x) sp.push_back(i.sparsity());
646  return MX(vertcat(sp), v_.value, false);
647  }
648 
649  template<typename Value>
650  MX Constant<Value>::get_reshape(const Sparsity& sp) const {
651  return MX::create(new Constant<Value>(sp, v_));
652  }
653 
654  template<typename Value>
655  MX Constant<Value>::get_transpose() const {
656  return MX::create(new Constant<Value>(sparsity().T(), v_));
657  }
658 
659  template<typename Value>
660  MX Constant<Value>::get_unary(casadi_int op) const {
661  // Constant folding
662  double ret(0);
663  casadi_math<double>::fun(op, to_double(), 0.0, ret);
664  if (operation_checker<F0XChecker>(op) || sparsity().is_dense()) {
665  return MX(sparsity(), ret);
666  } else {
667  if (v_.value==0) {
668  if (is_zero() && operation_checker<F0XChecker>(op)) {
669  return MX(sparsity(), ret, false);
670  } else {
671  return repmat(MX(ret), size1(), size2());
672  }
673  }
674  double ret2;
675  casadi_math<double>::fun(op, 0, 0.0, ret2);
676  return DM(sparsity(), ret, false)
677  + DM(sparsity().pattern_inverse(), ret2, false);
678  }
679  }
680 
681  template<typename Value>
682  MX Constant<Value>::_get_binary(casadi_int op, const MX& y, bool ScX, bool ScY) const {
683  casadi_assert_dev(sparsity()==y.sparsity() || ScX || ScY);
684 
685  if (ScX && !operation_checker<FX0Checker>(op)) {
686  double ret;
687  casadi_math<double>::fun(op, nnz()> 0 ? to_double(): 0.0, 0, ret);
688 
689  if (ret!=0) {
690  Sparsity f = Sparsity::dense(y.size1(), y.size2());
691  MX yy = project(y, f);
692  return MX(f, shared_from_this<MX>())->_get_binary(op, yy, false, false);
693  }
694  } else if (ScY && !operation_checker<F0XChecker>(op)) {
695  bool grow = true;
696  if (y->op()==OP_CONST && dynamic_cast<const ConstantDM*>(y.get())==nullptr) {
697  double ret;
698  casadi_math<double>::fun(op, 0, y.nnz()>0 ? y->to_double() : 0, ret);
699  grow = ret!=0;
700  }
701  if (grow) {
702  Sparsity f = Sparsity::dense(size1(), size2());
703  MX xx = project(shared_from_this<MX>(), f);
704  return xx->_get_binary(op, MX(f, y), false, false);
705  }
706  }
707 
708  switch (op) {
709  case OP_ADD:
710  if (v_.value==0) return ScY && !y->is_zero() ? repmat(y, size1(), size2()) : y;
711  break;
712  case OP_SUB:
713  if (v_.value==0) return ScY && !y->is_zero() ? repmat(-y, size1(), size2()) : -y;
714  break;
715  case OP_MUL:
716  if (v_.value==1) return y;
717  if (v_.value==-1) return -y;
718  if (v_.value==2) return y->get_unary(OP_TWICE);
719  break;
720  case OP_DIV:
721  if (v_.value==1) return y->get_unary(OP_INV);
722  if (v_.value==-1) return -y->get_unary(OP_INV);
723  break;
724  case OP_POW:
725  if (v_.value==0) return MX::zeros(y.sparsity());
726  if (v_.value==1) return MX::ones(y.sparsity());
727  if (v_.value==std::exp(1.0)) return y->get_unary(OP_EXP);
728  break;
729  default: break; //no rule
730  }
731 
732  // Constant folding
733  // NOTE: ugly, should use a function instead of a cast
734  if (y->op()==OP_CONST && dynamic_cast<const ConstantDM*>(y.get())==nullptr) {
735  double y_value = y.nnz()>0 ? y->to_double() : 0;
736  double ret;
737  casadi_math<double>::fun(op, nnz()> 0.0 ? to_double(): 0, y_value, ret);
738 
739  return MX(y.sparsity(), ret, false);
740  }
741 
742  // Fallback
743  return MXNode::_get_binary(op, y, ScX, ScY);
744  }
745 
746  template<typename Value>
747  int Constant<Value>::eval(const double** arg, double** res, casadi_int* iw, double* w) const {
748  std::fill(res[0], res[0]+nnz(), to_double());
749  return 0;
750  }
751 
752  template<typename Value>
753  int Constant<Value>::
754  eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
755  std::fill(res[0], res[0]+nnz(), SXElem(v_.value));
756  return 0;
757  }
758 
759  template<typename Value>
760  void Constant<Value>::generate(CodeGenerator& g,
761  const std::vector<casadi_int>& arg,
762  const std::vector<casadi_int>& res) const {
763  if (nnz()==0) {
764  // Quick return
765  } else if (nnz()==1) {
766  g << g.workel(res[0]) << " = " << g.constant(to_double()) << ";\n";
767  } else {
768  if (to_double()==0) {
769  g << g.clear(g.work(res[0], nnz()), nnz()) << '\n';
770  } else {
771  g << g.fill(g.work(res[0], nnz()), nnz(), g.constant(to_double())) << '\n';
772  }
773  }
774  }
775 
776  template<typename Value>
777  MX Constant<Value>::get_nzref(const Sparsity& sp, const std::vector<casadi_int>& nz) const {
778  if (v_.value!=0) {
779  // Check if any "holes"
780  for (std::vector<casadi_int>::const_iterator k=nz.begin(); k!=nz.end(); ++k) {
781  if (*k<0) {
782  // Do not simplify
783  return MXNode::get_nzref(sp, nz);
784  }
785  }
786  }
787  return MX::create(new Constant<Value>(sp, v_));
788  }
789 
790  template<typename Value>
791  MX Constant<Value>::get_nzassign(const MX& y, const std::vector<casadi_int>& nz) const {
792  if (y.is_constant() && y->is_zero() && v_.value==0) {
793  return y;
794  }
795 
796  // Fall-back
797  return MXNode::get_nzassign(y, nz);
798  }
799 
800  template<typename Value>
801  MX Constant<Value>::get_project(const Sparsity& sp) const {
802  if (is_zero()) {
803  return MX::create(new Constant<Value>(sp, v_));
804  } else if (sp.is_dense()) {
805  return densify(get_DM());
806  } else {
807  return MXNode::get_project(sp);
808  }
809  }
810 
811  template<typename Value>
812  std::string
813  Constant<Value>::disp(const std::vector<std::string>& arg) const {
814  std::stringstream ss;
815  if (sparsity().is_scalar()) {
816  // Print scalar
817  if (sparsity().nnz()==0) {
818  ss << "00";
819  } else {
820  ss << v_.value;
821  }
822  } else if (sparsity().is_empty()) {
823  // Print empty
824  sparsity().disp(ss);
825  } else {
826  // Print value
827  if (v_.value==0) {
828  ss << "zeros(";
829  } else if (v_.value==1) {
830  ss << "ones(";
831  } else if (v_.value!=v_.value) {
832  ss << "nan(";
833  } else if (v_.value==std::numeric_limits<double>::infinity()) {
834  ss << "inf(";
835  } else if (v_.value==-std::numeric_limits<double>::infinity()) {
836  ss << "-inf(";
837  } else {
838  ss << "all_" << v_.value << "(";
839  }
840 
841  // Print sparsity
842  sparsity().disp(ss);
843  ss << ")";
844  }
845  return ss.str();
846  }
847 
848  template<typename Value>
849  bool Constant<Value>::is_equal(const MXNode* node, casadi_int depth) const {
850  return node->is_value(to_double()) && sparsity()==node->sparsity();
851  }
852 
853 
854 } // namespace casadi
856 
857 
858 #endif // CASADI_CONSTANT_MX_HPP
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
The casadi namespace.
Matrix< SXElem > SX
Definition: sx_fwd.hpp:32
bool is_zero(const T &x)
Matrix< double > DM
Definition: dm_fwd.hpp:33