26 #ifndef CASADI_EINSTEIN_HPP
27 #define CASADI_EINSTEIN_HPP
29 #include "mx_node.hpp"
40 class CASADI_EXPORT Einstein :
public MXNode {
46 Einstein(
const MX& C,
const MX& A,
const MX& B,
47 const std::vector<casadi_int>& dim_c,
const std::vector<casadi_int>& dim_a,
48 const std::vector<casadi_int>& dim_b,
49 const std::vector<casadi_int>& c,
const std::vector<casadi_int>& a,
50 const std::vector<casadi_int>& b);
55 ~Einstein()
override {}
60 std::string disp(
const std::vector<std::string>& arg)
const override;
65 void generate(CodeGenerator& g,
66 const std::vector<casadi_int>& arg,
67 const std::vector<casadi_int>& res)
const override;
71 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
74 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
77 int eval_sx(
const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w)
const override;
82 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
87 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
88 std::vector<std::vector<MX> >& fsens)
const override;
93 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
94 std::vector<std::vector<MX> >& asens)
const override;
99 int sp_forward(
const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
104 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w)
const override;
109 casadi_int op()
const override {
return OP_EINSTEIN;}
112 casadi_int n_inplace()
const override {
return 1;}
117 bool is_equal(
const MXNode* node, casadi_int depth)
const override {
118 return sameOpAndDeps(node, depth) &&
dynamic_cast<const Einstein*
>(node)!=
nullptr;
124 size_t sz_w()
const override {
return sparsity().size1();}
127 Dict info()
const override {
128 return {{
"dim_a", dim_a_}, {
"dim_b", dim_b_}, {
"dim_c", dim_c_},
129 {
"a", a_}, {
"b", b_}, {
"c", c_},
130 {
"iter_dims", iter_dims_},
131 {
"strides_a", strides_a_}, {
"strides_b", strides_b_}, {
"strides_c", strides_c_},
132 {
"n_iter", n_iter_}};
136 std::vector<casadi_int> dim_c_, dim_a_, dim_b_;
138 std::vector<casadi_int> c_, a_, b_;
140 std::vector<casadi_int> iter_dims_;
142 std::vector<casadi_int> strides_a_;
143 std::vector<casadi_int> strides_b_;
144 std::vector<casadi_int> strides_c_;
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.