PhysicsBasedAnimationToolkit 0.0.10
Cross-platform C++20 library of algorithms and data structures commonly used in computer graphics research on physically-based simulation.
Loading...
Searching...
No Matches
LinearSolver.cuh
Go to the documentation of this file.
1
10
11#ifndef PBAT_GPU_IMPL_MATH_LINEARSOLVER_CUH
12#define PBAT_GPU_IMPL_MATH_LINEARSOLVER_CUH
13
14#include "Blas.cuh"
15#include "Matrix.cuh"
16#include "pbat/gpu/impl/common/Buffer.cuh"
17#include "pbat/gpu/impl/common/Cuda.cuh"
18
19#include <algorithm>
20#include <cstdio>
21#include <cuda/api/device.hpp>
22#include <cuda/api/stream.hpp>
23#include <cusolverDn.h>
24#include <memory>
25#include <type_traits>
26
27#define CUSOLVER_CHECK(err) \
28 { \
29 cusolverStatus_t err_ = (err); \
30 if (err_ != cusolverStatus_t::CUSOLVER_STATUS_SUCCESS) \
31 { \
32 std::printf("cusolver error %d at %s:%d\n", err_, __FILE__, __LINE__); \
33 throw std::runtime_error("cusolver error"); \
34 } \
35 }
36
37namespace pbat::gpu::impl::math {
38
39class LinearSolver
40{
41 public:
42 LinearSolver(
43 cuda::device_t device =
44 common::Device(common::EDeviceSelectionPreference::HighestComputeCapability));
45
46 LinearSolver(LinearSolver const&) = delete;
47 LinearSolver(LinearSolver&&) = delete;
48 LinearSolver& operator=(LinearSolver const&) = delete;
49 LinearSolver& operator=(LinearSolver&&) = delete;
50
51 cusolverDnHandle_t Handle() const { return mCusolverHandle; }
52
53 template <CMatrix TMatrixA, class TScalar = TMatrixA::ValueType>
54 int GeqrfWorkspace(TMatrixA const& A) const;
55
56 template <CMatrix TMatrixQR, CVector TVectorTau, class TScalar = TMatrixQR::ValueType>
57 void Geqrf(
58 TMatrixQR& QR,
59 TVectorTau& tau,
60 common::Buffer<TScalar>& workspace,
61 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
62
63 template <CMatrix TMatrixQ, CMatrix TMatrixB, class TScalar = TMatrixQ::ValueType>
64 int OrmqrWorkspace(TMatrixQ const& Q, TMatrixB const& B, bool bMultiplyFromLeft = true) const;
65
66 template <CMatrix TMatrixQ, CVector TVectorB, class TScalar = TMatrixQ::ValueType>
67 int OrmqrWorkspace(TMatrixQ const& Q, TVectorB const& B, bool bMultiplyFromLeft = true) const;
68
69 template <
70 CMatrix TMatrixQ,
71 CVector TVectorTau,
72 CMatrix TMatrixB,
73 class TScalar = TMatrixQ::ValueType>
74 void Ormqr(
75 TMatrixQ const& Q,
76 TVectorTau const& tau,
77 TMatrixB& B,
78 common::Buffer<TScalar>& workspace,
79 bool bMultiplyFromLeft = true,
80 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
81
82 template <
83 CMatrix TMatrixQ,
84 CVector TVectorTau,
85 CVector TVectorB,
86 class TScalar = TMatrixQ::ValueType>
87 void Ormqr(
88 TMatrixQ const& Q,
89 TVectorTau const& tau,
90 TVectorB& B,
91 common::Buffer<TScalar>& workspace,
92 bool bMultiplyFromLeft = true,
93 std::shared_ptr<cuda::stream_t> stream = nullptr) const;
94
95 ~LinearSolver();
96
97 protected:
98 void TrySetStream(std::shared_ptr<cuda::stream_t> stream) const;
99
100 private:
101 cusolverDnHandle_t mCusolverHandle;
102 cuda::device_t mDevice;
103};
104
105template <CMatrix TMatrixA, class TScalar>
106inline int LinearSolver::GeqrfWorkspace(TMatrixA const& A) const
107{
108 int workspace{0};
109 if constexpr (std::is_same_v<TScalar, float>)
110 {
111 CUSOLVER_CHECK(cusolverDnSgeqrf_bufferSize(
112 mCusolverHandle,
113 A.Rows(),
114 A.Cols(),
115 nullptr,
116 A.LeadingDimensions(),
117 &workspace));
118 }
119 if constexpr (std::is_same_v<TScalar, double>)
120 {
121 CUSOLVER_CHECK(cusolverDnDgeqrf_bufferSize(
122 mCusolverHandle,
123 A.Rows(),
124 A.Cols(),
125 nullptr,
126 A.LeadingDimensions(),
127 &workspace));
128 }
129 return workspace;
130}
131
132template <CMatrix TMatrixQR, CVector TVectorTau, class TScalar>
133inline void LinearSolver::Geqrf(
134 TMatrixQR& QR,
135 TVectorTau& tau,
136 common::Buffer<TScalar>& workspace,
137 std::shared_ptr<cuda::stream_t> stream) const
138{
139 TrySetStream(stream);
140 if constexpr (std::is_same_v<TScalar, float>)
141 {
142 CUSOLVER_CHECK(cusolverDnSgeqrf(
143 mCusolverHandle,
144 QR.Rows(),
145 QR.Cols(),
146 QR.Raw(),
147 QR.LeadingDimensions(),
148 tau.Raw(),
149 workspace.Raw(),
150 static_cast<int>(workspace.Size()),
151 nullptr));
152 }
153 if constexpr (std::is_same_v<TScalar, double>)
154 {
155 CUSOLVER_CHECK(cusolverDnDgeqrf(
156 mCusolverHandle,
157 QR.Rows(),
158 QR.Cols(),
159 QR.Raw(),
160 QR.LeadingDimensions(),
161 tau.Raw(),
162 workspace.Raw(),
163 static_cast<int>(workspace.Size()),
164 nullptr));
165 }
166}
167
168template <CMatrix TMatrixQ, CMatrix TMatrixB, class TScalar>
169int LinearSolver::OrmqrWorkspace(TMatrixQ const& Q, TMatrixB const& B, bool bMultiplyFromLeft) const
170{
171 int workspace{0};
172 auto side = bMultiplyFromLeft ? cublasSideMode_t::CUBLAS_SIDE_LEFT :
173 cublasSideMode_t::CUBLAS_SIDE_RIGHT;
174 if constexpr (std::is_same_v<TScalar, float>)
175 {
176 CUSOLVER_CHECK(cusolverDnSormqr_bufferSize(
177 mCusolverHandle,
178 side,
179 Q.Operation(),
180 B.Rows(),
181 B.Cols(),
182 Q.Cols(),
183 nullptr,
184 Q.LeadingDimensions(),
185 nullptr,
186 nullptr,
187 B.LeadingDimensions(),
188 &workspace));
189 }
190 if constexpr (std::is_same_v<TScalar, double>)
191 {
192 CUSOLVER_CHECK(cusolverDnDormqr_bufferSize(
193 mCusolverHandle,
194 side,
195 Q.Operation(),
196 B.Rows(),
197 B.Cols(),
198 Q.Cols(),
199 nullptr,
200 Q.LeadingDimensions(),
201 nullptr,
202 nullptr,
203 B.LeadingDimensions(),
204 &workspace));
205 }
206 return workspace;
207}
208
209template <CMatrix TMatrixQ, CVector TVectorB, class TScalar>
210inline int
211LinearSolver::OrmqrWorkspace(TMatrixQ const& Q, TVectorB const& B, bool bMultiplyFromLeft) const
212{
213 MatrixView<TScalar> BB(const_cast<TVectorB&>(B));
214 return OrmqrWorkspace(Q, BB, bMultiplyFromLeft);
215}
216
217template <CMatrix TMatrixQ, CVector TVectorTau, CMatrix TMatrixB, class TScalar>
218inline void LinearSolver::Ormqr(
219 TMatrixQ const& Q,
220 TVectorTau const& tau,
221 TMatrixB& B,
222 common::Buffer<TScalar>& workspace,
223 bool bMultiplyFromLeft,
224 std::shared_ptr<cuda::stream_t> stream) const
225{
226 TrySetStream(stream);
227 auto side = bMultiplyFromLeft ? cublasSideMode_t::CUBLAS_SIDE_LEFT :
228 cublasSideMode_t::CUBLAS_SIDE_RIGHT;
229 if constexpr (std::is_same_v<TScalar, float>)
230 {
231 CUSOLVER_CHECK(cusolverDnSormqr(
232 mCusolverHandle,
233 side,
234 Q.Operation(),
235 B.Rows(),
236 B.Cols(),
237 Q.Cols(),
238 Q.Raw(),
239 Q.LeadingDimensions(),
240 tau.Raw(),
241 B.Raw(),
242 B.LeadingDimensions(),
243 workspace.Raw(),
244 static_cast<int>(workspace.Size()),
245 nullptr));
246 }
247 if constexpr (std::is_same_v<TScalar, double>)
248 {
249 CUSOLVER_CHECK(cusolverDnDormqr(
250 mCusolverHandle,
251 side,
252 Q.Operation(),
253 B.Rows(),
254 B.Cols(),
255 Q.Cols(),
256 Q.Raw(),
257 Q.LeadingDimensions(),
258 tau.Raw(),
259 B.Raw(),
260 B.LeadingDimensions(),
261 workspace.Raw(),
262 static_cast<int>(workspace.Size()),
263 nullptr));
264 }
265}
266
267template <CMatrix TMatrixQ, CVector TVectorTau, CVector TVectorB, class TScalar>
268inline void LinearSolver::Ormqr(
269 TMatrixQ const& Q,
270 TVectorTau const& tau,
271 TVectorB& B,
272 common::Buffer<TScalar>& workspace,
273 bool bMultiplyFromLeft,
274 std::shared_ptr<cuda::stream_t> stream) const
275{
276 MatrixView<TScalar> BB(B);
277 Ormqr(Q, tau, BB, workspace, bMultiplyFromLeft, stream);
278}
279
280} // namespace pbat::gpu::impl::math
281
282#endif // PBAT_GPU_IMPL_MATH_LINEARSOLVER_CUH
BLAS API wrapper over cuBLAS.
Matrix and vector cuBLAS abstractions.
Definition Buffer.cuh:21
Definition Matrix.cuh:24
Definition Matrix.cuh:37
GPU implementations of math functions.
Definition Blas.cu:7