25 #include "casadi_runtime.hpp"
26 #include "../casadi_misc.hpp"
30 #ifndef CASADI_CASADI_SHARED_HPP
31 #define CASADI_CASADI_SHARED_HPP
36 const std::vector<casadi_int>& dim_a,
const std::vector<casadi_int>& dim_b,
const std::vector<casadi_int>& dim_c,
37 const std::vector<casadi_int>& a,
const std::vector<casadi_int>& b,
const std::vector<casadi_int>& c,
38 std::vector<casadi_int>& iter_dims,
39 std::vector<casadi_int>& strides_a, std::vector<casadi_int>& strides_b, std::vector<casadi_int>& strides_c
42 casadi_assert_dev(A.is_vector() && A.is_dense());
43 casadi_assert_dev(B.is_vector() && B.is_dense());
44 casadi_assert_dev(C.is_vector() && C.is_dense());
47 casadi_assert_dev(A.numel()==product(dim_a));
48 casadi_assert_dev(B.numel()==product(dim_b));
49 casadi_assert_dev(C.numel()==product(dim_c));
51 casadi_assert_dev(dim_a.size()==a.size());
52 casadi_assert_dev(dim_b.size()==b.size());
54 casadi_assert_dev(c.size()<=a.size()+b.size());
56 std::map<casadi_int, casadi_int> dim_map;
59 for (casadi_int i=0;i<a.size();++i) {
62 auto al = dim_map.find(ai);
63 if (al==dim_map.end()) {
64 dim_map[ai] = dim_a[i];
66 casadi_assert_dev(al->second==dim_a[i]);
70 for (casadi_int i=0;i<b.size();++i) {
73 auto bl = dim_map.find(bi);
74 if (bl==dim_map.end()) {
75 dim_map[bi] = dim_b[i];
77 casadi_assert_dev(bl->second==dim_b[i]);
81 for (casadi_int i=0;i<c.size();++i) {
84 auto cl = dim_map.find(ci);
85 if (cl==dim_map.end()) {
86 dim_map[ci] = dim_c[i];
88 casadi_assert_dev(cl->second==dim_c[i]);
92 std::vector< std::pair<casadi_int, casadi_int> > dim_map_pair;
93 for (
const auto & i : dim_map) dim_map_pair.push_back(i);
95 std::sort(dim_map_pair.begin(), dim_map_pair.end(),
96 [](
const std::pair<casadi_int, casadi_int>& a,
const std::pair<casadi_int, casadi_int>& b) { return a.second < b.second;});
98 std::vector<casadi_int> dim_map_keys;
100 casadi_int n_iter = 1;
101 for (
const auto& e : dim_map_pair) {
103 dim_map_keys.push_back(-e.first);
104 iter_dims.push_back(e.second);
108 strides_a.resize(iter_dims.size()+1);
110 strides_b.resize(iter_dims.size()+1);
112 strides_c.resize(iter_dims.size()+1);
114 std::vector<casadi_int> lu;
116 if (!dim_map_keys.empty()) lu =
lookupvector(dim_map_keys);
119 casadi_int cumprod = 1;
120 for (casadi_int j=0;j<a.size();++j) {
122 strides_a[1+lu[-a[j]]] = cumprod;
124 strides_a[0]+=a[j]*cumprod;
129 for (casadi_int j=0;j<b.size();++j) {
131 strides_b[1+lu[-b[j]]] = cumprod;
133 strides_b[0]+=b[j]*cumprod;
138 for (casadi_int j=0;j<c.size();++j) {
140 strides_c[1+lu[-c[j]]] = cumprod;
142 strides_c[0]+=c[j]*cumprod;
162 const std::vector<casadi_int>& iter_dims,
163 const std::vector<casadi_int>& strides_a,
const std::vector<casadi_int>& strides_b,
const std::vector<casadi_int>& strides_c,
164 const T* a_in,
const T* b_in, T* c_in) {
168 casadi_int iter_dim1 = 1, iter_dim2 = 1, iter_dim3 = 1;
170 casadi_int n = iter_dims.size();
172 casadi_int stridea1=0, strideb1=0, stridec1=0;
173 casadi_int stridea2=0, strideb2=0, stridec2=0;
174 casadi_int stridea3=0, strideb3=0, stridec3=0;
176 iter_dim3 = iter_dims[n-1];
177 stridea3 = strides_a[n];
178 strideb3 = strides_b[n];
179 stridec3 = strides_c[n];
182 iter_dim2 = iter_dims[n-2];
183 stridea2 = strides_a[n-1];
184 strideb2 = strides_b[n-1];
185 stridec2 = strides_c[n-1];
188 iter_dim1 = iter_dims[n-3];
189 stridea1 = strides_a[n-2];
190 strideb1 = strides_b[n-2];
191 stridec1 = strides_c[n-2];
195 const casadi_int* ptr_iter_dims = get_ptr(iter_dims);
197 const casadi_int *ptr_strides_a = get_ptr(strides_a)+1;
198 const casadi_int *ptr_strides_b = get_ptr(strides_b)+1;
199 const casadi_int *ptr_strides_c = get_ptr(strides_c)+1;
202 const T* a_perm = a_in+strides_a[0];
203 const T* b_perm = b_in+strides_b[0];
204 T* c_perm = c_in+strides_c[0];
206 n_iter/= iter_dim1*iter_dim2*iter_dim3;
209 for (casadi_int i=0;i<n_iter;++i) {
218 for (casadi_int j=0;j<n-3;++j) {
219 casadi_int ind = sub % ptr_iter_dims[j];
220 a+= ptr_strides_a[j]*ind;
221 b+= ptr_strides_b[j]*ind;
222 c+= ptr_strides_c[j]*ind;
223 sub/= ptr_iter_dims[j];
229 for (casadi_int i1=0;i1<iter_dim1;++i1) {
233 for (casadi_int i2=0;i2<iter_dim2;++i2) {
237 for (casadi_int i3=0;i3<iter_dim3;++i3) {
239 Contraction<T>(*a3, *b3, *c3);
void einstein_eval(casadi_int n_iter, const std::vector< casadi_int > &iter_dims, const std::vector< casadi_int > &strides_a, const std::vector< casadi_int > &strides_b, const std::vector< casadi_int > &strides_c, const T *a_in, const T *b_in, T *c_in)
casadi_int einstein_process(const T &A, const T &B, const T &C, const std::vector< casadi_int > &dim_a, const std::vector< casadi_int > &dim_b, const std::vector< casadi_int > &dim_c, const std::vector< casadi_int > &a, const std::vector< casadi_int > &b, const std::vector< casadi_int > &c, std::vector< casadi_int > &iter_dims, std::vector< casadi_int > &strides_a, std::vector< casadi_int > &strides_b, std::vector< casadi_int > &strides_c)
void Contraction(const T &a, const T &b, T &r)
CASADI_EXPORT std::vector< casadi_int > lookupvector(const std::vector< casadi_int > &v, casadi_int size)
Returns a vector for quickly looking up entries of supplied list.