43 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
44 "AfkMc2Seeder<T> requires T to be float or double");
46#ifdef CLUSTERING_KMEANS_AFKMC2_K_FLOOR
54 static constexpr std::size_t
kFloor = CLUSTERING_KMEANS_AFKMC2_K_FLOOR;
57 static constexpr std::size_t
kFloor = 100;
60#ifdef CLUSTERING_KMEANS_AFKMC2_CHAIN_LENGTH
69 static constexpr std::size_t
chainLengthDefault = CLUSTERING_KMEANS_AFKMC2_CHAIN_LENGTH;
94 const std::size_t n = X.
dim(0);
95 const std::size_t d = X.
dim(1);
111 const T *xData = X.
data();
112 T *centroidsData = outCentroids.
data();
113 T *qData = m_q.
data();
117 std::memcpy(centroidsData, xData + (first * d), d *
sizeof(T));
126 const T *firstRow = centroidsData;
128 for (std::size_t i = 0; i < n; ++i) {
129 const T d2 = sqEuclideanRowPtr(xData + (i * d), firstRow, d);
134 const T invN = T{1} /
static_cast<T
>(n);
136 const T invSum = T{1} / sumD2;
137 for (std::size_t i = 0; i < n; ++i) {
138 qData[i] = (T{0.5} * qData[i] * invSum) + (T{0.5} * invN);
143 for (std::size_t i = 0; i < n; ++i) {
150 T *qCumData = m_qCum.data();
152 for (std::size_t i = 0; i < n; ++i) {
154 qCumData[i] = running;
156 const T qTotal = qCumData[n - 1];
158 auto sampleFromQ = [&]()
noexcept -> std::size_t {
163 const std::size_t mid = lo + ((hi - lo) / 2);
164 if (qCumData[mid] > u) {
170 return lo < n ? lo : n - 1;
176 auto distToChosen = [&](std::size_t pointIdx, std::size_t chosenCount)
noexcept -> T {
177 const T *row = xData + (pointIdx * d);
178 T best = sqEuclideanRowPtr(row, centroidsData, d);
179 for (std::size_t c = 1; c < chosenCount; ++c) {
180 const T cand = sqEuclideanRowPtr(row, centroidsData + (c * d), d);
188 for (std::size_t c = 1; c < k; ++c) {
189 std::size_t xIdx = sampleFromQ();
190 T xDist = distToChosen(xIdx, c);
193 for (std::size_t step = 0; step < m; ++step) {
194 const std::size_t yIdx = sampleFromQ();
195 const T yDist = distToChosen(yIdx, c);
196 const T yQ = qData[yIdx];
204 const T numer = yDist * xQ;
205 const T denom = xDist * yQ;
207 const bool accept = (denom <= T{0}) || ((u * denom) < numer);
216 std::memcpy(centroidsData + (c * d), xData + (xIdx * d), d *
sizeof(T));
220 void ensureShape(std::size_t n) {
221 if (m_q.dim(0) != n) {
222 m_q = NDArray<T, 1>({n});
224 if (m_qCum.dim(0) != n) {
225 m_qCum = NDArray<T, 1>({n});
230 NDArray<T, 1> m_qCum;