12#include "clustering/index/nn_descent/detail/join_step.h"
13#include "clustering/index/nn_descent/detail/neighbor_heap.h"
14#include "clustering/index/nn_descent/detail/rp_tree_init.h"
48 static_assert(std::is_same_v<T, float>,
49 "NnDescentIndex<T> supports only float; add a specialization to extend.");
74 explicit NnDescentIndex(std::size_t
k, std::size_t maxIter = 10, T delta = T{0.001},
75 std::uint64_t seed = 0)
76 : m_k(
k), m_maxIter(maxIter), m_delta(delta), m_seed(seed) {
99 const std::size_t n = X.
dim(0);
100 const std::size_t d = X.
dim(1);
101 const bool kFitsN = (n == 0) || (m_k < n);
104 const bool sameShape =
105 (m_lastN == n) && (m_lastD == d) && (m_lastK == m_k) && m_bank.has_value();
107 m_lastIterations = 0;
111 m_bank.emplace(n, m_k);
112 if (n == 0 || m_k == 0) {
113 captureShape(X, n, d);
114 m_neighborsView.assign(n, {});
120 const std::size_t leafLimit = std::max<std::size_t>(2 * m_k, 8);
121 nn_descent::detail::RpTreeInit<T>::build(X, leafLimit, m_seed, *m_bank);
126 reloadBankFromView();
129 if (n == 0 || m_k == 0) {
130 captureShape(X, n, d);
131 m_neighborsView.assign(n, {});
137 const std::size_t total = n * m_k;
138 for (std::size_t iter = 0; iter < m_maxIter; ++iter) {
139 const std::size_t updates = nn_descent::detail::JoinStep<T>::run(X, *m_bank, pool);
144 const double ratio =
static_cast<double>(updates) /
static_cast<double>(total);
145 if (ratio <
static_cast<double>(m_delta)) {
150 captureShape(X, n, d);
155 m_neighborsView = m_bank->template toSortedLists<KnnEntry>();
156 m_bankOrderValid =
false;
161 [[nodiscard]]
const std::vector<std::vector<KnnEntry>> &
neighbors() const noexcept {
162 return m_neighborsView;
172 const std::size_t n = m_neighborsView.size();
176 std::vector<std::vector<std::int32_t>> adj(n);
177 for (std::size_t u = 0; u < n; ++u) {
178 for (
const KnnEntry &e : m_neighborsView[u]) {
179 if (e.idx < 0 || std::cmp_greater_equal(e.idx, n)) {
182 adj[u].push_back(e.idx);
183 adj[
static_cast<std::size_t
>(e.idx)].push_back(
static_cast<std::int32_t
>(u));
186 std::vector<std::uint8_t> visited(n, 0);
187 std::vector<std::int32_t> stack;
191 std::size_t visitedCount = 1;
192 while (!stack.empty()) {
193 const std::int32_t u = stack.back();
195 for (
const std::int32_t v : adj[
static_cast<std::size_t
>(u)]) {
196 if (visited[
static_cast<std::size_t
>(v)] == 0U) {
197 visited[
static_cast<std::size_t
>(v)] = 1;
203 return visitedCount == n;
207 [[nodiscard]] std::size_t
k() const noexcept {
return m_k; }
211 [[nodiscard]] std::size_t
lastIterations() const noexcept {
return m_lastIterations; }
223 void reloadBankFromView() {
224 if (!m_bank.has_value()) {
227 if (m_bankOrderValid) {
230 m_bank->rearmAllAsNew();
233 std::vector<std::pair<T, std::int32_t>> buf;
235 for (std::size_t i = 0; i < m_neighborsView.size(); ++i) {
237 for (
const KnnEntry &e : m_neighborsView[i]) {
238 buf.emplace_back(e.sqDist, e.idx);
240 m_bank->loadFromSorted(
static_cast<std::int32_t
>(i), buf);
242 m_bankOrderValid =
true;
245 void captureShape(
const NDArray<T, 2> & , std::size_t n, std::size_t d)
noexcept {
252 std::size_t m_maxIter;
254 std::uint64_t m_seed;
256 std::optional<nn_descent::detail::NeighborHeapBank<T>> m_bank;
257 std::vector<std::vector<KnnEntry>> m_neighborsView;
259 std::size_t m_lastN = 0;
260 std::size_t m_lastD = 0;
261 std::size_t m_lastK = 0;
262 std::size_t m_lastIterations = 0;
265 bool m_bankOrderValid =
false;
#define CLUSTERING_ALWAYS_ASSERT(cond)
Release-active assertion: evaluates cond in every build configuration.
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
size_t dim(std::size_t index) const noexcept
Returns the size of a specific dimension of the NDArray.
const std::vector< std::vector< KnnEntry > > & neighbors() const noexcept
Per-node k -NN neighbor list, sorted ascending by squared distance.
NnDescentIndex & operator=(const NnDescentIndex &)=delete
NnDescentIndex(NnDescentIndex &&)=default
Defaulted move constructor; transfers the neighbor bank and scratch.
NnDescentIndex(const NnDescentIndex &)=delete
void build(const NDArray< T, 2 > &X, math::Pool pool)
Build (or rebuild) the approximate kNN graph for X.
std::size_t k() const noexcept
k specified at construction.
~NnDescentIndex()=default
bool isConnected() const
Whether the undirected k-NN graph covers every node in a single connected component.
NnDescentIndex & operator=(NnDescentIndex &&)=default
Defaulted move assignment; transfers the neighbor bank and scratch.
std::size_t lastIterations() const noexcept
Number of join iterations actually executed during the most recent build.
NnDescentIndex(std::size_t k, std::size_t maxIter=10, T delta=T{0.001}, std::uint64_t seed=0)
Construct an index targeting k neighbors per node.
Per-node kNN entry returned by neighbors. Squared Euclidean distance carried as T.
std::int32_t idx
Neighbor point index.
T sqDist
Squared Euclidean distance from the query to idx.
Thin injection wrapper around a BS::light_thread_pool.