139 AMGXOperator(
const Teuchos::RCP<Tpetra::CrsMatrix<SC, LO, GO, NO> >& inA, Teuchos::ParameterList& paramListIn) {
140 RCP<const Teuchos::Comm<int> > comm = inA->getRowMap()->getComm();
141 int numProcs = comm->getSize();
142 int myRank = comm->getRank();
144 RCP<Teuchos::Time> amgxTimer = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: initialize");
152 AMGX_SAFE_CALL(AMGX_install_signal_handler());
153 Teuchos::ParameterList configs = paramListIn.sublist(
"amgx:params",
true);
154 if (configs.isParameter(
"json file")) {
155 AMGX_SAFE_CALL(AMGX_config_create_from_file(&
Config_, (
const char*)&configs.get<std::string>(
"json file")[0]));
157 std::ostringstream oss;
159 ParameterList::ConstIterator itr;
160 for (itr = configs.begin(); itr != configs.end(); ++itr) {
161 const std::string& name = configs.name(itr);
162 const ParameterEntry& entry = configs.entry(itr);
166 std::string configString = oss.str();
167 if (configString ==
"") {
171 AMGX_SAFE_CALL(AMGX_config_create(&
Config_, configString.c_str()));
183 RCP<const Teuchos::MpiComm<int> > tmpic = Teuchos::rcp_dynamic_cast<const Teuchos::MpiComm<int> >(comm->duplicate());
186 RCP<const Teuchos::OpaqueWrapper<MPI_Comm> > rawMpiComm = tmpic->getRawMpiComm();
187 MPI_Comm mpiComm = *rawMpiComm;
196 cudaGetDeviceCount(&numGPUDevices);
197 int device[] = {(comm->getRank() % numGPUDevices)};
199 AMGX_config_add_parameters(&
Config_,
"communicator=MPI");
207 AMGX_Mode mode = AMGX_mode_dDDI;
214 amgxTimer->incrementNumCalls();
216 std::vector<int> amgx2muelu;
220 RCP<const Tpetra::Import<LO, GO, NO> > importer = inA->getCrsGraph()->getImporter();
224 Tpetra::Distributor distributor = importer->getDistributor();
226 Array<int> sendRanks = distributor.getProcsTo();
227 Array<int> recvRanks = distributor.getProcsFrom();
229 std::sort(sendRanks.begin(), sendRanks.end());
230 std::sort(recvRanks.begin(), recvRanks.end());
233 if (sendRanks.size() != recvRanks.size()) {
236 for (
int i = 0; i < sendRanks.size(); i++) {
237 if (recvRanks[i] != sendRanks[i])
243 "AMGX requires that the processors that we send to and receive from are the same. "
244 "This is not the case: we send to {"
245 << sendRanks <<
"} and receive from {" << recvRanks <<
"}");
247 int num_neighbors = sendRanks.size();
248 const int* neighbors = &sendRanks[0];
254 Tpetra::Details::HashTable<int, int> hashTable(3 * num_neighbors);
255 for (
int i = 0; i < num_neighbors; i++)
256 hashTable.add(neighbors[i], i);
259 ArrayView<const int> exportLIDs = importer->getExportLIDs();
260 ArrayView<const int> exportPIDs = importer->getExportPIDs();
261 Array<int> importPIDs;
262 Tpetra::Import_Util::getPids(*importer, importPIDs,
true );
265 RCP<const Map> rowMap = inA->getRowMap();
266 RCP<const Map> colMap = inA->getColMap();
268 int N = rowMap->getLocalNumElements(), Nc = colMap->getLocalNumElements();
271 int numUniqExports = 0;
272 for (
int i = 0; i < exportLIDs.size(); i++)
278 int localOffset = 0, exportOffset = N - numUniqExports;
280 for (
int i = 0; i < exportLIDs.size(); i++)
284 for (
int i = 0; i < N; i++)
288 int importOffset = N;
289 for (
int k = 0; k < num_neighbors; k++)
290 for (
int i = 0; i < importPIDs.size(); i++)
291 if (importPIDs[i] != -1 && hashTable.get(importPIDs[i]) == k)
299 std::vector<std::vector<int> > sendDatas(num_neighbors);
300 std::vector<int> send_sizes(num_neighbors, 0);
301 for (
int i = 0; i < exportPIDs.size(); i++) {
302 int index = hashTable.get(exportPIDs[i]);
303 sendDatas[index].push_back(
muelu2amgx_[exportLIDs[i]]);
308 std::vector<const int*> send_maps(num_neighbors);
309 for (
int i = 0; i < num_neighbors; i++)
310 send_maps[i] = &(sendDatas[i][0]);
316 std::vector<std::vector<int> > recvDatas(num_neighbors);
317 std::vector<int> recv_sizes(num_neighbors, 0);
318 for (
int i = 0; i < importPIDs.size(); i++)
319 if (importPIDs[i] != -1) {
320 int index = hashTable.get(importPIDs[i]);
326 std::vector<const int*> recv_maps(num_neighbors);
327 for (
int i = 0; i < num_neighbors; i++)
328 recv_maps[i] = &(recvDatas[i][0]);
333 AMGX_SAFE_CALL(AMGX_matrix_comm_from_maps_one_ring(
A_, 1, num_neighbors, neighbors, &send_sizes[0], &send_maps[0], &recv_sizes[0], &recv_maps[0]));
335 AMGX_vector_bind(
X_,
A_);
336 AMGX_vector_bind(
Y_,
A_);
339 RCP<Teuchos::Time> matrixTransformTimer = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: transform matrix");
340 matrixTransformTimer->start();
342 ArrayRCP<const size_t> ia_s;
343 ArrayRCP<const int> ja;
344 ArrayRCP<const double> a;
345 inA->getAllValues(ia_s, ja, a);
347 ArrayRCP<int> ia(ia_s.size());
348 for (
int i = 0; i < ia.size(); i++)
349 ia[i] = Teuchos::as<int>(ia_s[i]);
351 N_ = inA->getLocalNumRows();
352 int nnz = inA->getLocalNumEntries();
354 matrixTransformTimer->stop();
355 matrixTransformTimer->incrementNumCalls();
359 RCP<Teuchos::Time> matrixTimer = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: transfer matrix CPU->GPU");
360 matrixTimer->start();
362 AMGX_matrix_upload_all(
A_,
N_, nnz, 1, 1, &ia[0], &ja[0], &a[0], NULL);
366 std::vector<int> ia_new(ia.size());
367 std::vector<int> ja_new(ja.size());
368 std::vector<double> a_new(a.size());
371 for (
int i = 0; i <
N_; i++) {
372 int oldRow = amgx2muelu[i];
374 ia_new[i + 1] = ia_new[i] + (ia[oldRow + 1] - ia[oldRow]);
376 for (
int j = ia[oldRow]; j < ia[oldRow + 1]; j++) {
377 int offset = j - ia[oldRow];
379 a_new[ia_new[i] + offset] = a[j];
387 for (
int j = ia_new[i]; j < ia_new[i + 1] - 1; j++)
388 if (ja_new[j] > ja_new[j + 1]) {
389 std::swap(ja_new[j], ja_new[j + 1]);
390 std::swap(a_new[j], a_new[j + 1]);
393 }
while (swapped ==
true);
396 AMGX_matrix_upload_all(
A_,
N_, nnz, 1, 1, &ia_new[0], &ja_new[0], &a_new[0], NULL);
399 matrixTimer->incrementNumCalls();
404 RCP<Teuchos::Time> realSetupTimer = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: setup (total)");
405 realSetupTimer->start();
407 realSetupTimer->stop();
408 realSetupTimer->incrementNumCalls();
410 vectorTimer1_ = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: transfer vectors CPU->GPU");
411 vectorTimer2_ = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: transfer vector GPU->CPU");
412 solverTimer_ = Teuchos::TimeMonitor::getNewTimer(
"MueLu: AMGX: Solve (total)");