PhysicsBasedAnimationToolkit 0.0.10
Cross-platform C++20 library of algorithms and data structures commonly used in computer graphics research on physically-based simulation.
Loading...
Searching...
No Matches
Reductions.h
1#ifndef PBAT_MATH_LINALG_MINI_REDUCTIONS_H
2#define PBAT_MATH_LINALG_MINI_REDUCTIONS_H
3
4#include "Api.h"
5#include "Concepts.h"
6#include "Product.h"
7#include "pbat/HostDevice.h"
8
9#include <type_traits>
10#include <utility>
11
12namespace pbat {
13namespace math {
14namespace linalg {
15namespace mini {
16
17template <class /*CMatrix*/ TMatrix>
18class ConstDiagonal
19{
20 public:
21 using NestedType = TMatrix;
22 using ScalarType = typename NestedType::ScalarType;
23 using SelfType = ConstDiagonal<NestedType>;
24
25 static auto constexpr kRows = NestedType::kRows;
26 static auto constexpr kCols = 1;
27 static bool constexpr bRowMajor = false;
28
29 PBAT_HOST_DEVICE ConstDiagonal(NestedType const& A) : mA(A) {}
30
31 PBAT_HOST_DEVICE ScalarType operator()(auto i, [[maybe_unused]] auto j) const
32 {
33 return mA(i, i);
34 }
35
36 // Vector(ized) access
37 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return mA(i, i); }
38 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
39
40 PBAT_MINI_READ_API(SelfType)
41
42 private:
43 NestedType const& mA;
44};
45
46template <class /*CMatrix*/ TMatrix>
47class Diagonal
48{
49 public:
50 using NestedType = TMatrix;
51 using ScalarType = typename NestedType::ScalarType;
52 using SelfType = Diagonal<NestedType>;
53
54 static auto constexpr kRows = NestedType::kRows;
55 static auto constexpr kCols = 1;
56 static bool constexpr bRowMajor = false;
57
58 PBAT_HOST_DEVICE Diagonal(NestedType& A) : mA(A) {}
59
60 PBAT_HOST_DEVICE ScalarType operator()(auto i, [[maybe_unused]] auto j) const
61 {
62 return mA(i, i);
63 }
64 PBAT_HOST_DEVICE ScalarType& operator()(auto i, [[maybe_unused]] auto j) { return mA(i, i); }
65
66 // Vector(ized) access
67 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return mA(i, i); }
68 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
69 PBAT_HOST_DEVICE ScalarType& operator()(auto i) { return mA(i, i); }
70 PBAT_HOST_DEVICE ScalarType& operator[](auto i) { return (*this)(i); }
71
72 PBAT_HOST_DEVICE void SetConstant(auto k) { AssignScalar(*this, k); }
73
74 PBAT_MINI_READ_WRITE_API(SelfType)
75
76 private:
77 NestedType& mA;
78};
79
80template <CMatrix TMatrix>
81PBAT_HOST_DEVICE auto Diag(TMatrix const& A)
82{
83 return ConstDiagonal<TMatrix>(A);
84}
85
86template <CMatrix TMatrix>
87PBAT_HOST_DEVICE auto Diag(TMatrix& A)
88{
89 return Diagonal<TMatrix>(A);
90}
91
92template <class /*CMatrix*/ TMatrix>
93PBAT_HOST_DEVICE auto Trace(TMatrix&& A)
94{
95 using MatrixType = std::remove_cvref_t<TMatrix>;
96 PBAT_MINI_CHECK_CMATRIX(MatrixType);
97 static_assert(
98 MatrixType::kRows == MatrixType::kCols,
99 "Cannot compute trace of non-square matrix");
100 using IntegerType = std::remove_const_t<decltype(MatrixType::kRows)>;
101 auto sum = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) {
102 return (std::forward<TMatrix>(A)(I, I) + ...);
103 };
104 return sum(std::make_integer_sequence<IntegerType, MatrixType::kRows>{});
105}
106
107template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
108PBAT_HOST_DEVICE auto Dot(TLhsMatrix&& A, TRhsMatrix&& B)
109{
110 return Trace(std::forward<TLhsMatrix>(A).Transpose() * std::forward<TRhsMatrix>(B));
111}
112
113// clang-format off
114#define PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(FunctionName, BinaryOp) \
115 template <CMatrix TMatrix> \
116 PBAT_HOST_DEVICE auto FunctionName(TMatrix const& A) \
117 { \
118 using MatrixType = TMatrix; \
119 using IntegerType = std::remove_const_t<decltype(TMatrix::kRows)>; \
120 if constexpr (MatrixType::bRowMajor) \
121 { \
122 auto fCols = \
123 [&]<IntegerType... J>(IntegerType i, std::integer_sequence<IntegerType, J...>) { \
124 return (static_cast<bool>(A(i, J)) BinaryOp ...); \
125 }; \
126 auto fRows = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) { \
127 return (fCols(I, std::make_integer_sequence<IntegerType, MatrixType::kCols>()) \
128 BinaryOp...); \
129 }; \
130 return fRows(std::make_integer_sequence<IntegerType, MatrixType::kRows>()); \
131 } \
132 else \
133 { \
134 auto fRows = \
135 [&]<IntegerType... I>(IntegerType j, std::integer_sequence<IntegerType, I...>) { \
136 return (static_cast<bool>(A(I, j)) BinaryOp ...); \
137 }; \
138 auto fCols = [&]<IntegerType... J>(std::integer_sequence<IntegerType, J...>) { \
139 return (fRows(J, std::make_integer_sequence<IntegerType, MatrixType::kRows>()) \
140 BinaryOp...); \
141 }; \
142 return fCols(std::make_integer_sequence<IntegerType, MatrixType::kCols>()); \
143 } \
144 }
145// clang-format on
146
147PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(All, and)
148PBAT_MINI_DEFINE_BINARY_PREDICATE_REDUCTION(Any, or)
149
150} // namespace mini
151} // namespace linalg
152} // namespace math
153} // namespace pbat
154
155#endif // PBAT_MATH_LINALG_MINI_REDUCTIONS_H
Definition Reductions.h:19
Mini linear algebra related functionality.
Definition Assign.h:12
Linear Algebra related functionality.
Definition FilterEigenvalues.h:7
Math related functionality.
Definition Concepts.h:19
The main namespace of the library.
Definition Aliases.h:15