Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_iallreduce.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
12
28
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
31#ifdef HAVE_TPETRACORE_MPI
34#endif // HAVE_TPETRACORE_MPI
35#include "Tpetra_Details_temporaryViewUtils.hpp"
37#include "Kokkos_Core.hpp"
38#include <memory>
39#include <stdexcept>
40#include <type_traits>
41#include <functional>
42
43#ifndef DOXYGEN_SHOULD_SKIP_THIS
44namespace Teuchos {
45 // forward declaration of Comm
46 template<class OrdinalType> class Comm;
47} // namespace Teuchos
48#endif // NOT DOXYGEN_SHOULD_SKIP_THIS
49
50namespace Tpetra {
51namespace Details {
52
53#ifdef HAVE_TPETRACORE_MPI
54std::string getMpiErrorString (const int errCode);
55#endif
56
64public:
66 virtual ~CommRequest () {}
67
72 virtual void wait () {}
73
77 virtual void cancel () {}
78};
79
80// Don't rely on anything in this namespace.
81namespace Impl {
82
84std::shared_ptr<CommRequest>
85emptyCommRequest ();
86
87#ifdef HAVE_TPETRACORE_MPI
88#if MPI_VERSION >= 3
89template<typename InputViewType, typename OutputViewType, typename ResultViewType>
90struct MpiRequest : public CommRequest
91{
92 MpiRequest(const InputViewType& send, const OutputViewType& recv, const ResultViewType& result, MPI_Request req_)
93 : sendBuf(send), recvBuf(recv), resultBuf(result), req(req_)
94 {}
95
96 ~MpiRequest()
97 {
98 //this is a no-op if wait() or cancel() have already been called
99 cancel();
100 }
101
106 void wait () override
107 {
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait (&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION
111 (err != MPI_SUCCESS, std::runtime_error,
112 "MpiCommRequest::wait: MPI_Wait failed with error \""
113 << getMpiErrorString (err));
114 // MPI_Wait should set the MPI_Request to MPI_REQUEST_NULL on
115 // success. We'll do it here just to be conservative.
116 req = MPI_REQUEST_NULL;
117 //Since recvBuf contains the result, copy it to the user's resultBuf.
118 Kokkos::deep_copy(resultBuf, recvBuf);
119 }
120 }
121
125 void cancel () override
126 {
127 //BMK: Per https://www.mpi-forum.org/docs/mpi-3.1/mpi31-report/node126.htm,
128 //MPI_Cancel cannot be used for collectives like iallreduce.
129 req = MPI_REQUEST_NULL;
130 }
131
132private:
133 InputViewType sendBuf;
134 OutputViewType recvBuf;
135 ResultViewType resultBuf;
136 //This request is active if and only if req != MPI_REQUEST_NULL.
137 MPI_Request req;
138};
139
142MPI_Request
143iallreduceRaw (const void* sendbuf,
144 void* recvbuf,
145 const int count,
146 MPI_Datatype mpiDatatype,
147 const Teuchos::EReductionType op,
148 MPI_Comm comm);
149#endif
150
152void
153allreduceRaw (const void* sendbuf,
154 void* recvbuf,
155 const int count,
156 MPI_Datatype mpiDatatype,
157 const Teuchos::EReductionType op,
158 MPI_Comm comm);
159
160template<class InputViewType, class OutputViewType>
161std::shared_ptr<CommRequest>
162iallreduceImpl (const InputViewType& sendbuf,
163 const OutputViewType& recvbuf,
164 const ::Teuchos::EReductionType op,
165 const ::Teuchos::Comm<int>& comm)
166{
167 using Packet = typename InputViewType::non_const_value_type;
168 if(comm.getSize() == 1)
169 {
170 Kokkos::deep_copy(recvbuf, sendbuf);
171 return emptyCommRequest();
172 }
173 Packet examplePacket;
174 MPI_Datatype mpiDatatype = sendbuf.extent(0) ?
175 MpiTypeTraits<Packet>::getType (examplePacket) :
176 MPI_BYTE;
177 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
178 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos (comm);
179 //Note BMK: Nonblocking collectives like iallreduce cannot use GPU buffers.
180 //See https://www.open-mpi.org/faq/?category=runcuda#mpi-cuda-support
181 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
182 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
183 std::shared_ptr<CommRequest> req;
184 //Next, if input/output alias and comm is an intercomm, make a deep copy of input.
185 //Not possible to do in-place allreduce for intercomm.
186 if(isInterComm(comm) && sendMPI.data() == recvMPI.data())
187 {
188 //Can't do in-place collective on an intercomm,
189 //so use a separate 1D copy as the input.
190 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing("tempInput"), sendMPI.extent(0));
191 for(size_t i = 0; i < sendMPI.extent(0); i++)
192 tempInput(i) = sendMPI.data()[i];
193#if MPI_VERSION >= 3
194 //MPI 3+: use async allreduce
195 MPI_Request mpiReq = iallreduceRaw((const void*) tempInput.data(), (void*) recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
196 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
197#else
198 //Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
199 allreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
200 Kokkos::deep_copy(recvbuf, recvMPI);
201 req = emptyCommRequest();
202#endif
203 }
204 else
205 {
206#if MPI_VERSION >= 3
207 //MPI 3+: use async allreduce
208 MPI_Request mpiReq = iallreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
209 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
210#else
211 //Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
212 allreduceRaw((const void*) sendMPI.data(), (void*) recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
213 Kokkos::deep_copy(recvbuf, recvMPI);
214 req = emptyCommRequest();
215#endif
216 }
217 if(datatypeNeedsFree)
218 MPI_Type_free(&mpiDatatype);
219 return req;
220}
221
222#else
223
224//No MPI: reduction is always the same as input.
225template<class InputViewType, class OutputViewType>
226std::shared_ptr<CommRequest>
227iallreduceImpl (const InputViewType& sendbuf,
228 const OutputViewType& recvbuf,
229 const ::Teuchos::EReductionType,
230 const ::Teuchos::Comm<int>&)
231{
232 Kokkos::deep_copy(recvbuf, sendbuf);
233 return emptyCommRequest();
234}
235
236#endif // HAVE_TPETRACORE_MPI
237
238} // namespace Impl
239
240//
241// SKIP DOWN TO HERE
242//
243
269template<class InputViewType, class OutputViewType>
270std::shared_ptr<CommRequest>
271iallreduce (const InputViewType& sendbuf,
272 const OutputViewType& recvbuf,
273 const ::Teuchos::EReductionType op,
274 const ::Teuchos::Comm<int>& comm)
275{
276 static_assert (Kokkos::is_view<InputViewType>::value,
277 "InputViewType must be a Kokkos::View specialization.");
278 static_assert (Kokkos::is_view<OutputViewType>::value,
279 "OutputViewType must be a Kokkos::View specialization.");
280 constexpr int rank = static_cast<int> (OutputViewType::rank);
281 static_assert (static_cast<int> (InputViewType::rank) == rank,
282 "InputViewType and OutputViewType must have the same rank.");
283 static_assert (rank == 0 || rank == 1,
284 "InputViewType and OutputViewType must both have "
285 "rank 0 or rank 1.");
286 typedef typename OutputViewType::non_const_value_type packet_type;
287 static_assert (std::is_same<typename OutputViewType::value_type,
288 packet_type>::value,
289 "OutputViewType must be a nonconst Kokkos::View.");
290 static_assert (std::is_same<typename InputViewType::non_const_value_type,
291 packet_type>::value,
292 "InputViewType and OutputViewType must be Views "
293 "whose entries have the same type.");
294 //Make sure layouts are contiguous (don't accept strided 1D view)
295 static_assert (!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
296 "Input/Output views must be contiguous (not LayoutStride)");
297 static_assert (!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
298 "Input/Output views must be contiguous (not LayoutStride)");
299
300 return Impl::iallreduceImpl<InputViewType, OutputViewType> (sendbuf, recvbuf, op, comm);
301}
302
303std::shared_ptr<CommRequest>
304iallreduce (const int localValue,
305 int& globalValue,
306 const ::Teuchos::EReductionType op,
307 const ::Teuchos::Comm<int>& comm);
308
309} // namespace Details
310} // namespace Tpetra
311
312#endif // TPETRA_DETAILS_IALLREDUCE_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Declaration of Tpetra::Details::extractMpiCommFromTeuchos.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Nonmember function that computes a residual Computes R = B - A * X.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace for new Tpetra features that are not ready for public release, but are ready for evaluation...
Namespace Tpetra contains the class and methods constituting the Tpetra library.