Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
auto_mst_backend.h
Go to the documentation of this file.
1#pragma once
2
3#include <cstddef>
4#include <type_traits>
5#include <variant>
6
12#include "clustering/ndarray.h"
13
14namespace clustering::hdbscan {
15
48template <class T> class AutoMstBackend {
49 static_assert(std::is_same_v<T, float>,
50 "AutoMstBackend<T> supports only float; a double specialization is out of scope.");
51
52public:
53#ifdef CLUSTERING_HDBSCAN_BORUVKA_LOW_DIM_CEIL
63 static constexpr std::size_t boruvkaLowDimCeil = CLUSTERING_HDBSCAN_BORUVKA_LOW_DIM_CEIL;
64#else
66 static constexpr std::size_t boruvkaLowDimCeil = 16;
67#endif
68
69#ifdef CLUSTERING_HDBSCAN_BORUVKA_DIM_CEIL
76 static constexpr std::size_t boruvkaDimCeil = CLUSTERING_HDBSCAN_BORUVKA_DIM_CEIL;
77#else
80 static constexpr std::size_t boruvkaDimCeil = 60;
81#endif
82
83 static_assert(boruvkaLowDimCeil <= boruvkaDimCeil,
84 "boruvkaLowDimCeil must not exceed boruvkaDimCeil; the dispatch order assumes a "
85 "nested Boruvka regime at low d and a relaxed one at moderate d.");
86
94 static constexpr bool primFitsBudget(std::size_t n) noexcept {
95 if (n == 0) {
96 return true;
97 }
98 constexpr std::size_t kNsqBudget = kPrimMrdMatrixByteBudget / sizeof(T);
99 return n <= kNsqBudget / n;
100 }
101
102 AutoMstBackend() = default;
103
112 void run(const NDArray<T, 2> &X, std::size_t minSamples, math::Pool pool, MstOutput<T> &out) {
113 ensureShape(X.dim(0), X.dim(1));
114 std::visit([&](auto &b) { b.run(X, minSamples, pool, out); }, m_held);
115 }
116
124 [[nodiscard]] std::size_t heldIndex() const noexcept { return m_held.index(); }
125
134 std::size_t peekArm(std::size_t n, std::size_t d) {
135 ensureShape(n, d);
136 return m_held.index();
137 }
138
139private:
140 void ensureShape(std::size_t n, std::size_t d) {
141 if (n == m_lastN && d == m_lastD) {
142 return;
143 }
144 if (d <= boruvkaLowDimCeil) {
145 if (!std::holds_alternative<BoruvkaMstBackend<T>>(m_held)) {
146 m_held.template emplace<BoruvkaMstBackend<T>>();
147 }
148 } else if (d <= boruvkaDimCeil && primFitsBudget(n)) {
149 // Streaming Prim beats both Boruvka and NN-Descent in the `d <= 60` band while @c n
150 // stays inside its quadratic compute budget. Above @c boruvkaDimCeil the dense pairwise
151 // work is dominated by the @c d-wide distance compute and NN-Descent wins; the cap keeps
152 // Prim out of that regime.
153 if (!std::holds_alternative<PrimMstBackend<T>>(m_held)) {
154 m_held.template emplace<PrimMstBackend<T>>();
155 }
156 } else if (d <= boruvkaDimCeil) {
157 if (!std::holds_alternative<BoruvkaMstBackend<T>>(m_held)) {
158 m_held.template emplace<BoruvkaMstBackend<T>>();
159 }
160 } else {
161 if (!std::holds_alternative<NnDescentMstBackend<T>>(m_held)) {
162 m_held.template emplace<NnDescentMstBackend<T>>();
163 }
164 }
165 m_lastN = n;
166 m_lastD = d;
167 }
168
169 // Variant order is fixed: Boruvka (0), Prim (1), NN-Descent (2). @ref heldIndex leans on this.
170 std::variant<BoruvkaMstBackend<T>, PrimMstBackend<T>, NnDescentMstBackend<T>> m_held{
171 std::in_place_type<PrimMstBackend<T>>};
172 std::size_t m_lastN = 0;
173 std::size_t m_lastD = 0;
174};
175
176} // namespace clustering::hdbscan
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
std::size_t peekArm(std::size_t n, std::size_t d)
Resolve the variant arm for (n, d) without running the pipeline.
static constexpr std::size_t boruvkaDimCeil
Dimensional ceiling above which KDTree Boruvka gives way to NN-Descent when Prim is out of byte budge...
void run(const NDArray< T, 2 > &X, std::size_t minSamples, math::Pool pool, MstOutput< T > &out)
Fit a backend arm chosen on the input shape and delegate run to it.
std::size_t heldIndex() const noexcept
Index of the currently held variant arm.
static constexpr std::size_t boruvkaLowDimCeil
Low-dimensional ceiling at or below which Boruvka is preferred regardless of N.
static constexpr bool primFitsBudget(std::size_t n) noexcept
Whether the Prim regime applies at n under the dense-MRD byte budget.
constexpr std::size_t kPrimMrdMatrixByteBudget
Equivalent byte-budget phrasing of kPrimMaxN, kept so callers that gate on n*n*sizeof(T) <= kPrimMrdM...
Frozen output contract of every MST backend.
Definition mst_output.h:41
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63