26 #ifndef CASADI_SUNDIALS_INTERFACE_HPP
27 #define CASADI_SUNDIALS_INTERFACE_HPP
29 #include <casadi/interfaces/sundials/casadi_sundials_common_export.h>
30 #include "casadi/core/integrator_impl.hpp"
32 #include <nvector/nvector_serial.h>
33 #include <sundials/sundials_dense.h>
34 #include <sundials/sundials_iterative.h>
35 #include <sundials/sundials_types.h>
43 struct SundialsMemory :
public IntegratorMemory {
45 N_Vector xz, xzdot, q;
48 N_Vector rxz, rxzdot, ruq;
60 double *jac_ode_x, *jac_alg_x, *jac_ode_z, *jac_alg_z;
66 long nsteps, nfevals, nlinsetups, netfails;
68 double hinused, hlast, hcur, tcur;
69 long nniters, nncfails;
72 long nstepsB, nfevalsB, nlinsetupsB, netfailsB;
74 double hinusedB, hlastB, hcurB, tcurB;
75 long nnitersB, nncfailsB;
78 long nstepsB_off, nfevalsB_off, nlinsetupsB_off, netfailsB_off;
79 long nnitersB_off, nncfailsB_off;
100 class SundialsInterface :
public Integrator {
103 SundialsInterface(
const std::string& name,
const Function& dae,
104 double t0,
const std::vector<double>& tout);
107 ~SundialsInterface()
override=0;
111 static const Options options_;
112 const Options& get_options()
const override {
return options_;}
116 void init(
const Dict& opts)
override;
119 void set_work(
void* mem,
const double**& arg,
double**& res,
120 casadi_int*& iw,
double*& w)
const override;
123 int init_mem(
void* mem)
const override;
126 double get_reltol()
const override {
return reltol_;}
129 double get_abstol()
const override {
return abstol_;}
132 int calc_daeF(SundialsMemory* m,
double t,
const double* x,
const double* z,
133 double* ode,
double* alg)
const;
136 int calc_daeB(SundialsMemory* m,
double t,
const double* x,
const double* z,
137 const double* rx,
const double* rz,
const double* rp,
double* adj_x,
double* adj_z)
const;
140 int calc_quadF(SundialsMemory* m,
double t,
const double* x,
const double* z,
144 int calc_quadB(SundialsMemory* m,
double t,
const double* x,
const double* z,
145 const double* rx,
const double* rz,
double* adj_p,
double* adj_u)
const;
148 int calc_jtimesF(SundialsMemory* m,
double t,
const double* x,
const double* z,
149 const double* fwd_x,
const double* fwd_z,
double* fwd_ode,
double* fwd_alg)
const;
152 int calc_jacF(SundialsMemory* m,
double t,
const double* x,
const double* z,
153 double* jac_ode_x,
double* jac_alg_x,
double* jac_ode_z,
double* jac_alg_z)
const;
156 Dict get_stats(
void* mem)
const override;
159 void print_stats(IntegratorMemory* mem)
const override;
162 void reset(IntegratorMemory* mem,
const double* u,
const double* x,
163 const double* z,
const double* p)
const override;
166 void resetB(IntegratorMemory* mem)
const override;
169 void impulseB(IntegratorMemory* mem,
170 const double* rx,
const double* rz,
const double* rp)
const override;
173 void reset_stats(SundialsMemory* m)
const;
176 void save_offsets(SundialsMemory* m)
const;
179 void add_offsets(SundialsMemory* m)
const;
182 static SundialsMemory* to_mem(
void *mem) {
183 SundialsMemory* m =
static_cast<SundialsMemory*
>(mem);
184 casadi_assert_dev(m);
190 enum JtimesFIn { JTIMESF_T, JTIMESF_X, JTIMESF_Z, JTIMESF_P, JTIMESF_U, JTIMESF_FWD_X,
191 JTIMESF_FWD_Z, JTIMESF_NUM_IN};
192 enum JtimesFOut { JTIMESF_FWD_ODE, JTIMESF_FWD_ALG, JTIMESF_NUM_OUT};
193 enum JacFOut {JACF_ODE_X, JACF_ALG_X, JACF_ODE_Z, JACF_ALG_Z, JACF_NUM_OUT};
198 double abstol_, reltol_;
199 casadi_int max_num_steps_;
202 casadi_int steps_per_checkpoint_;
203 bool disable_internal_warnings_;
204 casadi_int max_multistep_order_;
205 std::string linear_solver_;
206 Dict linear_solver_options_;
207 casadi_int max_krylov_;
209 bool second_order_correction_;
211 double max_step_size_;
212 double nonlin_conv_coeff_;
213 casadi_int max_order_;
221 enum NewtonScheme {SD_DIRECT, SD_GMRES, SD_BCGSTAB, SD_TFQMR} newton_scheme_;
224 enum InterpType {SD_POLYNOMIAL, SD_HERMITE} interp_;
227 struct LinSolDataDense {};
230 static void printvar(
const std::string&
id,
double v) {
231 uout() <<
id <<
" = " << v << std::endl;
234 static void printvar(
const std::string&
id, N_Vector v) {
235 std::vector<double> tmp(NV_DATA_S(v), NV_DATA_S(v)+NV_LENGTH_S(v));
236 uout() <<
id <<
" = " << tmp << std::endl;
240 void serialize_body(SerializingStream &s)
const override;
244 explicit SundialsInterface(DeserializingStream& s);
249 std::vector<double> tmp(NV_DATA_S(v), NV_DATA_S(v)+NV_LENGTH_S(v));
CASADI_EXPORT std::ostream & uout()
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
bool is_regular(const std::vector< T > &v)
Checks if array does not contain NaN or Inf.