solve.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SOLVE_HPP
27 #define CASADI_SOLVE_HPP
28 
29 #include "mx_node.hpp"
30 #include "casadi_call.hpp"
31 
32 namespace casadi {
46  template<bool Tr>
47  class CASADI_EXPORT Solve : public MXNode {
48  public:
52  Solve(const MX& r, const MX& A);
53 
57  ~Solve() override {}
58 
62  std::string disp(const std::vector<std::string>& arg) const override;
63 
67  virtual std::string mod_prefix() const {return "";}
68 
72  virtual std::string mod_suffix() const {return "";}
73 
77  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
78 
82  void ad_forward(const std::vector<std::vector<MX> >& fseed,
83  std::vector<std::vector<MX> >& fsens) const override;
84 
88  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
89  std::vector<std::vector<MX> >& asens) const override;
90 
92  casadi_int n_inplace() const override { return 1;}
93 
97  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
98 
102  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
103 
107  casadi_int op() const override { return OP_SOLVE;}
108 
110  Dict info() const override {
111  return {{"tr", Tr}};
112  }
113 
115  virtual MX solve(const MX& A, const MX& B, bool tr) const = 0;
116 
118  virtual const Sparsity& A_sp() const { return dep(1).sparsity();}
119 
123  void serialize_body(SerializingStream& s) const override;
124 
128  void serialize_type(SerializingStream& s) const override;
129 
133  static MXNode* deserialize(DeserializingStream& s);
134 
138  explicit Solve(DeserializingStream& s);
139  };
140 
147  template<bool Tr>
148  class CASADI_EXPORT LinsolCall : public Solve<Tr> {
149  public:
150 
154  LinsolCall(const MX& r, const MX& A, const Linsol& linear_solver);
155 
159  ~LinsolCall() override {}
160 
162  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
163 
165  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
166 
170  size_t sz_w() const override;
171 
175  void generate(CodeGenerator& g,
176  const std::vector<casadi_int>& arg,
177  const std::vector<casadi_int>& res,
178  const std::vector<bool>& arg_is_ref,
179  std::vector<bool>& res_is_ref) const override;
180 
183 
185  MX solve(const MX& A, const MX& B, bool tr) const override {
186  return linsol_.solve(A, B, tr);
187  }
188 
192  void serialize_body(SerializingStream& s) const override;
193 
197  void serialize_type(SerializingStream& s) const override;
198 
202  static MXNode* deserialize(DeserializingStream& s);
203 
207  explicit LinsolCall(DeserializingStream& s);
208  };
209 
216  template<bool Tr>
217  class CASADI_EXPORT TriuSolve : public Solve<Tr> {
218  public:
219 
223  TriuSolve(const MX& r, const MX& A);
224 
228  ~TriuSolve() override {}
229 
231  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
232 
234  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
235 
237  MX solve(const MX& A, const MX& B, bool tr) const override {
238  return A->get_solve_triu(B, tr);
239  }
240 
244  void generate(CodeGenerator& g,
245  const std::vector<casadi_int>& arg,
246  const std::vector<casadi_int>& res,
247  const std::vector<bool>& arg_is_ref,
248  std::vector<bool>& res_is_ref) const override;
249  };
250 
257  template<bool Tr>
258  class CASADI_EXPORT TrilSolve : public Solve<Tr> {
259  public:
260 
264  TrilSolve(const MX& r, const MX& A);
265 
269  ~TrilSolve() override {}
270 
272  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
273 
275  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
276 
278  MX solve(const MX& A, const MX& B, bool tr) const override {
279  return A->get_solve_tril(B, tr);
280  }
281 
285  void generate(CodeGenerator& g,
286  const std::vector<casadi_int>& arg,
287  const std::vector<casadi_int>& res,
288  const std::vector<bool>& arg_is_ref,
289  std::vector<bool>& res_is_ref) const override;
290  };
291 
298  template<bool Tr>
299  class CASADI_EXPORT SolveUnity : public Solve<Tr> {
300  public:
301 
305  SolveUnity(const MX& r, const MX& A);
306 
310  ~SolveUnity() override {}
311 
315  std::string mod_prefix() const override {return "(I - ";}
316 
320  std::string mod_suffix() const override {return ")";}
321 
323  const Sparsity& A_sp() const override;
324 
325  // Sparsity pattern of linear system, cached
326  mutable Sparsity A_sp_;
327 
328 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
330  mutable std::mutex A_sp_mtx_;
331 #endif // CASADI_WITH_THREADSAFE_SYMBOLICS
332  };
333 
340  template<bool Tr>
341  class CASADI_EXPORT TriuSolveUnity : public SolveUnity<Tr> {
342  public:
343 
347  TriuSolveUnity(const MX& r, const MX& A);
348 
352  ~TriuSolveUnity() override {}
353 
355  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
356 
358  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
359 
361  MX solve(const MX& A, const MX& B, bool tr) const override {
362  return A->get_solve_triu_unity(B, tr);
363  }
364 
368  void generate(CodeGenerator& g,
369  const std::vector<casadi_int>& arg,
370  const std::vector<casadi_int>& res,
371  const std::vector<bool>& arg_is_ref,
372  std::vector<bool>& res_is_ref) const override;
373  };
374 
381  template<bool Tr>
382  class CASADI_EXPORT TrilSolveUnity : public SolveUnity<Tr> {
383  public:
384 
388  TrilSolveUnity(const MX& r, const MX& A);
389 
393  ~TrilSolveUnity() override {}
394 
396  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
397 
399  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
400 
402  MX solve(const MX& A, const MX& B, bool tr) const override {
403  return A->get_solve_tril_unity(B, tr);
404  }
405 
409  void generate(CodeGenerator& g,
410  const std::vector<casadi_int>& arg,
411  const std::vector<casadi_int>& res,
412  const std::vector<bool>& arg_is_ref,
413  std::vector<bool>& res_is_ref) const override;
414  };
415 
416 } // namespace casadi
417 
418 #endif // CASADI_SOLVE_HPP
Helper class for C code generation.
Helper class for Serialization.
Linear solve operation with a linear solver instance.
Definition: solve.hpp:148
~LinsolCall() override
Destructor.
Definition: solve.hpp:159
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:185
Linsol linsol_
Linear solver (may be shared between multiple nodes)
Definition: solve.hpp:182
Linear solver.
Definition: linsol.hpp:55
DM solve(const DM &A, const DM &B, bool tr=false) const
Node class for MX objects.
Definition: mx_node.hpp:51
MX - Matrix expression.
Definition: mx.hpp:92
Helper class for Serialization.
Linear solve with unity diagonal added.
Definition: solve.hpp:299
~SolveUnity() override
Destructor.
Definition: solve.hpp:310
Sparsity A_sp_
Definition: solve.hpp:326
std::string mod_suffix() const override
Modifier for linear system, after argument.
Definition: solve.hpp:320
std::string mod_prefix() const override
Modifier for linear system, before argument.
Definition: solve.hpp:315
An MX atomic for linear solver solution: x = r * A^-1 or x = r * A^-T.
Definition: solve.hpp:47
Dict info() const override
Definition: solve.hpp:110
virtual const Sparsity & A_sp() const
Sparsity pattern for the linear system.
Definition: solve.hpp:118
virtual std::string mod_prefix() const
Modifier for linear system, before argument.
Definition: solve.hpp:67
casadi_int op() const override
Get the operation.
Definition: solve.hpp:107
virtual std::string mod_suffix() const
Modifier for linear system, after argument.
Definition: solve.hpp:72
virtual MX solve(const MX &A, const MX &B, bool tr) const =0
Solve another system with the same factorization.
~Solve() override
Destructor.
Definition: solve.hpp:57
casadi_int n_inplace() const override
Can the operation be performed inplace (i.e. overwrite the result)
Definition: solve.hpp:92
General sparsity class.
Definition: sparsity.hpp:106
Linear solve with an upper triangular matrix.
Definition: solve.hpp:382
~TrilSolveUnity() override
Destructor.
Definition: solve.hpp:393
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:402
Linear solve with an upper triangular matrix.
Definition: solve.hpp:258
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:278
~TrilSolve() override
Destructor.
Definition: solve.hpp:269
Linear solve with an upper triangular matrix, unity diagonal.
Definition: solve.hpp:341
~TriuSolveUnity() override
Destructor.
Definition: solve.hpp:352
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:361
Linear solve with an upper triangular matrix.
Definition: solve.hpp:217
~TriuSolve() override
Destructor.
Definition: solve.hpp:228
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:237
The casadi namespace.
Definition: archiver.hpp:32
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.