41 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
56 : m_kDim(B.dim(0)), m_nDim(B.dim(1)), m_workerCount(pool.workerCount()), m_pool(pool) {
57 constexpr std::size_t kNr = detail::kKernelNr<T>;
58 constexpr std::size_t kNcVal = detail::kNc<T>;
59 constexpr std::size_t kKcVal = detail::kKc<T>;
60 constexpr std::size_t kMcVal = detail::kMc<T>;
62 m_scratch.assign(m_workerCount * kMcVal * kKcVal, T{0});
64 if (m_kDim == 0 || m_nDim == 0) {
73 std::size_t total = 0;
74 for (std::size_t jc = 0; jc < m_nDim; jc += kNcVal) {
75 const std::size_t nc = (jc + kNcVal <= m_nDim) ? kNcVal : (m_nDim - jc);
76 const std::size_t roundedNc = ((nc + kNr - 1) / kNr) * kNr;
77 total += m_kDim * roundedNc;
79 m_Bp.assign(total, T{0});
81 auto Bd = ::clustering::detail::describeMatrix(B);
83 std::size_t jcBase = 0;
84 for (std::size_t jc = 0; jc < m_nDim; jc += kNcVal) {
85 const std::size_t nc = (jc + kNcVal <= m_nDim) ? kNcVal : (m_nDim - jc);
86 const std::size_t roundedNc = ((nc + kNr - 1) / kNr) * kNr;
88 std::size_t pcOffInJc = 0;
89 for (std::size_t pc = 0; pc < m_kDim; pc += kKcVal) {
90 const std::size_t kc = (pc + kKcVal <= m_kDim) ? kKcVal : (m_kDim - pc);
91 detail::packB<T>(Bd, pc, kc, jc, nc, m_Bp.data() + jcBase + pcOffInJc);
92 pcOffInJc += kc * roundedNc;
94 jcBase += m_kDim * roundedNc;
110 T beta = T{0})
const noexcept {
111 if (A.
dim(0) == 0 || m_nDim == 0) {
117 auto Ad = ::clustering::detail::describeMatrix(A);
118 auto Cd = ::clustering::detail::describeMatrixMut(C);
119 detail::gemmRunPrepacked<T>(Ad, m_Bp.data(), m_kDim, m_nDim, Cd, alpha, beta, m_scratch.data(),