Skip to content

Generating a data set with Guassians

EdwardRaff edited this page Oct 2, 2016 · 2 revisions

#Introduction

In the previous example we look at data transforms, which is one tool to improving accuracy when we have a more difficult classification problem. For the braver souls, it may be advantages to create a custom classifier. This could be especially true if there is some expert knowledge about the data you could incorporate into the decision process.

In this example, we look at generating a custom data set. While testing a new classifier on existing data sets is a good idea, sometimes you want to create a data set that has certain known properties, and problems, that you would like to test your classifier against. This example creates a 4 class problem with simple Gaussian distributions. This is for simplicity, by combining Gaussians, you can create many arbitrarily complex data sets.

The GUI code has been removed from JSAT. A small separate project has new GUI code related to JSAT that we will use for this example. To use it, add this repo and dependency to your maven pom file.

#Code

import com.edwardraff.jsatfx.Plot;
import java.util.Random;
import javafx.application.Platform;
import javafx.embed.swing.JFXPanel;
import javafx.scene.Scene;
import javafx.scene.chart.ScatterChart;
import javax.swing.JFrame;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;

/**
 * It can often be useful to generate synthetic data sets to test out classifiers on, and get a feel for how they work. 
 * 
 * @author Edward Raff
 */
public class GeneratingDataExample
{
    public static void main(String[] args)
    {
        //We create a new data set. This data set will have 2 dimensions so we can visualize it, and 4 target class values
        ClassificationDataSet dataSet = new ClassificationDataSet(2, new CategoricalData[0], new CategoricalData(4));

        //We can generate data from a multivarete normal distribution. The 'M' at the end stands for Multivariate 
        NormalM normal;

        //The normal is specifed by a mean and covariance matrix. The covariance matrix must be symmetric. 
        //We use a simple covariance matrix for each data point for simplicity
        Matrix covariance = new DenseMatrix(new double[][]
        {
            {1.0, 0.0}, //Try altering these values to see the change!
            {0.0, 1.0} //Just make sure its still symetric & a valid covariance matrix! 
        });

        //And we create 3 different means
        Vec mean0 = DenseVector.toDenseVec(0.0, 0.0);
        Vec mean1 = DenseVector.toDenseVec(0.0, 4.0);
        Vec mean2 = DenseVector.toDenseVec(4.0, 0.0);
        Vec mean3 = DenseVector.toDenseVec(4.0, 4.0);

        Vec[] means = new Vec[] {mean0, mean1, mean2, mean3};

        //We now generate out data
        for(int i = 0; i < means.length; i++)
        {
            normal = new NormalM(means[i], covariance);
            for(Vec sample : normal.sample(300, new Random()))
                dataSet.addDataPoint(sample, new int[0], i);
        }
        
        JFrame jFrame = new JFrame("2D Visualization");
        JFXPanel panel = new JFXPanel();
        jFrame.add(panel);
        jFrame.setSize(400, 400);
        jFrame.setVisible(true);
        jFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

        Platform.runLater(() -> 
        {
            //this will create the plot that does the visualization
            ScatterChart<Number, Number> plot = Plot.scatterC(dataSet);
            Scene scene = new Scene(plot);
            panel.setScene(scene);
        });
    }
}