14#ifndef _ZOLTAN2_XPETRACRSMATRIXADAPTER_HPP_
15#define _ZOLTAN2_XPETRACRSMATRIXADAPTER_HPP_
22#include <Xpetra_CrsMatrix.hpp>
52template <
typename User,
typename UserCoord=User>
56#ifndef DOXYGEN_SHOULD_SKIP_THIS
63 using xmatrix_t = Xpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t>;
65 using userCoord_t = UserCoord;
75 int nWeightsPerRow=0);
127 return matrix_->getLocalNumRows();
131 return matrix_->getLocalNumCols();
135 return matrix_->getLocalNumEntries();
142 ArrayView<const gno_t> rowView = rowMap_->getLocalElementList();
143 rowIds = rowView.getRawPtr();
147 ArrayRCP<const gno_t> &colIds)
const
149 ArrayRCP< const lno_t > localColumnIds;
150 ArrayRCP<const scalar_t> values;
151 matrix_->getAllValues(offsets,localColumnIds,values);
156 ArrayRCP<const gno_t> &colIds,
157 ArrayRCP<const scalar_t> &values)
const {
158 ArrayRCP< const lno_t > localColumnIds;
159 matrix_->getAllValues(offsets,localColumnIds,values);
169 if(idx<0 || idx >= nWeightsPerRow_)
171 std::ostringstream emsg;
172 emsg << __FILE__ <<
":" << __LINE__
173 <<
" Invalid row weight index " << idx << std::endl;
174 throw std::runtime_error(emsg.str());
178 rowWeights_[idx].getStridedList(length,
weights, stride);
183 template <
typename Adapter>
187 template <
typename Adapter>
193 RCP<const User> inmatrix_;
194 RCP<const xmatrix_t> matrix_;
195 RCP<const Xpetra::Map<lno_t, gno_t, node_t> > rowMap_;
196 RCP<const Xpetra::Map<lno_t, gno_t, node_t> > colMap_;
198 ArrayRCP<gno_t> columnIds_;
201 ArrayRCP<StridedData<lno_t, scalar_t> > rowWeights_;
202 ArrayRCP<bool> numNzWeight_;
204 bool mayHaveDiagonalEntries;
211template <
typename User,
typename UserCoord>
213 const RCP<const User> &inmatrix,
int nWeightsPerRow):
214 inmatrix_(inmatrix), matrix_(), rowMap_(), colMap_(),
216 nWeightsPerRow_(nWeightsPerRow), rowWeights_(), numNzWeight_(),
217 mayHaveDiagonalEntries(true)
221 matrix_ = rcp_const_cast<const xmatrix_t>(
226 rowMap_ = matrix_->getRowMap();
227 colMap_ = matrix_->getColMap();
229 size_t nrows = matrix_->getLocalNumRows();
230 size_t nnz = matrix_->getLocalNumEntries();
233 ArrayRCP< const offset_t > offset;
234 ArrayRCP< const lno_t > localColumnIds;
235 ArrayRCP< const scalar_t > values;
236 matrix_->getAllValues(offset,localColumnIds,values);
237 columnIds_.resize(nnz, 0);
239 for(
offset_t i = 0; i < offset[nrows]; i++){
240 columnIds_[i] = colMap_->getGlobalElement(localColumnIds[i]);
243 if (nWeightsPerRow_ > 0){
244 rowWeights_ = arcp(
new input_t [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
245 numNzWeight_ = arcp(
new bool [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
246 for (
int i=0; i < nWeightsPerRow_; i++)
247 numNzWeight_[i] =
false;
252template <
typename User,
typename UserCoord>
254 const scalar_t *weightVal,
int stride,
int idx)
260 std::ostringstream emsg;
261 emsg << __FILE__ <<
"," << __LINE__
262 <<
" error: setWeights not yet supported for"
263 <<
" columns or nonzeros."
265 throw std::runtime_error(emsg.str());
270template <
typename User,
typename UserCoord>
272 const scalar_t *weightVal,
int stride,
int idx)
275 if(idx<0 || idx >= nWeightsPerRow_)
277 std::ostringstream emsg;
278 emsg << __FILE__ <<
":" << __LINE__
279 <<
" Invalid row weight index " << idx << std::endl;
280 throw std::runtime_error(emsg.str());
284 ArrayRCP<const scalar_t> weightV(weightVal, 0, nvtx*stride,
false);
285 rowWeights_[idx] = input_t(weightV, stride);
289template <
typename User,
typename UserCoord>
297 std::ostringstream emsg;
298 emsg << __FILE__ <<
"," << __LINE__
299 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
300 <<
" columns" << std::endl;
301 throw std::runtime_error(emsg.str());
306template <
typename User,
typename UserCoord>
310 if(idx<0 || idx >= nWeightsPerRow_)
312 std::ostringstream emsg;
313 emsg << __FILE__ <<
":" << __LINE__
314 <<
" Invalid row weight index " << idx << std::endl;
315 throw std::runtime_error(emsg.str());
319 numNzWeight_[idx] =
true;
323template <
typename User,
typename UserCoord>
324 template <
typename Adapter>
326 const User &in, User *&out,
331 ArrayRCP<gno_t> importList;
335 (solution,
this, importList);
341 importList.getRawPtr());
342 out =
const_cast<User *
>(outPtr.get());
347template <
typename User,
typename UserCoord>
348 template <
typename Adapter>
350 const User &in, RCP<User> &out,
355 ArrayRCP<gno_t> importList;
359 (solution,
this, importList);
365 importList.getRawPtr());
Zoltan2::BasicUserTypes< zscalar_t, zlno_t, zgno_t > user_t
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the MatrixAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
Traits of Xpetra classes, including migration method.
enum MatrixEntityType getPrimaryEntityType() const
A PartitioningSolution is a solution to a partitioning problem.
The StridedData class manages lists of weights or coordinates.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
bool CRSViewAvailable() const
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
void applyPartitioningSolution(const User &in, RCP< User > &out, const PartitioningSolution< Adapter > &solution) const
bool useNumNonzerosAsRowWeight(int idx) const
Indicate whether row weight with index idx should be the global number of nonzeros in the row.
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const
void getRowIDsView(const gno_t *&rowIds) const
size_t getLocalNumColumns() const
Returns the number of columns on this process.
size_t getLocalNumEntries() const
Returns the number of nonzeros on this process.
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
size_t getLocalNumRows() const
Returns the number of rows on this process.
int getNumWeightsPerRow() const
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds, ArrayRCP< const scalar_t > &values) const
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row.
XpetraCrsMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
Created by mbenlioglu on Aug 31, 2020.
size_t getImportList(const PartitioningSolution< SolutionAdapter > &solution, const DataAdapter *const data, ArrayRCP< typename DataAdapter::gno_t > &imports)
From a PartitioningSolution, get a list of IDs to be imported. Assumes part numbers in PartitioningSo...
static RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows)
Migrate the object Given a user object and a new row distribution, create and return a new user objec...
static RCP< User > convertToXpetra(const RCP< User > &a)
Convert the object to its Xpetra wrapped version.