casadi_blazing_2d_boor_eval.hpp
1 //
2 // MIT No Attribution
3 //
4 // Copyright (C) 2010-2023 Joel Andersson, Joris Gillis, Moritz Diehl, KU Leuven.
5 //
6 // Permission is hereby granted, free of charge, to any person obtaining a copy of this
7 // software and associated documentation files (the "Software"), to deal in the Software
8 // without restriction, including without limitation the rights to use, copy, modify,
9 // merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
10 // permit persons to whom the Software is furnished to do so.
11 //
12 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
13 // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
14 // PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
15 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
16 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
17 // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 //
19 
20 // SYMBOL "blazing_2d_boor_eval"
21 template<typename T1>
22 void casadi_blazing_2d_boor_eval(T1* f, T1* J, T1* H, const T1* all_knots, const casadi_int* offset, const T1* c, const T1* dc, const T1* ddc, const T1* all_x, const casadi_int* lookup_mode, casadi_int* iw, T1* w) { // NOLINT(whitespace/line_length)
23  casadi_int n_dims = 2;
24  casadi_int m = 1;
25  casadi_int n_iter, i, pivot;
26  casadi_int *boor_offset, *starts, *index, *coeff_offset;
27  T1 *cumprod;
28  boor_offset = iw; iw+=n_dims+1;
29  starts = iw; iw+=n_dims;
30  index = iw; iw+=n_dims;
31  coeff_offset = iw;
32  cumprod = w; w+= n_dims+1;
33  boor_offset[0] = 0;
34  cumprod[n_dims] = 1;
35  coeff_offset[n_dims] = 0;
36 
37  casadi_int stride1 = offset[1]-offset[0]-4;
38 
39  simde__m256d zero = simde_mm256_set1_pd(0.0);
40 
41  simde__m256d boor_start_0000 = zero;
42  simde__m256d boor_start_1111 = simde_mm256_set1_pd(1.0);
43  simde__m256d boor_start_0001 = simde_mm256_set_pd(1.0, 0.0, 0.0, 0.0);
44  simde__m256d boor_start_0010 = simde_mm256_set_pd(0.0, 1.0, 0.0, 0.0);
45 
46  simde__m256d boor0_d3;
47  simde__m256d boor0_d2;
48  simde__m256d boor0_d1;
49  simde__m256d boor0_d0;
50 
51  simde__m256d boor1_d3;
52  simde__m256d boor1_d2;
53  simde__m256d boor1_d1;
54  simde__m256d boor1_d0;
55 
56  const T1* knots;
57  T1 x;
58  casadi_int degree, n_knots, n_b, L, start;
59  degree = 3;
60  knots = all_knots + offset[0];
61  n_knots = offset[0+1]-offset[0];
62  n_b = n_knots-degree-1;
63  x = all_x[0];
64  L = casadi_low(x, knots+degree, n_knots-2*degree, lookup_mode[0]);
65  start = L;
66  if (start>n_b-degree-1) start = n_b-degree-1;
67  starts[0] = start;
68  boor0_d3 = boor_start_0000;
69  if (x>=knots[0] && x<=knots[n_knots-1]) {
70  if (x==knots[1]) {
71  boor0_d3 = boor_start_1111;
72  } else if (x==knots[n_knots-1]) {
73  boor0_d3 = boor_start_0001;
74  } else if (knots[L+degree]==x) {
75  boor0_d3 = boor_start_0010;
76  } else {
77  boor0_d3 = boor_start_0001;
78  }
79  }
80  casadi_blazing_de_boor(x, knots+start, &boor0_d0, &boor0_d1, &boor0_d2, &boor0_d3);
81 
82  knots = all_knots + offset[1];
83  n_knots = offset[1+1]-offset[1];
84  n_b = n_knots-degree-1;
85  x = all_x[1];
86  L = casadi_low(x, knots+degree, n_knots-2*degree, lookup_mode[1]);
87  start = L;
88  if (start>n_b-degree-1) start = n_b-degree-1;
89  starts[1] = start;
90  boor1_d3 = boor_start_0000;
91  if (x>=knots[0] && x<=knots[n_knots-1]) {
92  if (x==knots[1]) {
93  boor1_d3 = boor_start_1111;
94  } else if (x==knots[n_knots-1]) {
95  boor1_d3 = boor_start_0001;
96  } else if (knots[L+degree]==x) {
97  boor1_d3 = boor_start_0010;
98  } else {
99  boor1_d3 = boor_start_0001;
100  }
101  }
102  casadi_blazing_de_boor(x, knots+start, &boor1_d0, &boor1_d1, &boor1_d2, &boor1_d3);
103 
104  simde__m256d C[4];
105 
106  for (int j=0;j<4;++j) {
107  C[j] = simde_mm256_loadu_pd(c+(starts[1]+j)*stride1+starts[0]);
108  }
109 
110  simde__m256d a, b0, b1, b2, b3, c0, c1, c2, c3, r;
111  simde__m256d ab[4];
112  simde__m128d r0, r1;
113 
114  a = boor0_d0;
115  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
116  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
117  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
118  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
119 
120  // Need to compute sum_ab C_ab A_a B_b
121 
122  // Step 1: Outer product a b: A_a B_b
123  ab[0] = simde_mm256_mul_pd(a, b0);
124  ab[1] = simde_mm256_mul_pd(a, b1);
125  ab[2] = simde_mm256_mul_pd(a, b2);
126  ab[3] = simde_mm256_mul_pd(a, b3);
127 
128  // Sum over b axis: sum_b C_ab * (A_a B_b)_b
129  // r <- r + ab[i]*C[i]
130  r = simde_mm256_set1_pd(0);
131  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
132  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
133  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
134  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
135 
136  if (f) {
137  // Sum all cab entries
138  r0 = simde_mm256_castpd256_pd128(r);
139  r1 = simde_mm256_extractf128_pd(r, 1);
140  r0 = simde_mm_add_pd(r0, r1);
141  f[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
142  }
143 
144  // First derivative
145  if (dc && J) {
146  stride1 = offset[1]-offset[0]-4-1;
147  for (int j=0;j<4;++j) {
148  C[j] = simde_mm256_loadu_pd(dc+(starts[1]+j)*stride1+starts[0]-1);
149  }
150  dc += stride1*(offset[2]-offset[1]-4);
151 
152  a = boor0_d1;
153  ab[0] = simde_mm256_mul_pd(a, b0);
154  ab[1] = simde_mm256_mul_pd(a, b1);
155  ab[2] = simde_mm256_mul_pd(a, b2);
156  ab[3] = simde_mm256_mul_pd(a, b3);
157 
158  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
159  // cab <- cab + ab[i]*C[i]
160  r = simde_mm256_set1_pd(0);
161  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
162  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
163  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
164  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
165 
166  // Sum all r entries
167  r0 = simde_mm256_castpd256_pd128(r);
168  r1 = simde_mm256_extractf128_pd(r, 1);
169  r0 = simde_mm_add_pd(r0, r1);
170  J[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
171 
172 
173  stride1 = offset[1]-offset[0]-4;
174  for (int j=0;j<4;++j) {
175  if (j==0) {
176  C[j] = zero;
177  } else {
178  C[j] = simde_mm256_loadu_pd(dc+(starts[1]+j-1)*stride1+starts[0]);
179  }
180  }
181 
182  a = boor0_d0;
183 
184  b0 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
185  b1 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
186  b2 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
187  b3 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
188 
189  ab[0] = simde_mm256_mul_pd(a, b0);
190  ab[1] = simde_mm256_mul_pd(a, b1);
191  ab[2] = simde_mm256_mul_pd(a, b2);
192  ab[3] = simde_mm256_mul_pd(a, b3);
193 
194  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
195  // cab <- cab + ab[i]*C[i]
196  r = simde_mm256_set1_pd(0);
197  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
198  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
199  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
200  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
201 
202  // Sum all r entries
203  r0 = simde_mm256_castpd256_pd128(r);
204  r1 = simde_mm256_extractf128_pd(r, 1);
205  r0 = simde_mm_add_pd(r0, r1);
206  J[1] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
207  }
208 
209  if (ddc && H) {
210  stride1 = offset[1]-offset[0]-4-2;
211  for (int j=0;j<4;++j) {
212  C[j] = simde_mm256_loadu_pd(ddc+(starts[1]+j)*stride1+starts[0]-2);
213  }
214  ddc += stride1*(offset[2]-offset[1]-4);
215 
216  a = boor0_d2;
217  b0 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
218  b1 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
219  b2 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
220  b3 = simde_mm256_permute4x64_pd(boor1_d0, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
221 
222  ab[0] = simde_mm256_mul_pd(a, b0);
223  ab[1] = simde_mm256_mul_pd(a, b1);
224  ab[2] = simde_mm256_mul_pd(a, b2);
225  ab[3] = simde_mm256_mul_pd(a, b3);
226  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
227  // cab <- cab + ab[i]*C[i]
228  r = simde_mm256_set1_pd(0);
229  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
230  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
231  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
232  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
233 
234  // Sum all r entries
235  r0 = simde_mm256_castpd256_pd128(r);
236  r1 = simde_mm256_extractf128_pd(r, 1);
237  r0 = simde_mm_add_pd(r0, r1);
238  H[0] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
239 
240  stride1 = offset[1]-offset[0]-4;
241  for (int j=0;j<4;++j) {
242  if (j<=1) {
243  C[j] = zero;
244  } else {
245  C[j] = simde_mm256_loadu_pd(ddc+(starts[1]+j-2)*stride1+starts[0]);
246  }
247  }
248  ddc += stride1*(offset[2]-offset[1]-4-2);
249 
250  a = boor0_d0;
251  b0 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
252  b1 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
253  b2 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
254  b3 = simde_mm256_permute4x64_pd(boor1_d2, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
255 
256  ab[0] = simde_mm256_mul_pd(a, b0);
257  ab[1] = simde_mm256_mul_pd(a, b1);
258  ab[2] = simde_mm256_mul_pd(a, b2);
259  ab[3] = simde_mm256_mul_pd(a, b3);
260  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
261  // cab <- cab + ab[i]*C[i]
262  r = simde_mm256_set1_pd(0);
263  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
264  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
265  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
266  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
267 
268  // Sum all r entries
269  r0 = simde_mm256_castpd256_pd128(r);
270  r1 = simde_mm256_extractf128_pd(r, 1);
271  r0 = simde_mm_add_pd(r0, r1);
272  H[3] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
273 
274  stride1 = offset[1]-offset[0]-5;
275  for (int j=0;j<4;++j) {
276  if (j==0) {
277  C[j] = zero;
278  } else {
279  C[j] = simde_mm256_loadu_pd(ddc+(starts[1]+j-1)*stride1+starts[0]-1);
280  }
281  }
282  ddc += stride1*(offset[3]-offset[2]-5);
283 
284  a = boor0_d1;
285 
286  b0 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(0, 0, 0, 0));
287  b1 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(1, 1, 1, 1));
288  b2 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(2, 2, 2, 2));
289  b3 = simde_mm256_permute4x64_pd(boor1_d1, SIMDE_MM_SHUFFLE(3, 3, 3, 3));
290 
291  ab[0] = simde_mm256_mul_pd(a, b0);
292  ab[1] = simde_mm256_mul_pd(a, b1);
293  ab[2] = simde_mm256_mul_pd(a, b2);
294  ab[3] = simde_mm256_mul_pd(a, b3);
295  // Sum over b axis: sum_b C_abc * (A_a B_b)_b
296  // cab <- cab + ab[i]*C[i]
297  r = simde_mm256_set1_pd(0);
298  r = simde_mm256_fmadd_pd(ab[0], C[0], r);
299  r = simde_mm256_fmadd_pd(ab[1], C[1], r);
300  r = simde_mm256_fmadd_pd(ab[2], C[2], r);
301  r = simde_mm256_fmadd_pd(ab[3], C[3], r);
302 
303  // Sum all r entries
304  r0 = simde_mm256_castpd256_pd128(r);
305  r1 = simde_mm256_extractf128_pd(r, 1);
306  r0 = simde_mm_add_pd(r0, r1);
307  H[1] = H[2] = simde_mm_cvtsd_f64(simde_mm_add_sd(r0, simde_mm_unpackhi_pd(r0, r0)));
308 
309  }
310 
311 }