linsol_qr.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "linsol_qr.hpp"
27 #include "casadi/core/global_options.hpp"
28 
29 namespace casadi {
30 
31  extern "C"
32  int CASADI_LINSOL_QR_EXPORT
33  casadi_register_linsol_qr(LinsolInternal::Plugin* plugin) {
34  plugin->creator = LinsolQr::creator;
35  plugin->name = "qr";
36  plugin->doc = LinsolQr::meta_doc.c_str();
37  plugin->version = CASADI_VERSION;
38  plugin->options = &LinsolQr::options_;
39  plugin->deserialize = &LinsolQr::deserialize;
40  return 0;
41  }
42 
43  extern "C"
44  void CASADI_LINSOL_QR_EXPORT casadi_load_linsol_qr() {
46  }
47 
48  LinsolQr::LinsolQr(const std::string& name, const Sparsity& sp)
49  : LinsolInternal(name, sp) {
50  }
51 
53  clear_mem();
54  }
55 
58  {{"eps",
59  {OT_DOUBLE,
60  "Minimum R entry before singularity is declared [1e-12]"}},
61  {"cache",
62  {OT_DOUBLE,
63  "Amount of factorisations to remember (thread-local) [0]"}}
64  }
65  };
66 
67  void LinsolQr::init(const Dict& opts) {
68  // Call the init method of the base class
70 
71  // Read options
72  eps_ = 1e-12;
73  n_cache_ = 0;
74  for (auto&& op : opts) {
75  if (op.first=="eps") {
76  eps_ = op.second;
77  } else if (op.first=="cache") {
78  n_cache_ = op.second;
79  }
80  }
81 
82  // Symbolic factorization
84  }
85 
89  }
90 
91  int LinsolQr::init_mem(void* mem) const {
92  if (LinsolInternal::init_mem(mem)) return 1;
93  auto m = static_cast<LinsolQrMemory*>(mem);
94 
95  // Memory for numerical solution
96  m->v.resize(sp_v_.nnz());
97  m->r.resize(sp_r_.nnz());
98  m->beta.resize(ncol());
99  m->w.resize(nrow() + ncol());
100 
101  m->cache.resize(cache_stride_*n_cache_);
102  m->cache_loc.resize(n_cache_, -1);
103 
104  return 0;
105  }
106 
107  int LinsolQr::sfact(void* mem, const double* A) const {
108  return 0;
109  }
110 
111  int LinsolQr::nfact(void* mem, const double* A) const {
112  auto m = static_cast<LinsolQrMemory*>(mem);
113 
114  // Check for a cache hit
115  double* cache = nullptr;
116  bool cache_hit = cache_check(A, get_ptr(m->cache), get_ptr(m->cache_loc),
117  cache_stride_, n_cache_, sp_.nnz(), &cache);
118 
119  if (cache && cache_hit) {
120  cache += sp_.nnz();
121  // Retrieve from cache and return early
122  casadi_copy(cache, sp_v_.nnz(), get_ptr(m->v)); cache+=sp_v_.nnz();
123  casadi_copy(cache, sp_r_.nnz(), get_ptr(m->r)); cache+=sp_r_.nnz();
124  casadi_copy(cache, ncol(), get_ptr(m->beta)); cache+=ncol();
125  return 0;
126  }
127 
128  // Cache miss -> compute result
129  casadi_qr(sp_, A, get_ptr(m->w),
130  sp_v_, get_ptr(m->v), sp_r_, get_ptr(m->r),
131  get_ptr(m->beta), get_ptr(prinv_), get_ptr(pc_));
132  // Check singularity
133  double rmin;
134  casadi_int irmin, nullity;
135  nullity = casadi_qr_singular(&rmin, &irmin, get_ptr(m->r), sp_r_, get_ptr(pc_), eps_);
136  if (nullity) {
137  if (verbose_) {
138  print("Singularity detected: Rank %lld<%lld\n", ncol()-nullity, ncol());
139  print("First singular R entry: %g<%g, corresponding to row %lld\n", rmin, eps_, irmin);
140  casadi_qr_colcomb(get_ptr(m->w), get_ptr(m->r), sp_r_, get_ptr(pc_), eps_, 0);
141  print("Linear combination of columns:\n[");
142  for (casadi_int k=0; k<ncol(); ++k) print(k==0 ? "%g" : ", %g", m->w[k]);
143  print("]\n");
144  }
145  return 1;
146  }
147 
148  if (cache) { // Store result in cache
149  casadi_copy(A, sp_.nnz(), cache); cache+=sp_.nnz();
150  casadi_copy(get_ptr(m->v), sp_v_.nnz(), cache); cache+=sp_v_.nnz();
151  casadi_copy(get_ptr(m->r), sp_r_.nnz(), cache); cache+=sp_r_.nnz();
152  casadi_copy(get_ptr(m->beta), ncol(), cache); cache+=ncol();
153  }
154  return 0;
155  }
156 
157  int LinsolQr::solve(void* mem, const double* A, double* x, casadi_int nrhs, bool tr) const {
158  auto m = static_cast<LinsolQrMemory*>(mem);
159  casadi_qr_solve(x, nrhs, tr,
160  sp_v_, get_ptr(m->v), sp_r_, get_ptr(m->r),
161  get_ptr(m->beta), get_ptr(prinv_), get_ptr(pc_), get_ptr(m->w));
162  return 0;
163  }
164 
165  void LinsolQr::generate(CodeGenerator& g, const std::string& A, const std::string& x,
166  casadi_int nrhs, bool tr) const {
167  // Codegen the integer vectors
168  std::string prinv = g.constant(prinv_);
169  std::string pc = g.constant(pc_);
170  std::string sp = g.sparsity(sp_);
171  std::string sp_v = g.sparsity(sp_v_);
172  std::string sp_r = g.sparsity(sp_r_);
173 
174  // Place in block to avoid conflicts caused by local variables
175  g << "{\n";
176  g.comment("FIXME(@jaeandersson): Memory allocation can be avoided");
177  g << "casadi_real v[" << sp_v_.nnz() << "], "
178  "r[" << sp_r_.nnz() << "], "
179  "beta[" << ncol() << "], "
180  "w[" << nrow() + ncol() << "];\n";
181 
182  if (n_cache_) {
183  g << "casadi_real *c;\n";
184  g << "casadi_real cache[" << cache_stride_*n_cache_ << "];\n";
185  g << "int cache_loc[" << n_cache_ << "] = {";
186  for (casadi_int i=0;i<n_cache_;++i) {
187  g << "-1,";
188  }
189  g << "};\n";
190  g << "if (" << g.cache_check(A, "cache", "cache_loc",
191  cache_stride_, n_cache_, sp_.nnz(), "&c") << ") {\n";
192  casadi_int offset = sp_.nnz();
193  g.comment("Retrieve from cache");
194  g << g.copy("c+" + str(offset), sp_v_.nnz(), "v") << "\n"; offset+=sp_v_.nnz();
195  g << g.copy("c+" + str(offset), sp_r_.nnz(), "r") << "\n"; offset+=sp_r_.nnz();
196  g << g.copy("c+" + str(offset), ncol(), "beta") << "\n"; offset+=ncol();
197  g << "} else {\n";
198  }
199 
200  // Factorize
201  g << g.qr(sp, A, "w", sp_v, "v", sp_r, "r", "beta", prinv, pc) << "\n";
202 
203  if (n_cache_) {
204  casadi_int offset = 0;
205  g.comment("Store in cache");
206  g << g.copy(A, sp_.nnz(), "c") << "\n";; offset+=sp_.nnz();
207  g << g.copy("v", sp_v_.nnz(), "c+"+str(offset)) << "\n"; offset+=sp_v_.nnz();
208  g << g.copy("r", sp_r_.nnz(), "c+"+str(offset)) << "\n"; offset+=sp_r_.nnz();
209  g << g.copy("beta", ncol(), "c+"+str(offset)) << "\n"; offset+=ncol();
210  g << "}\n";
211  }
212 
213  // Solve
214  g << g.qr_solve(x, nrhs, tr, sp_v, "v", sp_r, "r", "beta", prinv, pc, "w") << "\n";
215 
216  // End of block
217  g << "}\n";
218  }
219 
221  int version = s.version("LinsolQr", 1, 2);
222  s.unpack("LinsolQr::prinv", prinv_);
223  s.unpack("LinsolQr::pc", pc_);
224  s.unpack("LinsolQr::sp_v", sp_v_);
225  s.unpack("LinsolQr::sp_r", sp_r_);
226  s.unpack("LinsolQr::eps", eps_);
227  if (version>1) {
228  s.unpack("LinsolQr::n_cache", n_cache_);
229  } else {
230  n_cache_ = 1;
231  }
232  }
233 
236  s.version("LinsolQr", 2);
237  s.pack("LinsolQr::prinv", prinv_);
238  s.pack("LinsolQr::pc", pc_);
239  s.pack("LinsolQr::sp_v", sp_v_);
240  s.pack("LinsolQr::sp_r", sp_r_);
241  s.pack("LinsolQr::eps", eps_);
242  s.pack("LinsolQr::n_cache", n_cache_);
243  }
244 
245 } // namespace casadi
Helper class for C code generation.
std::string copy(const std::string &arg, std::size_t n, const std::string &res)
Create a copy operation.
void comment(const std::string &s)
Write a comment line (ignored if not verbose)
std::string constant(const std::vector< casadi_int > &v)
Represent an array constant; adding it when new.
std::string cache_check(const std::string &key, const std::string &cache, const std::string &loc, casadi_int stride, casadi_int sz, casadi_int key_sz, const std::string &val)
cache check
std::string qr_solve(const std::string &x, casadi_int nrhs, bool tr, const std::string &sp_v, const std::string &v, const std::string &sp_r, const std::string &r, const std::string &beta, const std::string &prinv, const std::string &pc, const std::string &w)
QR solve.
std::string sparsity(const Sparsity &sp, bool canonical=true)
std::string qr(const std::string &sp, const std::string &A, const std::string &w, const std::string &sp_v, const std::string &v, const std::string &sp_r, const std::string &r, const std::string &beta, const std::string &prinv, const std::string &pc)
QR factorization.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
void init(const Dict &opts) override
Initialize.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
casadi_int nrow() const
Get sparsity pattern.
casadi_int ncol() const
int init_mem(void *mem) const override
Initalize memory block.
static const std::string meta_doc
A documentation string.
Definition: linsol_qr.hpp:110
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
Definition: linsol_qr.hpp:125
int sfact(void *mem, const double *A) const override
Definition: linsol_qr.cpp:107
~LinsolQr() override
Definition: linsol_qr.cpp:52
int solve(void *mem, const double *A, double *x, casadi_int nrhs, bool tr) const override
Definition: linsol_qr.cpp:157
int nfact(void *mem, const double *A) const override
Numeric factorization.
Definition: linsol_qr.cpp:111
std::vector< casadi_int > pc_
Definition: linsol_qr.hpp:113
std::vector< casadi_int > prinv_
Symbolic factorization.
Definition: linsol_qr.hpp:113
casadi_int cache_stride_
Definition: linsol_qr.hpp:119
static LinsolInternal * creator(const std::string &name, const Sparsity &sp)
Create a new LinsolInternal.
Definition: linsol_qr.hpp:62
void generate(CodeGenerator &g, const std::string &A, const std::string &x, casadi_int nrhs, bool tr) const override
Generate C code.
Definition: linsol_qr.cpp:165
int init_mem(void *mem) const override
Initalize memory block.
Definition: linsol_qr.cpp:91
casadi_int n_cache_
Cache size.
Definition: linsol_qr.hpp:118
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
Definition: linsol_qr.cpp:234
static const Options options_
Options.
Definition: linsol_qr.hpp:77
void finalize() override
Finalize the object creation.
Definition: linsol_qr.cpp:86
LinsolQr(const std::string &name, const Sparsity &sp)
Definition: linsol_qr.cpp:48
void init(const Dict &opts) override
Initialize.
Definition: linsol_qr.cpp:67
static void registerPlugin(const Plugin &plugin, bool needs_lock=true)
Register an integrator in the factory.
void print(const char *fmt,...) const
C-style formatted printing during evaluation.
bool verbose_
Verbose printout.
virtual void finalize()
Finalize the object creation.
static const Options options_
Options.
void clear_mem()
Clear all memory (called from destructor)
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
General sparsity class.
Definition: sparsity.hpp:106
casadi_int nnz() const
Get the number of (structural) non-zeros.
Definition: sparsity.cpp:148
void qr_sparse(Sparsity &V, Sparsity &R, std::vector< casadi_int > &prinv, std::vector< casadi_int > &pc, bool amd=true) const
Symbolic QR factorization.
Definition: sparsity.cpp:654
The casadi namespace.
Definition: archiver.cpp:28
void casadi_copy(const T1 *x, casadi_int n, T1 *y)
COPY: y <-x.
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
void CASADI_LINSOL_QR_EXPORT casadi_load_linsol_qr()
Definition: linsol_qr.cpp:44
T * get_ptr(std::vector< T > &v)
Get a pointer to the data contained in the vector.
int CASADI_LINSOL_QR_EXPORT casadi_register_linsol_qr(LinsolInternal::Plugin *plugin)
Definition: linsol_qr.cpp:33
std::vector< double > v
Definition: linsol_qr.hpp:44
Options metadata for a class.
Definition: options.hpp:40