11#include "clustering/math/detail/pairwise_argmin_outer.h"
36[[nodiscard]]
inline std::array<std::size_t, 2>
39 return {rows == 0 ? std::size_t{1} : rows, k == 0 ? std::size_t{1} : k};
78template <
class T, Layout LX, Layout LC>
83 const std::size_t n = X.
dim(0);
84 const std::size_t k = C.
dim(0);
85 const std::size_t d = X.
dim(1);
86 if (n == 0 || k == 0) {
93 auto runChunk = [&](std::size_t iBase, std::size_t chunkRows,
const auto &xChunk)
noexcept {
99 auto scanRange = [&](std::size_t lo, std::size_t hi)
noexcept {
100 for (std::size_t i = lo; i < hi; ++i) {
101 const T *row = distsView.
data() + (i * k);
103 std::int32_t bestIdx = 0;
104 for (std::size_t j = 1; j < k; ++j) {
108 bestIdx =
static_cast<std::int32_t
>(j);
111 outMinSq(iBase + i) = bestVal;
112 labels(iBase + i) = bestIdx;
118 ->submit_blocks(std::size_t{0}, chunkRows,
119 [&](std::size_t lo, std::size_t hi) { scanRange(lo, hi); })
122 scanRange(0, chunkRows);
126 for (std::size_t iBase = 0; iBase < n; iBase += chunkCap) {
127 const std::size_t chunkRows = (iBase + chunkCap <= n) ? chunkCap : (n - iBase);
130 runChunk(iBase, chunkRows, xChunk);
132 auto xChunk = X.
slice(0, iBase, iBase + chunkRows);
133 runChunk(iBase, chunkRows, xChunk);
145template <
class T, Layout LX, Layout LC>
149 const std::size_t n = X.
dim(0);
150 const std::size_t k = C.
dim(0);
151 if (n == 0 || k == 0) {
172template <
class T, Layout LX, Layout LC>
175#ifdef CLUSTERING_USE_AVX2
177 const std::size_t n = X.dim(0);
178 const std::size_t k = C.dim(0);
179 const std::size_t d = X.dim(1);
180 if (n == 0 || k == 0 || d == 0) {
186 if (!X.template isAligned<32>() || !C.template isAligned<32>()) {
189 if (!cSqNorms.isContiguous()) {
231template <
class T, Layout LX = Layout::Contig, Layout LC = Layout::Contig>
235 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
236 "pairwiseArgminSqEuclidean<T> requires T to be float or double");
245 const std::size_t n = X.
dim(0);
246 const std::size_t k = C.
dim(0);
247 if (n == 0 || k == 0) {
251#ifdef CLUSTERING_USE_AVX2
254 detail::pairwiseArgminOuterAvx2F32(X, C, cSqNorms, labels, outMinDistSq, pool);
273template <
class T, Layout LX = Layout::Contig, Layout LC = Layout::Contig>
279 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
280 "pairwiseArgminSqEuclideanWithDispatchInfo<T> requires T to be float or double");
289 const std::size_t n = X.
dim(0);
290 const std::size_t k = C.
dim(0);
291 if (n == 0 || k == 0) {
295#ifdef CLUSTERING_USE_AVX2
298 pairwiseArgminOuterAvx2F32(X, C, cSqNorms, labels, outMinDistSq, pool);
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
NDArray< T, N, Layout::MaybeStrided > slice(std::size_t axis, std::size_t begin, std::size_t end) noexcept
Borrowed half-open slice along a single axis.
static NDArray borrow(T *ptr, std::array< std::size_t, N > shape) noexcept
Borrows a contiguous buffer as an NDArray without taking ownership.
const T * data() const noexcept
Provides read-only access to the internal data array.
bool isMutable() const noexcept
Reports whether writes through operator(), Accessor, or flatIndex are allowed.
constexpr std::size_t pairwiseArgminMaxD
Maximum feature dimension for which the fused pairwiseArgminSqEuclidean driver is used.
ArgminPath
Tag identifying which outer driver executed for a pairwiseArgminSqEuclidean request.
void pairwiseArgminMaterializedWithScratch(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinSq, NDArray< T, 2 > &distsScratch, Pool pool)
Compute per-row argmin + minimum squared distance over n in 256-row strips using a caller-owned dista...
void pairwiseArgminMaterialized(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinSq, Pool pool)
Compute per-row argmin and minimum squared distance via the materialized two-step.
bool canUseFusedArgmin(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms) noexcept
Runtime predicate: true when the fused AVX2 path is eligible for this call.
ArgminPath pairwiseArgminSqEuclideanWithDispatchInfo(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinDistSq, Pool pool)
Test-only: runs the same dispatch as pairwiseArgminSqEuclidean and reports which outer driver fired.
std::array< std::size_t, 2 > chunkedMaterializedScratchShape(std::size_t n, std::size_t k) noexcept
Required shape for the chunked materialized argmin scratch buffer.
void pairwiseArgminSqEuclidean(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LC > &C, const NDArray< T, 1 > &cSqNorms, NDArray< std::int32_t, 1 > &labels, NDArray< T, 1 > &outMinDistSq, Pool pool)
Per-row argmin and minimum squared distance of rows of X against rows of C.
constexpr std::size_t pairwiseArgminChunkRows
Chunk height used by the materialized argmin path when striping over n.
void pairwiseSqEuclidean(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Pairwise squared Euclidean distances between rows of two matrices.
Thin injection wrapper around a BS::light_thread_pool.
BS::light_thread_pool * pool
Underlying pool, or nullptr to force serial execution.
bool shouldParallelize(std::size_t totalWork, std::size_t minChunk, std::size_t minTasksPerWorker=2) const noexcept
Decide whether totalWork warrants parallel dispatch.