Intrepid2
Intrepid2_DataCombiners.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Intrepid2 Package
4//
5// Copyright 2007 NTESS and the Intrepid2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9//
10// Intrepid2_DataCombiners.hpp
11// Trilinos
12//
13// Created by Roberts, Nathan V on 5/31/23.
14//
15
16#ifndef Intrepid2_DataCombiners_hpp
17#define Intrepid2_DataCombiners_hpp
18
23
25#include "Intrepid2_Data.hpp"
28#include "Intrepid2_ScalarView.hpp"
29
30namespace Intrepid2 {
31 template<class DataScalar,typename DeviceType>
32 class Data;
33
34 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
35 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB, bool includeInnerLoop=false>
36 struct InPlaceCombinationFunctor
37 {
38 private:
39 ThisUnderlyingViewType this_underlying_;
40 AUnderlyingViewType A_underlying_;
41 BUnderlyingViewType B_underlying_;
42 BinaryOperator binaryOperator_;
43 int innerLoopSize_;
44 public:
45 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
46 BinaryOperator binaryOperator)
47 :
48 this_underlying_(this_underlying),
49 A_underlying_(A_underlying),
50 B_underlying_(B_underlying),
51 binaryOperator_(binaryOperator)
52 {
53 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
54 }
55
56 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
57 BinaryOperator binaryOperator, int innerLoopSize)
58 :
59 this_underlying_(this_underlying),
60 A_underlying_(A_underlying),
61 B_underlying_(B_underlying),
62 binaryOperator_(binaryOperator),
63 innerLoopSize_(innerLoopSize)
64 {
65 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
66 }
67
68 template<class ...IntArgs, bool M=includeInnerLoop>
69 KOKKOS_INLINE_FUNCTION
70 enable_if_t<!M, void>
71 operator()(const IntArgs&... args) const
72 {
73 auto & result = ArgExtractorThis::get( this_underlying_, args... );
74 const auto & A_val = ArgExtractorA::get( A_underlying_, args... );
75 const auto & B_val = ArgExtractorB::get( B_underlying_, args... );
76
77 result = binaryOperator_(A_val,B_val);
78 }
79
80 template<class ...IntArgs, bool M=includeInnerLoop>
81 KOKKOS_INLINE_FUNCTION
82 enable_if_t<M, void>
83 operator()(const IntArgs&... args) const
84 {
85 using int_type = std::tuple_element_t<0, std::tuple<IntArgs...>>;
86 for (int_type iFinal=0; iFinal<static_cast<int_type>(innerLoopSize_); iFinal++)
87 {
88 auto & result = ArgExtractorThis::get( this_underlying_, args..., iFinal );
89 const auto & A_val = ArgExtractorA::get( A_underlying_, args..., iFinal );
90 const auto & B_val = ArgExtractorB::get( B_underlying_, args..., iFinal );
91
92 result = binaryOperator_(A_val,B_val);
93 }
94 }
95 };
96
98 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType>
99 struct InPlaceCombinationFunctorConstantCase
100 {
101 private:
102 ThisUnderlyingViewType this_underlying_;
103 AUnderlyingViewType A_underlying_;
104 BUnderlyingViewType B_underlying_;
105 BinaryOperator binaryOperator_;
106 public:
107 InPlaceCombinationFunctorConstantCase(ThisUnderlyingViewType this_underlying,
108 AUnderlyingViewType A_underlying,
109 BUnderlyingViewType B_underlying,
110 BinaryOperator binaryOperator)
111 :
112 this_underlying_(this_underlying),
113 A_underlying_(A_underlying),
114 B_underlying_(B_underlying),
115 binaryOperator_(binaryOperator)
116 {
117 INTREPID2_TEST_FOR_EXCEPTION(this_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
118 INTREPID2_TEST_FOR_EXCEPTION(A_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
119 INTREPID2_TEST_FOR_EXCEPTION(B_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
120 }
121
122 KOKKOS_INLINE_FUNCTION
123 void operator()(const int arg0) const
124 {
125 auto & result = this_underlying_(0);
126 const auto & A_val = A_underlying_(0);
127 const auto & B_val = B_underlying_(0);
128
129 result = binaryOperator_(A_val,B_val);
130 }
131 };
132
134 template<bool passThroughBlockDiagonalArgs>
136 {
137 template<class ViewType, class ...IntArgs>
138 static KOKKOS_INLINE_FUNCTION typename ViewType::reference_type get(const ViewType &view, const IntArgs&... intArgs)
139 {
140 return view.getWritableEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
141 }
142 };
143
145 template<bool passThroughBlockDiagonalArgs>
147 {
148 template<class ViewType, class ...IntArgs>
149 static KOKKOS_INLINE_FUNCTION typename ViewType::const_reference_type get(const ViewType &view, const IntArgs&... intArgs)
150 {
151 return view.getEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
152 }
153 };
154
155// static class for combining two Data objects using a specified binary operator
156 template <class DataScalar,typename DeviceType, class BinaryOperator>
158{
159 using reference_type = typename ScalarView<DataScalar,DeviceType>::reference_type;
160 using const_reference_type = typename ScalarView<const DataScalar,DeviceType>::reference_type;
161public:
163 template<class PolicyType, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
164 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB>
165 static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying,
166 AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying,
167 BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
168 {
170 Functor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
171 Kokkos::parallel_for("compute in-place", policy, functor);
172 }
173
175 template<int rank>
176 static
177 enable_if_t<rank != 7, void>
179 {
180 auto policy = thisData.template dataExtentRangePolicy<rank>();
181
182 const bool A_1D = A.getUnderlyingViewRank() == 1;
183 const bool B_1D = B.getUnderlyingViewRank() == 1;
184 const bool this_1D = thisData.getUnderlyingViewRank() == 1;
185 const bool A_constant = A_1D && (A.getUnderlyingViewSize() == 1);
186 const bool B_constant = B_1D && (B.getUnderlyingViewSize() == 1);
187 const bool this_constant = this_1D && (thisData.getUnderlyingViewSize() == 1);
188 const bool A_full = A.underlyingMatchesLogical();
189 const bool B_full = B.underlyingMatchesLogical();
190 const bool this_full = thisData.underlyingMatchesLogical();
191
193
195 const FullArgExtractorData<true> fullArgsData; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
196 const FullArgExtractorWritableData<true> fullArgsWritable; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
197
204
205 // this lambda returns -1 if there is not a rank-1 underlying view whose data extent matches the logical extent in the corresponding dimension;
206 // otherwise, it returns the logical index of the corresponding dimension.
207 auto get1DArgIndex = [](const Data<DataScalar,DeviceType> &data) -> int
208 {
209 const auto & variationTypes = data.getVariationTypes();
210 for (int d=0; d<rank; d++)
211 {
212 if (variationTypes[d] == GENERAL)
213 {
214 return d;
215 }
216 }
217 return -1;
218 };
219 if (this_constant)
220 {
221 // then A, B are constant, too
222 auto thisAE = constArg;
223 auto AAE = constArg;
224 auto BAE = constArg;
225 auto & this_underlying = thisData.template getUnderlyingView<1>();
226 auto & A_underlying = A.template getUnderlyingView<1>();
227 auto & B_underlying = B.template getUnderlyingView<1>();
228 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
229 }
230 else if (this_full && A_full && B_full)
231 {
232 auto thisAE = fullArgs;
233 auto AAE = fullArgs;
234 auto BAE = fullArgs;
235
236 auto & this_underlying = thisData.template getUnderlyingView<rank>();
237 auto & A_underlying = A.template getUnderlyingView<rank>();
238 auto & B_underlying = B.template getUnderlyingView<rank>();
239
240 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
241 }
242 else if (A_constant)
243 {
244 auto AAE = constArg;
245 auto & A_underlying = A.template getUnderlyingView<1>();
246 if (this_full)
247 {
248 auto thisAE = fullArgs;
249 auto & this_underlying = thisData.template getUnderlyingView<rank>();
250
251 if (B_full)
252 {
253 auto BAE = fullArgs;
254 auto & B_underlying = B.template getUnderlyingView<rank>();
255 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
256 }
257 else // this_full, not B_full: B may have modular data, etc.
258 {
259 auto BAE = fullArgsData;
260 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
261 }
262 }
263 else // this is not full
264 {
265 // below, we optimize for the case of 1D data in B, when A is constant. Still need to handle other cases…
266 if (B_1D && (get1DArgIndex(B) != -1) )
267 {
268 // since A is constant, that implies that this_1D is true, and has the same 1DArgIndex
269 const int argIndex = get1DArgIndex(B);
270 auto & B_underlying = B.template getUnderlyingView<1>();
271 auto & this_underlying = thisData.template getUnderlyingView<1>();
272 switch (argIndex)
273 {
274 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, AAE, arg0); break;
275 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, AAE, arg1); break;
276 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, AAE, arg2); break;
277 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, AAE, arg3); break;
278 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, AAE, arg4); break;
279 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, AAE, arg5); break;
280 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
281 }
282 }
283 else
284 {
285 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
286 auto thisAE = fullArgsWritable;
287 auto BAE = fullArgsData;
288 storeInPlaceCombination(policy, thisData, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
289 }
290 }
291 }
292 else if (B_constant)
293 {
294 auto BAE = constArg;
295 auto & B_underlying = B.template getUnderlyingView<1>();
296 if (this_full)
297 {
298 auto thisAE = fullArgs;
299 auto & this_underlying = thisData.template getUnderlyingView<rank>();
300 if (A_full)
301 {
302 auto AAE = fullArgs;
303 auto & A_underlying = A.template getUnderlyingView<rank>();
304
305 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
306 }
307 else // this_full, not A_full: A may have modular data, etc.
308 {
309 // use A (the Data object). This could be further optimized by using A's underlying View and an appropriately-defined ArgExtractor.
310 auto AAE = fullArgsData;
311 storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
312 }
313 }
314 else // this is not full
315 {
316 // below, we optimize for the case of 1D data in A, when B is constant. Still need to handle other cases…
317 if (A_1D && (get1DArgIndex(A) != -1) )
318 {
319 // since B is constant, that implies that this_1D is true, and has the same 1DArgIndex as A
320 const int argIndex = get1DArgIndex(A);
321 auto & A_underlying = A.template getUnderlyingView<1>();
322 auto & this_underlying = thisData.template getUnderlyingView<1>();
323 switch (argIndex)
324 {
325 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, BAE); break;
326 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, BAE); break;
327 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, BAE); break;
328 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, BAE); break;
329 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, BAE); break;
330 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, BAE); break;
331 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
332 }
333 }
334 else
335 {
336 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
337 auto thisAE = fullArgsWritable;
338 auto AAE = fullArgsData;
339 storeInPlaceCombination(policy, thisData, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
340 }
341 }
342 }
343 else // neither A nor B constant
344 {
345 if (this_1D && (get1DArgIndex(thisData) != -1))
346 {
347 // possible ways that "this" could have full-extent, 1D data
348 // 1. A constant, B 1D
349 // 2. A 1D, B constant
350 // 3. A 1D, B 1D
351 // The constant possibilities are already addressed above, leaving us with (3). Note that A and B don't have to be full-extent, however
352 const int argThis = get1DArgIndex(thisData);
353 const int argA = get1DArgIndex(A); // if not full-extent, will be -1
354 const int argB = get1DArgIndex(B); // ditto
355
356 auto & A_underlying = A.template getUnderlyingView<1>();
357 auto & B_underlying = B.template getUnderlyingView<1>();
358 auto & this_underlying = thisData.template getUnderlyingView<1>();
359 if ((argA != -1) && (argB != -1))
360 {
361#ifdef INTREPID2_HAVE_DEBUG
362 INTREPID2_TEST_FOR_EXCEPTION(argA != argThis, std::logic_error, "Unexpected 1D arg combination.");
363 INTREPID2_TEST_FOR_EXCEPTION(argB != argThis, std::logic_error, "Unexpected 1D arg combination.");
364#endif
365 switch (argThis)
366 {
367 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, arg0); break;
368 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, arg1); break;
369 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, arg2); break;
370 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, arg3); break;
371 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, arg4); break;
372 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, arg5); break;
373 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
374 }
375 }
376 else if (argA != -1)
377 {
378 // B is not full-extent in dimension argThis; use the Data object
379 switch (argThis)
380 {
381 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg0, arg0, fullArgsData); break;
382 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg1, arg1, fullArgsData); break;
383 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg2, arg2, fullArgsData); break;
384 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg3, arg3, fullArgsData); break;
385 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg4, arg4, fullArgsData); break;
386 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg5, arg5, fullArgsData); break;
387 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
388 }
389 }
390 else
391 {
392 // A is not full-extent in dimension argThis; use the Data object
393 switch (argThis)
394 {
395 case 0: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg0, fullArgsData, arg0); break;
396 case 1: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg1, fullArgsData, arg1); break;
397 case 2: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg2, fullArgsData, arg2); break;
398 case 3: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg3, fullArgsData, arg3); break;
399 case 4: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg4, fullArgsData, arg4); break;
400 case 5: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg5, fullArgsData, arg5); break;
401 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
402 }
403 }
404 }
405 else if (this_full)
406 {
407 // This case uses A,B Data objects; could be optimized by dividing into subcases and using underlying Views with appropriate ArgExtractors.
408 auto & this_underlying = thisData.template getUnderlyingView<rank>();
409 auto thisAE = fullArgs;
410
411 if (A_full)
412 {
413 auto & A_underlying = A.template getUnderlyingView<rank>();
414 auto AAE = fullArgs;
415
416 if (B_1D && (get1DArgIndex(B) != -1))
417 {
418 const int argIndex = get1DArgIndex(B);
419 auto & B_underlying = B.template getUnderlyingView<1>();
420 switch (argIndex)
421 {
422 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg0); break;
423 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg1); break;
424 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg2); break;
425 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg3); break;
426 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg4); break;
427 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg5); break;
428 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
429 }
430 }
431 else
432 {
433 // A is full; B is not full, but not constant or full-extent 1D
434 // unoptimized in B access:
436 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
437 }
438 }
439 else // A is not full
440 {
441 if (A_1D && (get1DArgIndex(A) != -1))
442 {
443 const int argIndex = get1DArgIndex(A);
444 auto & A_underlying = A.template getUnderlyingView<1>();
445 if (B_full)
446 {
447 auto & B_underlying = B.template getUnderlyingView<rank>();
448 auto BAE = fullArgs;
449 switch (argIndex)
450 {
451 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg0, BAE); break;
452 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg1, BAE); break;
453 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg2, BAE); break;
454 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg3, BAE); break;
455 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg4, BAE); break;
456 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg5, BAE); break;
457 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
458 }
459 }
460 else
461 {
462 auto BAE = fullArgsData;
463 switch (argIndex)
464 {
465 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg0, BAE); break;
466 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg1, BAE); break;
467 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg2, BAE); break;
468 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg3, BAE); break;
469 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg4, BAE); break;
470 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg5, BAE); break;
471 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
472 }
473 }
474 }
475 else // A not full, and not full-extent 1D
476 {
477 // unoptimized in A, B accesses.
478 auto AAE = fullArgsData;
479 auto BAE = fullArgsData;
480 storeInPlaceCombination(policy, this_underlying, A, B, binaryOperator, thisAE, AAE, BAE);
481 }
482 }
483 }
484 else
485 {
486 // completely un-optimized case: we use Data objects for this, A, B.
487 auto thisAE = fullArgsWritable;
488 auto AAE = fullArgsData;
489 auto BAE = fullArgsData;
490 storeInPlaceCombination(policy, thisData, A, B, binaryOperator, thisAE, AAE, BAE);
491 }
492 }
493 }
494
496 template<int rank>
497 static
498 enable_if_t<rank == 7, void>
500 {
501 auto policy = thisData.template dataExtentRangePolicy<rank>();
502
503 using DataType = Data<DataScalar,DeviceType>;
507
508 const ordinal_type dim6 = thisData.getDataExtent(6);
509 const bool includeInnerLoop = true;
511 Functor functor(thisData, A, B, binaryOperator, dim6);
512 Kokkos::parallel_for("compute in-place", policy, functor);
513 }
514
515 static void storeInPlaceCombination(Data<DataScalar,DeviceType> &thisData, const Data<DataScalar,DeviceType> &A, const Data<DataScalar,DeviceType> &B, BinaryOperator binaryOperator)
516 {
517 using ExecutionSpace = typename DeviceType::execution_space;
518
519#ifdef INTREPID2_HAVE_DEBUG
520 // check logical extents
521 for (int d=0; d<rank_; d++)
522 {
523 INTREPID2_TEST_FOR_EXCEPTION(A.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
524 INTREPID2_TEST_FOR_EXCEPTION(B.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
525 }
526 // TODO: add some checks that data extent of this suffices to accept combined A + B data.
527#endif
528
529 const bool this_constant = (thisData.getUnderlyingViewRank() == 1) && (thisData.getUnderlyingViewSize() == 1);
530
531 // we special-case for constant output here; since the constant case is essentially all overhead, we want to avoid as much of the overhead of storeInPlaceCombination() as possible…
532 if (this_constant)
533 {
534 // constant data
535 Kokkos::RangePolicy<ExecutionSpace> policy(ExecutionSpace(),0,1); // just 1 entry
536
537 auto this_underlying = thisData.template getUnderlyingView<1>();
538 auto A_underlying = A.template getUnderlyingView<1>();
539 auto B_underlying = B.template getUnderlyingView<1>();
540
541 using ConstantCaseFunctor = InPlaceCombinationFunctorConstantCase<decltype(binaryOperator), decltype(this_underlying),
542 decltype(A_underlying), decltype(B_underlying)>;
543
544 ConstantCaseFunctor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
545 Kokkos::parallel_for("compute in-place", policy,functor);
546 }
547 else
548 {
549 switch (thisData.rank())
550 {
551 case 1: storeInPlaceCombination<1>(thisData, A, B, binaryOperator); break;
552 case 2: storeInPlaceCombination<2>(thisData, A, B, binaryOperator); break;
553 case 3: storeInPlaceCombination<3>(thisData, A, B, binaryOperator); break;
554 case 4: storeInPlaceCombination<4>(thisData, A, B, binaryOperator); break;
555 case 5: storeInPlaceCombination<5>(thisData, A, B, binaryOperator); break;
556 case 6: storeInPlaceCombination<6>(thisData, A, B, binaryOperator); break;
557 case 7: storeInPlaceCombination<7>(thisData, A, B, binaryOperator); break;
558 default:
559 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(true, std::logic_error, "unhandled rank in switch");
560 }
561 }
562 }
563};
564
565} // end namespace Intrepid2
566
567// We do ETI for basic double arithmetic on default device.
568//template<class Scalar> struct ScalarSumFunctor;
569//template<class Scalar> struct ScalarDifferenceFunctor;
570//template<class Scalar> struct ScalarProductFunctor;
571//template<class Scalar> struct ScalarQuotientFunctor;
572
577
578#endif /* Intrepid2_DataCombiners_hpp */
Header file with various static argument-extractor classes. These are useful for writing efficient,...
Defines functors for use with Data objects: so far, we include simple arithmetical functors for sum,...
Defines DataVariationType enum that specifies the types of variation possible within a Data object.
@ GENERAL
arbitrary variation
Defines the Data class, a wrapper around a Kokkos::View that allows data that is constant or repeatin...
#define INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(test, x, msg)
static enable_if_t< rank==7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank of 7. (Not optimized; expect...
static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying, AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying, BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
storeInPlaceCombination implementation for rank < 7, with compile-time underlying views and argument ...
static enable_if_t< rank !=7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank < 7.
Wrapper around a Kokkos::View that allows data that is constant or repeating in various logical dimen...
KOKKOS_INLINE_FUNCTION int extent_int(const int &r) const
Returns the logical extent in the specified dimension.
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewSize() const
returns the number of entries in the View that stores the unique data
KOKKOS_INLINE_FUNCTION int getDataExtent(const ordinal_type &d) const
returns the true extent of the data corresponding to the logical dimension provided; if the data does...
KOKKOS_INLINE_FUNCTION bool underlyingMatchesLogical() const
Returns true if the underlying container has exactly the same rank and extents as the logical contain...
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewRank() const
returns the rank of the View that stores the unique data
KOKKOS_INLINE_FUNCTION unsigned rank() const
Returns the logical rank of the Data container.
Argument extractor class which ignores the input arguments in favor of passing a single 0 argument to...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
Argument extractor class which passes all arguments to the provided container.
Argument extractor class which passes a single argument, indicated by the template parameter whichArg...