Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
pairwise.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <concepts>
5#include <cstddef>
6#include <cstdint>
7#include <type_traits>
8
11#include "clustering/math/detail/pairwise_threshold_outer.h"
14#include "clustering/ndarray.h"
15
16#ifdef CLUSTERING_USE_AVX2
17#include <immintrin.h>
18#endif
19
20// The dispatch metric n*m*d must not wrap. Realistic clustering sizes stay well inside 2^63 on
21// any LP64 / LLP64 platform we target; a 32-bit size_t would overflow the metric long before it
22// overflows an allocation. Pin the platform expectation so a stray cross-compile flags instead of
23// silently under-counting.
24static_assert(sizeof(std::size_t) >= 8, "pairwise dispatch assumes a 64-bit std::size_t");
25
26namespace clustering::math {
27
28namespace detail {
29
37enum class PairwisePath : std::uint8_t { Simd, Gemm };
38
39#ifdef CLUSTERING_USE_AVX2
40
41inline float horizontalSumAvx2(__m256 v) noexcept {
42 const __m256 permute = _mm256_permute2f128_ps(v, v, 1);
43 const __m256 s1 = _mm256_add_ps(v, permute);
44 const __m256 s2 = _mm256_hadd_ps(s1, s1);
45 const __m256 s3 = _mm256_hadd_ps(s2, s2);
46 return _mm_cvtss_f32(_mm256_castps256_ps128(s3));
47}
48
49inline double horizontalSumAvx2(__m256d v) noexcept {
50 const __m256d permute = _mm256_permute2f128_pd(v, v, 1);
51 const __m256d s1 = _mm256_add_pd(v, permute);
52 const __m256d s2 = _mm256_hadd_pd(s1, s1);
53 return _mm_cvtsd_f64(_mm256_castpd256_pd128(s2));
54}
55
56inline float sqEuclideanRowAvx2(const float *xRow, const float *yRow, std::size_t d) noexcept {
57 __m256 acc = _mm256_setzero_ps();
58 const bool xAligned = (reinterpret_cast<std::uintptr_t>(xRow) % 32) == 0;
59 const bool yAligned = (reinterpret_cast<std::uintptr_t>(yRow) % 32) == 0;
60 std::size_t k = 0;
61 for (; k + 8 <= d; k += 8) {
62 const __m256 vx = xAligned ? _mm256_load_ps(xRow + k) : _mm256_loadu_ps(xRow + k);
63 const __m256 vy = yAligned ? _mm256_load_ps(yRow + k) : _mm256_loadu_ps(yRow + k);
64 const __m256 diff = _mm256_sub_ps(vx, vy);
65 acc = _mm256_add_ps(acc, _mm256_mul_ps(diff, diff));
66 }
67 float tail = 0.0F;
68 for (; k < d; ++k) {
69 const float diff = xRow[k] - yRow[k];
70 tail += diff * diff;
71 }
72 return horizontalSumAvx2(acc) + tail;
73}
74
75inline double sqEuclideanRowAvx2(const double *xRow, const double *yRow, std::size_t d) noexcept {
76 __m256d acc = _mm256_setzero_pd();
77 const bool xAligned = (reinterpret_cast<std::uintptr_t>(xRow) % 32) == 0;
78 const bool yAligned = (reinterpret_cast<std::uintptr_t>(yRow) % 32) == 0;
79 std::size_t k = 0;
80 for (; k + 4 <= d; k += 4) {
81 const __m256d vx = xAligned ? _mm256_load_pd(xRow + k) : _mm256_loadu_pd(xRow + k);
82 const __m256d vy = yAligned ? _mm256_load_pd(yRow + k) : _mm256_loadu_pd(yRow + k);
83 const __m256d diff = _mm256_sub_pd(vx, vy);
84 acc = _mm256_add_pd(acc, _mm256_mul_pd(diff, diff));
85 }
86 double tail = 0.0;
87 for (; k < d; ++k) {
88 const double diff = xRow[k] - yRow[k];
89 tail += diff * diff;
90 }
91 return horizontalSumAvx2(acc) + tail;
92}
93
94#endif // CLUSTERING_USE_AVX2
95
96template <class T> constexpr std::size_t kAvx2Lanes = std::is_same_v<T, float> ? 8 : 4;
97
98template <class T, Layout LX, Layout LY>
99inline T sqEuclideanRow(const NDArray<T, 2, LX> &X, std::size_t i, const NDArray<T, 2, LY> &Y,
100 std::size_t j) noexcept {
101 const std::size_t d = X.dim(1);
102#ifdef CLUSTERING_USE_AVX2
103 if constexpr (LX == Layout::Contig && LY == Layout::Contig) {
104 if (d >= kAvx2Lanes<T>) {
105 const T *xRow = X.data() + (i * d);
106 const T *yRow = Y.data() + (j * d);
107 return sqEuclideanRowAvx2(xRow, yRow, d);
108 }
109 }
110#endif
111 T sum = T{0};
112 for (std::size_t k = 0; k < d; ++k) {
113 const T diff = X(i, k) - Y(j, k);
114 sum += diff * diff;
115 }
116 return sum;
117}
118
119#ifdef CLUSTERING_USE_AVX2
120
121inline float sqNormRowAvx2(const float *xRow, std::size_t d) noexcept {
122 __m256 acc = _mm256_setzero_ps();
123 const bool aligned = (reinterpret_cast<std::uintptr_t>(xRow) % 32) == 0;
124 std::size_t k = 0;
125 for (; k + 8 <= d; k += 8) {
126 const __m256 v = aligned ? _mm256_load_ps(xRow + k) : _mm256_loadu_ps(xRow + k);
127 acc = _mm256_add_ps(acc, _mm256_mul_ps(v, v));
128 }
129 float tail = 0.0F;
130 for (; k < d; ++k) {
131 tail += xRow[k] * xRow[k];
132 }
133 return horizontalSumAvx2(acc) + tail;
134}
135
136inline double sqNormRowAvx2(const double *xRow, std::size_t d) noexcept {
137 __m256d acc = _mm256_setzero_pd();
138 const bool aligned = (reinterpret_cast<std::uintptr_t>(xRow) % 32) == 0;
139 std::size_t k = 0;
140 for (; k + 4 <= d; k += 4) {
141 const __m256d v = aligned ? _mm256_load_pd(xRow + k) : _mm256_loadu_pd(xRow + k);
142 acc = _mm256_add_pd(acc, _mm256_mul_pd(v, v));
143 }
144 double tail = 0.0;
145 for (; k < d; ++k) {
146 tail += xRow[k] * xRow[k];
147 }
148 return horizontalSumAvx2(acc) + tail;
149}
150
151#endif // CLUSTERING_USE_AVX2
152
153template <class T, Layout LX>
154inline T sqNormRow(const NDArray<T, 2, LX> &X, std::size_t i) noexcept {
155 const std::size_t d = X.dim(1);
156#ifdef CLUSTERING_USE_AVX2
157 if constexpr (LX == Layout::Contig) {
158 if (d >= kAvx2Lanes<T>) {
159 const T *xRow = X.data() + (i * d);
160 return sqNormRowAvx2(xRow, d);
161 }
162 }
163#endif
164 T sum = T{0};
165 for (std::size_t k = 0; k < d; ++k) {
166 const T v = X(i, k);
167 sum += v * v;
168 }
169 return sum;
170}
171
186template <class T, Layout LX>
187void rowNormsSq(const NDArray<T, 2, LX> &X, NDArray<T, 1> &norms, Pool pool) {
188 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
189 "rowNormsSq<T> requires T to be float or double");
190
192 CLUSTERING_ALWAYS_ASSERT(norms.dim(0) == X.dim(0));
193
194 const std::size_t n = X.dim(0);
195 if (n == 0) {
196 return;
197 }
198
199 auto runRowRange = [&](std::size_t lo, std::size_t hi) noexcept {
200 for (std::size_t i = lo; i < hi; ++i) {
201 norms(i) = sqNormRow<T, LX>(X, i);
202 }
203 };
204
205 if (pool.shouldParallelize(n, 4, 2) && pool.pool != nullptr) {
206 pool.pool
207 ->submit_blocks(std::size_t{0}, n,
208 [&](std::size_t lo, std::size_t hi) { runRowRange(lo, hi); })
209 .wait();
210 } else {
211 runRowRange(0, n);
212 }
213}
214
232template <class T, Layout LX, Layout LY>
234 NDArray<T, 2> &out, Pool pool) {
235 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
236 "pairwiseSqEuclideanGemm<T> requires T to be float or double");
237
239 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == Y.dim(1));
240 CLUSTERING_ALWAYS_ASSERT(out.dim(0) == X.dim(0));
241 CLUSTERING_ALWAYS_ASSERT(out.dim(1) == Y.dim(0));
242
243 const std::size_t n = X.dim(0);
244 const std::size_t m = Y.dim(0);
245 if (n == 0 || m == 0) {
246 return;
247 }
248
249 NDArray<T, 1> xNorms({n});
250 NDArray<T, 1> yNorms({m});
251 rowNormsSq(X, xNorms, pool);
252 rowNormsSq(Y, yNorms, pool);
253
254 gemm(X, Y.t(), out, pool, T{-2}, T{0});
255
256 auto runBroadcastRange = [&](std::size_t lo, std::size_t hi) noexcept {
257 for (std::size_t i = lo; i < hi; ++i) {
258 const T xi = xNorms(i);
259 for (std::size_t j = 0; j < m; ++j) {
260 // Cancellation in ||x||^2 + ||y||^2 - 2 x . y can produce tiny negatives when x ~= y;
261 // squared distance is non-negative by definition, so clamp.
262 const T v = (out(i, j) + xi) + yNorms(j);
263 out(i, j) = std::max(v, T{0});
264 }
265 }
266 };
267
268 const std::size_t totalCells = n * m;
269 if (pool.shouldParallelize(totalCells, 64, 2) && pool.pool != nullptr) {
270 pool.pool
271 ->submit_blocks(std::size_t{0}, n,
272 [&](std::size_t lo, std::size_t hi) { runBroadcastRange(lo, hi); })
273 .wait();
274 } else {
275 runBroadcastRange(0, n);
276 }
277}
278
295template <class T, Layout LX, Layout LY>
297 NDArray<T, 2> &out, Pool pool) {
298 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
299 "pairwiseSqEuclideanSimd<T> requires T to be float or double");
300
302 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == Y.dim(1));
303 CLUSTERING_ALWAYS_ASSERT(out.dim(0) == X.dim(0));
304 CLUSTERING_ALWAYS_ASSERT(out.dim(1) == Y.dim(0));
305
306 const std::size_t n = X.dim(0);
307 const std::size_t m = Y.dim(0);
308 if (n == 0 || m == 0) {
309 return;
310 }
311
312 auto runRowRange = [&](std::size_t lo, std::size_t hi) noexcept {
313 for (std::size_t i = lo; i < hi; ++i) {
314 for (std::size_t j = 0; j < m; ++j) {
315 out(i, j) = sqEuclideanRow<T, LX, LY>(X, i, Y, j);
316 }
317 }
318 };
319
320 if (pool.shouldParallelize(n, 4, 2) && pool.pool != nullptr) {
321 pool.pool
322 ->submit_blocks(std::size_t{0}, n,
323 [&](std::size_t lo, std::size_t hi) { runRowRange(lo, hi); })
324 .wait();
325 } else {
326 runRowRange(0, n);
327 }
328}
329
330} // namespace detail
331
349template <class T, Layout LX = Layout::Contig, Layout LY = Layout::Contig>
351 Pool pool) {
352 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
353 "pairwiseSqEuclidean<T> requires T to be float or double");
354
356 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == Y.dim(1));
357 CLUSTERING_ALWAYS_ASSERT(out.dim(0) == X.dim(0));
358 CLUSTERING_ALWAYS_ASSERT(out.dim(1) == Y.dim(0));
359
360 const std::size_t n = X.dim(0);
361 const std::size_t m = Y.dim(0);
362 if (n == 0 || m == 0) {
363 return;
364 }
365
366 const std::size_t work = n * m * X.dim(1);
368 detail::pairwiseSqEuclideanGemm(X, Y, out, pool);
369 } else {
370 detail::pairwiseSqEuclideanSimd(X, Y, out, pool);
371 }
372}
373
374namespace detail {
375
393template <class T, Layout LX = Layout::Contig, Layout LY = Layout::Contig>
395 const NDArray<T, 2, LY> &Y, NDArray<T, 2> &out,
396 Pool pool) {
397 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
398 "pairwiseSqEuclideanWithDispatchInfo<T> requires T to be float or double");
399
401 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == Y.dim(1));
402 CLUSTERING_ALWAYS_ASSERT(out.dim(0) == X.dim(0));
403 CLUSTERING_ALWAYS_ASSERT(out.dim(1) == Y.dim(0));
404
405 const std::size_t n = X.dim(0);
406 const std::size_t m = Y.dim(0);
407 if (n == 0 || m == 0) {
408 return PairwisePath::Simd;
409 }
410
411 const std::size_t work = n * m * X.dim(1);
413 pairwiseSqEuclideanGemm(X, Y, out, pool);
414 return PairwisePath::Gemm;
415 }
416 pairwiseSqEuclideanSimd(X, Y, out, pool);
417 return PairwisePath::Simd;
418}
419
431template <class T, Layout LX, Layout LY>
433#ifdef CLUSTERING_USE_AVX2
434 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig && LY == Layout::Contig) {
435 const std::size_t n = X.dim(0);
436 const std::size_t m = Y.dim(0);
437 const std::size_t d = X.dim(1);
438 if (n == 0 || m == 0 || d == 0) {
439 return false;
440 }
441 if (d < 8 || d > kThresholdMaxD) {
442 return false;
443 }
444 if (!X.template isAligned<32>() || !Y.template isAligned<32>()) {
445 return false;
446 }
447 return true;
448 } else {
449 (void)X;
450 (void)Y;
451 return false;
452 }
453#else
454 (void)X;
455 (void)Y;
456 return false;
457#endif
458}
459
470template <class T, Layout LX, Layout LY, class Emit>
471 requires std::invocable<Emit &, std::size_t, std::size_t>
473 const NDArray<T, 2, LY> &Y, T radiusSq, Pool pool,
474 Emit &&emit) {
475 const std::size_t n = X.dim(0);
476 const std::size_t m = Y.dim(0);
477 if (n == 0 || m == 0) {
478 return;
479 }
480
481 auto runRowRange = [&](std::size_t lo, std::size_t hi) {
482 for (std::size_t i = lo; i < hi; ++i) {
483 for (std::size_t j = 0; j < m; ++j) {
484 const T distSq = sqEuclideanRow<T, LX, LY>(X, i, Y, j);
485 if (distSq <= radiusSq) {
486 emit(i, j);
487 }
488 }
489 }
490 };
491
492 // Only fan out across rows: column emit within a row is order-sensitive and consumers rely
493 // on the per-row contract. Parallelism at the seed (row) level is safe because each row's
494 // emits land in a distinct key space, but the caller owns thread-safety of @p emit.
495 if (pool.shouldParallelize(n * m, 64, 2) && pool.pool != nullptr) {
496 pool.pool
497 ->submit_blocks(std::size_t{0}, n,
498 [&](std::size_t lo, std::size_t hi) { runRowRange(lo, hi); })
499 .wait();
500 } else {
501 runRowRange(0, n);
502 }
503}
504
505} // namespace detail
506
529template <class T, Layout LX = Layout::Contig, Layout LY = Layout::Contig, class Emit>
530 requires std::invocable<Emit &, std::size_t, std::size_t>
532 T radiusSq, Pool pool, Emit &&emit) {
533 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
534 "pairwiseSqEuclideanThresholded<T> requires T to be float or double");
535 CLUSTERING_ALWAYS_ASSERT(X.dim(1) == Y.dim(1));
536
537 const std::size_t n = X.dim(0);
538 const std::size_t m = Y.dim(0);
539 if (n == 0 || m == 0) {
540 return;
541 }
542
543#ifdef CLUSTERING_USE_AVX2
544 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig && LY == Layout::Contig) {
546 NDArray<T, 1> xNorms({n});
547 NDArray<T, 1> yNorms({m});
548 detail::rowNormsSq(X, xNorms, pool);
549 detail::rowNormsSq(Y, yNorms, pool);
550 detail::pairwiseThresholdOuterAvx2F32(X, Y, xNorms, yNorms, radiusSq, pool, emit);
551 return;
552 }
553 }
554#endif
555
556 detail::pairwiseSqEuclideanThresholdedMaterialized(X, Y, radiusSq, pool, emit);
557}
558
581template <class T, Layout LX = Layout::Contig, class Emit>
582 requires std::invocable<Emit &, std::size_t, std::size_t>
584 Emit &&emit) {
585 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
586 "pairwiseSqEuclideanThresholdedSymmetric<T> requires T to be float or double");
587
588 const std::size_t n = X.dim(0);
589 if (n == 0) {
590 return;
591 }
592
593#ifdef CLUSTERING_USE_AVX2
594 if constexpr (std::is_same_v<T, float> && LX == Layout::Contig) {
596 NDArray<T, 1> xNorms({n});
597 detail::rowNormsSq(X, xNorms, pool);
598 detail::pairwiseThresholdOuterAvx2F32Symmetric(X, xNorms, radiusSq, pool, emit);
599 return;
600 }
601 }
602#endif
603
604 // Scalar fallback: walk only j >= i and forward each surviving upper-triangular cell to the
605 // caller. Mirrors the @c pairwiseSqEuclideanThresholdedMaterialized contract for the
606 // non-symmetric case; the caller's emit is responsible for any adj-side mirror push.
607 auto runRowRange = [&](std::size_t lo, std::size_t hi) {
608 for (std::size_t i = lo; i < hi; ++i) {
609 for (std::size_t j = i; j < n; ++j) {
610 const T distSq = detail::sqEuclideanRow<T, LX, LX>(X, i, X, j);
611 if (distSq <= radiusSq) {
612 emit(i, j);
613 }
614 }
615 }
616 };
617
618 if (pool.shouldParallelize(n * n / 2, 64, 2) && pool.pool != nullptr) {
619 pool.pool
620 ->submit_blocks(std::size_t{0}, n,
621 [&](std::size_t lo, std::size_t hi) { runRowRange(lo, hi); })
622 .wait();
623 } else {
624 runRowRange(0, n);
625 }
626}
627
628} // namespace clustering::math
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
NDArray< T, 2, Layout::MaybeStrided > t() noexcept
Transposes a rank-2 NDArray into a borrowed view with swapped axes.
Definition ndarray.h:683
bool isMutable() const noexcept
Reports whether writes through operator(), Accessor, or flatIndex are allowed.
Definition ndarray.h:488
constexpr std::size_t pairwiseGemmThreshold
Workload threshold at which pairwiseSqEuclidean switches from the per-pair SIMD kernel to the GEMM-id...
Definition defaults.h:52
PairwisePath pairwiseSqEuclideanWithDispatchInfo(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Test-only: runs the same dispatch as pairwiseSqEuclidean and reports which kernel fired.
Definition pairwise.h:394
float horizontalSumAvx2(__m256 v) noexcept
Definition pairwise.h:41
float sqNormRowAvx2(const float *xRow, std::size_t d) noexcept
Definition pairwise.h:121
PairwisePath
Tag identifying which inner kernel executed for a pairwise distance request.
Definition pairwise.h:37
float sqEuclideanRowAvx2(const float *xRow, const float *yRow, std::size_t d) noexcept
Definition pairwise.h:56
T sqEuclideanRow(const NDArray< T, 2, LX > &X, std::size_t i, const NDArray< T, 2, LY > &Y, std::size_t j) noexcept
Definition pairwise.h:99
constexpr std::size_t kAvx2Lanes
Definition pairwise.h:96
void rowNormsSq(const NDArray< T, 2, LX > &X, NDArray< T, 1 > &norms, Pool pool)
Row-wise sum of squares: norms(i) = sum_k X(i, k)^2.
Definition pairwise.h:187
void pairwiseSqEuclideanThresholdedMaterialized(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, T radiusSq, Pool pool, Emit &&emit)
Materialized fallback for the thresholded-emit API: compute each pair's squared distance via sqEuclid...
Definition pairwise.h:472
bool canUseFusedThreshold(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y) noexcept
Runtime predicate: true when the fused AVX2 threshold path is eligible.
Definition pairwise.h:432
void pairwiseSqEuclideanSimd(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Small-path pairwise squared Euclidean via SIMD accumulation per (i, j) pair.
Definition pairwise.h:296
void pairwiseSqEuclideanGemm(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Large-path pairwise squared Euclidean via the GEMM identity.
Definition pairwise.h:233
T sqNormRow(const NDArray< T, 2, LX > &X, std::size_t i) noexcept
Definition pairwise.h:154
void gemm(const NDArray< T, 2, LA > &A, const NDArray< T, 2, LB > &B, NDArray< T, 2 > &C, Pool pool, T alpha=T{1}, T beta=T{0})
One-shot dense matrix-matrix multiply: C := alpha * A * B + beta * C.
Definition gemm.h:31
void pairwiseSqEuclideanThresholded(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, T radiusSq, Pool pool, Emit &&emit)
Emit every row pair (i, j) whose squared Euclidean distance is at most radiusSq.
Definition pairwise.h:531
void pairwiseSqEuclideanThresholdedSymmetric(const NDArray< T, 2, LX > &X, T radiusSq, Pool pool, Emit &&emit)
Symmetric variant of pairwiseSqEuclideanThresholded for the X == Y case.
Definition pairwise.h:583
void pairwiseSqEuclidean(const NDArray< T, 2, LX > &X, const NDArray< T, 2, LY > &Y, NDArray< T, 2 > &out, Pool pool)
Pairwise squared Euclidean distances between rows of two matrices.
Definition pairwise.h:350
T sum(const NDArray< T, 1, L > &x) noexcept
Naive single-pass sum of a rank-1 array.
Definition reduce.h:25
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63
BS::light_thread_pool * pool
Underlying pool, or nullptr to force serial execution.
Definition thread.h:65
bool shouldParallelize(std::size_t totalWork, std::size_t minChunk, std::size_t minTasksPerWorker=2) const noexcept
Decide whether totalWork warrants parallel dispatch.
Definition thread.h:98