11#ifndef PBAT_GPU_IMPL_MATH_BLAS_H
12#define PBAT_GPU_IMPL_MATH_BLAS_H
15#include "pbat/gpu/impl/common/Cuda.cuh"
19#include <cuda/api/stream.hpp>
24#define CUBLAS_CHECK(err) \
26 cublasStatus_t err_ = (err); \
27 if (err_ != cublasStatus_t::CUBLAS_STATUS_SUCCESS) \
29 std::printf("cublas error %d at %s:%d\n", err_, __FILE__, __LINE__); \
30 throw std::runtime_error("cublas error"); \
40 cuda::device_t device =
41 common::Device(common::EDeviceSelectionPreference::HighestComputeCapability));
42 Blas(Blas
const&) =
delete;
43 Blas(Blas&&) =
delete;
44 Blas& operator=(Blas
const&) =
delete;
45 Blas& operator=(Blas&&) =
delete;
47 cublasHandle_t Handle()
const {
return mHandle; }
49 template <CVector TVectorX, CVector TVectorY,
class TScalar = TVectorX::ValueType>
51 Copy(TVectorX
const& x, TVectorY& y, std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
53 template <CVector TVectorX, CVector TVectorY,
class TScalar = TVectorX::ValueType>
57 TScalar alpha = TScalar(1),
58 std::shared_ptr<cuda::stream_t> stream =
nullptr);
64 class TScalar = TMatrixA::ValueType>
69 TScalar alpha = TScalar(1),
70 TScalar beta = TScalar(0),
71 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
73 template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar = TMatrixA::ValueType>
74 void UpperTriangularSolve(
77 TScalar alpha = TScalar(1),
78 bool bHasUnitDiagonal =
false,
79 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
81 template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar = TMatrixA::ValueType>
82 void LowerTriangularSolve(
85 TScalar alpha = TScalar(1),
86 bool bHasUnitDiagonal =
false,
87 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
89 template <CMatrix TMatrixA, CVector TVectorB,
class TScalar = TMatrixA::ValueType>
90 void UpperTriangularSolve(
93 TScalar alpha = TScalar(1),
94 bool bHasUnitDiagonal =
false,
95 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
97 template <CMatrix TMatrixA, CVector TVectorB,
class TScalar = TMatrixA::ValueType>
98 void LowerTriangularSolve(
101 TScalar alpha = TScalar(1),
102 bool bHasUnitDiagonal =
false,
103 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
105 template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar = TMatrixA::ValueType>
107 cublasSideMode_t side,
108 cublasFillMode_t uplo,
109 cublasDiagType_t diag,
112 TScalar alpha = TScalar(1),
113 std::shared_ptr<cuda::stream_t> stream =
nullptr)
const;
118 void TrySetStream(std::shared_ptr<cuda::stream_t> stream)
const;
121 cublasHandle_t mHandle;
122 cuda::device_t mDevice;
125template <CVector TVectorX, CVector TVectorY,
class TScalar>
126inline void Blas::Copy(TVectorX
const& x, TVectorY& y, std::shared_ptr<cuda::stream_t> stream)
const
128 TrySetStream(stream);
129 if constexpr (std::is_same_v<TScalar, float>)
132 cublasScopy(mHandle, x.Rows(), x.Raw(), x.Increment(), y.Raw(), y.Increment()));
134 if constexpr (std::is_same_v<TScalar, double>)
137 cublasDcopy(mHandle, x.Rows(), x.Raw(), x.Increment(), y.Raw(), y.Increment()));
141template <CVector TVectorX, CVector TVectorY,
class TScalar>
143Blas::Axpy(TVectorX
const& x, TVectorY& y, TScalar alpha, std::shared_ptr<cuda::stream_t> stream)
145 TrySetStream(stream);
146 if constexpr (std::is_same_v<TScalar, float>)
149 cublasSaxpy(mHandle, x.Rows(), &alpha, x.Raw(), x.Increment(), y.Raw(), y.Increment()));
151 if constexpr (std::is_same_v<TScalar, double>)
154 cublasDaxpy(mHandle, x.Rows(), &alpha, x.Raw(), x.Increment(), y.Raw(), y.Increment()));
158template <CMatrix TMatrixA, CVector TVectorX, CVector TVectorY,
class TScalar>
159inline void Blas::Gemv(
165 std::shared_ptr<cuda::stream_t> stream)
const
167 TrySetStream(stream);
168 if constexpr (std::is_same_v<TScalar, float>)
170 CUBLAS_CHECK(cublasSgemv(
177 A.LeadingDimensions(),
184 if constexpr (std::is_same_v<TScalar, double>)
186 CUBLAS_CHECK(cublasDgemv(
193 A.LeadingDimensions(),
202template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar>
203inline void Blas::UpperTriangularSolve(
207 bool bHasUnitDiagonal,
208 std::shared_ptr<cuda::stream_t> stream)
const
212 CUBLAS_FILL_MODE_UPPER,
213 bHasUnitDiagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT,
220template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar>
221inline void Blas::LowerTriangularSolve(
225 bool bHasUnitDiagonal,
226 std::shared_ptr<cuda::stream_t> stream)
const
230 CUBLAS_FILL_MODE_LOWER,
231 bHasUnitDiagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT,
238template <CMatrix TMatrixA, CVector TVectorB,
class TScalar>
239inline void Blas::UpperTriangularSolve(
243 bool bHasUnitDiagonal,
244 std::shared_ptr<cuda::stream_t> stream)
const
246 MatrixView<TScalar> BB(B);
247 UpperTriangularSolve(A, BB, alpha, bHasUnitDiagonal, stream);
250template <CMatrix TMatrixA, CVector TVectorB,
class TScalar>
251inline void Blas::LowerTriangularSolve(
255 bool bHasUnitDiagonal,
256 std::shared_ptr<cuda::stream_t> stream)
const
258 MatrixView<TScalar> BB(B);
259 LowerTriangularSolve(A, BB, alpha, bHasUnitDiagonal, stream);
262template <CMatrix TMatrixA, CMatrix TMatrixB,
class TScalar>
263inline void Blas::Trsm(
264 cublasSideMode_t side,
265 cublasFillMode_t uplo,
266 cublasDiagType_t diag,
270 std::shared_ptr<cuda::stream_t> stream)
const
272 TrySetStream(stream);
273 if constexpr (std::is_same_v<TScalar, double>)
285 A.LeadingDimensions(),
287 B.LeadingDimensions());
289 if constexpr (std::is_same_v<TScalar, float>)
301 A.LeadingDimensions(),
303 B.LeadingDimensions());
Matrix and vector cuBLAS abstractions.
GPU implementations of math functions.
Definition Blas.cu:7