100 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
101 "LloydFusedGemm<T> requires T to be float or double");
104 : m_centroidsOld({0, 0}), m_cSqNorms({0}), m_sums({0, 0}), m_counts({0}), m_minDistSq({0}),
105 m_shiftSq({0}), m_partialSums({0}), m_partialComps({0}), m_partialCounts({0}),
106 m_foldComp({0}), m_packedB({0}), m_packedCSqNorms({0}), m_distsChunk({0, 0}),
107 m_gemmApArena({0}), m_xNormsSq({0}), m_varSum({0}), m_varSumSq({0}), m_u({0}), m_l({0}),
108 m_shiftEuclidean({0}), m_halfDistToNearestOther({0}), m_elkanBounds({0, 0}),
109 m_centerDist({0, 0}) {}
111#ifdef CLUSTERING_KMEANS_KAHAN_N_THRESHOLD
120 static constexpr std::size_t
kahanNThreshold = CLUSTERING_KMEANS_KAHAN_N_THRESHOLD;
147 std::size_t k, std::size_t maxIter, T tol,
math::Pool pool,
149 bool &outConverged) {
150 const std::size_t n = X.
dim(0);
151 const std::size_t d = X.
dim(1);
159 if (n == 0 || d == 0) {
166 const std::size_t workerCount = pool.
workerCount();
167 ensureShape(n, d, k, workerCount);
175 const T shiftSqThreshold = tol * meanColumnVariance(X);
181 for (std::size_t i = 0; i < n; ++i) {
185 refreshCentroidSqNorms(centroids);
187 std::size_t iter = 0;
188 bool converged =
false;
201 (k <= kElkanMaxK) && (n * k <= kElkanNKLimit) && (k >= 2);
203 while (iter < maxIter) {
204 if (hamerlyEligible && iter > 0) {
205 runHamerlyAssignment(X, centroids, outLabels, pool);
206 }
else if (elkanEligible && iter > 0) {
207 runElkanAssignment(X, centroids, outLabels, pool);
213 runAssignment(X, centroids, outLabels, pool);
214 if (hamerlyEligible && iter == 0 && assignmentUsesFusedArgmin(X, centroids)) {
215 seedHamerlyBoundsFromLabels(X, centroids, outLabels, pool);
219 std::memcpy(m_centroidsOld.data(), centroids.
data(),
220 centroids.
dim(0) * centroids.
dim(1) *
sizeof(T));
223 scatterAndFoldKahan(X, outLabels, k, pool);
225 scatterAndFoldPlain(X, outLabels, k, pool);
234 (void)::clustering::kmeans::detail::reseedEmptyClusters<T>(X, centroids, m_sums, m_counts,
237 finalizeMeans(centroids);
238 refreshCentroidSqNorms(centroids);
241 const T totalShift = ::clustering::kmeans::detail::totalShiftSqKahan<T>(m_shiftSq);
244 if (totalShift <= shiftSqThreshold) {
256 if (hamerlyEligible && iter > 0) {
257 runHamerlyAssignment(X, centroids, outLabels,
math::Pool{});
258 }
else if (elkanEligible && iter > 0) {
259 runElkanAssignment(X, centroids, outLabels,
math::Pool{});
261 runAssignment(X, centroids, outLabels, pool);
263 if (!assignmentProducesDirectMinDistSq(X, centroids)) {
264 recomputeMinDistSqDirect(X, centroids, outLabels, pool);
271 for (std::size_t i = 0; i < n; ++i) {
272 const auto addend =
static_cast<double>(m_minDistSq(i));
273 const double y = addend - comp;
274 const double t = sum + y;
275 comp = (t - sum) - y;
281 outConverged = converged;
294 const std::size_t n = X.
dim(0);
295 const std::size_t d = X.
dim(1);
296 if (n == 0 || d == 0) {
299 const T *xData = X.
data();
305 if (m_varSum.
dim(0) != d) {
306 m_varSum = NDArray<T, 1>({d});
307 m_varSumSq = NDArray<T, 1>({d});
309 T *colSum = m_varSum.data();
310 T *colSumSq = m_varSumSq.data();
311 for (std::size_t t = 0; t < d; ++t) {
315 for (std::size_t i = 0; i < n; ++i) {
316 const T *row = xData + (i * d);
317 math::detail::columnwiseAccumSumSq<T>(row, d, colSum, colSumSq);
319 const auto nInv =
static_cast<T
>(1) /
static_cast<T
>(n);
321 for (std::size_t t = 0; t < d; ++t) {
322 const T mean = colSum[t] * nInv;
323 acc += (colSumSq[t] * nInv) - (mean * mean);
325 return acc /
static_cast<T
>(d);
328 void ensureShape(std::size_t n, std::size_t d, std::size_t k, std::size_t workerCount) {
329 const bool shapeChanged = (n != m_n) || (d != m_d) || (k != m_k);
330 const bool workerChanged = (workerCount != m_workerCount);
331 if (!shapeChanged && !workerChanged) {
337 const std::size_t blocks = workerCount == 0 ? std::size_t{1} : workerCount;
340 m_centroidsOld = NDArray<T, 2, Layout::Contig>({k, d});
341 m_cSqNorms = NDArray<T, 1>({k});
342 m_sums = NDArray<T, 2, Layout::Contig>({k, d});
343 m_counts = NDArray<std::int32_t, 1>({k});
344 m_minDistSq = NDArray<T, 1>({n});
345 m_xNormsSq = NDArray<T, 1>({n});
346 m_shiftSq = NDArray<T, 1>({k});
347 m_foldComp = NDArray<T, 1>({k * d});
351 m_u = NDArray<T, 1>({n});
352 m_l = NDArray<T, 1>({n});
353 m_shiftEuclidean = NDArray<T, 1>({k});
354 m_halfDistToNearestOther = NDArray<T, 1>({k});
358 if (n * k <= kElkanNKLimit && k <= kElkanMaxK) {
359 m_elkanBounds = NDArray<T, 2, Layout::Contig>({n, k});
360 m_centerDist = NDArray<T, 2, Layout::Contig>({k, k});
362 m_elkanBounds = NDArray<T, 2, Layout::Contig>({0, 0});
363 m_centerDist = NDArray<T, 2, Layout::Contig>({0, 0});
369 const std::size_t packedBSize = needsChunk
370 ? math::detail::packedBScratchSizeFloatsTiled<T>(k, d)
371 : math::detail::packedBScratchSizeFloats(k, d);
372 const std::size_t packedNormsSize = math::detail::packedCSqNormsScratchSizeFloats(k);
373 m_packedB = NDArray<T, 1>({packedBSize == 0 ? std::size_t{1} : packedBSize});
374 m_packedCSqNorms = NDArray<T, 1>({packedNormsSize == 0 ? std::size_t{1} : packedNormsSize});
377 const std::size_t distRows = needsChunk ? (blocks * chunkCap) : std::size_t{1};
378 const std::size_t safeK = (k == 0) ? std::size_t{1} : k;
379 const std::size_t distCols = needsChunk ? safeK : std::size_t{1};
380 m_distsChunk = NDArray<T, 2, Layout::Contig>({distRows, distCols});
381 }
else if (workerChanged) {
384 const std::size_t distRows = blocks * chunkCap;
385 const std::size_t distCols = (k == 0 ? std::size_t{1} : k);
386 m_distsChunk = NDArray<T, 2, Layout::Contig>({distRows, distCols});
393 m_partialSums = NDArray<T, 1>({blocks * k * d});
394 m_partialComps = NDArray<T, 1>({blocks * k * d});
395 m_partialCounts = NDArray<std::int32_t, 1>({blocks * k});
399 const std::size_t apSize = blocks * math::detail::kMc<T> * math::detail::kKc<T>;
400 m_gemmApArena = NDArray<T, 1>({needsChunk ? apSize : std::size_t{1}});
405 m_workerCount = workerCount;
408 void refreshCentroidSqNorms(
const NDArray<T, 2, Layout::Contig> ¢roids)
noexcept {
409 const std::size_t k = centroids.dim(0);
410 const std::size_t d = centroids.dim(1);
411 for (std::size_t c = 0; c < k; ++c) {
412 const T *row = centroids.data() + (c * d);
414 for (std::size_t t = 0; t < d; ++t) {
415 s += row[t] * row[t];
421 void finalizeMeans(NDArray<T, 2, Layout::Contig> ¢roids)
noexcept {
422 const std::size_t k = centroids.dim(0);
423 const std::size_t d = centroids.dim(1);
424 for (std::size_t c = 0; c < k; ++c) {
425 const std::int32_t cnt = m_counts(c);
429 const T inv = T{1} /
static_cast<T
>(cnt);
430 const T *src = m_sums.data() + (c * d);
431 T *dst = centroids.data() + (c * d);
432 for (std::size_t t = 0; t < d; ++t) {
433 dst[t] = src[t] * inv;
445 void runAssignment(
const NDArray<T, 2, Layout::Contig> &X,
446 const NDArray<T, 2, Layout::Contig> ¢roids,
447 NDArray<std::int32_t, 1> &labels, math::Pool pool) {
448#ifdef CLUSTERING_USE_AVX2
449 if constexpr (std::is_same_v<T, float>) {
450 const std::size_t d = X.dim(1);
451 if (X.template isAligned<32>() && centroids.template isAligned<32>() && d != 0) {
453 math::detail::pairwiseArgminDirectSmallDF32(X, centroids, labels, m_minDistSq, pool);
457 math::detail::pairwiseArgminOuterAvx2F32WithScratch(X, centroids, m_cSqNorms, labels,
458 m_minDistSq, m_packedB.data(),
459 m_packedCSqNorms.data(), pool);
465 runChunkedMaterializedAssignment(X, centroids, labels, pool);
474 assignmentProducesDirectMinDistSq(
const NDArray<T, 2, Layout::Contig> &X,
475 const NDArray<T, 2, Layout::Contig> &C)
noexcept {
476#ifdef CLUSTERING_USE_AVX2
477 if constexpr (std::is_same_v<T, float>) {
478 const std::size_t d = X.dim(1);
479 return X.template isAligned<32>() && C.template isAligned<32>() && d != 0 &&
493 [[nodiscard]]
bool assignmentUsesFusedArgmin(
const NDArray<T, 2, Layout::Contig> &X,
494 const NDArray<T, 2, Layout::Contig> &C)
noexcept {
495#ifdef CLUSTERING_USE_AVX2
496 if constexpr (std::is_same_v<T, float>) {
497 const std::size_t d = X.dim(1);
498 return X.template isAligned<32>() && C.template isAligned<32>() &&
520 void packCentroidsTiled(
const NDArray<T, 2, Layout::Contig> ¢roids)
noexcept {
521 constexpr std::size_t kNr = math::detail::kKernelNr<T>;
522 constexpr std::size_t kKcVal = math::detail::kKc<T>;
523 constexpr std::size_t kNcVal = math::detail::kNc<T>;
524 const std::size_t k = centroids.dim(0);
525 const std::size_t d = centroids.dim(1);
526 const auto cTransposed = centroids.t();
527 const auto cDesc = ::clustering::detail::describeMatrix(cTransposed);
528 T *bp = m_packedB.data();
529 std::size_t jcBase = 0;
530 for (std::size_t jc = 0; jc < k; jc += kNcVal) {
531 const std::size_t nc = (jc + kNcVal <= k) ? kNcVal : (k - jc);
532 const std::size_t roundedNc = ((nc + kNr - 1) / kNr) * kNr;
533 std::size_t pcOffInJc = 0;
534 for (std::size_t pc = 0; pc < d; pc += kKcVal) {
535 const std::size_t kc = (pc + kKcVal <= d) ? kKcVal : (d - pc);
536 math::detail::packB<T>(cDesc, pc, kc, jc, nc, bp + jcBase + pcOffInJc);
537 pcOffInJc += kc * roundedNc;
539 jcBase += d * roundedNc;
543 void runChunkedMaterializedAssignment(
const NDArray<T, 2, Layout::Contig> &X,
544 const NDArray<T, 2, Layout::Contig> ¢roids,
545 NDArray<std::int32_t, 1> &labels,
546 math::Pool pool)
noexcept {
547 const std::size_t n = X.dim(0);
548 const std::size_t k = centroids.dim(0);
549 const std::size_t d = X.dim(1);
550 if (n == 0 || k == 0) {
554 packCentroidsTiled(centroids);
556 constexpr std::size_t kMcVal = math::detail::kMc<T>;
557 constexpr std::size_t kKcVal = math::detail::kKc<T>;
559 const std::size_t numChunks = (n + chunkCap - 1) / chunkCap;
560 const T *bp = m_packedB.data();
561 T *apArena = m_gemmApArena.data();
562 T *distsBase = m_distsChunk.data();
563 const T *cNormsBase = m_cSqNorms.data();
564 T *minDistBase = m_minDistSq.data();
565 std::int32_t *labelsBase = labels.data();
566 const T *xBase = X.data();
569 T *uBase = m_u.data();
570 T *lBase = m_l.data();
574 T *elkanBoundsBase = m_elkanBounds.dim(0) == n ? m_elkanBounds.data() :
nullptr;
576 auto runOneChunk = [&](std::size_t chunkIdx)
noexcept {
577 const std::size_t iBase = chunkIdx * chunkCap;
578 const std::size_t chunkRows = (iBase + chunkCap <= n) ? chunkCap : (n - iBase);
580 T *distsChunk = distsBase + (w * chunkCap * k);
581 T *apSlice = apArena + (w * kMcVal * kKcVal);
586 const auto xDesc = ::clustering::detail::describeMatrix(xChunk);
587 auto distsDesc = ::clustering::detail::describeMatrixMut(distsView);
589 math::detail::gemmRunPrepacked<T>(xDesc, bp, d, k, distsDesc, T{-2}, T{0}, apSlice,
592 const T *xNormsChunk = m_xNormsSq.data() + iBase;
593 for (std::size_t i = 0; i < chunkRows; ++i) {
594 const T xn = xNormsChunk[i];
595 const T *row = distsChunk + (i * k);
596 T *elkanRow = elkanBoundsBase !=
nullptr ? elkanBoundsBase + ((iBase + i) * k) : nullptr;
597 T bestVal = std::numeric_limits<T>::infinity();
598 T secondVal = std::numeric_limits<T>::infinity();
599 std::int32_t bestIdx = 0;
600 for (std::size_t j = 0; j < k; ++j) {
601 T v = row[j] + xn + cNormsBase[j];
605 if (elkanRow !=
nullptr) {
606 elkanRow[j] = std::sqrt(v);
611 bestIdx =
static_cast<std::int32_t
>(j);
612 }
else if (v < secondVal) {
616 minDistBase[iBase + i] = bestVal;
617 labelsBase[iBase + i] = bestIdx;
618 uBase[iBase + i] = std::sqrt(bestVal);
619 lBase[iBase + i] = std::sqrt(secondVal);
623 if (pool.shouldParallelize(numChunks, 1, 2) && pool.pool !=
nullptr) {
625 ->submit_blocks(std::size_t{0}, numChunks,
626 [&](std::size_t lo, std::size_t hi) {
627 for (std::size_t c = lo; c < hi; ++c) {
633 for (std::size_t c = 0; c < numChunks; ++c) {
639 void scatterAndFoldPlain(
const NDArray<T, 2, Layout::Contig> &X,
640 const NDArray<std::int32_t, 1> &labels, std::size_t k, math::Pool pool) {
641 const std::size_t n = X.dim(0);
642 const std::size_t d = X.dim(1);
644 T *partialSums = m_partialSums.data();
645 std::int32_t *partialCounts = m_partialCounts.data();
647 for (std::size_t c = 0; c < k; ++c) {
649 for (std::size_t t = 0; t < d; ++t) {
653 if (n == 0 || d == 0) {
657 const bool willParallelize = pool.shouldParallelizeWork(n * d) &&
658 pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr;
659 const std::size_t desiredBlocks = willParallelize ? pool.workerCount() : std::size_t{1};
660 const detail::BlockPartition part(0, n, desiredBlocks);
661 const std::size_t numBlocks = part.num_blocks == 0 ? std::size_t{1} : part.num_blocks;
663 for (std::size_t b = 0; b < numBlocks; ++b) {
664 T *slab = partialSums + (b * k * d);
665 std::int32_t *cslab = partialCounts + (b * k);
666 for (std::size_t e = 0; e < k * d; ++e) {
669 for (std::size_t c = 0; c < k; ++c) {
674 auto scatterRange = [&](std::size_t lo, std::size_t hi)
noexcept {
675 const std::size_t b = part.blockIndexOf(lo);
676 T *slab = partialSums + (b * k * d);
677 std::int32_t *cslab = partialCounts + (b * k);
678 for (std::size_t i = lo; i < hi; ++i) {
679 const std::int32_t lbl = labels(i);
680 if (lbl < 0 || std::cmp_greater_equal(lbl, k)) {
683 const auto row =
static_cast<std::size_t
>(lbl);
684 const T *xRow = X.data() + (i * d);
685 T *dst = slab + (row * d);
686 for (std::size_t t = 0; t < d; ++t) {
693 if (willParallelize) {
696 std::size_t{0}, n, [&](std::size_t lo, std::size_t hi) { scatterRange(lo, hi); },
705 for (std::size_t b = 0; b < numBlocks; ++b) {
706 const T *slab = partialSums + (b * k * d);
707 const std::int32_t *cslab = partialCounts + (b * k);
708 for (std::size_t c = 0; c < k; ++c) {
709 m_counts(c) += cslab[c];
710 const T *src = slab + (c * d);
711 T *dstRow = &m_sums(c, 0);
712 for (std::size_t t = 0; t < d; ++t) {
719 void scatterAndFoldKahan(
const NDArray<T, 2, Layout::Contig> &X,
720 const NDArray<std::int32_t, 1> &labels, std::size_t k, math::Pool pool) {
721 const std::size_t n = X.dim(0);
722 const std::size_t d = X.dim(1);
724 T *partialSums = m_partialSums.data();
725 T *partialComps = m_partialComps.data();
726 std::int32_t *partialCounts = m_partialCounts.data();
727 T *foldComp = m_foldComp.data();
729 for (std::size_t c = 0; c < k; ++c) {
731 for (std::size_t t = 0; t < d; ++t) {
735 for (std::size_t e = 0; e < k * d; ++e) {
738 if (n == 0 || d == 0) {
742 const bool willParallelize = pool.shouldParallelizeWork(n * d) &&
743 pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr;
744 const std::size_t desiredBlocks = willParallelize ? pool.workerCount() : std::size_t{1};
745 const detail::BlockPartition part(0, n, desiredBlocks);
746 const std::size_t numBlocks = part.num_blocks == 0 ? std::size_t{1} : part.num_blocks;
748 for (std::size_t b = 0; b < numBlocks; ++b) {
749 T *slab = partialSums + (b * k * d);
750 T *cslab = partialComps + (b * k * d);
751 std::int32_t *nslab = partialCounts + (b * k);
752 for (std::size_t e = 0; e < k * d; ++e) {
756 for (std::size_t c = 0; c < k; ++c) {
761 auto scatterRange = [&](std::size_t lo, std::size_t hi)
noexcept {
762 const std::size_t b = part.blockIndexOf(lo);
763 T *slab = partialSums + (b * k * d);
764 T *cslab = partialComps + (b * k * d);
765 std::int32_t *nslab = partialCounts + (b * k);
766 for (std::size_t i = lo; i < hi; ++i) {
767 const std::int32_t lbl = labels(i);
768 if (lbl < 0 || std::cmp_greater_equal(lbl, k)) {
771 const auto row =
static_cast<std::size_t
>(lbl);
772 const T *xRow = X.data() + (i * d);
773 T *sumRow = slab + (row * d);
774 T *compRow = cslab + (row * d);
775 math::detail::kahanAddRow<T>(xRow, d, sumRow, compRow);
780 if (willParallelize) {
783 std::size_t{0}, n, [&](std::size_t lo, std::size_t hi) { scatterRange(lo, hi); },
790 for (std::size_t b = 0; b < numBlocks; ++b) {
791 const T *slab = partialSums + (b * k * d);
792 const T *cslab = partialComps + (b * k * d);
793 const std::int32_t *nslab = partialCounts + (b * k);
794 for (std::size_t c = 0; c < k; ++c) {
795 m_counts(c) += nslab[c];
796 const T *src = slab + (c * d);
797 const T *comp = cslab + (c * d);
798 T *dstRow = &m_sums(c, 0);
799 T *foldRow = foldComp + (c * d);
800 for (std::size_t t = 0; t < d; ++t) {
801 const T addend = src[t] - comp[t];
802 const T y = addend - foldRow[t];
803 const T tVal = dstRow[t] + y;
804 foldRow[t] = (tVal - dstRow[t]) - y;
811 void recomputeMinDistSqDirect(
const NDArray<T, 2, Layout::Contig> &X,
812 const NDArray<T, 2, Layout::Contig> ¢roids,
813 const NDArray<std::int32_t, 1> &labels, math::Pool pool)
noexcept {
814 const std::size_t n = X.dim(0);
815 const std::size_t d = X.dim(1);
816 const std::size_t k = centroids.dim(0);
817 if (n == 0 || d == 0 || k == 0) {
821 auto runRowRange = [&](std::size_t lo, std::size_t hi)
noexcept {
822 for (std::size_t i = lo; i < hi; ++i) {
823 const std::int32_t lbl = labels(i);
824 if (lbl < 0 || std::cmp_greater_equal(lbl, k)) {
825 m_minDistSq(i) = T{0};
828 const T *xRow = X.data() + (i * d);
829 const T *cRow = centroids.data() + (
static_cast<std::size_t
>(lbl) * d);
830 m_minDistSq(i) = math::detail::sqEuclideanRowPtr<T>(xRow, cRow, d);
834 if (pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr) {
836 ->submit_blocks(std::size_t{0}, n,
837 [&](std::size_t lo, std::size_t hi) { runRowRange(lo, hi); })
844 void seedHamerlyBoundsFromLabels(
const NDArray<T, 2, Layout::Contig> &X,
845 const NDArray<T, 2, Layout::Contig> ¢roids,
846 const NDArray<std::int32_t, 1> &labels,
847 math::Pool pool)
noexcept {
848 const std::size_t n = X.dim(0);
849 const std::size_t d = X.dim(1);
850 const std::size_t k = centroids.dim(0);
851 if (n == 0 || d == 0 || k == 0) {
855 auto seedRange = [&](std::size_t lo, std::size_t hi)
noexcept {
856 for (std::size_t i = lo; i < hi; ++i) {
857 const std::int32_t lbl = labels(i);
858 if (lbl < 0 || std::cmp_greater_equal(lbl, k)) {
859 m_minDistSq(i) = T{0};
860 m_u(i) = std::numeric_limits<T>::infinity();
864 const T *xRow = X.data() + (i * d);
865 const T *cRow = centroids.data() + (
static_cast<std::size_t
>(lbl) * d);
866 const T tightSq = math::detail::sqEuclideanRowPtr<T>(xRow, cRow, d);
867 m_minDistSq(i) = tightSq;
868 m_u(i) = std::sqrt(tightSq);
873 if (pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr) {
875 ->submit_blocks(std::size_t{0}, n,
876 [&](std::size_t lo, std::size_t hi) { seedRange(lo, hi); })
889 static constexpr std::size_t kHamerlyMaxK = 64;
897 static constexpr std::size_t kElkanMaxK = 4096;
906 static constexpr std::size_t kElkanNKLimit = std::size_t{32} << 20;
917 void runHamerlyAssignment(
const NDArray<T, 2, Layout::Contig> &X,
918 const NDArray<T, 2, Layout::Contig> ¢roids,
919 NDArray<std::int32_t, 1> &labels, math::Pool pool)
noexcept {
920 const std::size_t n = X.dim(0);
921 const std::size_t d = X.dim(1);
922 const std::size_t k = centroids.dim(0);
923 if (n == 0 || d == 0 || k == 0 || k > kHamerlyMaxK) {
926 const T *xData = X.data();
927 const T *cData = centroids.data();
928 T *uData = m_u.data();
929 T *lData = m_l.data();
930 T *minDistData = m_minDistSq.data();
931 std::int32_t *labelsData = labels.data();
938 std::size_t argMax = 0;
939 T *shiftData = m_shiftEuclidean.data();
940 for (std::size_t c = 0; c < k; ++c) {
941 const T s = std::sqrt(m_shiftSq(c));
947 }
else if (s > s2Max) {
957 T *halfDistData = m_halfDistToNearestOther.data();
958 for (std::size_t c = 0; c < k; ++c) {
959 T nearestSq = std::numeric_limits<T>::infinity();
960 const T *caRow = cData + (c * d);
961 for (std::size_t cp = 0; cp < k; ++cp) {
965 const T dsq = math::detail::sqEuclideanRowPtr<T>(caRow, cData + (cp * d), d);
966 if (dsq < nearestSq) {
970 halfDistData[c] = T{0.5} * std::sqrt(nearestSq);
973 auto processRange = [&](std::size_t lo, std::size_t hi)
noexcept {
974 std::array<T, kHamerlyMaxK> distBuf{};
975 for (std::size_t i = lo; i < hi; ++i) {
976 const std::int32_t a = labelsData[i];
977 if (a < 0 || std::cmp_greater_equal(a, k)) {
980 const auto au =
static_cast<std::size_t
>(a);
981 T ui = uData[i] + shiftData[au];
982 T li = lData[i] - ((au == argMax) ? s2Max : sMax);
995 if (ui <= halfDistData[au]) {
1001 const T *xi = xData + (i * d);
1002 const T *caRow = cData + (au * d);
1003 const T tightSq = math::detail::sqEuclideanRowPtr<T>(xi, caRow, d);
1004 ui = std::sqrt(tightSq);
1009 minDistData[i] = tightSq;
1014 T best = std::numeric_limits<T>::infinity();
1015 T second = std::numeric_limits<T>::infinity();
1016 std::int32_t bestIdx = 0;
1017 for (std::size_t j = 0; j < k; ++j) {
1018 const T v = distBuf[j];
1022 bestIdx =
static_cast<std::int32_t
>(j);
1023 }
else if (v < second) {
1027 labelsData[i] = bestIdx;
1028 minDistData[i] = best;
1029 uData[i] = std::sqrt(best);
1030 lData[i] = std::sqrt(second);
1034 if (pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr) {
1036 ->submit_blocks(std::size_t{0}, n,
1037 [&](std::size_t lo, std::size_t hi) { processRange(lo, hi); })
1054 void runElkanAssignment(
const NDArray<T, 2, Layout::Contig> &X,
1055 const NDArray<T, 2, Layout::Contig> ¢roids,
1056 NDArray<std::int32_t, 1> &labels, math::Pool pool)
noexcept {
1057 const std::size_t n = X.dim(0);
1058 const std::size_t d = X.dim(1);
1059 const std::size_t k = centroids.dim(0);
1060 if (n == 0 || d == 0 || k == 0 || m_elkanBounds.dim(0) != n || m_elkanBounds.dim(1) != k) {
1063 const T *xData = X.data();
1064 const T *cData = centroids.data();
1065 T *uData = m_u.data();
1066 T *boundsData = m_elkanBounds.data();
1067 T *minDistData = m_minDistSq.data();
1068 std::int32_t *labelsData = labels.data();
1071 T *shiftData = m_shiftEuclidean.data();
1072 for (std::size_t c = 0; c < k; ++c) {
1073 shiftData[c] = std::sqrt(m_shiftSq(c));
1078 T *centerDistData = m_centerDist.data();
1079 T *halfDistData = m_halfDistToNearestOther.data();
1080 for (std::size_t c = 0; c < k; ++c) {
1081 centerDistData[(c * k) + c] = T{0};
1082 T nearest = std::numeric_limits<T>::infinity();
1083 for (std::size_t cp = 0; cp < k; ++cp) {
1089 const T dsq = math::detail::sqEuclideanRowPtr<T>(cData + (c * d), cData + (cp * d), d);
1090 dist = std::sqrt(dsq);
1091 centerDistData[(c * k) + cp] = dist;
1092 centerDistData[(cp * k) + c] = dist;
1094 dist = centerDistData[(c * k) + cp];
1096 if (dist < nearest) {
1100 halfDistData[c] = T{0.5} * nearest;
1103 auto processRange = [&](std::size_t lo, std::size_t hi)
noexcept {
1104 for (std::size_t i = lo; i < hi; ++i) {
1105 std::int32_t a = labelsData[i];
1106 if (a < 0 || std::cmp_greater_equal(a, k)) {
1109 auto au =
static_cast<std::size_t
>(a);
1110 T u = uData[i] + shiftData[au];
1111 T *lRow = boundsData + (i * k);
1114 for (std::size_t c = 0; c < k; ++c) {
1115 T lnew = lRow[c] - shiftData[c];
1122 if (u <= halfDistData[au]) {
1127 bool uTight =
false;
1128 const T *xi = xData + (i * d);
1129 for (std::size_t c = 0; c < k; ++c) {
1133 const T lc = lRow[c];
1134 const T half = T{0.5} * centerDistData[(au * k) + c];
1135 if (u <= lc || u <= half) {
1139 const T tightSq = math::detail::sqEuclideanRowPtr<T>(xi, cData + (au * d), d);
1140 u = std::sqrt(tightSq);
1141 minDistData[i] = tightSq;
1143 if (u <= lc || u <= half) {
1147 const T dSq = math::detail::sqEuclideanRowPtr<T>(xi, cData + (c * d), d);
1148 const T dEuc = std::sqrt(dSq);
1152 a =
static_cast<std::int32_t
>(c);
1154 minDistData[i] = dSq;
1162 if (pool.shouldParallelize(n, 64, 2) && pool.pool !=
nullptr) {
1164 ->submit_blocks(std::size_t{0}, n,
1165 [&](std::size_t lo, std::size_t hi) { processRange(lo, hi); })
1172 NDArray<T, 2, Layout::Contig> m_centroidsOld;
1173 NDArray<T, 1> m_cSqNorms;
1174 NDArray<T, 2, Layout::Contig> m_sums;
1175 NDArray<std::int32_t, 1> m_counts;
1176 NDArray<T, 1> m_minDistSq;
1177 NDArray<T, 1> m_shiftSq;
1178 NDArray<T, 1> m_partialSums;
1179 NDArray<T, 1> m_partialComps;
1180 NDArray<std::int32_t, 1> m_partialCounts;
1181 NDArray<T, 1> m_foldComp;
1182 NDArray<T, 1> m_packedB;
1183 NDArray<T, 1> m_packedCSqNorms;
1184 NDArray<T, 2, Layout::Contig> m_distsChunk;
1185 NDArray<T, 1> m_gemmApArena;
1186 NDArray<T, 1> m_xNormsSq;
1187 NDArray<T, 1> m_varSum;
1188 NDArray<T, 1> m_varSumSq;
1195 NDArray<T, 1> m_shiftEuclidean;
1199 NDArray<T, 1> m_halfDistToNearestOther;
1203 NDArray<T, 2, Layout::Contig> m_elkanBounds;
1206 NDArray<T, 2, Layout::Contig> m_centerDist;
1208 std::size_t m_n = 0;
1209 std::size_t m_d = 0;
1210 std::size_t m_k = 0;
1211 std::size_t m_workerCount = 0;