expm.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "expm_impl.hpp"
27 #include "sparsity_interface.hpp"
28 #include <typeinfo>
29 
30 namespace casadi {
31 
32  bool has_expm(const std::string& name) {
33  return Expm::has_plugin(name);
34  }
35 
36  void load_expm(const std::string& name) {
37  Expm::load_plugin(name);
38  }
39 
40  std::string doc_expm(const std::string& name) {
41  return Expm::getPlugin(name).doc;
42  }
43 
44  Function expmsol(const std::string& name, const std::string& solver,
45  const Sparsity& A, const Dict& opts) {
46  return Function::create(Expm::instantiate(name, solver, A), opts);
47  }
48 
49  casadi_int expm_n_in() {
50  return 2;
51  }
52 
53  casadi_int expm_n_out() {
54  return 1;
55  }
56 
57  // Constructor
58  Expm::Expm(const std::string& name, const Sparsity &A)
59  : FunctionInternal(name), A_(Sparsity::dense(A.size1(), A.size2())) {
60 
61  casadi_assert_dev(A.is_square());
62 
63  }
64 
66  switch (i) {
67  case 0:
68  return A_;
69  case 1:
70  return Sparsity::dense(1, 1);
71  default: break;
72  }
73  return Sparsity();
74  }
75 
77  switch (i) {
78  case 0:
79  return A_;
80  default: break;
81  }
82  return Sparsity();
83  }
84 
87  {{"const_A",
88  {OT_BOOL,
89  "Assume A is constant. Default: false."}}
90  }
91  };
92 
93  void Expm::init(const Dict& opts) {
94  // Call the init method of the base class
96 
97  const_A_ = false;
98 
99  // Read options
100  for (auto&& op : opts) {
101  if (op.first=="const_A") {
102  const_A_ = op.second;
103  }
104  }
105 
106  }
107 
108  Function Expm::get_forward(casadi_int nfwd, const std::string& name,
109  const std::vector<std::string>& inames,
110  const std::vector<std::string>& onames,
111  const Dict& opts) const {
112  MX A = MX::sym("A", A_);
113  MX t = MX::sym("t");
114  MX Y = MX::sym("Y", A_);
115  MX Adot = MX::sym("Adot", A_);
116  MX tdot = MX::sym("tdot");
117 
118  MX Ydot = mtimes(A, Y)*tdot;
119 
120  if (!const_A_) {
121  DM N = DM::zeros(A_.size());
122 
123  MX extended = MX::blockcat({{A, Adot}, {N, A}});
124  MX R = expm(extended*t);
125 
126  Ydot += R(Slice(0, A_.size1()), Slice(A_.size1(), 2*A_.size1()));
127  }
128 
129  Function ret = Function(name, {A, t, Y, Adot, tdot}, {Ydot});
130 
131  return ret.map(name, "serial", nfwd,
132  std::vector<casadi_int>{0, 1, 2}, std::vector<casadi_int>{});
133 
134  }
135 
136  Function Expm::get_reverse(casadi_int nadj, const std::string& name,
137  const std::vector<std::string>& inames,
138  const std::vector<std::string>& onames,
139  const Dict& opts) const {
140  MX A = MX::sym("A", A_);
141  MX t = MX::sym("t");
142  MX Y = MX::sym("Y", A_);
143  MX Ybar = MX::sym("Ybar", A_);
144 
145  MX tbar = sum2(sum1(Ybar*mtimes(A, Y)));
146  MX Abar;
147  if (const_A_) {
148  Abar = MX(Sparsity(A_.size()));
149  } else {
150  DM N = DM::zeros(A_.size());
151 
152  MX At = A.T();
153  MX extended = MX::blockcat({{At, Ybar}, {N, At}});
154  MX R = expm(extended*t);
155 
156  Abar = R(Slice(0, A_.size1()), // NOLINT(cppcoreguidelines-slicing)
157  Slice(A_.size1(), 2*A_.size1()));
158  }
159  Function ret = Function(name, {A, t, Y, Ybar}, {Abar, tbar});
160 
161  return ret.map(name, "serial", nadj,
162  std::vector<casadi_int>{0, 1, 2}, std::vector<casadi_int>{});
163  }
164 
165  Sparsity Expm::get_jac_sparsity(casadi_int oind, casadi_int iind, bool symmetric) const {
166  if (const_A_ && iind == 0) {
167  return Sparsity(nnz_out(oind), nnz_in(iind));
168  }
169  // Fallback to base class
170  return FunctionInternal::get_jac_sparsity(oind, iind, symmetric);
171  }
172 
174  }
175 
176  std::map<std::string, Expm::Plugin> Expm::solvers_;
177 
178 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
179  std::mutex Expm::mutex_solvers_;
180 #endif // CASADI_WITH_THREADSAFE_SYMBOLICS
181 
182  const std::string Expm::infix_ = "expm";
183 
184 } // namespace casadi
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
Definition: expm.cpp:76
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nfwd forward derivatives.
Definition: expm.cpp:108
void init(const Dict &opts) override
Initialize.
Definition: expm.cpp:93
Sparsity A_
Definition: expm_impl.hpp:121
static const std::string infix_
Infix.
Definition: expm_impl.hpp:115
~Expm() override=0
Definition: expm.cpp:173
Sparsity get_jac_sparsity(casadi_int oind, casadi_int iind, bool symmetric) const override
Generate the sparsity of a Jacobian block.
Definition: expm.cpp:165
Expm(const std::string &name, const Sparsity &A)
Definition: expm.cpp:58
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
Definition: expm.cpp:65
static const Options options_
Options.
Definition: expm_impl.hpp:63
static std::map< std::string, Plugin > solvers_
Collection of solvers.
Definition: expm_impl.hpp:108
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Generate a function that calculates nadj adjoint derivatives.
Definition: expm.cpp:136
Internal class for Function.
void init(const Dict &opts) override
Initialize.
casadi_int nnz_in() const
Number of input/output nonzeros.
static const Options options_
Options.
virtual Sparsity get_jac_sparsity(casadi_int oind, casadi_int iind, bool symmetric) const
Get Jacobian sparsity.
casadi_int nnz_out() const
Number of input/output nonzeros.
Function object.
Definition: function.hpp:60
static Function create(FunctionInternal *node)
Create from node.
Definition: function.cpp:336
Function map(casadi_int n, const std::string &parallelization="serial") const
Create a mapped version of this function.
Definition: function.cpp:709
static MX sym(const std::string &name, casadi_int nrow=1, casadi_int ncol=1)
Create an nrow-by-ncol symbolic primitive.
static MatType zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
MX - Matrix expression.
Definition: mx.hpp:92
static MX blockcat(const std::vector< std::vector< MX > > &v)
Definition: mx.cpp:1197
MX T() const
Transpose the matrix.
Definition: mx.cpp:1029
static bool has_plugin(const std::string &pname, bool verbose=false)
Check if a plugin is available or can be loaded.
static Expm * instantiate(const std::string &fname, const std::string &pname, Problem problem)
static Plugin & getPlugin(const std::string &pname)
Load and get the creator function.
static Plugin load_plugin(const std::string &pname, bool register_plugin=true, bool needs_lock=true)
Load a plugin dynamically.
Class representing a Slice.
Definition: slice.hpp:48
General sparsity class.
Definition: sparsity.hpp:106
casadi_int size1() const
Get the number of rows.
Definition: sparsity.cpp:124
static Sparsity dense(casadi_int nrow, casadi_int ncol=1)
Create a dense rectangular sparsity pattern *.
Definition: sparsity.cpp:1012
std::pair< casadi_int, casadi_int > size() const
Get the shape.
Definition: sparsity.cpp:152
bool is_square() const
Is square?
Definition: sparsity.cpp:293
Function expmsol(const std::string &name, const std::string &solver, const Sparsity &A, const Dict &opts)
Definition: expm.cpp:44
bool has_expm(const std::string &name)
Check if a particular plugin is available.
Definition: expm.cpp:32
void load_expm(const std::string &name)
Explicitly load a plugin dynamically.
Definition: expm.cpp:36
casadi_int expm_n_out()
Get the number of expm solver outputs.
Definition: expm.cpp:53
casadi_int expm_n_in()
Get the number of expm solver inputs.
Definition: expm.cpp:49
std::string doc_expm(const std::string &name)
Get the documentation string for a plugin.
Definition: expm.cpp:40
The casadi namespace.
Definition: archiver.cpp:28
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
Options metadata for a class.
Definition: options.hpp:40