interpolant.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "interpolant_impl.hpp"
27 #include "casadi_misc.hpp"
28 #include "mx_node.hpp"
29 #include "casadi_low.hpp"
30 #include <typeinfo>
31 
32 namespace casadi {
33 
34  bool has_interpolant(const std::string& name) {
35  return Interpolant::has_plugin(name);
36  }
37 
38  void load_interpolant(const std::string& name) {
40  }
41 
42  std::string doc_interpolant(const std::string& name) {
43  return Interpolant::getPlugin(name).doc;
44  }
45 
46  void Interpolant::stack_grid(const std::vector< std::vector<double> >& grid,
47  std::vector<casadi_int>& offset, std::vector<double>& stacked) {
48 
49  // Get offset for each input dimension
50  offset.clear();
51  offset.reserve(grid.size()+1);
52  offset.push_back(0);
53  for (auto&& g : grid) offset.push_back(offset.back()+g.size());
54 
55  // Stack input grids
56  stacked.clear();
57  stacked.reserve(offset.back());
58  for (auto&& g : grid) stacked.insert(stacked.end(), g.begin(), g.end());
59  }
60 
61  void Interpolant::check_grid(const std::vector< std::vector<double> >& grid) {
62  // Dimension at least 1
63  casadi_assert(!grid.empty(), "At least one input required");
64 
65  // Grid must be strictly increasing
66  for (auto&& g : grid) {
67  casadi_assert(is_increasing(g), "Gridpoints must be strictly increasing");
68  casadi_assert(is_regular(g), "Gridpoints must be regular");
69  casadi_assert(g.size()>=2, "Need at least two grid points for every input");
70  }
71  }
72 
73  void Interpolant::check_grid(const std::vector<casadi_int> & grid_dims) {
74  // Dimension at least 1
75  casadi_assert(!grid_dims.empty(), "At least one dimension required");
76 
77  // Grid must be strictly increasing
78  for (casadi_int d : grid_dims) {
79  casadi_assert(d>=2, "Need at least two grid points for every input");
80  }
81  }
82 
83  std::vector<double> Interpolant::meshgrid(const std::vector< std::vector<double> >& grid) {
84  std::vector<casadi_int> cnts(grid.size()+1, 0);
85  std::vector<casadi_int> sizes(grid.size(), 0);
86  for (casadi_int k=0;k<grid.size();++k) sizes[k]= grid[k].size();
87 
88  casadi_int total_iter = 1;
89  for (casadi_int k=0;k<grid.size();++k) total_iter*= sizes[k];
90 
91  casadi_int n_dims = grid.size();
92 
93  std::vector<double> ret(total_iter*n_dims);
94  for (casadi_int i=0;i<total_iter;++i) {
95 
96  for (casadi_int j=0;j<grid.size();++j) {
97  ret[i*n_dims+j] = grid[j][cnts[j]];
98  }
99 
100  cnts[0]++;
101  casadi_int j = 0;
102  while (j<n_dims && cnts[j]==sizes[j]) {
103  cnts[j] = 0;
104  j++;
105  cnts[j]++;
106  }
107 
108  }
109 
110  return ret;
111  }
112 
113  casadi_int Interpolant::coeff_size() const {
114  return coeff_size(offset_, m_);
115  }
116 
117  casadi_int Interpolant::coeff_size(const std::vector<casadi_int>& offset, casadi_int m) {
118  casadi_int ret = 1;
119  for (casadi_int k=0;k<offset.size()-1;++k) {
120  ret *= offset[k+1]-offset[k];
121  }
122  return m*ret;
123  }
124 
125  Function interpolant(const std::string& name,
126  const std::string& solver,
127  const std::vector<std::vector<double> >& grid,
128  const std::vector<double>& values,
129  const Dict& opts) {
131  // Get offset for each input dimension
132  std::vector<casadi_int> offset;
133  // Stack input grids
134  std::vector<double> stacked;
135 
136  // Consistency check, number of elements
137  casadi_uint nel=1;
138  for (auto&& g : grid) nel *= g.size();
139  casadi_assert(values.size() % nel== 0,
140  "Inconsistent number of elements. Must be a multiple of " +
141  str(nel) + ", but got " + str(values.size()) + " instead.");
142 
143  Interpolant::stack_grid(grid, offset, stacked);
144 
145  casadi_int m = values.size()/nel;
146  return Interpolant::construct(solver, name, stacked, offset, values, m, opts);
147  }
148 
149  Function Interpolant::construct(const std::string& solver,
150  const std::string& name,
151  const std::vector<double>& grid,
152  const std::vector<casadi_int>& offset,
153  const std::vector<double>& values,
154  casadi_int m,
155  const Dict& opts) {
156  bool do_inline = false;
157  Dict options = extract_from_dict(opts, "inline", do_inline);
158  if (do_inline && !Interpolant::getPlugin(solver).exposed.do_inline) {
159  options["inline"] = true;
160  do_inline = false;
161  }
162  if (do_inline && Interpolant::getPlugin(solver).exposed.do_inline) {
163  return Interpolant::getPlugin(solver).exposed.
164  do_inline(name, grid, offset, values, m, options);
165  } else {
167  .creator(name, grid, offset, values, m), options);
168  }
169  }
170 
171  Function interpolant(const std::string& name,
172  const std::string& solver,
173  const std::vector<casadi_int>& grid_dims,
174  const std::vector<double>& values,
175  const Dict& opts) {
176  Interpolant::check_grid(grid_dims);
177 
178  // Consistency check, number of elements
179  casadi_uint nel = product(grid_dims);
180  casadi_assert(values.size() % nel== 0,
181  "Inconsistent number of elements. Must be a multiple of " +
182  str(nel) + ", but got " + str(values.size()) + " instead.");
183 
184  casadi_int m = values.size()/nel;
185  return Interpolant::construct(solver, name, std::vector<double>{},
186  cumsum0(grid_dims), values, m, opts);
187  }
188 
189  Function interpolant(const std::string& name,
190  const std::string& solver,
191  const std::vector<std::vector<double> >& grid,
192  casadi_int m,
193  const Dict& opts) {
195 
196  // Get offset for each input dimension
197  std::vector<casadi_int> offset;
198  // Stack input grids
199  std::vector<double> stacked;
200 
201  Interpolant::stack_grid(grid, offset, stacked);
202  return Interpolant::construct(solver, name, stacked, offset, std::vector<double>{}, m, opts);
203  }
204 
205  Function interpolant(const std::string& name,
206  const std::string& solver,
207  const std::vector<casadi_int>& grid_dims,
208  casadi_int m,
209  const Dict& opts) {
210  Interpolant::check_grid(grid_dims);
211  return Interpolant::construct(solver, name, std::vector<double>{},
212  cumsum0(grid_dims), std::vector<double>{}, m, opts);
213  }
214 
216  Interpolant(const std::string& name,
217  const std::vector<double>& grid,
218  const std::vector<casadi_int>& offset,
219  const std::vector<double>& values,
220  casadi_int m)
221  : FunctionInternal(name), m_(m), grid_(grid), offset_(offset), values_(values) {
222  // Number of grid points
223  ndim_ = offset_.size()-1;
224  }
225 
227  }
228 
230  if (i==0) return Sparsity::dense(ndim_, batch_x_);
231  if (arg_values(i)) return Sparsity::dense(coeff_size());
232  if (arg_grid(i)) return Sparsity::dense(offset_.back());
233  casadi_assert_dev(false);
234  }
235 
237  casadi_assert_dev(i==0);
238  return Sparsity::dense(m_, batch_x_);
239  }
240 
241  std::string Interpolant::get_name_in(casadi_int i) {
242  if (i==0) return "x";
243  if (arg_values(i)) return "c";
244  if (arg_grid(i)) return "g";
245  casadi_assert_dev(false);
246  }
247 
248  std::string Interpolant::get_name_out(casadi_int i) {
249  casadi_assert_dev(i==0);
250  return "f";
251  }
252 
253  std::map<std::string, Interpolant::Plugin> Interpolant::solvers_;
254 
255 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
256  std::mutex Interpolant::mutex_solvers_;
257 #endif // CASADI_WITH_THREADSAFE_SYMBOLICS
258 
259  const std::string Interpolant::infix_ = "interpolant";
260 
263  {{"lookup_mode",
265  "Specifies, for each grid dimension, the lookup algorithm used to find the correct index. "
266  "'linear' uses a for-loop + break; (default when #knots<=100), "
267  "'exact' uses floored division (only for uniform grids), "
268  "'binary' uses a binary search. (default when #knots>100)."}},
269  {"inline",
270  {OT_BOOL,
271  "Implement the lookup table in MX primitives. "
272  "Useful when you need derivatives with respect to grid and/or coefficients. "
273  "Such derivatives are fundamentally dense, so use with caution."}},
274  {"batch_x",
275  {OT_INT,
276  "Evaluate a batch of different inputs at once (default 1)."}}
277  }
278  };
279 
280  bool Interpolant::arg_values(casadi_int i) const {
281  if (!has_parametric_values()) return false;
282  return arg_values()==i;
283  }
284  bool Interpolant::arg_grid(casadi_int i) const {
285  if (!has_parametric_grid()) return false;
286  return arg_grid()==i;
287  }
288 
289  casadi_int Interpolant::arg_values() const {
290  casadi_assert_dev(has_parametric_values());
291  return 1+has_parametric_grid();
292  }
293  casadi_int Interpolant::arg_grid() const {
294  casadi_assert_dev(has_parametric_grid());
295  return 1;
296  }
297 
298  void Interpolant::init(const Dict& opts) {
299 
300  batch_x_ = 1;
301 
302  // Read options
303  for (auto&& op : opts) {
304  if (op.first=="lookup_mode") {
305  lookup_modes_ = op.second;
306  } else if (op.first=="batch_x") {
307  batch_x_ = op.second;
308  }
309  }
310 
311  // Call the base class initializer
313 
314  // Needed by casadi_interpn
315  alloc_w(ndim_, true);
316  alloc_iw(2*ndim_, true);
317  }
318 
319  std::vector<casadi_int> Interpolant::interpret_lookup_mode(
320  const std::vector<std::string>& modes, const std::vector<double>& knots,
321  const std::vector<casadi_int>& offset,
322  const std::vector<casadi_int>& margin_left, const std::vector<casadi_int>& margin_right) {
323  casadi_assert_dev(modes.empty() || modes.size()==offset.size()-1);
324 
325  std::vector<casadi_int> ret;
326  for (casadi_int i=0;i<offset.size()-1;++i) {
327  casadi_int n = offset[i+1]-offset[i];
328  ret.push_back(Low::interpret_lookup_mode(modes.empty() ? "auto": modes[i], n));
329  }
330 
331  for (casadi_int i=0;i<offset.size()-1;++i) {
332  if (ret[i]==LOOKUP_EXACT) {
333  if (!knots.empty()) {
334  casadi_int m_left = margin_left.empty() ? 0 : margin_left[i];
335  casadi_int m_right = margin_right.empty() ? 0 : margin_right[i];
336 
337  std::vector<double> grid(
338  knots.begin()+offset[i]+m_left,
339  knots.begin()+offset[i+1]-m_right);
340  casadi_assert_dev(is_increasing(grid) && is_equally_spaced(grid));
341  }
342  }
343  }
344  return ret;
345  }
346 
349  s.version("Interpolant", 2);
350  s.pack("Interpolant::ndim", ndim_);
351  s.pack("Interpolant::m", m_);
352  s.pack("Interpolant::grid", grid_);
353  s.pack("Interpolant::offset", offset_);
354  s.pack("Interpolant::values", values_);
355  s.pack("Interpolant::lookup_modes", lookup_modes_);
356  s.pack("Interpolant::batch_x", batch_x_);
357  }
358 
362  }
363 
366  }
367 
369  int version = s.version("Interpolant", 1, 2);
370  s.unpack("Interpolant::ndim", ndim_);
371  s.unpack("Interpolant::m", m_);
372  s.unpack("Interpolant::grid", grid_);
373  s.unpack("Interpolant::offset", offset_);
374  s.unpack("Interpolant::values", values_);
375  s.unpack("Interpolant::lookup_modes", lookup_modes_);
376  if (version==1) {
377  batch_x_ = 1;
378  } else {
379  s.unpack("Interpolant::batch_x", batch_x_);
380  }
381  }
382 
383 } // namespace casadi
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
Internal class for Function.
void alloc_iw(size_t sz_iw, bool persistent=false)
Ensure required length of iw field.
void init(const Dict &opts) override
Initialize.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
static const Options options_
Options.
void serialize_type(SerializingStream &s) const override
Serialize type information.
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
Function object.
Definition: function.hpp:60
static Function create(FunctionInternal *node)
Create from node.
Definition: function.cpp:336
std::string get_name_out(casadi_int i) override
Names of function input and outputs.
std::vector< std::string > lookup_modes_
static const Options options_
Options.
~Interpolant() override
Destructor.
static std::vector< double > meshgrid(const std::vector< std::vector< double > > &grid)
Definition: interpolant.cpp:83
Sparsity get_sparsity_out(casadi_int i) override
Sparsities of function inputs and outputs.
casadi_int coeff_size() const
Size of the flattened coefficients vector.
void init(const Dict &opts) override
Initialize.
Sparsity get_sparsity_in(casadi_int i) override
Sparsities of function inputs and outputs.
std::vector< casadi_int > offset_
static void check_grid(const std::vector< std::vector< double > > &grid)
Definition: interpolant.cpp:61
static Function construct(const std::string &solver, const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m, const Dict &opts)
Comstruct a new Interpolant.
std::vector< double > grid_
Interpolant(const std::string &name, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< double > &values, casadi_int m)
Constructor.
bool has_parametric_grid() const
Is parametric?
static void stack_grid(const std::vector< std::vector< double > > &grid, std::vector< casadi_int > &offset, std::vector< double > &stacked)
Definition: interpolant.cpp:46
std::vector< double > values_
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
casadi_int arg_values() const
casadi_int arg_grid() const
static const std::string infix_
Infix.
std::string get_name_in(casadi_int i) override
Names of function input and outputs.
bool has_parametric_values() const
Is parametric?
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
void serialize_type(SerializingStream &s) const override
Serialize type information.
static std::vector< casadi_int > interpret_lookup_mode(const std::vector< std::string > &modes, const std::vector< double > &grid, const std::vector< casadi_int > &offset, const std::vector< casadi_int > &margin_left=std::vector< casadi_int >(), const std::vector< casadi_int > &margin_right=std::vector< casadi_int >())
Convert from (optional) lookup modes labels to enum.
static std::map< std::string, Plugin > solvers_
Collection of solvers.
static casadi_int interpret_lookup_mode(const std::string &lookup_mode, casadi_int n)
Definition: casadi_low.cpp:61
static bool has_plugin(const std::string &pname, bool verbose=false)
Check if a plugin is available or can be loaded.
void serialize_type(SerializingStream &s) const
Serialize type information.
static Plugin & getPlugin(const std::string &pname)
Load and get the creator function.
static ProtoFunction * deserialize(DeserializingStream &s)
Deserialize with type disambiguation.
static Plugin load_plugin(const std::string &pname, bool register_plugin=true, bool needs_lock=true)
Load a plugin dynamically.
Base class for FunctionInternal and LinsolInternal.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
General sparsity class.
Definition: sparsity.hpp:106
static Sparsity dense(casadi_int nrow, casadi_int ncol=1)
Create a dense rectangular sparsity pattern *.
Definition: sparsity.cpp:1012
Function interpolant(const std::string &name, const std::string &solver, const std::vector< std::vector< double > > &grid, const std::vector< double > &values, const Dict &opts)
std::string doc_interpolant(const std::string &name)
Get the documentation string for a plugin.
Definition: interpolant.cpp:42
void load_interpolant(const std::string &name)
Explicitly load a plugin dynamically.
Definition: interpolant.cpp:38
bool has_interpolant(const std::string &name)
Check if a particular plugin is available.
Definition: interpolant.cpp:34
The casadi namespace.
Definition: archiver.cpp:28
bool is_equally_spaced(const std::vector< double > &v)
T product(const std::vector< T > &values)
product
bool is_increasing(const std::vector< T > &v)
Check if the vector is strictly increasing.
@ OT_STRINGVECTOR
std::string str(const T &v)
String representation, any type.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
std::vector< T > cumsum0(const std::vector< T > &values)
cumulative sum, starting with zero
constexpr casadi_int LOOKUP_EXACT
bool is_regular(const std::vector< T > &v)
Checks if array does not contain NaN or Inf.
Dict extract_from_dict(const Dict &d, const std::string &key, T &value)
MX do_inline(const MX &x, const std::vector< std::vector< double > > &knots, const MX &coeffs, casadi_int m, const std::vector< casadi_int > &degree, const std::vector< casadi_int > &lookup_mode)
Definition: bspline.cpp:211
Options metadata for a class.
Definition: options.hpp:40