Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
gemm_plan.h
Go to the documentation of this file.
1#pragma once
2
3#include <cstddef>
4#include <cstdint>
5#include <type_traits>
6#include <vector>
7
8#include "clustering/math/detail/gemm_outer_prepacked.h"
9#include "clustering/math/detail/gemm_pack.h"
10#include "clustering/math/detail/matrix_desc.h"
11#include "clustering/math/detail/reference_gemm.h"
13#include "clustering/ndarray.h"
14
15namespace clustering::math {
16
40template <class T, class Backend = detail::ReferenceGemm> class GemmPlan {
41 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
42 "GemmPlan: T must be float or double");
43
44public:
54 template <Layout LB>
56 : m_kDim(B.dim(0)), m_nDim(B.dim(1)), m_workerCount(pool.workerCount()), m_pool(pool) {
57 constexpr std::size_t kNr = detail::kKernelNr<T>;
58 constexpr std::size_t kNcVal = detail::kNc<T>;
59 constexpr std::size_t kKcVal = detail::kKc<T>;
60 constexpr std::size_t kMcVal = detail::kMc<T>;
61
62 m_scratch.assign(m_workerCount * kMcVal * kKcVal, T{0});
63
64 if (m_kDim == 0 || m_nDim == 0) {
65 // No Bp storage needed: execute() treats K==0 as the BLAS C<-beta*C identity without
66 // reading Bp, and N==0 is a no-op.
67 return;
68 }
69
70 // Two-pass: first compute total Bp size, then pack. Keeping the sizing pass separate lets
71 // us call reserve/resize exactly once; the pack loop walks the same offset arithmetic that
72 // gemmRunPrepacked uses, so the two stay structurally locked.
73 std::size_t total = 0;
74 for (std::size_t jc = 0; jc < m_nDim; jc += kNcVal) {
75 const std::size_t nc = (jc + kNcVal <= m_nDim) ? kNcVal : (m_nDim - jc);
76 const std::size_t roundedNc = ((nc + kNr - 1) / kNr) * kNr;
77 total += m_kDim * roundedNc;
78 }
79 m_Bp.assign(total, T{0});
80
81 auto Bd = ::clustering::detail::describeMatrix(B);
82
83 std::size_t jcBase = 0;
84 for (std::size_t jc = 0; jc < m_nDim; jc += kNcVal) {
85 const std::size_t nc = (jc + kNcVal <= m_nDim) ? kNcVal : (m_nDim - jc);
86 const std::size_t roundedNc = ((nc + kNr - 1) / kNr) * kNr;
87
88 std::size_t pcOffInJc = 0;
89 for (std::size_t pc = 0; pc < m_kDim; pc += kKcVal) {
90 const std::size_t kc = (pc + kKcVal <= m_kDim) ? kKcVal : (m_kDim - pc);
91 detail::packB<T>(Bd, pc, kc, jc, nc, m_Bp.data() + jcBase + pcOffInJc);
92 pcOffInJc += kc * roundedNc;
93 }
94 jcBase += m_kDim * roundedNc;
95 }
96 }
97
108 template <Layout LA>
109 void execute(const NDArray<T, 2, LA> &A, NDArray<T, 2> &C, T alpha = T{1},
110 T beta = T{0}) const noexcept {
111 if (A.dim(0) == 0 || m_nDim == 0) {
112 return;
113 }
114
115 // Pass the full scratch base -- gemmRunPrepacked slices per-worker inside its Mc dispatch via
116 // Pool::workerIndex(). On the serial path workerIndex() returns 0, so slice 0 is used.
117 auto Ad = ::clustering::detail::describeMatrix(A);
118 auto Cd = ::clustering::detail::describeMatrixMut(C);
119 detail::gemmRunPrepacked<T>(Ad, m_Bp.data(), m_kDim, m_nDim, Cd, alpha, beta, m_scratch.data(),
120 m_pool);
121 }
122
124 [[nodiscard]] std::size_t kDim() const noexcept { return m_kDim; }
125
127 [[nodiscard]] std::size_t nDim() const noexcept { return m_nDim; }
128
131 [[nodiscard]] const T *debugBpData() const noexcept { return m_Bp.data(); }
132
135 [[nodiscard]] std::size_t debugScratchSize() const noexcept { return m_scratch.size(); }
136
137 GemmPlan(const GemmPlan &) = delete;
138 GemmPlan &operator=(const GemmPlan &) = delete;
140 GemmPlan(GemmPlan &&) noexcept = default;
142 GemmPlan &operator=(GemmPlan &&) noexcept = default;
143 ~GemmPlan() = default;
144
145private:
146 std::size_t m_kDim = 0;
147 std::size_t m_nDim = 0;
148 std::size_t m_workerCount = 1;
149 Pool m_pool{};
150 std::vector<T, ::clustering::detail::AlignedAllocator<T, 32>> m_Bp;
151 // mutable: execute() is const on the plan's observable shape but the scratch is a per-call
152 // mutation surface sliced by worker index.
153 mutable std::vector<T, ::clustering::detail::AlignedAllocator<T, 32>> m_scratch;
154};
155
156} // namespace clustering::math
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
const T * debugBpData() const noexcept
Debug accessor exposing the packed B pointer so tests can pin alignment.
Definition gemm_plan.h:131
void execute(const NDArray< T, 2, LA > &A, NDArray< T, 2 > &C, T alpha=T{1}, T beta=T{0}) const noexcept
Execute the plan: compute C := alpha * A * B + beta * C against the pre-packed B captured at construc...
Definition gemm_plan.h:109
GemmPlan(GemmPlan &&) noexcept=default
Defaulted move constructor; transfers the packed B panel and scratch.
std::size_t nDim() const noexcept
Column count captured at construction (B.cols).
Definition gemm_plan.h:127
GemmPlan & operator=(const GemmPlan &)=delete
std::size_t debugScratchSize() const noexcept
Debug accessor exposing the scratch capacity so tests can pin the sizing formula.
Definition gemm_plan.h:135
GemmPlan(const NDArray< T, 2, LB > &B, Pool pool)
Construct the plan and fully pre-pack B into m_Bp.
Definition gemm_plan.h:55
std::size_t kDim() const noexcept
Inner dimension captured at construction (B.rows).
Definition gemm_plan.h:124
GemmPlan(const GemmPlan &)=delete
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63