20 using LhsNestedType = TLhsMatrix;
21 using RhsNestedType = TRhsMatrix;
23 using ScalarType =
typename LhsNestedType::ScalarType;
24 using SelfType = Product<LhsNestedType, RhsNestedType>;
26 static auto constexpr kRows = LhsNestedType::kRows;
27 static auto constexpr kCols = RhsNestedType::kCols;
28 static bool constexpr bRowMajor =
false;
30 PBAT_HOST_DEVICE Product(LhsNestedType
const& _A, RhsNestedType
const& _B) : A(_A), B(_B) {}
32 PBAT_HOST_DEVICE ScalarType operator()(
auto i,
auto j)
const
34 using IntegerType = std::remove_const_t<
decltype(LhsNestedType::kRows)>;
35 auto contract = [
this, i, j]<IntegerType... K>(std::integer_sequence<IntegerType, K...>) {
36 return ((A(i, K) * B(K, j)) + ...);
38 return contract(std::make_integer_sequence<IntegerType, LhsNestedType::kCols>());
42 PBAT_HOST_DEVICE ScalarType operator()(
auto i)
const {
return (*
this)(i % kRows, i / kRows); }
43 PBAT_HOST_DEVICE ScalarType operator[](
auto i)
const {
return (*
this)(i); }
45 PBAT_MINI_READ_API(SelfType)
48 LhsNestedType
const& A;
49 RhsNestedType
const& B;