casadi_blazing_3d_boor_eval.hpp
1 //
2 // MIT No Attribution
3 //
4 // Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl, KU Leuven.
5 //
6 // Permission is hereby granted, free of charge, to any person obtaining a copy of this
7 // software and associated documentation files (the "Software"), to deal in the Software
8 // without restriction, including without limitation the rights to use, copy, modify,
9 // merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
10 // permit persons to whom the Software is furnished to do so.
11 //
12 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
13 // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
14 // PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
15 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
16 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
17 // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 //
19 
20 // SYMBOL "blazing_3d_boor_eval"
21 template<typename T1>
22 void casadi_blazing_3d_boor_eval(T1* f, T1* J, T1* H, const T1* all_knots, const casadi_int* offset, const T1* c, const T1* dc, const T1* ddc, const T1* all_x, const casadi_int* lookup_mode, casadi_int* iw, T1* w) { // NOLINT(whitespace/line_length)
23  casadi_int n_dims = 3;
24  casadi_int m = 1;
25  casadi_int n_iter, k, i, pivot;
26  casadi_int *boor_offset, *starts, *index, *coeff_offset;
27  T1 *cumprod;
28  boor_offset = iw; iw+=n_dims+1;
29  starts = iw; iw+=n_dims;
30  index = iw; iw+=n_dims;
31  coeff_offset = iw;
32  cumprod = w; w+= n_dims+1;
33  boor_offset[0] = 0;
34  cumprod[n_dims] = 1;
35  coeff_offset[n_dims] = 0;
36 
37  casadi_int stride1 = offset[1]-offset[0]-4;
38  casadi_int stride2 = (offset[2]-offset[1]-4)*stride1;
39 
40  simde__m256d zero = simde_mm256_set1_pd(0.0);
41 
42  simde__m256d boor_start_0000 = zero;
43  simde__m256d boor_start_1111 = simde_mm256_set1_pd(1.0);
44  simde__m256d boor_start_0001 = simde_mm256_set_pd(1.0, 0.0, 0.0, 0.0);
45  simde__m256d boor_start_0010 = simde_mm256_set_pd(0.0, 1.0, 0.0, 0.0);
46 
47  simde__m256d boor0_d3;
48  simde__m256d boor0_d2;
49  simde__m256d boor0_d1;
50  simde__m256d boor0_d0;
51 
52  simde__m256d boor1_d3;
53  simde__m256d boor1_d2;
54  simde__m256d boor1_d1;
55  simde__m256d boor1_d0;
56 
57  simde__m256d boor2_d3;
58  simde__m256d boor2_d2;
59  simde__m256d boor2_d1;
60  simde__m256d boor2_d0;
61 
62  const T1* knots;
63  T1 x;
64  casadi_int degree, n_knots, n_b, L, start;
65  degree = 3;
66  knots = all_knots + offset[0];
67  n_knots = offset[0+1]-offset[0];
68  n_b = n_knots-degree-1;
69  x = all_x[0];
70  L = casadi_low(x, knots+degree, n_knots-2*degree, lookup_mode[0]);
71  start = L;
72  if (start>n_b-degree-1) start = n_b-degree-1;
73  starts[0] = start;
74  boor0_d3 = boor_start_0000;
75  if (x>=knots[0] && x<=knots[n_knots-1]) {
76  if (x==knots[1]) {
77  boor0_d3 = boor_start_1111;
78  } else if (x==knots[n_knots-1]) {
79  boor0_d3 = boor_start_0001;
80  } else if (knots[L+degree]==x) {
81  boor0_d3 = boor_start_0010;
82  } else {
83  boor0_d3 = boor_start_0001;
84  }
85  }
86  casadi_blazing_de_boor(x, knots+start, &boor0_d0, &boor0_d1, &boor0_d2, &boor0_d3);
87 
88  knots = all_knots + offset[1];
89  n_knots = offset[1+1]-offset[1];
90  n_b = n_knots-degree-1;
91  x = all_x[1];
92  L = casadi_low(x, knots+degree, n_knots-2*degree, lookup_mode[1]);
93  start = L;
94  if (start>n_b-degree-1) start = n_b-degree-1;
95  starts[1] = start;
96  boor1_d3 = boor_start_0000;
97  if (x>=knots[0] && x<=knots[n_knots-1]) {
98  if (x==knots[1]) {
99  boor1_d3 = boor_start_1111;
100  } else if (x==knots[n_knots-1]) {
101  boor1_d3 = boor_start_0001;
102  } else if (knots[L+degree]==x) {
103  boor1_d3 = boor_start_0010;
104  } else {
105  boor1_d3 = boor_start_0001;
106  }
107  }
108  casadi_blazing_de_boor(x, knots+start, &boor1_d0, &boor1_d1, &boor1_d2, &boor1_d3);
109 
110  knots = all_knots + offset[2];
111  n_knots = offset[2+1]-offset[2];
112  n_b = n_knots-degree-1;
113  x = all_x[2];
114  L = casadi_low(x, knots+degree, n_knots-2*degree, lookup_mode[2]);
115  start = L;
116  if (start>n_b-degree-1) start = n_b-degree-1;
117  starts[2] = start;
118  boor2_d3 = boor_start_0000;
119  if (x>=knots[0] && x<=knots[n_knots-1]) {
120  if (x==knots[1]) {
121  boor2_d3 = boor_start_1111;
122  } else if (x==knots[n_knots-1]) {
123  boor2_d3 = boor_start_0001;
124  } else if (knots[L+degree]==x) {
125  boor2_d3 = boor_start_0010;
126  } else {
127  boor2_d3 = boor_start_0001;
128  }
129  }
130  casadi_blazing_de_boor(x, knots+start, &boor2_d0, &boor2_d1, &boor2_d2, &boor2_d3);
131 
132  simde__m256d C[16];
133 
134  for (int j=0;j<4;++j) {
135  for (int k=0;k<4;++k) {
136  C[j+4*k] = simde_mm256_loadu_pd(c+(starts[1]+j)*stride1+(starts[2]+k)*stride2+starts[0]);
137  }
138  }
139 
140  simde__m256d a, b0, b1, b2, b3, c0, c1, c2, c3, r;
141  simde__m256d ab[4], cab[4];
142  simde__m128d r0, r1;
143 
144  a = boor0_d0;
145  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
146  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
147  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
148  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
149 
150  c0 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
151  c1 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
152  c2 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
153  c3 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
154 
155  // Need to compute sum_abc C_abc A_a B_b C_c
156 
157  // Step 1: Outer product a b: A_a B_b
158  ab[0] = simde_mm256_mul_pd(a, b0);
159  ab[1] = simde_mm256_mul_pd(a, b1);
160  ab[2] = simde_mm256_mul_pd(a, b2);
161  ab[3] = simde_mm256_mul_pd(a, b3);
162 
163  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
164  // cab <- cab + ab[i]*C[i]
165  for (int i=0;i<4;++i) {
166  cab[i] = simde_mm256_set1_pd(0);
167  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
168  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
169  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
170  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
171  }
172 
173  if (f) {
174  // Reduce over the c direction
175  r = simde_mm256_set1_pd(0);
176  r = simde_mm256_fmadd_pd(cab[0], c0, r);
177  r = simde_mm256_fmadd_pd(cab[1], c1, r);
178  r = simde_mm256_fmadd_pd(cab[2], c2, r);
179  r = simde_mm256_fmadd_pd(cab[3], c3, r);
180 
181  // Sum all r entries
182  r0 = simde_mm256_castpd256_pd128(r);
183  r1 = simde_mm256_extractf128_pd(r, 1);
184  r0 = simde_mm_add_pd(r0, r1);
185  f[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
186  }
187 
188  // First derivative
189  if (dc && J) {
190  stride1 = offset[1]-offset[0]-4-1;
191  stride2 = (offset[2]-offset[1]-4)*stride1;
192  for (int j=0;j<4;++j) {
193  for (int k=0;k<4;++k) {
194  C[j+4*k] = simde_mm256_loadu_pd(
195  dc+(starts[1]+j)*stride1+(starts[2]+k)*stride2+starts[0]-1);
196  }
197  }
198  dc += stride2*(offset[3]-offset[2]-4);
199 
200  a = boor0_d1;
201  ab[0] = simde_mm256_mul_pd(a, b0);
202  ab[1] = simde_mm256_mul_pd(a, b1);
203  ab[2] = simde_mm256_mul_pd(a, b2);
204  ab[3] = simde_mm256_mul_pd(a, b3);
205  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
206  // cab <- cab + ab[i]*C[i]
207  for (int i=0;i<4;++i) {
208  cab[i] = simde_mm256_set1_pd(0);
209  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
210  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
211  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
212  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
213  }
214 
215  // Reduce over the c direction
216  r = simde_mm256_set1_pd(0);
217  r = simde_mm256_fmadd_pd(cab[0], c0, r);
218  r = simde_mm256_fmadd_pd(cab[1], c1, r);
219  r = simde_mm256_fmadd_pd(cab[2], c2, r);
220  r = simde_mm256_fmadd_pd(cab[3], c3, r);
221 
222  // Sum all r entries
223  r0 = simde_mm256_castpd256_pd128(r);
224  r1 = simde_mm256_extractf128_pd(r, 1);
225  r0 = simde_mm_add_pd(r0, r1);
226  J[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
227 
228 
229  stride1 = offset[1]-offset[0]-4;
230  stride2 = (offset[2]-offset[1]-4-1)*stride1;
231  for (int j=0;j<4;++j) {
232  for (int k=0;k<4;++k) {
233  if (j==0) {
234  C[j+4*k] = zero;
235  } else {
236  C[j+4*k] = simde_mm256_loadu_pd(
237  dc+(starts[1]+j-1)*stride1+(starts[2]+k)*stride2+starts[0]);
238  }
239  }
240  }
241  dc += stride2*(offset[3]-offset[2]-4);
242 
243  a = boor0_d0;
244 
245  b0 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
246  b1 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
247  b2 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
248  b3 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
249 
250  ab[0] = simde_mm256_mul_pd(a, b0);
251  ab[1] = simde_mm256_mul_pd(a, b1);
252  ab[2] = simde_mm256_mul_pd(a, b2);
253  ab[3] = simde_mm256_mul_pd(a, b3);
254 
255  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
256  // cab <- cab + ab[i]*C[i]
257  for (int i=0;i<4;++i) {
258  cab[i] = simde_mm256_set1_pd(0);
259  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
260  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
261  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
262  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
263  }
264 
265  // Reduce over the c direction
266  r = simde_mm256_set1_pd(0);
267  r = simde_mm256_fmadd_pd(cab[0], c0, r);
268  r = simde_mm256_fmadd_pd(cab[1], c1, r);
269  r = simde_mm256_fmadd_pd(cab[2], c2, r);
270  r = simde_mm256_fmadd_pd(cab[3], c3, r);
271 
272  // Sum all r entries
273  r0 = simde_mm256_castpd256_pd128(r);
274  r1 = simde_mm256_extractf128_pd(r, 1);
275  r0 = simde_mm_add_pd(r0, r1);
276  J[1] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
277 
278  stride1 = offset[1]-offset[0]-4;
279  stride2 = (offset[2]-offset[1]-4)*stride1;
280  for (int j=0;j<4;++j) {
281  for (int k=0;k<4;++k) {
282  if (k==0) {
283  C[j+4*k] = zero;
284  } else {
285  C[j+4*k] = simde_mm256_loadu_pd(
286  dc+(starts[1]+j)*stride1+(starts[2]+k-1)*stride2+starts[0]);
287  }
288  }
289  }
290 
291  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
292  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
293  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
294  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
295 
296  c0 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
297  c1 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
298  c2 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
299  c3 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
300 
301  ab[0] = simde_mm256_mul_pd(a, b0);
302  ab[1] = simde_mm256_mul_pd(a, b1);
303  ab[2] = simde_mm256_mul_pd(a, b2);
304  ab[3] = simde_mm256_mul_pd(a, b3);
305 
306  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
307  // cab <- cab + ab[i]*C[i]
308  for (int i=0;i<4;++i) {
309  cab[i] = simde_mm256_set1_pd(0);
310  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
311  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
312  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
313  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
314  }
315 
316  // Reduce over the c direction
317  r = simde_mm256_set1_pd(0);
318  r = simde_mm256_fmadd_pd(cab[0], c0, r);
319  r = simde_mm256_fmadd_pd(cab[1], c1, r);
320  r = simde_mm256_fmadd_pd(cab[2], c2, r);
321  r = simde_mm256_fmadd_pd(cab[3], c3, r);
322 
323  // Sum all r entries
324  r0 = simde_mm256_castpd256_pd128(r);
325  r1 = simde_mm256_extractf128_pd(r, 1);
326  r0 = simde_mm_add_pd(r0, r1);
327  J[2] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
328 
329  }
330 
331  if (ddc && H) {
332  stride1 = offset[1]-offset[0]-4-2;
333  stride2 = (offset[2]-offset[1]-4)*stride1;
334  for (int j=0;j<4;++j) {
335  for (int k=0;k<4;++k) {
336  C[j+4*k] = simde_mm256_loadu_pd(
337  ddc+(starts[1]+j)*stride1+(starts[2]+k)*stride2+starts[0]-2);
338  }
339  }
340  ddc += stride2*(offset[3]-offset[2]-4);
341 
342  a = boor0_d2;
343  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
344  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
345  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
346  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
347 
348  c0 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
349  c1 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
350  c2 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
351  c3 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
352 
353  ab[0] = simde_mm256_mul_pd(a, b0);
354  ab[1] = simde_mm256_mul_pd(a, b1);
355  ab[2] = simde_mm256_mul_pd(a, b2);
356  ab[3] = simde_mm256_mul_pd(a, b3);
357  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
358  // cab <- cab + ab[i]*C[i]
359  for (int i=0;i<4;++i) {
360  cab[i] = simde_mm256_set1_pd(0);
361  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
362  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
363  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
364  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
365  }
366 
367  // Reduce over the c direction
368  r = simde_mm256_set1_pd(0);
369  r = simde_mm256_fmadd_pd(cab[0], c0, r);
370  r = simde_mm256_fmadd_pd(cab[1], c1, r);
371  r = simde_mm256_fmadd_pd(cab[2], c2, r);
372  r = simde_mm256_fmadd_pd(cab[3], c3, r);
373 
374  // Sum all r entries
375  r0 = simde_mm256_castpd256_pd128(r);
376  r1 = simde_mm256_extractf128_pd(r, 1);
377  r0 = simde_mm_add_pd(r0, r1);
378  H[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
379 
380  stride1 = offset[1]-offset[0]-4;
381  stride2 = (offset[2]-offset[1]-4-2)*stride1;
382  for (int j=0;j<4;++j) {
383  for (int k=0;k<4;++k) {
384  if (j<=1) {
385  C[j+4*k] = zero;
386  } else {
387  C[j+4*k] = simde_mm256_loadu_pd(
388  ddc+(starts[1]+j-2)*stride1+(starts[2]+k)*stride2+starts[0]);
389  }
390  }
391  }
392  ddc += stride2*(offset[3]-offset[2]-4);
393 
394  a = boor0_d0;
395  b0 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
396  b1 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
397  b2 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
398  b3 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
399 
400  c0 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
401  c1 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
402  c2 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
403  c3 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
404 
405  ab[0] = simde_mm256_mul_pd(a, b0);
406  ab[1] = simde_mm256_mul_pd(a, b1);
407  ab[2] = simde_mm256_mul_pd(a, b2);
408  ab[3] = simde_mm256_mul_pd(a, b3);
409  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
410  // cab <- cab + ab[i]*C[i]
411  for (int i=0;i<4;++i) {
412  cab[i] = simde_mm256_set1_pd(0);
413  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
414  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
415  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
416  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
417  }
418 
419  // Reduce over the c direction
420  r = simde_mm256_set1_pd(0);
421  r = simde_mm256_fmadd_pd(cab[0], c0, r);
422  r = simde_mm256_fmadd_pd(cab[1], c1, r);
423  r = simde_mm256_fmadd_pd(cab[2], c2, r);
424  r = simde_mm256_fmadd_pd(cab[3], c3, r);
425 
426  // Sum all r entries
427  r0 = simde_mm256_castpd256_pd128(r);
428  r1 = simde_mm256_extractf128_pd(r, 1);
429  r0 = simde_mm_add_pd(r0, r1);
430  H[4] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
431 
432  stride1 = offset[1]-offset[0]-4;
433  stride2 = (offset[2]-offset[1]-4)*stride1;
434  for (int j=0;j<4;++j) {
435  for (int k=0;k<4;++k) {
436  if (k<=1) {
437  C[j+4*k] = zero;
438  } else {
439  C[j+4*k] = simde_mm256_loadu_pd(
440  ddc+(starts[1]+j)*stride1+(starts[2]+k-2)*stride2+starts[0]);
441  }
442  }
443  }
444  ddc += stride2*(offset[3]-offset[2]-4-2);
445 
446  a = boor0_d0;
447  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
448  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
449  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
450  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
451 
452  c0 = simde_mm256_permute4x64_pd(boor2_d2, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
453  c1 = simde_mm256_permute4x64_pd(boor2_d2, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
454  c2 = simde_mm256_permute4x64_pd(boor2_d2, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
455  c3 = simde_mm256_permute4x64_pd(boor2_d2, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
456 
457  ab[0] = simde_mm256_mul_pd(a, b0);
458  ab[1] = simde_mm256_mul_pd(a, b1);
459  ab[2] = simde_mm256_mul_pd(a, b2);
460  ab[3] = simde_mm256_mul_pd(a, b3);
461  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
462  // cab <- cab + ab[i]*C[i]
463  for (int i=0;i<4;++i) {
464  cab[i] = simde_mm256_set1_pd(0);
465  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
466  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
467  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
468  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
469  }
470 
471  // Reduce over the c direction
472  r = simde_mm256_set1_pd(0);
473  r = simde_mm256_fmadd_pd(cab[0], c0, r);
474  r = simde_mm256_fmadd_pd(cab[1], c1, r);
475  r = simde_mm256_fmadd_pd(cab[2], c2, r);
476  r = simde_mm256_fmadd_pd(cab[3], c3, r);
477 
478  // Sum all r entries
479  r0 = simde_mm256_castpd256_pd128(r);
480  r1 = simde_mm256_extractf128_pd(r, 1);
481  r0 = simde_mm_add_pd(r0, r1);
482  H[8] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
483 
484  stride1 = offset[1]-offset[0]-5;
485  stride2 = (offset[2]-offset[1]-5)*stride1;
486  for (int j=0;j<4;++j) {
487  for (int k=0;k<4;++k) {
488  if (j==0) {
489  C[j+4*k] = zero;
490  } else {
491  C[j+4*k] = simde_mm256_loadu_pd(
492  ddc+(starts[1]+j-1)*stride1+(starts[2]+k)*stride2+starts[0]-1);
493  }
494  }
495  }
496  ddc += stride2*(offset[3]-offset[2]-4);
497 
498  a = boor0_d1;
499 
500  b0 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
501  b1 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
502  b2 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
503  b3 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
504 
505  c0 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
506  c1 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
507  c2 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
508  c3 = simde_mm256_permute4x64_pd(boor2_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
509 
510  ab[0] = simde_mm256_mul_pd(a, b0);
511  ab[1] = simde_mm256_mul_pd(a, b1);
512  ab[2] = simde_mm256_mul_pd(a, b2);
513  ab[3] = simde_mm256_mul_pd(a, b3);
514  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
515  // cab <- cab + ab[i]*C[i]
516  for (int i=0;i<4;++i) {
517  cab[i] = simde_mm256_set1_pd(0);
518  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
519  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
520  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
521  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
522  }
523 
524  // Reduce over the c direction
525  r = simde_mm256_set1_pd(0);
526  r = simde_mm256_fmadd_pd(cab[0], c0, r);
527  r = simde_mm256_fmadd_pd(cab[1], c1, r);
528  r = simde_mm256_fmadd_pd(cab[2], c2, r);
529  r = simde_mm256_fmadd_pd(cab[3], c3, r);
530 
531  // Sum all r entries
532  r0 = simde_mm256_castpd256_pd128(r);
533  r1 = simde_mm256_extractf128_pd(r, 1);
534  r0 = simde_mm_add_pd(r0, r1);
535  H[1] = H[3] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
536 
537  stride1 = offset[1]-offset[0]-4;
538  stride2 = (offset[2]-offset[1]-5)*stride1;
539  for (int j=0;j<4;++j) {
540  for (int k=0;k<4;++k) {
541  if (k==0) {
542  C[j+4*k] = zero;
543  } else {
544  C[j+4*k] = simde_mm256_loadu_pd(
545  ddc+(starts[1]+j-1)*stride1+(starts[2]+k-1)*stride2+starts[0]);
546  }
547  }
548  }
549  ddc += stride2*(offset[3]-offset[2]-5);
550 
551  a = boor0_d0;
552 
553  b0 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
554  b1 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
555  b2 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
556  b3 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
557 
558  c0 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
559  c1 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
560  c2 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
561  c3 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
562 
563  ab[0] = simde_mm256_mul_pd(a, b0);
564  ab[1] = simde_mm256_mul_pd(a, b1);
565  ab[2] = simde_mm256_mul_pd(a, b2);
566  ab[3] = simde_mm256_mul_pd(a, b3);
567  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
568  // cab <- cab + ab[i]*C[i]
569  for (int i=0;i<4;++i) {
570  cab[i] = simde_mm256_set1_pd(0);
571  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
572  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
573  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
574  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
575  }
576 
577  // Reduce over the c direction
578  r = simde_mm256_set1_pd(0);
579  r = simde_mm256_fmadd_pd(cab[0], c0, r);
580  r = simde_mm256_fmadd_pd(cab[1], c1, r);
581  r = simde_mm256_fmadd_pd(cab[2], c2, r);
582  r = simde_mm256_fmadd_pd(cab[3], c3, r);
583 
584  // Sum all r entries
585  r0 = simde_mm256_castpd256_pd128(r);
586  r1 = simde_mm256_extractf128_pd(r, 1);
587  r0 = simde_mm_add_pd(r0, r1);
588  H[5] = H[7] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
589 
590 
591 
592  stride1 = offset[1]-offset[0]-5;
593  stride2 = (offset[2]-offset[1]-4)*stride1;
594  for (int j=0;j<4;++j) {
595  for (int k=0;k<4;++k) {
596  if (k==0) {
597  C[j+4*k] = zero;
598  } else {
599  C[j+4*k] = simde_mm256_loadu_pd(
600  ddc+(starts[1]+j)*stride1+(starts[2]+k-1)*stride2+starts[0]-1);
601  }
602  }
603  }
604  ddc += stride2*(offset[3]-offset[2]-5);
605 
606  a = boor0_d1;
607 
608  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
609  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
610  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
611  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
612 
613  c0 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
614  c1 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
615  c2 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
616  c3 = simde_mm256_permute4x64_pd(boor2_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
617 
618  ab[0] = simde_mm256_mul_pd(a, b0);
619  ab[1] = simde_mm256_mul_pd(a, b1);
620  ab[2] = simde_mm256_mul_pd(a, b2);
621  ab[3] = simde_mm256_mul_pd(a, b3);
622  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
623  // cab <- cab + ab[i]*C[i]
624  for (int i=0;i<4;++i) {
625  cab[i] = simde_mm256_set1_pd(0);
626  cab[i] = simde_mm256_fmadd_pd(ab[0], C[4*i+0], cab[i]);
627  cab[i] = simde_mm256_fmadd_pd(ab[1], C[4*i+1], cab[i]);
628  cab[i] = simde_mm256_fmadd_pd(ab[2], C[4*i+2], cab[i]);
629  cab[i] = simde_mm256_fmadd_pd(ab[3], C[4*i+3], cab[i]);
630  }
631 
632  // Reduce over the c direction
633  r = simde_mm256_set1_pd(0);
634  r = simde_mm256_fmadd_pd(cab[0], c0, r);
635  r = simde_mm256_fmadd_pd(cab[1], c1, r);
636  r = simde_mm256_fmadd_pd(cab[2], c2, r);
637  r = simde_mm256_fmadd_pd(cab[3], c3, r);
638 
639  // Sum all r entries
640  r0 = simde_mm256_castpd256_pd128(r);
641  r1 = simde_mm256_extractf128_pd(r, 1);
642  r0 = simde_mm_add_pd(r0, r1);
643  H[2] = H[6] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
644  }
645 }