symbolic_qr.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "symbolic_qr.hpp"
27 
28 #ifdef WITH_DL
29 #include <cstdlib>
30 #endif // WITH_DL
31 
32 namespace casadi {
33 
34  extern "C"
35  int CASADI_LINSOL_SYMBOLICQR_EXPORT
36  casadi_register_linsol_symbolicqr(LinsolInternal::Plugin* plugin) {
37  plugin->creator = SymbolicQr::creator;
38  plugin->name = "symbolicqr";
39  plugin->doc = SymbolicQr::meta_doc.c_str();
40  plugin->version = CASADI_VERSION;
41  plugin->options = &SymbolicQr::options_;
42  plugin->deserialize = &SymbolicQr::deserialize;
43  return 0;
44  }
45 
46  extern "C"
47  void CASADI_LINSOL_SYMBOLICQR_EXPORT casadi_load_linsol_symbolicqr() {
49  }
50 
51  SymbolicQr::SymbolicQr(const std::string& name, const Sparsity& sp) :
52  LinsolInternal(name, sp) {
53  }
54 
56  clear_mem();
57  }
58 
61  {{"fopts",
62  {OT_DICT,
63  "Options to be passed to generated function objects"}}
64  }
65  };
66 
67  void SymbolicQr::init(const Dict& opts) {
68  // Call the base class initializer
70 
71  // Read options
72  for (auto&& op : opts) {
73  if (op.first=="fopts") {
74  fopts_ = op.second;
75  }
76  }
77 
78  // Symbolic expression for A
79  SX A = SX::sym("A", sp_);
80 
81  // BTF factorization
82  std::vector<casadi_int> rowperm, colperm, rowblock, colblock, coarse_rowblock, coarse_colblock;
83  sp_.btf(rowperm, colperm, rowblock, colblock, coarse_rowblock, coarse_colblock);
84 
85  // Get the inverted column permutation
86  std::vector<casadi_int> inv_colperm(colperm.size());
87  for (casadi_int k=0; k<colperm.size(); ++k)
88  inv_colperm[colperm[k]] = k;
89 
90  // Get the inverted row permutation
91  std::vector<casadi_int> inv_rowperm(rowperm.size());
92  for (casadi_int k=0; k<rowperm.size(); ++k)
93  inv_rowperm[rowperm[k]] = k;
94 
95  // Permute the linear system
96  SX Aperm = A(rowperm, colperm); // NOLINT(cppcoreguidelines-slicing)
97 
98  // Generate the QR factorization function
99  SX Q1, R1;
100  qr(Aperm, Q1, R1);
101  factorize_ = Function("QR_fact", {A}, {Q1, R1}, fopts_);
102 
103  // Symbolic expressions for solve function
104  SX Q = SX::sym("Q", Q1.sparsity());
105  SX R = SX::sym("R", R1.sparsity());
106  SX b = SX::sym("b", sp_.size2(), 1);
107 
108  // Solve non-transposed
109  // We have Pb' * Q * R * Px * x = b <=> x = Px' * inv(R) * Q' * Pb * b
110 
111  // Permute the right hand sides
112  SX bperm = b(rowperm, Slice()); // NOLINT(cppcoreguidelines-slicing)
113 
114  // Solve the factorized system
115  SX xperm = SX::solve(R, mtimes(Q.T(), bperm));
116 
117  // Permute back the solution
118  SX x = xperm(inv_colperm, Slice()); // NOLINT(cppcoreguidelines-slicing)
119 
120  // Generate the QR solve function
121  std::vector<SX> solv_in = {Q, R, b};
122  solve_ = Function("QR_solv", solv_in, {x}, fopts_);
123 
124  // Solve transposed
125  // We have (Pb' * Q * R * Px)' * x = b
126  // <=> Px' * R' * Q' * Pb * x = b
127  // <=> x = Pb' * Q * inv(R') * Px * b
128 
129  // Permute the right hand side
130  bperm = b(colperm, Slice()); // NOLINT(cppcoreguidelines-slicing)
131 
132  // Solve the factorized system
133  xperm = mtimes(Q, SX::solve(R.T(), bperm));
134 
135  // Permute back the solution
136  x = xperm(inv_rowperm, Slice()); // NOLINT(cppcoreguidelines-slicing)
137 
138  // Mofify the QR solve function
139  solveT_ = Function("QR_solv_T", solv_in, {x}, fopts_);
140  }
141 
142  int SymbolicQr::init_mem(void* mem) const {
143  if (LinsolInternal::init_mem(mem)) return 1;
144  auto m = static_cast<SymbolicQrMemory*>(mem);
145 
146  m->alloc(solveT_);
147  m->alloc(solve_);
148  m->alloc(factorize_);
149 
150  // Temporary storage
151  m->w.resize(m->w.size() + sp_.size1());
152 
153  // Allocate storage for QR factorization
154  m->q.resize(factorize_.nnz_out(0));
155  m->r.resize(factorize_.nnz_out(1));
156  return 0;
157  }
158 
159  int SymbolicQr::nfact(void* mem, const double* A) const {
160  auto m = static_cast<SymbolicQrMemory*>(mem);
161 
162  // Factorize
163  std::fill_n(get_ptr(m->arg), factorize_.n_in(), nullptr);
164  m->arg[0] = A;
165  std::fill_n(get_ptr(m->res), factorize_.n_out(), nullptr);
166  m->res[0] = get_ptr(m->q);
167  m->res[1] = get_ptr(m->r);
168  if (factorize_(get_ptr(m->arg), get_ptr(m->res), get_ptr(m->iw), get_ptr(m->w))) return 1;
169  return 0;
170  }
171 
172  int SymbolicQr::solve(void* mem, const double* A, double* x, casadi_int nrhs, bool tr) const {
173  auto m = static_cast<SymbolicQrMemory*>(mem);
174 
175  // Select solve function
176  const Function& solv = tr ? solveT_ : solve_;
177 
178  // Solve for all right hand sides
179  std::fill_n(get_ptr(m->arg), solv.n_in(), nullptr);
180  m->arg[0] = get_ptr(m->q);
181  m->arg[1] = get_ptr(m->r);
182  std::fill_n(get_ptr(m->res), solv.n_out(), nullptr);
183  for (casadi_int i=0; i<nrhs; ++i) {
184  std::copy_n(x, nrow(), get_ptr(m->w)); // Copy x to a temporary
185  m->arg[2] = get_ptr(m->w);
186  m->res[0] = x;
187  if (solv(get_ptr(m->arg), get_ptr(m->res),
188  get_ptr(m->iw), get_ptr(m->w)+nrow(), 0)) return 1;
189  x += nrow();
190  }
191  return 0;
192  }
193 
194  void SymbolicQr::linsol_eval_sx(const SXElem** arg, SXElem** res,
195  casadi_int* iw, SXElem* w, void* mem,
196  bool tr, casadi_int nrhs) const {
197  //auto m = static_cast<SymbolicQrMemory*>(mem);
198  casadi_assert_dev(arg[0]!=nullptr);
199  casadi_assert_dev(arg[1]!=nullptr);
200  casadi_assert_dev(res[0]!=nullptr);
201 
202  // Get A and factorize it
203  SX A = SX::zeros(sp_);
204  std::copy(arg[1], arg[1]+A.nnz(), A->begin());
205  std::vector<SX> v = factorize_(A);
206 
207  // Select solve function
208  const Function& solv = tr ? solveT_ : solve_;
209 
210  // Solve for every right hand side
211  v.push_back(SX::zeros(A.size1()));
212  const SXElem* a=arg[0];
213  SXElem* r=res[0];
214  for (casadi_int i=0; i<nrhs; ++i) {
215  std::copy(a, a+v[2].nnz(), v[2]->begin());
216  SX rr = solv(v).at(0);
217  std::copy(rr->begin(), rr->end(), r);
218  r += rr.nnz();
219  a += v[2].nnz();
220  }
221  }
222 
224  arg.resize(std::max(arg.size(), f.sz_arg()));
225  res.resize(std::max(res.size(), f.sz_res()));
226  iw.resize(std::max(iw.size(), f.sz_iw()));
227  w.resize(std::max(w.size(), f.sz_w()));
228  }
229 
231  s.version("SymbolicQr", 1);
232  s.unpack("SymbolicQr::factorize", factorize_);
233  s.unpack("SymbolicQr::solve", solve_);
234  s.unpack("SymbolicQr::solveT", solveT_);
235  s.unpack("SymbolicQr::fopts", fopts_);
236  }
237 
240  s.version("SymbolicQr", 1);
241  s.pack("SymbolicQr::factorize", factorize_);
242  s.pack("SymbolicQr::solve", solve_);
243  s.pack("SymbolicQr::solveT", solveT_);
244  s.pack("SymbolicQr::fopts", fopts_);
245  }
246 
247 } // namespace casadi
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
static const Options options_
Options.
Function object.
Definition: function.hpp:60
casadi_int nnz_out() const
Get number of output nonzeros.
Definition: function.cpp:855
size_t sz_res() const
Get required length of res field.
Definition: function.cpp:1085
size_t sz_iw() const
Get required length of iw field.
Definition: function.cpp:1087
casadi_int n_out() const
Get the number of function outputs.
Definition: function.cpp:823
casadi_int n_in() const
Get the number of function inputs.
Definition: function.cpp:819
size_t sz_w() const
Get required length of w field.
Definition: function.cpp:1089
size_t sz_arg() const
Get required length of arg field.
Definition: function.cpp:1083
casadi_int nnz() const
Get the number of (structural) non-zero elements.
casadi_int size1() const
Get the first dimension (i.e. number of rows)
static Matrix< Scalar > sym(const std::string &name, casadi_int nrow=1, casadi_int ncol=1)
Create an nrow-by-ncol symbolic primitive.
static Matrix< Scalar > zeros(casadi_int nrow=1, casadi_int ncol=1)
Create a dense matrix or a matrix with specified sparsity with all entries zero.
void init(const Dict &opts) override
Initialize.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
casadi_int nnz() const
casadi_int nrow() const
Get sparsity pattern.
int init_mem(void *mem) const override
Initalize memory block.
Sparse matrix class. SX and DM are specializations.
Definition: matrix_decl.hpp:99
Matrix< Scalar > T() const
Transpose the matrix.
const Sparsity & sparsity() const
Const access the sparsity - reference to data member.
static Matrix< Scalar > solve(const Matrix< Scalar > &A, const Matrix< Scalar > &b)
static void registerPlugin(const Plugin &plugin, bool needs_lock=true)
Register an integrator in the factory.
void clear_mem()
Clear all memory (called from destructor)
The basic scalar symbolic class of CasADi.
Definition: sx_elem.hpp:75
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
Class representing a Slice.
Definition: slice.hpp:48
General sparsity class.
Definition: sparsity.hpp:106
casadi_int size1() const
Get the number of rows.
Definition: sparsity.cpp:124
casadi_int size2() const
Get the number of columns.
Definition: sparsity.cpp:128
casadi_int btf(std::vector< casadi_int > &rowperm, std::vector< casadi_int > &colperm, std::vector< casadi_int > &rowblock, std::vector< casadi_int > &colblock, std::vector< casadi_int > &coarse_rowblock, std::vector< casadi_int > &coarse_colblock) const
Calculate the block triangular form (BTF)
Definition: sparsity.cpp:711
int init_mem(void *mem) const override
Initalize memory block.
static const Options options_
Options.
Definition: symbolic_qr.hpp:91
SymbolicQr(const std::string &name, const Sparsity &sp)
Definition: symbolic_qr.cpp:51
void init(const Dict &opts) override
Initialize.
Definition: symbolic_qr.cpp:67
static LinsolInternal * creator(const std::string &name, const Sparsity &sp)
Create a new Linsol.
Definition: symbolic_qr.hpp:85
static const std::string meta_doc
A documentation string.
~SymbolicQr() override
Definition: symbolic_qr.cpp:55
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int solve(void *mem, const double *A, double *x, casadi_int nrhs, bool tr) const override
void linsol_eval_sx(const SXElem **arg, SXElem **res, casadi_int *iw, SXElem *w, void *mem, bool tr, casadi_int nrhs) const override
Evaluate symbolically (SX)
int nfact(void *mem, const double *A) const override
Numeric factorization.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
The casadi namespace.
Definition: archiver.cpp:28
int CASADI_LINSOL_SYMBOLICQR_EXPORT casadi_register_linsol_symbolicqr(LinsolInternal::Plugin *plugin)
Definition: symbolic_qr.cpp:36
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.
void CASADI_LINSOL_SYMBOLICQR_EXPORT casadi_load_linsol_symbolicqr()
Definition: symbolic_qr.cpp:47
Options metadata for a class.
Definition: options.hpp:40
Memory for SymbolicQR
Definition: symbolic_qr.hpp:49
void alloc(const Function &f)
std::vector< casadi_int > iw
Definition: symbolic_qr.hpp:53
std::vector< double * > res
Definition: symbolic_qr.hpp:52
std::vector< const double * > arg
Definition: symbolic_qr.hpp:51
std::vector< double > w
Definition: symbolic_qr.hpp:54