339 const std::size_t n = X.
dim(0);
340 const std::size_t d = X.
dim(1);
351 ensureShape(n, d, nLocalTrials, pool.
workerCount());
356 const T *xData = X.
data();
357 T *centroidsData = outCentroids.
data();
358 T *minSq = m_minSq.data();
359 T *candRowsData = m_candRows.data();
360 T *cumDistSq = m_cumDistSq.data();
361 T *candDistSqData = m_candDistSq.data();
362#ifdef CLUSTERING_USE_AVX2
363 T *candRowsTData = m_candRowsT.data();
369 constexpr std::size_t kNrF = math::detail::kKernelNr<float>;
370 const bool useGemmScoring = (d >= 32) && (nLocalTrials >= kNrF);
371 if (useGemmScoring) {
372 T *xNormsData = m_xNormsSq.data();
373 for (std::size_t i = 0; i < n; ++i) {
381 std::memcpy(centroidsData, xData + (first * d), d *
sizeof(T));
383 for (std::size_t i = 0; i < n; ++i) {
384 minSq[i] = detail::sqEuclideanRowPtr(xData + (i * d), centroidsData, d);
391 std::vector<std::size_t> candidates(nLocalTrials, 0);
392 std::vector<T> scores(nLocalTrials, T{0});
394 for (std::size_t c = 1; c < k; ++c) {
398 for (std::size_t i = 0; i < n; ++i) {
399 runningSum += minSq[i];
400 cumDistSq[i] = runningSum;
402 const T total = runningSum;
406 if (!(total > T{0})) {
408 std::memcpy(centroidsData + (c * d), xData + (pick * d), d *
sizeof(T));
409 for (std::size_t i = 0; i < n; ++i) {
410 const T cand = detail::sqEuclideanRowPtr(xData + (i * d), centroidsData + (c * d), d);
411 if (cand < minSq[i]) {
421 const T *cumBegin = cumDistSq;
422 const T *cumEnd = cumDistSq + n;
423 for (std::size_t t = 0; t < nLocalTrials; ++t) {
425 const T *it = std::upper_bound(cumBegin, cumEnd, u);
426 const std::size_t pick = (it == cumEnd) ? (n - 1) :
static_cast<std::size_t
>(it - cumBegin);
427 candidates[t] = pick;
433 for (std::size_t t = 0; t < nLocalTrials; ++t) {
434 std::memcpy(candRowsData + (t * d), xData + (candidates[t] * d), d *
sizeof(T));
437 for (std::size_t t = 0; t < nLocalTrials; ++t) {
440 constexpr std::size_t kMaxLocalTrials = 32;
444 bool scoredViaTransposed =
false;
445#ifdef CLUSTERING_USE_AVX2
451 if constexpr (std::is_same_v<T, float>) {
453 for (std::size_t kk = 0; kk < d; ++kk) {
454 float *dstK = candRowsTData + (kk * transposedWidth);
455 for (std::size_t t = 0; t < nLocalTrials; ++t) {
456 dstK[t] = candRowsData[(t * d) + kk];
458 for (std::size_t t = nLocalTrials; t < transposedWidth; ++t) {
462 if (transposedWidth == 16) {
463 __m256 scoresLoAcc = _mm256_setzero_ps();
464 __m256 scoresHiAcc = _mm256_setzero_ps();
465 for (std::size_t i = 0; i < n; ++i) {
466 const float *xi = xData + (i * d);
467 const __m256 miVec = _mm256_set1_ps(minSq[i]);
468 float *dstRow = candDistSqData + (i * transposedWidth);
470 const __m256 dLo = _mm256_loadu_ps(dstRow);
471 const __m256 dHi = _mm256_loadu_ps(dstRow + 8);
472 scoresLoAcc = _mm256_add_ps(scoresLoAcc, _mm256_min_ps(dLo, miVec));
473 scoresHiAcc = _mm256_add_ps(scoresHiAcc, _mm256_min_ps(dHi, miVec));
475 std::array<float, 16> tmp{};
476 _mm256_storeu_ps(tmp.data(), scoresLoAcc);
477 _mm256_storeu_ps(tmp.data() + 8, scoresHiAcc);
478 for (std::size_t t = 0; t < nLocalTrials; ++t) {
481 }
else if (transposedWidth == 8) {
482 __m256 scoresAcc = _mm256_setzero_ps();
483 for (std::size_t i = 0; i < n; ++i) {
484 const float *xi = xData + (i * d);
485 const __m256 miVec = _mm256_set1_ps(minSq[i]);
486 float *dstRow = candDistSqData + (i * transposedWidth);
488 const __m256 dv = _mm256_loadu_ps(dstRow);
489 scoresAcc = _mm256_add_ps(scoresAcc, _mm256_min_ps(dv, miVec));
491 std::array<float, 8> tmp{};
492 _mm256_storeu_ps(tmp.data(), scoresAcc);
493 for (std::size_t t = 0; t < nLocalTrials; ++t) {
499 for (std::size_t i = 0; i < n; ++i) {
500 const float *xi = xData + (i * d);
501 const float mi = minSq[i];
502 float *dstRow = candDistSqData + (i * transposedWidth);
503 for (std::size_t base = 0; base < transposedWidth; base += 8) {
505 transposedWidth, dstRow + base);
507 for (std::size_t t = 0; t < nLocalTrials; ++t) {
508 scores[t] += (dstRow[t] < mi) ? dstRow[t] : mi;
512 scoredViaTransposed =
true;
517 if (!scoredViaTransposed) {
522 if (useGemmScoring) {
526 auto candT = candView.t();
529 const auto xDesc = ::clustering::detail::describeMatrix(xView);
530 const auto candDesc = ::clustering::detail::describeMatrix(candT);
531 auto distsDesc = ::clustering::detail::describeMatrixMut(distsView);
532 math::detail::gemmRunReference<T>(xDesc, candDesc, distsDesc, T{-2}, T{0},
533 m_gemmApArena.data(), m_gemmBpArena.data(), pool);
535 T *candNorms = m_candNormsSq.data();
536 for (std::size_t t = 0; t < nLocalTrials; ++t) {
539 const T *xNorms = m_xNormsSq.data();
540 const T *distsFlat = m_distsFlat.data();
541 for (std::size_t i = 0; i < n; ++i) {
542 const T mi = minSq[i];
543 const T xn = xNorms[i];
544 const T *distRowI = distsFlat + (i * nLocalTrials);
545 T *dstRow = candDistSqData + (i * transposedWidth);
546 for (std::size_t t = 0; t < nLocalTrials; ++t) {
547 T v = distRowI[t] + xn + candNorms[t];
552 scores[t] += (v < mi) ? v : mi;
563 bool scoredViaSoa =
false;
564#ifdef CLUSTERING_USE_AVX2
565 if constexpr (std::is_same_v<T, float>) {
572 const bool soaEligible = (d >= 8) && (nLocalTrials >= 1) && (nLocalTrials <= 6);
574 auto soaRange = [&](std::size_t lo, std::size_t hi, T *localScores)
noexcept {
575 const std::size_t rangeN = hi - lo;
576 const float *xSlice = xData + (lo * d);
577 const float *minSlice = minSq + lo;
578 float *distSlice = candDistSqData + (lo * transposedWidth);
579 switch (nLocalTrials) {
581 math::detail::kmppScoreSoaRowsAvx2F32<1>(xSlice, rangeN, d, candRowsData,
582 minSlice, distSlice, transposedWidth,
586 math::detail::kmppScoreSoaRowsAvx2F32<2>(xSlice, rangeN, d, candRowsData,
587 minSlice, distSlice, transposedWidth,
591 math::detail::kmppScoreSoaRowsAvx2F32<3>(xSlice, rangeN, d, candRowsData,
592 minSlice, distSlice, transposedWidth,
596 math::detail::kmppScoreSoaRowsAvx2F32<4>(xSlice, rangeN, d, candRowsData,
597 minSlice, distSlice, transposedWidth,
601 math::detail::kmppScoreSoaRowsAvx2F32<5>(xSlice, rangeN, d, candRowsData,
602 minSlice, distSlice, transposedWidth,
606 math::detail::kmppScoreSoaRowsAvx2F32<6>(xSlice, rangeN, d, candRowsData,
607 minSlice, distSlice, transposedWidth,
615 if (willParallelize) {
617 T *localScores = m_localScores.data();
618 for (std::size_t e = 0; e < workers * nLocalTrials; ++e) {
619 localScores[e] = T{0};
624 [&](std::size_t lo, std::size_t hi) {
626 soaRange(lo, hi, localScores + (w * nLocalTrials));
630 for (std::size_t w = 0; w < workers; ++w) {
631 const T *row = localScores + (w * nLocalTrials);
632 for (std::size_t t = 0; t < nLocalTrials; ++t) {
637 soaRange(0, n, scores.data());
644 auto scanRange = [&](std::size_t lo, std::size_t hi, T *localScores)
noexcept {
645 std::array<T, 32> distRowLocal{};
646 for (std::size_t i = lo; i < hi; ++i) {
647 const T *xi = xData + (i * d);
648 const T mi = minSq[i];
650 distRowLocal.data());
651 T *dstRow = candDistSqData + (i * transposedWidth);
652 for (std::size_t t = 0; t < nLocalTrials; ++t) {
653 dstRow[t] = distRowLocal[t];
654 localScores[t] += (distRowLocal[t] < mi) ? distRowLocal[t] : mi;
659 if (willParallelize) {
661 T *localScores = m_localScores.data();
662 for (std::size_t e = 0; e < workers * nLocalTrials; ++e) {
663 localScores[e] = T{0};
668 [&](std::size_t lo, std::size_t hi) {
670 scanRange(lo, hi, localScores + (w * nLocalTrials));
674 for (std::size_t w = 0; w < workers; ++w) {
675 const T *row = localScores + (w * nLocalTrials);
676 for (std::size_t t = 0; t < nLocalTrials; ++t) {
681 scanRange(0, n, scores.data());
687 std::size_t bestT = 0;
688 T bestScore = scores[0];
689 for (std::size_t t = 1; t < nLocalTrials; ++t) {
690 if (scores[t] < bestScore) {
691 bestScore = scores[t];
695 const std::size_t bestCandidate = candidates[bestT];
699 const T *winnerRow = xData + (bestCandidate * d);
700 std::memcpy(centroidsData + (c * d), winnerRow, d *
sizeof(T));
701 for (std::size_t i = 0; i < n; ++i) {
702 const T cand = candDistSqData[(i * transposedWidth) + bestT];
703 if (cand < minSq[i]) {