1#ifndef PBAT_MATH_LINALG_MINI_REDUCTIONS_H
2#define PBAT_MATH_LINALG_MINI_REDUCTIONS_H
7#include "pbat/HostDevice.h"
17template <
class TMatrix>
21 using NestedType = TMatrix;
22 using ScalarType =
typename NestedType::ScalarType;
23 using SelfType = ConstDiagonal<NestedType>;
25 static auto constexpr kRows = NestedType::kRows;
26 static auto constexpr kCols = 1;
27 static bool constexpr bRowMajor =
false;
29 PBAT_HOST_DEVICE ConstDiagonal(NestedType
const& A) : mA(A) {}
31 PBAT_HOST_DEVICE ScalarType operator()(
auto i, [[maybe_unused]]
auto j)
const
37 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return mA(i, i); }
38 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
40 PBAT_MINI_READ_API(SelfType)
46template <
class TMatrix>
50 using NestedType = TMatrix;
51 using ScalarType =
typename NestedType::ScalarType;
52 using SelfType = Diagonal<NestedType>;
54 static auto constexpr kRows = NestedType::kRows;
55 static auto constexpr kCols = 1;
56 static bool constexpr bRowMajor =
false;
58 PBAT_HOST_DEVICE Diagonal(NestedType& A) : mA(A) {}
60 PBAT_HOST_DEVICE ScalarType operator()(
auto i, [[maybe_unused]]
auto j)
const
64 PBAT_HOST_DEVICE ScalarType& operator()(
auto i, [[maybe_unused]]
auto j) {
return mA(i, i); }
67 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return mA(i, i); }
68 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
69 PBAT_HOST_DEVICE ScalarType& operator()(
auto i) {
return mA(i, i); }
70 PBAT_HOST_DEVICE ScalarType& operator[](
auto i) {
return (*
this)(i); }
72 PBAT_HOST_DEVICE
void SetConstant(
auto k) { AssignScalar(*
this, k); }
74 PBAT_MINI_READ_WRITE_API(SelfType)
80template <CMatrix TMatrix>
81PBAT_HOST_DEVICE
auto Diag(TMatrix
const& A)
86template <CMatrix TMatrix>
87PBAT_HOST_DEVICE
auto Diag(TMatrix& A)
89 return Diagonal<TMatrix>(A);
92template <
class TMatrix>
93PBAT_HOST_DEVICE
auto Trace(TMatrix&& A)
95 using MatrixType = std::remove_cvref_t<TMatrix>;
96 PBAT_MINI_CHECK_CMATRIX(MatrixType);
98 MatrixType::kRows == MatrixType::kCols,
99 "Cannot compute trace of non-square matrix");
100 using IntegerType = std::remove_const_t<
decltype(MatrixType::kRows)>;
101 auto sum = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) {
102 return (std::forward<TMatrix>(A)(I, I) + ...);
104 return sum(std::make_integer_sequence<IntegerType, MatrixType::kRows>{});
107template <
class TLhsMatrix,
class TRhsMatrix>
108PBAT_HOST_DEVICE
auto Dot(TLhsMatrix&& A, TRhsMatrix&& B)
110 return Trace(std::forward<TLhsMatrix>(A).Transpose() * std::forward<TRhsMatrix>(B));
114#define PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(FunctionName, BinaryOp) \
115 template <CMatrix TMatrix> \
116 PBAT_HOST_DEVICE auto FunctionName(TMatrix const& A) \
118 using MatrixType = TMatrix; \
119 using IntegerType = std::remove_const_t<decltype(TMatrix::kRows)>; \
120 if constexpr (MatrixType::bRowMajor) \
123 [&]<IntegerType... J>(IntegerType i, std::integer_sequence<IntegerType, J...>) { \
124 return (static_cast<bool>(A(i, J)) BinaryOp ...); \
126 auto fRows = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) { \
127 return (fCols(I, std::make_integer_sequence<IntegerType, MatrixType::kCols>()) \
130 return fRows(std::make_integer_sequence<IntegerType, MatrixType::kRows>()); \
135 [&]<IntegerType... I>(IntegerType j, std::integer_sequence<IntegerType, I...>) { \
136 return (static_cast<bool>(A(I, j)) BinaryOp ...); \
138 auto fCols = [&]<IntegerType... J>(std::integer_sequence<IntegerType, J...>) { \
139 return (fRows(J, std::make_integer_sequence<IntegerType, MatrixType::kRows>()) \
142 return fCols(std::make_integer_sequence<IntegerType, MatrixType::kCols>()); \
147PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(All, and)
148PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(Any, or)
Definition Reductions.h:19
Mini linear algebra related functionality.
Definition Assign.h:12
Linear Algebra related functionality.
Definition FilterEigenvalues.h:7
Math related functionality.
Definition Concepts.h:19
The main namespace of the library.
Definition Aliases.h:15