Skip to content

K Means and Finding "K"

EdwardRaff edited this page Mar 16, 2017 · 1 revision

Introduction

K-Means is one of the most commonly used algorithms for clustering, but has a weakness in that you need to tell it how many clusters "K" you want. This example shows a number of different ways one can search for the "K" in K-Means.

We should be clear that evaluating a clustering algorithm is not easy, and there is no agreed upon "best" way to do it. This example uses the class labels as a kind of ground-truth. We will look at the different values of K found for different data sets, and get an idea of how these tools can be helpful, even if not definitive.

Code

import java.io.File;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.IntStream;
import jsat.ARFFLoader;
import jsat.classifiers.ClassificationDataSet;
import jsat.clustering.Clusterer;
import jsat.clustering.GapStatistic;
import jsat.clustering.evaluation.ClusterEvaluation;
import jsat.clustering.evaluation.NormalizedMutualInformation;
import jsat.clustering.kmeans.GMeans;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.clustering.kmeans.KMeans;
import jsat.clustering.kmeans.KMeansPDN;
import jsat.clustering.kmeans.XMeans;
import jsat.datatransform.Imputer;
import jsat.datatransform.LinearTransform;

/**
 * @author Edward Raff
 */
public class KMeansAndK
{
    public static void main(String[] args)
    {
        //the data sets we we use, all have only numeric features and a class label
        String[] dataSetName = new String[]
        {
            "breast-w", "heart-statlog", "ionosphere", "iris", "sonar", 
        };
        ClassificationDataSet[] dataSets = new ClassificationDataSet[dataSetName.length];
        
        for(int i = 0; i < dataSetName.length; i++)
        {
            ClassLoader classloader = Thread.currentThread().getContextClassLoader();
            File file = new File(classloader.getResource(dataSetName[i] + ".arff").getFile());
            //We know that there is only one categorical feature for each of these data sets, and it is the class label. So we use '0' as the argument 
            dataSets[i] = ARFFLoader.loadArffFile(file).asClassificationDataSet(0);
            dataSets[i].applyTransform(new Imputer(dataSets[i]));//impute missing values in the dataset
            dataSets[i].applyTransform(new LinearTransform(dataSets[i]));//scale feature values to [0, 1]
        }
        
        /**
         * This map contains 4 different methods that try to infer the "best"
         * value of k for k-means. These don't always work well, but its
         * something lots of people are interested in.
         */
        Map<String, Clusterer> methodsToEval = new LinkedHashMap<String, Clusterer>()
        {{
            put("PDN KMeans", new KMeansPDN());
            put("Gap-Means", new GapStatistic());
            put("X-Means", new XMeans());
            put("G-Means", new GMeans());
        }};
        
        /**
         * We will compare these with 3 different values of k that we will
         * explicitly cluster for. Feel free to add/remove values from the list
         */
        int[] kToTest = new int[]{2, 3, 6};
        
        /**
         * We will use the NMI as our evaluation criteria. It compares the
         * clustering results with the class labels. The class labels aren't
         * necessarily the best ground truth for clusters. In fact, how to
         * properly evaluate clustering algorithms is a very open question! But
         * this is a commonly used method.
         *
         * The ClusterEvaluation interface dictates that values near 0 are
         * better, and larger values are worse. NMI is usually the opposite, but
         * obeys the interface. Read the NMI's Javadoc for more details.
         */
        ClusterEvaluation evaluator = new NormalizedMutualInformation();
        
        /**
         * And finally, we will use a normal k-means algorithm to do clustering
         * when we specify the number of clusters we want. JSAT implements a
         * number of different algorithms that all solve the k-means problem,
         * and are better in different scenarios. This one is likely to be the
         * best for most users.
         */
        KMeans simpleKMeans = new HamerlyKMeans();
        
        
        /*
         * Lets print out a simple header. First two values will be our data set
         * name and the number of classes in that data set. Then for each model 
         * in c1ToEval, we will print out the value of k it determined and the 
         * evaluation of that clustering. Finally we will print out the 
         * evaluation for running K-Means with some select values of k
         */
        System.out.printf("%-20.20s", "Data Set: classes");
        for( String name : methodsToEval.keySet())
            System.out.printf("%-15.15s", "| k , " + name + "");
        for( int k : kToTest)
            System.out.printf("%-15.15s", "| " + k + "-Means");
        System.out.println();
        
        //now we will loop through every data set, and evaluate all of our clustering algorithms
        for(int i = 0; i < dataSets.length; i++)
        {
            ClassificationDataSet data = dataSets[i];
            System.out.printf("%-15s: %2d | ", dataSetName[i], data.getClassSize());
            //print out the number of clusters chosen and the evaluation for each automatic version of k-means
            int[] clusteringResults = new int[data.getSampleSize()];//hold the clustering results, this version is manditory to use NMI
            for( Clusterer clusterer : methodsToEval.values())
            {
                //when we call this constructor, the algorithm is expected to figure out the number of clusters on its own
               clusterer.cluster(data, clusteringResults);
                //the number of clusters found can be determined from the maximimum cluster ID returned, +1 since 0 is a cluster ID. 
                int kFound = IntStream.of(clusteringResults).max().getAsInt()+1;;
                System.out.printf("%4d , %.3f | ",kFound, evaluator.evaluate(clusteringResults, data));
            }
            //now lets print out our results when we specify the value of k we want to try
            for(int k : kToTest)
            {
                //run k-means with a specific value of k, and keep track of cluster assignments
                clusteringResults = simpleKMeans.cluster(data, k, clusteringResults);
                //now evaluate the cluster assignments and print a score
                System.out.printf("    %.3f    | ",evaluator.evaluate(clusteringResults, data));
            }
            System.out.println();
        }
    }
}

