Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Re-org with distr::slice, distr::weighted modules #1548

Merged
merged 17 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/benches.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defaults:

jobs:
clippy-fmt:
name: Check Clippy and rustfmt
name: "Benches: Check Clippy and rustfmt"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -33,7 +33,7 @@ jobs:
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
benches:
name: Test benchmarks
name: "Benches: Test"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/distr_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defaults:

jobs:
clippy-fmt:
name: Check Clippy and rustfmt
name: "distr_test: Check Clippy and rustfmt"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -33,7 +33,7 @@ jobs:
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
ks-tests:
name: Run Komogorov Smirnov tests
name: "distr_test: Run Komogorov Smirnov tests"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
toolchain: stable
components: clippy, rustfmt
- name: Check Clippy
run: cargo clippy --all --all-targets -- -D warnings
run: cargo clippy --workspace -- -D warnings
- name: Check rustfmt
run: cargo fmt --all -- --check

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.

## [0.9.0-beta.3] - 2025-01-03
- Add feature `thread_rng` (#1547)
- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548)
- Rename trait `distr::DistString` -> `distr::SampleString` (#1548)
- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548)
- Move `distr::{Weight, WeightError, WeightedIndex}` -> `distr::weighted::{Weight, Error, WeightedIndex}` (#1548)

## [0.9.0-beta.1] - 2024-11-30
- Bump `rand_core` version
Expand Down
1 change: 1 addition & 0 deletions benches/benches/distr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use criterion_cycles_per_byte::CyclesPerByte;

use rand::prelude::*;
use rand_distr::weighted::*;
use rand_distr::*;

// At this time, distributions are optimised for 64-bit platforms.
Expand Down
2 changes: 1 addition & 1 deletion benches/benches/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// except according to those terms.

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::distr::WeightedIndex;
use rand::distr::weighted::WeightedIndex;
use rand::prelude::*;
use rand::seq::index::sample_weighted;

Expand Down
4 changes: 2 additions & 2 deletions distr_test/tests/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

mod ks;
use ks::test_discrete;
use rand::distr::{Distribution, WeightedIndex};
use rand::distr::Distribution;
use rand::seq::{IndexedRandom, IteratorRandom};
use rand_distr::{WeightedAliasIndex, WeightedTreeIndex};
use rand_distr::weighted::*;

/// Takes the unnormalized pdf and creates the cdf of a discrete distribution
fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 {
Expand Down
6 changes: 6 additions & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.5.0-beta.3] - 2025-01-03
- Bump `rand` version (#1547)
- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548)
- Rename trait `DistString` -> `SampleString` (#1548)
- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548)
- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548)
- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548)
- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548)

## [0.5.0-beta.2] - 2024-11-30
- Bump `rand` version
Expand Down
24 changes: 7 additions & 17 deletions rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
//!
//! The following are re-exported:
//!
//! - The [`Distribution`] trait and [`DistIter`] helper type
//! - The [`Distribution`] trait and [`Iter`] helper type
//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`],
//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions
//! [`Open01`], [`Bernoulli`] distributions
//! - The [`weighted`] module
//!
//! ## Distributions
//!
Expand Down Expand Up @@ -76,9 +77,6 @@
//! - [`UnitBall`] distribution
//! - [`UnitCircle`] distribution
//! - [`UnitDisc`] distribution
//! - Alternative implementations for weighted index sampling
//! - [`WeightedAliasIndex`] distribution
//! - [`WeightedTreeIndex`] distribution
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
Expand All @@ -94,7 +92,7 @@ extern crate std;
use rand::Rng;

pub use rand::distr::{
uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01,
uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01,
StandardUniform, Uniform,
};

Expand Down Expand Up @@ -128,16 +126,13 @@ pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zeta::{Error as ZetaError, Zeta};
pub use self::zipf::{Error as ZipfError, Zipf};
#[cfg(feature = "alloc")]
pub use rand::distr::{WeightError, WeightedIndex};
pub use student_t::StudentT;
#[cfg(feature = "alloc")]
pub use weighted_alias::WeightedAliasIndex;
#[cfg(feature = "alloc")]
pub use weighted_tree::WeightedTreeIndex;

pub use num_traits;

#[cfg(feature = "alloc")]
pub mod weighted;

#[cfg(test)]
#[macro_use]
mod test {
Expand Down Expand Up @@ -189,11 +184,6 @@ mod test {
}
}

#[cfg(feature = "alloc")]
pub mod weighted_alias;
#[cfg(feature = "alloc")]
pub mod weighted_tree;

mod beta;
mod binomial;
mod cauchy;
Expand Down
28 changes: 28 additions & 0 deletions rand_distr/src/weighted/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Weighted (index) sampling
//!
//! This module is a superset of [`rand::distr::weighted`].
//!
//! Multiple implementations of weighted index sampling are provided:
//!
//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction
//! and `O(log N)` sampling over `N` weights.
//! It also supports updating weights with `O(N)` time.
//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high
//! construction time many samples are required to outperform [`WeightedIndex`].
//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and
//! update/insertion/removal of weights with `O(log N)` time.

mod weighted_alias;
mod weighted_tree;

pub use rand::distr::weighted::*;
pub use weighted_alias::*;
pub use weighted_tree::*;
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.

use super::WeightError;
use super::Error;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use alloc::{boxed::Box, vec, vec::Vec};
use core::fmt;
Expand Down Expand Up @@ -41,7 +41,7 @@ use serde::{Deserialize, Serialize};
/// # Example
///
/// ```
/// use rand_distr::WeightedAliasIndex;
/// use rand_distr::weighted::WeightedAliasIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
Expand Down Expand Up @@ -85,14 +85,14 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
/// Creates a new [`WeightedAliasIndex`].
///
/// Error cases:
/// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`WeightError::InvalidWeight`] when a weight is not-a-number,
/// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
/// - [`Error::InvalidWeight`] when a weight is not-a-number,
/// negative or greater than `max = W::MAX / weights.len()`.
/// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightError> {
/// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, Error> {
let n = weights.len();
if n == 0 || n > u32::MAX as usize {
return Err(WeightError::InvalidInput);
return Err(Error::InvalidInput);
}
let n = n as u32;

Expand All @@ -103,7 +103,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(WeightError::InvalidWeight);
return Err(Error::InvalidWeight);
}

// The sum of weights will represent 100% of no alias odds.
Expand All @@ -115,7 +115,7 @@ impl<W: AliasableWeight> WeightedAliasIndex<W> {
weight_sum
};
if weight_sum == W::ZERO {
return Err(WeightError::InsufficientNonZero);
return Err(Error::InsufficientNonZero);
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
Expand Down Expand Up @@ -384,23 +384,23 @@ mod test {
// Floating point special cases
assert_eq!(
WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
WeightError::InsufficientNonZero
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand All @@ -418,11 +418,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand All @@ -440,11 +440,11 @@ mod test {
// Signed integer special cases
assert_eq!(
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
assert_eq!(
WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand Down Expand Up @@ -491,15 +491,15 @@ mod test {

assert_eq!(
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
WeightError::InvalidInput
Error::InvalidInput
);
assert_eq!(
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
WeightError::InsufficientNonZero
Error::InsufficientNonZero
);
assert_eq!(
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightError::InvalidWeight
Error::InvalidWeight
);
}

Expand Down
Loading
Loading