109 : m_allocator(calculatePoolSize(points.
dim(0))), m_points(points), m_dim(points.
dim(1)) {
111 const std::size_t n = points.
dim(0);
113 std::iota(m_indices.begin(), m_indices.end(), 0);
114 m_root = build(0, n, 0);
120 m_points_reordered.resize(n * m_dim);
121 const T *src = points.
data();
122 T *dst = m_points_reordered.data();
123 for (std::size_t k = 0; k < n; ++k) {
124 const std::size_t src_row = m_indices[k];
125 const T *s = src + (src_row * m_dim);
126 T *d = dst + (k * m_dim);
127 for (std::size_t j = 0; j < m_dim; ++j) {
137 m_nodeBounds.assign(
static_cast<std::size_t
>(m_nextNodeId) * 2 * m_dim, T{});
138 if (m_root !=
nullptr) {
139 populateBounds(m_root);
156 std::int64_t limit = -1)
const {
157 std::vector<std::size_t> indices;
158 const T radius_sq = radius * radius;
159 std::vector<KDTreeNode *> stack;
160 stack.reserve(kDefaultStackReserve);
163 std::vector<T> qbuf(m_dim);
164 for (std::size_t k = 0; k < m_dim; ++k) {
165 qbuf[k] = query_point[k];
167 queryImpl(m_root, qbuf.data(), radius_sq, indices, stack, limit);
183 [[nodiscard]] std::vector<std::vector<std::int32_t>>
query(T radius,
math::Pool pool)
const {
184 const std::size_t n = m_points.dim(0);
185 std::vector<std::vector<std::int32_t>> adj(n);
190 const T radius_sq = radius * radius;
192 auto runRange = [&](std::size_t lo, std::size_t hi) {
196 std::vector<KDTreeNode *> stack;
197 stack.reserve(kDefaultStackReserve);
198 const T *sourceData = m_points.data();
199 for (std::size_t i = lo; i < hi; ++i) {
201 const T *qp = sourceData + (i * m_dim);
202 queryImpl(m_root, qp, radius_sq, adj[i], stack, -1);
208 ->submit_blocks(std::size_t{0}, n,
209 [&](std::size_t lo, std::size_t hi) { runRange(lo, hi); })
243 const std::size_t n = m_points.dim(0);
247 const auto kSz =
static_cast<std::size_t
>(k);
251 auto runRange = [&](std::size_t lo, std::size_t hi) {
254 std::vector<KDTreeNode *> stack;
255 stack.reserve(kDefaultStackReserve);
256 math::detail::TopKNeighbors<T, std::int32_t> topK(kSz);
257 const T *sourceData = m_points.data();
258 std::int32_t *idxOut = indices.data();
259 T *distOut = sqDists.data();
260 for (std::size_t i = lo; i < hi; ++i) {
261 const auto iOriginal =
static_cast<std::int32_t
>(i);
262 const T *qp = sourceData + (i * m_dim);
264 knnQueryImpl(m_root, qp, iOriginal, topK, stack);
265 topK.drainAscending(distOut + (i * kSz), idxOut + (i * kSz));
271 ->submit_blocks(std::size_t{0}, n,
272 [&](std::size_t lo, std::size_t hi) { runRange(lo, hi); })
277 return {std::move(indices), std::move(sqDists)};
292 [[nodiscard]] std::pair<std::span<const T>, std::span<const T>>
294 assert(node !=
nullptr &&
"KDTree::nodeBounds on null node");
295 const std::size_t base =
static_cast<std::size_t
>(node->m_id) * 2 * m_dim;
296 const T *bounds = m_nodeBounds.data() + base;
297 return {std::span<const T>(bounds, m_dim), std::span<const T>(bounds + m_dim, m_dim)};
311 return {m_indices.data(), m_indices.size()};
326 return {m_points_reordered.data(), m_points_reordered.size()};
336 return static_cast<std::size_t
>(m_nextNodeId);
343 [[nodiscard]] std::size_t
dim() const noexcept {
return m_dim; }
353 if (m_allocator.isDeallocSupported()) {
354 doRecDealloc(m_root);
369 static size_t calculatePoolSize(
size_t numPoints) {
370 if (numPoints == 0) {
387 static void doRecDealloc(KDTreeNode *node) {
388 if (node ==
nullptr) {
392 doRecDealloc(node->m_left);
393 doRecDealloc(node->m_right);
412 KDTreeNode *build(std::size_t start, std::size_t end, std::size_t depth) {
418 if (end - start <= LeafSize) {
419 KDTreeNode *node = m_allocator.allocate();
420 *node = {.m_index = start,
421 .m_dim = end - start,
424 .m_id = m_nextNodeId++};
429 const std::size_t
dim = depth % m_points.dim(1);
430 const std::size_t median = start + (((end - start) - 1) / 2);
432 using diff_t = std::vector<std::size_t>::difference_type;
433 std::nth_element(m_indices.begin() +
static_cast<diff_t
>(start),
434 m_indices.begin() +
static_cast<diff_t
>(median),
435 m_indices.begin() +
static_cast<diff_t
>(end),
436 [
this,
dim](std::size_t lhs, std::size_t rhs) {
437 return m_points[lhs][dim] < m_points[rhs][dim];
440 KDTreeNode *node = m_allocator.allocate();
441 *node = {.m_index = median,
443 .m_left = build(start, median, depth + 1),
444 .m_right = build(median + 1, end, depth + 1),
445 .m_id = m_nextNodeId++};
471 template <
class OutIdx>
472 void queryImpl(KDTreeNode *
root,
const T *qp, T radius_sq, std::vector<OutIdx> &indices,
473 std::vector<KDTreeNode *> &stack, std::int64_t limit = -1)
const {
474 if (
root ==
nullptr) {
479 stack.push_back(
root);
481 const T *reorderedBase = m_points_reordered.data();
483 while (!stack.empty()) {
484 const KDTreeNode *node = stack.back();
487 if (node ==
nullptr) {
491 if (limit != -1 && indices.size() ==
static_cast<std::size_t
>(limit)) {
500 if (node->m_left ==
nullptr && node->m_right ==
nullptr) {
501 const std::size_t base = node->m_index;
502 const std::size_t count = node->m_dim;
503 const T *leafPts = reorderedBase + (base * m_dim);
504 const bool isBounded = (limit != -1);
505 const std::size_t cap =
506 isBounded ?
static_cast<std::size_t
>(limit) : std::numeric_limits<std::size_t>::max();
507 math::detail::radiusScan(qp, leafPts, count, m_dim, radius_sq, [&](std::size_t i)
noexcept {
508 if (indices.size() >= cap) {
511 indices.push_back(
static_cast<OutIdx
>(m_indices[base + i]));
513 if (isBounded && indices.size() >= cap) {
523 const std::size_t pivotSlot = node->m_index;
524 const std::size_t splitDim = node->m_dim;
525 const T *pivotRow = reorderedBase + (pivotSlot * m_dim);
526 const T dist_sq = math::detail::sqEuclideanRowPtr(qp, pivotRow, m_dim);
527 if (dist_sq <= radius_sq) {
528 indices.push_back(
static_cast<OutIdx
>(m_indices[pivotSlot]));
531 const T pivotCoord = pivotRow[splitDim];
532 const T diff = qp[splitDim] - pivotCoord;
534 if (node->m_left !=
nullptr) {
535 stack.push_back(node->m_left);
537 if (diff * diff <= radius_sq && node->m_right !=
nullptr) {
538 stack.push_back(node->m_right);
541 if (node->m_right !=
nullptr) {
542 stack.push_back(node->m_right);
544 if (diff * diff <= radius_sq && node->m_left !=
nullptr) {
545 stack.push_back(node->m_left);
568 void knnQueryImpl(KDTreeNode *
root,
const T *qp, std::int32_t selfIndex,
569 math::detail::TopKNeighbors<T, std::int32_t> &topK,
570 std::vector<KDTreeNode *> &stack)
const {
571 if (
root ==
nullptr) {
576 stack.push_back(
root);
578 const T *reorderedBase = m_points_reordered.data();
580 while (!stack.empty()) {
581 const KDTreeNode *node = stack.back();
584 if (node ==
nullptr) {
590 const T bound = topK.boundKey();
597 if (bound != std::numeric_limits<T>::max()) {
600 if (gapSq >= bound) {
605 if (node->m_left ==
nullptr && node->m_right ==
nullptr) {
610 const std::size_t base = node->m_index;
611 const std::size_t count = node->m_dim;
612 const T *leafPts = reorderedBase + (base * m_dim);
613 std::array<T, LeafSize> dsqBuf{};
614 math::detail::sqDistancesAosBlock<T>(qp, leafPts, count, m_dim, dsqBuf.data());
622 T minDsq = dsqBuf[0];
623 for (std::size_t i = 1; i < count; ++i) {
624 if (dsqBuf[i] < minDsq) {
628 if (minDsq >= bound) {
632 for (std::size_t i = 0; i < count; ++i) {
633 const auto pointIdx =
static_cast<std::int32_t
>(m_indices[base + i]);
634 if (pointIdx == selfIndex) {
637 topK.push(dsqBuf[i], pointIdx);
647 const std::size_t pivotSlot = node->m_index;
648 const std::size_t splitDim = node->m_dim;
649 const T *pivotRow = reorderedBase + (pivotSlot * m_dim);
650 const auto pivotIdx =
static_cast<std::int32_t
>(m_indices[pivotSlot]);
651 const T pivotCoord = pivotRow[splitDim];
652 const T diff = qp[splitDim] - pivotCoord;
653 const T diffSq = diff * diff;
654 if (pivotIdx != selfIndex && diffSq <= bound) {
655 const T dist_sq = math::detail::sqEuclideanRowPtr(qp, pivotRow, m_dim);
656 topK.push(dist_sq, pivotIdx);
660 if (diffSq <= bound && node->m_right !=
nullptr) {
661 stack.push_back(node->m_right);
663 if (node->m_left !=
nullptr) {
664 stack.push_back(node->m_left);
667 if (diffSq <= bound && node->m_left !=
nullptr) {
668 stack.push_back(node->m_left);
670 if (node->m_right !=
nullptr) {
671 stack.push_back(node->m_right);
687 void populateBounds(KDTreeNode *node)
noexcept {
688 T *minOut = m_nodeBounds.data() + (
static_cast<std::size_t
>(node->m_id) * 2 * m_dim);
689 T *maxOut = minOut + m_dim;
691 if (node->m_left ==
nullptr && node->m_right ==
nullptr) {
693 const std::size_t base = node->m_index;
694 const std::size_t count = node->m_dim;
695 const T *leafPts = m_points_reordered.data() + (base * m_dim);
696 for (std::size_t j = 0; j < m_dim; ++j) {
697 minOut[j] = leafPts[j];
698 maxOut[j] = leafPts[j];
700 for (std::size_t i = 1; i < count; ++i) {
701 const T *row = leafPts + (i * m_dim);
702 for (std::size_t j = 0; j < m_dim; ++j) {
703 if (row[j] < minOut[j]) {
706 if (row[j] > maxOut[j]) {
718 if (node->m_left !=
nullptr) {
719 populateBounds(node->m_left);
721 if (node->m_right !=
nullptr) {
722 populateBounds(node->m_right);
725 const KDTreeNode *
const seed = (node->m_left !=
nullptr) ? node->m_left : node->m_right;
726 const T *seedMin = m_nodeBounds.data() + (
static_cast<std::size_t
>(seed->m_id) * 2 * m_dim);
727 const T *seedMax = seedMin + m_dim;
728 for (std::size_t j = 0; j < m_dim; ++j) {
729 minOut[j] = seedMin[j];
730 maxOut[j] = seedMax[j];
733 if (node->m_left !=
nullptr && node->m_right !=
nullptr) {
735 m_nodeBounds.data() + (
static_cast<std::size_t
>(node->m_right->m_id) * 2 * m_dim);
736 const T *otherMax = otherMin + m_dim;
737 for (std::size_t j = 0; j < m_dim; ++j) {
738 if (otherMin[j] < minOut[j]) {
739 minOut[j] = otherMin[j];
741 if (otherMax[j] > maxOut[j]) {
742 maxOut[j] = otherMax[j];
748 const std::size_t pivotSlot = node->m_index;
749 const T *pivotRow = m_points_reordered.data() + (pivotSlot * m_dim);
750 for (std::size_t j = 0; j < m_dim; ++j) {
751 if (pivotRow[j] < minOut[j]) {
752 minOut[j] = pivotRow[j];
754 if (pivotRow[j] > maxOut[j]) {
755 maxOut[j] = pivotRow[j];
763 static constexpr std::size_t kDefaultStackReserve = 64;
767 KDTreeNode *m_root =
nullptr;
768 const NDArray<T, 2> &m_points;
769 std::size_t m_dim = 0;
770 std::vector<std::size_t> m_indices;
772 std::vector<T> m_points_reordered;
777 std::uint32_t m_nextNodeId = 0;
780 std::vector<T> m_nodeBounds;