Skip to content

feat: make kernels an enum that implements kernel #181

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions src/model_selection/hyper_tuning/grid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,14 @@ mod tests {
use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{
hyper_tuning::grid_search::{self, GridSearchCVParameters},
KFold,
metrics::{accuracy, recall},
model_selection::{hyper_tuning::grid_search, KFold},
svm::{
svc::{SVCSearchParameters, SVC},
Kernels,
},
};
use grid_search::GridSearchCV;
use grid_search::{GridSearchCV, GridSearchCVParameters};

#[test]
fn test_grid_search() {
Expand Down Expand Up @@ -233,4 +234,58 @@ mod tests {
let result = grid_search.predict(&x).unwrap();
assert_eq!(result, vec![0.]);
}

#[test]
fn svm_check() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);

let y: Vec<f64> = vec![
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1.,
];

let kernels = vec![Kernels::linear(), Kernels::rbf(0.7), Kernels::rbf(0.9)];
let parameters = SVCSearchParameters {
kernel: kernels,
c: vec![1., 2.],
..Default::default()
};
let cv = KFold {
n_splits: 5,
..KFold::default()
};
let _grid_search = GridSearchCV::fit(
&x,
&y,
GridSearchCVParameters {
estimator: SVC::fit,
score: recall,
cv,
parameters_search: parameters.into_iter(),
_phantom: Default::default(),
},
)
.unwrap();
}
}
2 changes: 2 additions & 0 deletions src/model_selection/kfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::rand::get_rng_impl;
use rand::seq::SliceRandom;

/// K-Folds cross-validator
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KFold {
/// Number of folds. Must be at least 2.
pub n_splits: usize, // cannot exceed std::usize::MAX
Expand Down
3 changes: 2 additions & 1 deletion src/model_selection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
}

/// Cross validation results.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CrossValidationResult<T: RealNumber> {
/// Vector with test scores on each cv split
pub test_score: Vec<T>,
Expand Down
57 changes: 38 additions & 19 deletions src/svm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,57 +39,65 @@ pub trait Kernel<T: RealNumber, V: BaseVector<T>>: Clone {
}

/// Pre-defined kernel functions
pub struct Kernels {}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Kernels<T: RealNumber> {
/// Linear kernel
Linear(LinearKernel),
/// Radial basis function kernel (Gaussian)
RBF(RBFKernel<T>),
/// Polynomial kernel
Polynomial(PolynomialKernel<T>),
/// Sigmoid kernel
Sigmoid(SigmoidKernel<T>),
}

