26 #ifndef CASADI_INTEGRATOR_IMPL_HPP
27 #define CASADI_INTEGRATOR_IMPL_HPP
29 #include "integrator.hpp"
30 #include "oracle_function.hpp"
31 #include "plugin_interface.hpp"
32 #include "casadi_enum.hpp"
41 struct CASADI_EXPORT IntegratorMemory :
public OracleMemory {
43 double *q, *x, *z, *p, *u, *e, *edot, *old_e, *xdot, *zdot;
45 double *adj_x, *adj_z, *adj_p, *adj_q;
63 casadi_int *event_triggered;
67 casadi_int event_iter;
69 casadi_int num_events;
71 casadi_int event_index;
75 struct CASADI_EXPORT SpForwardMem {
83 struct CASADI_EXPORT SpReverseMem {
98 Integrator :
public OracleFunction,
public PluginInterface<Integrator> {
103 Integrator(
const std::string& name,
const Function& oracle,
104 double t0,
const std::vector<double>& tout);
109 ~Integrator()
override=0;
115 size_t get_n_in()
override {
return INTEGRATOR_NUM_IN;}
116 size_t get_n_out()
override {
return INTEGRATOR_NUM_OUT;}
123 Sparsity get_sparsity_in(casadi_int i)
override;
124 Sparsity get_sparsity_out(casadi_int i)
override;
131 std::string get_name_in(casadi_int i)
override {
return integrator_in(i);}
132 std::string get_name_out(casadi_int i)
override {
return integrator_out(i);}
138 int init_mem(
void* mem)
const override;
144 static const Options options_;
145 const Options& get_options()
const override {
return options_;}
151 void init(
const Dict& opts)
override;
156 void set_work(
void* mem,
const double**& arg,
double**& res,
157 casadi_int*& iw,
double*& w)
const override;
160 virtual Function create_advanced(
const Dict& opts);
162 virtual MX algebraic_state_init(
const MX& x0,
const MX& z0)
const {
return z0; }
163 virtual MX algebraic_state_output(
const MX& Z)
const {
return Z; }
166 void set_q(IntegratorMemory* m,
const double* q)
const;
169 void set_x(IntegratorMemory* m,
const double* x)
const;
172 void set_z(IntegratorMemory* m,
const double* z)
const;
175 void set_p(IntegratorMemory* m,
const double* p)
const;
178 void set_u(IntegratorMemory* m,
const double* u)
const;
181 void get_q(IntegratorMemory* m,
double* q)
const;
184 void get_x(IntegratorMemory* m,
double* x)
const;
187 void get_z(IntegratorMemory* m,
double* z)
const;
192 virtual void reset(IntegratorMemory* mem,
bool first_call)
const {}
197 casadi_int next_stop(casadi_int k,
const double* u)
const;
202 int calc_edot(IntegratorMemory* m)
const;
207 int predict_events(IntegratorMemory* m)
const;
212 int trigger_event(IntegratorMemory* m, casadi_int* ind)
const;
217 int advance(IntegratorMemory* m)
const;
222 virtual int advance_noevent(IntegratorMemory* mem)
const = 0;
227 virtual void resetB(IntegratorMemory* mem)
const = 0;
232 casadi_int next_stopB(casadi_int k,
const double* u)
const;
237 virtual void impulseB(IntegratorMemory* mem,
238 const double* adj_x,
const double* adj_z,
const double* adj_q)
const = 0;
243 virtual void retreat(IntegratorMemory* mem,
const double* u,
244 double* adj_x,
double* adj_p,
double* adj_u)
const = 0;
249 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
254 virtual void print_stats(IntegratorMemory* mem)
const {}
257 int fdae_sp_forward(SpForwardMem* m,
const bvec_t* x,
258 const bvec_t* p,
const bvec_t* u, bvec_t* ode, bvec_t* alg)
const;
261 int fquad_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
262 const bvec_t* p,
const bvec_t* u, bvec_t* quad)
const;
265 int bdae_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
266 const bvec_t* p,
const bvec_t* u,
const bvec_t* adj_ode,
const bvec_t* adj_quad,
267 bvec_t* adj_x, bvec_t* adj_z)
const;
270 int bquad_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
271 const bvec_t* p,
const bvec_t* u,
const bvec_t* adj_ode,
const bvec_t* adj_alg,
272 const bvec_t* adj_quad, bvec_t* adj_p, bvec_t* adj_u)
const;
277 int sp_forward(
const bvec_t** arg, bvec_t** res,
278 casadi_int* iw, bvec_t* w,
void* mem)
const override;
281 int fdae_sp_reverse(SpReverseMem* m, bvec_t* x,
282 bvec_t* p, bvec_t* u, bvec_t* ode, bvec_t* alg)
const;
285 int fquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
286 bvec_t* p, bvec_t* u, bvec_t* quad)
const;
289 int bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
290 bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* adj_quad,
291 bvec_t* adj_x, bvec_t* adj_z)
const;
294 int bquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
295 bvec_t* p, bvec_t* u, bvec_t* adj_ode, bvec_t* adj_alg, bvec_t* adj_quad,
296 bvec_t* adj_p, bvec_t* adj_u)
const;
301 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w,
void* mem)
const override;
305 bool has_spfwd()
const override {
return true;}
306 bool has_sprev()
const override {
return true;}
313 Function get_forward(casadi_int nfwd,
const std::string& name,
314 const std::vector<std::string>& inames,
315 const std::vector<std::string>& onames,
316 const Dict& opts)
const override;
317 bool has_forward(casadi_int nfwd)
const override {
return true;}
324 Function get_reverse(casadi_int nadj,
const std::string& name,
325 const std::vector<std::string>& inames,
326 const std::vector<std::string>& onames,
327 const Dict& opts)
const override;
328 bool has_reverse(casadi_int nadj)
const override {
return ne_ == 0;}
334 virtual Dict getDerivativeOptions(
bool fwd)
const;
340 template<
typename MatType> Function get_forward_dae(
const std::string& name)
const;
341 Function augmented_dae()
const;
345 static bool all_zero(
const double* v, casadi_int n);
348 Sparsity sp_jac_aug(
const Sparsity& J,
const Sparsity& J1)
const;
351 Sparsity sp_jac_dae();
354 Sparsity sp_jac_rdae();
357 Sparsity sp_jac_dae_, sp_jac_rdae_;
360 inline casadi_int nt()
const {
return tout_.size();}
366 enum DaeOut { DAE_ODE, DAE_ALG, DAE_NUM_OUT};
367 static std::vector<std::string> dae_out() {
return {
"ode",
"alg"}; }
368 enum QuadOut { QUAD_QUAD, QUAD_NUM_OUT};
369 static std::vector<std::string> quad_out() {
return {
"quad"}; }
370 enum BDynIn { BDYN_T, BDYN_X, BDYN_Z, BDYN_P, BDYN_U,
371 BDYN_OUT_ODE, BDYN_OUT_ALG, BDYN_OUT_QUAD, BDYN_OUT_ZERO,
372 BDYN_ADJ_ODE, BDYN_ADJ_ALG, BDYN_ADJ_QUAD, BDYN_ADJ_ZERO, BDYN_NUM_IN};
373 static std::string bdyn_in(casadi_int i);
374 static std::vector<std::string> bdyn_in();
375 enum BDynOut { BDYN_ADJ_T, BDYN_ADJ_X, BDYN_ADJ_Z, BDYN_ADJ_P, BDYN_ADJ_U, BDYN_NUM_OUT};
376 static std::string bdyn_out(casadi_int i);
377 static std::vector<std::string> bdyn_out();
378 enum DAEBOut { BDAE_ADJ_X, BDAE_ADJ_Z, BDAE_NUM_OUT};
379 static std::vector<std::string> bdae_out() {
return {
"adj_x",
"adj_z"}; }
380 enum QuadBOut { BQUAD_ADJ_P, BQUAD_ADJ_U, BQUAD_NUM_OUT};
381 static std::vector<std::string> bquad_out() {
return {
"adj_p",
"adj_u"}; }
388 std::vector<double> tout_;
391 casadi_int nfwd_, nadj_;
397 casadi_int nx_, nz_, nq_, nx1_, nz1_, nq1_;
400 casadi_int nrx_, nrz_, nrq_, nuq_, nrx1_, nrz1_, nrq1_, nuq1_;
403 casadi_int np_, nrp_, np1_, nrp1_;
406 casadi_int nu_, nu1_;
415 std::vector<double> nom_x_, nom_z_;
418 Dict augmented_options_;
427 Function transition_;
430 casadi_int max_event_iter_;
433 casadi_int max_events_;
439 double event_acceptable_tol_;
442 typedef Integrator* (*Creator)(
const std::string& name,
const Function& oracle,
443 double t0,
const std::vector<double>& tout);
449 static std::map<std::string, Plugin> solvers_;
451 #ifdef CASADI_WITH_THREADSAFE_SYMBOLICS
452 static std::mutex mutex_solvers_;
456 static const std::string infix_;
459 template<
typename XType>
460 static Function map2oracle(
const std::string& name,
const std::map<std::string, XType>& d);
465 void serialize_body(SerializingStream &s)
const override;
469 void serialize_type(SerializingStream &s)
const override;
474 static ProtoFunction* deserialize(DeserializingStream& s);
479 std::string serialize_base_function()
const override {
return "Integrator"; }
482 static bool grid_in(casadi_int i);
485 static bool grid_out(casadi_int i);
488 static casadi_int adjmap_out(casadi_int i);
494 explicit Integrator(DeserializingStream& s);
555 struct CASADI_EXPORT FixedStepMemory :
public IntegratorMemory {
557 double *v, *v_prev, *q_prev;
560 double *rv, *adj_u, *adj_p_prev, *adj_u_prev;
563 double *x_tape, *v_tape;
566 class CASADI_EXPORT FixedStepIntegrator :
public Integrator {
570 explicit FixedStepIntegrator(
const std::string& name,
const Function& dae,
571 double t0,
const std::vector<double>& tout);
574 ~FixedStepIntegrator()
override;
580 static const Options options_;
581 const Options& get_options()
const override {
return options_;}
585 void init(
const Dict& opts)
override;
590 void set_work(
void* mem,
const double**& arg,
double**& res,
591 casadi_int*& iw,
double*& w)
const override;
594 Function create_advanced(
const Dict& opts)
override;
599 void* alloc_mem()
const override {
return new FixedStepMemory();}
604 int init_mem(
void* mem)
const override;
609 void free_mem(
void *mem)
const override {
delete static_cast<FixedStepMemory*
>(mem);}
612 virtual void setup_step() = 0;
617 void reset(IntegratorMemory* mem,
bool first_call)
const override;
622 int advance_noevent(IntegratorMemory* mem)
const override;
625 void resetB(IntegratorMemory* mem)
const override;
628 void impulseB(IntegratorMemory* mem,
629 const double* adj_x,
const double* adj_z,
const double* adj_q)
const override;
634 void retreat(IntegratorMemory* mem,
const double* u,
635 double* adj_x,
double* adj_p,
double* adj_u)
const override;
638 void stepF(FixedStepMemory* m,
double t,
double h,
639 const double* x0,
const double* v0,
double* xf,
double* vf,
double* qf)
const;
642 void stepB(FixedStepMemory* m,
double t,
double h,
643 const double* x0,
const double* xf,
const double* vf,
644 const double* adj_xf,
const double* rv0,
645 double* adj_x0,
double* adj_p,
double* adj_u)
const;
648 casadi_int nk_target_;
651 std::vector<casadi_int> disc_;
654 casadi_int nv_, nv1_, nrv_, nrv1_;
659 void serialize_body(SerializingStream &s)
const override;
665 explicit FixedStepIntegrator(DeserializingStream& s);
668 class CASADI_EXPORT ImplicitFixedStepIntegrator :
public FixedStepIntegrator {
672 explicit ImplicitFixedStepIntegrator(
const std::string& name,
const Function& dae,
673 double t0,
const std::vector<double>& tout);
676 ~ImplicitFixedStepIntegrator()
override;
682 static const Options options_;
683 const Options& get_options()
const override {
return options_;}
687 void init(
const Dict& opts)
override;
692 void serialize_body(SerializingStream &s)
const override;
698 explicit ImplicitFixedStepIntegrator(DeserializingStream& s);
CASADI_EXPORT std::vector< std::string > integrator_in()
Get input scheme of integrators.
CASADI_EXPORT std::vector< std::string > integrator_out()
Get integrator output scheme of integrators.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.