Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_scaleBlockDiagonal.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
11#define TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
12
13#include "TpetraCore_config.h"
14#include "Tpetra_CrsMatrix.hpp"
15#include "Teuchos_ScalarTraits.hpp"
17#include "KokkosBatched_Util.hpp"
18#include "KokkosBatched_LU_Decl.hpp"
19#include "KokkosBatched_LU_Serial_Impl.hpp"
20#include "KokkosBatched_Trsm_Decl.hpp"
21#include "KokkosBatched_Trsm_Serial_Impl.hpp"
22
23
30
31namespace Tpetra {
32namespace Details {
33
34
35template<class MultiVectorType>
36void inverseScaleBlockDiagonal(MultiVectorType & blockDiagonal, bool doTranspose, MultiVectorType & multiVectorToBeScaled) {
37 using LO = typename MultiVectorType::local_ordinal_type;
38 using range_type = Kokkos::RangePolicy<typename MultiVectorType::node_type::execution_space, LO>;
39 using namespace KokkosBatched;
40 typename MultiVectorType::impl_scalar_type SC_one = Teuchos::ScalarTraits<typename MultiVectorType::impl_scalar_type>::one();
41
42 // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
44 TEUCHOS_TEST_FOR_EXCEPTION(!blockDiagonal.getMap()->isSameAs(*multiVectorToBeScaled.getMap()),
45 std::runtime_error, "Tpetra::Details::scaledBlockDiagonal was given incompatible maps");
46 }
47
48 LO numrows = blockDiagonal.getLocalLength();
49 LO blocksize = blockDiagonal.getNumVectors();
50 LO numblocks = numrows / blocksize;
51
52 // Get Kokkos versions of objects
53
54 auto blockDiag = blockDiagonal.getLocalViewDevice(Access::OverwriteAll);
55 auto toScale = multiVectorToBeScaled.getLocalViewDevice(Access::ReadWrite);
56
57 typedef Algo::Level3::Unblocked algo_type;
58 Kokkos::parallel_for("scaleBlockDiagonal",range_type(0,numblocks),KOKKOS_LAMBDA(const LO i){
59 Kokkos::pair<LO,LO> row_range(i*blocksize,(i+1)*blocksize);
60 auto A = Kokkos::subview(blockDiag,row_range, Kokkos::ALL());
61 auto B = Kokkos::subview(toScale, row_range, Kokkos::ALL());
62
63 // Factor
64 SerialLU<algo_type>::invoke(A);
65
66 if(doTranspose) {
67 // Solve U^T
68 SerialTrsm<Side::Left,Uplo::Upper,Trans::Transpose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
69 // Solver L^T
70 SerialTrsm<Side::Left,Uplo::Lower,Trans::Transpose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
71 }
72 else {
73 // Solve L
74 SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,algo_type>::invoke(SC_one,A,B);
75 // Solve U
76 SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,algo_type>::invoke(SC_one,A,B);
77 }
78 });
79}
80
81} // namespace Details
82} // namespace Tpetra
83
84#endif // TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
static bool debug()
Whether Tpetra is in debug mode.
Nonmember function that computes a residual Computes R = B - A * X.
Namespace Tpetra contains the class and methods constituting the Tpetra library.