26 #ifndef CASADI_EINSTEIN_HPP
27 #define CASADI_EINSTEIN_HPP
29 #include "mx_node.hpp"
47 const std::vector<casadi_int>& dim_c,
const std::vector<casadi_int>& dim_a,
48 const std::vector<casadi_int>& dim_b,
49 const std::vector<casadi_int>& c,
const std::vector<casadi_int>& a,
50 const std::vector<casadi_int>& b);
60 std::string disp(
const std::vector<std::string>& arg)
const override;
66 const std::vector<casadi_int>& arg,
67 const std::vector<casadi_int>& res,
68 const std::vector<bool>& arg_is_ref,
69 std::vector<bool>& res_is_ref)
const override;
73 int eval_gen(
const T** arg, T** res, casadi_int* iw, T* w)
const;
76 int eval(
const double** arg,
double** res, casadi_int* iw,
double* w)
const override;
79 int eval_sx(
const SXElem** arg,
SXElem** res, casadi_int* iw,
SXElem* w)
const override;
84 void eval_mx(
const std::vector<MX>& arg, std::vector<MX>& res)
const override;
89 void ad_forward(
const std::vector<std::vector<MX> >& fseed,
90 std::vector<std::vector<MX> >& fsens)
const override;
95 void ad_reverse(
const std::vector<std::vector<MX> >& aseed,
96 std::vector<std::vector<MX> >& asens)
const override;
101 int sp_forward(
const bvec_t** arg,
bvec_t** res, casadi_int* iw,
bvec_t* w)
const override;
120 return sameOpAndDeps(node, depth) &&
dynamic_cast<const Einstein*
>(node)!=
nullptr;
126 size_t sz_w()
const override {
return sparsity().size1();}
130 return {{
"dim_a", dim_a_}, {
"dim_b", dim_b_}, {
"dim_c", dim_c_},
131 {
"a", a_}, {
"b", b_}, {
"c", c_},
132 {
"iter_dims", iter_dims_},
133 {
"strides_a", strides_a_}, {
"strides_b", strides_b_}, {
"strides_c", strides_c_},
134 {
"n_iter", n_iter_}};
138 std::vector<casadi_int> dim_c_,
dim_a_, dim_b_;
140 std::vector<casadi_int> c_,
a_, b_;
Helper class for C code generation.
An MX atomic for an Einstein product,.
size_t sz_w() const override
Get required length of w field.
~Einstein() override
Destructor.
Dict info() const override
std::vector< casadi_int > dim_a_
std::vector< casadi_int > strides_a_
casadi_int op() const override
Get the operation.
std::vector< casadi_int > iter_dims_
casadi_int n_inplace() const override
Can the operation be performed inplace (i.e. overwrite the result)
std::vector< casadi_int > strides_c_
bool is_equal(const MXNode *node, casadi_int depth) const override
Check if two nodes are equivalent up to a given depth.
std::vector< casadi_int > a_
std::vector< casadi_int > strides_b_
Node class for MX objects.
The basic scalar symbolic class of CasADi.
unsigned long long bvec_t
GenericType::Dict Dict
C++ equivalent of Python's dict or MATLAB's struct.