80 if (
Get<bool>(currentLevel,
"Filtering") ==
false) {
81 GetOStream(
Runtime0) <<
"Filtered matrix is not being constructed as no filtering is being done" << std::endl;
82 Set(currentLevel,
"A", A);
87 bool lumping = pL.get<
bool>(
"filtered matrix: use lumping");
91 bool use_spread_lumping = pL.get<
bool>(
"filtered matrix: use spread lumping");
92 if (use_spread_lumping && (!lumping))
93 throw std::runtime_error(
"Must also request 'filtered matrix: use lumping' in order to use spread lumping");
95 if (use_spread_lumping) {
99 double DdomAllowGrowthRate = 1.1;
100 double DdomCap = 2.0;
101 if (use_spread_lumping) {
102 DdomAllowGrowthRate = pL.get<
double>(
"filtered matrix: spread lumping diag dom growth factor");
103 DdomCap = pL.get<
double>(
"filtered matrix: spread lumping diag dom cap");
105 bool use_root_stencil = lumping && pL.get<
bool>(
"filtered matrix: use root stencil");
106 if (use_root_stencil)
108 double dirichlet_threshold = pL.get<
double>(
"filtered matrix: Dirichlet threshold");
109 if (dirichlet_threshold >= 0.0)
110 GetOStream(
Runtime0) <<
"Filtering Dirichlet threshold of " << dirichlet_threshold << std::endl;
112 if (use_root_stencil || pL.get<
bool>(
"filtered matrix: reuse graph"))
119 FILE* f = fopen(
"graph.dat",
"w");
120 size_t numGRows = G->GetNodeNumVertices();
121 for (
size_t i = 0; i < numGRows; i++) {
123 auto indsG = G->getNeighborVertices(i);
124 for (
size_t j = 0; j < (size_t)indsG.length; j++) {
125 fprintf(f,
"%d %d 1.0\n", (
int)i, (
int)indsG(j));
131 RCP<ParameterList> fillCompleteParams(
new ParameterList);
132 fillCompleteParams->set(
"No Nonlocal Changes",
true);
134 RCP<Matrix> filteredA;
135 if (use_root_stencil) {
136 filteredA = MatrixFactory::Build(A->getCrsGraph());
137 filteredA->fillComplete(fillCompleteParams);
138 filteredA->resumeFill();
139 BuildNewUsingRootStencil(*A, *G, dirichlet_threshold, currentLevel, *filteredA, use_spread_lumping, DdomAllowGrowthRate, DdomCap);
140 filteredA->fillComplete(fillCompleteParams);
142 }
else if (pL.get<
bool>(
"filtered matrix: reuse graph")) {
143 filteredA = MatrixFactory::Build(A->getCrsGraph());
144 filteredA->resumeFill();
145 BuildReuse(*A, *G, (lumping != use_spread_lumping), dirichlet_threshold, *filteredA);
150 filteredA->fillComplete(fillCompleteParams);
153 filteredA = MatrixFactory::Build(A->getRowMap(), A->getColMap(), A->getLocalMaxNumRowEntries());
154 BuildNew(*A, *G, (lumping != use_spread_lumping), dirichlet_threshold, *filteredA);
158 filteredA->fillComplete(A->getDomainMap(), A->getRangeMap(), fillCompleteParams);
162 Xpetra::IO<SC, LO, GO, NO>::Write(
"filteredA.dat", *filteredA);
165 Xpetra::IO<SC, LO, GO, NO>::Write(
"A.dat", *A);
166 RCP<Matrix> origFilteredA = MatrixFactory::Build(A->getRowMap(), A->getColMap(), A->getLocalMaxNumRowEntries());
167 BuildNew(*A, *G, lumping, dirichlet_threshold, *origFilteredA);
168 if (use_spread_lumping)
ExperimentalLumping(*A, *origFilteredA, DdomAllowGrowthRate, DdomCap);
169 origFilteredA->fillComplete(A->getDomainMap(), A->getRangeMap(), fillCompleteParams);
170 Xpetra::IO<SC, LO, GO, NO>::Write(
"origFilteredA.dat", *origFilteredA);
173 filteredA->SetFixedBlockSize(A->GetFixedBlockSize());
175 if (pL.get<
bool>(
"filtered matrix: reuse eigenvalue")) {
180 filteredA->SetMaxEigenvalueEstimate(A->GetMaxEigenvalueEstimate());
183 Set(currentLevel,
"A", filteredA);
203 BuildReuse(
const Matrix& A,
const LWGraph& G,
const bool lumping,
double dirichletThresh, Matrix& filteredA)
const {
204 using TST =
typename Teuchos::ScalarTraits<SC>;
205 SC zero = TST::zero();
207 size_t blkSize = A.GetFixedBlockSize();
209 ArrayView<const LO> inds;
210 ArrayView<const SC> valsA;
211#ifdef ASSUME_DIRECT_ACCESS_TO_ROW
217 Array<char> filter(std::max(blkSize * G.
GetImportMap()->getLocalNumElements(),
218 A.getColMap()->getLocalNumElements()),
222 for (
size_t i = 0; i < numGRows; i++) {
225 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
226 for (
size_t k = 0; k < blkSize; k++)
227 filter[indsG(j) * blkSize + k] = 1;
229 for (
size_t k = 0; k < blkSize; k++) {
230 LO row = i * blkSize + k;
232 A.getLocalRowView(row, inds, valsA);
234 size_t nnz = inds.size();
238#ifdef ASSUME_DIRECT_ACCESS_TO_ROW
240 ArrayView<const SC> vals1;
241 filteredA.getLocalRowView(row, inds, vals1);
242 vals = ArrayView<SC>(
const_cast<SC*
>(vals1.getRawPtr()), nnz);
244 memcpy(vals.getRawPtr(), valsA.getRawPtr(), nnz *
sizeof(SC));
246 vals = Array<SC>(valsA);
249 SC ZERO = Teuchos::ScalarTraits<SC>::zero();
251 SC A_rowsum = ZERO, F_rowsum = ZERO;
252 for (LO l = 0; l < (LO)inds.size(); l++)
253 A_rowsum += valsA[l];
255 if (lumping ==
false) {
256 for (
size_t j = 0; j < nnz; j++)
257 if (!filter[inds[j]])
264 for (
size_t j = 0; j < nnz; j++) {
265 if (filter[inds[j]]) {
266 if (inds[j] == row) {
273 diagExtra += vals[j];
283 if (diagIndex != -1) {
285 vals[diagIndex] += diagExtra;
286 if (dirichletThresh >= 0.0 && TST::real(vals[diagIndex]) <= dirichletThresh) {
288 for (LO l = 0; l < (LO)nnz; l++)
291 vals[diagIndex] = TST::one();
296#ifndef ASSUME_DIRECT_ACCESS_TO_ROW
299 filteredA.replaceLocalValues(row, inds, vals);
304 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
305 for (
size_t k = 0; k < blkSize; k++)
306 filter[indsG(j) * blkSize + k] = 0;
407 BuildNewUsingRootStencil(
const Matrix& A,
const LWGraph& G,
double dirichletThresh,
Level& currentLevel, Matrix& filteredA,
bool use_spread_lumping,
double DdomAllowGrowthRate,
double DdomCap)
const {
408 using TST =
typename Teuchos::ScalarTraits<SC>;
409 using Teuchos::arcp_const_cast;
410 SC ZERO = Teuchos::ScalarTraits<SC>::zero();
411 SC ONE = Teuchos::ScalarTraits<SC>::one();
412 LO INVALID = Teuchos::OrdinalTraits<LO>::invalid();
417 LO numAggs = aggregates->GetNumAggregates();
420 size_t blkSize = A.GetFixedBlockSize();
421 size_t numRows = A.getMap()->getLocalNumElements();
422 ArrayView<const LO> indsA;
423 ArrayView<const SC> valsA;
424 ArrayRCP<const size_t> rowptr;
425 ArrayRCP<const LO> inds;
426 ArrayRCP<const SC> vals_const;
432 RCP<CrsMatrix> filteredAcrs =
dynamic_cast<const CrsMatrixWrap*
>(&filteredA)->getCrsMatrix();
433 filteredAcrs->getAllValues(rowptr, inds, vals_const);
434 vals = arcp_const_cast<SC>(vals_const);
435 Array<bool> vals_dropped_indicator(vals.size(),
false);
438 RCP<const Map> rowMap = A.getRowMap();
439 RCP<const Map> colMap = A.getColMap();
444 Array<LO> diagIndex(numRows, INVALID);
445 Array<SC> diagExtra(numRows, ZERO);
452 typename Aggregates::LO_view::HostMirror ptr_h, nodes_h, unaggregated_h;
454 aggregates->ComputeNodesInAggregate(nodesInAgg.ptr, nodesInAgg.nodes, nodesInAgg.unaggregated);
455 nodesInAgg.ptr_h = Kokkos::create_mirror_view(nodesInAgg.ptr);
456 nodesInAgg.nodes_h = Kokkos::create_mirror_view(nodesInAgg.nodes);
457 nodesInAgg.unaggregated_h = Kokkos::create_mirror_view(nodesInAgg.unaggregated);
458 Kokkos::deep_copy(nodesInAgg.ptr_h, nodesInAgg.ptr);
459 Kokkos::deep_copy(nodesInAgg.nodes_h, nodesInAgg.nodes);
460 Kokkos::deep_copy(nodesInAgg.unaggregated_h, nodesInAgg.unaggregated);
461 Teuchos::ArrayRCP<const LO> vertex2AggId = aggregates->GetVertex2AggId()->getData(0);
463 LO graphNumCols = G.
GetImportMap()->getLocalNumElements();
464 Array<bool> filter(graphNumCols,
false);
467 for (LO i = 0; i < (LO)nodesInAgg.unaggregated_h.extent(0); i++) {
468 for (LO m = 0; m < (LO)blkSize; m++) {
469 LO row = amalgInfo->ComputeLocalDOF(nodesInAgg.unaggregated_h(i), m);
470 if (row >= (LO)numRows)
continue;
471 size_t index_start = rowptr[row];
472 A.getLocalRowView(row, indsA, valsA);
473 for (LO k = 0; k < (LO)indsA.size(); k++) {
474 if (row == indsA[k]) {
475 vals[index_start + k] = ONE;
478 vals[index_start + k] = ZERO;
483 std::vector<LO> badCount(numAggs, 0);
487 for (LO i = 0; i < numAggs; i++)
488 maxAggSize = std::max(maxAggSize, nodesInAgg.ptr_h(i + 1) - nodesInAgg.ptr_h(i));
494 size_t numNewDrops = 0;
495 size_t numOldDrops = 0;
496 size_t numFixedDiags = 0;
497 size_t numSymDrops = 0;
499 for (LO i = 0; i < numAggs; i++) {
500 LO numNodesInAggregate = nodesInAgg.ptr_h(i + 1) - nodesInAgg.ptr_h(i);
501 if (numNodesInAggregate == 0)
continue;
504 LO root_node = INVALID;
505 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
506 if (aggregates->IsRoot(nodesInAgg.nodes_h(k))) {
507 root_node = nodesInAgg.nodes_h(k);
512 TEUCHOS_TEST_FOR_EXCEPTION(root_node == INVALID,
519 goodAggNeighbors.resize(0);
520 for (LO k = 0; k < (LO)goodNodeNeighbors.length; k++) {
521 goodAggNeighbors.push_back(vertex2AggId[goodNodeNeighbors(k)]);
528 badAggNeighbors.resize(0);
529 for (LO j = 0; j < (LO)blkSize; j++) {
530 LO row = amalgInfo->ComputeLocalDOF(root_node, j);
531 if (row >= (LO)numRows)
continue;
532 A.getLocalRowView(row, indsA, valsA);
533 for (LO k = 0; k < (LO)indsA.size(); k++) {
534 if ((indsA[k] < (LO)numRows) && (TST::magnitude(valsA[k]) != TST::magnitude(ZERO))) {
535 LO node = amalgInfo->ComputeLocalNode(indsA[k]);
536 LO agg = vertex2AggId[node];
537 if (!std::binary_search(goodAggNeighbors.begin(), goodAggNeighbors.end(), agg))
538 badAggNeighbors.push_back(agg);
547 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
549 for (LO kk = 0; kk < nodeNeighbors.length; kk++) {
550 if ((vertex2AggId[nodeNeighbors(kk)] >= 0) && (vertex2AggId[nodeNeighbors(kk)] < numAggs))
551 (badCount[vertex2AggId[nodeNeighbors(kk)]])++;
555 reallyBadAggNeighbors.resize(0);
556 for (LO k = 0; k < (LO)badAggNeighbors.size(); k++) {
557 if (badCount[badAggNeighbors[k]] <= 1) reallyBadAggNeighbors.push_back(badAggNeighbors[k]);
559 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
561 for (LO kk = 0; kk < nodeNeighbors.length; kk++) {
562 if ((vertex2AggId[nodeNeighbors(kk)] >= 0) && (vertex2AggId[nodeNeighbors(kk)] < numAggs))
563 badCount[vertex2AggId[nodeNeighbors(kk)]] = 0;
569 for (LO b = 0; b < (LO)reallyBadAggNeighbors.size(); b++) {
570 LO bad_agg = reallyBadAggNeighbors[b];
571 for (LO k = nodesInAgg.ptr_h(bad_agg); k < nodesInAgg.ptr_h(bad_agg + 1); k++) {
572 LO bad_node = nodesInAgg.nodes_h(k);
573 for (LO j = 0; j < (LO)blkSize; j++) {
574 LO bad_row = amalgInfo->ComputeLocalDOF(bad_node, j);
575 if (bad_row >= (LO)numRows)
continue;
576 size_t index_start = rowptr[bad_row];
577 A.getLocalRowView(bad_row, indsA, valsA);
578 for (LO l = 0; l < (LO)indsA.size(); l++) {
579 if (indsA[l] < (LO)numRows && vertex2AggId[amalgInfo->ComputeLocalNode(indsA[l])] == i && vals_dropped_indicator[index_start + l] ==
false) {
580 vals_dropped_indicator[index_start + l] =
true;
581 vals[index_start + l] = ZERO;
582 diagExtra[bad_row] += valsA[l];
593 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
594 LO row_node = nodesInAgg.nodes_h(k);
598 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
599 filter[indsG(j)] =
true;
601 for (LO m = 0; m < (LO)blkSize; m++) {
602 LO row = amalgInfo->ComputeLocalDOF(row_node, m);
603 if (row >= (LO)numRows)
continue;
604 size_t index_start = rowptr[row];
605 A.getLocalRowView(row, indsA, valsA);
607 for (LO l = 0; l < (LO)indsA.size(); l++) {
608 int col_node = amalgInfo->ComputeLocalNode(indsA[l]);
609 bool is_good = filter[col_node];
610 if (indsA[l] == row) {
612 vals[index_start + l] = valsA[l];
617 if (vals_dropped_indicator[index_start + l] ==
true) {
627 if (is_good && indsA[l] < (LO)numRows) {
628 int agg = vertex2AggId[col_node];
629 if (std::binary_search(reallyBadAggNeighbors.begin(), reallyBadAggNeighbors.end(), agg))
634 vals[index_start + l] = valsA[l];
636 if (!filter[col_node])
640 diagExtra[row] += valsA[l];
641 vals[index_start + l] = ZERO;
642 vals_dropped_indicator[index_start + l] =
true;
649 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
650 filter[indsG(j)] =
false;
655 if (!use_spread_lumping) {
657 for (LO row = 0; row < (LO)numRows; row++) {
658 if (diagIndex[row] != INVALID) {
659 size_t index_start = rowptr[row];
660 size_t diagIndexInMatrix = index_start + diagIndex[row];
662 vals[diagIndexInMatrix] += diagExtra[row];
663 SC A_rowsum = ZERO, A_absrowsum = ZERO, F_rowsum = ZERO;
665 if ((dirichletThresh >= 0.0 && TST::real(vals[diagIndexInMatrix]) <= dirichletThresh) || TST::real(vals[diagIndexInMatrix]) == ZERO) {
667 A.getLocalRowView(row, indsA, valsA);
671 for (LO l = 0; l < (LO)indsA.size(); l++) {
672 A_rowsum += valsA[l];
673 A_absrowsum += std::abs(valsA[l]);
675 for (LO l = 0; l < (LO)indsA.size(); l++)
676 F_rowsum += vals[index_start + l];
690 for (
size_t l = rowptr[row]; l < rowptr[row + 1]; l++) {
693 vals[diagIndexInMatrix] = TST::one();
703 for (LO row = 0; row < (LO)numRows; row++) {
704 filteredA.replaceLocalValues(row, inds(rowptr[row], rowptr[row + 1] - rowptr[row]), vals(rowptr[row], rowptr[row + 1] - rowptr[row]));
708 size_t g_newDrops = 0, g_oldDrops = 0, g_fixedDiags = 0;
710 MueLu_sumAll(A.getRowMap()->getComm(), numNewDrops, g_newDrops);
711 MueLu_sumAll(A.getRowMap()->getComm(), numOldDrops, g_oldDrops);
712 MueLu_sumAll(A.getRowMap()->getComm(), numFixedDiags, g_fixedDiags);
713 GetOStream(
Runtime0) <<
"Filtering out " << g_newDrops <<
" edges, in addition to the " << g_oldDrops <<
" edges dropped earlier" << std::endl;
714 GetOStream(
Runtime0) <<
"Fixing " << g_fixedDiags <<
" zero diagonal values" << std::endl;
731 using TST =
typename Teuchos::ScalarTraits<SC>;
732 SC zero = TST::zero();
735 ArrayView<const LO> inds;
736 ArrayView<const SC> vals;
737 ArrayView<const LO> finds;
740 SC PosOffSum, NegOffSum, PosOffDropSum, NegOffDropSum;
741 SC diag, gamma, alpha;
742 LO NumPosKept, NumNegKept;
746 SC PosFilteredSum, NegFilteredSum;
749 SC rho = as<Scalar>(irho);
750 SC rho2 = as<Scalar>(irho2);
752 for (LO row = 0; row < (LO)A.getRowMap()->getLocalNumElements(); row++) {
753 noLumpDdom = as<Scalar>(10000.0);
766 ArrayView<const SC> tvals;
767 A.getLocalRowView(row, inds, vals);
768 size_t nnz = inds.size();
769 if (nnz == 0)
continue;
770 filteredA.getLocalRowView(row, finds, tvals);
772 fvals = ArrayView<SC>(
const_cast<SC*
>(tvals.getRawPtr()), nnz);
774 LO diagIndex = -1, fdiagIndex = -1;
778 PosOffDropSum = zero;
779 NegOffDropSum = zero;
785 for (
size_t j = 0; j < nnz; j++) {
786 if (inds[j] == row) {
790 if (TST::real(vals[j]) > TST::real(zero))
791 PosOffSum += vals[j];
793 NegOffSum += vals[j];
796 PosOffDropSum = PosOffSum;
797 NegOffDropSum = NegOffSum;
801 for (
size_t jj = 0; jj < (size_t)finds.size(); jj++) {
802 while (inds[j] != finds[jj]) j++;
804 if (finds[jj] == row)
807 if (TST::real(vals[j]) > TST::real(zero)) {
808 PosOffDropSum -= fvals[jj];
809 if (TST::real(fvals[jj]) != TST::real(zero)) NumPosKept++;
811 NegOffDropSum -= fvals[jj];
812 if (TST::real(fvals[jj]) != TST::real(zero)) NumNegKept++;
818 if (TST::magnitude(diag) != TST::magnitude(zero))
819 noLumpDdom = (PosOffSum - NegOffSum) / diag;
824 Target = rho * noLumpDdom;
825 if (TST::magnitude(Target) <= TST::magnitude(rho)) Target = rho2;
827 PosFilteredSum = PosOffSum - PosOffDropSum;
828 NegFilteredSum = NegOffSum - NegOffDropSum;
838 diag += PosOffDropSum;
841 gamma = -NegOffDropSum - PosFilteredSum;
843 if (TST::real(gamma) < TST::real(zero)) {
851 if (fdiagIndex != -1) fvals[fdiagIndex] = diag;
853 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
854 while (inds[j] != finds[jj]) j++;
856 if ((j != diagIndex) && (TST::real(vals[j]) > TST::real(zero)) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)))
857 fvals[jj] = -gamma * (vals[j] / PosFilteredSum);
869 bool flipPosOffDiagsToNeg =
false;
881 if ((TST::real(diag) > TST::real(gamma)) &&
882 (TST::real((-NegFilteredSum) / (diag - gamma)) <= TST::real(Target))) {
889 if (fdiagIndex != -1) fvals[fdiagIndex] = diag - gamma;
890 }
else if (NumNegKept > 0) {
896 numer = -NegFilteredSum - Target * (diag - gamma);
897 denom = gamma * (Target - TST::one());
919 if (TST::magnitude(denom) < TST::magnitude(numer))
922 alpha = numer / denom;
923 if (TST::real(alpha) < TST::real(zero)) alpha = zero;
924 if (TST::real(diag) < TST::real((one - alpha) * gamma)) alpha = TST::one();
928 if (fdiagIndex != -1) fvals[fdiagIndex] = diag - (one - alpha) * gamma;
938 SC temp = (NegFilteredSum + alpha * gamma) / NegFilteredSum;
940 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
941 while (inds[j] != finds[jj]) j++;
943 if ((jj != fdiagIndex) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)) &&
944 (TST::real(vals[j]) < TST::real(zero)))
945 fvals[jj] = temp * vals[j];
950 if (NumPosKept > 0) {
953 flipPosOffDiagsToNeg =
true;
956 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
957 while (inds[j] != finds[jj]) j++;
959 if ((j != diagIndex) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)) &&
960 (TST::real(vals[j]) > TST::real(zero)))
961 fvals[jj] = -gamma / ((SC)NumPosKept);
966 if (!flipPosOffDiagsToNeg) {
971 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
972 while (inds[j] != finds[jj]) j++;
974 if ((jj != fdiagIndex) && (TST::real(vals[j]) > TST::real(zero))) fvals[jj] = zero;