26 #include "interpolant_impl.hpp"
27 #include "casadi_misc.hpp"
28 #include "mx_node.hpp"
29 #include "casadi_low.hpp"
47 std::vector<casadi_int>& offset, std::vector<double>& stacked) {
51 offset.reserve(grid.size()+1);
53 for (
auto&& g : grid) offset.push_back(offset.back()+g.size());
57 stacked.reserve(offset.back());
58 for (
auto&& g : grid) stacked.insert(stacked.end(), g.begin(), g.end());
63 casadi_assert(!grid.empty(),
"At least one input required");
66 for (
auto&& g : grid) {
67 casadi_assert(
is_increasing(g),
"Gridpoints must be strictly increasing");
68 casadi_assert(
is_regular(g),
"Gridpoints must be regular");
69 casadi_assert(g.size()>=2,
"Need at least two grid points for every input");
75 casadi_assert(!grid_dims.empty(),
"At least one dimension required");
78 for (casadi_int d : grid_dims) {
79 casadi_assert(d>=2,
"Need at least two grid points for every input");
84 std::vector<casadi_int> cnts(grid.size()+1, 0);
85 std::vector<casadi_int> sizes(grid.size(), 0);
86 for (casadi_int k=0;k<grid.size();++k) sizes[k]= grid[k].size();
88 casadi_int total_iter = 1;
89 for (casadi_int k=0;k<grid.size();++k) total_iter*= sizes[k];
91 casadi_int n_dims = grid.size();
93 std::vector<double> ret(total_iter*n_dims);
94 for (casadi_int i=0;i<total_iter;++i) {
96 for (casadi_int j=0;j<grid.size();++j) {
97 ret[i*n_dims+j] = grid[j][cnts[j]];
102 while (j<n_dims && cnts[j]==sizes[j]) {
119 for (casadi_int k=0;k<offset.size()-1;++k) {
120 ret *= offset[k+1]-offset[k];
126 const std::string& solver,
127 const std::vector<std::vector<double> >& grid,
128 const std::vector<double>& values,
132 std::vector<casadi_int> offset;
134 std::vector<double> stacked;
138 for (
auto&& g : grid) nel *= g.size();
139 casadi_assert(values.size() % nel== 0,
140 "Inconsistent number of elements. Must be a multiple of " +
141 str(nel) +
", but got " +
str(values.size()) +
" instead.");
145 casadi_int m = values.size()/nel;
150 const std::string& name,
151 const std::vector<double>& grid,
152 const std::vector<casadi_int>& offset,
153 const std::vector<double>& values,
159 options[
"inline"] =
true;
164 do_inline(name, grid, offset, values, m, options);
167 .creator(name, grid, offset, values, m), options);
172 const std::string& solver,
173 const std::vector<casadi_int>& grid_dims,
174 const std::vector<double>& values,
179 casadi_uint nel =
product(grid_dims);
180 casadi_assert(values.size() % nel== 0,
181 "Inconsistent number of elements. Must be a multiple of " +
182 str(nel) +
", but got " +
str(values.size()) +
" instead.");
184 casadi_int m = values.size()/nel;
186 cumsum0(grid_dims), values, m, opts);
190 const std::string& solver,
191 const std::vector<std::vector<double> >& grid,
197 std::vector<casadi_int> offset;
199 std::vector<double> stacked;
206 const std::string& solver,
207 const std::vector<casadi_int>& grid_dims,
212 cumsum0(grid_dims), std::vector<double>{}, m, opts);
217 const std::vector<double>& grid,
218 const std::vector<casadi_int>& offset,
219 const std::vector<double>& values,
221 :
FunctionInternal(name), m_(m), grid_(grid), offset_(offset), values_(values) {
233 casadi_assert_dev(
false);
237 casadi_assert_dev(i==0);
242 if (i==0)
return "x";
245 casadi_assert_dev(
false);
249 casadi_assert_dev(i==0);
255 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
256 std::mutex Interpolant::mutex_solvers_;
265 "Specifies, for each grid dimension, the lookup algorithm used to find the correct index. "
266 "'linear' uses a for-loop + break; (default when #knots<=100), "
267 "'exact' uses floored division (only for uniform grids), "
268 "'binary' uses a binary search. (default when #knots>100)."}},
271 "Implement the lookup table in MX primitives. "
272 "Useful when you need derivatives with respect to grid and/or coefficients. "
273 "Such derivatives are fundamentally dense, so use with caution."}},
276 "Evaluate a batch of different inputs at once (default 1)."}}
303 for (
auto&& op : opts) {
304 if (op.first==
"lookup_mode") {
306 }
else if (op.first==
"batch_x") {
320 const std::vector<std::string>& modes,
const std::vector<double>& knots,
321 const std::vector<casadi_int>& offset,
322 const std::vector<casadi_int>& margin_left,
const std::vector<casadi_int>& margin_right) {
323 casadi_assert_dev(modes.empty() || modes.size()==offset.size()-1);
325 std::vector<casadi_int> ret;
326 for (casadi_int i=0;i<offset.size()-1;++i) {
327 casadi_int n = offset[i+1]-offset[i];
331 for (casadi_int i=0;i<offset.size()-1;++i) {
333 if (!knots.empty()) {
334 casadi_int m_left = margin_left.empty() ? 0 : margin_left[i];
335 casadi_int m_right = margin_right.empty() ? 0 : margin_right[i];
337 std::vector<double> grid(
338 knots.begin()+offset[i]+m_left,
339 knots.begin()+offset[i+1]-m_right);
351 s.
pack(
"Interpolant::m",
m_);
369 int version = s.
version(
"Interpolant", 1, 2);
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
void alloc_iw(size_t sz_iw, bool persistent=false)
Ensure required length of iw field.
void init(const Dict &opts) override
Initialize.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
static const Options options_
Options.
void serialize_type(SerializingStream &s) const override
Serialize type information.
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
static Function create(FunctionInternal *node)
Create from node.
std::string get_name_out(casadi_int i) override
Names of function input and outputs.
std::vector< std::string > lookup_modes_
static const Options options_
Options.
~Interpolant() override
Destructor.
static std::vector< double > meshgrid(const std::vector< std::vector< double > > &grid)
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
casadi_int coeff_size() const
Size of the flattened coefficients vector.
void init(const Dict &opts) override
Initialize.
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
std::vector< casadi_int > offset_
static void check_grid(const std::vector< std::vector< double > > &grid)
static Function construct(const std::string &solver, const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m, const Dict &opts)
Comstruct a new Interpolant.
std::vector< double > grid_
Interpolant(const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m)
Constructor.
bool has_parametric_grid() const
Is parametric?
static void stack_grid(const std::vector< std::vector< double > > &grid, std::vector< casadi_int > &offset, std::vector< double > &stacked)
std::vector< double > values_
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
casadi_int arg_values() const
casadi_int arg_grid() const
static const std::string infix_
Infix.
std::string get_name_in(casadi_int i) override
Names of function input and outputs.
bool has_parametric_values() const
Is parametric?
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void serialize_type(SerializingStream &s) const override
Serialize type information.
static std::vector< casadi_int > interpret_lookup_mode(const std::vector< std::string > &modes, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< casadi_int > &margin_left=std::vector< casadi_int >(), const std::vector< casadi_int > &margin_right=std::vector< casadi_int >())
Convert from (optional) lookup modes labels to enum.
static std::map< std::string, Plugin > solvers_
Collection of solvers.
static casadi_int interpret_lookup_mode(const std::string &lookup_mode, casadi_int n)
static bool has_plugin(const std::string &pname, bool verbose=false)
Check if a plugin is available or can be loaded.
void serialize_type(SerializingStream &s) const
Serialize type information.
static Plugin & getPlugin(const std::string &pname)
Load and get the creator function.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
static Plugin load_plugin(const std::string &pname, bool register_plugin=true, bool needs_lock=true)
Load a plugin dynamically.
Base class for FunctionInternal and LinsolInternal.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
static Sparsity dense(casadi_int nrow, casadi_int ncol=1)
Create a dense rectangular sparsity pattern *.
Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts)
std::string doc_interpolant(const std::string &name)
Get the documentation string for a plugin.
void load_interpolant(const std::string &name)
Explicitly load a plugin dynamically.
bool has_interpolant(const std::string &name)
Check if a particular plugin is available.
bool is_equally_spaced(const std::vector< double > &v)
T product(const std::vector< T > &values)
product
bool is_increasing(const std::vector< T > &v)
Check if the vector is strictly increasing.
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
std::vector< T > cumsum0(const std::vector< T > &values)
cumulative sum, starting with zero
constexpr casadi_int LOOKUP_EXACT
bool is_regular(const std::vector< T > &v)
Checks if array does not contain NaN or Inf.
Dict extract_from_dict(const Dict &d, const std::string &key, T &value)
MX do_inline(const MX &x, const std::vector< std::vector< double > > &knots, const MX &coeffs, casadi_int m, const std::vector< casadi_int > °ree, const std::vector< casadi_int > &lookup_mode)
Options metadata for a class.