Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
greedy_kmpp_seeder.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <cmath>
6#include <cstddef>
7#include <cstdint>
8#include <cstring>
9#include <type_traits>
10#include <vector>
11
13#include "clustering/math/detail/avx2_helpers.h"
14#include "clustering/math/detail/gemm_outer.h"
15#include "clustering/math/detail/matrix_desc.h"
17#include "clustering/math/rng.h"
19#include "clustering/ndarray.h"
20
21#ifdef CLUSTERING_USE_AVX2
22#include <immintrin.h>
23
24#include "clustering/math/detail/kmpp_score_avx2.h"
25#endif
26
27namespace clustering::kmeans {
28
29namespace detail {
30
31using math::detail::sqEuclideanRowPtr;
32
42[[nodiscard]] inline std::size_t greedyKmppLocalTrials(std::size_t k) noexcept {
43 if (k <= 1) {
44 return 1;
45 }
46 const auto lnK = std::log(static_cast<double>(k));
47 return 2 + static_cast<std::size_t>(lnK);
48}
49
57[[nodiscard]] constexpr std::size_t greedyKmppTransposedWidth(std::size_t L) noexcept {
58 constexpr std::size_t kChunk = 8;
59 return ((L + kChunk - 1) / kChunk) * kChunk;
60}
61
62#ifdef CLUSTERING_USE_AVX2
63
73template <std::size_t B>
74[[gnu::always_inline]] inline void
75sqEuclideanRowToBatchAvx2Fixed(const float *x, const float *candData, std::size_t d,
76 float *out) noexcept {
77 static_assert(B >= 1 && B <= 8, "B must lie in [1, 8] -- 8 ymm regs hold the batch");
78 // Double accumulator set (2 * B YMMs) over a 2x-unrolled K loop. Halves the per-iter fmadd
79 // dependency chain so Zen5's 4-FMA-per-cycle throughput isn't latency-bound on the 4-cycle
80 // fmadd round-trip; also gives the register allocator enough explicit live ranges to keep
81 // accumulators in YMM registers rather than spilling to the stack (measured: 8 GFLOPS with
82 // the original 1x loop, ~2x post-unroll on the seeder's B=4 hot path).
83 std::array<__m256, B> acc0{};
84 std::array<__m256, B> acc1{};
85 for (std::size_t t = 0; t < B; ++t) {
86 acc0[t] = _mm256_setzero_ps();
87 acc1[t] = _mm256_setzero_ps();
88 }
89 std::size_t k = 0;
90 for (; k + 16 <= d; k += 16) {
91 const __m256 vx0 = _mm256_loadu_ps(x + k);
92 const __m256 vx1 = _mm256_loadu_ps(x + k + 8);
93 for (std::size_t t = 0; t < B; ++t) {
94 const __m256 vc0 = _mm256_loadu_ps(candData + (t * d) + k);
95 const __m256 vc1 = _mm256_loadu_ps(candData + (t * d) + k + 8);
96 const __m256 diff0 = _mm256_sub_ps(vx0, vc0);
97 const __m256 diff1 = _mm256_sub_ps(vx1, vc1);
98 acc0[t] = _mm256_fmadd_ps(diff0, diff0, acc0[t]);
99 acc1[t] = _mm256_fmadd_ps(diff1, diff1, acc1[t]);
100 }
101 }
102 // 8-lane tail.
103 for (; k + 8 <= d; k += 8) {
104 const __m256 vx = _mm256_loadu_ps(x + k);
105 for (std::size_t t = 0; t < B; ++t) {
106 const __m256 vc = _mm256_loadu_ps(candData + (t * d) + k);
107 const __m256 diff = _mm256_sub_ps(vx, vc);
108 acc0[t] = _mm256_fmadd_ps(diff, diff, acc0[t]);
109 }
110 }
111 std::array<float, B> tail{};
112 for (std::size_t t = 0; t < B; ++t) {
113 tail[t] = 0.0F;
114 }
115 for (std::size_t kt = k; kt < d; ++kt) {
116 const float xk = x[kt];
117 for (std::size_t t = 0; t < B; ++t) {
118 const float diff = xk - candData[(t * d) + kt];
119 tail[t] += diff * diff;
120 }
121 }
122 for (std::size_t t = 0; t < B; ++t) {
123 const __m256 sum = _mm256_add_ps(acc0[t], acc1[t]);
124 out[t] = math::detail::horizontalSumAvx2(sum) + tail[t];
125 }
126}
127
128template <std::size_t B>
129inline void sqEuclideanRowToBatchAvx2Fixed(const double *x, const double *candData, std::size_t d,
130 double *out) noexcept {
131 static_assert(B >= 1 && B <= 8, "B must lie in [1, 8] -- 8 ymm regs hold the batch");
132 std::array<__m256d, B> acc{};
133 for (std::size_t t = 0; t < B; ++t) {
134 acc[t] = _mm256_setzero_pd();
135 }
136 std::size_t k = 0;
137 for (; k + 4 <= d; k += 4) {
138 const __m256d vx = _mm256_loadu_pd(x + k);
139 for (std::size_t t = 0; t < B; ++t) {
140 const __m256d vc = _mm256_loadu_pd(candData + (t * d) + k);
141 const __m256d diff = _mm256_sub_pd(vx, vc);
142 acc[t] = _mm256_fmadd_pd(diff, diff, acc[t]);
143 }
144 }
145 std::array<double, B> tail{};
146 for (std::size_t t = 0; t < B; ++t) {
147 tail[t] = 0.0;
148 }
149 for (std::size_t kt = k; kt < d; ++kt) {
150 const double xk = x[kt];
151 for (std::size_t t = 0; t < B; ++t) {
152 const double diff = xk - candData[(t * d) + kt];
153 tail[t] += diff * diff;
154 }
155 }
156 for (std::size_t t = 0; t < B; ++t) {
157 out[t] = math::detail::horizontalSumAvx2(acc[t]) + tail[t];
158 }
159}
160
169template <class T>
170inline void sqEuclideanRowToBatchAvx2(const T *x, const T *candData, std::size_t L, std::size_t d,
171 T *out) noexcept {
172 std::size_t base = 0;
173 while (base + 8 <= L) {
174 sqEuclideanRowToBatchAvx2Fixed<8>(x, candData + (base * d), d, out + base);
175 base += 8;
176 }
177 switch (L - base) {
178 case 0:
179 break;
180 case 1:
181 sqEuclideanRowToBatchAvx2Fixed<1>(x, candData + (base * d), d, out + base);
182 break;
183 case 2:
184 sqEuclideanRowToBatchAvx2Fixed<2>(x, candData + (base * d), d, out + base);
185 break;
186 case 3:
187 sqEuclideanRowToBatchAvx2Fixed<3>(x, candData + (base * d), d, out + base);
188 break;
189 case 4:
190 sqEuclideanRowToBatchAvx2Fixed<4>(x, candData + (base * d), d, out + base);
191 break;
192 case 5:
193 sqEuclideanRowToBatchAvx2Fixed<5>(x, candData + (base * d), d, out + base);
194 break;
195 case 6:
196 sqEuclideanRowToBatchAvx2Fixed<6>(x, candData + (base * d), d, out + base);
197 break;
198 case 7:
199 sqEuclideanRowToBatchAvx2Fixed<7>(x, candData + (base * d), d, out + base);
200 break;
201 default:
202 break;
203 }
204}
205
215inline void sqEuclideanRowAgainst8Transposed(const float *x, const float *candData, std::size_t d,
216 float *out) noexcept {
217 __m256 acc = _mm256_setzero_ps();
218 for (std::size_t k = 0; k < d; ++k) {
219 const __m256 cv = _mm256_load_ps(candData + (k * 8));
220 const __m256 xv = _mm256_set1_ps(x[k]);
221 const __m256 diff = _mm256_sub_ps(xv, cv);
222 acc = _mm256_fmadd_ps(diff, diff, acc);
223 }
224 _mm256_storeu_ps(out, acc);
225}
226
235inline void sqEuclideanRowAgainst16Transposed(const float *x, const float *candData, std::size_t d,
236 float *out) noexcept {
237 __m256 accLo = _mm256_setzero_ps();
238 __m256 accHi = _mm256_setzero_ps();
239 for (std::size_t k = 0; k < d; ++k) {
240 const __m256 cLo = _mm256_load_ps(candData + (k * 16));
241 const __m256 cHi = _mm256_load_ps(candData + (k * 16) + 8);
242 const __m256 xv = _mm256_set1_ps(x[k]);
243 const __m256 diffLo = _mm256_sub_ps(xv, cLo);
244 const __m256 diffHi = _mm256_sub_ps(xv, cHi);
245 accLo = _mm256_fmadd_ps(diffLo, diffLo, accLo);
246 accHi = _mm256_fmadd_ps(diffHi, diffHi, accHi);
247 }
248 _mm256_storeu_ps(out, accLo);
249 _mm256_storeu_ps(out + 8, accHi);
250}
251
260inline void sqEuclideanRowAgainst8TransposedStrided(const float *x, const float *candData,
261 std::size_t d, std::size_t rowStride,
262 float *out) noexcept {
263 __m256 acc = _mm256_setzero_ps();
264 for (std::size_t k = 0; k < d; ++k) {
265 const __m256 cv = _mm256_loadu_ps(candData + (k * rowStride));
266 const __m256 xv = _mm256_set1_ps(x[k]);
267 const __m256 diff = _mm256_sub_ps(xv, cv);
268 acc = _mm256_fmadd_ps(diff, diff, acc);
269 }
270 _mm256_storeu_ps(out, acc);
271}
272
273#endif // CLUSTERING_USE_AVX2
274
283template <class T>
284inline void sqEuclideanRowToBatch(const T *x, const T *candData, std::size_t L, std::size_t d,
285 T *out) noexcept {
286#ifdef CLUSTERING_USE_AVX2
287 if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
289 sqEuclideanRowToBatchAvx2(x, candData, L, d, out);
290 return;
291 }
292 }
293#endif
294 for (std::size_t t = 0; t < L; ++t) {
295 out[t] = sqEuclideanRowPtr(x, candData + (t * d), d);
296 }
297}
298
299} // namespace detail
300
317template <class T> class GreedyKmppSeeder {
318public:
319 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
320 "GreedyKmppSeeder<T> requires T to be float or double");
321
323 : m_candRows({0, 0}), m_candRowsT({0, 0}), m_candDistSq({0, 0}), m_cumDistSq({0}),
324 m_minSq({0}), m_distsFlat({0, 0}), m_xNormsSq({0}), m_candNormsSq({0}), m_gemmApArena({0}),
325 m_gemmBpArena({0}), m_localScores({0}) {}
326
337 void run(const NDArray<T, 2, Layout::Contig> &X, std::size_t k, std::uint64_t seed,
338 math::Pool pool, NDArray<T, 2, Layout::Contig> &outCentroids) {
339 const std::size_t n = X.dim(0);
340 const std::size_t d = X.dim(1);
341
342 CLUSTERING_ALWAYS_ASSERT(outCentroids.isMutable());
343 CLUSTERING_ALWAYS_ASSERT(outCentroids.dim(0) == k);
344 CLUSTERING_ALWAYS_ASSERT(outCentroids.dim(1) == d);
347
348 (void)pool;
349
350 const std::size_t nLocalTrials = detail::greedyKmppLocalTrials(k);
351 ensureShape(n, d, nLocalTrials, pool.workerCount());
352
353 math::pcg64 rng;
354 rng.seed(seed);
355
356 const T *xData = X.data();
357 T *centroidsData = outCentroids.data();
358 T *minSq = m_minSq.data();
359 T *candRowsData = m_candRows.data();
360 T *cumDistSq = m_cumDistSq.data();
361 T *candDistSqData = m_candDistSq.data();
362#ifdef CLUSTERING_USE_AVX2
363 T *candRowsTData = m_candRowsT.data();
364#endif
365
366 // GEMM scoring wins only when the candidate width L is >= one kNr panel (6). Below that
367 // the 8x6 kernel's fixed 48-FMA body over-computes the 8xL useful tile; the per-row
368 // streaming kernel with L parallel accumulators is tighter. Gate on L >= kNr<float>.
369 constexpr std::size_t kNrF = math::detail::kKernelNr<float>;
370 const bool useGemmScoring = (d >= 32) && (nLocalTrials >= kNrF);
371 if (useGemmScoring) {
372 T *xNormsData = m_xNormsSq.data();
373 for (std::size_t i = 0; i < n; ++i) {
374 xNormsData[i] = math::detail::sqNormRow<T, Layout::Contig>(X, i);
375 }
376 }
377
378 // Step 1: first centroid uniformly. randUniformU64 is the deterministic primitive; the
379 // modulo map carries a tiny bias for very large n but is the standard sklearn convention.
380 const auto first = static_cast<std::size_t>(math::randUniformU64(rng) % n);
381 std::memcpy(centroidsData, xData + (first * d), d * sizeof(T));
382
383 for (std::size_t i = 0; i < n; ++i) {
384 minSq[i] = detail::sqEuclideanRowPtr(xData + (i * d), centroidsData, d);
385 }
386
387 if (k == 1) {
388 return;
389 }
390
391 std::vector<std::size_t> candidates(nLocalTrials, 0);
392 std::vector<T> scores(nLocalTrials, T{0});
393
394 for (std::size_t c = 1; c < k; ++c) {
395 // Build the cumulative-distance array in a single pass. The cum array is only used as a
396 // probability normalizer, so no Kahan compensation.
397 T runningSum = T{0};
398 for (std::size_t i = 0; i < n; ++i) {
399 runningSum += minSq[i];
400 cumDistSq[i] = runningSum;
401 }
402 const T total = runningSum;
403
404 // Degenerate guard: when every chosen centroid coincides with every remaining point the
405 // total collapses to ~0; pick the next centroid uniformly so the routine cannot stall.
406 if (!(total > T{0})) {
407 const auto pick = static_cast<std::size_t>(math::randUniformU64(rng) % n);
408 std::memcpy(centroidsData + (c * d), xData + (pick * d), d * sizeof(T));
409 for (std::size_t i = 0; i < n; ++i) {
410 const T cand = detail::sqEuclideanRowPtr(xData + (i * d), centroidsData + (c * d), d);
411 if (cand < minSq[i]) {
412 minSq[i] = cand;
413 }
414 }
415 continue;
416 }
417
418 // Draw nLocalTrials candidates by inverse-CDF sampling on the cumulative array via
419 // std::upper_bound. Determinism: identical seed + identical n produces identical candidate
420 // sets because the @c randUnit draw sequence is the same and the cum array is identical.
421 const T *cumBegin = cumDistSq;
422 const T *cumEnd = cumDistSq + n;
423 for (std::size_t t = 0; t < nLocalTrials; ++t) {
424 const T u = math::randUnit<T>(rng) * total;
425 const T *it = std::upper_bound(cumBegin, cumEnd, u);
426 const std::size_t pick = (it == cumEnd) ? (n - 1) : static_cast<std::size_t>(it - cumBegin);
427 candidates[t] = pick;
428 }
429
430 // Pack the L candidate rows into a contiguous (L, d) buffer so the batched scoring kernel
431 // can stream x once across L accumulators. The L*d pack is negligible against the n-pass
432 // scoring it amortizes.
433 for (std::size_t t = 0; t < nLocalTrials; ++t) {
434 std::memcpy(candRowsData + (t * d), xData + (candidates[t] * d), d * sizeof(T));
435 }
436
437 for (std::size_t t = 0; t < nLocalTrials; ++t) {
438 scores[t] = T{0};
439 }
440 constexpr std::size_t kMaxLocalTrials = 32;
441 CLUSTERING_ALWAYS_ASSERT(nLocalTrials <= kMaxLocalTrials);
442
443 const std::size_t transposedWidth = detail::greedyKmppTransposedWidth(nLocalTrials);
444 bool scoredViaTransposed = false;
445#ifdef CLUSTERING_USE_AVX2
446 // Low-d hot path: at d <= kAvx2Lanes the (L, d) row-batched kernel either falls into the
447 // scalar K-tail (d < 8) or pays @c L horizontal-sum reductions for one K-iter of work
448 // (d == 8). The transposed `(d, W)` layout puts the same-feature components of every
449 // candidate in consecutive 8-lane YMM registers, so each broadcast-of-x[k] + FMA pair
450 // folds 8 (or 16, for the 16-lane unroll) distances at once.
451 if constexpr (std::is_same_v<T, float>) {
452 if (d > 0 && d <= math::detail::kAvx2Lanes<float>) {
453 for (std::size_t kk = 0; kk < d; ++kk) {
454 float *dstK = candRowsTData + (kk * transposedWidth);
455 for (std::size_t t = 0; t < nLocalTrials; ++t) {
456 dstK[t] = candRowsData[(t * d) + kk];
457 }
458 for (std::size_t t = nLocalTrials; t < transposedWidth; ++t) {
459 dstK[t] = 0.0F;
460 }
461 }
462 if (transposedWidth == 16) {
463 __m256 scoresLoAcc = _mm256_setzero_ps();
464 __m256 scoresHiAcc = _mm256_setzero_ps();
465 for (std::size_t i = 0; i < n; ++i) {
466 const float *xi = xData + (i * d);
467 const __m256 miVec = _mm256_set1_ps(minSq[i]);
468 float *dstRow = candDistSqData + (i * transposedWidth);
469 detail::sqEuclideanRowAgainst16Transposed(xi, candRowsTData, d, dstRow);
470 const __m256 dLo = _mm256_loadu_ps(dstRow);
471 const __m256 dHi = _mm256_loadu_ps(dstRow + 8);
472 scoresLoAcc = _mm256_add_ps(scoresLoAcc, _mm256_min_ps(dLo, miVec));
473 scoresHiAcc = _mm256_add_ps(scoresHiAcc, _mm256_min_ps(dHi, miVec));
474 }
475 std::array<float, 16> tmp{};
476 _mm256_storeu_ps(tmp.data(), scoresLoAcc);
477 _mm256_storeu_ps(tmp.data() + 8, scoresHiAcc);
478 for (std::size_t t = 0; t < nLocalTrials; ++t) {
479 scores[t] = tmp[t];
480 }
481 } else if (transposedWidth == 8) {
482 __m256 scoresAcc = _mm256_setzero_ps();
483 for (std::size_t i = 0; i < n; ++i) {
484 const float *xi = xData + (i * d);
485 const __m256 miVec = _mm256_set1_ps(minSq[i]);
486 float *dstRow = candDistSqData + (i * transposedWidth);
487 detail::sqEuclideanRowAgainst8Transposed(xi, candRowsTData, d, dstRow);
488 const __m256 dv = _mm256_loadu_ps(dstRow);
489 scoresAcc = _mm256_add_ps(scoresAcc, _mm256_min_ps(dv, miVec));
490 }
491 std::array<float, 8> tmp{};
492 _mm256_storeu_ps(tmp.data(), scoresAcc);
493 for (std::size_t t = 0; t < nLocalTrials; ++t) {
494 scores[t] = tmp[t];
495 }
496 } else {
497 // Generic chunked path for L > 16 (very high k). Walk the transposed layout 8 lanes
498 // at a time so each chunk stays on the fully unrolled 8-wide kernel.
499 for (std::size_t i = 0; i < n; ++i) {
500 const float *xi = xData + (i * d);
501 const float mi = minSq[i];
502 float *dstRow = candDistSqData + (i * transposedWidth);
503 for (std::size_t base = 0; base < transposedWidth; base += 8) {
504 detail::sqEuclideanRowAgainst8TransposedStrided(xi, candRowsTData + base, d,
505 transposedWidth, dstRow + base);
506 }
507 for (std::size_t t = 0; t < nLocalTrials; ++t) {
508 scores[t] += (dstRow[t] < mi) ? dstRow[t] : mi;
509 }
510 }
511 }
512 scoredViaTransposed = true;
513 }
514 }
515#endif
516
517 if (!scoredViaTransposed) {
518 // GEMM-based batch distance for moderate-to-high d: compute X * cand^T via the core
519 // GEMM (alpha=-2, beta=0), then add pre-computed per-row ||x||^2 and per-candidate
520 // ||c||^2 in one min+sum fold. BLAS-style GEMM is the decisive win at d >= ~16 where
521 // the per-row streaming kernel bottlenecks on L1/L2 bandwidth.
522 if (useGemmScoring) {
523 auto candView = NDArray<T, 2, Layout::Contig>::borrow(candRowsData, {nLocalTrials, d});
524 auto xView = NDArray<T, 2, Layout::Contig>::borrow(const_cast<T *>(xData), {n, d});
525 auto distsView = NDArray<T, 2>::borrow(m_distsFlat.data(), {n, nLocalTrials});
526 auto candT = candView.t();
527 // Direct gemmRunReference with caller-owned scratch so the seeder's per-pick GEMM
528 // leaves the shape-stable allocation footprint in place (no per-call arena alloc).
529 const auto xDesc = ::clustering::detail::describeMatrix(xView);
530 const auto candDesc = ::clustering::detail::describeMatrix(candT);
531 auto distsDesc = ::clustering::detail::describeMatrixMut(distsView);
532 math::detail::gemmRunReference<T>(xDesc, candDesc, distsDesc, T{-2}, T{0},
533 m_gemmApArena.data(), m_gemmBpArena.data(), pool);
534 // Candidate norms once per pick.
535 T *candNorms = m_candNormsSq.data();
536 for (std::size_t t = 0; t < nLocalTrials; ++t) {
537 candNorms[t] = math::detail::sqNormRow<T, Layout::Contig>(candView, t);
538 }
539 const T *xNorms = m_xNormsSq.data();
540 const T *distsFlat = m_distsFlat.data();
541 for (std::size_t i = 0; i < n; ++i) {
542 const T mi = minSq[i];
543 const T xn = xNorms[i];
544 const T *distRowI = distsFlat + (i * nLocalTrials);
545 T *dstRow = candDistSqData + (i * transposedWidth);
546 for (std::size_t t = 0; t < nLocalTrials; ++t) {
547 T v = distRowI[t] + xn + candNorms[t];
548 if (v < T{0}) {
549 v = T{0};
550 }
551 dstRow[t] = v;
552 scores[t] += (v < mi) ? v : mi;
553 }
554 }
555 } else {
556 // Fused scoring: for each x row, compute L distances against the candidate pack and
557 // update L parallel running sums in one pass. The single-x-stream path is the load-
558 // bearing win at envelope shapes where n*d far exceeds L2 -- one stream is the
559 // difference between bandwidth-bound and bandwidth-bound times L. Parallelized over
560 // X rows via per-worker score slabs reduced at the end; candDistSqData writes are
561 // row-local so no aliasing across workers.
562 const bool willParallelize = pool.shouldParallelize(n, 1024, 2) && pool.pool != nullptr;
563 bool scoredViaSoa = false;
564#ifdef CLUSTERING_USE_AVX2
565 if constexpr (std::is_same_v<T, float>) {
566 // SoA 8-row M-tile kernel: streams X AoS through an in-register 8x8 transpose so 8
567 // rows' features land in feature-major YMM accumulators, folds L distances per row
568 // without per-row horizontal reductions, writes the per-(row, cand) distances to
569 // @c outDist, and accumulates min-capped scores. The kernel handles arbitrary row
570 // counts, so per-worker row ranges slot in under the same parallel fan-out that
571 // feeds the fallback path.
572 const bool soaEligible = (d >= 8) && (nLocalTrials >= 1) && (nLocalTrials <= 6);
573 if (soaEligible) {
574 auto soaRange = [&](std::size_t lo, std::size_t hi, T *localScores) noexcept {
575 const std::size_t rangeN = hi - lo;
576 const float *xSlice = xData + (lo * d);
577 const float *minSlice = minSq + lo;
578 float *distSlice = candDistSqData + (lo * transposedWidth);
579 switch (nLocalTrials) {
580 case 1:
581 math::detail::kmppScoreSoaRowsAvx2F32<1>(xSlice, rangeN, d, candRowsData,
582 minSlice, distSlice, transposedWidth,
583 localScores);
584 break;
585 case 2:
586 math::detail::kmppScoreSoaRowsAvx2F32<2>(xSlice, rangeN, d, candRowsData,
587 minSlice, distSlice, transposedWidth,
588 localScores);
589 break;
590 case 3:
591 math::detail::kmppScoreSoaRowsAvx2F32<3>(xSlice, rangeN, d, candRowsData,
592 minSlice, distSlice, transposedWidth,
593 localScores);
594 break;
595 case 4:
596 math::detail::kmppScoreSoaRowsAvx2F32<4>(xSlice, rangeN, d, candRowsData,
597 minSlice, distSlice, transposedWidth,
598 localScores);
599 break;
600 case 5:
601 math::detail::kmppScoreSoaRowsAvx2F32<5>(xSlice, rangeN, d, candRowsData,
602 minSlice, distSlice, transposedWidth,
603 localScores);
604 break;
605 case 6:
606 math::detail::kmppScoreSoaRowsAvx2F32<6>(xSlice, rangeN, d, candRowsData,
607 minSlice, distSlice, transposedWidth,
608 localScores);
609 break;
610 default:
611 break;
612 }
613 };
614
615 if (willParallelize) {
616 const std::size_t workers = pool.workerCount();
617 T *localScores = m_localScores.data();
618 for (std::size_t e = 0; e < workers * nLocalTrials; ++e) {
619 localScores[e] = T{0};
620 }
621 pool.pool
622 ->submit_blocks(
623 std::size_t{0}, n,
624 [&](std::size_t lo, std::size_t hi) {
625 const std::size_t w = math::Pool::workerIndex();
626 soaRange(lo, hi, localScores + (w * nLocalTrials));
627 },
628 workers)
629 .wait();
630 for (std::size_t w = 0; w < workers; ++w) {
631 const T *row = localScores + (w * nLocalTrials);
632 for (std::size_t t = 0; t < nLocalTrials; ++t) {
633 scores[t] += row[t];
634 }
635 }
636 } else {
637 soaRange(0, n, scores.data());
638 }
639 scoredViaSoa = true;
640 }
641 }
642#endif
643 if (!scoredViaSoa) {
644 auto scanRange = [&](std::size_t lo, std::size_t hi, T *localScores) noexcept {
645 std::array<T, 32> distRowLocal{};
646 for (std::size_t i = lo; i < hi; ++i) {
647 const T *xi = xData + (i * d);
648 const T mi = minSq[i];
649 detail::sqEuclideanRowToBatch<T>(xi, candRowsData, nLocalTrials, d,
650 distRowLocal.data());
651 T *dstRow = candDistSqData + (i * transposedWidth);
652 for (std::size_t t = 0; t < nLocalTrials; ++t) {
653 dstRow[t] = distRowLocal[t];
654 localScores[t] += (distRowLocal[t] < mi) ? distRowLocal[t] : mi;
655 }
656 }
657 };
658
659 if (willParallelize) {
660 const std::size_t workers = pool.workerCount();
661 T *localScores = m_localScores.data();
662 for (std::size_t e = 0; e < workers * nLocalTrials; ++e) {
663 localScores[e] = T{0};
664 }
665 pool.pool
666 ->submit_blocks(
667 std::size_t{0}, n,
668 [&](std::size_t lo, std::size_t hi) {
669 const std::size_t w = math::Pool::workerIndex();
670 scanRange(lo, hi, localScores + (w * nLocalTrials));
671 },
672 workers)
673 .wait();
674 for (std::size_t w = 0; w < workers; ++w) {
675 const T *row = localScores + (w * nLocalTrials);
676 for (std::size_t t = 0; t < nLocalTrials; ++t) {
677 scores[t] += row[t];
678 }
679 }
680 } else {
681 scanRange(0, n, scores.data());
682 }
683 }
684 }
685 }
686
687 std::size_t bestT = 0;
688 T bestScore = scores[0];
689 for (std::size_t t = 1; t < nLocalTrials; ++t) {
690 if (scores[t] < bestScore) {
691 bestScore = scores[t];
692 bestT = t;
693 }
694 }
695 const std::size_t bestCandidate = candidates[bestT];
696
697 // Commit best candidate: copy its row into outCentroids and refresh @c minSq from the
698 // cached candidate-distance plane, skipping a fresh O(n*d) scan against the winner row.
699 const T *winnerRow = xData + (bestCandidate * d);
700 std::memcpy(centroidsData + (c * d), winnerRow, d * sizeof(T));
701 for (std::size_t i = 0; i < n; ++i) {
702 const T cand = candDistSqData[(i * transposedWidth) + bestT];
703 if (cand < minSq[i]) {
704 minSq[i] = cand;
705 }
706 }
707 }
708 }
709
710private:
711 void ensureShape(std::size_t n, std::size_t d, std::size_t L, std::size_t workers) {
712 const std::size_t w = detail::greedyKmppTransposedWidth(L == 0 ? std::size_t{1} : L);
713 if (m_candRows.dim(0) != L || m_candRows.dim(1) != d) {
714 m_candRows = NDArray<T, 2, Layout::Contig>({L, d});
715 }
716 if (m_candRowsT.dim(0) != d || m_candRowsT.dim(1) != w) {
717 m_candRowsT = NDArray<T, 2, Layout::Contig>({d == 0 ? std::size_t{1} : d, w});
718 }
719 if (m_candDistSq.dim(0) != n || m_candDistSq.dim(1) != w) {
720 m_candDistSq = NDArray<T, 2, Layout::Contig>({n == 0 ? std::size_t{1} : n, w});
721 }
722 if (m_cumDistSq.dim(0) != n) {
723 m_cumDistSq = NDArray<T, 1>({n});
724 }
725 if (m_minSq.dim(0) != n) {
726 m_minSq = NDArray<T, 1>({n});
727 }
728 // GEMM-scoring-only scratch (distsFlat, xNormsSq, candNormsSq, gemmApArena, gemmBpArena).
729 // The GEMM path fires at `d >= 32` && L >= kKernelNr<float>; outside that envelope we keep
730 // unit-sized placeholders so @c .data() stays dereferenceable without paying the @c kKc*kNc
731 // envelope tax (@c Bp alone is several MB).
732 constexpr std::size_t kNrForGemm = math::detail::kKernelNr<float>;
733 const bool gemmScoringUsed = std::is_same_v<T, float> && (d >= 32) && (L >= kNrForGemm);
734 const std::size_t nSafe = (n == 0) ? std::size_t{1} : n;
735 const std::size_t lSafe = (L == 0) ? std::size_t{1} : L;
736 const std::size_t distsFlatRows = gemmScoringUsed ? nSafe : std::size_t{1};
737 const std::size_t distsFlatCols = gemmScoringUsed ? lSafe : std::size_t{1};
738 if (m_distsFlat.dim(0) != distsFlatRows || m_distsFlat.dim(1) != distsFlatCols) {
739 m_distsFlat = NDArray<T, 2, Layout::Contig>({distsFlatRows, distsFlatCols});
740 }
741 const std::size_t xNormsLen = gemmScoringUsed ? nSafe : std::size_t{1};
742 if (m_xNormsSq.dim(0) != xNormsLen) {
743 m_xNormsSq = NDArray<T, 1>({xNormsLen});
744 }
745 const std::size_t candNormsLen = gemmScoringUsed ? lSafe : std::size_t{1};
746 if (m_candNormsSq.dim(0) != candNormsLen) {
747 m_candNormsSq = NDArray<T, 1>({candNormsLen});
748 }
749 const std::size_t workersClamped = workers == 0 ? std::size_t{1} : workers;
750 // @c gemmRunReference parallelizes the Mc-tile loop, with each worker owning a per-worker
751 // slice of the A-pack arena at offset `(worker * kMc * kKc)`. Sizing the arena for just
752 // one worker was fine while the seeder's envelope kept the GEMM path off (k=16, L=4 fell
753 // into the SoA kernel), but the Elkan-eligible shapes push L >= kNrF where the GEMM scoring
754 // activates and multiple workers collide into the same slice.
755 const std::size_t apSize = gemmScoringUsed
756 ? (workersClamped * math::detail::kMc<T> * math::detail::kKc<T>)
757 : std::size_t{1};
758 const std::size_t bpSize =
759 gemmScoringUsed ? (math::detail::kKc<T> * math::detail::kNc<T>) : std::size_t{1};
760 if (m_gemmApArena.dim(0) != apSize) {
761 m_gemmApArena = NDArray<T, 1>({apSize});
762 }
763 if (m_gemmBpArena.dim(0) != bpSize) {
764 m_gemmBpArena = NDArray<T, 1>({bpSize});
765 }
766 const std::size_t lsLen = workersClamped * (L == 0 ? std::size_t{1} : L);
767 if (m_localScores.dim(0) != lsLen) {
768 m_localScores = NDArray<T, 1>({lsLen});
769 }
770 }
771
773 NDArray<T, 2, Layout::Contig> m_candRows;
778 NDArray<T, 2, Layout::Contig> m_candRowsT;
782 NDArray<T, 2, Layout::Contig> m_candDistSq;
785 NDArray<T, 1> m_cumDistSq;
788 NDArray<T, 1> m_minSq;
790 NDArray<T, 2, Layout::Contig> m_distsFlat;
792 NDArray<T, 1> m_xNormsSq;
794 NDArray<T, 1> m_candNormsSq;
796 NDArray<T, 1> m_gemmApArena;
798 NDArray<T, 1> m_gemmBpArena;
801 NDArray<T, 1> m_localScores;
802};
803
804} // namespace clustering::kmeans
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
static NDArray borrow(T *ptr, std::array< std::size_t, N > shape) noexcept
Borrows a contiguous buffer as an NDArray without taking ownership.
Definition ndarray.h:570
const T * data() const noexcept
Provides read-only access to the internal data array.
Definition ndarray.h:503
bool isMutable() const noexcept
Reports whether writes through operator(), Accessor, or flatIndex are allowed.
Definition ndarray.h:488
void run(const NDArray< T, 2, Layout::Contig > &X, std::size_t k, std::uint64_t seed, math::Pool pool, NDArray< T, 2, Layout::Contig > &outCentroids)
Seed k centroids from X into outCentroids.
void sqEuclideanRowAgainst8TransposedStrided(const float *x, const float *candData, std::size_t d, std::size_t rowStride, float *out) noexcept
Compute one 8-way squared distance slab against an (d, W) transposed candidate layout with an explici...
void sqEuclideanRowToBatchAvx2Fixed(const float *x, const float *candData, std::size_t d, float *out) noexcept
Compile-time batched scoring kernel: stream x once across B parallel AVX2 accumulators to compute B s...
void sqEuclideanRowToBatchAvx2(const T *x, const T *candData, std::size_t L, std::size_t d, T *out) noexcept
Compute L squared Euclidean distances against an (L, d) row-batched candidate layout in a single stre...
void sqEuclideanRowAgainst8Transposed(const float *x, const float *candData, std::size_t d, float *out) noexcept
Compute L squared distances against an (d, 8) transposed candidate layout with one streaming pass ove...
constexpr std::size_t greedyKmppTransposedWidth(std::size_t L) noexcept
Round L up to the nearest multiple of 8 used by the transposed scoring layout.
void sqEuclideanRowAgainst16Transposed(const float *x, const float *candData, std::size_t d, float *out) noexcept
Compute two 8-way squared distance slabs against an (d, 16) transposed candidate layout in one stream...
std::size_t greedyKmppLocalTrials(std::size_t k) noexcept
Compute the local-trials count used by greedy k-means++.
void sqEuclideanRowToBatch(const T *x, const T *candData, std::size_t L, std::size_t d, T *out) noexcept
Squared Euclidean distance from one x row to a batch of L candidate rows.
float horizontalSumAvx2(__m256 v) noexcept
Definition pairwise.h:41
constexpr std::size_t kAvx2Lanes
Definition pairwise.h:96
T sqNormRow(const NDArray< T, 2, LX > &X, std::size_t i) noexcept
Definition pairwise.h:154
T randUnit(Rng &rng) noexcept
Draw a uniform variate in the half-open unit interval [0, 1).
Definition rng.h:152
std::uint64_t randUniformU64(Rng &rng) noexcept
Draw a 64-bit unsigned integer uniformly at random from the full u64 range.
Definition rng.h:139
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63
static std::size_t workerIndex() noexcept
Stable index of the calling worker thread within the owning pool.
Definition thread.h:82
std::size_t workerCount() const noexcept
Number of worker threads available, or 1 in serial mode.
Definition thread.h:72
BS::light_thread_pool * pool
Underlying pool, or nullptr to force serial execution.
Definition thread.h:65
bool shouldParallelize(std::size_t totalWork, std::size_t minChunk, std::size_t minTasksPerWorker=2) const noexcept
Decide whether totalWork warrants parallel dispatch.
Definition thread.h:98
128-bit state for the PCG-XSL-RR 64-bit output generator (Melissa O'Neill).
Definition rng.h:30
void seed(std::uint64_t seedValue, std::uint64_t stream=0) noexcept
Initialize the generator per PCG's canonical seeding procedure.
Definition rng.h:46