PhysicsBasedAnimationToolkit 0.0.10
Cross-platform C++20 library of algorithms and data structures commonly used in computer graphics research on physically-based simulation.
Loading...
Searching...
No Matches
BinaryOperations.h
1#ifndef PBAT_MATH_LINALG_MINI_BINARYOPERATIONS_H
2#define PBAT_MATH_LINALG_MINI_BINARYOPERATIONS_H
3
4#include "Api.h"
5#include "Concepts.h"
6#include "Scale.h"
7#include "pbat/HostDevice.h"
8
9#include <cmath>
10#include <functional>
11#include <type_traits>
12#include <utility>
13
14namespace pbat {
15namespace math {
16namespace linalg {
17namespace mini {
18
19template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
20class Sum
21{
22 public:
23 using LhsNestedType = TLhsMatrix;
24 using RhsNestedType = TRhsMatrix;
25
26 using ScalarType = typename LhsNestedType::ScalarType;
27 using SelfType = Sum<LhsNestedType, RhsNestedType>;
28
29 static auto constexpr kRows = LhsNestedType::kRows;
30 static auto constexpr kCols = RhsNestedType::kCols;
31 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
32
33 PBAT_HOST_DEVICE Sum(LhsNestedType const& _A, RhsNestedType const& _B) : A(_A), B(_B)
34 {
35 static_assert(
36 LhsNestedType::kRows == RhsNestedType::kRows and
37 LhsNestedType::kCols == RhsNestedType::kCols,
38 "Invalid matrix sum dimensions");
39 }
40
41 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return A(i, j) + B(i, j); }
42
43 // Vector(ized) access
44 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
45 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
46
47 PBAT_MINI_READ_API(SelfType)
48
49 private:
50 LhsNestedType const& A;
51 RhsNestedType const& B;
52};
53
54template <class /*CMatrix*/ TLhsMatrix>
55class SumScalar
56{
57 public:
58 using LhsNestedType = TLhsMatrix;
59
60 using ScalarType = typename LhsNestedType::ScalarType;
61 using SelfType = SumScalar<LhsNestedType>;
62
63 static auto constexpr kRows = LhsNestedType::kRows;
64 static auto constexpr kCols = LhsNestedType::kCols;
65 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
66
67 PBAT_HOST_DEVICE SumScalar(LhsNestedType const& A, ScalarType k) : mA(A), mK(k) {}
68
69 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return mA(i, j) + mK; }
70
71 // Vector(ized) access
72 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
73 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
74
75 PBAT_MINI_READ_API(SelfType)
76
77 private:
78 LhsNestedType const& mA;
79 ScalarType mK;
80};
81
82template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
83class Subtraction
84{
85 public:
86 using LhsNestedType = TLhsMatrix;
87 using RhsNestedType = TRhsMatrix;
88
89 using ScalarType = typename LhsNestedType::ScalarType;
90 using SelfType = Subtraction<LhsNestedType, RhsNestedType>;
91
92 static auto constexpr kRows = LhsNestedType::kRows;
93 static auto constexpr kCols = RhsNestedType::kCols;
94 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
95
96 PBAT_HOST_DEVICE Subtraction(LhsNestedType const& _A, RhsNestedType const& _B) : A(_A), B(_B)
97 {
98 static_assert(
99 LhsNestedType::kRows == RhsNestedType::kRows and
100 LhsNestedType::kCols == RhsNestedType::kCols,
101 "Invalid matrix sum dimensions");
102 }
103
104 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return A(i, j) - B(i, j); }
105
106 // Vector(ized) access
107 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
108 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
109
110 PBAT_MINI_READ_API(SelfType)
111
112 private:
113 LhsNestedType const& A;
114 RhsNestedType const& B;
115};
116
117template <class /*CMatrix*/ TLhsMatrix>
118class SubtractionScalar
119{
120 public:
121 using LhsNestedType = TLhsMatrix;
122 using ScalarType = typename LhsNestedType::ScalarType;
123 using SelfType = SubtractionScalar<LhsNestedType>;
124
125 static auto constexpr kRows = LhsNestedType::kRows;
126 static auto constexpr kCols = LhsNestedType::kCols;
127 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
128
129 PBAT_HOST_DEVICE SubtractionScalar(LhsNestedType const& A, ScalarType k) : mA(A), mK(k) {}
130
131 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return mA(i, j) - mK; }
132
133 // Vector(ized) access
134 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
135 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
136
137 PBAT_MINI_READ_API(SelfType)
138
139 private:
140 LhsNestedType const& mA;
141 ScalarType mK;
142};
143
144template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
145class Minimum
146{
147 public:
148 using LhsNestedType = TLhsMatrix;
149 using RhsNestedType = TRhsMatrix;
150
151 using ScalarType = typename LhsNestedType::ScalarType;
152 using SelfType = Minimum<LhsNestedType, RhsNestedType>;
153
154 static auto constexpr kRows = LhsNestedType::kRows;
155 static auto constexpr kCols = RhsNestedType::kCols;
156 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
157
158 PBAT_HOST_DEVICE Minimum(LhsNestedType const& _A, RhsNestedType const& _B) : A(_A), B(_B)
159 {
160 static_assert(
161 LhsNestedType::kRows == RhsNestedType::kRows and
162 LhsNestedType::kCols == RhsNestedType::kCols,
163 "Invalid matrix minimum dimensions");
164 }
165
166 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const
167 {
168 using namespace std;
169 return min(A(i, j), B(i, j));
170 }
171
172 // Vector(ized) access
173 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
174 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
175
176 PBAT_MINI_READ_API(SelfType)
177
178 private:
179 LhsNestedType const& A;
180 RhsNestedType const& B;
181};
182
183template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
184class Maximum
185{
186 public:
187 using LhsNestedType = TLhsMatrix;
188 using RhsNestedType = TRhsMatrix;
189
190 using ScalarType = typename LhsNestedType::ScalarType;
191 using SelfType = Maximum<LhsNestedType, RhsNestedType>;
192
193 static auto constexpr kRows = LhsNestedType::kRows;
194 static auto constexpr kCols = RhsNestedType::kCols;
195 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
196
197 PBAT_HOST_DEVICE Maximum(LhsNestedType const& _A, RhsNestedType const& _B) : A(_A), B(_B)
198 {
199 static_assert(
200 LhsNestedType::kRows == RhsNestedType::kRows and
201 LhsNestedType::kCols == RhsNestedType::kCols,
202 "Invalid matrix maximum dimensions");
203 }
204
205 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const
206 {
207 using namespace std;
208 return max(A(i, j), B(i, j));
209 }
210
211 // Vector(ized) access
212 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
213 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
214
215 PBAT_MINI_READ_API(SelfType)
216
217 private:
218 LhsNestedType const& A;
219 RhsNestedType const& B;
220};
221
222template <class /*CMatrix*/ TMatrix, class Compare>
223class MatrixScalarPredicate
224{
225 public:
226 using CompareType = Compare;
227 using NestedType = TMatrix;
228 using ScalarType = bool;
229 using SelfType = MatrixScalarPredicate<NestedType, CompareType>;
230
231 static auto constexpr kRows = NestedType::kRows;
232 static auto constexpr kCols = NestedType::kCols;
233 static bool constexpr bRowMajor = NestedType::bRowMajor;
234
235 PBAT_HOST_DEVICE
236 MatrixScalarPredicate(NestedType const& A, typename NestedType::ScalarType k, CompareType comp)
237 : mA(A), mK(k), mComparator(comp)
238 {
239 }
240
241 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const
242 {
243 return mComparator(mA(i, j), mK);
244 }
245
246 // Vector(ized) access
247 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
248 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
249
250 PBAT_MINI_READ_API(SelfType)
251
252 private:
253 NestedType const& mA;
254 typename NestedType::ScalarType mK;
255 CompareType mComparator;
256};
257
258template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix, class Compare>
259class MatrixMatrixPredicate
260{
261 public:
262 using CompareType = Compare;
263 using LhsNestedType = TLhsMatrix;
264 using RhsNestedType = TRhsMatrix;
265 using ScalarType = bool;
266 using SelfType = MatrixMatrixPredicate<LhsNestedType, RhsNestedType, CompareType>;
267
268 static auto constexpr kRows = LhsNestedType::kRows;
269 static auto constexpr kCols = LhsNestedType::kCols;
270 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
271
272 PBAT_HOST_DEVICE
273 MatrixMatrixPredicate(LhsNestedType const& A, RhsNestedType const& B, CompareType comp)
274 : mA(A), mB(B), mComparator(comp)
275 {
276 static_assert(
277 LhsNestedType::kRows == RhsNestedType::kRows and
278 LhsNestedType::kCols == RhsNestedType::kCols,
279 "A and B must have same dimensions");
280 }
281
282 PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const
283 {
284 return mComparator(mA(i, j), mB(i, j));
285 }
286
287 // Vector(ized) access
288 PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
289 PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
290
291 PBAT_MINI_READ_API(SelfType)
292
293 private:
294 LhsNestedType const& mA;
295 RhsNestedType const& mB;
296 CompareType mComparator;
297};
298
299template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
300PBAT_HOST_DEVICE auto operator+(TLhsMatrix&& A, TRhsMatrix&& B)
301{
302 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
303 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
304 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
305 {
306 return SumScalar<LhsMatrixType>(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
307 }
308 else
309 {
310 return Sum<LhsMatrixType, RhsMatrixType>(
311 std::forward<TLhsMatrix>(A),
312 std::forward<TRhsMatrix>(B));
313 }
314}
315
316template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
317PBAT_HOST_DEVICE auto operator+=(TLhsMatrix&& A, TRhsMatrix&& B)
318{
319 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
320 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
321 {
322 AddAssignScalar(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
323 }
324 else
325 {
326 AddAssign(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
327 }
328 return A;
329}
330
331template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
332PBAT_HOST_DEVICE auto operator-(TLhsMatrix&& A, TRhsMatrix&& B)
333{
334 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
335 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
336 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
337 {
339 std::forward<TLhsMatrix>(A),
340 std::forward<TRhsMatrix>(B));
341 }
342 else
343 {
345 std::forward<TLhsMatrix>(A),
346 std::forward<TRhsMatrix>(B));
347 }
348}
349
350template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
351PBAT_HOST_DEVICE auto operator-=(TLhsMatrix&& A, TRhsMatrix&& B)
352{
353 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
354 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
355 {
356 SubtractAssignScalar(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
357 }
358 else
359 {
360 SubtractAssign(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
361 }
362 return A;
363}
364
365template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
366PBAT_HOST_DEVICE auto Min(TLhsMatrix&& A, TRhsMatrix&& B)
367{
368 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
369 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
371 std::forward<TLhsMatrix>(A),
372 std::forward<TRhsMatrix>(B));
373}
374
375template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
376PBAT_HOST_DEVICE auto Max(TLhsMatrix&& A, TRhsMatrix&& B)
377{
378 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
379 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
381 std::forward<TLhsMatrix>(A),
382 std::forward<TRhsMatrix>(B));
383}
384
385#define PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(Operator, Comparator) \
386 template <CMatrix TMatrix> \
387 PBAT_HOST_DEVICE auto Operator(TMatrix const& A, typename TMatrix::ScalarType k) \
388 { \
389 using ScalarType = typename TMatrix::ScalarType; \
390 using CompareType = Comparator<ScalarType>; \
391 return MatrixScalarPredicate<TMatrix, CompareType>(A, k, CompareType{}); \
392 }
393
394#define PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(Operator, Comparator) \
395 template <CMatrix TLhsMatrix, CMatrix TRhsMatrix> \
396 PBAT_HOST_DEVICE auto Operator(TLhsMatrix const& A, TRhsMatrix const& B) \
397 { \
398 using ScalarType = typename TLhsMatrix::ScalarType; \
399 using CompareType = Comparator<ScalarType>; \
400 return MatrixMatrixPredicate<TLhsMatrix, TRhsMatrix, CompareType>(A, B, CompareType{}); \
401 }
402
403PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator<, std::less)
404PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator>, std::greater)
405PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator==, std::equal_to)
406PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator!=, std::not_equal_to)
407PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator<=, std::less_equal)
408PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(operator>=, std::greater_equal)
409
410PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator<, std::less)
411PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator>, std::greater)
412PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator==, std::equal_to)
413PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator!=, std::not_equal_to)
414PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator<=, std::less_equal)
415PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator>=, std::greater_equal)
416
417PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator&&, std::logical_and)
418PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(operator||, std::logical_or)
419
420} // namespace mini
421} // namespace linalg
422} // namespace math
423} // namespace pbat
424
425#endif // PBAT_MATH_LINALG_MINI_BINARYOPERATIONS_H
Definition BinaryOperations.h:185
Definition BinaryOperations.h:146
Definition BinaryOperations.h:84
Definition BinaryOperations.h:119
Definition BinaryOperations.h:56
Mini linear algebra related functionality.
Definition Assign.h:12
Linear Algebra related functionality.
Definition FilterEigenvalues.h:7
Math related functionality.
Definition Concepts.h:19
The main namespace of the library.
Definition Aliases.h:15