solve.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_SOLVE_HPP
27 #define CASADI_SOLVE_HPP
28 
29 #include "mx_node.hpp"
30 #include "casadi_call.hpp"
31 
32 namespace casadi {
46  template<bool Tr>
47  class CASADI_EXPORT Solve : public MXNode {
48  public:
52  Solve(const MX& r, const MX& A);
53 
57  ~Solve() override {}
58 
62  std::string disp(const std::vector<std::string>& arg) const override;
63 
67  virtual std::string mod_prefix() const {return "";}
68 
72  virtual std::string mod_suffix() const {return "";}
73 
77  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
78 
82  void ad_forward(const std::vector<std::vector<MX> >& fseed,
83  std::vector<std::vector<MX> >& fsens) const override;
84 
88  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
89  std::vector<std::vector<MX> >& asens) const override;
90 
92  casadi_int n_inplace() const override { return 1;}
93 
97  int sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
98 
102  int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const override;
103 
107  casadi_int op() const override { return OP_SOLVE;}
108 
110  Dict info() const override {
111  return {{"tr", Tr}};
112  }
113 
115  virtual MX solve(const MX& A, const MX& B, bool tr) const = 0;
116 
118  virtual const Sparsity& A_sp() const { return dep(1).sparsity();}
119 
123  void serialize_body(SerializingStream& s) const override;
124 
128  void serialize_type(SerializingStream& s) const override;
129 
133  static MXNode* deserialize(DeserializingStream& s);
134 
138  explicit Solve(DeserializingStream& s);
139  };
140 
147  template<bool Tr>
148  class CASADI_EXPORT LinsolCall : public Solve<Tr> {
149  public:
150 
154  LinsolCall(const MX& r, const MX& A, const Linsol& linear_solver);
155 
159  ~LinsolCall() override {}
160 
162  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
163 
165  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
166 
170  size_t sz_w() const override;
171 
175  void generate(CodeGenerator& g,
176  const std::vector<casadi_int>& arg,
177  const std::vector<casadi_int>& res) const override;
178 
181 
183  MX solve(const MX& A, const MX& B, bool tr) const override {
184  return linsol_.solve(A, B, tr);
185  }
186 
190  void serialize_body(SerializingStream& s) const override;
191 
195  void serialize_type(SerializingStream& s) const override;
196 
200  static MXNode* deserialize(DeserializingStream& s);
201 
205  explicit LinsolCall(DeserializingStream& s);
206  };
207 
214  template<bool Tr>
215  class CASADI_EXPORT TriuSolve : public Solve<Tr> {
216  public:
217 
221  TriuSolve(const MX& r, const MX& A);
222 
226  ~TriuSolve() override {}
227 
229  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
230 
232  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
233 
235  MX solve(const MX& A, const MX& B, bool tr) const override {
236  return A->get_solve_triu(B, tr);
237  }
238 
242  void generate(CodeGenerator& g, const std::vector<casadi_int>& arg,
243  const std::vector<casadi_int>& res) const override;
244  };
245 
252  template<bool Tr>
253  class CASADI_EXPORT TrilSolve : public Solve<Tr> {
254  public:
255 
259  TrilSolve(const MX& r, const MX& A);
260 
264  ~TrilSolve() override {}
265 
267  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
268 
270  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
271 
273  MX solve(const MX& A, const MX& B, bool tr) const override {
274  return A->get_solve_tril(B, tr);
275  }
276 
280  void generate(CodeGenerator& g, const std::vector<casadi_int>& arg,
281  const std::vector<casadi_int>& res) const override;
282  };
283 
290  template<bool Tr>
291  class CASADI_EXPORT SolveUnity : public Solve<Tr> {
292  public:
293 
297  SolveUnity(const MX& r, const MX& A);
298 
302  ~SolveUnity() override {}
303 
307  std::string mod_prefix() const override {return "(I - ";}
308 
312  std::string mod_suffix() const override {return ")";}
313 
315  const Sparsity& A_sp() const override;
316 
317  // Sparsity pattern of linear system, cached
318  mutable Sparsity A_sp_;
319  };
320 
327  template<bool Tr>
328  class CASADI_EXPORT TriuSolveUnity : public SolveUnity<Tr> {
329  public:
330 
334  TriuSolveUnity(const MX& r, const MX& A);
335 
339  ~TriuSolveUnity() override {}
340 
342  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
343 
345  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
346 
348  MX solve(const MX& A, const MX& B, bool tr) const override {
349  return A->get_solve_triu_unity(B, tr);
350  }
351 
355  void generate(CodeGenerator& g, const std::vector<casadi_int>& arg,
356  const std::vector<casadi_int>& res) const override;
357  };
358 
365  template<bool Tr>
366  class CASADI_EXPORT TrilSolveUnity : public SolveUnity<Tr> {
367  public:
368 
372  TrilSolveUnity(const MX& r, const MX& A);
373 
377  ~TrilSolveUnity() override {}
378 
380  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
381 
383  int eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const override;
384 
386  MX solve(const MX& A, const MX& B, bool tr) const override {
387  return A->get_solve_tril_unity(B, tr);
388  }
389 
393  void generate(CodeGenerator& g, const std::vector<casadi_int>& arg,
394  const std::vector<casadi_int>& res) const override;
395  };
396 
397 } // namespace casadi
398 
399 #endif // CASADI_SOLVE_HPP
Helper class for C code generation.
Helper class for Serialization.
Linear solve operation with a linear solver instance.
Definition: solve.hpp:148
~LinsolCall() override
Destructor.
Definition: solve.hpp:159
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:183
Linsol linsol_
Linear solver (may be shared between multiple nodes)
Definition: solve.hpp:180
Linear solver.
Definition: linsol.hpp:55
DM solve(const DM &A, const DM &B, bool tr=false) const
Node class for MX objects.
Definition: mx_node.hpp:50
MX - Matrix expression.
Definition: mx.hpp:84
Helper class for Serialization.
Linear solve with unity diagonal added.
Definition: solve.hpp:291
~SolveUnity() override
Destructor.
Definition: solve.hpp:302
Sparsity A_sp_
Definition: solve.hpp:318
std::string mod_suffix() const override
Modifier for linear system, after argument.
Definition: solve.hpp:312
std::string mod_prefix() const override
Modifier for linear system, before argument.
Definition: solve.hpp:307
An MX atomic for linear solver solution: x = r * A^-1 or x = r * A^-T.
Definition: solve.hpp:47
Dict info() const override
Definition: solve.hpp:110
virtual const Sparsity & A_sp() const
Sparsity pattern for the linear system.
Definition: solve.hpp:118
virtual std::string mod_prefix() const
Modifier for linear system, before argument.
Definition: solve.hpp:67
casadi_int op() const override
Get the operation.
Definition: solve.hpp:107
virtual std::string mod_suffix() const
Modifier for linear system, after argument.
Definition: solve.hpp:72
virtual MX solve(const MX &A, const MX &B, bool tr) const =0
Solve another system with the same factorization.
~Solve() override
Destructor.
Definition: solve.hpp:57
casadi_int n_inplace() const override
Can the operation be performed inplace (i.e. overwrite the result)
Definition: solve.hpp:92
General sparsity class.
Definition: sparsity.hpp:99
Linear solve with an upper triangular matrix.
Definition: solve.hpp:366
~TrilSolveUnity() override
Destructor.
Definition: solve.hpp:377
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:386
Linear solve with an upper triangular matrix.
Definition: solve.hpp:253
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:273
~TrilSolve() override
Destructor.
Definition: solve.hpp:264
Linear solve with an upper triangular matrix, unity diagonal.
Definition: solve.hpp:328
~TriuSolveUnity() override
Destructor.
Definition: solve.hpp:339
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:348
Linear solve with an upper triangular matrix.
Definition: solve.hpp:215
~TriuSolve() override
Destructor.
Definition: solve.hpp:226
MX solve(const MX &A, const MX &B, bool tr) const override
Solve another system with the same factorization.
Definition: solve.hpp:235
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.