11#ifndef PBAT_GPU_IMPL_MATH_LINEARSOLVER_CUH
12#define PBAT_GPU_IMPL_MATH_LINEARSOLVER_CUH
16#include "pbat/gpu/impl/common/Buffer.cuh"
17#include "pbat/gpu/impl/common/Cuda.cuh"
21#include <cuda/api/device.hpp>
22#include <cuda/api/stream.hpp>
23#include <cusolverDn.h>
27#define CUSOLVER_CHECK(err) \
29 cusolverStatus_t err_ = (err); \
30 if (err_ != cusolverStatus_t::CUSOLVER_STATUS_SUCCESS) \
32 std::printf("cusolver error %d at %s:%d\n", err_, __FILE__, __LINE__); \
33 throw std::runtime_error("cusolver error"); \
43 cuda::device_t device =
44 common::Device(common::EDeviceSelectionPreference::HighestComputeCapability));
46 LinearSolver(LinearSolver
const&) =
delete;
47 LinearSolver(LinearSolver&&) =
delete;
48 LinearSolver& operator=(LinearSolver
const&) =
delete;
49 LinearSolver& operator=(LinearSolver&&) =
delete;
51 cusolverDnHandle_t Handle()
const {
return mCusolverHandle; }
53 template <CMatrix TMatrixA,
class TScalar = TMatrixA::ValueType>
54 int GeqrfWorkspace(TMatrixA
const& A)
const;
56 template <CMatrix TMatrixQR, CVector TVectorTau,
class TScalar = TMatrixQR::ValueType>
61 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
63 template <CMatrix TMatrixQ, CMatrix TMatrixB,
class TScalar = TMatrixQ::ValueType>
64 int OrmqrWorkspace(TMatrixQ
const& Q, TMatrixB
const& B,
bool bMultiplyFromLeft =
true)
const;
66 template <CMatrix TMatrixQ, CVector TVectorB,
class TScalar = TMatrixQ::ValueType>
67 int OrmqrWorkspace(TMatrixQ
const& Q, TVectorB
const& B,
bool bMultiplyFromLeft =
true)
const;
73 class TScalar = TMatrixQ::ValueType>
76 TVectorTau
const& tau,
79 bool bMultiplyFromLeft =
true,
80 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
86 class TScalar = TMatrixQ::ValueType>
89 TVectorTau
const& tau,
92 bool bMultiplyFromLeft =
true,
93 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
98 void TrySetStream(std::shared_ptr<cuda::stream_t> stream)
const;
101 cusolverDnHandle_t mCusolverHandle;
102 cuda::device_t mDevice;
105template <CMatrix TMatrixA,
class TScalar>
106inline int LinearSolver::GeqrfWorkspace(TMatrixA
const& A)
const
109 if constexpr (std::is_same_v<TScalar, float>)
111 CUSOLVER_CHECK(cusolverDnSgeqrf_bufferSize(
116 A.LeadingDimensions(),
119 if constexpr (std::is_same_v<TScalar, double>)
121 CUSOLVER_CHECK(cusolverDnDgeqrf_bufferSize(
126 A.LeadingDimensions(),
132template <CMatrix TMatrixQR, CVector TVectorTau,
class TScalar>
133inline void LinearSolver::Geqrf(
136 common::Buffer<TScalar>& workspace,
137 std::shared_ptr<cuda::stream_t> stream)
const
139 TrySetStream(stream);
140 if constexpr (std::is_same_v<TScalar, float>)
142 CUSOLVER_CHECK(cusolverDnSgeqrf(
147 QR.LeadingDimensions(),
150 static_cast<int>(workspace.Size()),
153 if constexpr (std::is_same_v<TScalar, double>)
155 CUSOLVER_CHECK(cusolverDnDgeqrf(
160 QR.LeadingDimensions(),
163 static_cast<int>(workspace.Size()),
168template <CMatrix TMatrixQ, CMatrix TMatrixB,
class TScalar>
169int LinearSolver::OrmqrWorkspace(TMatrixQ
const& Q, TMatrixB
const& B,
bool bMultiplyFromLeft)
const
172 auto side = bMultiplyFromLeft ? cublasSideMode_t::CUBLAS_SIDE_LEFT :
173 cublasSideMode_t::CUBLAS_SIDE_RIGHT;
174 if constexpr (std::is_same_v<TScalar, float>)
176 CUSOLVER_CHECK(cusolverDnSormqr_bufferSize(
184 Q.LeadingDimensions(),
187 B.LeadingDimensions(),
190 if constexpr (std::is_same_v<TScalar, double>)
192 CUSOLVER_CHECK(cusolverDnDormqr_bufferSize(
200 Q.LeadingDimensions(),
203 B.LeadingDimensions(),
209template <CMatrix TMatrixQ, CVector TVectorB,
class TScalar>
211LinearSolver::OrmqrWorkspace(TMatrixQ
const& Q, TVectorB
const& B,
bool bMultiplyFromLeft)
const
213 MatrixView<TScalar> BB(
const_cast<TVectorB&
>(B));
214 return OrmqrWorkspace(Q, BB, bMultiplyFromLeft);
217template <CMatrix TMatrixQ, CVector TVectorTau, CMatrix TMatrixB,
class TScalar>
218inline void LinearSolver::Ormqr(
220 TVectorTau
const& tau,
222 common::Buffer<TScalar>& workspace,
223 bool bMultiplyFromLeft,
224 std::shared_ptr<cuda::stream_t> stream)
const
226 TrySetStream(stream);
227 auto side = bMultiplyFromLeft ? cublasSideMode_t::CUBLAS_SIDE_LEFT :
228 cublasSideMode_t::CUBLAS_SIDE_RIGHT;
229 if constexpr (std::is_same_v<TScalar, float>)
231 CUSOLVER_CHECK(cusolverDnSormqr(
239 Q.LeadingDimensions(),
242 B.LeadingDimensions(),
244 static_cast<int>(workspace.Size()),
247 if constexpr (std::is_same_v<TScalar, double>)
249 CUSOLVER_CHECK(cusolverDnDormqr(
257 Q.LeadingDimensions(),
260 B.LeadingDimensions(),
262 static_cast<int>(workspace.Size()),
267template <CMatrix TMatrixQ, CVector TVectorTau, CVector TVectorB,
class TScalar>
268inline void LinearSolver::Ormqr(
270 TVectorTau
const& tau,
272 common::Buffer<TScalar>& workspace,
273 bool bMultiplyFromLeft,
274 std::shared_ptr<cuda::stream_t> stream)
const
276 MatrixView<TScalar> BB(B);
277 Ormqr(Q, tau, BB, workspace, bMultiplyFromLeft, stream);
BLAS API wrapper over cuBLAS.
Matrix and vector cuBLAS abstractions.
GPU implementations of math functions.
Definition Blas.cu:7