Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
nn_descent_mst_backend.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <cstddef>
6#include <cstdint>
7#include <limits>
8#include <optional>
9#include <type_traits>
10#include <utility>
11#include <vector>
12
16#include "clustering/math/detail/avx2_helpers.h"
17#include "clustering/math/dsu.h"
19#include "clustering/ndarray.h"
20
21namespace clustering::hdbscan {
22
36 std::size_t kExtra = 10;
38 std::size_t maxIter = 10;
40 float delta = 0.001F;
42 std::uint64_t seed = 0;
43};
44
72template <class T> class NnDescentMstBackend {
73 static_assert(
74 std::is_same_v<T, float>,
75 "NnDescentMstBackend<T> supports only float; a double specialization is out of scope.");
76
77public:
85
92 explicit NnDescentMstBackend(NnDescentMstConfig config) : m_config(config) {}
93
105 void run(const NDArray<T, 2> &X, std::size_t minSamples, math::Pool pool, MstOutput<T> &out) {
106 const std::size_t n = X.dim(0);
107 CLUSTERING_ALWAYS_ASSERT(minSamples >= 1);
108 CLUSTERING_ALWAYS_ASSERT(minSamples < n);
109
110 out.edges.clear();
111 out.edges.reserve(n - 1);
112 out.coreDistances = NDArray<T, 1>(std::array<std::size_t, 1>{n});
113
114 // Phase 1: rebuild the NN-Descent kNN graph per fit. The kNN graph encodes the current
115 // input's neighbour structure; caching it across fits would risk silent failure under
116 // in-place mutation of the caller's buffer. The index is reconstructed on every call so
117 // the graph always reflects the data @c run was invoked with.
118 const std::size_t k = minSamples + m_config.kExtra;
120 m_index.emplace(k, m_config.maxIter, m_config.delta, m_config.seed);
121 m_index->build(X, pool);
122
123 const auto &neighbors = m_index->neighbors();
124
125 // Phase 2: core distance per point is the minSamples-th nearest squared distance. The
126 // NN-Descent graph returns entries sorted ascending by squared distance; index minSamples - 1
127 // picks the correct slot. We assume the graph's k is at least minSamples so the slot exists.
128 T *coreDistData = out.coreDistances.data();
129 for (std::size_t i = 0; i < n; ++i) {
130 CLUSTERING_ALWAYS_ASSERT(neighbors[i].size() >= minSamples);
131 coreDistData[i] = neighbors[i][minSamples - 1].sqDist;
132 }
133
134 // Phase 3: MRD-weighted edge list. The kNN graph is directed (i -> j when j is a neighbour of
135 // i); to form an undirected edge set we emit (min(i,j), max(i,j), weight) and dedupe by
136 // sorting plus a forward sweep. The weight takes the max of core[i], core[j], and sqDist(i,j).
137 struct Edge {
138 T weight;
139 std::int32_t u;
140 std::int32_t v;
141 };
142 std::vector<Edge> edges;
143 edges.reserve(n * k);
144 for (std::size_t i = 0; i < n; ++i) {
145 const T coreI = coreDistData[i];
146 for (const auto &e : neighbors[i]) {
147 const auto j = static_cast<std::size_t>(e.idx);
148 if (j == i) {
149 continue;
150 }
151 const T coreJ = coreDistData[j];
152 T w = e.sqDist;
153 if (coreI > w) {
154 w = coreI;
155 }
156 if (coreJ > w) {
157 w = coreJ;
158 }
159 const auto u32 = static_cast<std::int32_t>(std::min(i, j));
160 const auto v32 = static_cast<std::int32_t>(std::max(i, j));
161 edges.push_back(Edge{w, u32, v32});
162 }
163 }
164
165 // Sort ascending by weight then by endpoints so Kruskal consumes a deterministic order and
166 // duplicate (u, v) entries from the directed kNN view become adjacent. We do not explicitly
167 // dedupe: Kruskal's union-find check rejects a second edge between the same pair implicitly.
168 std::sort(edges.begin(), edges.end(), [](const Edge &a, const Edge &b) {
169 if (a.weight != b.weight) {
170 return a.weight < b.weight;
171 }
172 if (a.u != b.u) {
173 return a.u < b.u;
174 }
175 return a.v < b.v;
176 });
177
178 // Phase 4: Kruskal on the sorted edge list. The UnionFind is keyed on the signed int32 MST
179 // index type; cast via uint32 for the union-find that requires an unsigned index.
181 for (const Edge &e : edges) {
182 if (uf.unite(static_cast<std::uint32_t>(e.u), static_cast<std::uint32_t>(e.v))) {
183 out.edges.push_back(MstEdge<T>{e.u, e.v, e.weight});
184 if (out.edges.size() + 1 == n) {
185 break;
186 }
187 }
188 }
189
190 // Phase 5: connectivity fallback. When the kNN-graph MRD-MST leaves more than one component,
191 // enumerate minimum-weight bridges between every pair of disjoint components and Kruskal them
192 // in until a single spanning tree remains. Each bridge candidate is the (i, j) MRD edge of
193 // minimum weight where i and j live in distinct components; weight uses the true squared
194 // Euclidean distance computed row-by-row, then lifted by max with both core distances.
195 if (uf.countComponents() > 1) {
196 resolveDisconnectedComponents(X, coreDistData, pool, uf, out);
197 }
198 }
199
200private:
215 void resolveDisconnectedComponents(const NDArray<T, 2> &X, const T *coreDistData, math::Pool pool,
216 UnionFind<std::uint32_t> &uf, MstOutput<T> &out) {
217 const std::size_t n = X.dim(0);
218 const std::size_t d = X.dim(1);
219
220 // Materialize the current component membership as a root -> member list so subsequent
221 // all-pairs scans iterate only over the members of each component rather than the full
222 // point set per pair.
223 std::vector<std::vector<std::uint32_t>> members;
224 std::vector<std::uint32_t> rootToSlot(n, std::numeric_limits<std::uint32_t>::max());
225 for (std::uint32_t i = 0; i < static_cast<std::uint32_t>(n); ++i) {
226 const std::uint32_t r = uf.find(i);
227 std::uint32_t slot = rootToSlot[r];
228 if (slot == std::numeric_limits<std::uint32_t>::max()) {
229 slot = static_cast<std::uint32_t>(members.size());
230 rootToSlot[r] = slot;
231 members.emplace_back();
232 }
233 members[slot].push_back(i);
234 }
235
236 while (uf.countComponents() > 1) {
237 // For each pair of surviving components (a, b) find the minimum-MRD bridge. Compare every
238 // member of a to every member of b; track the minimum-weight edge. When all pairs are
239 // scanned, Kruskal the resulting bridges into the MST. The outer loop runs c - 1 times in
240 // the worst case but typically terminates after one round because inserting a single bridge
241 // collapses components transitively.
242 struct Bridge {
243 T weight;
244 std::int32_t u;
245 std::int32_t v;
246 };
247
248 // Materialise the live (a, b) work list before fan-out so the parallel body indexes a flat
249 // array. Skipping empty slots and slots that already share a root keeps the worker count
250 // honest after the first Kruskal round collapses components transitively.
251 struct PairIdx {
252 std::uint32_t a;
253 std::uint32_t b;
254 };
255 std::vector<PairIdx> pairs;
256 pairs.reserve(members.size() * (members.size() - 1) / 2);
257 for (std::uint32_t a = 0; a < static_cast<std::uint32_t>(members.size()); ++a) {
258 if (members[a].empty()) {
259 continue;
260 }
261 const std::uint32_t rootA = uf.find(members[a][0]);
262 for (std::uint32_t b = a + 1; b < static_cast<std::uint32_t>(members.size()); ++b) {
263 if (members[b].empty()) {
264 continue;
265 }
266 if (uf.find(members[b][0]) == rootA) {
267 continue;
268 }
269 pairs.push_back(PairIdx{a, b});
270 }
271 }
272
273 std::vector<Bridge> bridges(pairs.size(), Bridge{std::numeric_limits<T>::infinity(), 0, 0});
274
275 auto computePair = [&](std::size_t pairIdx) noexcept {
276 const auto &p = pairs[pairIdx];
277 const auto &memA = members[p.a];
278 const auto &memB = members[p.b];
279 T bestW = std::numeric_limits<T>::infinity();
280 std::int32_t bestU = 0;
281 std::int32_t bestV = 0;
282 for (const std::uint32_t ia : memA) {
283 const T coreI = coreDistData[ia];
284 const T *rowI = X.data() + (static_cast<std::size_t>(ia) * d);
285 for (const std::uint32_t jb : memB) {
286 const T coreJ = coreDistData[jb];
287 const T *rowJ = X.data() + (static_cast<std::size_t>(jb) * d);
288 const T sq = math::detail::sqEuclideanRowPtr(rowI, rowJ, d);
289 T w = sq;
290 if (coreI > w) {
291 w = coreI;
292 }
293 if (coreJ > w) {
294 w = coreJ;
295 }
296 if (w < bestW) {
297 bestW = w;
298 bestU = static_cast<std::int32_t>(ia);
299 bestV = static_cast<std::int32_t>(jb);
300 }
301 }
302 }
303 bridges[pairIdx] = Bridge{bestW, bestU, bestV};
304 };
305
306 // Cost per pair scales with `|memA| * |memB| * d`; the gate uses the largest pair as a
307 // proxy because per-pair work is wildly heterogeneous and underestimating any single big
308 // pair would leave one worker stuck while others finish.
309 std::size_t maxPairOps = 0;
310 for (const auto &p : pairs) {
311 const std::size_t ops = members[p.a].size() * members[p.b].size() * d;
312 maxPairOps = std::max(maxPairOps, ops);
313 }
314 const std::size_t totalPairOps = maxPairOps * pairs.size();
315 if (pool.pool != nullptr && pool.shouldParallelizeWork(totalPairOps)) {
316 pool.pool
317 ->submit_blocks(std::size_t{0}, pairs.size(),
318 [&](std::size_t lo, std::size_t hi) {
319 for (std::size_t i = lo; i < hi; ++i) {
320 computePair(i);
321 }
322 })
323 .wait();
324 } else {
325 for (std::size_t i = 0; i < pairs.size(); ++i) {
326 computePair(i);
327 }
328 }
329
330 // Drop the infinity sentinel before the sort so Kruskal never sees an unset slot. In
331 // practice every pair admits at least one finite bridge (cross-component points exist).
332 bridges.erase(std::remove_if(bridges.begin(), bridges.end(),
333 [](const Bridge &br) {
334 return br.weight == std::numeric_limits<T>::infinity();
335 }),
336 bridges.end());
337
338 // Sort bridges ascending by weight and Kruskal them into the MST. A bridge whose endpoints
339 // are already merged (via a prior bridge in the same round) is skipped by union-find.
340 std::sort(bridges.begin(), bridges.end(),
341 [](const Bridge &a, const Bridge &b) { return a.weight < b.weight; });
342 bool progress = false;
343 for (const Bridge &br : bridges) {
344 if (uf.unite(static_cast<std::uint32_t>(br.u), static_cast<std::uint32_t>(br.v))) {
345 out.edges.push_back(MstEdge<T>{br.u, br.v, br.weight});
346 progress = true;
347 if (out.edges.size() + 1 == n) {
348 break;
349 }
350 }
351 }
352
353 // Guard against an infinite loop: if no bridge was accepted the graph cannot reach
354 // connectivity. In practice the precondition (cross-component points exist) guarantees a
355 // finite bridge and this branch is never taken.
356 CLUSTERING_ALWAYS_ASSERT(progress);
357
358 // Rebuild the component membership for the next iteration: a merge may collapse multiple
359 // slots into one root, so the next round's all-pairs scan should only visit surviving
360 // components. Slots whose root changed get their members absorbed into the winning slot.
361 if (uf.countComponents() > 1) {
362 std::vector<std::vector<std::uint32_t>> nextMembers;
363 std::fill(rootToSlot.begin(), rootToSlot.end(), std::numeric_limits<std::uint32_t>::max());
364 for (const auto &comp : members) {
365 for (const std::uint32_t i : comp) {
366 const std::uint32_t r = uf.find(i);
367 std::uint32_t slot = rootToSlot[r];
368 if (slot == std::numeric_limits<std::uint32_t>::max()) {
369 slot = static_cast<std::uint32_t>(nextMembers.size());
370 rootToSlot[r] = slot;
371 nextMembers.emplace_back();
372 }
373 nextMembers[slot].push_back(i);
374 }
375 }
376 members = std::move(nextMembers);
377 }
378 }
379 }
380
381 NnDescentMstConfig m_config{};
382 std::optional<index::NnDescentIndex<T>> m_index;
383};
384
385} // namespace clustering::hdbscan
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
const T * data() const noexcept
Provides read-only access to the internal data array.
Definition ndarray.h:503
Disjoint-set-union with iterative path compression and union-by-rank.
Definition dsu.h:22
NnDescentMstBackend(NnDescentMstConfig config)
Construct with an explicit tuning config.
NnDescentMstBackend()=default
Default-construct with the default NnDescentMstConfig.
void run(const NDArray< T, 2 > &X, std::size_t minSamples, math::Pool pool, MstOutput< T > &out)
Build the approximate MRD-weighted minimum spanning tree of X.
One edge of the minimum spanning tree of mutual-reachability distances.
Definition mst_output.h:22
Frozen output contract of every MST backend.
Definition mst_output.h:41
NDArray< T, 1 > coreDistances
Per-point core distance (length N; self-excluded kNN distance at minSamples).
Definition mst_output.h:45
std::vector< MstEdge< T > > edges
The N - 1 MST edges, in insertion order.
Definition mst_output.h:43
Tuning knobs for the NN-Descent MST backend.
std::size_t kExtra
Extra neighbours per node on top of minSamples; the kNN graph is built with k = minSamples + kExtra.
float delta
Convergence threshold on the update fraction; forwarded to NnDescentIndex.
std::size_t maxIter
Iteration cap on the NN-Descent join loop; forwarded to NnDescentIndex.
std::uint64_t seed
PRNG seed for RP-tree partition choices; forwarded to NnDescentIndex.
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63