Results

Data Set: classes k , PDN KMean k , Gap-Means k , X-Means k , G-Means 2-Means 3-Means 6-Means
breast-w : 2 3 , 0.330 6 , 0.535 34 , 0.707 34 , 0.706 0.270 0.335 0.429
heart-statlog : 2 10 , 0.846 8 , 0.876 13 , 0.852 13 , 0.833 0.721 0.744 0.851
ionosphere : 2 9 , 0.734 2 , 0.870 16 , 0.750 17 , 0.758 0.870 0.821 0.759
iris : 3 3 , 0.287 3 , 0.286 7 , 0.404 2 , 0.266 0.266 0.258 0.358
sonar : 2 4 , 0.977 10 , 0.901 10 , 0.904 1 , 1.000 0.989 0.996 0.937

You should see something like the above output when you run this program. Immediately we can see some common trends. G and X means tend to guess a larger number of clusters. While G and X means are somewhat related, that they have returned almost the same number of clusters is more happenstance then anything - you shouldn't expect that. PDN and Gap were closer the the correct number of clusters.

For the scores in the table, closer to zero is better (perfect agreement with class labels) and 1.0 is worse (looks completely random). For breast-w and heart-statlog, we can see that having fewer clusters tended to give a better score. Though the agreement between clusters and labels was quite low for heart-statlog.

On ionosphere we actually see the opposite, where G and X means got better NMI scores by picking considerably more clusters than the number of classes! This can easily happen, especially with a simple algorithm like k-means. So you shouldn't consider that "wrong" necessarily. Maybe there is something more interesting in the data than just the classes previously chosen?

Sonar we can see everyone had a lot of trouble (G-Means gave up and said it was all one cluster!). We would definitely need to do some more feature engineering (or picking a better clustering algorithm) to make progress on that data.

Conclusions

You've now seen how we can do clustering in JSAT, and even try to find the "best K" for K-Means clustering. You can use this code with other clustering algorithms as well, and other data sets, to see what happens. Learning how to interpret these kind of results is part of the challenge and skill involved in applying these methods in practice.