74 std::is_same_v<T, float>,
75 "NnDescentMstBackend<T> supports only float; a double specialization is out of scope.");
106 const std::size_t n = X.
dim(0);
111 out.
edges.reserve(n - 1);
118 const std::size_t k = minSamples + m_config.kExtra;
120 m_index.emplace(k, m_config.maxIter, m_config.delta, m_config.seed);
121 m_index->build(X, pool);
123 const auto &neighbors = m_index->neighbors();
129 for (std::size_t i = 0; i < n; ++i) {
131 coreDistData[i] = neighbors[i][minSamples - 1].sqDist;
142 std::vector<Edge> edges;
143 edges.reserve(n * k);
144 for (std::size_t i = 0; i < n; ++i) {
145 const T coreI = coreDistData[i];
146 for (
const auto &e : neighbors[i]) {
147 const auto j =
static_cast<std::size_t
>(e.idx);
151 const T coreJ = coreDistData[j];
159 const auto u32 =
static_cast<std::int32_t
>(std::min(i, j));
160 const auto v32 =
static_cast<std::int32_t
>(std::max(i, j));
161 edges.push_back(Edge{w, u32, v32});
168 std::sort(edges.begin(), edges.end(), [](
const Edge &a,
const Edge &b) {
169 if (a.weight != b.weight) {
170 return a.weight < b.weight;
181 for (
const Edge &e : edges) {
182 if (uf.unite(
static_cast<std::uint32_t
>(e.u),
static_cast<std::uint32_t
>(e.v))) {
183 out.edges.push_back(
MstEdge<T>{e.u, e.v, e.weight});
184 if (out.edges.size() + 1 == n) {
195 if (uf.countComponents() > 1) {
196 resolveDisconnectedComponents(X, coreDistData, pool, uf, out);
215 void resolveDisconnectedComponents(
const NDArray<T, 2> &X,
const T *coreDistData, math::Pool pool,
217 const std::size_t n = X.dim(0);
218 const std::size_t d = X.dim(1);
223 std::vector<std::vector<std::uint32_t>> members;
224 std::vector<std::uint32_t> rootToSlot(n, std::numeric_limits<std::uint32_t>::max());
225 for (std::uint32_t i = 0; i < static_cast<std::uint32_t>(n); ++i) {
226 const std::uint32_t r = uf.find(i);
227 std::uint32_t slot = rootToSlot[r];
228 if (slot == std::numeric_limits<std::uint32_t>::max()) {
229 slot =
static_cast<std::uint32_t
>(members.size());
230 rootToSlot[r] = slot;
231 members.emplace_back();
233 members[slot].push_back(i);
236 while (uf.countComponents() > 1) {
255 std::vector<PairIdx> pairs;
256 pairs.reserve(members.size() * (members.size() - 1) / 2);
257 for (std::uint32_t a = 0; a < static_cast<std::uint32_t>(members.size()); ++a) {
258 if (members[a].empty()) {
261 const std::uint32_t rootA = uf.find(members[a][0]);
262 for (std::uint32_t b = a + 1; b < static_cast<std::uint32_t>(members.size()); ++b) {
263 if (members[b].empty()) {
266 if (uf.find(members[b][0]) == rootA) {
269 pairs.push_back(PairIdx{a, b});
273 std::vector<Bridge> bridges(pairs.size(), Bridge{std::numeric_limits<T>::infinity(), 0, 0});
275 auto computePair = [&](std::size_t pairIdx)
noexcept {
276 const auto &p = pairs[pairIdx];
277 const auto &memA = members[p.a];
278 const auto &memB = members[p.b];
279 T bestW = std::numeric_limits<T>::infinity();
280 std::int32_t bestU = 0;
281 std::int32_t bestV = 0;
282 for (
const std::uint32_t ia : memA) {
283 const T coreI = coreDistData[ia];
284 const T *rowI = X.data() + (
static_cast<std::size_t
>(ia) * d);
285 for (
const std::uint32_t jb : memB) {
286 const T coreJ = coreDistData[jb];
287 const T *rowJ = X.data() + (
static_cast<std::size_t
>(jb) * d);
288 const T sq = math::detail::sqEuclideanRowPtr(rowI, rowJ, d);
298 bestU =
static_cast<std::int32_t
>(ia);
299 bestV =
static_cast<std::int32_t
>(jb);
303 bridges[pairIdx] = Bridge{bestW, bestU, bestV};
309 std::size_t maxPairOps = 0;
310 for (
const auto &p : pairs) {
311 const std::size_t ops = members[p.a].size() * members[p.b].size() * d;
312 maxPairOps = std::max(maxPairOps, ops);
314 const std::size_t totalPairOps = maxPairOps * pairs.size();
315 if (pool.pool !=
nullptr && pool.shouldParallelizeWork(totalPairOps)) {
317 ->submit_blocks(std::size_t{0}, pairs.size(),
318 [&](std::size_t lo, std::size_t hi) {
319 for (std::size_t i = lo; i < hi; ++i) {
325 for (std::size_t i = 0; i < pairs.size(); ++i) {
332 bridges.erase(std::remove_if(bridges.begin(), bridges.end(),
333 [](
const Bridge &br) {
334 return br.weight == std::numeric_limits<T>::infinity();
340 std::sort(bridges.begin(), bridges.end(),
341 [](
const Bridge &a,
const Bridge &b) { return a.weight < b.weight; });
342 bool progress =
false;
343 for (
const Bridge &br : bridges) {
344 if (uf.unite(
static_cast<std::uint32_t
>(br.u),
static_cast<std::uint32_t
>(br.v))) {
345 out.edges.push_back(
MstEdge<T>{br.u, br.v, br.weight});
347 if (out.edges.size() + 1 == n) {
361 if (uf.countComponents() > 1) {
362 std::vector<std::vector<std::uint32_t>> nextMembers;
363 std::fill(rootToSlot.begin(), rootToSlot.end(), std::numeric_limits<std::uint32_t>::max());
364 for (
const auto &comp : members) {
365 for (
const std::uint32_t i : comp) {
366 const std::uint32_t r = uf.find(i);
367 std::uint32_t slot = rootToSlot[r];
368 if (slot == std::numeric_limits<std::uint32_t>::max()) {
369 slot =
static_cast<std::uint32_t
>(nextMembers.size());
370 rootToSlot[r] = slot;
371 nextMembers.emplace_back();
373 nextMembers[slot].push_back(i);
376 members = std::move(nextMembers);
382 std::optional<index::NnDescentIndex<T>> m_index;