1#ifndef PBAT_MATH_LINALG_MINI_MATRIX_H
2#define PBAT_MATH_LINALG_MINI_MATRIX_H
6#include "pbat/HostDevice.h"
11#include <initializer_list>
20template <
class TScalar,
int M,
int N = 1>
24 using ScalarType = TScalar;
27 static auto constexpr kRows = M;
28 static auto constexpr kCols = N;
29 static bool constexpr bRowMajor =
false;
31 PBAT_HOST_DEVICE ScalarType operator()([[maybe_unused]]
auto i, [[maybe_unused]]
auto j)
const
37 PBAT_HOST_DEVICE ScalarType operator()([[maybe_unused]]
auto i)
const {
return ScalarType{1}; }
38 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
40 template <auto S, auto T>
41 PBAT_HOST_DEVICE
auto Slice([[maybe_unused]]
auto i, [[maybe_unused]]
auto j)
const
45 PBAT_HOST_DEVICE
auto Col([[maybe_unused]]
auto j)
const
49 PBAT_HOST_DEVICE
auto Row([[maybe_unused]]
auto i)
const
54 PBAT_MINI_DIMENSIONS_API
55 PBAT_MINI_CONST_TRANSPOSE_API(SelfType)
58template <
class TScalar,
int M,
int N = 1>
62 using ScalarType = TScalar;
65 static auto constexpr kRows = M;
66 static auto constexpr kCols = N;
67 static bool constexpr bRowMajor =
false;
69 PBAT_HOST_DEVICE ScalarType operator()([[maybe_unused]]
auto i, [[maybe_unused]]
auto j)
const
75 PBAT_HOST_DEVICE ScalarType operator()([[maybe_unused]]
auto i)
const {
return ScalarType{0}; }
76 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
78 template <auto S, auto T>
79 PBAT_HOST_DEVICE
auto Slice([[maybe_unused]]
auto i, [[maybe_unused]]
auto j)
const
83 PBAT_HOST_DEVICE
auto Col([[maybe_unused]]
auto j)
const
87 PBAT_HOST_DEVICE
auto Row([[maybe_unused]]
auto i)
const
92 PBAT_MINI_DIMENSIONS_API
93 PBAT_MINI_CONST_TRANSPOSE_API(SelfType)
96template <
class TScalar,
int M,
int N>
100 using ScalarType = TScalar;
103 static auto constexpr kRows = M;
104 static auto constexpr kCols = N;
105 static bool constexpr bRowMajor =
false;
107 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
109 return static_cast<ScalarType
>(i == j);
113 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
114 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
116 PBAT_MINI_READ_API(SelfType)
119template <
class TScalar,
int M,
int N = 1>
123 using ScalarType = TScalar;
124 using SelfType = SMatrix<ScalarType, M, N>;
125#include "pbat/warning/Push.h"
126#include "pbat/warning/SignConversion.h"
127 using StorageType = std::array<ScalarType, M * N>;
128#include "pbat/warning/Pop.h"
129 using IndexType =
typename StorageType::size_type;
131 static auto constexpr kRows = M;
132 static auto constexpr kCols = N;
133 static bool constexpr bRowMajor =
false;
135 PBAT_HOST_DEVICE SMatrix() : a() {}
137 template <
class... T>
138 PBAT_HOST_DEVICE SMatrix(T... values) : a{values...}
142 template <
class TMatrix>
143 PBAT_HOST_DEVICE SMatrix(TMatrix&& B) : a()
145 Assign(*
this, std::forward<TMatrix>(B));
148 PBAT_HOST_DEVICE
void SetConstant(ScalarType k) { AssignScalar(*
this, k); }
150 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
152 auto k =
static_cast<IndexType
>(j * M + i);
155 PBAT_HOST_DEVICE ScalarType& operator()(
auto i,
auto j)
157 auto k =
static_cast<IndexType
>(j * M + i);
162 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const
164 auto k =
static_cast<IndexType
>(i);
167 PBAT_HOST_DEVICE ScalarType& operator()(
auto i)
169 auto k =
static_cast<IndexType
>(i);
172 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const
174 auto k =
static_cast<IndexType
>(i);
177 PBAT_HOST_DEVICE ScalarType& operator[](
auto i)
179 auto k =
static_cast<IndexType
>(i);
183 PBAT_HOST_DEVICE
void SetZero() { memset(a.data(), 0, kRows * kCols *
sizeof(ScalarType)); }
185 ScalarType* Data() {
return a.data(); }
186 ScalarType
const* Data()
const {
return a.data(); }
188 PBAT_MINI_READ_WRITE_API(SelfType)
194template <
class TScalar,
int M>
197template <
class TScalar,
int M,
int N>
201 using ScalarType = TScalar;
202 using SelfType = SMatrixView<ScalarType, M, N>;
204 static auto constexpr kRows = M;
205 static auto constexpr kCols = N;
206 static bool constexpr bRowMajor =
false;
208 PBAT_HOST_DEVICE SMatrixView(ScalarType* a) : mA(a) {}
210 PBAT_HOST_DEVICE
void SetConstant(ScalarType k) { AssignScalar(*
this, k); }
212 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const {
return mA[j * M + i]; }
213 PBAT_HOST_DEVICE ScalarType& operator()(
auto i,
auto j) {
return mA[j * M + i]; }
216 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return mA[i]; }
217 PBAT_HOST_DEVICE ScalarType& operator()(
auto i) {
return mA[i]; }
218 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return mA[i]; }
219 PBAT_HOST_DEVICE ScalarType& operator[](
auto i) {
return mA[i]; }
221 PBAT_HOST_DEVICE
void SetZero() { memset(mA, 0, kRows * kCols *
sizeof(ScalarType)); }
223 ScalarType* Data() {
return mA; }
224 ScalarType
const* Data()
const {
return mA; }
226 PBAT_MINI_READ_WRITE_API(SelfType)
232template <
class TScalar,
int M>
235template <
class TScalar,
int M>
236PBAT_HOST_DEVICE
auto Unit(
auto i)
241template <auto M, auto N,
class TScalar>
242PBAT_HOST_DEVICE
auto FromFlatBuffer(TScalar* buf, std::int64_t bi)
244 return SMatrixView<TScalar, M, N>(buf + M * N * bi);
247template <
class TScalar, CMatrix TIndexMatrix>
248PBAT_HOST_DEVICE
auto FromFlatBuffer(TScalar* buf, TIndexMatrix
const& inds)
250 using IntegerType =
typename TIndexMatrix::ScalarType;
251 static_assert(std::is_integral_v<IntegerType>,
"inds must be matrix of indices");
252 auto constexpr M = TIndexMatrix::kRows;
253 auto constexpr N = TIndexMatrix::kCols;
254 using ScalarType = std::remove_cvref_t<TScalar>;
255 SMatrix<ScalarType, M, N> A{};
257 ForRange<0, N>([&]<
auto j>() { ForRange<0, M>([&]<
auto i>() { A(i, j) = buf[inds(i, j)]; }); });
261template <CMatrix TMatrix>
263ToFlatBuffer(TMatrix
const& A,
typename TMatrix::ScalarType* buf, std::int64_t bi)
265 auto constexpr M = TMatrix::kRows;
266 auto constexpr N = TMatrix::kCols;
267 FromFlatBuffer<M, N>(buf, bi) = A;
270template <CMatrix TMatrix, CMatrix TIndexMatrix>
272ToFlatBuffer(TMatrix
const& A, TIndexMatrix
const& inds,
typename TMatrix::ScalarType* buf)
274 auto constexpr MA = TMatrix::kRows;
275 auto constexpr NA = TMatrix::kCols;
276 auto constexpr MI = TIndexMatrix::kRows;
277 auto constexpr NI = TIndexMatrix::kCols;
278 static_assert(MA == MI or MI == 1,
"A must have same rows as inds or inds is a row vector");
279 static_assert(NA == NI,
"A must have same cols as inds");
281 if constexpr (MA > 1 and MI == 1)
285 ForRange<0, NA>([&]<
auto j>() {
286 ForRange<0, MA>([&]<
auto i>() { buf[MA * inds(0, j) + i] = A(i, j); });
292 [&]<
auto j>() { ForRange<0, MA>([&]<
auto i>() { buf[inds(i, j)] = A(i, j); }); });
296template <auto M, auto N,
class TScalar>
297PBAT_HOST_DEVICE
auto FromBuffers(std::array<TScalar*, M> buf, std::int64_t bi)
299 using ScalarType = std::remove_const_t<TScalar>;
302 ForRange<0, M>([&]<
auto i>() { A.Row(i) = FromFlatBuffer<1, N>(buf[i], bi); });
306template <auto K,
class TScalar, CMatrix TIndexMatrix>
307PBAT_HOST_DEVICE
auto FromBuffers(std::array<TScalar*, K> buf, TIndexMatrix
const& inds)
309 using IntegerType =
typename TIndexMatrix::ScalarType;
310 static_assert(std::is_integral_v<IntegerType>,
"inds must be matrix of indices");
311 auto constexpr M = TIndexMatrix::kRows;
312 auto constexpr N = TIndexMatrix::kCols;
313 using ScalarType = std::remove_cvref_t<TScalar>;
317 [&]<
auto k>() { A.template Slice<M, N>(k * M, 0) = FromFlatBuffer(buf[k], inds); });
321template <CMatrix TMatrix, auto M>
323ToBuffers(TMatrix
const& A, std::array<typename TMatrix::ScalarType*, M> buf, std::int64_t bi)
325 static_assert(M == TMatrix::kRows,
"A must have same rows as number of buffers");
326 auto constexpr N = TMatrix::kCols;
328 ForRange<0, M>([&]<
auto i>() { FromFlatBuffer<1, N>(buf[i], bi) = A.Row(i); });
331template <CMatrix TMatrix, CMatrix TIndexMatrix, auto K>
332PBAT_HOST_DEVICE
void ToBuffers(
334 TIndexMatrix
const& inds,
335 std::array<typename TMatrix::ScalarType*, K> buf)
337 auto constexpr MA = TMatrix::kRows;
338 auto constexpr NA = TMatrix::kCols;
339 auto constexpr MI = TIndexMatrix::kRows;
340 auto constexpr NI = TIndexMatrix::kCols;
341 static_assert(MA % MI == 0,
"Rows of A must be multiple of rows of inds");
342 static_assert(NA == NI,
"A and inds must have same number of columns");
343 static_assert(MA / MI == K,
"A must have number of rows == #buffers*#rows of inds");
346 [&]<
auto k>() { ToFlatBuffer(A.template Slice<MI, NI>(k * MI, 0), inds, buf[k]); });
constexpr void ForRange(F &&f)
Compile-time for loop over a range of values.
Definition ConstexprFor.h:55
Mini linear algebra related functionality.
Definition Assign.h:12
Linear Algebra related functionality.
Definition FilterEigenvalues.h:7
Math related functionality.
Definition Concepts.h:19
The main namespace of the library.
Definition Aliases.h:15