26 #ifndef CASADI_BSPLINE_INTERPOLANT_HPP
27 #define CASADI_BSPLINE_INTERPOLANT_HPP
29 #include "casadi/core/interpolant_impl.hpp"
30 #include <casadi/solvers/casadi_interpolant_bspline_export.h>
55 class BSplineInterpolant :
public Interpolant {
58 BSplineInterpolant(
const std::string& name,
59 const std::vector<double>& grid,
60 const std::vector<casadi_int>& offset,
61 const std::vector<double>& values,
65 ~BSplineInterpolant()
override;
68 const char* plugin_name()
const override {
return "bspline";}
71 std::string class_name()
const override {
return "BSplineInterpolant";}
74 static Interpolant* creator(
const std::string& name,
75 const std::vector<double>& grid,
76 const std::vector<casadi_int>& offset,
77 const std::vector<double>& values,
79 return new BSplineInterpolant(name, grid, offset, values, m);
83 bool get_diff_in(casadi_int i)
override {
return true; }
86 void init(
const Dict& opts)
override;
89 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
93 bool has_jacobian()
const override {
return true;}
94 Function get_jacobian(
const std::string& name,
95 const std::vector<std::string>& inames,
96 const std::vector<std::string>& onames,
97 const Dict& opts)
const override;
107 bool has_forward(casadi_int nfwd)
const override {
return true; }
108 Function get_forward(casadi_int nfwd,
const std::string& name,
109 const std::vector<std::string>& inames,
110 const std::vector<std::string>& onames,
111 const Dict& opts)
const override;
120 bool has_reverse(casadi_int nadj)
const override {
return true; }
121 Function get_reverse(casadi_int nadj,
const std::string& name,
122 const std::vector<std::string>& inames,
123 const std::vector<std::string>& onames,
124 const Dict& opts)
const override;
128 bool has_codegen()
const override {
return true;}
131 void codegen_body(CodeGenerator& g)
const override;
134 void codegen_declarations(CodeGenerator& g)
const override;
137 static const std::string meta_doc;
141 static const Options options_;
142 const Options& get_options()
const override {
return options_;}
149 void find(std::map<FunctionInternal*, Function>& all_fun, casadi_int max_depth)
const override;
152 int sp_forward(
const bvec_t** arg, bvec_t** res,
153 casadi_int* iw, bvec_t* w,
void* mem)
const override {
154 return S_->sp_forward(arg, res, iw, w, mem);
158 int sp_reverse(bvec_t** arg, bvec_t** res,
159 casadi_int* iw, bvec_t* w,
void* mem)
const override {
160 return S_->sp_reverse(arg, res, iw, w, mem);
163 static std::vector<double> greville_points(
const std::vector<double>& x, casadi_int deg);
165 void serialize_body(SerializingStream &s)
const override;
168 static ProtoFunction* deserialize(DeserializingStream& s) {
return new BSplineInterpolant(s); }
172 explicit BSplineInterpolant(DeserializingStream& s);
174 static std::vector<double> not_a_knot(
const std::vector<double>& x, casadi_int k);
176 template <
typename M>
177 MX construct_graph(
const MX& x,
const M& values,
const Dict& linsol_options,
const Dict& opts);
179 enum FittingAlgorithm {ALG_NOT_A_KNOT, ALG_SMOOTH_LINEAR};
182 std::string linear_solver_;
183 FittingAlgorithm algorithm_;
184 double smooth_linear_frac_;
185 std::vector<casadi_int> degree_;
189 template <
typename M>
190 MX BSplineInterpolant::construct_graph(
const MX& x,
const M& values,
191 const Dict& linsol_options,
const Dict& opts) {
193 std::vector< std::vector<double> > grid;
194 for (casadi_int k=0;k<degree_.size();++k) {
195 std::vector<double> local_grid(grid_.begin()+offset_[k], grid_.begin()+offset_[k+1]);
196 grid.push_back(local_grid);
199 bool do_inline =
false;
200 for (
auto&& op : opts) {
201 if (op.first==
"inline") {
202 do_inline = op.second;
207 opts_bspline[
"lookup_mode"] = lookup_modes_;
208 opts_bspline[
"inline"] = do_inline;
210 switch (algorithm_) {
213 std::vector< std::vector<double> > knots;
214 for (casadi_int k=0;k<degree_.size();++k)
215 knots.push_back(not_a_knot(grid[k], degree_[k]));
217 opts_dual[
"lookup_mode"] = lookup_modes_;
221 casadi_assert_dev(J.size1()==J.size2());
223 M V = M::reshape(values, m_, -1).T();
224 M C_opt = solve(J, V, linear_solver_, linsol_options);
226 if (!has_parametric_values()) {
227 double fit =
static_cast<double>(norm_1(mtimes(J, C_opt) - V));
228 if (verbose_) casadi_message(
"Lookup table fitting error: " + str(fit));
231 return MX::bspline(x, C_opt.T(), knots, degree_, m_, opts_bspline);
233 case ALG_SMOOTH_LINEAR:
235 casadi_int n_dim = degree_.size();
238 if (has_parametric_values()) {
239 linear =
interpolant(
"linear",
"linear", grid, m_);
241 linear =
interpolant(
"linear",
"linear", grid, values_);
244 std::vector< std::vector<double> > egrid;
245 std::vector< std::vector<double> > new_grid;
247 for (casadi_int k=0;k<n_dim;++k) {
248 casadi_assert(degree_[k]==3,
"Only degree 3 supported for 'smooth_linear'.");
251 const std::vector<double>& g = grid[k];
255 for (casadi_int i=0;i<g.size()-1;++i) {
256 double delta = g[i+1]-g[i];
257 if (delta<m) m = delta;
259 double step = smooth_linear_frac_*m;
262 std::vector<double> new_g;
263 new_g.push_back(g.front());
264 new_g.push_back(g.front()+step);
265 for (casadi_int i=1;i<g.size()-1;++i) {
266 new_g.push_back(g[i]-step);
267 new_g.push_back(g[i]);
268 new_g.push_back(g[i]+step);
270 new_g.push_back(g.back()-step);
271 new_g.push_back(g.back());
272 new_grid.push_back(new_g);
275 double v1 = new_g.front();
276 double vend = new_g.back();
277 new_g.insert(new_g.begin(), degree_[k], v1);
278 new_g.insert(new_g.end(), degree_[k], vend);
283 egrid.push_back(greville_points(new_g, degree_[k]));
286 std::vector<double> mg = meshgrid(egrid);
287 casadi_int N = mg.size()/n_dim;
290 DM arg = DM::reshape(mg, n_dim, N);
292 if (has_parametric_values()) {
293 res = linear(std::vector<M>{M(arg), values});
295 res = linear(std::vector<M>{M(arg)});
298 return MX::bspline(x, res[0], grid, degree_, m_, opts_bspline);
301 casadi_assert_dev(
false);
static DM bspline_dual(const std::vector< double > &x, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > °ree, const Dict &opts=Dict())
CASADI_EXPORT Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts=Dict())
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.