26 #ifndef CASADI_BSPLINE_INTERPOLANT_HPP
27 #define CASADI_BSPLINE_INTERPOLANT_HPP
29 #include "casadi/core/interpolant_impl.hpp"
30 #include <casadi/solvers/casadi_interpolant_bspline_export.h>
55 class BSplineInterpolant :
public Interpolant {
58 BSplineInterpolant(
const std::string& name,
59 const std::vector<double>& grid,
60 const std::vector<casadi_int>& offset,
61 const std::vector<double>& values,
65 ~BSplineInterpolant()
override;
68 const char* plugin_name()
const override {
return "bspline";}
71 std::string class_name()
const override {
return "BSplineInterpolant";}
74 static Interpolant* creator(
const std::string& name,
75 const std::vector<double>& grid,
76 const std::vector<casadi_int>& offset,
77 const std::vector<double>& values,
79 return new BSplineInterpolant(name, grid, offset, values, m);
83 bool get_diff_in(casadi_int i)
override {
return true; }
86 void init(
const Dict& opts)
override;
89 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
93 bool has_jacobian()
const override {
return true;}
94 Function get_jacobian(
const std::string& name,
95 const std::vector<std::string>& inames,
96 const std::vector<std::string>& onames,
97 const Dict& opts)
const override;
107 bool has_forward(casadi_int nfwd)
const override {
return true; }
108 Function get_forward(casadi_int nfwd,
const std::string& name,
109 const std::vector<std::string>& inames,
110 const std::vector<std::string>& onames,
111 const Dict& opts)
const override;
120 bool has_reverse(casadi_int nadj)
const override {
return true; }
121 Function get_reverse(casadi_int nadj,
const std::string& name,
122 const std::vector<std::string>& inames,
123 const std::vector<std::string>& onames,
124 const Dict& opts)
const override;
128 bool has_codegen()
const override {
return true;}
131 void codegen_body(CodeGenerator& g)
const override;
134 void codegen_declarations(CodeGenerator& g)
const override;
137 static const std::string meta_doc;
141 static const Options options_;
142 const Options& get_options()
const override {
return options_;}
149 int sp_forward(
const bvec_t** arg, bvec_t** res,
150 casadi_int* iw, bvec_t* w,
void* mem)
const override {
151 return S_->sp_forward(arg, res, iw, w, mem);
155 int sp_reverse(bvec_t** arg, bvec_t** res,
156 casadi_int* iw, bvec_t* w,
void* mem)
const override {
157 return S_->sp_reverse(arg, res, iw, w, mem);
160 static std::vector<double> greville_points(
const std::vector<double>& x, casadi_int deg);
162 void serialize_body(SerializingStream &s)
const override;
165 static ProtoFunction* deserialize(DeserializingStream& s) {
return new BSplineInterpolant(s); }
169 explicit BSplineInterpolant(DeserializingStream& s);
171 static std::vector<double> not_a_knot(
const std::vector<double>& x, casadi_int k);
173 template <
typename M>
174 MX construct_graph(
const MX& x,
const M& values,
const Dict& linsol_options,
const Dict& opts);
176 enum FittingAlgorithm {ALG_NOT_A_KNOT, ALG_SMOOTH_LINEAR};
179 std::string linear_solver_;
180 FittingAlgorithm algorithm_;
181 double smooth_linear_frac_;
182 std::vector<casadi_int> degree_;
186 template <
typename M>
187 MX BSplineInterpolant::construct_graph(
const MX& x,
const M& values,
188 const Dict& linsol_options,
const Dict& opts) {
190 std::vector< std::vector<double> > grid;
191 for (casadi_int k=0;k<degree_.size();++k) {
192 std::vector<double> local_grid(grid_.begin()+offset_[k], grid_.begin()+offset_[k+1]);
193 grid.push_back(local_grid);
196 bool do_inline =
false;
197 for (
auto&& op : opts) {
198 if (op.first==
"inline") {
199 do_inline = op.second;
204 opts_bspline[
"lookup_mode"] = lookup_modes_;
205 opts_bspline[
"inline"] = do_inline;
207 switch (algorithm_) {
210 std::vector< std::vector<double> > knots;
211 for (casadi_int k=0;k<degree_.size();++k)
212 knots.push_back(not_a_knot(grid[k], degree_[k]));
214 opts_dual[
"lookup_mode"] = lookup_modes_;
218 casadi_assert_dev(J.size1()==J.size2());
220 M V = M::reshape(values, m_, -1).T();
221 M C_opt = solve(J, V, linear_solver_, linsol_options);
223 if (!has_parametric_values()) {
224 double fit =
static_cast<double>(norm_1(mtimes(J, C_opt) - V));
225 if (verbose_) casadi_message(
"Lookup table fitting error: " + str(fit));
228 return MX::bspline(x, C_opt.T(), knots, degree_, m_, opts_bspline);
230 case ALG_SMOOTH_LINEAR:
232 casadi_int n_dim = degree_.size();
235 if (has_parametric_values()) {
236 linear =
interpolant(
"linear",
"linear", grid, m_);
238 linear =
interpolant(
"linear",
"linear", grid, values_);
241 std::vector< std::vector<double> > egrid;
242 std::vector< std::vector<double> > new_grid;
244 for (casadi_int k=0;k<n_dim;++k) {
245 casadi_assert(degree_[k]==3,
"Only degree 3 supported for 'smooth_linear'.");
248 const std::vector<double>& g = grid[k];
252 for (casadi_int i=0;i<g.size()-1;++i) {
253 double delta = g[i+1]-g[i];
254 if (delta<m) m = delta;
256 double step = smooth_linear_frac_*m;
259 std::vector<double> new_g;
260 new_g.push_back(g.front());
261 new_g.push_back(g.front()+step);
262 for (casadi_int i=1;i<g.size()-1;++i) {
263 new_g.push_back(g[i]-step);
264 new_g.push_back(g[i]);
265 new_g.push_back(g[i]+step);
267 new_g.push_back(g.back()-step);
268 new_g.push_back(g.back());
269 new_grid.push_back(new_g);
272 double v1 = new_g.front();
273 double vend = new_g.back();
274 new_g.insert(new_g.begin(), degree_[k], v1);
275 new_g.insert(new_g.end(), degree_[k], vend);
280 egrid.push_back(greville_points(new_g, degree_[k]));
283 std::vector<double> mg = meshgrid(egrid);
284 casadi_int N = mg.size()/n_dim;
287 DM arg = DM::reshape(mg, n_dim, N);
289 if (has_parametric_values()) {
290 res = linear(std::vector<M>{M(arg), values});
292 res = linear(std::vector<M>{M(arg)});
295 return MX::bspline(x, res[0], grid, degree_, m_, opts_bspline);
298 casadi_assert_dev(
false);
static DM bspline_dual(const std::vector< double > &x, const std::vector< std::vector< double > > &knots, const std::vector< casadi_int > °ree, const Dict &opts=Dict())
CASADI_EXPORT Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts=Dict())
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.