Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_lclDot.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_LCLDOT_HPP
11#define TPETRA_DETAILS_LCLDOT_HPP
12
19
20#include "Kokkos_DualView.hpp"
21#include "Kokkos_ArithTraits.hpp"
22#include "KokkosBlas1_dot.hpp"
23#include "Teuchos_ArrayView.hpp"
24#include "Teuchos_TestForException.hpp"
25
26namespace Tpetra {
27namespace Details {
28
29template<class RV, class XMV>
30void
31lclDot (const RV& dotsOut,
32 const XMV& X_lcl,
33 const XMV& Y_lcl,
34 const size_t lclNumRows,
35 const size_t numVecs,
36 const size_t whichVecsX[],
37 const size_t whichVecsY[],
38 const bool constantStrideX,
39 const bool constantStrideY)
40{
41 using Kokkos::ALL;
42 using Kokkos::subview;
43 typedef typename RV::non_const_value_type dot_type;
44#ifdef HAVE_TPETRA_DEBUG
45 const char prefix[] = "Tpetra::MultiVector::lclDotImpl: ";
46#endif // HAVE_TPETRA_DEBUG
47
48 static_assert (Kokkos::is_view<RV>::value,
49 "Tpetra::MultiVector::lclDotImpl: "
50 "The first argument dotsOut is not a Kokkos::View.");
51 static_assert (RV::rank == 1, "Tpetra::MultiVector::lclDotImpl: "
52 "The first argument dotsOut must have rank 1.");
53 static_assert (Kokkos::is_view<XMV>::value,
54 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
55 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
56 static_assert (XMV::rank == 2, "Tpetra::MultiVector::lclDotImpl: "
57 "X_lcl and Y_lcl must have rank 2.");
58
59 // In case the input dimensions don't match, make sure that we
60 // don't overwrite memory that doesn't belong to us, by using
61 // subset views with the minimum dimensions over all input.
62 const std::pair<size_t, size_t> rowRng (0, lclNumRows);
63 const std::pair<size_t, size_t> colRng (0, numVecs);
64 RV theDots = subview (dotsOut, colRng);
65 XMV X = subview (X_lcl, rowRng, ALL ());
66 XMV Y = subview (Y_lcl, rowRng, ALL ());
67
68#ifdef HAVE_TPETRA_DEBUG
69 if (lclNumRows != 0) {
70 TEUCHOS_TEST_FOR_EXCEPTION
71 (X.extent (0) != lclNumRows, std::logic_error, prefix <<
72 "X.extent(0) = " << X.extent (0) << " != lclNumRows "
73 "= " << lclNumRows << ". "
74 "Please report this bug to the Tpetra developers.");
75 TEUCHOS_TEST_FOR_EXCEPTION
76 (Y.extent (0) != lclNumRows, std::logic_error, prefix <<
77 "Y.extent(0) = " << Y.extent (0) << " != lclNumRows "
78 "= " << lclNumRows << ". "
79 "Please report this bug to the Tpetra developers.");
80 // If a MultiVector is constant stride, then numVecs should
81 // equal its View's number of columns. Otherwise, numVecs
82 // should be less than its View's number of columns.
83 TEUCHOS_TEST_FOR_EXCEPTION
84 (constantStrideX &&
85 (X.extent (0) != lclNumRows || X.extent (1) != numVecs),
86 std::logic_error, prefix << "X is " << X.extent (0) << " x " <<
87 X.extent (1) << " (constant stride), which differs from the "
88 "local dimensions " << lclNumRows << " x " << numVecs << ". "
89 "Please report this bug to the Tpetra developers.");
90 TEUCHOS_TEST_FOR_EXCEPTION
91 (! constantStrideX &&
92 (X.extent (0) != lclNumRows || X.extent (1) < numVecs),
93 std::logic_error, prefix << "X is " << X.extent (0) << " x " <<
94 X.extent (1) << " (NOT constant stride), but the local "
95 "dimensions are " << lclNumRows << " x " << numVecs << ". "
96 "Please report this bug to the Tpetra developers.");
97 TEUCHOS_TEST_FOR_EXCEPTION
98 (constantStrideY &&
99 (Y.extent (0) != lclNumRows || Y.extent (1) != numVecs),
100 std::logic_error, prefix << "Y is " << Y.extent (0) << " x " <<
101 Y.extent (1) << " (constant stride), which differs from the "
102 "local dimensions " << lclNumRows << " x " << numVecs << ". "
103 "Please report this bug to the Tpetra developers.");
104 TEUCHOS_TEST_FOR_EXCEPTION
105 (! constantStrideY &&
106 (Y.extent (0) != lclNumRows || Y.extent (1) < numVecs),
107 std::logic_error, prefix << "Y is " << Y.extent (0) << " x " <<
108 Y.extent (1) << " (NOT constant stride), but the local "
109 "dimensions are " << lclNumRows << " x " << numVecs << ". "
110 "Please report this bug to the Tpetra developers.");
111 }
112#endif // HAVE_TPETRA_DEBUG
113
114 if (lclNumRows == 0) {
115 const dot_type zero = Kokkos::ArithTraits<dot_type>::zero ();
116 // DEEP_COPY REVIEW - NOT TESTED
117 Kokkos::deep_copy (theDots, zero);
118 }
119 else { // lclNumRows != 0
120 if (constantStrideX && constantStrideY) {
121 if (X.extent (1) == 1) {
122 typename RV::non_const_value_type result =
123 KokkosBlas::dot (subview (X, ALL (), 0), subview (Y, ALL (), 0));
124 // DEEP_COPY REVIEW - NOT TESTED
125 Kokkos::deep_copy (theDots, result);
126 }
127 else {
128 KokkosBlas::dot (theDots, X, Y);
129 }
130 }
131 else { // not constant stride
132 // NOTE (mfh 15 Jul 2014) This does a kernel launch for
133 // every column. It might be better to have a kernel that
134 // does the work all at once. On the other hand, we don't
135 // prioritize performance of MultiVector views of
136 // noncontiguous columns.
137 for (size_t k = 0; k < numVecs; ++k) {
138 const size_t X_col = constantStrideX ? k : whichVecsX[k];
139 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
140 KokkosBlas::dot (subview (theDots, k), subview (X, ALL (), X_col),
141 subview (Y, ALL (), Y_col));
142 } // for each column
143 } // constantStride
144 } // lclNumRows != 0
145}
146
147} // namespace Details
148} // namespace Tpetra
149
150#endif // TPETRA_DETAILS_LCLDOT_HPP
Nonmember function that computes a residual Computes R = B - A * X.
Namespace Tpetra contains the class and methods constituting the Tpetra library.