Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_TSQRADAPTOR_HPP
11#define TPETRA_TSQRADAPTOR_HPP
12
16
17#include "Tpetra_ConfigDefs.hpp"
18
19#ifdef HAVE_TPETRA_TSQR
20# include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
21# include "Tsqr.hpp" // full (internode + intranode) TSQR
22# include "Tsqr_DistTsqr.hpp" // internode TSQR
23// Subclass of TSQR::MessengerBase, implemented using Teuchos
24// communicator template helper functions
25# include "Tsqr_TeuchosMessenger.hpp"
26# include "Tpetra_MultiVector.hpp"
27# include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
28# include <stdexcept>
29
30namespace Tpetra {
31
53 template<class MV>
54 class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
55 public:
56 using scalar_type = typename MV::scalar_type;
57 using ordinal_type = typename MV::local_ordinal_type;
58 using dense_matrix_type =
59 Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
60 using magnitude_type =
61 typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
62
63 private:
64 using node_tsqr_factory_type =
65 TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
66 typename MV::device_type>;
67 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
68 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
69 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
70
71 TSQR::MatView<ordinal_type, scalar_type>
72 get_mat_view(MV& X)
73 {
74 TEUCHOS_ASSERT( ! tsqr_.is_null() );
75 // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
76 // object in Q to make sure it is the same communicator as the
77 // one we are using in our dist_tsqr_type implementation.
78
79 const ordinal_type lclNumRows(X.getLocalLength());
80 const ordinal_type numCols(X.getNumVectors());
81 scalar_type* X_ptr = nullptr;
82 // LAPACK and BLAS functions require "LDA" >= 1, even if the
83 // corresponding matrix dimension is zero.
84 ordinal_type X_stride = 1;
85 if(tsqr_->wants_device_memory()) {
86 auto X_view = X.getLocalViewDevice(Access::ReadWrite);
87 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
88 X_stride = static_cast<ordinal_type>(X_view.stride(1));
89 if(X_stride == 0) {
90 X_stride = ordinal_type(1); // see note above
91 }
92 }
93 else {
94 auto X_view = X.getLocalViewHost(Access::ReadWrite);
95 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
96 X_stride = static_cast<ordinal_type>(X_view.stride(1));
97 if(X_stride == 0) {
98 X_stride = ordinal_type(1); // see note above
99 }
100 }
101 using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
102 return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
103 }
104
105 public:
112 TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist) :
113 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
114 distTsqr_(new dist_tsqr_type),
115 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
116 {
117 setParameterList(plist);
118 }
119
121 TsqrAdaptor() :
122 nodeTsqr_(node_tsqr_factory_type::getNodeTsqr()),
123 distTsqr_(new dist_tsqr_type),
124 tsqr_(new tsqr_type(nodeTsqr_, distTsqr_))
125 {
126 setParameterList(Teuchos::null);
127 }
128
130 Teuchos::RCP<const Teuchos::ParameterList>
131 getValidParameters() const
132 {
133 if(defaultParams_.is_null()) {
134 auto params = Teuchos::parameterList("TSQR implementation");
135 params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
136 params->set("DistTsqr", *(distTsqr_->getValidParameters()));
137 defaultParams_ = params;
138 }
139 return defaultParams_;
140 }
141
167 void
168 setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist)
169 {
170 auto params = plist.is_null() ?
171 Teuchos::parameterList(*getValidParameters()) : plist;
172 using Teuchos::sublist;
173 nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
174 distTsqr_->setParameterList(sublist(params, "DistTsqr"));
175
176 this->setMyParamList(params);
177 }
178
200 void
201 factorExplicit(MV& A,
202 MV& Q,
203 dense_matrix_type& R,
204 const bool forceNonnegativeDiagonal=false)
205 {
206 TEUCHOS_TEST_FOR_EXCEPTION
207 (! A.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
208 "factorExplicit: Input MultiVector A must have constant stride.");
209 TEUCHOS_TEST_FOR_EXCEPTION
210 (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
211 "factorExplicit: Input MultiVector Q must have constant stride.");
212 prepareTsqr(Q); // Finish initializing TSQR.
213 TEUCHOS_ASSERT( ! tsqr_.is_null() );
214
215 auto A_view = get_mat_view(A);
216 auto Q_view = get_mat_view(Q);
217 constexpr bool contiguousCacheBlocks = false;
218 tsqr_->factorExplicitRaw(A_view.extent(0),
219 A_view.extent(1),
220 A_view.data(), A_view.stride(1),
221 Q_view.data(), Q_view.stride(1),
222 R.values(), R.stride(),
223 contiguousCacheBlocks,
224 forceNonnegativeDiagonal);
225 }
226
257 int
258 revealRank(MV& Q,
259 dense_matrix_type& R,
260 const magnitude_type& tol)
261 {
262 TEUCHOS_TEST_FOR_EXCEPTION
263 (! Q.isConstantStride(), std::invalid_argument, "TsqrAdaptor::"
264 "revealRank: Input MultiVector Q must have constant stride.");
265 prepareTsqr(Q); // Finish initializing TSQR.
266
267 auto Q_view = get_mat_view(Q);
268 constexpr bool contiguousCacheBlocks = false;
269 return tsqr_->revealRankRaw(Q_view.extent(0),
270 Q_view.extent(1),
271 Q_view.data(), Q_view.stride(1),
272 R.values(), R.stride(),
273 tol, contiguousCacheBlocks);
274 }
275
276 private:
278 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
279
281 Teuchos::RCP<dist_tsqr_type> distTsqr_;
282
284 Teuchos::RCP<tsqr_type> tsqr_;
285
287 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
288
290 bool ready_ = false;
291
312 void
313 prepareTsqr(const MV& mv)
314 {
315 if(! ready_) {
316 prepareDistTsqr(mv);
317 ready_ = true;
318 }
319 }
320
327 void
328 prepareDistTsqr(const MV& mv)
329 {
330 using Teuchos::RCP;
331 using Teuchos::rcp_implicit_cast;
332 using mess_type = TSQR::TeuchosMessenger<scalar_type>;
333 using base_mess_type = TSQR::MessengerBase<scalar_type>;
334
335 auto comm = mv.getMap()->getComm();
336 RCP<mess_type> mess(new mess_type(comm));
337 auto messBase = rcp_implicit_cast<base_mess_type>(mess);
338 distTsqr_->init(messBase);
339 }
340 };
341
342} // namespace Tpetra
343
344#endif // HAVE_TPETRA_TSQR
345
346#endif // TPETRA_TSQRADAPTOR_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.