@@ -31,7 +31,7 @@ use rand::Rng;
31
31
/// ```rust
32
32
/// use rand_distr::{Pert, Distribution};
33
33
///
34
- /// let d = Pert::new(0., 5., 2.5).unwrap();
34
+ /// let d = Pert::new(0., 5.).with_mode( 2.5).unwrap();
35
35
/// let v = d.sample(&mut rand::thread_rng());
36
36
/// println!("{} is from a PERT distribution", v);
37
37
/// ```
@@ -82,35 +82,75 @@ where
82
82
Exp1 : Distribution < F > ,
83
83
Open01 : Distribution < F > ,
84
84
{
85
- /// Set up the PERT distribution with defined `min`, `max` and `mode`.
85
+ /// Construct a PERT distribution with defined `min`, `max`
86
86
///
87
- /// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
87
+ /// # Example
88
+ ///
89
+ /// ```
90
+ /// use rand_distr::Pert;
91
+ /// let pert_dist = Pert::new(0.0, 10.0)
92
+ /// .with_shape(3.5)
93
+ /// .with_mean(3.0)
94
+ /// .unwrap();
95
+ /// # let _unused: Pert<f64> = pert_dist;
96
+ /// ```
97
+ #[ allow( clippy:: new_ret_no_self) ]
98
+ #[ inline]
99
+ pub fn new ( min : F , max : F ) -> PertBuilder < F > {
100
+ let shape = F :: from ( 4.0 ) . unwrap ( ) ;
101
+ PertBuilder { min, max, shape }
102
+ }
103
+ }
104
+
105
+ /// Struct used to build a [`Pert`]
106
+ #[ derive( Debug ) ]
107
+ pub struct PertBuilder < F > {
108
+ min : F ,
109
+ max : F ,
110
+ shape : F ,
111
+ }
112
+
113
+ impl < F > PertBuilder < F >
114
+ where
115
+ F : Float ,
116
+ StandardNormal : Distribution < F > ,
117
+ Exp1 : Distribution < F > ,
118
+ Open01 : Distribution < F > ,
119
+ {
120
+ /// Set the shape parameter
121
+ ///
122
+ /// If not specified, this defaults to 4.
123
+ #[ inline]
124
+ pub fn with_shape ( mut self , shape : F ) -> PertBuilder < F > {
125
+ self . shape = shape;
126
+ self
127
+ }
128
+
129
+ /// Specify the mean
88
130
#[ inline]
89
- pub fn new ( min : F , max : F , mode : F ) -> Result < Pert < F > , PertError > {
90
- Pert :: new_with_shape ( min, max, mode, F :: from ( 4. ) . unwrap ( ) )
131
+ pub fn with_mean ( self , mean : F ) -> Result < Pert < F > , PertError > {
132
+ let two = F :: from ( 2.0 ) . unwrap ( ) ;
133
+ let mode = ( ( self . shape + two) * mean - self . min - self . max ) / self . shape ;
134
+ self . with_mode ( mode)
91
135
}
92
136
93
- /// Set up the PERT distribution with defined `min`, `max`, ` mode` and
94
- /// `shape`.
95
- pub fn new_with_shape ( min : F , max : F , mode : F , shape : F ) -> Result < Pert < F > , PertError > {
96
- if !( max > min) {
137
+ /// Specify the mode
138
+ # [ inline ]
139
+ pub fn with_mode ( self , mode : F ) -> Result < Pert < F > , PertError > {
140
+ if !( self . max > self . min ) {
97
141
return Err ( PertError :: RangeTooSmall ) ;
98
142
}
99
- if !( mode >= min && max >= mode) {
143
+ if !( mode >= self . min && self . max >= mode) {
100
144
return Err ( PertError :: ModeRange ) ;
101
145
}
102
- if !( shape >= F :: from ( 0. ) . unwrap ( ) ) {
146
+ if !( self . shape >= F :: from ( 0. ) . unwrap ( ) ) {
103
147
return Err ( PertError :: ShapeTooSmall ) ;
104
148
}
105
149
150
+ let ( min, max, shape) = ( self . min , self . max , self . shape ) ;
106
151
let range = max - min;
107
- let mu = ( min + max + shape * mode) / ( shape + F :: from ( 2. ) . unwrap ( ) ) ;
108
- let v = if mu == mode {
109
- shape * F :: from ( 0.5 ) . unwrap ( ) + F :: from ( 1. ) . unwrap ( )
110
- } else {
111
- ( mu - min) * ( F :: from ( 2. ) . unwrap ( ) * mode - min - max) / ( ( mode - mu) * ( max - min) )
112
- } ;
113
- let w = v * ( max - mu) / ( mu - min) ;
152
+ let v = F :: from ( 1.0 ) . unwrap ( ) + shape * ( mode - min) / range;
153
+ let w = F :: from ( 1.0 ) . unwrap ( ) + shape * ( max - mode) / range;
114
154
let beta = Beta :: new ( v, w) . map_err ( |_| PertError :: RangeTooSmall ) ?;
115
155
Ok ( Pert { min, range, beta } )
116
156
}
@@ -136,17 +176,38 @@ mod test {
136
176
#[ test]
137
177
fn test_pert ( ) {
138
178
for & ( min, max, mode) in & [ ( -1. , 1. , 0. ) , ( 1. , 2. , 1. ) , ( 5. , 25. , 25. ) ] {
139
- let _distr = Pert :: new ( min, max, mode) . unwrap ( ) ;
179
+ let _distr = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
140
180
// TODO: test correctness
141
181
}
142
182
143
183
for & ( min, max, mode) in & [ ( -1. , 1. , 2. ) , ( -1. , 1. , -2. ) , ( 2. , 1. , 1. ) ] {
144
- assert ! ( Pert :: new( min, max, mode) . is_err( ) ) ;
184
+ assert ! ( Pert :: new( min, max) . with_mode ( mode) . is_err( ) ) ;
145
185
}
146
186
}
147
187
148
188
#[ test]
149
- fn pert_distributions_can_be_compared ( ) {
150
- assert_eq ! ( Pert :: new( 1.0 , 3.0 , 2.0 ) , Pert :: new( 1.0 , 3.0 , 2.0 ) ) ;
189
+ fn distributions_can_be_compared ( ) {
190
+ let ( min, mode, max, shape) = ( 1.0 , 2.0 , 3.0 , 4.0 ) ;
191
+ let p1 = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
192
+ let mean = ( min + shape * mode + max) / ( shape + 2.0 ) ;
193
+ let p2 = Pert :: new ( min, max) . with_mean ( mean) . unwrap ( ) ;
194
+ assert_eq ! ( p1, p2) ;
195
+ }
196
+
197
+ #[ test]
198
+ fn mode_almost_half_range ( ) {
199
+ assert ! ( Pert :: new( 0.0f32 , 0.48258883 ) . with_mode( 0.24129441 ) . is_ok( ) ) ;
200
+ }
201
+
202
+ #[ test]
203
+ fn almost_symmetric_about_zero ( ) {
204
+ let distr = Pert :: new ( -10f32 , 10f32 ) . with_mode ( f32:: EPSILON ) ;
205
+ assert ! ( distr. is_ok( ) ) ;
206
+ }
207
+
208
+ #[ test]
209
+ fn almost_symmetric ( ) {
210
+ let distr = Pert :: new ( 0f32 , 2f32 ) . with_mode ( 1f32 + f32:: EPSILON ) ;
211
+ assert ! ( distr. is_ok( ) ) ;
151
212
}
152
213
}
0 commit comments