26 #include "linsol_internal.hpp"
27 #include "mx_node.hpp"
28 #include "filesystem_impl.hpp"
38 (*this)->construct(opts);
66 return (*this)->plugin_name();
75 "Linsol::solve: Dimension mismatch. A and b must have matching row count. "
76 "Got " + A.
dim() +
" and " + B.
dim() +
".");
79 auto m =
static_cast<LinsolMemory*
>((*this)->memory(mem));
82 for (
auto&& s : m->fstats) s.second.reset();
85 if (
sfact(A.
ptr(), mem)) casadi_error(
"Linsol::solve: 'sfact' failed");
88 if (
nfact(A.
ptr(), mem)) casadi_error(
"Linsol::solve: 'nfact' failed");
93 casadi_error(
"Linsol::solve: 'solve' failed");
95 if (m->t_total) m->t_total->toc();
97 (*this)->print_time(m->fstats);
107 if (
sfact(A.
ptr())) casadi_error(
"'sfact' failed");
111 if (A==
nullptr)
return 1;
112 auto m =
static_cast<LinsolMemory*
>((*this)->memory(mem));
117 if (m->t_total) m->fstats.at(
"sfact").tic();
119 if ((*this)->sfact(m, A))
return 1;
120 if (m->t_total) m->fstats.at(
"sfact").toc();
129 if (
nfact(A.
ptr())) casadi_error(
"'nfact' failed");
133 if (A==
nullptr)
return 1;
134 auto m =
static_cast<LinsolMemory*
>((*this)->memory(mem));
138 if (
sfact(A, mem))
return 1;
142 if (m->t_total) m->fstats.at(
"nfact").tic();
143 int flag = (*this)->nfact(m, A);
144 if (m->t_total) m->fstats.at(
"nfact").toc();
145 if (flag && (*this)->regularity_check_) {
147 std::vector<std::string> nonzeros(
sparsity().nnz());
148 for (
size_t nz = 0; nz < nonzeros.size(); ++nz)
149 nonzeros[nz] = std::to_string(A[nz]);
152 std::string fname = (*this)->class_name() +
"_" + (*this)->name_ +
"_debug.m";
156 opts[
"nonzeros"] = nonzeros;
159 casadi_error(
"Numerical factorization failed for " + (*this)->name_
160 +
"[" + (*this)->class_name() +
"]. Linear system saved to '" + fname +
"'");
169 casadi_assert(n>=0,
"'neig' failed");
174 return (*this)->neig((*this)->memory(mem), A);
180 casadi_assert(n>=0,
"'rank' failed");
185 return (*this)->rank((*this)->memory(mem), A);
188 int Linsol::solve(
const double* A,
double* x, casadi_int nrhs,
bool tr,
int mem)
const {
189 auto m =
static_cast<LinsolMemory*
>((*this)->memory(mem));
190 casadi_assert(m->is_nfact,
"Linear system has not been factorized");
191 if (m->t_total) m->fstats.at(
"solve").tic();
192 int ret = (*this)->solve(m, A, x, nrhs, tr);
193 if (m->t_total) m->fstats.at(
"solve").toc();
198 return (*this)->checkout();
202 (*this)->release(mem);
219 casadi_assert((*this)->has_memory(mem),
220 "No stats available since Linsol did not solve a problem yet.");
221 return (*this)->get_stats((*this)->memory(mem));
226 return (*this)->LinsolInternal::serialize(s);
Helper class for Serialization.
static void open(std::ofstream &, const std::string &path, std::ios_base::openmode mode=std::ios_base::out)
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
std::string dim(bool with_nz=false) const
Get string representation of dimensions.
void own(SharedObjectInternal *node)
SharedObjectInternal * operator->() const
Access a member function or object.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
static bool has_plugin(const std::string &name)
Check if a plugin is available.
casadi_int rank(const DM &A) const
Matrix rank.
static void load_plugin(const std::string &name)
Explicitly load a plugin dynamically.
casadi_int checkout() const
Checkout a memory object.
void nfact(const DM &A) const
Numeric factorization of the linear system.
Linsol()
Default constructor.
const Sparsity & sparsity() const
Get linear system sparsity.
static bool test_cast(const SharedObjectInternal *ptr)
Check if a particular cast is allowed.
Dict stats(int mem=1) const
Get all statistics obtained at the end of the last evaluate call.
std::string plugin_name() const
Query plugin name.
static std::string doc(const std::string &name)
Get solver specific documentation.
static Linsol deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
DM solve(const DM &A, const DM &B, bool tr=false) const
LinsolInternal * operator->()
Access functions of the node.
static Linsol create(LinsolInternal *node)
Create from node.
void sfact(const DM &A) const
Symbolic factorization of the linear system, e.g. selecting pivots.
void serialize(SerializingStream &s) const
Serialize an object.
void release(int mem) const
Release a memory object.
casadi_int neig(const DM &A) const
Number of negative eigenvalues.
virtual MX get_solve(const MX &r, bool tr, const Linsol &linear_solver) const
Solve a system of linear equations.
const Sparsity & sparsity() const
Const access the sparsity - reference to data member.
static bool has_plugin(const std::string &pname, bool verbose=false)
Check if a plugin is available or can be loaded.
static Plugin & getPlugin(const std::string &pname)
Load and get the creator function.
static Plugin load_plugin(const std::string &pname, bool register_plugin=true, bool needs_lock=true)
Load a plugin dynamically.
virtual void finalize()
Finalize the object creation.
Helper class for Serialization.
void export_code(const std::string &lang, std::ostream &stream=casadi::uout(), const Dict &options=Dict()) const
Export matrix in specific language.
std::string doc_linsol(const std::string &name)
Get the documentation string for a plugin.
bool has_linsol(const std::string &name)
Check if a particular plugin is available.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
void load_linsol(const std::string &name)
Explicitly load a plugin dynamically.