Skip to content

Commit 2584f48

Browse files
authored
Fix pert for mode approx eq mean; use builder pattern (#1452)
- Fix #1311 (mode close to mean) - Use a builder pattern, allowing specification via mode OR mean
1 parent d17ce4e commit 2584f48

File tree

4 files changed

+86
-25
lines changed

4 files changed

+86
-25
lines changed

rand_distr/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## Unreleased
8-
98
### Added
109
- Add plots for `rand_distr` distributions to documentation (#1434)
10+
- Add `PertBuilder`, fix case where mode ≅ mean (#1452)
1111

1212
## [0.5.0-alpha.1] - 2024-03-18
1313
- Target `rand` version `0.9.0-alpha.1`

rand_distr/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ pub use self::normal_inverse_gaussian::{
117117
Error as NormalInverseGaussianError, NormalInverseGaussian,
118118
};
119119
pub use self::pareto::{Error as ParetoError, Pareto};
120-
pub use self::pert::{Pert, PertError};
120+
pub use self::pert::{Pert, PertBuilder, PertError};
121121
pub use self::poisson::{Error as PoissonError, Poisson};
122122
pub use self::skew_normal::{Error as SkewNormalError, SkewNormal};
123123
pub use self::triangular::{Triangular, TriangularError};

rand_distr/src/pert.rs

+83-22
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use rand::Rng;
3131
/// ```rust
3232
/// use rand_distr::{Pert, Distribution};
3333
///
34-
/// let d = Pert::new(0., 5., 2.5).unwrap();
34+
/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
3535
/// let v = d.sample(&mut rand::thread_rng());
3636
/// println!("{} is from a PERT distribution", v);
3737
/// ```
@@ -82,35 +82,75 @@ where
8282
Exp1: Distribution<F>,
8383
Open01: Distribution<F>,
8484
{
85-
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
85+
/// Construct a PERT distribution with defined `min`, `max`
8686
///
87-
/// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
87+
/// # Example
88+
///
89+
/// ```
90+
/// use rand_distr::Pert;
91+
/// let pert_dist = Pert::new(0.0, 10.0)
92+
/// .with_shape(3.5)
93+
/// .with_mean(3.0)
94+
/// .unwrap();
95+
/// # let _unused: Pert<f64> = pert_dist;
96+
/// ```
97+
#[allow(clippy::new_ret_no_self)]
98+
#[inline]
99+
pub fn new(min: F, max: F) -> PertBuilder<F> {
100+
let shape = F::from(4.0).unwrap();
101+
PertBuilder { min, max, shape }
102+
}
103+
}
104+
105+
/// Struct used to build a [`Pert`]
106+
#[derive(Debug)]
107+
pub struct PertBuilder<F> {
108+
min: F,
109+
max: F,
110+
shape: F,
111+
}
112+
113+
impl<F> PertBuilder<F>
114+
where
115+
F: Float,
116+
StandardNormal: Distribution<F>,
117+
Exp1: Distribution<F>,
118+
Open01: Distribution<F>,
119+
{
120+
/// Set the shape parameter
121+
///
122+
/// If not specified, this defaults to 4.
123+
#[inline]
124+
pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
125+
self.shape = shape;
126+
self
127+
}
128+
129+
/// Specify the mean
88130
#[inline]
89-
pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
90-
Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
131+
pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
132+
let two = F::from(2.0).unwrap();
133+
let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
134+
self.with_mode(mode)
91135
}
92136

93-
/// Set up the PERT distribution with defined `min`, `max`, `mode` and
94-
/// `shape`.
95-
pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
96-
if !(max > min) {
137+
/// Specify the mode
138+
#[inline]
139+
pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
140+
if !(self.max > self.min) {
97141
return Err(PertError::RangeTooSmall);
98142
}
99-
if !(mode >= min && max >= mode) {
143+
if !(mode >= self.min && self.max >= mode) {
100144
return Err(PertError::ModeRange);
101145
}
102-
if !(shape >= F::from(0.).unwrap()) {
146+
if !(self.shape >= F::from(0.).unwrap()) {
103147
return Err(PertError::ShapeTooSmall);
104148
}
105149

150+
let (min, max, shape) = (self.min, self.max, self.shape);
106151
let range = max - min;
107-
let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap());
108-
let v = if mu == mode {
109-
shape * F::from(0.5).unwrap() + F::from(1.).unwrap()
110-
} else {
111-
(mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min))
112-
};
113-
let w = v * (max - mu) / (mu - min);
152+
let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
153+
let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
114154
let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
115155
Ok(Pert { min, range, beta })
116156
}
@@ -136,17 +176,38 @@ mod test {
136176
#[test]
137177
fn test_pert() {
138178
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
139-
let _distr = Pert::new(min, max, mode).unwrap();
179+
let _distr = Pert::new(min, max).with_mode(mode).unwrap();
140180
// TODO: test correctness
141181
}
142182

143183
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
144-
assert!(Pert::new(min, max, mode).is_err());
184+
assert!(Pert::new(min, max).with_mode(mode).is_err());
145185
}
146186
}
147187

148188
#[test]
149-
fn pert_distributions_can_be_compared() {
150-
assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0));
189+
fn distributions_can_be_compared() {
190+
let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
191+
let p1 = Pert::new(min, max).with_mode(mode).unwrap();
192+
let mean = (min + shape * mode + max) / (shape + 2.0);
193+
let p2 = Pert::new(min, max).with_mean(mean).unwrap();
194+
assert_eq!(p1, p2);
195+
}
196+
197+
#[test]
198+
fn mode_almost_half_range() {
199+
assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
200+
}
201+
202+
#[test]
203+
fn almost_symmetric_about_zero() {
204+
let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
205+
assert!(distr.is_ok());
206+
}
207+
208+
#[test]
209+
fn almost_symmetric() {
210+
let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
211+
assert!(distr.is_ok());
151212
}
152213
}

rand_distr/tests/value_stability.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ fn pert_stability() {
250250
// mean = 4, var = 12/7
251251
test_samples(
252252
860,
253-
Pert::new(2., 10., 3.).unwrap(),
253+
Pert::new(2., 10.).with_mode(3.).unwrap(),
254254
&[
255255
4.908681667460367,
256256
4.014196196158352,

0 commit comments

Comments
 (0)