108 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
109 "topk<T> requires T to be float or double");
110 assert(outIdx.size() == k &&
"topk requires outIdx.size() == k");
114 const std::size_t n = x.dim(0);
115 assert(k <= n &&
"topk requires k <= x.dim(0)");
117 std::vector<std::pair<T, std::size_t>> staged;
119 for (std::size_t i = 0; i < n; ++i) {
120 staged.emplace_back(x(i), i);
123 std::vector<std::pair<T, std::size_t>> top(k);
124 const auto cmp = [](
const std::pair<T, std::size_t> &a,
125 const std::pair<T, std::size_t> &b)
noexcept {
128 if (a.first != b.first) {
129 return a.first > b.first;
131 return a.second < b.second;
133 std::partial_sort_copy(staged.begin(), staged.end(), top.begin(), top.end(), cmp);
135 for (std::size_t i = 0; i < k; ++i) {
136 outIdx[i] = top[i].second;
void topk(const NDArray< T, 1, L > &x, std::size_t k, std::span< std::size_t > outIdx) noexcept
Indices of the top-k largest values, written in descending value order.