-
Notifications
You must be signed in to change notification settings - Fork 13
CrossValidation
There is a common CrossValidation
interface for all implemented methods. This interface defines mainly 4 methods as follows:
public interface CrossValidation {
/**
* perform learning and evaluations
*/
public void run();
/**
* Tells the average score of the test
* @return the average score
*/
public double getAverageScore();
/**
* Tells the standard deviation of the test
* @return the standard deviation
*/
public double getStdDevScore();
/**
* Tells the scores of the tests, in order of evaluation
* @return an array with the scores in order
*/
public double[] getScores();
}
Currently, 3 crossvalidation techniques are available, namely RandomSplitCrossValidation
, LeaveOneOutCrossValidation
and NFoldCrossValidation
. RandomSplitCrossValidation
performs several evaluation of the classifier using a random split of the provided sample set. LeaveOneOutCrossValidation
is the implementation of the well known leave one out protocole. NFoldCrossValidation
splits the data in n subset, using (n-1) for training and the last one for testing. CrossValidation
is agnostic regarding data type (as is the whole library) but also regarding the metric used (please refer to the Evaluator
classes for this point).
Suppose you have a Classifier<double[]> c
and a List<TrainingSample<double[]>> l
, you can initialize a CrossValidation as follows:
Evaluator<double[]> eval = new AccuracyEvaluator<double[]>();
RandomSplitCrossValidation<double[]> cv = new RandomSplitCrossValidation<double[]>(c, l, eval);
cv.setTrainPercent(0.80);
cv.setNbTest(10);
In this case, we use the accuracy metric computed by the AccuracyEvaluator
. We take 80% of the sample for training and the remaining for evaluation. The test will be done 10 times, using random split each time.
To run the test, we just call the run()
method:
cv.run();
Results can be obtained calling the getAverageScore()
and getStdDevScore()
methods:
debug.println(1,"Accuracy: " + cv.getAverageScore() + " +/- "
+ cv.getStdDevScore());