26 #ifndef CASADI_SOLVE_IMPL_HPP
27 #define CASADI_SOLVE_IMPL_HPP
30 #include "linsol_internal.hpp"
37 "Solve::Solve: dimension mismatch. Got r " + r.
dim() +
" and A " + A.
dim());
45 ss <<
"(" << mod_prefix() << arg.at(1) << mod_suffix();
47 ss <<
"\\" << arg.at(0) <<
")";
53 Solve<Tr>(r, A), linsol_(linear_solver) {
58 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
59 scoped_checkout<Linsol> mem(linsol_);
61 auto m =
static_cast<LinsolMemory*
>(linsol_->memory(mem));
63 for (
auto&& s : m->fstats) s.second.reset();
64 if (m->t_total) m->t_total->tic();
66 if (linsol_.sfact(arg[1], mem))
return 1;
67 if (linsol_.nfact(arg[1], mem))
return 1;
68 if (linsol_.solve(arg[1], res[0], this->dep(0).size2(), Tr, mem))
return 1;
70 linsol_->print_time(m->fstats);
77 linsol_->linsol_eval_sx(arg, res, iw, w, linsol_->memory(0), Tr, this->dep(0).size2());
84 res[0] =
MX(arg[0].size());
86 res[0] = solve(arg[1], arg[0], Tr);
92 std::vector<std::vector<MX> >& fsens)
const {
94 std::vector<MX> arg(this->n_dep());
95 for (casadi_int i=0; i<arg.size(); ++i) arg[i] = this->dep(i);
96 std::vector<MX> res(this->nout());
97 for (casadi_int i=0; i<res.size(); ++i) res[i] = this->get_output(i);
100 casadi_int nfwd = fseed.size();
101 const MX& A = arg[1];
102 const MX& X = res[0];
105 std::vector<MX> rhs(nfwd);
106 std::vector<casadi_int> col_offset(nfwd+1, 0);
107 for (casadi_int d=0; d<nfwd; ++d) {
108 const MX& B_hat = fseed[d][0];
109 const MX& A_hat = fseed[d][1];
110 rhs[d] = Tr ? B_hat - mtimes(A_hat.
T(), X) : B_hat - mtimes(A_hat, X);
111 col_offset[d+1] = col_offset[d] + rhs[d].size2();
113 rhs = horzsplit(solve(A, horzcat(rhs), Tr), col_offset);
117 for (casadi_int d=0; d<nfwd; ++d) {
119 fsens[d][0] = rhs[d];
125 std::vector<std::vector<MX> >& asens)
const {
127 std::vector<MX> arg(this->n_dep());
128 for (casadi_int i=0; i<arg.size(); ++i) arg[i] = this->dep(i);
129 std::vector<MX> res(this->nout());
130 for (casadi_int i=0; i<res.size(); ++i) res[i] = this->get_output(i);
133 casadi_int nadj = aseed.size();
134 const MX& A = arg[1];
135 const MX& X = res[0];
138 std::vector<MX> rhs(nadj);
139 std::vector<casadi_int> col_offset(nadj+1, 0);
140 for (casadi_int d=0; d<nadj; ++d) {
141 rhs[d] = aseed[d][0];
142 col_offset[d+1] = col_offset[d] + rhs[d].size2();
144 rhs = horzsplit(solve(A, horzcat(rhs), !Tr), col_offset);
148 for (casadi_int d=0; d<nadj; ++d) {
158 if (asens[d][1].is_empty(
true)) {
165 if (asens[d][0].is_empty(
true)) {
166 asens[d][0] = rhs[d];
168 asens[d][0] += rhs[d];
176 casadi_int nrhs = dep(0).size2();
179 const Sparsity& A_sp = this->A_sp();
180 const casadi_int* A_colind = A_sp.
colind();
181 const casadi_int* A_row = A_sp.
row();
182 casadi_int n = A_sp.
size1();
185 const bvec_t *B=arg[0], *A = arg[1];
190 for (casadi_int r=0; r<nrhs; ++r) {
192 std::copy(B, B+n, tmp);
195 for (casadi_int cc=0; cc<n; ++cc) {
196 for (casadi_int k=A_colind[cc]; k<A_colind[cc+1]; ++k) {
197 casadi_int rr = A_row[k];
198 tmp[Tr ? cc : rr] |= A[k];
203 std::fill(X, X+n, 0);
204 A_sp.spsolve(X, tmp, Tr);
216 casadi_int nrhs = dep(0).size2();
219 const Sparsity& A_sp = this->A_sp();
220 const casadi_int* A_colind = A_sp.
colind();
221 const casadi_int* A_row = A_sp.
row();
222 casadi_int n = A_sp.
size1();
225 bvec_t *B=arg[0], *A=arg[1], *X=res[0];
229 for (casadi_int r=0; r<nrhs; ++r) {
231 std::fill(tmp, tmp+n, 0);
232 A_sp.spsolve(tmp, X, !Tr);
235 std::fill(X, X+n, 0);
238 for (casadi_int i=0; i<n; ++i) B[i] |= tmp[i];
241 for (casadi_int cc=0; cc<n; ++cc) {
242 for (casadi_int k=A_colind[cc]; k<A_colind[cc+1]; ++k) {
243 casadi_int rr = A_row[k];
244 A[k] |= tmp[Tr ? cc : rr];
257 return this->sparsity().size1();
262 const std::vector<casadi_int>& arg,
263 const std::vector<casadi_int>& res)
const {
265 casadi_int nrhs = this->dep(0).size2();
268 g.local(
"rr",
"casadi_real",
"*");
269 g <<
"rr = " << g.work(res[0], this->nnz()) <<
";\n";
272 g.local(
"ss",
"casadi_real",
"*");
273 g <<
"ss = " << g.work(arg[1], this->dep(1).nnz()) <<
";\n";
276 if (arg[0]!=res[0]) {
277 g << g.copy(g.work(arg[0], this->nnz()), this->nnz(),
"rr") <<
'\n';
280 linsol_->
generate(g,
"ss",
"rr", nrhs, Tr);
291 s.
pack(
"Solve::Tr", Tr);
301 s.
unpack(
"Solve::Tr", tr);
302 casadi_error(
"Not implemented");
308 s.
pack(
"Solve::Linsol", linsol_);
324 s.
unpack(
"Solve::Tr", tr);
339 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
340 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
346 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
347 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
357 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
358 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
364 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
365 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
false, this->dep(0).size2());
376 if (A_sp_.is_null()) {
377 const Sparsity& no_diag = this->dep(1).sparsity();
391 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
392 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
399 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
400 casadi_triusolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
411 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
412 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
419 if (arg[0] != res[0]) std::copy(arg[0], arg[0] + this->dep(0).nnz(), res[0]);
420 casadi_trilsolve(this->dep(1).sparsity(), arg[1], res[0], Tr,
true, this->dep(0).size2());
426 const std::vector<casadi_int>& res)
const {
428 casadi_int nrhs = this->dep(0).size2();
430 if (arg[0]!=res[0]) {
431 g << g.copy(g.work(arg[0], this->nnz()), this->nnz(), g.work(res[0], this->nnz())) <<
'\n';
434 g << g.triusolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz()),
435 g.work(res[0], this->nnz()), Tr,
false, nrhs) <<
'\n';
440 const std::vector<casadi_int>& res)
const {
442 casadi_int nrhs = this->dep(0).size2();
444 if (arg[0]!=res[0]) {
445 g << g.copy(g.work(arg[0], this->nnz()), this->nnz(), g.work(res[0], this->nnz())) <<
'\n';
448 g << g.trilsolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz()),
449 g.work(res[0], this->nnz()), Tr,
false, nrhs) <<
'\n';
454 const std::vector<casadi_int>& res)
const {
456 casadi_int nrhs = this->dep(0).size2();
458 if (arg[0]!=res[0]) {
459 g << g.copy(g.work(arg[0], this->nnz()), this->nnz(), g.work(res[0], this->nnz())) <<
'\n';
462 g << g.triusolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz()),
463 g.work(res[0], this->nnz()), Tr,
true, nrhs) <<
'\n';
468 const std::vector<casadi_int>& res)
const {
470 casadi_int nrhs = this->dep(0).size2();
472 if (arg[0]!=res[0]) {
473 g << g.copy(g.work(arg[0], this->nnz()), this->nnz(), g.work(res[0], this->nnz())) <<
'\n';
476 g << g.trilsolve(this->dep(1).sparsity(), g.work(arg[1], this->dep(1).nnz()),
477 g.work(res[0], this->nnz()), Tr,
true, nrhs) <<
'\n';
Helper class for C code generation.
std::string generate(const std::string &prefix="")
Generate file(s)
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
Sparsity sparsity() const
Get the sparsity pattern.
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
std::string dim(bool with_nz=false) const
Get string representation of dimensions.
static MX zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
Linear solve operation with a linear solver instance.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
size_t sz_w() const override
Get required length of w field.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res) const override
Generate code for the operation.
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void serialize_type(SerializingStream &s) const override
Serialize type information.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
LinsolCall(const MX &r, const MX &A, const Linsol &linear_solver)
Constructor.
Linsol linsol_
Linear solver (may be shared between multiple nodes)
Node class for MX objects.
virtual void serialize_type(SerializingStream &s) const
Serialize type information.
virtual void serialize_body(SerializingStream &s) const
Serialize an object without type information.
MX T() const
Transpose the matrix.
Helper class for Serialization.
void pack(const Sparsity &e)
Serializes an object to the output stream.
Linear solve with unity diagonal added.
const Sparsity & A_sp() const override
Sparsity pattern for the linear system.
SolveUnity(const MX &r, const MX &A)
Constructor.
An MX atomic for linear solver solution: x = r * A^-1 or x = r * A^-T.
static MXNode * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Solve(const MX &r, const MX &A)
Constructor.
void ad_forward(const std::vector< std::vector< MX > > &fseed, std::vector< std::vector< MX > > &fsens) const override
Calculate forward mode directional derivatives.
std::string disp(const std::vector< std::string > &arg) const override
Print expression.
void ad_reverse(const std::vector< std::vector< MX > > &aseed, std::vector< std::vector< MX > > &asens) const override
Calculate reverse mode directional derivatives.
void eval_mx(const std::vector< MX > &arg, std::vector< MX > &res) const override
Evaluate symbolically (MX)
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity backwards.
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w) const override
Propagate sparsity forward.
void serialize_type(SerializingStream &s) const override
Serialize type information.
casadi_int colind(casadi_int cc) const
Get a reference to the colindex of column cc (see class description)
casadi_int size1() const
Get the number of rows.
static Sparsity diag(casadi_int nrow)
Create diagonal sparsity pattern *.
casadi_int row(casadi_int el) const
Get the row of a non-zero element.
TrilSolveUnity(const MX &r, const MX &A)
Constructor.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res) const override
Generate code for the operation.
TrilSolve(const MX &r, const MX &A)
Constructor.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res) const override
Generate code for the operation.
TriuSolveUnity(const MX &r, const MX &A)
Constructor.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res) const override
Generate code for the operation.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
void generate(CodeGenerator &g, const std::vector< casadi_int > &arg, const std::vector< casadi_int > &res) const override
Generate code for the operation.
int eval(const double **arg, double **res, casadi_int *iw, double *w) const override
Evaluate the function numerically.
TriuSolve(const MX &r, const MX &A)
Constructor.
int eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w) const override
Evaluate the function symbolically (SX)