Skip to content

Commit

Permalink
fix!: return an error if vec_y contains negative value (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
tasshi-me authored Dec 29, 2022
1 parent 423d739 commit 93b4019
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 32 deletions.
14 changes: 0 additions & 14 deletions src/error.rs

This file was deleted.

4 changes: 4 additions & 0 deletions src/gaussian.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod deprecated;
mod error;
mod gaussian;
mod operations;

Expand All @@ -7,3 +8,6 @@ pub use self::deprecated::*;

#[doc(inline)]
pub use self::gaussian::Gaussian;

#[doc(inline)]
pub use self::error::GaussianError;
4 changes: 2 additions & 2 deletions src/gaussian/deprecated.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::Error;
use crate::gaussian::operations;
use crate::gaussian::GaussianError;
use crate::linalg::Float;
use ndarray::Array1;

Expand Down Expand Up @@ -80,6 +80,6 @@ pub fn val<F: Float>(x: F, mu: F, sigma: F, a: F) -> F {
since = "0.4.0",
note = "Please use the Gaussian::fit function instead"
)]
pub fn fit<F: Float>(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<(F, F, F), Error> {
pub fn fit<F: Float>(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<(F, F, F), GaussianError> {
operations::fitting_guos(x_vec, y_vec)
}
16 changes: 16 additions & 0 deletions src/gaussian/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::linalg::LinalgError;
use thiserror::Error;

#[derive(Error, Debug, Eq, PartialEq)]
pub enum GaussianError {
/// Given y_vec contains a negative value
#[error("Given y_vec contains a negative value")]
GivenYVecContainsNegativeValue,
/// Given y_vec contains a negative value
#[error("Given x_vec has no element")]
/// Given x_vec has no element
GivenXVecHasNoElement,
/// Error from [`crate::linalg::LinalgError`]
#[error("Linalg error: {0:?}")]
Linalg(#[from] LinalgError),
}
14 changes: 11 additions & 3 deletions src/gaussian/gaussian.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::convert::TryFrom;

use crate::error::Error;
use crate::gaussian::operations;
use crate::gaussian::{operations, GaussianError};
use crate::linalg::Float;
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
use ndarray::{array, Array1};
Expand Down Expand Up @@ -149,7 +148,7 @@ impl<F: Float> Gaussian<F> {
///
/// # References
/// \[1\] [E. Pastuchov ́a and M. Z ́akopˇcan, ”Comparison of Algorithms for Fitting a Gaussian Function used in Testing Smart Sensors”, Journal of Electrical Engineering, vol. 66, no. 3, pp. 178-181, 2015.](https://www.researchgate.net/publication/281907940_Comparison_of_Algorithms_For_Fitting_a_Gaussian_Function_Used_in_Testing_Smart_Sensors)
pub fn fit(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<Gaussian<F>, Error> {
pub fn fit(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<Gaussian<F>, GaussianError> {
let (mu, sigma, a) = operations::fitting_guos(x_vec, y_vec)?;
Ok(Gaussian::<F>::new(mu, sigma, a))
}
Expand Down Expand Up @@ -307,6 +306,15 @@ mod tests {
assert_abs_diff_eq!(gaussian, estimated, epsilon = 1e-9);
}

#[test]
fn fit_with_negative_value() {
let x_vec: Array1<f64> = array![1., 2., 3.];
let y_vec: Array1<f64> = array![-1., 1., -1.];

let err = Gaussian::fit(x_vec, y_vec).unwrap_err();
assert_eq!(err, GaussianError::GivenYVecContainsNegativeValue);
}

#[test]
fn as_tuple() {
let (mu, sigma, a): (f64, f64, f64) = (1., 2., 3.);
Expand Down
28 changes: 24 additions & 4 deletions src/gaussian/operations.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::Error;
use crate::gaussian::GaussianError;
use crate::linalg;
use crate::linalg::Float;
use ndarray::{array, Array1};
Expand All @@ -12,8 +12,11 @@ pub fn values<F: Float>(x_vec: Array1<F>, mu: F, sigma: F, a: F) -> Array1<F> {
}

#[allow(dead_code)]
pub fn fitting_caruanas<F: Float>(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<(F, F, F), Error> {
let len_x_vec = F::from(x_vec.len()).ok_or(Error::Optional)?;
pub fn fitting_caruanas<F: Float>(
x_vec: Array1<F>,
y_vec: Array1<F>,
) -> Result<(F, F, F), GaussianError> {
let len_x_vec = F::from(x_vec.len()).ok_or(GaussianError::GivenXVecHasNoElement)?;
let sum_x = x_vec.sum();
let sum_x_pow2 = x_vec.iter().map(|x| x.powi(2)).sum();
let sum_x_pow3 = x_vec.iter().map(|x| x.powi(3)).sum();
Expand Down Expand Up @@ -47,7 +50,14 @@ pub fn fitting_caruanas<F: Float>(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<
Ok((mu, sigma, a))
}

pub fn fitting_guos<F: Float>(x_vec: Array1<F>, y_vec: Array1<F>) -> Result<(F, F, F), Error> {
pub fn fitting_guos<F: Float>(
x_vec: Array1<F>,
y_vec: Array1<F>,
) -> Result<(F, F, F), GaussianError> {
// Guo's algorithm doesn't support negative value in y[]
if y_vec.iter().any(|y| y.is_sign_negative()) {
return Err(GaussianError::GivenYVecContainsNegativeValue);
}
let sum_y_pow2: F = y_vec.iter().map(|y| y.powi(2)).sum();
let sum_x_y_pow2 = y_vec
.iter()
Expand Down Expand Up @@ -161,4 +171,14 @@ mod tests {
epsilon = 1e-9
);
}

#[test]
fn gaussian_fit_guos_y_vec_contains_negative_value() {
let (mu, sigma, a): (f64, f64, f64) = (5., 3., 1.);
let x_vec: Array1<f64> = Array::range(1., 10., 1.);
let y_vec: Array1<f64> = values(x_vec.clone(), mu, sigma, a).map(|y| y - 0.5);

let err = fitting_guos(x_vec, y_vec).unwrap_err();
assert_eq!(err, GaussianError::GivenYVecContainsNegativeValue);
}
}
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
//! ```
//!
pub mod error;
pub mod gaussian;
pub mod linalg;

Expand Down
27 changes: 19 additions & 8 deletions src/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::error::Error;
use approx::{abs_diff_eq, abs_diff_ne};
use ndarray::{s, Array1, Array2, Axis};

Expand All @@ -9,6 +8,18 @@ use std::iter::Sum;
pub trait Float: NdFloat + Sum + AbsDiffEq + RelativeEq + UlpsEq {}
impl<F: NdFloat + Sum + AbsDiffEq + RelativeEq + UlpsEq> Float for F {}

use thiserror::Error;

#[derive(Error, Debug, Eq, PartialEq)]
pub enum LinalgError {
/// Equations have no solutions
#[error("Equations have no solutions")]
EquationsHaveNoSolutions,
/// Equations have infinite solutions
#[error("Equations have infinite solutions")]
EquationsHaveInfSolutions,
}

/// Solves a system of linear equations.
///
/// This function implements the Gaussian elimination.
Expand All @@ -25,7 +36,7 @@ impl<F: NdFloat + Sum + AbsDiffEq + RelativeEq + UlpsEq> Float for F {}
/// let x = linalg::solve(a, b).unwrap();
/// assert_abs_diff_eq!(x, array![1., -2., -2.], epsilon = 1e-9);
/// ```
pub fn solve<F: Float>(a: Array2<F>, b: Array1<F>) -> Result<Array1<F>, Error> {
pub fn solve<F: Float>(a: Array2<F>, b: Array1<F>) -> Result<Array1<F>, LinalgError> {
let mut a = a;
let mut b = b;

Expand Down Expand Up @@ -87,12 +98,12 @@ pub fn solve<F: Float>(a: Array2<F>, b: Array1<F>) -> Result<Array1<F>, Error> {

// no solutions
if rank_coef != rank_aug {
return Err(Error::LinalgSolveNoSolutions);
return Err(LinalgError::EquationsHaveNoSolutions);
}

// infinite solutions
if rank_coef != a.ncols() {
return Err(Error::LinalgSolveInfSolutions);
return Err(LinalgError::EquationsHaveInfSolutions);
}

// backward substitution
Expand Down Expand Up @@ -179,7 +190,7 @@ mod tests {
let a = array![[2., 1., -3., -2.], [2., -1., -1., 3.], [1., -1., -2., 2.]];
let b = array![4., 1., -3.];
let err = solve(a, b).unwrap_err(); //panic
assert_eq!(err, Error::LinalgSolveInfSolutions);
assert_eq!(err, LinalgError::EquationsHaveInfSolutions);
}

#[test]
Expand All @@ -195,7 +206,7 @@ mod tests {
];
let b = array![2., -6. / 5., -1., 1.];
let err = solve(a, b).unwrap_err(); //panic
assert_eq!(err, Error::LinalgSolveInfSolutions);
assert_eq!(err, LinalgError::EquationsHaveInfSolutions);
}

#[test]
Expand All @@ -206,7 +217,7 @@ mod tests {
let a = array![[-2., 3.], [4., 1.], [1., -3.],];
let b = array![1., 5., -1.];
let err = solve(a, b).unwrap_err(); //panic
assert_eq!(err, Error::LinalgSolveNoSolutions);
assert_eq!(err, LinalgError::EquationsHaveNoSolutions);
}

#[test]
Expand All @@ -217,6 +228,6 @@ mod tests {
let a = array![[1., 3., -2.], [-1., 2., -3.], [2., -1., 3.],];
let b = array![2., -2., 3.];
let err = solve(a, b).unwrap_err(); //panic
assert_eq!(err, Error::LinalgSolveNoSolutions);
assert_eq!(err, LinalgError::EquationsHaveNoSolutions);
}
}

0 comments on commit 93b4019

Please # to comment.