Mean Shift with SkLearn

Mean shift clustering uses sliding-window to find dense areas in the data points. It is a centroid-based algorithm. The goal of the algorithm is to locate the centre points of each group/class, which works by updating candidates for centre points to be the mean of the points within the sliding-window. These candidate windows are then filtered in a post-processing stage to eliminate near-duplicates, forming the final set of centre points and their corresponding groups. The assumption is that the population is dense at the centre of each cluster.

Mean-shift algorithm starts by selecting a random point in the data space. Then, check the data in a fixed radius around the point, to identify the direction where the density increases most. The point continues to shift in the data space in order to move in the direction where the density increases most. This continues till you reach a point where the density decreases on every side. This is the centre of the group. This process is done with many sliding windows until all points lie within a window. When multiple windows overlap the merge and the maximum is preserved. The data points are then clustered according to the sliding window in which they reside.

This is different from the K-means algorithm because we do not select the number of clusters. The mean shift algorithm identifies this for us. That adds a lot of value. Ofcourse it leaves the room for configuration (hence room for doubt) because we have to select the radius of the sliding window. The radius in turn impacts the count of clusters identified.


Let us now check out how this could be implemented in Python using ScikitLearn. To generate the training data, ScikitLearn provides us a method make_blobs

We can start by importing the required modules

from sklearn.datasets.samples_generator import make_blobs
from sklearn.cluster import MeanShift
from sklearn.model_selection import train_test_split

Now, we can generate the required data. Note that since this is a clustering application. We do not really need the target output. But, we generate the same just to make sure our model generates good clusters.

X, y = make_blobs(n_samples=500, centers=5, n_features=5, random_state=0)

We can have a look at the data we got

(500, 5)
array([[ 0.7927672 ,  4.69609538,  2.29690233,  1.00159957, -2.3564752 ],
        [ 5.60928139, -0.0376682 , -0.68774591,  9.60509046, -8.52337837]])
array([0, 2])

Now, create an instance of the MeanShift and try to fit the data

ap = MeanShift()

Now we need to check how good is this clustering. The values in y were the clusters assigned by the make_blogs. They may not match exactly with the clusters generated by the MeanShift. For example, cluster 1 in the original data could be cluster 2 here. That is not important. But it is important that the clustering is similar. That is, if two elements were in the same cluster in the first set, they should be in the same group in the second set as well. To verify this, we can write a small script:

mismatch1 = 0
mismatch2 = 0
for i in range(y1.shape[0]):
    for j in range(y1.shape[0]):
        if ((y[i] == y[j]) and (y1[i] != y1[j])):
            mismatch1 = mismatch1 + 1
        if ((y1[i] == y1[j]) and (y[i] != y[j])):
            mismatch2 = mismatch2 + 1


The output is 0 and 120000! How did that happen? The script measures two types of problems. Two elements that were in the same cluster in the reference set should be in marked in the same cluster by the algorithm. And the other checks if two elements marked in the same cluster were actually in the same cluster. The first test passes without any problem. But the second one miserably goofed up. That means, the algorithm has grouped a few clusters together. So points that were not in the same cluster appear to be so. This limitation of the algorithm is that it identifies the cluster based on the density of points and if an entire cluster is not dense enough, it is pushed into another dense cluster. Thus, it may miss some clusters. But its performance is pretty good.