bspline_interpolant.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BSPLINE_INTERPOLANT_HPP
27 #define CASADI_BSPLINE_INTERPOLANT_HPP
28 
29 #include "casadi/core/interpolant_impl.hpp"
30 #include <casadi/solvers/casadi_interpolant_bspline_export.h>
31 
40 
41 namespace casadi {
42  class BSplineCommon;
55  class CASADI_INTERPOLANT_BSPLINE_EXPORT BSplineInterpolant : public Interpolant {
56  public:
57  // Constructor
58  BSplineInterpolant(const std::string& name,
59  const std::vector<double>& grid,
60  const std::vector<casadi_int>& offset,
61  const std::vector<double>& values,
62  casadi_int m);
63 
64  // Destructor
65  ~BSplineInterpolant() override;
66 
67  // Get name of the plugin
68  const char* plugin_name() const override { return "bspline";}
69 
70  // Get name of the class
71  std::string class_name() const override { return "BSplineInterpolant";}
72 
74  static Interpolant* creator(const std::string& name,
75  const std::vector<double>& grid,
76  const std::vector<casadi_int>& offset,
77  const std::vector<double>& values,
78  casadi_int m) {
79  return new BSplineInterpolant(name, grid, offset, values, m);
80  }
81 
82  // Is differentiable? Deferred to bspline
83  bool get_diff_in(casadi_int i) override { return true; }
84 
85  // Initialize
86  void init(const Dict& opts) override;
87 
89  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
90 
92 
93  bool has_jacobian() const override { return true;}
94  Function get_jacobian(const std::string& name,
95  const std::vector<std::string>& inames,
96  const std::vector<std::string>& onames,
97  const Dict& opts) const override;
99 
100 
102 
107  bool has_forward(casadi_int nfwd) const override { return true; }
108  Function get_forward(casadi_int nfwd, const std::string& name,
109  const std::vector<std::string>& inames,
110  const std::vector<std::string>& onames,
111  const Dict& opts) const override;
113 
115 
120  bool has_reverse(casadi_int nadj) const override { return true; }
121  Function get_reverse(casadi_int nadj, const std::string& name,
122  const std::vector<std::string>& inames,
123  const std::vector<std::string>& onames,
124  const Dict& opts) const override;
126 
128  bool has_codegen() const override { return true;}
129 
131  void codegen_body(CodeGenerator& g) const override;
132 
134  void codegen_declarations(CodeGenerator& g) const override;
135 
137  static const std::string meta_doc;
138 
140 
141  static const Options options_;
142  const Options& get_options() const override { return options_;}
144 
145  // Spline Function
147 
148  // Get all embedded functions, recursively
149  void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth) const override;
150 
152  int sp_forward(const bvec_t** arg, bvec_t** res,
153  casadi_int* iw, bvec_t* w, void* mem) const override {
154  return S_->sp_forward(arg, res, iw, w, mem);
155  }
156 
158  int sp_reverse(bvec_t** arg, bvec_t** res,
159  casadi_int* iw, bvec_t* w, void* mem) const override {
160  return S_->sp_reverse(arg, res, iw, w, mem);
161  }
162 
163  static std::vector<double> greville_points(const std::vector<double>& x, casadi_int deg);
164 
165  void serialize_body(SerializingStream &s) const override;
166 
169 
170  protected:
173 
174  static std::vector<double> not_a_knot(const std::vector<double>& x, casadi_int k);
175 
176  template <typename M>
177  MX construct_graph(const MX& x, const M& values, const Dict& linsol_options, const Dict& opts);
178 
179  enum FittingAlgorithm {ALG_NOT_A_KNOT, ALG_SMOOTH_LINEAR};
180 
182  std::string linear_solver_;
185  std::vector<casadi_int> degree_;
186  };
187 
188 
189  template <typename M>
190  MX BSplineInterpolant::construct_graph(const MX& x, const M& values,
191  const Dict& linsol_options, const Dict& opts) {
192 
193  std::vector< std::vector<double> > grid;
194  for (casadi_int k=0;k<degree_.size();++k) {
195  std::vector<double> local_grid(grid_.begin()+offset_[k], grid_.begin()+offset_[k+1]);
196  grid.push_back(local_grid);
197  }
198 
199  bool do_inline = false;
200  for (auto&& op : opts) {
201  if (op.first=="inline") {
202  do_inline = op.second;
203  }
204  }
205 
206  Dict opts_bspline;
207  opts_bspline["lookup_mode"] = lookup_modes_;
208  opts_bspline["inline"] = do_inline;
209 
210  switch (algorithm_) {
211  case ALG_NOT_A_KNOT:
212  {
213  std::vector< std::vector<double> > knots;
214  for (casadi_int k=0;k<degree_.size();++k)
215  knots.push_back(not_a_knot(grid[k], degree_[k]));
216  Dict opts_dual;
217  opts_dual["lookup_mode"] = lookup_modes_;
218 
219  DM J = MX::bspline_dual(meshgrid(grid), knots, degree_, opts_dual);
220 
221  casadi_assert_dev(J.size1()==J.size2());
222 
223  M V = M::reshape(values, m_, -1).T();
224  M C_opt = solve(J, V, linear_solver_, linsol_options);
225 
226  if (!has_parametric_values()) {
227  double fit = static_cast<double>(norm_1(mtimes(J, C_opt) - V));
228  if (verbose_) casadi_message("Lookup table fitting error: " + str(fit));
229  }
230 
231  return MX::bspline(x, C_opt.T(), knots, degree_, m_, opts_bspline);
232  }
233  case ALG_SMOOTH_LINEAR:
234  {
235  casadi_int n_dim = degree_.size();
236  // Linear fit
237  Function linear;
238  if (has_parametric_values()) {
239  linear = interpolant("linear", "linear", grid, m_);
240  } else {
241  linear = interpolant("linear", "linear", grid, values_);
242  }
243 
244  std::vector< std::vector<double> > egrid;
245  std::vector< std::vector<double> > new_grid;
246 
247  for (casadi_int k=0;k<n_dim;++k) {
248  casadi_assert(degree_[k]==3, "Only degree 3 supported for 'smooth_linear'.");
249 
250  // Add extra knots
251  const std::vector<double>& g = grid[k];
252 
253  // Determine smallest gap.
254  double m = inf;
255  for (casadi_int i=0;i<g.size()-1;++i) {
256  double delta = g[i+1]-g[i];
257  if (delta<m) m = delta;
258  }
259  double step = smooth_linear_frac_*m;
260 
261  // Add extra knots
262  std::vector<double> new_g;
263  new_g.push_back(g.front());
264  new_g.push_back(g.front()+step);
265  for (casadi_int i=1;i<g.size()-1;++i) {
266  new_g.push_back(g[i]-step);
267  new_g.push_back(g[i]);
268  new_g.push_back(g[i]+step);
269  }
270  new_g.push_back(g.back()-step);
271  new_g.push_back(g.back());
272  new_grid.push_back(new_g);
273 
274  // Correct multiplicity
275  double v1 = new_g.front();
276  double vend = new_g.back();
277  new_g.insert(new_g.begin(), degree_[k], v1);
278  new_g.insert(new_g.end(), degree_[k], vend);
279 
280  grid[k] = new_g;
281 
282  // Compute greville points
283  egrid.push_back(greville_points(new_g, degree_[k]));
284  }
285 
286  std::vector<double> mg = meshgrid(egrid);
287  casadi_int N = mg.size()/n_dim;
288 
289  // Evaluate linear interpolation on greville grid
290  DM arg = DM::reshape(mg, n_dim, N);
291  std::vector<M> res;
292  if (has_parametric_values()) {
293  res = linear(std::vector<M>{M(arg), values});
294  } else {
295  res = linear(std::vector<M>{M(arg)});
296  }
297 
298  return MX::bspline(x, res[0], grid, degree_, m_, opts_bspline);
299  }
300  default:
301  casadi_assert_dev(false);
302  }
303  return MX(); // Cannot happen
304  }
305 
306 } // namespace casadi
307 
309 #endif // CASADI_BSPLINE_INTERPOLANT_HPP
'bspline' plugin for Interpolant
std::string linear_solver_
Only used during init, no need to serialize these.
int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const override
Propagate sparsity backwards.
static std::vector< double > not_a_knot(const std::vector< double > &x, casadi_int k)
int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const override
Propagate sparsity forward.
std::vector< casadi_int > degree_
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
MX construct_graph(const MX &x, const M &values, const Dict &linsol_options, const Dict &opts)
static Interpolant * creator(const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m)
Create a new Interpolant.
bool has_jacobian() const override
Full Jacobian.
const char * plugin_name() const override
const Options & get_options() const override
Options.
static std::vector< double > greville_points(const std::vector< double > &x, casadi_int deg)
static const Options options_
Options.
bool has_reverse(casadi_int nadj) const override
Return function that calculates adjoint derivatives reverse(nadj) returns a cached instance if availa...
bool get_diff_in(casadi_int i) override
Which inputs are differentiable?
std::string class_name() const override
Get type name.
static const std::string meta_doc
A documentation string.
bool has_codegen() const override
Is codegen supported?
bool has_forward(casadi_int nfwd) const override
Return function that calculates forward derivatives forward(nfwd) returns a cached instance if availa...
Helper class for C code generation.
Helper class for Serialization.
virtual int sp_forward(const bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const
Propagate sparsity forward.
virtual int sp_reverse(bvec_t **arg, bvec_t **res, casadi_int *iw, bvec_t *w, void *mem) const
Propagate sparsity backwards.
Function object.
Definition: function.hpp:60
casadi_int size2() const
Get the second dimension (i.e. number of columns)
casadi_int size1() const
Get the first dimension (i.e. number of rows)
std::vector< std::string > lookup_modes_
static std::vector< double > meshgrid(const std::vector< std::vector< double > > &grid)
Definition: interpolant.cpp:83
std::vector< casadi_int > offset_
std::vector< double > grid_
std::vector< double > values_
bool has_parametric_values() const
Is parametric?
MX - Matrix expression.
Definition: mx.hpp:92
MX T() const
Transpose the matrix.
Definition: mx.cpp:1029
static MX bspline(const MX &x, const DM &coeffs, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > &degree, casadi_int m, const Dict &opts=Dict())
Definition: mx.cpp:2116
static DM bspline_dual(const std::vector< double > &x, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > &degree, const Dict &opts=Dict())
Definition: mx.cpp:2133
static Matrix< double > reshape(const Matrix< double > &x, casadi_int nrow, casadi_int ncol)
Base class for FunctionInternal and LinsolInternal.
bool verbose_
Verbose printout.
Helper class for Serialization.
Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts)
The casadi namespace.
Definition: archiver.cpp:28
T norm_1(const std::vector< T > &x)
unsigned long long bvec_t
std::vector< casadi_int > find(const std::vector< T > &v)
find nonzeros
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
const double inf
infinity
Definition: calculus.hpp:50
MX do_inline(const MX &x, const std::vector< std::vector< double > > &knots, const MX &coeffs, casadi_int m, const std::vector< casadi_int > &degree, const std::vector< casadi_int > &lookup_mode)
Definition: bspline.cpp:211
Options metadata for a class.
Definition: options.hpp:40