linsol.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "linsol_internal.hpp"
27 #include "mx_node.hpp"
28 #include "filesystem_impl.hpp"
29 
30 namespace casadi {
31 
33  }
34 
35  Linsol::Linsol(const std::string& name, const std::string& solver,
36  const Sparsity& sp, const Dict& opts) {
37  own(LinsolInternal::getPlugin(solver).creator(name, sp));
38  (*this)->construct(opts);
39  }
40 
42  return static_cast<LinsolInternal*>(SharedObject::operator->());
43  }
44 
46  return static_cast<const LinsolInternal*>(SharedObject::operator->());
47  }
48 
50  return dynamic_cast<const LinsolInternal*>(ptr)!=nullptr;
51  }
52 
53  bool Linsol::has_plugin(const std::string& name) {
54  return LinsolInternal::has_plugin(name);
55  }
56 
57  void Linsol::load_plugin(const std::string& name) {
59  }
60 
61  std::string Linsol::doc(const std::string& name) {
62  return LinsolInternal::getPlugin(name).doc;
63  }
64 
65  std::string Linsol::plugin_name() const {
66  return (*this)->plugin_name();
67  }
68 
69  const Sparsity& Linsol::sparsity() const {
70  return (*this)->sp_;
71  }
72 
73  DM Linsol::solve(const DM& A, const DM& B, bool tr) const {
74  casadi_assert(A.size1()==B.size1(),
75  "Linsol::solve: Dimension mismatch. A and b must have matching row count. "
76  "Got " + A.dim() + " and " + B.dim() + ".");
77 
78  scoped_checkout<Linsol> mem(*this);
79  auto m = static_cast<LinsolMemory*>((*this)->memory(mem));
80 
81  // Reset statistics
82  for (auto&& s : m->fstats) s.second.reset();
83  if (m->t_total) m->t_total->tic();
84  // Symbolic factorization
85  if (sfact(A.ptr(), mem)) casadi_error("Linsol::solve: 'sfact' failed");
86 
87  // Numeric factorization
88  if (nfact(A.ptr(), mem)) casadi_error("Linsol::solve: 'nfact' failed");
89 
90  // Solve
91  DM x = densify(B);
92  if (solve(A.ptr(), x.ptr(), x.size2(), false, mem))
93  casadi_error("Linsol::solve: 'solve' failed");
94  // Show statistics
95  if (m->t_total) m->t_total->toc();
96 
97  (*this)->print_time(m->fstats);
98  return x;
99  }
100 
101  MX Linsol::solve(const MX& A, const MX& B, bool tr) const {
102  return A->get_solve(B, tr, *this);
103  }
104 
105  void Linsol::sfact(const DM& A) const {
106  if (A.sparsity()!=sparsity()) return sfact(project(A, sparsity()));
107  if (sfact(A.ptr())) casadi_error("'sfact' failed");
108  }
109 
110  int Linsol::sfact(const double* A, int mem) const {
111  if (A==nullptr) return 1;
112  auto m = static_cast<LinsolMemory*>((*this)->memory(mem));
113 
114  // Factorization will be needed after this step
115  m->is_sfact = m->is_nfact = false;
116 
117  if (m->t_total) m->fstats.at("sfact").tic();
118  // Perform pivoting
119  if ((*this)->sfact(m, A)) return 1;
120  if (m->t_total) m->fstats.at("sfact").toc();
121 
122  // Mark as (successfully) pivoted
123  m->is_sfact = true;
124  return 0;
125  }
126 
127  void Linsol::nfact(const DM& A) const {
128  if (A.sparsity()!=sparsity()) return nfact(project(A, sparsity()));
129  if (nfact(A.ptr())) casadi_error("'nfact' failed");
130  }
131 
132  int Linsol::nfact(const double* A, int mem) const {
133  if (A==nullptr) return 1;
134  auto m = static_cast<LinsolMemory*>((*this)->memory(mem));
135 
136  // Perform pivoting, if required
137  if (!m->is_sfact) {
138  if (sfact(A, mem)) return 1;
139  }
140 
141  m->is_nfact = false;
142  if (m->t_total) m->fstats.at("nfact").tic();
143  int flag = (*this)->nfact(m, A);
144  if (m->t_total) m->fstats.at("nfact").toc();
145  if (flag && (*this)->regularity_check_) {
146  // Collect nonzeros
147  std::vector<std::string> nonzeros(sparsity().nnz());
148  for (size_t nz = 0; nz < nonzeros.size(); ++nz)
149  nonzeros[nz] = std::to_string(A[nz]);
150  // Create .m file
151  std::ofstream mfile;
152  std::string fname = (*this)->class_name() + "_" + (*this)->name_ + "_debug.m";
153  Filesystem::open(mfile, fname);
154  Dict opts;
155  opts["name"] = "A";
156  opts["nonzeros"] = nonzeros;
157  sparsity().export_code("matlab", mfile, opts);
158  mfile.close();
159  casadi_error("Numerical factorization failed for " + (*this)->name_
160  + "[" + (*this)->class_name() + "]. Linear system saved to '" + fname + "'");
161  }
162  m->is_nfact = true;
163  return flag;
164  }
165 
166  casadi_int Linsol::neig(const DM& A) const {
167  if (A.sparsity()!=sparsity()) return neig(project(A, sparsity()));
168  casadi_int n = neig(A.ptr());
169  casadi_assert(n>=0, "'neig' failed");
170  return n;
171  }
172 
173  casadi_int Linsol::neig(const double* A, int mem) const {
174  return (*this)->neig((*this)->memory(mem), A);
175  }
176 
177  casadi_int Linsol::rank(const DM& A) const {
178  if (A.sparsity()!=sparsity()) return rank(project(A, sparsity()));
179  casadi_int n = rank(A.ptr());
180  casadi_assert(n>=0, "'rank' failed");
181  return n;
182  }
183 
184  casadi_int Linsol::rank(const double* A, int mem) const {
185  return (*this)->rank((*this)->memory(mem), A);
186  }
187 
188  int Linsol::solve(const double* A, double* x, casadi_int nrhs, bool tr, int mem) const {
189  auto m = static_cast<LinsolMemory*>((*this)->memory(mem));
190  casadi_assert(m->is_nfact, "Linear system has not been factorized");
191  if (m->t_total) m->fstats.at("solve").tic();
192  int ret = (*this)->solve(m, A, x, nrhs, tr);
193  if (m->t_total) m->fstats.at("solve").toc();
194  return ret;
195  }
196 
197  casadi_int Linsol::checkout() const {
198  return (*this)->checkout();
199  }
200 
201  void Linsol::release(int mem) const {
202  (*this)->release(mem);
203  }
204 
205 
206  bool has_linsol(const std::string& name) {
207  return Linsol::has_plugin(name);
208  }
209 
210  void load_linsol(const std::string& name) {
211  Linsol::load_plugin(name);
212  }
213 
214  std::string doc_linsol(const std::string& name) {
215  return Linsol::doc(name);
216  }
217 
218  Dict Linsol::stats(int mem) const {
219  casadi_assert((*this)->has_memory(mem),
220  "No stats available since Linsol did not solve a problem yet.");
221  return (*this)->get_stats((*this)->memory(mem));
222  }
223 
225  // TODO(jgillis): I don't get why LinsolInternal:: this is necessary
226  return (*this)->LinsolInternal::serialize(s);
227  }
228 
230  Linsol linsol;
231  linsol.own(LinsolInternal::deserialize(s));
232  linsol->finalize();
233  return linsol;
234  }
235 
237  Linsol ret;
238  ret.own(node);
239  return ret;
240  }
241 
242 } // namespace casadi
Helper class for Serialization.
void tic()
Start timing.
Definition: timing.cpp:40
static void open(std::ofstream &, const std::string &path, std::ios_base::openmode mode=std::ios_base::out)
Definition: filesystem.cpp:115
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
std::string dim(bool with_nz=false) const
Get string representation of dimensions.
SharedObjectInternal * operator->() const
Access a member function or object.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Linear solver.
Definition: linsol.hpp:55
static bool has_plugin(const std::string &name)
Check if a plugin is available.
Definition: linsol.cpp:53
casadi_int rank(const DM &A) const
Matrix rank.
Definition: linsol.cpp:177
static void load_plugin(const std::string &name)
Explicitly load a plugin dynamically.
Definition: linsol.cpp:57
casadi_int checkout() const
Checkout a memory object.
Definition: linsol.cpp:197
void nfact(const DM &A) const
Numeric factorization of the linear system.
Definition: linsol.cpp:127
Linsol()
Default constructor.
Definition: linsol.cpp:32
const Sparsity & sparsity() const
Get linear system sparsity.
Definition: linsol.cpp:69
static bool test_cast(const SharedObjectInternal *ptr)
Check if a particular cast is allowed.
Definition: linsol.cpp:49
Dict stats(int mem=1) const
Get all statistics obtained at the end of the last evaluate call.
Definition: linsol.cpp:218
std::string plugin_name() const
Query plugin name.
Definition: linsol.cpp:65
static std::string doc(const std::string &name)
Get solver specific documentation.
Definition: linsol.cpp:61
static Linsol deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Definition: linsol.cpp:229
DM solve(const DM &A, const DM &B, bool tr=false) const
Definition: linsol.cpp:73
LinsolInternal * operator->()
Access functions of the node.
Definition: linsol.cpp:41
static Linsol create(LinsolInternal *node)
Create from node.
Definition: linsol.cpp:236
void sfact(const DM &A) const
Symbolic factorization of the linear system, e.g. selecting pivots.
Definition: linsol.cpp:105
void serialize(SerializingStream &s) const
Serialize an object.
Definition: linsol.cpp:224
void release(int mem) const
Release a memory object.
Definition: linsol.cpp:201
casadi_int neig(const DM &A) const
Number of negative eigenvalues.
Definition: linsol.cpp:166
virtual MX get_solve(const MX &r, bool tr, const Linsol &linear_solver) const
Solve a system of linear equations.
Definition: mx_node.cpp:649
MX - Matrix expression.
Definition: mx.hpp:92
const Sparsity & sparsity() const
Const access the sparsity - reference to data member.
Scalar * ptr()
static bool has_plugin(const std::string &pname, bool verbose=false)
Check if a plugin is available or can be loaded.
static Plugin & getPlugin(const std::string &pname)
Load and get the creator function.
static Plugin load_plugin(const std::string &pname, bool register_plugin=true, bool needs_lock=true)
Load a plugin dynamically.
virtual void finalize()
Finalize the object creation.
Helper class for Serialization.
General sparsity class.
Definition: sparsity.hpp:106
void export_code(const std::string &lang, std::ostream &stream=casadi::uout(), const Dict &options=Dict()) const
Export matrix in specific language.
Definition: sparsity.cpp:778
The casadi namespace.
Definition: archiver.cpp:28
std::string doc_linsol(const std::string &name)
Get the documentation string for a plugin.
Definition: linsol.cpp:214
bool has_linsol(const std::string &name)
Check if a particular plugin is available.
Definition: linsol.cpp:206
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
void load_linsol(const std::string &name)
Explicitly load a plugin dynamically.
Definition: linsol.cpp:210