26 #ifndef CASADI_INTEGRATOR_IMPL_HPP
27 #define CASADI_INTEGRATOR_IMPL_HPP
29 #include "integrator.hpp"
30 #include "oracle_function.hpp"
31 #include "plugin_interface.hpp"
32 #include "casadi_enum.hpp"
41 struct CASADI_EXPORT IntegratorMemory :
public OracleMemory {
53 struct CASADI_EXPORT SpForwardMem {
61 struct CASADI_EXPORT SpReverseMem {
76 Integrator :
public OracleFunction,
public PluginInterface<Integrator> {
81 Integrator(
const std::string& name,
const Function& oracle,
82 double t0,
const std::vector<double>& tout);
87 ~Integrator()
override=0;
93 size_t get_n_in()
override {
return INTEGRATOR_NUM_IN;}
94 size_t get_n_out()
override {
return INTEGRATOR_NUM_OUT;}
101 Sparsity get_sparsity_in(casadi_int i)
override;
102 Sparsity get_sparsity_out(casadi_int i)
override;
109 std::string get_name_in(casadi_int i)
override {
return integrator_in(i);}
110 std::string get_name_out(casadi_int i)
override {
return integrator_out(i);}
116 int init_mem(
void* mem)
const override;
122 static const Options options_;
123 const Options& get_options()
const override {
return options_;}
129 void init(
const Dict& opts)
override;
132 virtual Function create_advanced(
const Dict& opts);
134 virtual MX algebraic_state_init(
const MX& x0,
const MX& z0)
const {
return z0; }
135 virtual MX algebraic_state_output(
const MX& Z)
const {
return Z; }
140 virtual void reset(IntegratorMemory* mem,
141 const double* u,
const double* x,
const double* z,
const double* p)
const = 0;
146 casadi_int next_stop(casadi_int k,
const double* u)
const;
151 virtual void advance(IntegratorMemory* mem,
152 const double* u,
double* x,
double* z,
double* q)
const = 0;
157 virtual void resetB(IntegratorMemory* mem)
const = 0;
162 casadi_int next_stopB(casadi_int k,
const double* u)
const;
167 virtual void impulseB(IntegratorMemory* mem,
168 const double* rx,
const double* rz,
const double* rp)
const = 0;
173 virtual void retreat(IntegratorMemory* mem,
const double* u,
174 double* rx,
double* rq,
double* uq)
const = 0;
179 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w,
void* mem)
const override;
184 virtual void print_stats(IntegratorMemory* mem)
const {}
187 int fdae_sp_forward(SpForwardMem* m,
const bvec_t* x,
188 const bvec_t* p,
const bvec_t* u, bvec_t* ode, bvec_t* alg)
const;
191 int fquad_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
192 const bvec_t* p,
const bvec_t* u, bvec_t* quad)
const;
195 int bdae_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
196 const bvec_t* p,
const bvec_t* u,
const bvec_t* rx,
const bvec_t* rp,
197 bvec_t* adj_x, bvec_t* adj_z)
const;
200 int bquad_sp_forward(SpForwardMem* m,
const bvec_t* x,
const bvec_t* z,
201 const bvec_t* p,
const bvec_t* u,
const bvec_t* rx,
const bvec_t* rz,
const bvec_t* rp,
202 bvec_t* adj_p, bvec_t* adj_u)
const;
207 int sp_forward(
const bvec_t** arg, bvec_t** res,
208 casadi_int* iw, bvec_t* w,
void* mem)
const override;
211 int fdae_sp_reverse(SpReverseMem* m, bvec_t* x,
212 bvec_t* p, bvec_t* u, bvec_t* ode, bvec_t* alg)
const;
215 int fquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
216 bvec_t* p, bvec_t* u, bvec_t* quad)
const;
219 int bdae_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
220 bvec_t* p, bvec_t* u, bvec_t* rx, bvec_t* rp,
221 bvec_t* adj_x, bvec_t* adj_z)
const;
224 int bquad_sp_reverse(SpReverseMem* m, bvec_t* x, bvec_t* z,
225 bvec_t* p, bvec_t* u, bvec_t* rx, bvec_t* rz, bvec_t* rp,
226 bvec_t* adj_p, bvec_t* adj_u)
const;
231 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w,
void* mem)
const override;
235 bool has_spfwd()
const override {
return true;}
236 bool has_sprev()
const override {
return true;}
243 Function get_forward(casadi_int nfwd,
const std::string& name,
244 const std::vector<std::string>& inames,
245 const std::vector<std::string>& onames,
246 const Dict& opts)
const override;
247 bool has_forward(casadi_int nfwd)
const override {
return true;}
254 Function get_reverse(casadi_int nadj,
const std::string& name,
255 const std::vector<std::string>& inames,
256 const std::vector<std::string>& onames,
257 const Dict& opts)
const override;
258 bool has_reverse(casadi_int nadj)
const override {
return true;}
264 virtual Dict getDerivativeOptions(
bool fwd)
const;
270 template<
typename MatType> Function get_forward_dae(
const std::string& name)
const;
271 Function augmented_dae()
const;
275 static bool all_zero(
const double* v, casadi_int n);
278 Sparsity sp_jac_aug(
const Sparsity& J,
const Sparsity& J1)
const;
281 Sparsity sp_jac_dae();
284 Sparsity sp_jac_rdae();
287 Sparsity sp_jac_dae_, sp_jac_rdae_;
290 inline casadi_int nt()
const {
return tout_.size();}
296 enum DaeOut { DAE_ODE, DAE_ALG, DAE_NUM_OUT};
297 static std::vector<std::string> dae_out() {
return {
"ode",
"alg"}; }
298 enum QuadOut { QUAD_QUAD, QUAD_NUM_OUT};
299 static std::vector<std::string> quad_out() {
return {
"quad"}; }
300 enum BDynIn { BDYN_T, BDYN_X, BDYN_Z, BDYN_P, BDYN_U,
301 BDYN_OUT_ODE, BDYN_OUT_ALG, BDYN_OUT_QUAD,
302 BDYN_ADJ_ODE, BDYN_ADJ_ALG, BDYN_ADJ_QUAD, BDYN_NUM_IN};
303 static std::string bdyn_in(casadi_int i);
304 static std::vector<std::string> bdyn_in();
305 enum BDynOut { BDYN_ADJ_T, BDYN_ADJ_X, BDYN_ADJ_Z, BDYN_ADJ_P, BDYN_ADJ_U, BDYN_NUM_OUT};
306 static std::string bdyn_out(casadi_int i);
307 static std::vector<std::string> bdyn_out();
308 enum DAEBOut { BDAE_ADJ_X, BDAE_ADJ_Z, BDAE_NUM_OUT};
309 static std::vector<std::string> bdae_out() {
return {
"adj_x",
"adj_z"}; }
310 enum QuadBOut { BQUAD_ADJ_P, BQUAD_ADJ_U, BQUAD_NUM_OUT};
311 static std::vector<std::string> bquad_out() {
return {
"adj_p",
"adj_u"}; }
318 std::vector<double> tout_;
321 casadi_int nfwd_, nadj_;
327 casadi_int nx_, nz_, nq_, nx1_, nz1_, nq1_;
330 casadi_int nrx_, nrz_, nrq_, nuq_, nrx1_, nrz1_, nrq1_, nuq1_;
333 casadi_int np_, nrp_, np1_, nrp1_;
336 casadi_int nu_, nu1_;
339 std::vector<double> nom_x_, nom_z_;
342 Dict augmented_options_;
351 typedef Integrator* (*Creator)(
const std::string& name,
const Function& oracle,
352 double t0,
const std::vector<double>& tout);
358 static std::map<std::string, Plugin> solvers_;
361 static const std::string infix_;
364 template<
typename XType>
365 static Function map2oracle(
const std::string& name,
const std::map<std::string, XType>& d);
370 void serialize_body(SerializingStream &s)
const override;
374 void serialize_type(SerializingStream &s)
const override;
379 static ProtoFunction* deserialize(DeserializingStream& s);
384 std::string serialize_base_function()
const override {
return "Integrator"; }
387 static bool grid_in(casadi_int i);
390 static bool grid_out(casadi_int i);
393 static casadi_int adjmap_out(casadi_int i);
399 explicit Integrator(DeserializingStream& s);
460 struct CASADI_EXPORT FixedStepMemory :
public IntegratorMemory {
462 double *x, *z, *rx, *rz, *rq, *x_prev, *rx_prev;
465 double *v, *p, *u, *q, *v_prev, *q_prev;
468 double *rv, *rp, *uq, *rq_prev, *uq_prev;
471 double *x_tape, *v_tape;
474 class CASADI_EXPORT FixedStepIntegrator :
public Integrator {
478 explicit FixedStepIntegrator(
const std::string& name,
const Function& dae,
479 double t0,
const std::vector<double>& tout);
482 ~FixedStepIntegrator()
override;
488 static const Options options_;
489 const Options& get_options()
const override {
return options_;}
493 void init(
const Dict& opts)
override;
498 void set_work(
void* mem,
const double**& arg,
double**& res,
499 casadi_int*& iw,
double*& w)
const override;
502 Function create_advanced(
const Dict& opts)
override;
507 void* alloc_mem()
const override {
return new FixedStepMemory();}
512 int init_mem(
void* mem)
const override;
517 void free_mem(
void *mem)
const override {
delete static_cast<FixedStepMemory*
>(mem);}
520 virtual void setup_step() = 0;
525 void reset(IntegratorMemory* mem,
526 const double* u,
const double* x,
const double* z,
const double* p)
const override;
531 void advance(IntegratorMemory* mem,
532 const double* u,
double* x,
double* z,
double* q)
const override;
535 void resetB(IntegratorMemory* mem)
const override;
538 void impulseB(IntegratorMemory* mem,
539 const double* rx,
const double* rz,
const double* rp)
const override;
544 void retreat(IntegratorMemory* mem,
const double* u,
545 double* rx,
double* rq,
double* uq)
const override;
548 void stepF(FixedStepMemory* m,
double t,
double h,
549 const double* x0,
const double* v0,
double* xf,
double* vf,
double* qf)
const;
552 void stepB(FixedStepMemory* m,
double t,
double h,
553 const double* x0,
const double* xf,
const double* vf,
554 const double* rx0,
const double* rv0,
555 double* rxf,
double* rqf,
double* uqf)
const;
558 casadi_int nk_target_;
561 std::vector<casadi_int> disc_;
564 casadi_int nv_, nv1_, nrv_, nrv1_;
569 void serialize_body(SerializingStream &s)
const override;
575 explicit FixedStepIntegrator(DeserializingStream& s);
578 class CASADI_EXPORT ImplicitFixedStepIntegrator :
public FixedStepIntegrator {
582 explicit ImplicitFixedStepIntegrator(
const std::string& name,
const Function& dae,
583 double t0,
const std::vector<double>& tout);
586 ~ImplicitFixedStepIntegrator()
override;
592 static const Options options_;
593 const Options& get_options()
const override {
return options_;}
597 void init(
const Dict& opts)
override;
602 void serialize_body(SerializingStream &s)
const override;
608 explicit ImplicitFixedStepIntegrator(DeserializingStream& s);
CASADI_EXPORT std::vector< std::string > integrator_in()
Get input scheme of integrators.
CASADI_EXPORT std::vector< std::string > integrator_out()
Get integrator output scheme of integrators.
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.