26 requires std::is_same_v<typename TMatrix::ValueType, float> or
27 std::is_same_v<typename TMatrix::ValueType, double>;
28 {a.Raw()}->std::same_as<
typename TMatrix::ValueType*>;
29 {a.Rows()}->std::convertible_to<
int>;
30 {a.Cols()}->std::convertible_to<
int>;
31 {a.LeadingDimensions()}->std::convertible_to<
int>;
32 {a.Operation()}->std::convertible_to<cublasOperation_t>;
39 requires std::is_same_v<typename TVector::ValueType, float> or
40 std::is_same_v<typename TVector::ValueType, double>;
41 {a.Raw()}->std::same_as<
typename TVector::ValueType*>;
42 {a.Rows()}->std::convertible_to<
int>;
43 {a.Increment()}->std::convertible_to<
int>;
49 using ValueType = std::remove_cvref_t<T>;
56 ValueType* Raw() {
return data; }
57 ValueType
const* Raw()
const {
return data; }
58 auto Rows()
const {
return n; }
59 constexpr auto Cols()
const {
return 1; }
60 auto Increment()
const {
return inc; }
63 VectorView<ValueType> Slice(
auto row,
auto rows,
auto inc)
const
65 return VectorView<ValueType>{
const_cast<ValueType*
>(
data) + row, rows,
inc};
67 VectorView<ValueType> Segment(
auto row,
auto rows)
const {
return Slice(row, rows, 1); }
68 VectorView<ValueType> Head(
auto rows)
const {
return Slice(0, rows, 1); }
69 VectorView<ValueType> Tail(
auto rows)
const {
return Slice(
n - rows, rows, 1); }
75 using ValueType = std::remove_cvref_t<T>;
88 cublasOperation_t opIn = cublasOperation_t::CUBLAS_OP_N)
89 :
data(dataIn),
m(mIn),
n(nIn),
ld(ldIn),
op(opIn)
93 throw std::invalid_argument(
94 "MatrixView::MatrixView(ValueType* data, int m, int n, int ld) -> ld < m");
98 template <CVector TVector>
101 if (v.Increment() != 1)
103 throw std::invalid_argument(
104 "MatrixView::MatrixView(TVector const& v) -> v.Increment() must be 1");
109 ValueType* Raw() {
return data; }
110 ValueType
const* Raw()
const {
return data; }
111 auto Rows()
const {
return m; }
112 auto Cols()
const {
return n; }
113 auto LeadingDimensions()
const {
return ld; }
114 auto Operation()
const {
return op; }
117 MatrixView<ValueType> Transposed()
const
119 return MatrixView<ValueType>{
120 const_cast<ValueType*
>(
data),
124 op == CUBLAS_OP_N ? CUBLAS_OP_T : CUBLAS_OP_N};
126 MatrixView<ValueType> SubMatrix(
auto row,
auto col,
auto rows,
auto cols)
const
128 return MatrixView<ValueType>{
const_cast<ValueType*
>(
data) +
ld * col + row, rows, cols,
ld};
130 MatrixView<ValueType> LeftCols(
auto cols)
const {
return SubMatrix(0, 0,
m, cols); }
131 MatrixView<ValueType> RightCols(
auto cols)
const
133 return SubMatrix(0, Cols() - cols,
m, cols);
135 MatrixView<ValueType> TopRows(
auto rows)
const {
return SubMatrix(0, 0, rows, Cols()); }
136 MatrixView<ValueType> BottomRows(
auto rows)
const
138 return SubMatrix(Rows() - rows, 0, rows, Cols());
140 MatrixView<ValueType> Col(
auto col)
const {
return SubMatrix(0, col, Rows(), Cols()); }
141 MatrixView<ValueType> Row(
auto row)
const {
return SubMatrix(row, 0, 1, Cols()); }
142 VectorView<ValueType> Flattened()
const
144 return VectorView<ValueType>{
const_cast<ValueType*
>(
data), Rows() * Cols(), 1};
151 using ValueType = std::remove_cvref_t<T>;
154 Matrix(
auto rows,
auto cols) :
data(rows * cols),
m(
static_cast<int>(rows)) {}
160 ValueType* Raw() {
return data.Raw(); }
161 ValueType
const* Raw()
const {
return data.Raw(); }
162 auto Rows()
const {
return m; }
163 auto Cols()
const {
return static_cast<int>(data.Size()) / m; }
164 auto LeadingDimensions()
const {
return m; }
165 auto Operation()
const {
return CUBLAS_OP_N; }
168 MatrixView<ValueType> View()
const
170 return MatrixView<ValueType>{
const_cast<ValueType*
>(data.Raw()), m, Cols(), m};
172 MatrixView<ValueType> SubMatrix(
auto row,
auto col,
auto rows,
auto cols)
const
174 ValueType* a =
const_cast<ValueType*
>(data.Raw());
175 return MatrixView<ValueType>{a + m * col + row, rows, cols, m};
177 MatrixView<ValueType> LeftCols(
auto cols)
const {
return SubMatrix(0, 0, m, cols); }
178 MatrixView<ValueType> RightCols(
auto cols)
const
180 return SubMatrix(0, Cols() - cols, m, cols);
182 MatrixView<ValueType> TopRows(
auto rows)
const {
return SubMatrix(0, 0, rows, Cols()); }
183 MatrixView<ValueType> BottomRows(
auto rows)
const
185 return SubMatrix(Rows() - rows, 0, rows, Cols());
187 MatrixView<ValueType> Col(
auto col)
const {
return SubMatrix(0, col, Rows(), Cols()); }
188 MatrixView<ValueType> Row(
auto row)
const {
return SubMatrix(row, 0, 1, Cols()); }
189 VectorView<ValueType> Flattened()
const
191 return VectorView<ValueType>{
const_cast<ValueType*
>(data.Raw()), Rows() * Cols(), 1};
193 MatrixView<ValueType> Transposed()
const
195 return MatrixView<ValueType>{
const_cast<ValueType*
>(data.Raw()), m, Cols(), m, CUBLAS_OP_T};
202 using ValueType = std::remove_cvref_t<T>;
205 Vector(
auto rows) :
data(rows),
n(
static_cast<int>(rows)) {}
211 ValueType* Raw() {
return data.Raw(); }
212 ValueType
const* Raw()
const {
return data.Raw(); }
213 auto Rows()
const {
return n; }
214 constexpr auto Cols()
const {
return 1; }
215 constexpr auto Increment()
const {
return 1; }
218 VectorView<ValueType> Slice(
auto row,
auto rows,
auto inc)
const
220 ValueType* a =
const_cast<ValueType*
>(data.Raw());
221 return VectorView<ValueType>{a + row, rows, inc};
223 VectorView<ValueType> Segment(
auto row,
auto rows)
const {
return Slice(row, rows, 1); }
224 VectorView<ValueType> Head(
auto rows)
const {
return Slice(0, rows, 1); }
225 VectorView<ValueType> Tail(
auto rows)
const {
return Slice(n - rows, rows, 1); }