Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_rightScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
11#define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
12
19
20#include "TpetraCore_config.h"
21#include "Kokkos_Core.hpp"
22#include "Kokkos_ArithTraits.hpp"
23#include <type_traits>
24
25namespace Tpetra {
26namespace Details {
27
36template<class LocalSparseMatrixType,
37 class ScalingFactorsViewType,
38 const bool divide>
40public:
41 using val_type =
42 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
43 using mag_type = typename ScalingFactorsViewType::non_const_value_type;
44 static_assert (ScalingFactorsViewType::rank == 1,
45 "scalingFactors must be a rank-1 Kokkos::View.");
46 using device_type = typename LocalSparseMatrixType::device_type;
47 using LO = typename LocalSparseMatrixType::ordinal_type;
48 using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
49
60 RightScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
61 const ScalingFactorsViewType& scalingFactors,
62 const bool assumeSymmetric) :
63 A_lcl_ (A_lcl),
64 scalingFactors_ (scalingFactors),
65 assumeSymmetric_ (assumeSymmetric)
66 {}
67
68 KOKKOS_INLINE_FUNCTION void
69 operator () (const typename policy_type::member_type & team) const
70 {
71 using KAM = Kokkos::ArithTraits<mag_type>;
72
73 const LO lclRow = team.league_rank();
74 auto curRow = A_lcl_.row (lclRow);
75 const LO numEnt = curRow.length;
76 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
77 const LO lclColInd = curRow.colidx(k);
78 const mag_type curColNorm = scalingFactors_(lclColInd);
79 // Users are responsible for any divisions or multiplications by
80 // zero.
81 const mag_type scalingFactor = assumeSymmetric_ ?
82 KAM::sqrt (curColNorm) : curColNorm;
83 if (divide) { // constexpr, so should get compiled out
84 curRow.value(k) = curRow.value(k) / scalingFactor;
85 }
86 else {
87 curRow.value(k) = curRow.value(k) * scalingFactor;
88 }
89 });
90 }
91
92private:
93 LocalSparseMatrixType A_lcl_;
94 typename ScalingFactorsViewType::const_type scalingFactors_;
95 bool assumeSymmetric_;
96};
97
111template<class LocalSparseMatrixType, class ScalingFactorsViewType>
112void
113rightScaleLocalCrsMatrix (const LocalSparseMatrixType& A_lcl,
114 const ScalingFactorsViewType& scalingFactors,
115 const bool assumeSymmetric,
116 const bool divide = true)
117{
118 using device_type = typename LocalSparseMatrixType::device_type;
119 using execution_space = typename device_type::execution_space;
120 using LO = typename LocalSparseMatrixType::ordinal_type;
121 using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
122
123 const LO lclNumRows = A_lcl.numRows ();
124 if (divide) {
125 using functor_type =
126 RightScaleLocalCrsMatrix<LocalSparseMatrixType,
127 typename ScalingFactorsViewType::const_type, true>;
128 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
129 Kokkos::parallel_for ("rightScaleLocalCrsMatrix",
130 policy_type (lclNumRows, Kokkos::AUTO), functor);
131 }
132 else {
133 using functor_type =
134 RightScaleLocalCrsMatrix<LocalSparseMatrixType,
135 typename ScalingFactorsViewType::const_type, false>;
136 functor_type functor (A_lcl, scalingFactors, assumeSymmetric);
137 Kokkos::parallel_for ("rightScaleLocalCrsMatrix",
138 policy_type (lclNumRows, Kokkos::AUTO), functor);
139 }
140}
141
142} // namespace Details
143} // namespace Tpetra
144
145#endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Nonmember function that computes a residual Computes R = B - A * X.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.
Namespace Tpetra contains the class and methods constituting the Tpetra library.