impl Kernels {
impl<T: RealNumber> Kernels<T> {
/// Linear kernel
pub fn linear() -> LinearKernel {
LinearKernel {}
pub fn linear() -> Self {
Self::Linear(LinearKernel {})
}

/// Radial basis function kernel (Gaussian)
pub fn rbf<T: RealNumber>(gamma: T) -> RBFKernel<T> {
RBFKernel { gamma }
pub fn rbf(gamma: T) -> Self {
Self::RBF(RBFKernel { gamma })
}

/// Polynomial kernel
/// * `degree` - degree of the polynomial
/// * `gamma` - kernel coefficient
/// * `coef0` - independent term in kernel function
pub fn polynomial<T: RealNumber>(degree: T, gamma: T, coef0: T) -> PolynomialKernel<T> {
PolynomialKernel {
pub fn polynomial(degree: T, gamma: T, coef0: T) -> Self {
Self::Polynomial(PolynomialKernel {
degree,
gamma,
coef0,
}
})
}

/// Polynomial kernel
/// * `degree` - degree of the polynomial
/// * `n_features` - number of features in vector
pub fn polynomial_with_degree<T: RealNumber>(
degree: T,
n_features: usize,
) -> PolynomialKernel<T> {
pub fn polynomial_with_degree(degree: T, n_features: usize) -> Self {
let coef0 = T::one();
let gamma = T::one() / T::from_usize(n_features).unwrap();
Kernels::polynomial(degree, gamma, coef0)
Self::polynomial(degree, gamma, coef0)
}

/// Sigmoid kernel
/// * `gamma` - kernel coefficient
/// * `coef0` - independent term in kernel function
pub fn sigmoid<T: RealNumber>(gamma: T, coef0: T) -> SigmoidKernel<T> {
SigmoidKernel { gamma, coef0 }
pub fn sigmoid(gamma: T, coef0: T) -> Self {
Self::Sigmoid(SigmoidKernel { gamma, coef0 })
}

/// Sigmoid kernel
/// * `gamma` - kernel coefficient
pub fn sigmoid_with_gamma<T: RealNumber>(gamma: T) -> SigmoidKernel<T> {
SigmoidKernel {
pub fn sigmoid_with_gamma(gamma: T) -> Self {
Self::Sigmoid(SigmoidKernel {
gamma,
coef0: T::one(),
}
})
}
}

Expand Down Expand Up @@ -128,6 +136,17 @@ pub struct SigmoidKernel<T: RealNumber> {
pub coef0: T,
}

impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for Kernels<T> {
fn apply(&self, x_i: &V, x_j: &V) -> T {
match self {
Self::Linear(k) => k.apply(x_i, x_j),
Self::RBF(k) => k.apply(x_i, x_j),
Self::Polynomial(k) => k.apply(x_i, x_j),
Self::Sigmoid(k) => k.apply(x_i, x_j),
}
}
}

impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for LinearKernel {
fn apply(&self, x_i: &V, x_j: &V) -> T {
x_i.dot(x_j)
Expand Down
22 changes: 11 additions & 11 deletions src/svm/svc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::svm::{Kernel, Kernels, LinearKernel};
use crate::svm::{Kernel, Kernels};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -129,10 +129,10 @@ pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowV
pub kernel: Vec<K>,
#[cfg_attr(feature = "serde", serde(default))]
/// Unused parameter.
m: PhantomData<M>,
pub m: PhantomData<M>,
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the pseudo random number generation for shuffling the data for probability estimates
seed: Vec<Option<u64>>,
pub seed: Vec<Option<u64>>,
}

/// SVC grid search iterator
Expand Down Expand Up @@ -219,9 +219,9 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
}
}

impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKernel> {
impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, Kernels<T>> {
fn default() -> Self {
let default_params: SVCParameters<T, M, LinearKernel> = SVCParameters::default();
let default_params: SVCParameters<T, M, Kernels<T>> = SVCParameters::default();

SVCSearchParameters {
epoch: vec![default_params.epoch],
Expand Down Expand Up @@ -319,7 +319,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
}
}

impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, Kernels<T>> {
fn default() -> Self {
SVCParameters {
epoch: 2,
Expand Down Expand Up @@ -879,19 +879,19 @@ mod tests {

#[test]
fn search_parameters() {
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, Kernels<_>> =
SVCSearchParameters {
epoch: vec![10, 100],
kernel: vec![LinearKernel {}],
kernel: vec![Kernels::linear()],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.epoch, 10);
assert_eq!(next.kernel, LinearKernel {});
assert_eq!(next.kernel, Kernels::linear());
let next = iter.next().unwrap();
assert_eq!(next.epoch, 100);
assert_eq!(next.kernel, LinearKernel {});
assert_eq!(next.kernel, Kernels::linear());
assert!(iter.next().is_none());
}

Expand Down Expand Up @@ -1065,7 +1065,7 @@ mod tests {

let svc = SVC::fit(&x, &y, Default::default()).unwrap();

let deserialized_svc: SVC<f64, DenseMatrix<f64>, LinearKernel> =
let deserialized_svc: SVC<f64, DenseMatrix<f64>, Kernels<_>> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();

assert_eq!(svc, deserialized_svc);
Expand Down
18 changes: 9 additions & 9 deletions src/svm/svr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::svm::{Kernel, Kernels, LinearKernel};
use crate::svm::{Kernel, Kernels};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -183,9 +183,9 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
}
}

impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, Kernels<T>> {
fn default() -> Self {
let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
let default_params: SVRParameters<T, M, Kernels<T>> = SVRParameters::default();

SVRSearchParameters {
eps: vec![default_params.eps],
Expand Down Expand Up @@ -272,7 +272,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVRParameters<T, M
}
}

impl<T: RealNumber, M: Matrix<T>> Default for SVRParameters<T, M, LinearKernel> {
impl<T: RealNumber, M: Matrix<T>> Default for SVRParameters<T, M, Kernels<T>> {
fn default() -> Self {
SVRParameters {
eps: T::from_f64(0.1).unwrap(),
Expand Down Expand Up @@ -641,19 +641,19 @@ mod tests {

#[test]
fn search_parameters() {
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, Kernels<_>> =
SVRSearchParameters {
eps: vec![0., 1.],
kernel: vec![LinearKernel {}],
kernel: vec![Kernels::linear()],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.eps, 0.);
assert_eq!(next.kernel, LinearKernel {});
assert_eq!(next.kernel, Kernels::linear());
let next = iter.next().unwrap();
assert_eq!(next.eps, 1.);
assert_eq!(next.kernel, LinearKernel {});
assert_eq!(next.kernel, Kernels::linear());
assert!(iter.next().is_none());
}

Expand Down Expand Up @@ -721,7 +721,7 @@ mod tests {

let svr = SVR::fit(&x, &y, Default::default()).unwrap();

let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
let deserialized_svr: SVR<f64, DenseMatrix<f64>, Kernels<_>> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();

assert_eq!(svr, deserialized_svr);
Expand Down