26 #ifndef CASADI_SUNDIALS_INTERFACE_HPP
27 #define CASADI_SUNDIALS_INTERFACE_HPP
29 #include <casadi/interfaces/sundials/casadi_sundials_common_export.h>
30 #include "casadi/core/integrator_impl.hpp"
32 #include <nvector/nvector_serial.h>
33 #include <sundials/sundials_dense.h>
34 #include <sundials/sundials_iterative.h>
35 #include <sundials/sundials_types.h>
43 struct SundialsMemory :
public IntegratorMemory {
45 N_Vector v_xz, v_xzdot, v_q;
48 N_Vector v_adj_xz, v_adj_xzdot, v_adj_pu;
54 double *jac_ode_x, *jac_alg_x, *jac_ode_z, *jac_alg_z;
60 long nsteps, nfevals, nlinsetups, netfails;
62 double hinused, hlast, hcur, tcur;
63 long nniters, nncfails;
66 long nstepsB, nfevalsB, nlinsetupsB, netfailsB;
68 double hinusedB, hlastB, hcurB, tcurB;
69 long nnitersB, nncfailsB;
72 long nstepsB_off, nfevalsB_off, nlinsetupsB_off, netfailsB_off;
73 long nnitersB_off, nncfailsB_off;
91 class SundialsInterface :
public Integrator {
94 SundialsInterface(
const std::string& name,
const Function& dae,
95 double t0,
const std::vector<double>& tout);
98 ~SundialsInterface()
override=0;
102 static const Options options_;
103 const Options& get_options()
const override {
return options_;}
107 void init(
const Dict& opts)
override;
110 void set_work(
void* mem,
const double**& arg,
double**& res,
111 casadi_int*& iw,
double*& w)
const override;
114 int init_mem(
void* mem)
const override;
117 double get_reltol()
const override {
return reltol_;}
120 double get_abstol()
const override {
return abstol_;}
123 int calc_daeF(SundialsMemory* m,
double t,
const double* x,
const double* z,
124 double* ode,
double* alg)
const;
127 int calc_daeB(SundialsMemory* m,
double t,
const double* x,
const double* z,
128 const double* adj_ode,
const double* adj_alg,
const double* adj_quad,
129 double* adj_x,
double* adj_z)
const;
132 int calc_quadF(SundialsMemory* m,
double t,
const double* x,
const double* z,
136 int calc_quadB(SundialsMemory* m,
double t,
const double* x,
const double* z,
137 const double* adj_ode,
const double* adj_alg,
double* adj_p,
double* adj_u)
const;
140 int calc_jtimesF(SundialsMemory* m,
double t,
const double* x,
const double* z,
141 const double* fwd_x,
const double* fwd_z,
double* fwd_ode,
double* fwd_alg)
const;
144 int calc_jacF(SundialsMemory* m,
double t,
const double* x,
const double* z,
145 double* jac_ode_x,
double* jac_alg_x,
double* jac_ode_z,
double* jac_alg_z)
const;
148 Dict get_stats(
void* mem)
const override;
151 void print_stats(IntegratorMemory* mem)
const override;
154 void reset(IntegratorMemory* mem,
bool first_call)
const override;
157 void resetB(IntegratorMemory* mem)
const override;
160 void impulseB(IntegratorMemory* mem,
161 const double* adj_x,
const double* adj_z,
const double* adj_q)
const override;
164 void reset_stats(SundialsMemory* m)
const;
167 void save_offsets(SundialsMemory* m)
const;
170 void add_offsets(SundialsMemory* m)
const;
173 static SundialsMemory* to_mem(
void *mem) {
174 SundialsMemory* m =
static_cast<SundialsMemory*
>(mem);
175 casadi_assert_dev(m);
181 enum JtimesFIn { JTIMESF_T, JTIMESF_X, JTIMESF_Z, JTIMESF_P, JTIMESF_U, JTIMESF_FWD_X,
182 JTIMESF_FWD_Z, JTIMESF_NUM_IN};
183 enum JtimesFOut { JTIMESF_FWD_ODE, JTIMESF_FWD_ALG, JTIMESF_NUM_OUT};
184 enum JacFOut {JACF_ODE_X, JACF_ALG_X, JACF_ODE_Z, JACF_ALG_Z, JACF_NUM_OUT};
189 double abstol_, reltol_;
190 casadi_int max_num_steps_;
193 casadi_int steps_per_checkpoint_;
194 bool disable_internal_warnings_;
195 casadi_int max_multistep_order_;
196 std::string linear_solver_;
197 Dict linear_solver_options_;
198 casadi_int max_krylov_;
200 bool second_order_correction_;
202 double max_step_size_;
203 double nonlin_conv_coeff_;
204 casadi_int max_order_;
212 enum NewtonScheme {SD_DIRECT, SD_GMRES, SD_BCGSTAB, SD_TFQMR} newton_scheme_;
215 enum InterpType {SD_POLYNOMIAL, SD_HERMITE} interp_;
218 struct LinSolDataDense {};
221 static void printvar(
const std::string&
id,
double v) {
222 uout() <<
id <<
" = " << v << std::endl;
225 static void printvar(
const std::string&
id, N_Vector v) {
226 std::vector<double> tmp(NV_DATA_S(v), NV_DATA_S(v)+NV_LENGTH_S(v));
227 uout() <<
id <<
" = " << tmp << std::endl;
231 void serialize_body(SerializingStream &s)
const override;
235 explicit SundialsInterface(DeserializingStream& s);
240 std::vector<double> tmp(NV_DATA_S(v), NV_DATA_S(v)+NV_LENGTH_S(v));
CASADI_EXPORT std::ostream & uout()
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.
bool is_regular(const std::vector< T > &v)
Checks if array does not contain NaN or Inf.