PhysicsBasedAnimationToolkit 0.0.10
Cross-platform C++20 library of algorithms and data structures commonly used in computer graphics research on physically-based simulation.
Loading...
Searching...
No Matches
Blas.cuh
Go to the documentation of this file.
1
10
11#ifndef PBAT_GPU_IMPL_MATH_BLAS_H
12#define PBAT_GPU_IMPL_MATH_BLAS_H
13
14#include "Matrix.cuh"
15#include "pbat/gpu/impl/common/Cuda.cuh"
16
17#include <cstdio>
18#include <cublas_v2.h>
19#include <cuda/api/stream.hpp>
20#include <exception>
21#include <memory>
22#include <type_traits>
23
24#define CUBLAS_CHECK(err) \
25 { \
26 cublasStatus_t err_ = (err); \
27 if (err_ != cublasStatus_t::CUBLAS_STATUS_SUCCESS) \
28 { \
29 std::printf("cublas error %d at %s:%d\n", err_, __FILE__, __LINE__); \
30 throw std::runtime_error("cublas error"); \
31 } \
32 }
33
34namespace pbat::gpu::impl::math {
35
36class Blas
37{
38 public:
39 Blas(
40 cuda::device_t device =
41 common::Device(common::EDeviceSelectionPreference::HighestComputeCapability));
42 Blas(Blas const&) = delete;
43 Blas(Blas&&) = delete;
44 Blas& operator=(Blas const&) = delete;
45 Blas& operator=(Blas&&) = delete;
46
47 cublasHandle_t Handle() const { return mHandle; }
48
49 template <CVector TVectorX, CVector TVectorY, class TScalar = TVectorX::ValueType>
50 void
51 Copy(TVectorX const& x, TVectorY& y, std::shared_ptr<cuda::stream_t> stream = nullptr) const;
52
53 template <CVector TVectorX, CVector TVectorY, class TScalar = TVectorX::ValueType>
54 void Axpy(
55 TVectorX const& x,
56 TVectorY& y,
57 TScalar alpha = TScalar(1),
58 std::shared_ptr<cuda::stream_t> stream = nullptr);
59
60 template <
61 CMatrix TMatrixA,
62 CVector TVectorX,
63 CVector TVectorY,
64 class TScalar = TMatrixA::ValueType>
65 void Gemv(
66 TMatrixA const& A,
67 TVectorX const& x,
68 TVectorY& y,
69 TScalar alpha = TScalar(1),
70 TScalar beta = TScalar(0),
71 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
72
73 template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar = TMatrixA::ValueType>
74 void UpperTriangularSolve(
75 TMatrixA const& A,
76 TMatrixB& B,
77 TScalar alpha = TScalar(1),
78 bool bHasUnitDiagonal = false,
79 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
80
81 template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar = TMatrixA::ValueType>
82 void LowerTriangularSolve(
83 TMatrixA const& A,
84 TMatrixB& B,
85 TScalar alpha = TScalar(1),
86 bool bHasUnitDiagonal = false,
87 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
88
89 template <CMatrix TMatrixA, CVector TVectorB, class TScalar = TMatrixA::ValueType>
90 void UpperTriangularSolve(
91 TMatrixA const& A,
92 TVectorB& B,
93 TScalar alpha = TScalar(1),
94 bool bHasUnitDiagonal = false,
95 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
96
97 template <CMatrix TMatrixA, CVector TVectorB, class TScalar = TMatrixA::ValueType>
98 void LowerTriangularSolve(
99 TMatrixA const& A,
100 TVectorB& B,
101 TScalar alpha = TScalar(1),
102 bool bHasUnitDiagonal = false,
103 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
104
105 template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar = TMatrixA::ValueType>
106 void Trsm(
107 cublasSideMode_t side,
108 cublasFillMode_t uplo,
109 cublasDiagType_t diag,
110 TMatrixA const& A,
111 TMatrixB& B,
112 TScalar alpha = TScalar(1),
113 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
114
115 ~Blas();
116
117 protected:
118 void TrySetStream(std::shared_ptr<cuda::stream_t> stream) const;
119
120 private:
121 cublasHandle_t mHandle;
122 cuda::device_t mDevice;
123};
124
125template <CVector TVectorX, CVector TVectorY, class TScalar>
126inline void Blas::Copy(TVectorX const& x, TVectorY& y, std::shared_ptr<cuda::stream_t> stream) const
127{
128 TrySetStream(stream);
129 if constexpr (std::is_same_v<TScalar, float>)
130 {
131 CUBLAS_CHECK(
132 cublasScopy(mHandle, x.Rows(), x.Raw(), x.Increment(), y.Raw(), y.Increment()));
133 }
134 if constexpr (std::is_same_v<TScalar, double>)
135 {
136 CUBLAS_CHECK(
137 cublasDcopy(mHandle, x.Rows(), x.Raw(), x.Increment(), y.Raw(), y.Increment()));
138 }
139}
140
141template <CVector TVectorX, CVector TVectorY, class TScalar>
142inline void
143Blas::Axpy(TVectorX const& x, TVectorY& y, TScalar alpha, std::shared_ptr<cuda::stream_t> stream)
144{
145 TrySetStream(stream);
146 if constexpr (std::is_same_v<TScalar, float>)
147 {
148 CUBLAS_CHECK(
149 cublasSaxpy(mHandle, x.Rows(), &alpha, x.Raw(), x.Increment(), y.Raw(), y.Increment()));
150 }
151 if constexpr (std::is_same_v<TScalar, double>)
152 {
153 CUBLAS_CHECK(
154 cublasDaxpy(mHandle, x.Rows(), &alpha, x.Raw(), x.Increment(), y.Raw(), y.Increment()));
155 }
156}
157
158template <CMatrix TMatrixA, CVector TVectorX, CVector TVectorY, class TScalar>
159inline void Blas::Gemv(
160 TMatrixA const& A,
161 TVectorX const& x,
162 TVectorY& y,
163 TScalar alpha,
164 TScalar beta,
165 std::shared_ptr<cuda::stream_t> stream) const
166{
167 TrySetStream(stream);
168 if constexpr (std::is_same_v<TScalar, float>)
169 {
170 CUBLAS_CHECK(cublasSgemv(
171 mHandle,
172 A.Operation(),
173 A.Rows(),
174 A.Cols(),
175 &alpha,
176 A.Raw(),
177 A.LeadingDimensions(),
178 x.Raw(),
179 x.Increment(),
180 &beta,
181 y.Raw(),
182 y.Increment()));
183 }
184 if constexpr (std::is_same_v<TScalar, double>)
185 {
186 CUBLAS_CHECK(cublasDgemv(
187 mHandle,
188 A.Operation(),
189 A.Rows(),
190 A.Cols(),
191 &alpha,
192 A.Raw(),
193 A.LeadingDimensions(),
194 x.Raw(),
195 x.Increment(),
196 &beta,
197 y.Raw(),
198 y.Increment()));
199 }
200}
201
202template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar>
203inline void Blas::UpperTriangularSolve(
204 TMatrixA const& A,
205 TMatrixB& B,
206 TScalar alpha,
207 bool bHasUnitDiagonal,
208 std::shared_ptr<cuda::stream_t> stream) const
209{
210 Trsm(
211 CUBLAS_SIDE_LEFT,
212 CUBLAS_FILL_MODE_UPPER,
213 bHasUnitDiagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT,
214 A,
215 B,
216 alpha,
217 stream);
218}
219
220template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar>
221inline void Blas::LowerTriangularSolve(
222 TMatrixA const& A,
223 TMatrixB& B,
224 TScalar alpha,
225 bool bHasUnitDiagonal,
226 std::shared_ptr<cuda::stream_t> stream) const
227{
228 Trsm(
229 CUBLAS_SIDE_LEFT,
230 CUBLAS_FILL_MODE_LOWER,
231 bHasUnitDiagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT,
232 A,
233 B,
234 alpha,
235 stream);
236}
237
238template <CMatrix TMatrixA, CVector TVectorB, class TScalar>
239inline void Blas::UpperTriangularSolve(
240 TMatrixA const& A,
241 TVectorB& B,
242 TScalar alpha,
243 bool bHasUnitDiagonal,
244 std::shared_ptr<cuda::stream_t> stream) const
245{
246 MatrixView<TScalar> BB(B);
247 UpperTriangularSolve(A, BB, alpha, bHasUnitDiagonal, stream);
248}
249
250template <CMatrix TMatrixA, CVector TVectorB, class TScalar>
251inline void Blas::LowerTriangularSolve(
252 TMatrixA const& A,
253 TVectorB& B,
254 TScalar alpha,
255 bool bHasUnitDiagonal,
256 std::shared_ptr<cuda::stream_t> stream) const
257{
258 MatrixView<TScalar> BB(B);
259 LowerTriangularSolve(A, BB, alpha, bHasUnitDiagonal, stream);
260}
261
262template <CMatrix TMatrixA, CMatrix TMatrixB, class TScalar>
263inline void Blas::Trsm(
264 cublasSideMode_t side,
265 cublasFillMode_t uplo,
266 cublasDiagType_t diag,
267 TMatrixA const& A,
268 TMatrixB& B,
269 TScalar alpha,
270 std::shared_ptr<cuda::stream_t> stream) const
271{
272 TrySetStream(stream);
273 if constexpr (std::is_same_v<TScalar, double>)
274 {
275 cublasDtrsm(
276 mHandle,
277 side,
278 uplo,
279 A.Operation(),
280 diag,
281 B.Rows(),
282 B.Cols(),
283 &alpha,
284 A.Raw(),
285 A.LeadingDimensions(),
286 B.Raw(),
287 B.LeadingDimensions());
288 }
289 if constexpr (std::is_same_v<TScalar, float>)
290 {
291 cublasStrsm(
292 mHandle,
293 side,
294 uplo,
295 A.Operation(),
296 diag,
297 B.Rows(),
298 B.Cols(),
299 &alpha,
300 A.Raw(),
301 A.LeadingDimensions(),
302 B.Raw(),
303 B.LeadingDimensions());
304 }
305}
306
307} // namespace pbat::gpu::impl::math
308
309#endif // PBAT_GPU_IMPL_MATH_BLAS_H
Matrix and vector cuBLAS abstractions.
Definition Matrix.cuh:24
Definition Matrix.cuh:37
GPU implementations of math functions.
Definition Blas.cu:7