26 #ifndef CASADI_SOLVE_IMPL_HPP
27 #define CASADI_SOLVE_IMPL_HPP
30 #include "linsol_internal.hpp"
37 "Solve::Solve: dimension mismatch. Got r " + r.
dim() +
" and A " + A.
dim());
45 ss <<
"(" << mod_prefix() << arg.at(1) << mod_suffix();
47 ss <<
"\\" << arg.at(0) <<
")";
53 Solve<Tr>(r, A), linsol_(linear_solver) {
58 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
59 scoped_checkout<Linsol> mem(linsol_);
61 auto m =
static_cast<LinsolMemory*
>(linsol_->memory(mem));
63 for (
auto&& s : m->fstats) s.second.reset();
64 if (m->t_total) m->t_total->tic();
66 if (linsol_.sfact(arg[1], mem))
return 1;
67 if (linsol_.nfact(arg[1], mem))
return 1;
68 if (linsol_.solve(arg[1], res[0], this->dep(0).size2(), Tr, mem))
return 1;
70 linsol_->print_time(m->fstats);
77 linsol_->linsol_eval_sx(arg, res, iw, w, linsol_->memory(0), Tr, this->dep(0).size2());
84 res[0] =
MX(arg[0].size());
86 res[0] = solve(arg[1], arg[0], Tr);
92 std::vector<std::vector<MX> >& fsens)
const {
94 std::vector<MX> arg(this->n_dep());
95 for (casadi_int i=0; i<arg.size(); ++i) arg[i] = this->dep(i);
96 std::vector<MX> res(this->nout());
97 for (casadi_int i=0; i<res.size(); ++i) res[i] = this->get_output(i);
100 casadi_int nfwd = fseed.size();
101 const MX& A = arg[1];
102 const MX&
X = res[0];
105 std::vector<MX> rhs(nfwd);
106 std::vector<casadi_int> col_offset(nfwd+1, 0);
107 for (casadi_int d=0; d<nfwd; ++d) {
108 const MX& B_hat = fseed[d][0];
109 const MX& A_hat = fseed[d][1];
110 rhs[d] = Tr ? B_hat - mtimes(A_hat.
T(),
X) : B_hat - mtimes(A_hat,
X);
111 col_offset[d+1] = col_offset[d] + rhs[d].size2();
113 rhs = horzsplit(solve(A, horzcat(rhs), Tr), col_offset);
117 for (casadi_int d=0; d<nfwd; ++d) {
119 fsens[d][0] = rhs[d];
125 std::vector<std::vector<MX> >& asens)
const {
127 std::vector<MX> arg(this->n_dep());
128 for (casadi_int i=0; i<arg.size(); ++i) arg[i] = this->dep(i);
129 std::vector<MX> res(this->nout());
130 for (casadi_int i=0; i<res.size(); ++i) res[i] = this->get_output(i);
133 casadi_int nadj = aseed.size();
134 const MX& A = arg[1];
135 const MX&
X = res[0];
138 std::vector<MX> rhs(nadj);
139 std::vector<casadi_int> col_offset(nadj+1, 0);
140 for (casadi_int d=0; d<nadj; ++d) {
141 rhs[d] = aseed[d][0];
142 col_offset[d+1] = col_offset[d] + rhs[d].size2();
144 rhs = horzsplit(solve(A, horzcat(rhs), !Tr), col_offset);
148 for (casadi_int d=0; d<nadj; ++d) {
158 if (asens[d][1].is_empty(
true)) {
165 if (asens[d][0].is_empty(
true)) {
166 asens[d][0] = rhs[d];
168 asens[d][0] += rhs[d];
176 casadi_int nrhs = dep(0).size2();
179 const Sparsity& A_sp = this->A_sp();
180 const casadi_int* A_colind = A_sp.
colind();
181 const casadi_int* A_row = A_sp.
row();
182 casadi_int n = A_sp.
size1();
185 const bvec_t *B=arg[0], *A = arg[1];
190 for (casadi_int r=0; r<nrhs; ++r) {
192 std::copy(B, B+n, tmp);
195 for (casadi_int cc=0; cc<n; ++cc) {
196 for (casadi_int k=A_colind[cc]; k<A_colind[cc+1]; ++k) {
197 casadi_int rr = A_row[k];
198 tmp[Tr ? cc : rr] |= A[k];
203 std::fill(
X,
X+n, 0);
204 A_sp.spsolve(
X, tmp, Tr);
216 casadi_int nrhs = dep(0).size2();
219 const Sparsity& A_sp = this->A_sp();
220 const casadi_int* A_colind = A_sp.
colind();
221 const casadi_int* A_row = A_sp.
row();
222 casadi_int n = A_sp.
size1();
225 bvec_t *B=arg[0], *A=arg[1], *
X=res[0];
229 for (casadi_int r=0; r<nrhs; ++r) {
231 std::fill(tmp, tmp+n, 0);
232 A_sp.spsolve(tmp,
X, !Tr);
235 std::fill(
X,
X+n, 0);
238 for (casadi_int i=0; i<n; ++i) B[i] |= tmp[i];
241 for (casadi_int cc=0; cc<n; ++cc) {
242 for (casadi_int k=A_colind[cc]; k<A_colind[cc+1]; ++k) {
243 casadi_int rr = A_row[k];
244 A[k] |= tmp[Tr ? cc : rr];
257 return this->sparsity().size1();
262 const std::vector<casadi_int>& arg,
263 const std::vector<casadi_int>& res,
264 const std::vector<bool>& arg_is_ref,
265 std::vector<bool>& res_is_ref)
const {
267 casadi_int nrhs = this->dep(0).size2();
270 g.local(
"rr",
"casadi_real",
"*");
271 g <<
"rr = " << g.work(res[0], this->nnz(),
false) <<
";\n";
274 g.local(
"ss",
"const casadi_real",
"*");
275 g <<
"ss = " << g.work(arg[1], this->dep(1).nnz(), arg_is_ref[1]) <<
";\n";
278 if (arg[0]!=res[0] || arg_is_ref[0]) {
279 g << g.copy(g.work(arg[0], this->nnz(), arg_is_ref[0]), this->nnz(),
"rr") <<
'\n';
282 linsol_->
generate(g,
"ss",
"rr", nrhs, Tr);
293 s.
pack(
"Solve::Tr", Tr);
303 s.
unpack(
"Solve::Tr", tr);
304 casadi_error(
"Not implemented");
310 s.
pack(
"Solve::Linsol", linsol_);
326 s.
unpack(
"Solve::Tr", tr);
341 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
342 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
348 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
349 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
359 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
360 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
366 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
367 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
377 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
379 std::lock_guard<std::mutex> lock(A_sp_mtx_);
382 if (A_sp_.is_null()) {
383 const Sparsity& no_diag = this->dep(1).sparsity();
397 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
398 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
405 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
406 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
417 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
418 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
425 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
426 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
432 const std::vector<casadi_int>& arg,
433 const std::vector<casadi_int>& res,
434 const std::vector<bool>& arg_is_ref,
435 std::vector<bool>& res_is_ref)
const {
437 casadi_int nrhs = this->dep(0).size2();
439 if (arg[0]!=res[0] || arg_is_ref[0]) {
440 g << g.copy(g.work(arg[0], this->nnz(), arg_is_ref[0]),
442 g.work(res[0], this->nnz(),
false)) <<
'\n';
445 g << g.triusolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz(), arg_is_ref[1]),
446 g.work(res[0], this->nnz(),
false), Tr,
false, nrhs) <<
'\n';
451 const std::vector<casadi_int>& arg,
452 const std::vector<casadi_int>& res,
453 const std::vector<bool>& arg_is_ref,
454 std::vector<bool>& res_is_ref)
const {
456 casadi_int nrhs = this->dep(0).size2();
458 if (arg[0]!=res[0] || arg_is_ref[0]) {
459 g << g.copy(g.work(arg[0], this->nnz(), arg_is_ref[0]),
461 g.work(res[0], this->nnz(),
false)) <<
'\n';
464 g << g.trilsolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz(), arg_is_ref[1]),
465 g.work(res[0], this->nnz(),
false), Tr,
false, nrhs) <<
'\n';
470 const std::vector<casadi_int>& arg,
471 const std::vector<casadi_int>& res,
472 const std::vector<bool>& arg_is_ref,
473 std::vector<bool>& res_is_ref)
const {
475 casadi_int nrhs = this->dep(0).size2();
477 if (arg[0]!=res[0] || arg_is_ref[0]) {
478 g << g.copy(g.work(arg[0], this->nnz(), arg_is_ref[0]),
480 g.work(res[0], this->nnz(),
false)) <<
'\n';
483 g << g.triusolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz(), arg_is_ref[1]),
484 g.work(res[0], this->nnz(),
false), Tr,
true, nrhs) <<
'\n';
489 const std::vector<casadi_int>& arg,
490 const std::vector<casadi_int>& res,
491 const std::vector<bool>& arg_is_ref,
492 std::vector<bool>& res_is_ref)
const {
494 casadi_int nrhs = this->dep(0).size2();
496 if (arg[0]!=res[0] || arg_is_ref[0]) {
497 g << g.copy(g.work(arg[0], this->nnz(), arg_is_ref[0]),
499 g.work(res[0], this->nnz(),
false)) <<
'\n';
502 g << g.trilsolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz(), arg_is_ref[1]),
503 g.work(res[0], this->nnz(),
false), Tr,
true, nrhs) <<
'\n';
Helper class for C code generation.
std::string generate(const std::string &prefix="")
Generate file(s)
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Sparsity sparsity() const
Get the sparsity pattern.
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
std::string dim(bool with_nz=false) const
Get string representation of dimensions.
static MX zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
Linear solve operation with a linear solver instance.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
size_t sz_w() const override
Get required length of w field.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void serialize_type(SerializingStream &s) const override
Serialize type information.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
LinsolCall(const MX &r, const MX &A, const Linsol &linear_solver)
Constructor.
Linsol linsol_
Linear solver (may be shared between multiple nodes)
Node class for MX objects.
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
MX T() const
Transpose the matrix.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
Linear solve with unity diagonal added.
const Sparsity & A_sp() const override
Sparsity pattern for the linear system.
SolveUnity(const MX &r, const MX &A)
Constructor.
An MX atomic for linear solver solution: x = r * A^-1 or x = r * A^-T.
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Solve(const MX &r, const MX &A)
Constructor.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
void serialize_type(SerializingStream &s) const override
Serialize type information.
casadi_int colind(casadi_int cc) const
Get a reference to the colindex of column cc (see class description)
casadi_int size1() const
Get the number of rows.
static Sparsity diag(casadi_int nrow)
Create diagonal sparsity pattern *.
casadi_int row(casadi_int el) const
Get the row of a non-zero element.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
TrilSolveUnity(const MX &r, const MX &A)
Constructor.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
TrilSolve(const MX &r, const MX &A)
Constructor.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
TriuSolveUnity(const MX &r, const MX &A)
Constructor.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
TriuSolve(const MX &r, const MX &A)
Constructor.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res, const std::vector< bool > &arg_is_ref, std::vector< bool > &res_is_ref) const override
Generate code for the operation.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)