14#define DefineMatrixMatrixAssign(FunctionName, Operator) \
15 template <class TLhsMatrix, class TRhsMatrix> \
16 PBAT_HOST_DEVICE void FunctionName(TLhsMatrix&& A, TRhsMatrix&& B) \
18 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>; \
19 using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>; \
20 static_assert(CMatrix<LhsMatrixType>, "Left input must satisfy concept CMatrix"); \
21 static_assert(CMatrix<RhsMatrixType>, "Right input must satisfy concept CMatrix"); \
23 LhsMatrixType::kRows == RhsMatrixType::kRows and \
24 LhsMatrixType::kCols == RhsMatrixType::kCols, \
25 "A and B must have same dimensions"); \
26 using IntegerType = std::remove_const_t<decltype(LhsMatrixType::kRows)>; \
27 if constexpr (LhsMatrixType::bRowMajor and RhsMatrixType::bRowMajor) \
29 auto fCols = [&]<IntegerType... J>( \
31 std::integer_sequence<IntegerType, J...>) { \
32 ((std::forward<TLhsMatrix>(A)(i, J) Operator std::forward<TRhsMatrix>(B)(i, J)), \
35 auto fRows = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) { \
36 (fCols(I, std::make_integer_sequence<IntegerType, LhsMatrixType::kCols>()), ...); \
38 fRows(std::make_integer_sequence<IntegerType, LhsMatrixType::kRows>()); \
42 auto fRows = [&]<IntegerType... I>( \
44 std::integer_sequence<IntegerType, I...>) { \
45 ((std::forward<TLhsMatrix>(A)(I, j) Operator std::forward<TRhsMatrix>(B)(I, j)), \
48 auto fCols = [&]<IntegerType... J>(std::integer_sequence<IntegerType, J...>) { \
49 (fRows(J, std::make_integer_sequence<IntegerType, LhsMatrixType::kRows>()), ...); \
51 fCols(std::make_integer_sequence<IntegerType, LhsMatrixType::kCols>()); \
55#define DefineMatrixScalarAssign(FunctionName, Operator) \
56 template <class TLhsMatrix> \
57 PBAT_HOST_DEVICE void FunctionName( \
59 typename std::remove_cvref_t<TLhsMatrix>::ScalarType k) \
61 using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>; \
62 static_assert(CMatrix<LhsMatrixType>, "Left input must satisfy concept CMatrix"); \
63 using IntegerType = std::remove_const_t<decltype(LhsMatrixType::kRows)>; \
64 if constexpr (LhsMatrixType::bRowMajor) \
67 [&]<IntegerType... J>(IntegerType i, std::integer_sequence<IntegerType, J...>) { \
68 ((std::forward<TLhsMatrix>(A)(i, J) Operator k), ...); \
70 auto fRows = [&]<IntegerType... I>(std::integer_sequence<IntegerType, I...>) { \
71 (fCols(I, std::make_integer_sequence<IntegerType, LhsMatrixType::kCols>()), ...); \
73 fRows(std::make_integer_sequence<IntegerType, LhsMatrixType::kRows>()); \
78 [&]<IntegerType... I>(IntegerType j, std::integer_sequence<IntegerType, I...>) { \
79 ((std::forward<TLhsMatrix>(A)(I, j) Operator k), ...); \
81 auto fCols = [&]<IntegerType... J>(std::integer_sequence<IntegerType, J...>) { \
82 (fRows(J, std::make_integer_sequence<IntegerType, LhsMatrixType::kRows>()), ...); \
84 fCols(std::make_integer_sequence<IntegerType, LhsMatrixType::kCols>()); \
88DefineMatrixMatrixAssign(Assign, =);
89DefineMatrixMatrixAssign(AddAssign, +=);
90DefineMatrixMatrixAssign(SubtractAssign, -=);
91DefineMatrixScalarAssign(AssignScalar, =);
92DefineMatrixScalarAssign(AddAssignScalar, +=);
93DefineMatrixScalarAssign(SubtractAssignScalar, -=);
94DefineMatrixScalarAssign(MultiplyAssign, *=);
95DefineMatrixScalarAssign(DivideAssign, /=);
97#define PBAT_MINI_ASSIGN_API(SelfType) \
98 template <class TOtherMatrix> \
99 PBAT_HOST_DEVICE SelfType& operator=(TOtherMatrix&& B) \
101 Assign(*this, std::forward<TOtherMatrix>(B)); \