Clustering
C++20 header-only: DBSCAN, HDBSCAN, k-means.
Loading...
Searching...
No Matches
rng.h
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <cassert>
6#include <cmath>
7#include <cstddef>
8#include <cstdint>
9#include <span>
10#include <type_traits>
11#include <utility>
12#include <vector>
13
14#include "clustering/ndarray.h"
15
16#ifndef __SIZEOF_INT128__
17#error "clustering::math::rng requires a compiler with __uint128_t support (GCC/Clang)."
18#endif
19
20namespace clustering::math {
21
30struct pcg64 {
32 __uint128_t m_state = 0;
34 __uint128_t m_inc = 0;
35
46 void seed(std::uint64_t seedValue, std::uint64_t stream = 0) noexcept {
47 static constexpr __uint128_t kMultHi =
48 (static_cast<__uint128_t>(2549297995355413924ULL) << 64) | 4865540595714422341ULL;
49 m_state = 0;
50 m_inc = (static_cast<__uint128_t>(stream) << 1U) | 1U;
51 m_state = (m_state * kMultHi) + m_inc;
52 m_state += seedValue;
53 m_state = (m_state * kMultHi) + m_inc;
54 }
55};
56
63inline std::uint64_t advanceState(pcg64 &rng) noexcept {
64 static constexpr __uint128_t kMult =
65 (static_cast<__uint128_t>(2549297995355413924ULL) << 64) | 4865540595714422341ULL;
66 const __uint128_t old = rng.m_state;
67 rng.m_state = (old * kMult) + rng.m_inc;
68 const auto rot = static_cast<std::uint64_t>(old >> 122);
69 const auto xorshifted = static_cast<std::uint64_t>(old ^ (old >> 64));
70 return (xorshifted >> rot) | (xorshifted << ((-rot) & 63U));
71}
72
82 std::array<std::uint64_t, 4> m_s{0, 0, 0, 0};
83
92 void seed(std::uint64_t seedValue) noexcept {
93 std::uint64_t z = seedValue;
94 for (auto &word : m_s) {
95 z += 0x9E3779B97F4A7C15ULL;
96 std::uint64_t x = z;
97 x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
98 x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
99 x = x ^ (x >> 31);
100 word = x;
101 }
102 }
103};
104
111inline std::uint64_t advanceState(xoshiro256ss &rng) noexcept {
112 const auto rotl = [](std::uint64_t x, int k) -> std::uint64_t {
113 return (x << k) | (x >> (64 - k));
114 };
115 const std::uint64_t result = rotl(rng.m_s[1] * 5U, 7) * 9U;
116 const std::uint64_t t = rng.m_s[1] << 17U;
117 rng.m_s[2] ^= rng.m_s[0];
118 rng.m_s[3] ^= rng.m_s[1];
119 rng.m_s[1] ^= rng.m_s[2];
120 rng.m_s[0] ^= rng.m_s[3];
121 rng.m_s[2] ^= t;
122 rng.m_s[3] = rotl(rng.m_s[3], 45);
123 return result;
124}
125
132template <class Rng> inline std::uint32_t randUniformU32(Rng &rng) noexcept {
133 return static_cast<std::uint32_t>(advanceState(rng) >> 32U);
134}
135
139template <class Rng> inline std::uint64_t randUniformU64(Rng &rng) noexcept {
140 return advanceState(rng);
141}
142
152template <class T, class Rng> inline T randUnit(Rng &rng) noexcept {
153 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
154 "randUnit<T> requires T to be float or double");
155 if constexpr (std::is_same_v<T, double>) {
156 return static_cast<double>(advanceState(rng) >> 11U) * 0x1.0p-53;
157 } else {
158 return static_cast<float>(advanceState(rng) >> 40U) * 0x1.0p-24F;
159 }
160}
161
178template <class T, Layout L, class Rng>
179inline std::size_t weightedCategorical(const NDArray<T, 1, L> &weights, Rng &rng) noexcept {
180 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
181 "weightedCategorical<T> requires T to be float or double");
182 const std::size_t n = weights.dim(0);
183 assert(n > 0 && "weightedCategorical requires at least one weight");
184
185 T total = T{0};
186 for (std::size_t i = 0; i < n; ++i) {
187 const T w = weights(i);
188 assert(w >= T{0} && "weightedCategorical requires non-negative weights");
189 total += w;
190 }
191 assert(total > T{0} && "weightedCategorical requires at least one positive weight");
192
193 const T u = randUnit<T>(rng) * total;
194 T cumulative = T{0};
195 std::size_t lastPositive = 0;
196 for (std::size_t i = 0; i < n; ++i) {
197 const T w = weights(i);
198 cumulative += w;
199 if (cumulative > u) {
200 return i;
201 }
202 if (w > T{0}) {
203 lastPositive = i;
204 }
205 }
206 // Guard against floating-point drift pushing the final cumulative just below u*total: fall
207 // back to the last index that actually contributed mass so we never return a zero-weight slot.
208 return lastPositive;
209}
210
232template <class T, Layout L, class Rng>
233inline void aExpjReservoir(const NDArray<T, 1, L> &weights, std::size_t k, Rng &rng,
234 std::span<std::size_t> outIdx) noexcept {
235 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
236 "aExpjReservoir<T> requires T to be float or double");
237 const std::size_t n = weights.dim(0);
238 assert(outIdx.size() == k && "aExpjReservoir requires outIdx.size() == k");
239 assert(k <= n && "aExpjReservoir requires k <= weights.dim(0)");
240 if (k == 0) {
241 return;
242 }
243
244 // Generate all keys, partial-sort by key descending, emit the top-k indices. O(n log n) in the
245 // straight path; a size-k min-heap variant would trim this to O(n log k) when k << n.
246 std::vector<std::pair<T, std::size_t>> keyed;
247 keyed.reserve(n);
248 for (std::size_t i = 0; i < n; ++i) {
249 const T w = weights(i);
250 assert(w > T{0} && "aExpjReservoir requires strictly positive weights");
251 // randUnit draws from [0, 1); nudge zero away from the log singularity by resampling. In
252 // double precision a single redraw is sufficient with probability 1 - 2^-53.
253 T u = randUnit<T>(rng);
254 while (u <= T{0}) {
255 u = randUnit<T>(rng);
256 }
257 const T key = std::log(u) / w;
258 keyed.emplace_back(key, i);
259 }
260
261 // Partial sort by key descending: the k largest keys bubble to the front.
262 const auto cmp = [](const std::pair<T, std::size_t> &a,
263 const std::pair<T, std::size_t> &b) noexcept { return a.first > b.first; };
264 std::partial_sort(keyed.begin(), keyed.begin() + static_cast<std::ptrdiff_t>(k), keyed.end(),
265 cmp);
266
267 for (std::size_t i = 0; i < k; ++i) {
268 outIdx[i] = keyed[i].second;
269 }
270}
271
272} // namespace clustering::math
Represents a multidimensional array (NDArray) of a fixed number of dimensions N and element type T.
Definition ndarray.h:136
T randUnit(Rng &rng) noexcept
Draw a uniform variate in the half-open unit interval [0, 1).
Definition rng.h:152
void aExpjReservoir(const NDArray< T, 1, L > &weights, std::size_t k, Rng &rng, std::span< std::size_t > outIdx) noexcept
Efraimidis-Spirakis weighted reservoir sampling (A-Exp variant, log-key form).
Definition rng.h:233
std::uint64_t advanceState(pcg64 &rng) noexcept
Advance a pcg64 one step and return the 64-bit XSL-RR output.
Definition rng.h:63
std::size_t weightedCategorical(const NDArray< T, 1, L > &weights, Rng &rng) noexcept
Sample one category index proportionally to non-negative weights.
Definition rng.h:179
std::uint32_t randUniformU32(Rng &rng) noexcept
Draw a 32-bit unsigned integer uniformly at random from the full u32 range.
Definition rng.h:132
std::uint64_t randUniformU64(Rng &rng) noexcept
Draw a 64-bit unsigned integer uniformly at random from the full u64 range.
Definition rng.h:139
128-bit state for the PCG-XSL-RR 64-bit output generator (Melissa O'Neill).
Definition rng.h:30
__uint128_t m_inc
Stream-encoded odd increment mixed into the LCG step.
Definition rng.h:34
__uint128_t m_state
128-bit generator state; advanced by every advanceState call.
Definition rng.h:32
void seed(std::uint64_t seedValue, std::uint64_t stream=0) noexcept
Initialize the generator per PCG's canonical seeding procedure.
Definition rng.h:46
256-bit state for Vigna & Blackman's xoshiro256** generator.
Definition rng.h:80
void seed(std::uint64_t seedValue) noexcept
Initialize the four state words via SplitMix64 diffusion of a single 64-bit seed.
Definition rng.h:92
std::array< std::uint64_t, 4 > m_s
Four 64-bit state words; SplitMix64-diffused at seed time.
Definition rng.h:82