Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
nn_descent.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <cstddef>
5#include <cstdint>
6#include <optional>
7#include <type_traits>
8#include <utility>
9#include <vector>
10
12#include "clustering/index/nn_descent/detail/join_step.h"
13#include "clustering/index/nn_descent/detail/neighbor_heap.h"
14#include "clustering/index/nn_descent/detail/rp_tree_init.h"
16#include "clustering/ndarray.h"
17
18namespace clustering::index {
19
47template <class T> class NnDescentIndex {
48 static_assert(std::is_same_v<T, float>,
49 "NnDescentIndex<T> supports only float; add a specialization to extend.");
50
51public:
53 struct KnnEntry {
55 std::int32_t idx;
58 };
59
74 explicit NnDescentIndex(std::size_t k, std::size_t maxIter = 10, T delta = T{0.001},
75 std::uint64_t seed = 0)
76 : m_k(k), m_maxIter(maxIter), m_delta(delta), m_seed(seed) {
78 }
79
80 NnDescentIndex(const NnDescentIndex &) = delete;
86 ~NnDescentIndex() = default;
87
98 void build(const NDArray<T, 2> &X, math::Pool pool) {
99 const std::size_t n = X.dim(0);
100 const std::size_t d = X.dim(1);
101 const bool kFitsN = (n == 0) || (m_k < n);
103
104 const bool sameShape =
105 (m_lastN == n) && (m_lastD == d) && (m_lastK == m_k) && m_bank.has_value();
106
107 m_lastIterations = 0;
108
109 if (!sameShape) {
110 // Cold start: discard any prior bank and rebuild from RP-tree init.
111 m_bank.emplace(n, m_k);
112 if (n == 0 || m_k == 0) {
113 captureShape(X, n, d);
114 m_neighborsView.assign(n, {});
115 return;
116 }
117 // Leaf limit sized to roughly @c 2k so the leaf-pair enumeration is `O(k^2)` per leaf,
118 // matching Dong 2011's recommendation. A minimum of `max(2k, 8)` keeps very small @c k
119 // from degenerating into singleton leaves.
120 const std::size_t leafLimit = std::max<std::size_t>(2 * m_k, 8);
121 nn_descent::detail::RpTreeInit<T>::build(X, leafLimit, m_seed, *m_bank);
122 } else {
123 // Warm start: reload the bank from the previously published sorted view so the heap
124 // invariant is restored after the destructive @c toSortedLists pass, then re-flag every
125 // slot as "new" so the next join iteration revisits them.
126 reloadBankFromView();
127 }
128
129 if (n == 0 || m_k == 0) {
130 captureShape(X, n, d);
131 m_neighborsView.assign(n, {});
132 return;
133 }
134
135 // Join loop. Each iteration counts updated slots; once the fraction drops below @p delta the
136 // graph is considered converged.
137 const std::size_t total = n * m_k;
138 for (std::size_t iter = 0; iter < m_maxIter; ++iter) {
139 const std::size_t updates = nn_descent::detail::JoinStep<T>::run(X, *m_bank, pool);
140 ++m_lastIterations;
141 if (total == 0) {
142 break;
143 }
144 const double ratio = static_cast<double>(updates) / static_cast<double>(total);
145 if (ratio < static_cast<double>(m_delta)) {
146 break;
147 }
148 }
149
150 captureShape(X, n, d);
151
152 // Materialize the public neighbor view, sorted ascending by squared distance. This
153 // destroys the bank's heap order; a subsequent warm-start call reheapifies via
154 // @c reloadBankFromView below.
155 m_neighborsView = m_bank->template toSortedLists<KnnEntry>();
156 m_bankOrderValid = false;
157 }
158
161 [[nodiscard]] const std::vector<std::vector<KnnEntry>> &neighbors() const noexcept {
162 return m_neighborsView;
163 }
164
171 [[nodiscard]] bool isConnected() const {
172 const std::size_t n = m_neighborsView.size();
173 if (n <= 1) {
174 return true;
175 }
176 std::vector<std::vector<std::int32_t>> adj(n);
177 for (std::size_t u = 0; u < n; ++u) {
178 for (const KnnEntry &e : m_neighborsView[u]) {
179 if (e.idx < 0 || std::cmp_greater_equal(e.idx, n)) {
180 continue;
181 }
182 adj[u].push_back(e.idx);
183 adj[static_cast<std::size_t>(e.idx)].push_back(static_cast<std::int32_t>(u));
184 }
185 }
186 std::vector<std::uint8_t> visited(n, 0);
187 std::vector<std::int32_t> stack;
188 stack.reserve(n);
189 stack.push_back(0);
190 visited[0] = 1;
191 std::size_t visitedCount = 1;
192 while (!stack.empty()) {
193 const std::int32_t u = stack.back();
194 stack.pop_back();
195 for (const std::int32_t v : adj[static_cast<std::size_t>(u)]) {
196 if (visited[static_cast<std::size_t>(v)] == 0U) {
197 visited[static_cast<std::size_t>(v)] = 1;
198 ++visitedCount;
199 stack.push_back(v);
200 }
201 }
202 }
203 return visitedCount == n;
204 }
205
207 [[nodiscard]] std::size_t k() const noexcept { return m_k; }
208
211 [[nodiscard]] std::size_t lastIterations() const noexcept { return m_lastIterations; }
212
213private:
223 void reloadBankFromView() {
224 if (!m_bank.has_value()) {
225 return;
226 }
227 if (m_bankOrderValid) {
228 // Bank is already in heap order; we only need to re-arm the "new" flag so the join loop
229 // revisits every slot.
230 m_bank->rearmAllAsNew();
231 return;
232 }
233 std::vector<std::pair<T, std::int32_t>> buf;
234 buf.reserve(m_k);
235 for (std::size_t i = 0; i < m_neighborsView.size(); ++i) {
236 buf.clear();
237 for (const KnnEntry &e : m_neighborsView[i]) {
238 buf.emplace_back(e.sqDist, e.idx);
239 }
240 m_bank->loadFromSorted(static_cast<std::int32_t>(i), buf);
241 }
242 m_bankOrderValid = true;
243 }
244
245 void captureShape(const NDArray<T, 2> & /*X*/, std::size_t n, std::size_t d) noexcept {
246 m_lastN = n;
247 m_lastD = d;
248 m_lastK = m_k;
249 }
250
251 std::size_t m_k;
252 std::size_t m_maxIter;
253 T m_delta;
254 std::uint64_t m_seed;
255
256 std::optional<nn_descent::detail::NeighborHeapBank<T>> m_bank;
257 std::vector<std::vector<KnnEntry>> m_neighborsView;
258
259 std::size_t m_lastN = 0;
260 std::size_t m_lastD = 0;
261 std::size_t m_lastK = 0;
262 std::size_t m_lastIterations = 0;
265 bool m_bankOrderValid = false;
266};
267
268} // namespace clustering::index
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
Definition ndarray.h:461
const std::vector< std::vector< KnnEntry > > & neighbors() const noexcept
Per-node k -NN neighbor list, sorted ascending by squared distance.
Definition nn_descent.h:161
NnDescentIndex & operator=(const NnDescentIndex &)=delete
NnDescentIndex(NnDescentIndex &&)=default
Defaulted move constructor; transfers the neighbor bank and scratch.
NnDescentIndex(const NnDescentIndex &)=delete
void build(const NDArray< T, 2 > &X, math::Pool pool)
Build (or rebuild) the approximate kNN graph for X.
Definition nn_descent.h:98
std::size_t k() const noexcept
k specified at construction.
Definition nn_descent.h:207
bool isConnected() const
Whether the undirected k-NN graph covers every node in a single connected component.
Definition nn_descent.h:171
NnDescentIndex & operator=(NnDescentIndex &&)=default
Defaulted move assignment; transfers the neighbor bank and scratch.
std::size_t lastIterations() const noexcept
Number of join iterations actually executed during the most recent build.
Definition nn_descent.h:211
NnDescentIndex(std::size_t k, std::size_t maxIter=10, T delta=T{0.001}, std::uint64_t seed=0)
Construct an index targeting k neighbors per node.
Definition nn_descent.h:74
Per-node kNN entry returned by neighbors. Squared Euclidean distance carried as T.
Definition nn_descent.h:53
std::int32_t idx
Neighbor point index.
Definition nn_descent.h:55
T sqDist
Squared Euclidean distance from the query to idx.
Definition nn_descent.h:57
Thin injection wrapper around a BS::light_thread_pool.
Definition thread.h:63