bspline_interpolant.hpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BSPLINE_INTERPOLANT_HPP
27 #define CASADI_BSPLINE_INTERPOLANT_HPP
28 
29 #include "casadi/core/interpolant_impl.hpp"
30 #include <casadi/solvers/casadi_interpolant_bspline_export.h>
31 
40 
41 namespace casadi {
42  class BSplineCommon;
55  class BSplineInterpolant : public Interpolant {
56  public:
57  // Constructor
58  BSplineInterpolant(const std::string& name,
59  const std::vector<double>& grid,
60  const std::vector<casadi_int>& offset,
61  const std::vector<double>& values,
62  casadi_int m);
63 
64  // Destructor
65  ~BSplineInterpolant() override;
66 
67  // Get name of the plugin
68  const char* plugin_name() const override { return "bspline";}
69 
70  // Get name of the class
71  std::string class_name() const override { return "BSplineInterpolant";}
72 
74  static Interpolant* creator(const std::string& name,
75  const std::vector<double>& grid,
76  const std::vector<casadi_int>& offset,
77  const std::vector<double>& values,
78  casadi_int m) {
79  return new BSplineInterpolant(name, grid, offset, values, m);
80  }
81 
82  // Is differentiable? Deferred to bspline
83  bool get_diff_in(casadi_int i) override { return true; }
84 
85  // Initialize
86  void init(const Dict& opts) override;
87 
89  int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
90 
92 
93  bool has_jacobian() const override { return true;}
94  Function get_jacobian(const std::string& name,
95  const std::vector<std::string>& inames,
96  const std::vector<std::string>& onames,
97  const Dict& opts) const override;
99 
100 
102 
107  bool has_forward(casadi_int nfwd) const override { return true; }
108  Function get_forward(casadi_int nfwd, const std::string& name,
109  const std::vector<std::string>& inames,
110  const std::vector<std::string>& onames,
111  const Dict& opts) const override;
113 
115 
120  bool has_reverse(casadi_int nadj) const override { return true; }
121  Function get_reverse(casadi_int nadj, const std::string& name,
122  const std::vector<std::string>& inames,
123  const std::vector<std::string>& onames,
124  const Dict& opts) const override;
126 
128  bool has_codegen() const override { return true;}
129 
131  void codegen_body(CodeGenerator& g) const override;
132 
134  void codegen_declarations(CodeGenerator& g) const override;
135 
137  static const std::string meta_doc;
138 
140 
141  static const Options options_;
142  const Options& get_options() const override { return options_;}
144 
145  // Spline Function
146  Function S_;
147 
149  int sp_forward(const bvec_t** arg, bvec_t** res,
150  casadi_int* iw, bvec_t* w, void* mem) const override {
151  return S_->sp_forward(arg, res, iw, w, mem);
152  }
153 
155  int sp_reverse(bvec_t** arg, bvec_t** res,
156  casadi_int* iw, bvec_t* w, void* mem) const override {
157  return S_->sp_reverse(arg, res, iw, w, mem);
158  }
159 
160  static std::vector<double> greville_points(const std::vector<double>& x, casadi_int deg);
161 
162  void serialize_body(SerializingStream &s) const override;
163 
165  static ProtoFunction* deserialize(DeserializingStream& s) { return new BSplineInterpolant(s); }
166 
167  protected:
169  explicit BSplineInterpolant(DeserializingStream& s);
170 
171  static std::vector<double> not_a_knot(const std::vector<double>& x, casadi_int k);
172 
173  template <typename M>
174  MX construct_graph(const MX& x, const M& values, const Dict& linsol_options, const Dict& opts);
175 
176  enum FittingAlgorithm {ALG_NOT_A_KNOT, ALG_SMOOTH_LINEAR};
177 
179  std::string linear_solver_;
180  FittingAlgorithm algorithm_;
181  double smooth_linear_frac_;
182  std::vector<casadi_int> degree_;
183  };
184 
185 
186  template <typename M>
187  MX BSplineInterpolant::construct_graph(const MX& x, const M& values,
188  const Dict& linsol_options, const Dict& opts) {
189 
190  std::vector< std::vector<double> > grid;
191  for (casadi_int k=0;k<degree_.size();++k) {
192  std::vector<double> local_grid(grid_.begin()+offset_[k], grid_.begin()+offset_[k+1]);
193  grid.push_back(local_grid);
194  }
195 
196  bool do_inline = false;
197  for (auto&& op : opts) {
198  if (op.first=="inline") {
199  do_inline = op.second;
200  }
201  }
202 
203  Dict opts_bspline;
204  opts_bspline["lookup_mode"] = lookup_modes_;
205  opts_bspline["inline"] = do_inline;
206 
207  switch (algorithm_) {
208  case ALG_NOT_A_KNOT:
209  {
210  std::vector< std::vector<double> > knots;
211  for (casadi_int k=0;k<degree_.size();++k)
212  knots.push_back(not_a_knot(grid[k], degree_[k]));
213  Dict opts_dual;
214  opts_dual["lookup_mode"] = lookup_modes_;
215 
216  DM J = MX::bspline_dual(meshgrid(grid), knots, degree_, opts_dual);
217 
218  casadi_assert_dev(J.size1()==J.size2());
219 
220  M V = M::reshape(values, m_, -1).T();
221  M C_opt = solve(J, V, linear_solver_, linsol_options);
222 
223  if (!has_parametric_values()) {
224  double fit = static_cast<double>(norm_1(mtimes(J, C_opt) - V));
225  if (verbose_) casadi_message("Lookup table fitting error: " + str(fit));
226  }
227 
228  return MX::bspline(x, C_opt.T(), knots, degree_, m_, opts_bspline);
229  }
230  case ALG_SMOOTH_LINEAR:
231  {
232  casadi_int n_dim = degree_.size();
233  // Linear fit
234  Function linear;
235  if (has_parametric_values()) {
236  linear = interpolant("linear", "linear", grid, m_);
237  } else {
238  linear = interpolant("linear", "linear", grid, values_);
239  }
240 
241  std::vector< std::vector<double> > egrid;
242  std::vector< std::vector<double> > new_grid;
243 
244  for (casadi_int k=0;k<n_dim;++k) {
245  casadi_assert(degree_[k]==3, "Only degree 3 supported for 'smooth_linear'.");
246 
247  // Add extra knots
248  const std::vector<double>& g = grid[k];
249 
250  // Determine smallest gap.
251  double m = inf;
252  for (casadi_int i=0;i<g.size()-1;++i) {
253  double delta = g[i+1]-g[i];
254  if (delta<m) m = delta;
255  }
256  double step = smooth_linear_frac_*m;
257 
258  // Add extra knots
259  std::vector<double> new_g;
260  new_g.push_back(g.front());
261  new_g.push_back(g.front()+step);
262  for (casadi_int i=1;i<g.size()-1;++i) {
263  new_g.push_back(g[i]-step);
264  new_g.push_back(g[i]);
265  new_g.push_back(g[i]+step);
266  }
267  new_g.push_back(g.back()-step);
268  new_g.push_back(g.back());
269  new_grid.push_back(new_g);
270 
271  // Correct multiplicity
272  double v1 = new_g.front();
273  double vend = new_g.back();
274  new_g.insert(new_g.begin(), degree_[k], v1);
275  new_g.insert(new_g.end(), degree_[k], vend);
276 
277  grid[k] = new_g;
278 
279  // Compute greville points
280  egrid.push_back(greville_points(new_g, degree_[k]));
281  }
282 
283  std::vector<double> mg = meshgrid(egrid);
284  casadi_int N = mg.size()/n_dim;
285 
286  // Evaluate linear interpolation on greville grid
287  DM arg = DM::reshape(mg, n_dim, N);
288  std::vector<M> res;
289  if (has_parametric_values()) {
290  res = linear(std::vector<M>{M(arg), values});
291  } else {
292  res = linear(std::vector<M>{M(arg)});
293  }
294 
295  return MX::bspline(x, res[0], grid, degree_, m_, opts_bspline);
296  }
297  default:
298  casadi_assert_dev(false);
299  }
300  return MX(); // Cannot happen
301  }
302 
303 } // namespace casadi
304 
306 #endif // CASADI_BSPLINE_INTERPOLANT_HPP
static DM bspline_dual(const std::vector< double > &x, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > &degree, const Dict &opts=Dict())
friend MX bspline(const MX &x, const DM &coeffs, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > &degree, casadi_int m, const Dict &opts=Dict())
Definition: mx.hpp:761
CASADI_EXPORT Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts=Dict())
The casadi namespace.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
Matrix< double > DM
Definition: dm_fwd.hpp:33