sundials_interface.cpp
1 /*
2  * This file is part of CasADi.
3  *
4  * CasADi -- A symbolic framework for dynamic optimization.
5  * Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl,
6  * KU Leuven. All rights reserved.
7  * Copyright (C) 2011-2014 Greg Horn
8  *
9  * CasADi is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 3 of the License, or (at your option) any later version.
13  *
14  * CasADi is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with CasADi; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  *
23  */
24 
25 
26 #include "sundials_interface.hpp"
27 
28 #include "casadi/core/casadi_misc.hpp"
29 
30 INPUTSCHEME(IntegratorInput)
31 OUTPUTSCHEME(IntegratorOutput)
32 
33 namespace casadi {
34 
35 SundialsInterface::SundialsInterface(const std::string& name, const Function& dae,
36  double t0, const std::vector<double>& tout)
37  : Integrator(name, dae, t0, tout) {
38 }
39 
41 }
42 
45  {{"max_num_steps",
46  {OT_INT,
47  "Maximum number of integrator steps"}},
48  {"reltol",
49  {OT_DOUBLE,
50  "Relative tolerence for the IVP solution"}},
51  {"abstol",
52  {OT_DOUBLE,
53  "Absolute tolerence for the IVP solution"}},
54  {"newton_scheme",
55  {OT_STRING,
56  "Linear solver scheme in the Newton method: DIRECT|gmres|bcgstab|tfqmr"}},
57  {"max_krylov",
58  {OT_INT,
59  "Maximum Krylov subspace size"}},
60  {"sensitivity_method",
61  {OT_STRING,
62  "Sensitivity method: SIMULTANEOUS|staggered"}},
63  {"max_multistep_order",
64  {OT_INT,
65  "Maximum order for the (variable-order) multistep method"}},
66  {"use_preconditioner",
67  {OT_BOOL,
68  "Precondition the iterative solver [default: true]"}},
69  {"stop_at_end",
70  {OT_BOOL,
71  "[DEPRECATED] Stop the integrator at the end of the interval"}},
72  {"disable_internal_warnings",
73  {OT_BOOL,
74  "Disable SUNDIALS internal warning messages"}},
75  {"quad_err_con",
76  {OT_BOOL,
77  "Should the quadratures affect the step size control"}},
78  {"fsens_err_con",
79  {OT_BOOL,
80  "include the forward sensitivities in all error controls"}},
81  {"steps_per_checkpoint",
82  {OT_INT,
83  "Number of steps between two consecutive checkpoints"}},
84  {"interpolation_type",
85  {OT_STRING,
86  "Type of interpolation for the adjoint sensitivities"}},
87  {"linear_solver",
88  {OT_STRING,
89  "A custom linear solver creator function [default: qr]"}},
90  {"linear_solver_options",
91  {OT_DICT,
92  "Options to be passed to the linear solver"}},
93  {"second_order_correction",
94  {OT_BOOL,
95  "Second order correction in the augmented system Jacobian [true]"}},
96  {"step0",
97  {OT_DOUBLE,
98  "initial step size [default: 0/estimated]"}},
99  {"max_step_size",
100  {OT_DOUBLE,
101  "Max step size [default: 0/inf]"}},
102  {"max_order",
103  {OT_DOUBLE,
104  "Maximum order"}},
105  {"nonlin_conv_coeff",
106  {OT_DOUBLE,
107  "Coefficient in the nonlinear convergence test"}},
108  {"scale_abstol",
109  {OT_BOOL,
110  "Scale absolute tolerance by nominal value"}}
111  }
112 };
113 
114 void SundialsInterface::init(const Dict& opts) {
115  // Call the base class method
116  Integrator::init(opts);
117 
118  // Default options
119  abstol_ = 1e-8;
120  reltol_ = 1e-6;
121  max_num_steps_ = 10000;
122  stop_at_end_ = true;
123  use_precon_ = true;
124  max_krylov_ = 10;
125  linear_solver_ = "qr";
126  std::string newton_scheme = "direct";
127  quad_err_con_ = false;
128  std::string interpolation_type = "hermite";
133  step0_ = 0;
134  max_step_size_ = 0;
135  max_order_ = 0;
136  nonlin_conv_coeff_ = 0;
137  scale_abstol_ = false;
138 
139  // Read options
140  for (auto&& op : opts) {
141  if (op.first=="abstol") {
142  abstol_ = op.second;
143  } else if (op.first=="reltol") {
144  reltol_ = op.second;
145  } else if (op.first=="max_num_steps") {
146  max_num_steps_ = op.second;
147  } else if (op.first=="stop_at_end") {
148  stop_at_end_ = op.second;
149  if (!stop_at_end_) {
150  casadi_warning("The 'stop_at_end' option has been deprecated and is currently ignored");
151  }
152  } else if (op.first=="use_preconditioner") {
153  use_precon_ = op.second;
154  } else if (op.first=="max_krylov") {
155  max_krylov_ = op.second;
156  } else if (op.first=="newton_scheme") {
157  newton_scheme = op.second.to_string();
158  } else if (op.first=="linear_solver") {
159  linear_solver_ = op.second.to_string();
160  } else if (op.first=="linear_solver_options") {
161  linear_solver_options_ = op.second;
162  } else if (op.first=="quad_err_con") {
163  quad_err_con_ = op.second;
164  } else if (op.first=="interpolation_type") {
165  interpolation_type = op.second.to_string();
166  } else if (op.first=="steps_per_checkpoint") {
167  steps_per_checkpoint_ = op.second;
168  } else if (op.first=="disable_internal_warnings") {
169  disable_internal_warnings_ = op.second;
170  } else if (op.first=="max_multistep_order") {
171  max_multistep_order_ = op.second;
172  } else if (op.first=="second_order_correction") {
173  second_order_correction_ = op.second;
174  } else if (op.first=="step0") {
175  step0_ = op.second;
176  } else if (op.first=="max_step_size") {
177  max_step_size_ = op.second;
178  } else if (op.first=="max_order") {
179  max_order_ = op.second;
180  } else if (op.first=="nonlin_conv_coeff") {
181  nonlin_conv_coeff_ = op.second;
182  } else if (op.first=="scale_abstol") {
183  scale_abstol_ = op.second;
184  }
185  }
186 
187  // Type of Newton scheme
188  if (newton_scheme=="direct") {
190  } else if (newton_scheme=="gmres") {
192  } else if (newton_scheme=="bcgstab") {
194  } else if (newton_scheme=="tfqmr") {
196  } else {
197  casadi_error("Unknown Newton scheme: " + newton_scheme);
198  }
199 
200  // Interpolation_type
201  if (interpolation_type=="hermite") {
203  } else if (interpolation_type=="polynomial") {
205  } else {
206  casadi_error("Unknown interpolation type: " + interpolation_type);
207  }
208 
209  // If derivative, use Jacobian from non-augmented system if possible
210  SundialsInterface* d = 0;
211  if (nfwd_ > 0 && !derivative_of_.is_null()) {
213  casadi_assert_dev(d != nullptr);
214  }
215 
216  // Get Jacobian function, forward problem
217  Function jacF;
218  Sparsity jacF_sp;
219  if (d == 0) {
220  // New Jacobian function
221  jacF = create_function("jacF", {"t", "x", "z", "p", "u"},
222  {"jac:ode:x", "jac:alg:x", "jac:ode:z", "jac:alg:z"});
223  jacF_sp = jacF.sparsity_out(JACF_ODE_X) + Sparsity::diag(nx1_);
224  if (nz_ > 0) {
225  jacF_sp = horzcat(vertcat(jacF_sp, jacF.sparsity_out(JACF_ALG_X)),
226  vertcat(jacF.sparsity_out(JACF_ODE_Z), jacF.sparsity_out(JACF_ALG_Z)));
227  }
228  } else {
229  // Reuse existing Jacobian function
230  jacF = d->get_function("jacF");
231  set_function(jacF, jacF.name(), true);
232  linsolF_ = d->linsolF_;
233  jacF_sp = linsolF_.sparsity();
234  }
235  alloc_w(jacF_sp.nnz(), true); // jacF
236 
237  // Linear solver for forward problem
238  if (linsolF_.is_null()) {
239  linsolF_ = Linsol("linsolF", linear_solver_, jacF_sp, linear_solver_options_);
240  }
241 
242  // Attach functions to calculate DAE and quadrature RHS all-at-once
243  if (nfwd_ > 0) {
244  create_forward("daeF", nfwd_);
245  if (nq_ > 0) create_forward("quadF", nfwd_);
246  if (nadj_ > 0) {
247  create_forward("daeB", nfwd_);
248  if (nrq_ > 0 || nuq_ > 0) create_forward("quadB", nfwd_);
249  }
250  }
251 
252  // Attach functions for jacobian information, foward problem
254  create_function("jtimesF", {"t", "x", "z", "p", "u", "fwd:x", "fwd:z"},
255  {"fwd:ode", "fwd:alg"});
256  if (nfwd_ > 0) {
257  create_forward("jtimesF", nfwd_);
258  }
259  }
260 
261 
262  // For Jacobian calculation
263  alloc_w(jacF.nnz_out(JACF_ODE_X), true); // jac_ode_x
264  alloc_w(jacF.nnz_out(JACF_ALG_X), true); // jac_alg_x
265  alloc_w(jacF.nnz_out(JACF_ODE_Z), true); // jac_ode_z
266  alloc_w(jacF.nnz_out(JACF_ALG_Z), true); // jac_alg_z
267 
268  // Transposing the Jacobian (for calculating jacB)
269  // This will be unnecessary once linsolF_ is used for both forward and adjoint
270  // cf. #3047
271  alloc_w(nx_ + nz_); // casadi_trans
272  alloc_iw(nx_ + nz_); // casadi_trans
273 }
274 
275 void SundialsInterface::set_work(void* mem, const double**& arg, double**& res,
276  casadi_int*& iw, double*& w) const {
277  auto m = static_cast<SundialsMemory*>(mem);
278 
279  // Set work in base classes
280  Integrator::set_work(mem, arg, res, iw, w);
281 
282  // Work vectors
283  m->jacF = w; w += linsolF_.sparsity().nnz();
284 
285  // Work vectors
286  const Function& jacF = get_function("jacF");
287  m->jac_ode_x = w; w += jacF.nnz_out(JACF_ODE_X);
288  m->jac_alg_x = w; w += jacF.nnz_out(JACF_ALG_X);
289  m->jac_ode_z = w; w += jacF.nnz_out(JACF_ODE_Z);
290  m->jac_alg_z = w; w += jacF.nnz_out(JACF_ALG_Z);
291 }
292 
293 int SundialsInterface::init_mem(void* mem) const {
294  if (Integrator::init_mem(mem)) return 1;
295  auto m = static_cast<SundialsMemory*>(mem);
296 
297  // Allocate NVectors
298  m->v_xz = N_VNew_Serial(nx_ + nz_);
299  m->v_q = N_VNew_Serial(nq_);
300  m->v_adj_xz = N_VNew_Serial(nrx_ + nrz_);
301  m->v_adj_pu = N_VNew_Serial(nrq_ + nuq_);
302 
303  // Absolute tolerances as NVector
304  if (scale_abstol_) {
305  // Allocate NVector
306  m->abstolv = N_VNew_Serial(nx_ + nz_);
307  // Get pointer to data
308  double* abstolv = NV_DATA_S(m->abstolv);
309  // States
310  for (casadi_int d = 0; d <= nfwd_; ++d) {
311  for (casadi_int i = 0; i < nx1_; ++i) *abstolv++ = abstol_ * nom_x_[i];
312  }
313  // Algebraic variables
314  for (casadi_int d = 0; d <= nfwd_; ++d) {
315  for (casadi_int i = 0; i < nz1_; ++i) *abstolv++ = abstol_ * nom_z_[i];
316  }
317  // Consistency check
318  casadi_assert_dev(abstolv == NV_DATA_S(m->abstolv) + nx_ + nz_);
319  } else {
320  m->abstolv = nullptr;
321  }
322 
323  m->mem_linsolF = linsolF_.checkout();
324 
325  // Reset stats
326  reset_stats(m);
327 
328  return 0;
329 }
330 
331 void SundialsInterface::reset(IntegratorMemory* mem, bool first_call) const {
332  auto m = static_cast<SundialsMemory*>(mem);
333 
334  // Reset the base classes
335  Integrator::reset(mem, first_call);
336 
337  // Reset stats
338  if (first_call) reset_stats(m);
339 
340  // Set the state
341  casadi_copy(m->q, nq_, NV_DATA_S(m->v_q));
342  casadi_copy(m->x, nx_, NV_DATA_S(m->v_xz));
343  casadi_copy(m->z, nz_, NV_DATA_S(m->v_xz) + nx_);
344 }
345 
347  // Reset stats, forward problem
348  m->nsteps = m->nfevals = m->nlinsetups = m->netfails = 0;
349  m->qlast = m->qcur = -1;
350  m->tcur = t0_;
351  m->hinused = m->hlast = m->hcur = casadi::nan;
352  m->nniters = m->nncfails = 0;
353 
354  // Reset stats, backward problem
355  m->nstepsB = m->nfevalsB = m->nlinsetupsB = m->netfailsB = 0;
356  m->qlastB = m->qcurB = -1;
357  m->hinusedB = m->hlastB = m->hcurB = m->tcurB = casadi::nan;
358  m->nnitersB = m->nncfailsB = 0;
359 
360  // Set offsets to zero
361  save_offsets(m);
362 }
363 
365  // Retrieve stats offset, backward problem
366  m->nstepsB_off = m->nstepsB;
367  m->nfevalsB_off = m->nfevalsB;
369  m->netfailsB_off = m->netfailsB;
370  m->nnitersB_off = m->nnitersB;
371  m->nncfailsB_off = m->nncfailsB;
372 }
373 
375  // Add stats offsets, backward problem
376  m->nstepsB += m->nstepsB_off;
377  m->nfevalsB += m->nfevalsB_off;
378  m->nlinsetupsB += m->nlinsetupsB_off;
379  m->netfailsB += m->netfailsB_off;
380  m->nnitersB += m->nnitersB_off;
381  m->nncfailsB += m->nncfailsB_off;
382 
383 }
384 
386  auto m = static_cast<SundialsMemory*>(mem);
387 
388  // Clear seeds
389  casadi_clear(m->adj_q, nrp_);
390  casadi_clear(NV_DATA_S(m->v_adj_xz), nrx_ + nrz_);
391 
392  // Reset summation states
393  N_VConst(0., m->v_adj_pu);
394 }
395 
397  const double* adj_x, const double* adj_z, const double* adj_q) const {
398  auto m = static_cast<SundialsMemory*>(mem);
399 
400  // Add impulse to backward parameters
401  casadi_axpy(nrp_, 1., adj_q, m->adj_q);
402 
403  // Add impulse to backward state
404  casadi_axpy(nrx_, 1., adj_x, NV_DATA_S(m->v_adj_xz));
405 
406  // Add impulse to algebraic variables:
407  // If nonzero, this has to be propagated to an impulse in backward state
408  // casadi_copy(adj_z, nrz_, NV_DATA_S(m->v_adj_xz) + nrx_);
409  casadi_axpy(nrz_, 1., adj_z, NV_DATA_S(m->v_adj_xz) + nrx_);
410 }
411 
413  this->v_xz = nullptr;
414  this->v_q = nullptr;
415  this->v_adj_xz = nullptr;
416  this->v_adj_pu = nullptr;
417  this->first_callB = true;
418  this->abstolv = nullptr;
419  this->mem_linsolF = -1;
420 }
421 
423  if (this->v_xz) N_VDestroy_Serial(this->v_xz);
424  if (this->v_q) N_VDestroy_Serial(this->v_q);
425  if (this->v_adj_xz) N_VDestroy_Serial(this->v_adj_xz);
426  if (this->v_adj_pu) N_VDestroy_Serial(this->v_adj_pu);
427  if (this->abstolv) N_VDestroy_Serial(this->abstolv);
428 }
429 
431  Dict stats = Integrator::get_stats(mem);
432  auto m = static_cast<SundialsMemory*>(mem);
433 
434  // Counters, forward problem
435  stats["nsteps"] = static_cast<casadi_int>(m->nsteps);
436  stats["nfevals"] = static_cast<casadi_int>(m->nfevals);
437  stats["nlinsetups"] = static_cast<casadi_int>(m->nlinsetups);
438  stats["netfails"] = static_cast<casadi_int>(m->netfails);
439  stats["qlast"] = m->qlast;
440  stats["qcur"] = m->qcur;
441  stats["hinused"] = m->hinused;
442  stats["hlast"] = m->hlast;
443  stats["hcur"] = m->hcur;
444  stats["tcur"] = m->tcur;
445  stats["nniters"] = static_cast<casadi_int>(m->nniters);
446  stats["nncfails"] = static_cast<casadi_int>(m->nncfails);
447 
448  // Counters, backward problem
449  stats["nstepsB"] = static_cast<casadi_int>(m->nstepsB);
450  stats["nfevalsB"] = static_cast<casadi_int>(m->nfevalsB);
451  stats["nlinsetupsB"] = static_cast<casadi_int>(m->nlinsetupsB);
452  stats["netfailsB"] = static_cast<casadi_int>(m->netfailsB);
453  stats["qlastB"] = m->qlastB;
454  stats["qcurB"] = m->qcurB;
455  stats["hinusedB"] = m->hinusedB;
456  stats["hlastB"] = m->hlastB;
457  stats["hcurB"] = m->hcurB;
458  stats["tcurB"] = m->tcurB;
459  stats["nnitersB"] = static_cast<casadi_int>(m->nnitersB);
460  stats["nncfailsB"] = static_cast<casadi_int>(m->nncfailsB);
461  return stats;
462 }
463 
465  auto m = to_mem(mem);
466  print("FORWARD INTEGRATION:\n");
467  print("Number of steps taken by SUNDIALS: %ld\n", m->nsteps);
468  print("Number of calls to the user's f function: %ld\n", m->nfevals);
469  print("Number of calls made to the linear solver setup function: %ld\n", m->nlinsetups);
470  print("Number of error test failures: %ld\n", m->netfails);
471  print("Method order used on the last internal step: %d\n", m->qlast);
472  print("Method order to be used on the next internal step: %d\n", m->qcur);
473  print("Actual value of initial step size: %g\n", m->hinused);
474  print("Step size taken on the last internal step: %g\n", m->hlast);
475  print("Step size to be attempted on the next internal step: %g\n", m->hcur);
476  print("Current internal time reached: %g\n", m->tcur);
477  print("Number of nonlinear iterations performed: %ld\n", m->nniters);
478  print("Number of nonlinear convergence failures: %ld\n", m->nncfails);
479  if (nrx_>0) {
480  print("BACKWARD INTEGRATION:\n");
481  print("Number of steps taken by SUNDIALS: %ld\n", m->nstepsB);
482  print("Number of calls to the user's f function: %ld\n", m->nfevalsB);
483  print("Number of calls made to the linear solver setup function: %ld\n", m->nlinsetupsB);
484  print("Number of error test failures: %ld\n", m->netfailsB);
485  print("Method order used on the last internal step: %d\n" , m->qlastB);
486  print("Method order to be used on the next internal step: %d\n", m->qcurB);
487  print("Actual value of initial step size: %g\n", m->hinusedB);
488  print("Step size taken on the last internal step: %g\n", m->hlastB);
489  print("Step size to be attempted on the next internal step: %g\n", m->hcurB);
490  print("Current internal time reached: %g\n", m->tcurB);
491  print("Number of nonlinear iterations performed: %ld\n", m->nnitersB);
492  print("Number of nonlinear convergence failures: %ld\n", m->nncfailsB);
493  }
494  print("\n");
495 }
496 
498  int version = s.version("SundialsInterface", 1, 2);
499  s.unpack("SundialsInterface::abstol", abstol_);
500  s.unpack("SundialsInterface::reltol", reltol_);
501  s.unpack("SundialsInterface::max_num_steps", max_num_steps_);
502  s.unpack("SundialsInterface::stop_at_end", stop_at_end_);
503  s.unpack("SundialsInterface::quad_err_con", quad_err_con_);
504  s.unpack("SundialsInterface::steps_per_checkpoint", steps_per_checkpoint_);
505  s.unpack("SundialsInterface::disable_internal_warnings", disable_internal_warnings_);
506  s.unpack("SundialsInterface::max_multistep_order", max_multistep_order_);
507  s.unpack("SundialsInterface::linear_solver", linear_solver_);
508  s.unpack("SundialsInterface::linear_solver_options", linear_solver_options_);
509 
510  s.unpack("SundialsInterface::max_krylov", max_krylov_);
511  s.unpack("SundialsInterface::use_precon", use_precon_);
512  s.unpack("SundialsInterface::second_order_correction", second_order_correction_);
513 
514  s.unpack("SundialsInterface::step0", step0_);
515  if (version>=2) {
516  s.unpack("SundialsInterface::max_step_size", max_step_size_);
517  } else {
518  max_step_size_ = 0;
519  }
520 
521  s.unpack("SundialsInterface::nonlin_conv_coeff", nonlin_conv_coeff_);
522  s.unpack("SundialsInterface::max_order", max_order_);
523  s.unpack("SundialsInterface::scale_abstol", scale_abstol_);
524 
525  s.unpack("SundialsInterface::linsolF", linsolF_);
526 
527  int newton_scheme;
528  s.unpack("SundialsInterface::newton_scheme", newton_scheme);
529  newton_scheme_ = static_cast<NewtonScheme>(newton_scheme);
530 
531  int interp;
532  s.unpack("SundialsInterface::interp", interp);
533  interp_ = static_cast<InterpType>(interp);
534 
535 }
536 
539  s.version("SundialsInterface", 2);
540  s.pack("SundialsInterface::abstol", abstol_);
541  s.pack("SundialsInterface::reltol", reltol_);
542  s.pack("SundialsInterface::max_num_steps", max_num_steps_);
543  s.pack("SundialsInterface::stop_at_end", stop_at_end_);
544  s.pack("SundialsInterface::quad_err_con", quad_err_con_);
545  s.pack("SundialsInterface::steps_per_checkpoint", steps_per_checkpoint_);
546  s.pack("SundialsInterface::disable_internal_warnings", disable_internal_warnings_);
547  s.pack("SundialsInterface::max_multistep_order", max_multistep_order_);
548 
549  s.pack("SundialsInterface::linear_solver", linear_solver_);
550  s.pack("SundialsInterface::linear_solver_options", linear_solver_options_);
551  s.pack("SundialsInterface::max_krylov", max_krylov_);
552  s.pack("SundialsInterface::use_precon", use_precon_);
553  s.pack("SundialsInterface::second_order_correction", second_order_correction_);
554 
555  s.pack("SundialsInterface::step0", step0_);
556  s.pack("SundialsInterface::max_step_size", max_step_size_);
557 
558  s.pack("SundialsInterface::nonlin_conv_coeff", nonlin_conv_coeff_);
559  s.pack("SundialsInterface::max_order", max_order_);
560  s.pack("SundialsInterface::scale_abstol", scale_abstol_);
561 
562  s.pack("SundialsInterface::linsolF", linsolF_);
563 
564  s.pack("SundialsInterface::newton_scheme", static_cast<int>(newton_scheme_));
565  s.pack("SundialsInterface::interp", static_cast<int>(interp_));
566 }
567 
568 int SundialsInterface::calc_daeF(SundialsMemory* m, double t, const double* x, const double* z,
569  double* ode, double* alg) const {
570  // Evaluate nondifferentiated
571  m->arg[DYN_T] = &t; // t
572  m->arg[DYN_X] = x; // x
573  m->arg[DYN_Z] = z; // z
574  m->arg[DYN_P] = m->p; // p
575  m->arg[DYN_U] = m->u; // u
576  m->res[DAE_ODE] = ode; // ode
577  m->res[DAE_ALG] = alg; // alg
578  if (calc_function(m, "daeF")) return 1;
579  // Evaluate sensitivities
580  if (nfwd_ > 0) {
581  m->arg[DYN_NUM_IN + DAE_ODE] = ode; // out:ode
582  m->arg[DYN_NUM_IN + DAE_ALG] = alg; // out:alg
583  m->arg[DYN_NUM_IN + DAE_NUM_OUT + DYN_T] = 0; // fwd:t
584  m->arg[DYN_NUM_IN + DAE_NUM_OUT + DYN_X] = x + nx1_; // fwd:x
585  m->arg[DYN_NUM_IN + DAE_NUM_OUT + DYN_Z] = z ? z + nz1_ : 0; // fwd:z
586  m->arg[DYN_NUM_IN + DAE_NUM_OUT + DYN_P] = m->p + np1_; // fwd:p
587  m->arg[DYN_NUM_IN + DAE_NUM_OUT + DYN_U] = m->u + nu1_; // fwd:u
588  m->res[DAE_ODE] = ode ? ode + nx1_ : 0; // fwd:ode
589  m->res[DAE_ALG] = alg ? alg + nz1_ : 0; // fwd:alg
590  if (calc_function(m, forward_name("daeF", nfwd_))) return 1;
591  }
592  return 0;
593 }
594 
595 int SundialsInterface::calc_daeB(SundialsMemory* m, double t, const double* x, const double* z,
596  const double* adj_ode, const double* adj_alg, const double* adj_quad,
597  double* adj_x, double* adj_z) const {
598  // Evaluate nondifferentiated
599  m->arg[BDYN_T] = &t; // t
600  m->arg[BDYN_X] = x; // x
601  m->arg[BDYN_Z] = z; // z
602  m->arg[BDYN_P] = m->p; // p
603  m->arg[BDYN_U] = m->u; // u
604  m->arg[BDYN_OUT_ODE] = nullptr; // out_ode
605  m->arg[BDYN_OUT_ALG] = nullptr; // out_alg
606  m->arg[BDYN_OUT_QUAD] = nullptr; // out_quad
607  m->arg[BDYN_OUT_ZERO] = nullptr; // out_zero
608  m->arg[BDYN_ADJ_ODE] = adj_ode; // adj_ode
609  m->arg[BDYN_ADJ_ALG] = adj_alg; // adj_alg
610  m->arg[BDYN_ADJ_QUAD] = adj_quad; // adj_quad
611  m->arg[BDYN_ADJ_ZERO] = nullptr; // adj_zero
612  m->res[BDAE_ADJ_X] = adj_x; // adj_x
613  m->res[BDAE_ADJ_Z] = adj_z; // adj_z
614  if (calc_function(m, "daeB")) return 1;
615  // Evaluate sensitivities
616  if (nfwd_ > 0) {
617  m->arg[BDYN_NUM_IN + BDAE_ADJ_X] = adj_x; // out:adj_x
618  m->arg[BDYN_NUM_IN + BDAE_ADJ_Z] = adj_z; // out:adj_z
619  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_T] = 0; // fwd:t
620  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_X] = x ? x + nx1_ : x; // fwd:x
621  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_Z] = z ? z + nz1_ : z; // fwd:z
622  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_P] = m->p + np1_; // fwd:p
623  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_U] = m->u + nu1_; // fwd:u
624  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_OUT_ODE] = nullptr; // fwd:out_ode
625  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_OUT_ALG] = nullptr; // fwd:out_alg
626  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_OUT_QUAD] = nullptr; // fwd:out_quad
627  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_OUT_ZERO] = nullptr; // fwd:out_zero
629  adj_ode ? adj_ode + nrx1_ * nadj_ : 0; // fwd:adj_ode
631  adj_alg ? adj_alg + nrz1_ * nadj_ : 0; // fwd:adj_alg
633  adj_quad ? adj_quad + nrp1_ * nadj_ : 0; // fwd:adj_quad
634  m->arg[BDYN_NUM_IN + BDAE_NUM_OUT + BDYN_ADJ_ZERO] = nullptr; // fwd:adj_zero
635  m->res[BDAE_ADJ_X] = adj_x ? adj_x + nrx1_ * nadj_ : 0; // fwd:adj_x
636  m->res[BDAE_ADJ_Z] = adj_z ? adj_z + nrz1_ * nadj_ : 0; // fwd:adj_z
637  if (calc_function(m, forward_name("daeB", nfwd_))) return 1;
638  }
639  return 0;
640 }
641 
642 int SundialsInterface::calc_quadF(SundialsMemory* m, double t, const double* x, const double* z,
643  double* quad) const {
644  m->arg[DYN_T] = &t; // t
645  m->arg[DYN_X] = x; // x
646  m->arg[DYN_Z] = z; // z
647  m->arg[DYN_P] = m->p; // p
648  m->arg[DYN_U] = m->u; // u
649  m->res[QUAD_QUAD] = quad; // quad
650  if (calc_function(m, "quadF")) return 1;
651  // Evaluate sensitivities
652  if (nfwd_ > 0) {
653  m->arg[DYN_NUM_IN + QUAD_QUAD] = quad; // out:quad
654  m->arg[DYN_NUM_IN + QUAD_NUM_OUT + DYN_T] = 0; // fwd:t
655  m->arg[DYN_NUM_IN + QUAD_NUM_OUT + DYN_X] = x + nx1_; // fwd:x
656  m->arg[DYN_NUM_IN + QUAD_NUM_OUT + DYN_Z] = z ? z + nz1_ : 0; // fwd:z
657  m->arg[DYN_NUM_IN + QUAD_NUM_OUT + DYN_P] = m->p + np1_; // fwd:p
658  m->arg[DYN_NUM_IN + QUAD_NUM_OUT + DYN_U] = m->u + nu1_; // fwd:u
659  m->res[QUAD_QUAD] = quad ? quad + nq1_ : 0; // fwd:quad
660  if (calc_function(m, forward_name("quadF", nfwd_))) return 1;
661  }
662  return 0;
663 }
664 
665 int SundialsInterface::calc_quadB(SundialsMemory* m, double t, const double* x, const double* z,
666  const double* adj_ode, const double* adj_alg, double* adj_p, double* adj_u) const {
667  // Evaluate nondifferentiated
668  m->arg[BDYN_T] = &t; // t
669  m->arg[BDYN_X] = x; // x
670  m->arg[BDYN_Z] = z; // z
671  m->arg[BDYN_P] = m->p; // p
672  m->arg[BDYN_U] = m->u; // u
673  m->arg[BDYN_OUT_ODE] = nullptr; // out_ode
674  m->arg[BDYN_OUT_ALG] = nullptr; // out_alg
675  m->arg[BDYN_OUT_QUAD] = nullptr; // out_quad
676  m->arg[BDYN_OUT_ZERO] = nullptr; // out_zero
677  m->arg[BDYN_ADJ_ODE] = adj_ode; // adj_ode
678  m->arg[BDYN_ADJ_ALG] = adj_alg; // adj_alg
679  m->arg[BDYN_ADJ_QUAD] = m->adj_q; // adj_quad
680  m->arg[BDYN_ADJ_ZERO] = nullptr; // adj_zero
681  m->res[BQUAD_ADJ_P] = adj_p; // adj_p
682  m->res[BQUAD_ADJ_U] = adj_u; // adj_u
683  if (calc_function(m, "quadB")) return 1;
684  // Evaluate sensitivities
685  if (nfwd_ > 0) {
686  m->arg[BDYN_NUM_IN + BQUAD_ADJ_P] = adj_p; // out:adj_p
687  m->arg[BDYN_NUM_IN + BQUAD_ADJ_U] = adj_u; // out:adj_u
688  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_T] = 0; // fwd:t
689  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_X] = x ? x + nx1_ : 0; // fwd:x
690  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_Z] = z ? z + nz1_ : 0; // fwd:z
691  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_P] = m->p + np1_; // fwd:p
692  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_U] = m->u + nu1_; // fwd:u
693  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_OUT_ODE] = nullptr; // fwd:out_ode
694  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_OUT_ALG] = nullptr; // fwd:out_alg
695  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_OUT_QUAD] = nullptr; // fwd:out_quad
696  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_OUT_ZERO] = nullptr; // fwd:out_zero
698  adj_ode ? adj_ode + nrx1_ * nadj_ : 0; // fwd:adj_ode
700  adj_alg ? adj_alg + nrz1_ * nadj_ : 0; // fwd:adj_alg
701  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_ADJ_QUAD] = m->adj_q + nrp1_ * nadj_; // fwd:adj_quad
702  m->arg[BDYN_NUM_IN + BQUAD_NUM_OUT + BDYN_ADJ_ZERO] = nullptr; // fwd:adj_zero
703  m->res[BQUAD_ADJ_P] = adj_p + nrq1_ * nadj_; // fwd:adj_p
704  m->res[BQUAD_ADJ_U] = adj_u + nuq1_ * nadj_; // fwd:adj_u
705  if (calc_function(m, forward_name("quadB", nfwd_))) return 1;
706  }
707  return 0;
708 }
709 
710 int SundialsInterface::calc_jtimesF(SundialsMemory* m, double t, const double* x, const double* z,
711  const double* fwd_x, const double* fwd_z, double* fwd_ode, double* fwd_alg) const {
712  // Evaluate nondifferentiated
713  m->arg[JTIMESF_T] = &t; // t
714  m->arg[JTIMESF_X] = x; // x
715  m->arg[JTIMESF_Z] = z; // z
716  m->arg[JTIMESF_P] = m->p; // p
717  m->arg[JTIMESF_U] = m->u; // u
718  m->arg[JTIMESF_FWD_X] = fwd_x; // fwd:x
719  m->arg[JTIMESF_FWD_Z] = fwd_z; // fwd:z
720  m->res[JTIMESF_FWD_ODE] = fwd_ode; // fwd:ode
721  m->res[JTIMESF_FWD_ALG] = fwd_alg; // fwd:alg
722  if (calc_function(m, "jtimesF")) return 1;
723  // Evaluate sensitivities
724  if (nfwd_ > 0) {
725  m->arg[JTIMESF_NUM_IN + JTIMESF_FWD_ODE] = fwd_ode; // out:fwd:ode
726  m->arg[JTIMESF_NUM_IN + JTIMESF_FWD_ALG] = fwd_alg; // out:fwd:alg
727  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_T] = 0; // fwd:t
728  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_X] = x + nx1_; // fwd:x
729  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_Z] = z + nz1_; // fwd:z
730  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_P] = m->p + np1_; // fwd:p
731  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_U] = m->u + nu1_; // fwd:u
732  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_FWD_X] = fwd_x + nx1_; // fwd:fwd:x
733  m->arg[JTIMESF_NUM_IN + JTIMESF_NUM_OUT + JTIMESF_FWD_Z] = fwd_z + nz1_; // fwd:fwd:z
734  m->res[JTIMESF_FWD_ODE] = fwd_ode + nx1_; // fwd:fwd:ode
735  m->res[JTIMESF_FWD_ALG] = fwd_alg + nz1_; // fwd:fwd:alg
736  if (calc_function(m, forward_name("jtimesF", nfwd_))) return 1;
737  }
738  // Successful return
739  return 0;
740 }
741 
742 int SundialsInterface::calc_jacF(SundialsMemory* m, double t, const double* x, const double* z,
743  double* jac_ode_x, double* jac_alg_x, double* jac_ode_z, double* jac_alg_z) const {
744  // Calculate Jacobian
745  m->arg[DYN_T] = &t;
746  m->arg[DYN_X] = x;
747  m->arg[DYN_Z] = z;
748  m->arg[DYN_P] = m->p;
749  m->arg[DYN_U] = m->u;
750  m->res[JACF_ODE_X] = jac_ode_x;
751  m->res[JACF_ALG_X] = jac_alg_x;
752  m->res[JACF_ODE_Z] = jac_ode_z;
753  m->res[JACF_ALG_Z] = jac_alg_z;
754  return calc_function(m, "jacF");
755 }
756 
757 } // namespace casadi
Helper class for Serialization.
void unpack(Sparsity &e)
Reconstruct an object from the input stream.
void version(const std::string &name, int v)
void alloc_iw(size_t sz_iw, bool persistent=false)
Ensure required length of iw field.
static std::string forward_name(const std::string &fcn, casadi_int nfwd)
Helper function: Get name of forward derivative function.
void alloc_w(size_t sz_w, bool persistent=false)
Ensure required length of w field.
Function derivative_of_
If the function is the derivative of another function.
Function object.
Definition: function.hpp:60
casadi_int nnz_out() const
Get number of output nonzeros.
Definition: function.cpp:855
const Sparsity & sparsity_out(casadi_int ind) const
Get sparsity of a given output.
Definition: function.cpp:1031
FunctionInternal * get() const
Definition: function.cpp:353
const std::string & name() const
Name of the function.
Definition: function.cpp:1307
bool is_null() const
Is a null pointer?
Internal storage for integrator related data.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
casadi_int nfwd_
Number of sensitivities.
void init(const Dict &opts) override
Initialize.
Definition: integrator.cpp:642
virtual void reset(IntegratorMemory *mem, bool first_call) const
Reset the forward solver at the start or after an event.
static const Options options_
Options.
std::vector< double > nom_x_
casadi_int nrx_
Number of states for the backward integration.
int init_mem(void *mem) const override
Initalize memory block.
Definition: integrator.cpp:912
std::vector< double > nom_z_
void set_work(void *mem, const double **&arg, double **&res, casadi_int *&iw, double *&w) const override
Set the (persistent) work vectors.
Definition: integrator.cpp:882
double t0_
Initial time.
casadi_int nx_
Number of states for the forward integration.
Linear solver.
Definition: linsol.hpp:55
casadi_int checkout() const
Checkout a memory object.
Definition: linsol.cpp:197
const Sparsity & sparsity() const
Get linear system sparsity.
Definition: linsol.cpp:69
void set_function(const Function &fcn, const std::string &fname, bool jit=false)
Function create_function(const Function &oracle, const std::string &fname, const std::vector< std::string > &s_in, const std::vector< std::string > &s_out, const Function::AuxOut &aux=Function::AuxOut(), const Dict &opts=Dict())
Function create_forward(const std::string &fname, casadi_int nfwd)
int calc_function(OracleMemory *m, const std::string &fcn, const double *const *arg=nullptr, int thread_id=0) const
std::vector< std::string > get_function() const override
Get list of dependency functions.
Dict get_stats(void *mem) const override
Get all statistics.
void print(const char *fmt,...) const
C-style formatted printing during evaluation.
Helper class for Serialization.
void version(const std::string &name, int v)
void pack(const Sparsity &e)
Serializes an object to the output stream.
General sparsity class.
Definition: sparsity.hpp:106
static Sparsity diag(casadi_int nrow)
Create diagonal sparsity pattern *.
Definition: sparsity.hpp:190
casadi_int nnz() const
Get the number of (structural) non-zeros.
Definition: sparsity.cpp:148
Linsol linsolF_
Linear solver.
void set_work(void *mem, const double **&arg, double **&res, casadi_int *&iw, double *&w) const override
Set the (persistent) work vectors.
enum casadi::SundialsInterface::InterpType interp_
static SundialsMemory * to_mem(void *mem)
Cast to memory object.
void impulseB(IntegratorMemory *mem, const double *adj_x, const double *adj_z, const double *adj_q) const override
Introduce an impulse into the backwards integration at the current time.
Dict get_stats(void *mem) const override
Get all statistics.
int calc_jacF(SundialsMemory *m, double t, const double *x, const double *z, double *jac_ode_x, double *jac_alg_x, double *jac_ode_z, double *jac_alg_z) const
NewtonScheme
Supported iterative solvers in Sundials.
void print_stats(IntegratorMemory *mem) const override
Print solver statistics.
int calc_quadF(SundialsMemory *m, double t, const double *x, const double *z, double *quad) const
SundialsInterface(const std::string &name, const Function &dae, double t0, const std::vector< double > &tout)
Constructor.
void serialize_body(SerializingStream &s) const override
Serialize an object without type information.
int calc_jtimesF(SundialsMemory *m, double t, const double *x, const double *z, const double *fwd_x, const double *fwd_z, double *fwd_ode, double *fwd_alg) const
enum casadi::SundialsInterface::NewtonScheme newton_scheme_
void reset(IntegratorMemory *mem, bool first_call) const override
Reset the forward solver at the start or after an event.
void reset_stats(SundialsMemory *m) const
Reset stats.
int calc_daeF(SundialsMemory *m, double t, const double *x, const double *z, double *ode, double *alg) const
int init_mem(void *mem) const override
Initalize memory block.
void add_offsets(SundialsMemory *m) const
Add stats offsets to stats.
~SundialsInterface() override=0
Destructor.
int calc_quadB(SundialsMemory *m, double t, const double *x, const double *z, const double *adj_ode, const double *adj_alg, double *adj_p, double *adj_u) const
void init(const Dict &opts) override
Initialize.
void resetB(IntegratorMemory *mem) const override
Reset the backward problem and take time to tf.
void save_offsets(SundialsMemory *m) const
Save stats offsets before reset.
int calc_daeB(SundialsMemory *m, double t, const double *x, const double *z, const double *adj_ode, const double *adj_alg, const double *adj_quad, double *adj_x, double *adj_z) const
static const Options options_
Options.
The casadi namespace.
Definition: archiver.cpp:28
void casadi_copy(const T1 *x, casadi_int n, T1 *y)
COPY: y <-x.
@ DYN_NUM_IN
Definition: integrator.hpp:196
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
const double nan
Not a number.
Definition: calculus.hpp:53
void casadi_axpy(casadi_int n, T1 alpha, const T1 *x, T1 *y)
AXPY: y <- a*x + y.
void casadi_clear(T1 *x, casadi_int n)
CLEAR: x <- 0.
Options metadata for a class.
Definition: options.hpp:40
long nstepsB
Stats, backward integration.
long nstepsB_off
Offsets for stats in backward integration.
long nsteps
Stats, forward integration.
int mem_linsolF
Linear solver memory objects.