1#ifndef PBAT_MATH_LINALG_MINI_BINARYOPERATIONS_H
2#define PBAT_MATH_LINALG_MINI_BINARYOPERATIONS_H
7#include "pbat/HostDevice.h"
19template <
class TLhsMatrix,
class TRhsMatrix>
23 using LhsNestedType = TLhsMatrix;
24 using RhsNestedType = TRhsMatrix;
26 using ScalarType =
typename LhsNestedType::ScalarType;
27 using SelfType = Sum<LhsNestedType, RhsNestedType>;
29 static auto constexpr kRows = LhsNestedType::kRows;
30 static auto constexpr kCols = RhsNestedType::kCols;
31 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
33 PBAT_HOST_DEVICE Sum(LhsNestedType
const& _A, RhsNestedType
const& _B) : A(_A), B(_B)
36 LhsNestedType::kRows == RhsNestedType::kRows and
37 LhsNestedType::kCols == RhsNestedType::kCols,
38 "Invalid matrix sum dimensions");
41 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const {
return A(i, j) + B(i, j); }
44 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
45 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
47 PBAT_MINI_READ_API(SelfType)
50 LhsNestedType
const& A;
51 RhsNestedType
const& B;
54template <
class TLhsMatrix>
58 using LhsNestedType = TLhsMatrix;
60 using ScalarType =
typename LhsNestedType::ScalarType;
61 using SelfType = SumScalar<LhsNestedType>;
63 static auto constexpr kRows = LhsNestedType::kRows;
64 static auto constexpr kCols = LhsNestedType::kCols;
65 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
67 PBAT_HOST_DEVICE SumScalar(LhsNestedType
const& A, ScalarType k) : mA(A), mK(k) {}
69 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const {
return mA(i, j) + mK; }
72 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
73 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
75 PBAT_MINI_READ_API(SelfType)
78 LhsNestedType
const& mA;
82template <
class TLhsMatrix,
class TRhsMatrix>
86 using LhsNestedType = TLhsMatrix;
87 using RhsNestedType = TRhsMatrix;
89 using ScalarType =
typename LhsNestedType::ScalarType;
90 using SelfType = Subtraction<LhsNestedType, RhsNestedType>;
92 static auto constexpr kRows = LhsNestedType::kRows;
93 static auto constexpr kCols = RhsNestedType::kCols;
94 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
96 PBAT_HOST_DEVICE Subtraction(LhsNestedType
const& _A, RhsNestedType
const& _B) : A(_A), B(_B)
99 LhsNestedType::kRows == RhsNestedType::kRows and
100 LhsNestedType::kCols == RhsNestedType::kCols,
101 "Invalid matrix sum dimensions");
104 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const {
return A(i, j) - B(i, j); }
107 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
108 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
110 PBAT_MINI_READ_API(SelfType)
113 LhsNestedType
const& A;
114 RhsNestedType
const& B;
117template <
class TLhsMatrix>
118class SubtractionScalar
121 using LhsNestedType = TLhsMatrix;
122 using ScalarType =
typename LhsNestedType::ScalarType;
123 using SelfType = SubtractionScalar<LhsNestedType>;
125 static auto constexpr kRows = LhsNestedType::kRows;
126 static auto constexpr kCols = LhsNestedType::kCols;
127 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
129 PBAT_HOST_DEVICE SubtractionScalar(LhsNestedType
const& A, ScalarType k) : mA(A), mK(k) {}
131 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const {
return mA(i, j) - mK; }
134 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
135 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
137 PBAT_MINI_READ_API(SelfType)
140 LhsNestedType
const& mA;
144template <
class TLhsMatrix,
class TRhsMatrix>
148 using LhsNestedType = TLhsMatrix;
149 using RhsNestedType = TRhsMatrix;
151 using ScalarType =
typename LhsNestedType::ScalarType;
152 using SelfType = Minimum<LhsNestedType, RhsNestedType>;
154 static auto constexpr kRows = LhsNestedType::kRows;
155 static auto constexpr kCols = RhsNestedType::kCols;
156 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
158 PBAT_HOST_DEVICE Minimum(LhsNestedType
const& _A, RhsNestedType
const& _B) : A(_A), B(_B)
161 LhsNestedType::kRows == RhsNestedType::kRows and
162 LhsNestedType::kCols == RhsNestedType::kCols,
163 "Invalid matrix minimum dimensions");
166 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
169 return min(A(i, j), B(i, j));
173 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
174 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
176 PBAT_MINI_READ_API(SelfType)
179 LhsNestedType
const& A;
180 RhsNestedType
const& B;
183template <
class TLhsMatrix,
class TRhsMatrix>
187 using LhsNestedType = TLhsMatrix;
188 using RhsNestedType = TRhsMatrix;
190 using ScalarType =
typename LhsNestedType::ScalarType;
191 using SelfType = Maximum<LhsNestedType, RhsNestedType>;
193 static auto constexpr kRows = LhsNestedType::kRows;
194 static auto constexpr kCols = RhsNestedType::kCols;
195 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
197 PBAT_HOST_DEVICE Maximum(LhsNestedType
const& _A, RhsNestedType
const& _B) : A(_A), B(_B)
200 LhsNestedType::kRows == RhsNestedType::kRows and
201 LhsNestedType::kCols == RhsNestedType::kCols,
202 "Invalid matrix maximum dimensions");
205 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
208 return max(A(i, j), B(i, j));
212 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
213 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
215 PBAT_MINI_READ_API(SelfType)
218 LhsNestedType
const& A;
219 RhsNestedType
const& B;
222template <
class TMatrix,
class Compare>
223class MatrixScalarPredicate
226 using CompareType = Compare;
227 using NestedType = TMatrix;
228 using ScalarType = bool;
229 using SelfType = MatrixScalarPredicate<NestedType, CompareType>;
231 static auto constexpr kRows = NestedType::kRows;
232 static auto constexpr kCols = NestedType::kCols;
233 static bool constexpr bRowMajor = NestedType::bRowMajor;
236 MatrixScalarPredicate(NestedType
const& A,
typename NestedType::ScalarType k, CompareType comp)
237 : mA(A), mK(k), mComparator(comp)
241 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
243 return mComparator(mA(i, j), mK);
247 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
248 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
250 PBAT_MINI_READ_API(SelfType)
253 NestedType
const& mA;
254 typename NestedType::ScalarType mK;
255 CompareType mComparator;
258template <
class TLhsMatrix,
class TRhsMatrix,
class Compare>
259class MatrixMatrixPredicate
262 using CompareType = Compare;
263 using LhsNestedType = TLhsMatrix;
264 using RhsNestedType = TRhsMatrix;
265 using ScalarType = bool;
266 using SelfType = MatrixMatrixPredicate<LhsNestedType, RhsNestedType, CompareType>;
268 static auto constexpr kRows = LhsNestedType::kRows;
269 static auto constexpr kCols = LhsNestedType::kCols;
270 static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
273 MatrixMatrixPredicate(LhsNestedType
const& A, RhsNestedType
const& B, CompareType comp)
274 : mA(A), mB(B), mComparator(comp)
277 LhsNestedType::kRows == RhsNestedType::kRows and
278 LhsNestedType::kCols == RhsNestedType::kCols,
279 "A and B must have same dimensions");
282 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
284 return mComparator(mA(i, j), mB(i, j));
288 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
289 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
291 PBAT_MINI_READ_API(SelfType)
294 LhsNestedType
const& mA;
295 RhsNestedType
const& mB;
296 CompareType mComparator;
299template <
class TLhsMatrix,
class TRhsMatrix>
300PBAT_HOST_DEVICE
auto operator+(TLhsMatrix&& A, TRhsMatrix&& B)
302 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
303 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
304 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
310 return Sum<LhsMatrixType, RhsMatrixType>(
311 std::forward<TLhsMatrix>(A),
312 std::forward<TRhsMatrix>(B));
316template <
class TLhsMatrix,
class TRhsMatrix>
317PBAT_HOST_DEVICE
auto operator+=(TLhsMatrix&& A, TRhsMatrix&& B)
319 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
320 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
322 AddAssignScalar(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
326 AddAssign(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
331template <
class TLhsMatrix,
class TRhsMatrix>
332PBAT_HOST_DEVICE
auto operator-(TLhsMatrix&& A, TRhsMatrix&& B)
334 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
335 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
336 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
339 std::forward<TLhsMatrix>(A),
340 std::forward<TRhsMatrix>(B));
345 std::forward<TLhsMatrix>(A),
346 std::forward<TRhsMatrix>(B));
350template <
class TLhsMatrix,
class TRhsMatrix>
351PBAT_HOST_DEVICE
auto operator-=(TLhsMatrix&& A, TRhsMatrix&& B)
353 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
354 if constexpr (std::is_arithmetic_v<RhsMatrixType>)
356 SubtractAssignScalar(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
360 SubtractAssign(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
365template <
class TLhsMatrix,
class TRhsMatrix>
366PBAT_HOST_DEVICE
auto Min(TLhsMatrix&& A, TRhsMatrix&& B)
368 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
369 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
371 std::forward<TLhsMatrix>(A),
372 std::forward<TRhsMatrix>(B));
375template <
class TLhsMatrix,
class TRhsMatrix>
376PBAT_HOST_DEVICE
auto Max(TLhsMatrix&& A, TRhsMatrix&& B)
378 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
379 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
381 std::forward<TLhsMatrix>(A),
382 std::forward<TRhsMatrix>(B));
385#define PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(Operator, Comparator) \
386 template <CMatrix TMatrix> \
387 PBAT_HOST_DEVICE auto Operator(TMatrix const& A, typename TMatrix::ScalarType k) \
389 using ScalarType = typename TMatrix::ScalarType; \
390 using CompareType = Comparator<ScalarType>; \
391 return MatrixScalarPredicate<TMatrix, CompareType>(A, k, CompareType{}); \
394#define PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(Operator, Comparator) \
395 template <CMatrix TLhsMatrix, CMatrix TRhsMatrix> \
396 PBAT_HOST_DEVICE auto Operator(TLhsMatrix const& A, TRhsMatrix const& B) \
398 using ScalarType = typename TLhsMatrix::ScalarType; \
399 using CompareType = Comparator<ScalarType>; \
400 return MatrixMatrixPredicate<TLhsMatrix, TRhsMatrix, CompareType>(A, B, CompareType{}); \
403PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator<, std::less)
404PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator>, std::greater)
405PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator==, std::equal_to)
406PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator!=, std::not_equal_to)
407PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator<=, std::less_equal)
408PBAT_MINI_DEFINE_MATRIX_SCALAR_PREDICATE(
operator>=, std::greater_equal)
410PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator<, std::less)
411PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator>, std::greater)
412PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator==, std::equal_to)
413PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator!=, std::not_equal_to)
414PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator<=, std::less_equal)
415PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator>=, std::greater_equal)
417PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator&&, std::logical_and)
418PBAT_MINI_DEFINE_MATRIX_MATRIX_PREDICATE(
operator||, std::logical_or)
Definition BinaryOperations.h:185
Definition BinaryOperations.h:146
Definition BinaryOperations.h:84
Definition BinaryOperations.h:119
Definition BinaryOperations.h:56
Mini linear algebra related functionality.
Definition Assign.h:12
Linear Algebra related functionality.
Definition FilterEigenvalues.h:7
Math related functionality.
Definition Concepts.h:19
The main namespace of the library.
Definition Aliases.h:15