10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
31#ifdef HAVE_TPETRACORE_MPI
35#include "Tpetra_Details_temporaryViewUtils.hpp"
37#include "Kokkos_Core.hpp"
43#ifndef DOXYGEN_SHOULD_SKIP_THIS
46 template<
class OrdinalType>
class Comm;
53#ifdef HAVE_TPETRACORE_MPI
54std::string getMpiErrorString (
const int errCode);
84std::shared_ptr<CommRequest>
87#ifdef HAVE_TPETRACORE_MPI
89template<
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
90struct MpiRequest :
public CommRequest
92 MpiRequest(
const InputViewType& send,
const OutputViewType& recv,
const ResultViewType& result, MPI_Request req_)
93 : sendBuf(send), recvBuf(recv), resultBuf(result), req(req_)
106 void wait ()
override
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait (&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION
111 (err != MPI_SUCCESS, std::runtime_error,
112 "MpiCommRequest::wait: MPI_Wait failed with error \""
113 << getMpiErrorString (err));
116 req = MPI_REQUEST_NULL;
118 Kokkos::deep_copy(resultBuf, recvBuf);
125 void cancel ()
override
129 req = MPI_REQUEST_NULL;
133 InputViewType sendBuf;
134 OutputViewType recvBuf;
135 ResultViewType resultBuf;
143iallreduceRaw (
const void* sendbuf,
146 MPI_Datatype mpiDatatype,
147 const Teuchos::EReductionType op,
153allreduceRaw (
const void* sendbuf,
156 MPI_Datatype mpiDatatype,
157 const Teuchos::EReductionType op,
160template<
class InputViewType,
class OutputViewType>
161std::shared_ptr<CommRequest>
162iallreduceImpl (
const InputViewType& sendbuf,
163 const OutputViewType& recvbuf,
164 const ::Teuchos::EReductionType op,
165 const ::Teuchos::Comm<int>& comm)
167 using Packet =
typename InputViewType::non_const_value_type;
168 if(comm.getSize() == 1)
170 Kokkos::deep_copy(recvbuf, sendbuf);
171 return emptyCommRequest();
173 Packet examplePacket;
174 MPI_Datatype mpiDatatype = sendbuf.extent(0) ?
175 MpiTypeTraits<Packet>::getType (examplePacket) :
177 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
178 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos (comm);
181 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
182 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
183 std::shared_ptr<CommRequest> req;
186 if(
isInterComm(comm) && sendMPI.data() == recvMPI.data())
190 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
191 for(
size_t i = 0; i < sendMPI.extent(0); i++)
192 tempInput(i) = sendMPI.data()[i];
195 MPI_Request mpiReq = iallreduceRaw((
const void*) tempInput.data(), (
void*) recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
196 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(tempInput),
decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
199 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
200 Kokkos::deep_copy(recvbuf, recvMPI);
201 req = emptyCommRequest();
208 MPI_Request mpiReq = iallreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
209 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(sendMPI),
decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
212 allreduceRaw((
const void*) sendMPI.data(), (
void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
213 Kokkos::deep_copy(recvbuf, recvMPI);
214 req = emptyCommRequest();
217 if(datatypeNeedsFree)
218 MPI_Type_free(&mpiDatatype);
225template<
class InputViewType,
class OutputViewType>
226std::shared_ptr<CommRequest>
227iallreduceImpl (
const InputViewType& sendbuf,
228 const OutputViewType& recvbuf,
229 const ::Teuchos::EReductionType,
230 const ::Teuchos::Comm<int>&)
232 Kokkos::deep_copy(recvbuf, sendbuf);
233 return emptyCommRequest();
269template<
class InputViewType,
class OutputViewType>
270std::shared_ptr<CommRequest>
271iallreduce (
const InputViewType& sendbuf,
272 const OutputViewType& recvbuf,
273 const ::Teuchos::EReductionType op,
274 const ::Teuchos::Comm<int>& comm)
276 static_assert (Kokkos::is_view<InputViewType>::value,
277 "InputViewType must be a Kokkos::View specialization.");
278 static_assert (Kokkos::is_view<OutputViewType>::value,
279 "OutputViewType must be a Kokkos::View specialization.");
280 constexpr int rank =
static_cast<int> (OutputViewType::rank);
281 static_assert (
static_cast<int> (InputViewType::rank) == rank,
282 "InputViewType and OutputViewType must have the same rank.");
283 static_assert (rank == 0 || rank == 1,
284 "InputViewType and OutputViewType must both have "
285 "rank 0 or rank 1.");
286 typedef typename OutputViewType::non_const_value_type packet_type;
287 static_assert (std::is_same<
typename OutputViewType::value_type,
289 "OutputViewType must be a nonconst Kokkos::View.");
290 static_assert (std::is_same<
typename InputViewType::non_const_value_type,
292 "InputViewType and OutputViewType must be Views "
293 "whose entries have the same type.");
295 static_assert (!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
296 "Input/Output views must be contiguous (not LayoutStride)");
297 static_assert (!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
298 "Input/Output views must be contiguous (not LayoutStride)");
300 return Impl::iallreduceImpl<InputViewType, OutputViewType> (sendbuf, recvbuf, op, comm);
303std::shared_ptr<CommRequest>
304iallreduce (
const int localValue,
306 const ::Teuchos::EReductionType op,
307 const ::Teuchos::Comm<int>& comm);
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Nonmember function that computes a residual Computes R = B - A * X.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace for new Tpetra features that are not ready for public release, but are ready for evaluation...
Namespace Tpetra contains the class and methods constituting the Tpetra library.