26 #include "bspline_interpolant.hpp"
27 #include "casadi/core/bspline.hpp"
32 int CASADI_INTERPOLANT_BSPLINE_EXPORT
35 plugin->name =
"bspline";
37 plugin->version = CASADI_VERSION;
40 plugin->exposed.do_inline =
nullptr;
53 "Sets, for each grid dimension, the degree of the spline."}},
56 "Solver used for constructing the coefficient tensor."}},
57 {
"linear_solver_options",
59 "Options to be passed to the linear solver."}},
62 "Algorithm used for fitting the data: 'not_a_knot' (default, same as Matlab),"
63 " 'smooth_linear'."}},
64 {
"smooth_linear_frac",
66 "When 'smooth_linear' algorithm is active, determines sharpness between"
67 " 0 (sharp, as linear interpolation) and 0.5 (smooth)."
68 "Default value is 0.1."}}
78 const std::vector<double>& grid,
79 const std::vector<casadi_int>& offset,
80 const std::vector<double>& values,
87 std::vector<double> ret;
89 casadi_int m = (k-1)/2;
90 casadi_assert(x.size()>=2*m+2,
"Need more data points");
91 for (casadi_int i=0;i<k+1;++i) ret.push_back(x[0]);
92 for (casadi_int i=0;i<x.size()-2*m-2;++i) ret.push_back(x[m+1+i]);
93 for (casadi_int i=0;i<k+1;++i) ret.push_back(x[x.size()-1]);
95 casadi_error(
"Not implemented");
111 Dict linear_solver_options;
114 for (
auto&& op : opts) {
115 if (op.first==
"degree") {
117 }
else if (op.first==
"linear_solver") {
119 }
else if (op.first==
"linear_solver_options") {
120 linear_solver_options = op.second.to_dict();
121 }
else if (op.first==
"algorithm") {
122 std::string alg = op.second.to_string();
123 if (alg==
"not_a_knot") {
125 }
else if (alg==
"smooth_linear") {
128 casadi_error(
"Algorithm option invalid: " +
get_options().
info(
"algorithm"));
130 }
else if (op.first==
"smooth_linear_frac") {
133 "smooth_linear_frac must be in ]0,0.5[");
151 S_ =
Function(
"wrapper", {x, coeff}, {e}, {
"x",
"c"}, {
"f"});
154 S_ =
Function(
"wrapper", {x}, {e}, {
"x"}, {
"f"});
164 casadi_int max_depth)
const {
170 casadi_int dim = x.size()-deg-1;
171 std::vector<double> ret(dim);
172 for (casadi_int i = 0; i < dim; ++i) {
174 for (casadi_int j = 0; j < deg; j++) {
177 ret[i] = ret[i] / deg;
183 casadi_int* iw,
double* w,
void* mem)
const {
184 setup(mem, arg, res, iw, w);
186 return S_(arg, res, iw, w, m);
199 const std::vector<std::string>& inames,
200 const std::vector<std::string>& onames,
201 const Dict& opts)
const {
206 get_forward(casadi_int nfwd,
const std::string& name,
207 const std::vector<std::string>& inames,
208 const std::vector<std::string>& onames,
209 const Dict& opts)
const {
214 get_reverse(casadi_int nadj,
const std::string& name,
215 const std::vector<std::string>& inames,
216 const std::vector<std::string>& onames,
217 const Dict& opts)
const {
222 s.
version(
"BSplineInterpolant", 1);
223 s.
unpack(
"BSplineInterpolant::s",
S_);
229 s.
version(
"BSplineInterpolant", 1);
230 s.
pack(
"BSplineInterpolant::s",
S_);
~BSplineInterpolant() override
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
std::string linear_solver_
Only used during init, no need to serialize these.
void codegen_body(CodeGenerator &g) const override
Generate code for the body of the C function.
static std::vector< double > not_a_knot(const std::vector< double > &x, casadi_int k)
std::vector< casadi_int > degree_
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void init(const Dict &opts) override
Initialize.
MX construct_graph(const MX &x, const M &values, const Dict &linsol_options, const Dict &opts)
static Interpolant * creator(const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m)
Create a new Interpolant.
FittingAlgorithm algorithm_
Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Return function that calculates forward derivatives forward(nfwd) returns a cached instance if availa...
Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Return function that calculates adjoint derivatives reverse(nadj) returns a cached instance if availa...
const Options & get_options() const override
Options.
static std::vector< double > greville_points(const std::vector< double > &x, casadi_int deg)
int eval(const double **arg, double **res, casadi_int *iw, double *w, void *mem) const override
Evaluate numerically.
double smooth_linear_frac_
static const Options options_
Options.
void codegen_declarations(CodeGenerator &g) const override
Generate code for the declarations of the C function.
BSplineInterpolant(const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m)
static const std::string meta_doc
A documentation string.
void find(std::map< FunctionInternal *, Function > &all_fun, casadi_int max_depth) const override
Function get_jacobian(const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const override
Full Jacobian.
Helper class for C code generation.
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
void alloc_iw(size_t sz_iw, bool persistent=false)
Ensure required length of iw field.
void alloc_res(size_t sz_res, bool persistent=false)
Ensure required length of res field.
void alloc_arg(size_t sz_arg, bool persistent=false)
Ensure required length of arg field.
virtual void codegen_body(CodeGenerator &g) const
Generate code for the function body.
virtual Function get_reverse(casadi_int nadj, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const
Return function that calculates adjoint derivatives.
virtual Dict info() const
virtual void codegen_declarations(CodeGenerator &g) const
Generate code for the declarations of the C function.
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
virtual Function get_jacobian(const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const
Return Jacobian of all input elements with respect to all output elements.
void setup(void *mem, const double **arg, double **res, casadi_int *iw, double *w) const
Set the (persistent and temporary) work vectors.
void add_embedded(std::map< FunctionInternal *, Function > &all_fun, const Function &dep, casadi_int max_depth) const
virtual Function get_forward(casadi_int nfwd, const std::string &name, const std::vector< std::string > &inames, const std::vector< std::string > &onames, const Dict &opts) const
Return function that calculates forward derivatives.
size_t sz_res() const
Get required length of res field.
size_t sz_iw() const
Get required length of iw field.
size_t sz_w() const
Get required length of w field.
size_t sz_arg() const
Get required length of arg field.
static MX sym(const std::string &name, casadi_int nrow=1, casadi_int ncol=1)
Create an nrow-by-ncol symbolic primitive.
static const Options options_
Options.
casadi_int coeff_size() const
Size of the flattened coefficients vector.
void init(const Dict &opts) override
Initialize.
std::vector< casadi_int > offset_
bool has_parametric_grid() const
Is parametric?
std::vector< double > values_
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
bool has_parametric_values() const
Is parametric?
static void registerPlugin(const Plugin &plugin, bool needs_lock=true)
Register an integrator in the factory.
void clear_mem()
Clear all memory (called from destructor)
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
int CASADI_INTERPOLANT_BSPLINE_EXPORT casadi_register_interpolant_bspline(Interpolant::Plugin *plugin)
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
void CASADI_INTERPOLANT_BSPLINE_EXPORT casadi_load_interpolant_bspline()