Amesos2 - Direct Sparse Solver Interfaces Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1// @HEADER
2// *****************************************************************************
3// Amesos2: Templated Direct Sparse Solver Package
4//
5// Copyright 2011 NTESS and the Amesos2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11#define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
12
14#include "Amesos2_cuSOLVER_TypeMap.hpp"
15
16#include <cuda.h>
17#include <cusolverSp.h>
18#include <cusolverDn.h>
19#include <cusparse.h>
20#include <cusolverSp_LOWLEVEL_PREVIEW.h>
21
22#ifdef HAVE_TEUCHOS_COMPLEX
23#include <cuComplex.h>
24#endif
25
26namespace Amesos2 {
27
28 template <>
29 struct FunctionMap<cuSOLVER,double>
30 {
31 static cusolverStatus_t bufferInfo(
32 cusolverSpHandle_t handle,
33 int size,
34 int nnz,
35 cusparseMatDescr_t & desc,
36 const double * values,
37 const int * rowPtr,
38 const int * colIdx,
39 csrcholInfo_t & chol_info,
40 size_t * internalDataInBytes,
41 size_t * workspaceInBytes)
42 {
43 cusolverStatus_t status =
44 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
46 return status;
47 }
48
49 static cusolverStatus_t numeric(
50 cusolverSpHandle_t handle,
51 int size,
52 int nnz,
53 cusparseMatDescr_t & desc,
54 const double * values,
55 const int * rowPtr,
56 const int * colIdx,
57 csrcholInfo_t & chol_info,
58 void * buffer)
59 {
60 cusolverStatus_t status = cusolverSpDcsrcholFactor(
61 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
62 return status;
63 }
64
65 static cusolverStatus_t solve(
66 cusolverSpHandle_t handle,
67 int size,
68 const double * b,
69 double * x,
70 csrcholInfo_t & chol_info,
71 void * buffer)
72 {
73 cusolverStatus_t status = cusolverSpDcsrcholSolve(
74 handle, size, b, x, chol_info, buffer);
75 return status;
76 }
77 };
78
79 template <>
80 struct FunctionMap<cuSOLVER,float>
81 {
82 static cusolverStatus_t bufferInfo(
83 cusolverSpHandle_t handle,
84 int size,
85 int nnz,
86 cusparseMatDescr_t & desc,
87 const float * values,
88 const int * rowPtr,
89 const int * colIdx,
90 csrcholInfo_t & chol_info,
91 size_t * internalDataInBytes,
92 size_t * workspaceInBytes)
93 {
94 cusolverStatus_t status =
95 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
96 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
97 return status;
98 }
99
100 static cusolverStatus_t numeric(
101 cusolverSpHandle_t handle,
102 int size,
103 int nnz,
104 cusparseMatDescr_t & desc,
105 const float * values,
106 const int * rowPtr,
107 const int * colIdx,
108 csrcholInfo_t & chol_info,
109 void * buffer)
110 {
111 cusolverStatus_t status = cusolverSpScsrcholFactor(
112 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
113 return status;
114 }
115
116 static cusolverStatus_t solve(
117 cusolverSpHandle_t handle,
118 int size,
119 const float * b,
120 float * x,
121 csrcholInfo_t & chol_info,
122 void * buffer)
123 {
124 cusolverStatus_t status = cusolverSpScsrcholSolve(
125 handle, size, b, x, chol_info, buffer);
126 return status;
127 }
128 };
129
130#ifdef HAVE_TEUCHOS_COMPLEX
131 template <>
132 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
133 {
134 static cusolverStatus_t bufferInfo(
135 cusolverSpHandle_t handle,
136 int size,
137 int nnz,
138 cusparseMatDescr_t & desc,
139 const void * values,
140 const int * rowPtr,
141 const int * colIdx,
142 csrcholInfo_t & chol_info,
143 size_t * internalDataInBytes,
144 size_t * workspaceInBytes)
145 {
146 typedef cuDoubleComplex scalar_t;
147 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
148 cusolverStatus_t status =
149 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
150 cu_values, rowPtr, colIdx, chol_info,
151 internalDataInBytes, workspaceInBytes);
152 return status;
153 }
154
155 static cusolverStatus_t numeric(
156 cusolverSpHandle_t handle,
157 int size,
158 int nnz,
159 cusparseMatDescr_t & desc,
160 const void * values,
161 const int * rowPtr,
162 const int * colIdx,
163 csrcholInfo_t & chol_info,
164 void * buffer)
165 {
166 typedef cuDoubleComplex scalar_t;
167 const scalar_t * cu_values =
168 reinterpret_cast<const scalar_t *>(values);
169 cusolverStatus_t status = cusolverSpZcsrcholFactor(
170 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
171 return status;
172 }
173
174 static cusolverStatus_t solve(
175 cusolverSpHandle_t handle,
176 int size,
177 const void * b,
178 void * x,
179 csrcholInfo_t & chol_info,
180 void * buffer)
181 {
182 typedef cuDoubleComplex scalar_t;
183 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
184 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
185 cusolverStatus_t status = cusolverSpZcsrcholSolve(
186 handle, size, cu_b, cu_x, chol_info, buffer);
187 return status;
188 }
189 };
190
191 template <>
192 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
193 {
194 static cusolverStatus_t bufferInfo(
195 cusolverSpHandle_t handle,
196 int size,
197 int nnz,
198 cusparseMatDescr_t & desc,
199 const void * values,
200 const int * rowPtr,
201 const int * colIdx,
202 csrcholInfo_t & chol_info,
203 size_t * internalDataInBytes,
204 size_t * workspaceInBytes)
205 {
206 typedef cuFloatComplex scalar_t;
207 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
208 cusolverStatus_t status =
209 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
210 cu_values, rowPtr, colIdx, chol_info,
211 internalDataInBytes, workspaceInBytes);
212 return status;
213 }
214
215 static cusolverStatus_t numeric(
216 cusolverSpHandle_t handle,
217 int size,
218 int nnz,
219 cusparseMatDescr_t & desc,
220 const void * values,
221 const int * rowPtr,
222 const int * colIdx,
223 csrcholInfo_t & chol_info,
224 void * buffer)
225 {
226 typedef cuFloatComplex scalar_t;
227 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
228 cusolverStatus_t status = cusolverSpCcsrcholFactor(
229 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
230 return status;
231 }
232
233 static cusolverStatus_t solve(
234 cusolverSpHandle_t handle,
235 int size,
236 const void * b,
237 void * x,
238 csrcholInfo_t & chol_info,
239 void * buffer)
240 {
241 typedef cuFloatComplex scalar_t;
242 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
243 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
244 cusolverStatus_t status = cusolverSpCcsrcholSolve(
245 handle, size, cu_b, cu_x, chol_info, buffer);
246 return status;
247 }
248 };
249#endif
250
251} // end namespace Amesos2
252
253#endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
Declaration of Function mapping class for Amesos2.
Amesos2 interface to cuSOLVER.
Definition Amesos2_cuSOLVER_decl.hpp:26
const int size
Definition klu2_simple.cpp:50
Passes functions to TPL functions based on type.
Definition Amesos2_FunctionMap.hpp:43