bspline.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BSPLINE_HPP
27 #define CASADI_BSPLINE_HPP
28 
29 #include "mx_node.hpp"
30 #include <map>
31 #include <stack>
32 
34 
35 namespace casadi {
42  class CASADI_EXPORT BSplineCommon : public MXNode {
43  public:
44 
46  BSplineCommon(const std::vector<double>& knots,
47  const std::vector<casadi_int>& offset,
48  const std::vector<casadi_int>& degree,
49  casadi_int m,
50  const std::vector<casadi_int>& lookup_mode);
51 
53  ~BSplineCommon() override {}
54 
55  static void prepare(casadi_int m, const std::vector<casadi_int>& offset,
56  const std::vector<casadi_int>& degree, casadi_int &coeffs_size,
57  std::vector<casadi_int>& coeffs_dims, std::vector<casadi_int>& strides);
58 
59  static casadi_int get_coeff_size(casadi_int m, const std::vector<casadi_int>& offset,
60  const std::vector<casadi_int>& degree);
61 
62  std::vector<double> knots_;
63  std::vector<casadi_int> offset_;
64  std::vector<casadi_int> degree_;
65  casadi_int m_;
66  std::vector<casadi_int> lookup_mode_;
67 
68  // Derived fields
69  std::vector<casadi_int> strides_;
70  std::vector<casadi_int> coeffs_dims_;
71  casadi_int coeffs_size_;
72 
79  mutable MX jac_cache_;
80 
81  virtual MX jac_cached() const = 0;
82 
86  static size_t n_iw(const std::vector<casadi_int> &degree);
87 
91  static size_t n_w(const std::vector<casadi_int> &degree);
92 
96  size_t sz_iw() const override;
97 
101  size_t sz_w() const override;
102 
106  casadi_int op() const override { return OP_BSPLINE;}
107 
111  void ad_forward(const std::vector<std::vector<MX> >& fseed,
112  std::vector<std::vector<MX> >& fsens) const override;
113 
117  void ad_reverse(const std::vector<std::vector<MX> >& aseed,
118  std::vector<std::vector<MX> >& asens) const override;
119 
123  void generate(CodeGenerator& g,
124  const std::vector<casadi_int>& arg,
125  const std::vector<casadi_int>& res) const override;
126 
130  virtual std::string generate(CodeGenerator& g,
131  const std::vector<casadi_int>& arg) const = 0;
132 
136  static MXNode* deserialize(DeserializingStream& s);
137 
138  template<class M>
139  M derivative_coeff(casadi_int i, const M& coeffs) const;
140 
141  template<class T>
142  MX jac(const MX& x, const T& coeffs) const;
143 
147  void serialize_body(SerializingStream& s) const override;
148 
149  protected:
150 
154  explicit BSplineCommon(DeserializingStream& s);
155 
156  };
157 
166  class CASADI_EXPORT BSpline : public BSplineCommon {
167  public:
168 
169  static MX create(const MX& x, const std::vector< std::vector<double> >& knots,
170  const std::vector<double>& coeffs,
171  const std::vector<casadi_int>& degree,
172  casadi_int m,
173  const Dict& opts);
174 
176  BSpline(const MX& x, const std::vector<double>& knots,
177  const std::vector<casadi_int>& offset,
178  const std::vector<double>& coeffs,
179  const std::vector<casadi_int>& degree,
180  casadi_int m,
181  const std::vector<casadi_int>& lookup_mode);
182 
184  ~BSpline() override {}
185 
187  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
188 
192  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
193 
197  std::string generate(CodeGenerator& g,
198  const std::vector<casadi_int>& arg) const override;
199 
203  std::string disp(const std::vector<std::string>& arg) const override;
204 
205  // Numeric coefficients
206  std::vector<double> coeffs_;
207 
208  MX jac_cached() const override;
209 
220  static DM dual(const std::vector<double>& x,
221  const std::vector< std::vector<double> >& knots,
222  const std::vector<casadi_int>& degree,
223  const Dict& opts);
227  void serialize_body(SerializingStream& s) const override;
231  void serialize_type(SerializingStream& s) const override;
232 
236  explicit BSpline(DeserializingStream& s);
237  };
238 
239  // Symbolic coefficients
240  class CASADI_EXPORT BSplineParametric : public BSplineCommon {
241  public:
242  static MX create(const MX& x, const MX& coeffs,
243  const std::vector< std::vector<double> >& knots,
244  const std::vector<casadi_int>& degree,
245  casadi_int m,
246  const Dict& opts);
247 
249  BSplineParametric(const MX& x, const MX& coeffs,
250  const std::vector<double>& knots,
251  const std::vector<casadi_int>& offset,
252  const std::vector<casadi_int>& degree,
253  casadi_int m,
254  const std::vector<casadi_int>& lookup_mode);
255 
257  ~BSplineParametric() override {}
258 
260  int eval(const double** arg, double** res, casadi_int* iw, double* w) const override;
261 
265  void eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const override;
266 
267  MX jac_cached() const override;
268 
272  std::string generate(CodeGenerator& g,
273  const std::vector<casadi_int>& arg) const override;
274 
278  std::string disp(const std::vector<std::string>& arg) const override;
279 
283  void serialize_type(SerializingStream& s) const override;
284 
288  explicit BSplineParametric(DeserializingStream& s) : BSplineCommon(s) {}
289  };
290 
291 } // namespace casadi
293 
294 #endif // CASADI_BSPLINE_HPP
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
Matrix< double > DM
Definition: dm_fwd.hpp:33