Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add KS tests for weighted sampling #1530

Merged
merged 21 commits into from
Nov 26, 2024
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add KS test for IndexedRandom::choose_weighted
dhardy committed Nov 18, 2024
commit 9e03a1571064cf5428675169cb95329c833ffd92
24 changes: 23 additions & 1 deletion distr_test/tests/weighted.rs
Original file line number Diff line number Diff line change
@@ -8,7 +8,8 @@

mod ks;
use ks::test_discrete;
use rand::distr::WeightedIndex;
use rand::distr::{Distribution, WeightedIndex};
use rand::seq::IndexedRandom;
use rand_distr::{WeightedAliasIndex, WeightedTreeIndex};

fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 {
@@ -75,3 +76,24 @@ fn weighted_tree_index() {
test_weights(100, |i| (i as f64).powi(3));
test_weights(100, |i| 1.0 / ((i + 1) as f64));
}

#[test]
fn choose_weighted_indexed() {
struct Adapter<F: Fn(i64) -> f64>(Vec<i64>, F);
impl<F: Fn(i64) -> f64> Distribution<i64> for Adapter<F> {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> i64 {
*IndexedRandom::choose_weighted(&self.0[..], rng, |i| (self.1)(*i)).unwrap()
}
}

fn test_weights(num: usize, weight: impl Fn(i64) -> f64) {
let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight);
test_discrete(0, distr, make_cdf(num, |i| weight(i)));
}

test_weights(100, |_| 1.0);
test_weights(100, |i| ((i + 1) as f64).ln());
test_weights(100, |i| i as f64);
test_weights(100, |i| (i as f64).powi(3));
test_weights(100, |i| 1.0 / ((i + 1) as f64));
}