From c8713840ed331a7283b6404bd7ca5a6238883a90 Mon Sep 17 00:00:00 2001 From: Luis Moreno Date: Sat, 1 Oct 2022 18:21:45 -0500 Subject: [PATCH 1/3] feat: make kernels an enum that implements kernel --- .../hyper_tuning/grid_search.rs | 44 +++++++++++++-- src/svm/mod.rs | 56 ++++++++++++------- src/svm/svc.rs | 20 +++---- src/svm/svr.rs | 16 +++--- 4 files changed, 94 insertions(+), 42 deletions(-) diff --git a/src/model_selection/hyper_tuning/grid_search.rs b/src/model_selection/hyper_tuning/grid_search.rs index 1544faf0..741c9477 100644 --- a/src/model_selection/hyper_tuning/grid_search.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -163,13 +163,14 @@ mod tests { use crate::{ linalg::naive::dense_matrix::DenseMatrix, linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters}, - metrics::accuracy, - model_selection::{ - hyper_tuning::grid_search::{self, GridSearchCVParameters}, - KFold, + metrics::{accuracy, recall}, + model_selection::{hyper_tuning::grid_search, KFold}, + svm::{ + svc::{SVCSearchParameters, SVC}, + Kernels, }, }; - use grid_search::GridSearchCV; + use grid_search::{GridSearchCV, GridSearchCVParameters}; #[test] fn test_grid_search() { @@ -233,4 +234,37 @@ mod tests { let result = grid_search.predict(&x).unwrap(); assert_eq!(result, vec![0.]); } + + #[test] + fn svm_check() { + let breast_cancer = crate::dataset::breast_cancer::load_dataset(); + let y = breast_cancer.target; + let x = DenseMatrix::from_array( + breast_cancer.num_samples, + breast_cancer.num_features, + &breast_cancer.data, + ); + let kernels = vec![Kernels::linear(), Kernels::rbf(0.001), Kernels::rbf(0.0001)]; + let parameters = SVCSearchParameters { + kernel: kernels, + c: vec![0., 10., 100., 1000.], + ..Default::default() + }; + let cv = KFold { + n_splits: 5, + ..KFold::default() + }; + let _grid_search = GridSearchCV::fit( + &x, + &y, + GridSearchCVParameters { + estimator: SVC::fit, + score: recall, + cv, + parameters_search: parameters.into_iter(), + _phantom: Default::default(), + }, + ) + .unwrap(); + } } diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 4c71b3f2..8ef5d456 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -39,57 +39,64 @@ pub trait Kernel>: Clone { } /// Pre-defined kernel functions -pub struct Kernels {} +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Kernels { + /// Linear kernel + Linear(LinearKernel), + /// Radial basis function kernel (Gaussian) + RBF(RBFKernel), + /// Polynomial kernel + Polynomial(PolynomialKernel), + /// Sigmoid kernel + Sigmoid(SigmoidKernel), +} -impl Kernels { +impl Kernels { /// Linear kernel - pub fn linear() -> LinearKernel { - LinearKernel {} + pub fn linear() -> Self { + Self::Linear(LinearKernel {}) } /// Radial basis function kernel (Gaussian) - pub fn rbf(gamma: T) -> RBFKernel { - RBFKernel { gamma } + pub fn rbf(gamma: T) -> Self { + Self::RBF(RBFKernel { gamma }) } /// Polynomial kernel /// * `degree` - degree of the polynomial /// * `gamma` - kernel coefficient /// * `coef0` - independent term in kernel function - pub fn polynomial(degree: T, gamma: T, coef0: T) -> PolynomialKernel { - PolynomialKernel { + pub fn polynomial(degree: T, gamma: T, coef0: T) -> Self { + Self::Polynomial(PolynomialKernel { degree, gamma, coef0, - } + }) } /// Polynomial kernel /// * `degree` - degree of the polynomial /// * `n_features` - number of features in vector - pub fn polynomial_with_degree( - degree: T, - n_features: usize, - ) -> PolynomialKernel { + pub fn polynomial_with_degree(degree: T, n_features: usize) -> Self { let coef0 = T::one(); let gamma = T::one() / T::from_usize(n_features).unwrap(); - Kernels::polynomial(degree, gamma, coef0) + Self::polynomial(degree, gamma, coef0) } /// Sigmoid kernel /// * `gamma` - kernel coefficient /// * `coef0` - independent term in kernel function - pub fn sigmoid(gamma: T, coef0: T) -> SigmoidKernel { - SigmoidKernel { gamma, coef0 } + pub fn sigmoid(gamma: T, coef0: T) -> Self { + Self::Sigmoid(SigmoidKernel { gamma, coef0 }) } /// Sigmoid kernel /// * `gamma` - kernel coefficient - pub fn sigmoid_with_gamma(gamma: T) -> SigmoidKernel { - SigmoidKernel { + pub fn sigmoid_with_gamma(gamma: T) -> Self { + Self::Sigmoid(SigmoidKernel { gamma, coef0: T::one(), - } + }) } } @@ -128,6 +135,17 @@ pub struct SigmoidKernel { pub coef0: T, } +impl> Kernel for Kernels { + fn apply(&self, x_i: &V, x_j: &V) -> T { + match self { + Self::Linear(k) => k.apply(x_i, x_j), + Self::RBF(k) => k.apply(x_i, x_j), + Self::Polynomial(k) => k.apply(x_i, x_j), + Self::Sigmoid(k) => k.apply(x_i, x_j), + } + } +} + impl> Kernel for LinearKernel { fn apply(&self, x_i: &V, x_j: &V) -> T { x_i.dot(x_j) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 3354d0da..9f03a1b4 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -85,7 +85,7 @@ use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::rand::get_rng_impl; -use crate::svm::{Kernel, Kernels, LinearKernel}; +use crate::svm::{Kernel, Kernels}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -129,10 +129,10 @@ pub struct SVCSearchParameters, K: Kernel, #[cfg_attr(feature = "serde", serde(default))] /// Unused parameter. - m: PhantomData, + pub m: PhantomData, #[cfg_attr(feature = "serde", serde(default))] /// Controls the pseudo random number generation for shuffling the data for probability estimates - seed: Vec>, + pub seed: Vec>, } /// SVC grid search iterator @@ -219,9 +219,9 @@ impl, K: Kernel> Iterator } } -impl> Default for SVCSearchParameters { +impl> Default for SVCSearchParameters> { fn default() -> Self { - let default_params: SVCParameters = SVCParameters::default(); + let default_params: SVCParameters> = SVCParameters::default(); SVCSearchParameters { epoch: vec![default_params.epoch], @@ -319,7 +319,7 @@ impl, K: Kernel> SVCParameters> Default for SVCParameters { +impl> Default for SVCParameters> { fn default() -> Self { SVCParameters { epoch: 2, @@ -879,19 +879,19 @@ mod tests { #[test] fn search_parameters() { - let parameters: SVCSearchParameters, LinearKernel> = + let parameters: SVCSearchParameters, Kernels<_>> = SVCSearchParameters { epoch: vec![10, 100], - kernel: vec![LinearKernel {}], + kernel: vec![Kernels::linear()], ..Default::default() }; let mut iter = parameters.into_iter(); let next = iter.next().unwrap(); assert_eq!(next.epoch, 10); - assert_eq!(next.kernel, LinearKernel {}); + assert_eq!(next.kernel, Kernels::linear()); let next = iter.next().unwrap(); assert_eq!(next.epoch, 100); - assert_eq!(next.kernel, LinearKernel {}); + assert_eq!(next.kernel, Kernels::linear()); assert!(iter.next().is_none()); } diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 25326d4c..eb871ca2 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -76,7 +76,7 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; -use crate::svm::{Kernel, Kernels, LinearKernel}; +use crate::svm::{Kernel, Kernels}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -183,9 +183,9 @@ impl, K: Kernel> Iterator } } -impl> Default for SVRSearchParameters { +impl> Default for SVRSearchParameters> { fn default() -> Self { - let default_params: SVRParameters = SVRParameters::default(); + let default_params: SVRParameters> = SVRParameters::default(); SVRSearchParameters { eps: vec![default_params.eps], @@ -272,7 +272,7 @@ impl, K: Kernel> SVRParameters> Default for SVRParameters { +impl> Default for SVRParameters> { fn default() -> Self { SVRParameters { eps: T::from_f64(0.1).unwrap(), @@ -641,19 +641,19 @@ mod tests { #[test] fn search_parameters() { - let parameters: SVRSearchParameters, LinearKernel> = + let parameters: SVRSearchParameters, Kernels<_>> = SVRSearchParameters { eps: vec![0., 1.], - kernel: vec![LinearKernel {}], + kernel: vec![Kernels::linear()], ..Default::default() }; let mut iter = parameters.into_iter(); let next = iter.next().unwrap(); assert_eq!(next.eps, 0.); - assert_eq!(next.kernel, LinearKernel {}); + assert_eq!(next.kernel, Kernels::linear()); let next = iter.next().unwrap(); assert_eq!(next.eps, 1.); - assert_eq!(next.kernel, LinearKernel {}); + assert_eq!(next.kernel, Kernels::linear()); assert!(iter.next().is_none()); } From 7bc706333175f2247aeb0862adef3d1a363a5c99 Mon Sep 17 00:00:00 2001 From: Luis Moreno Date: Sat, 1 Oct 2022 19:24:23 -0500 Subject: [PATCH 2/3] Update test --- .../hyper_tuning/grid_search.rs | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/model_selection/hyper_tuning/grid_search.rs b/src/model_selection/hyper_tuning/grid_search.rs index 741c9477..68389bfb 100644 --- a/src/model_selection/hyper_tuning/grid_search.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -237,17 +237,38 @@ mod tests { #[test] fn svm_check() { - let breast_cancer = crate::dataset::breast_cancer::load_dataset(); - let y = breast_cancer.target; - let x = DenseMatrix::from_array( - breast_cancer.num_samples, - breast_cancer.num_features, - &breast_cancer.data, - ); - let kernels = vec![Kernels::linear(), Kernels::rbf(0.001), Kernels::rbf(0.0001)]; + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + + let y: Vec = vec![ + -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., + ]; + + let kernels = vec![Kernels::linear(), Kernels::rbf(0.7), Kernels::rbf(0.9)]; let parameters = SVCSearchParameters { kernel: kernels, - c: vec![0., 10., 100., 1000.], + c: vec![1., 2.], ..Default::default() }; let cv = KFold { From 5e9be367b419f4cf88da9449eb22484db69cec6f Mon Sep 17 00:00:00 2001 From: Luis Moreno Date: Sun, 2 Oct 2022 09:02:40 -0500 Subject: [PATCH 3/3] Fix serde --- src/model_selection/kfold.rs | 2 ++ src/model_selection/mod.rs | 3 ++- src/svm/mod.rs | 1 + src/svm/svc.rs | 2 +- src/svm/svr.rs | 2 +- 5 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/model_selection/kfold.rs b/src/model_selection/kfold.rs index ef48b872..b0d54215 100644 --- a/src/model_selection/kfold.rs +++ b/src/model_selection/kfold.rs @@ -9,6 +9,8 @@ use crate::rand::get_rng_impl; use rand::seq::SliceRandom; /// K-Folds cross-validator +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct KFold { /// Number of folds. Must be at least 2. pub n_splits: usize, // cannot exceed std::usize::MAX diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index f16b9559..22709bf6 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -176,7 +176,8 @@ pub fn train_test_split>( } /// Cross validation results. -#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct CrossValidationResult { /// Vector with test scores on each cv split pub test_score: Vec, diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 8ef5d456..b2e6e544 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -39,6 +39,7 @@ pub trait Kernel>: Clone { } /// Pre-defined kernel functions +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone, PartialEq, Eq)] pub enum Kernels { /// Linear kernel diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 9f03a1b4..42140713 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -1065,7 +1065,7 @@ mod tests { let svc = SVC::fit(&x, &y, Default::default()).unwrap(); - let deserialized_svc: SVC, LinearKernel> = + let deserialized_svc: SVC, Kernels<_>> = serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); assert_eq!(svc, deserialized_svc); diff --git a/src/svm/svr.rs b/src/svm/svr.rs index eb871ca2..546bdbd3 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -721,7 +721,7 @@ mod tests { let svr = SVR::fit(&x, &y, Default::default()).unwrap(); - let deserialized_svr: SVR, LinearKernel> = + let deserialized_svr: SVR, Kernels<_>> = serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); assert_eq!(svr, deserialized_svr);