Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
pairwise_argmin.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <cstddef>
6#include <cstdint>
7#include <type_traits>
8
11#include "clustering/math/detail/pairwise_argmin_outer.h"
14#include "clustering/ndarray.h"
15
16namespace clustering::math {
17
24inline constexpr std::size_t pairwiseArgminChunkRows = 256;
25
36[[nodiscard]] inline std::array<std::size_t, 2>
37chunkedMaterializedScratchShape(std::size_t n, std::size_t k) noexcept {
38 const std::size_t rows = (n < pairwiseArgminChunkRows) ? n : pairwiseArgminChunkRows;
39 return {rows == 0 ? std::size_t{1} : rows, k == 0 ? std::size_t{1} : k};
40}
41
42namespace detail {
43
51enum class ArgminPath : std::uint8_t { Fused, Materialized };
52
78template <class T, Layout LX, Layout LC>
81 NDArray<T, 1> &outMinSq, NDArray<T, 2> &distsScratch,
82 Pool pool) {
83 const std::size_t n = X.dim(0);
84 const std::size_t k = C.dim(0);
85 const std::size_t d = X.dim(1);
86 if (n == 0 || k == 0) {
87 return;
88 }
89 const std::size_t chunkCap = pairwiseArgminChunkRows;
90 CLUSTERING_ALWAYS_ASSERT(distsScratch.dim(1) >= k);
91 CLUSTERING_ALWAYS_ASSERT(distsScratch.dim(0) >= (n < chunkCap ? n : chunkCap));
92
93 auto runChunk = [&](std::size_t iBase, std::size_t chunkRows, const auto &xChunk) noexcept {
94 // The scratch tile may be wider than k when the caller pre-sized with k padded up; the
95 // dispatch view narrows it back to (chunkRows, k) so the callee sees exactly one chunk.
96 NDArray<T, 2> distsView = NDArray<T, 2>::borrow(distsScratch.data(), {chunkRows, k});
97 pairwiseSqEuclidean(xChunk, C, distsView, pool);
98
99 auto scanRange = [&](std::size_t lo, std::size_t hi) noexcept {
100 for (std::size_t i = lo; i < hi; ++i) {
101 const T *row = distsView.data() + (i * k);
102 T bestVal = row[0];
103 std::int32_t bestIdx = 0;
104 for (std::size_t j = 1; j < k; ++j) {
105 const T v = row[j];
106 if (v < bestVal) {
107 bestVal = v;
108 bestIdx = static_cast<std::int32_t>(j);
109 }
110 }
111 outMinSq(iBase + i) = bestVal;
112 labels(iBase + i) = bestIdx;
113 }
114 };
115
116 if (pool.shouldParallelize(chunkRows, 4, 2) && pool.pool != nullptr) {
117 pool.pool
118 ->submit_blocks(std::size_t{0}, chunkRows,
119 [&](std::size_t lo, std::size_t hi) { scanRange(lo, hi); })
120 .wait();
121 } else {
122 scanRange(0, chunkRows);
123 }
124 };
125
126 for (std::size_t iBase = 0; iBase < n; iBase += chunkCap) {
127 const std::size_t chunkRows = (iBase + chunkCap <= n) ? chunkCap : (n - iBase);
128 if constexpr (LX == Layout::Contig) {
129 auto xChunk = NDArray<T, 2, Layout::Contig>::borrow(X.data() + (iBase * d), {chunkRows, d});
130 runChunk(iBase, chunkRows, xChunk);
131 } else {
132 auto xChunk = X.slice(0, iBase, iBase + chunkRows);
133 runChunk(iBase, chunkRows, xChunk);
134 }
135 }
136}
137
145template <class T, Layout LX, Layout LC>
147 NDArray<std::int32_t, 1> &labels, NDArray<T, 1> &outMinSq,
148 Pool pool) {
149 const std::size_t n = X.dim(0);
150 const std::size_t k = C.dim(0);
151 if (n == 0 || k == 0) {
152 return;
153 }
154
155 const auto shape = chunkedMaterializedScratchShape(n, k);
156 NDArray<T, 2> distsScratch({shape[0], shape[1]});
157 pairwiseArgminMaterializedWithScratch(X, C, labels, outMinSq, distsScratch, pool);
158}
159
172template <class T, Layout LX, Layout LC>
174 const NDArray<T, 1> &cSqNorms) noexcept {
175#ifdef CLUSTERING_USE_AVX2
176 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig && LC == Layout::Contig) {
177 const std::size_t n = X.dim(0);
178 const std::size_t k = C.dim(0);
179 const std::size_t d = X.dim(1);
180 if (n == 0 || k == 0 || d == 0) {
181 return false;
182 }
184 return false;
185 }
186 if (!X.template isAligned<32>() || !C.template isAligned<32>()) {
187 return false;
188 }
189 if (!cSqNorms.isContiguous()) {
190 return false;
191 }
192 return true;
193 } else {
194 (void)X;
195 (void)C;
196 (void)cSqNorms;
197 return false;
198 }
199#else
200 (void)X;
201 (void)C;
202 (void)cSqNorms;
203 return false;
204#endif
205}
206
207} // namespace detail
208
231template <class T, Layout LX = Layout::Contig, Layout LC = Layout::Contig>
233 const NDArray<T, 1> &cSqNorms, NDArray<std::int32_t, 1> &labels,
234 NDArray<T, 1> &outMinDistSq, Pool pool) {
235 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
236 "pairwiseArgminSqEuclidean<T> requires T to be float or double");
237
239 CLUSTERING_ALWAYS_ASSERT(outMinDistSq.isMutable());
240 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == C.dim(1));
241 CLUSTERING_ALWAYS_ASSERT(labels.dim(0) == X.dim(0));
242 CLUSTERING_ALWAYS_ASSERT(outMinDistSq.dim(0) == X.dim(0));
243 CLUSTERING_ALWAYS_ASSERT(cSqNorms.dim(0) == C.dim(0));
244
245 const std::size_t n = X.dim(0);
246 const std::size_t k = C.dim(0);
247 if (n == 0 || k == 0) {
248 return;
249 }
250
251#ifdef CLUSTERING_USE_AVX2
252 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig && LC == Layout::Contig) {
253 if (detail::canUseFusedArgmin(X, C, cSqNorms)) {
254 detail::pairwiseArgminOuterAvx2F32(X, C, cSqNorms, labels, outMinDistSq, pool);
255 return;
256 }
257 }
258#endif
259
260 detail::pairwiseArgminMaterialized(X, C, labels, outMinDistSq, pool);
261}
262
263namespace detail {
264
273template <class T, Layout LX = Layout::Contig, Layout LC = Layout::Contig>
275 const NDArray<T, 2, LC> &C,
276 const NDArray<T, 1> &cSqNorms,
278 NDArray<T, 1> &outMinDistSq, Pool pool) {
279 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
280 "pairwiseArgminSqEuclideanWithDispatchInfo<T> requires T to be float or double");
281
283 CLUSTERING_ALWAYS_ASSERT(outMinDistSq.isMutable());
284 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == C.dim(1));
285 CLUSTERING_ALWAYS_ASSERT(labels.dim(0) == X.dim(0));
286 CLUSTERING_ALWAYS_ASSERT(outMinDistSq.dim(0) == X.dim(0));
287 CLUSTERING_ALWAYS_ASSERT(cSqNorms.dim(0) == C.dim(0));
288
289 const std::size_t n = X.dim(0);
290 const std::size_t k = C.dim(0);
291 if (n == 0 || k == 0) {
293 }
294
295#ifdef CLUSTERING_USE_AVX2
296 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig && LC == Layout::Contig) {
297 if (canUseFusedArgmin(X, C, cSqNorms)) {
298 pairwiseArgminOuterAvx2F32(X, C, cSqNorms, labels, outMinDistSq, pool);
299 return ArgminPath::Fused;
300 }
301 }
302#endif
303
304 pairwiseArgminMaterialized(X, C, labels, outMinDistSq, pool);
306}
307
308} // namespace detail
309
310} // namespace clustering::math
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
NDArray< T, N, Layout::MaybeStrided > slice(std::size_t axis, std::size_t begin, std::size_t end) noexcept
Borrowed half-open slice along a single axis.
Definition ndarray.h:773
static NDArray borrow(T *ptr, std::array< std::size_t, N > shape) noexcept
Borrows a contiguous buffer as an NDArray without taking ownership.
Definition ndarray.h:570
const T * data() const noexcept
Provides read-only access to the internal data array.
Definition ndarray.h:503
bool isMutable() const noexcept
Reports whether writes through operator(), Accessor, or flatIndex are allowed.
Definition ndarray.h:488
constexpr std::size_t pairwiseArgminMaxD
Maximum feature dimension for which the fused pairwiseArgminSqEuclidean driver is used.
Definition defaults.h:76
ArgminPath
Tag identifying which outer driver executed for a pairwiseArgminSqEuclidean request.
void pairwiseArgminMaterializedWithScratch(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinSq, NDArray< T, 2 > &distsScratch, Pool pool)
Compute per-row argmin + minimum squared distance over n in 256-row strips using a caller-owned dista...
void pairwiseArgminMaterialized(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinSq, Pool pool)
Compute per-row argmin and minimum squared distance via the materialized two-step.
bool canUseFusedArgmin(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms) noexcept
Runtime predicate: true when the fused AVX2 path is eligible for this call.
ArgminPath pairwiseArgminSqEuclideanWithDispatchInfo(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinDistSq, Pool pool)
Test-only: runs the same dispatch as pairwiseArgminSqEuclidean and reports which outer driver fired.
std::array< std::size_t, 2 > chunkedMaterializedScratchShape(std::size_t n, std::size_t k) noexcept
Required shape for the chunked materialized argmin scratch buffer.
void pairwiseArgminSqEuclidean(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinDistSq, Pool pool)
Per-row argmin and minimum squared distance of rows of X against rows of C.
constexpr std::size_t pairwiseArgminChunkRows
Chunk height used by the materialized argmin path when striping over n.
void pairwiseSqEuclidean(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Pairwise squared Euclidean distances between rows of two matrices.
Definition pairwise.h:350
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63
BS::light_thread_pool * pool
Underlying pool, or nullptr to force serial execution.
Definition thread.h:65
bool shouldParallelize(std::size_t totalWork, std::size_t minChunk, std::size_t minTasksPerWorker=2) const noexcept
Decide whether totalWork warrants parallel dispatch.
Definition thread.h:98