Introduction to Clustergram

Note

You can try this notebook in you browser: Binder

When we want to do some cluster analysis to identify groups in our data, we often use algorithms like K-Means, which require the specification of a number of clusters. But the issue is that we usually don’t know how many clusters there are.

There are many methods on how to determine the correct number, like silhouettes or elbow plot, to name a few. But they usually don’t give much insight into what is happening between different options, so the numbers are a bit abstract.

Matthias Schonlau proposed another approach - a clustergram. Clustergram is a two-dimensional plot capturing the flows of observations between classes as you add more clusters. It tells you how your data reshuffles and how good your splits are. Tal Galili later implemented clustergram for K-Means in R. And I have used Tal’s implementation, ported it to Python and created clustergram - a Python package to make clustergrams.

clustergram currently supports K-Means and using scikit-learn (inlcuding Mini-Batch implementation) and RAPIDS.AI cuML (if you have a CUDA-enabled GPU), Gaussian Mixture Model (scikit-learn only) and hierarchical clustering based on scipy.hierarchy. Alternatively, we can create clustergram based on labels and data derived from alternative custom clustering algorithms. It provides a sklearn-like API and plots clustergram using matplotlib, which gives it a wide range of styling options to match your publication style.

Install

You can install clustergram from conda or pip:

conda install clustergram -c conda-forge

or

pip install clustergram

In any case, you still need to install your selected backend (scikit-learn and scipy or cuML).

from clustergram import Clustergram
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale
sns.set(style='whitegrid')

Let us look at some examples to understand how clustergram looks and what to do with it.

Iris flower data set

The first example which we try to analyse using clustergram is the famous Iris flower data set. It contains data on three species of Iris flowers measuring sepal width and length and petal width and length. We can start with some exploration:

iris = sns.load_dataset("iris")
g = sns.pairplot(iris, hue="species")
g.fig.suptitle("Iris flowers", y=1.01)
Text(0.5, 1.01, 'Iris flowers')
../_images/introduction_5_1.png

It seems that setosa is a relatively well-defined group, while the difference between versicolor and virginica is smaller as they partially overlap (or entirely in the case of sepal width).

Okay, so we know how the data looks. Now we can check how does clustergram look. Remember that we know that there are three clusters, and we should ideally be able to recognise this from clustergram. I am saying ideally because even though there are known labels, it does not mean that our data or clustering method are able to distinguish between those classes.

Let’s start with K-Means clustering. To get a stable result, we can run a clustergram with 100 initialisations. Feel free to run even more.

data = scale(iris.drop(columns=['species']))
cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
cgram.fit(data)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2072/2153905428.py in <module>
      1 cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
----> 2 cgram.fit(data)

~/checkouts/readthedocs.org/user_builds/clustergram/conda/stable/lib/python3.9/site-packages/clustergram/clustergram.py in fit(self, data, **kwargs)
    202         if self.backend == "sklearn":
    203             if self.method == "kmeans":
--> 204                 self._kmeans_sklearn(data, minibatch=False, **kwargs)
    205             elif self.method == "minibatchkmeans":
    206                 self._kmeans_sklearn(data, minibatch=True, **kwargs)

~/checkouts/readthedocs.org/user_builds/clustergram/conda/stable/lib/python3.9/site-packages/clustergram/clustergram.py in _kmeans_sklearn(self, data, minibatch, **kwargs)
    240                 )
    241             else:
--> 242                 results = KMeans(n_clusters=n, **self.engine_kwargs).fit(data, **kwargs)
    243 
    244             self.labels[n] = results.labels_

~/checkouts/readthedocs.org/user_builds/clustergram/conda/stable/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py in fit(self, X, y, sample_weight)
   1186 
   1187             # run a k-means once
-> 1188             labels, inertia, centers, n_iter_ = kmeans_single(
   1189                 X,
   1190                 sample_weight,

~/checkouts/readthedocs.org/user_builds/clustergram/conda/stable/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py in _kmeans_single_elkan(X, sample_weight, centers_init, max_iter, verbose, x_squared_norms, tol, n_threads)
    539         )
    540 
--> 541     inertia = _inertia(X, sample_weight, centers, labels, n_threads)
    542 
    543     return labels, inertia, centers, i + 1

KeyboardInterrupt: 
ax = cgram.plot(figsize=(10, 8))
ax.yaxis.grid(False)
sns.despine(offset=10)
ax.set_title('K-Means (scikit-learn)')
Text(0.5, 1.0, 'K-Means (scikit-learn)')
../_images/introduction_9_1.png

On the x axis, we can see the number of clusters. Points represent a centre of each cluster (by default) weighted by the first principal component (that helps with the diagram’s readability). The lines connecting points and their thickness represent observations moving between clusters. Therefore, we can read when new clusters are formed as a split of a single existing class and when they are formed based on observations from two clusters.

We’re looking for the separation, i.e., did an additional cluster bring any meaningful split? The step from one cluster to two is a big one - nice and clear separation. From two to three, we also have quite a nice split in the top branch. But from 3 to 4, there is no visible difference because the new fourth cluster is almost the same as the existing bottom branch. Although it is now separated into two, this split does not give us much information. Therefore, we could conclude that the ideal number of clusters for Iris data is three.

We can also check some additional information, like a silhouette score or Calinski-Harabazs score.

fig, axs = plt.subplots(2, figsize=(10, 10), sharex=True)
cgram.silhouette_score().plot(xlabel="Number of clusters (k)", ylabel="Silhouette score", ax=axs[0])
cgram.calinski_harabasz_score().plot(xlabel="Number of clusters (k)", ylabel="Calinski-Harabasz score", ax=axs[1])
sns.despine(offset=10)
../_images/introduction_11_0.png

These plots would suggest 3-4 clusters, similarly to clustergram, but they are not very conclusive.

Palmer penguins data set

Now let’s try different data, one where clusters are a bit more complicated to assess. Palmer penguins contain similar data as Iris example, but it measures several attributes of three species of penguins.

penguins = sns.load_dataset("penguins")
g = sns.pairplot(penguins, hue="species")
g.fig.suptitle("Palmer penguins", y=1.01)
Text(0.5, 1.01, 'Palmer penguins')
../_images/introduction_14_1.png

Looking at the situation, we see that the overlap between species is much higher than before. It will likely be much more complicated to identify them. Again, we know that there are three clusters, but that does not mean that data has the power to distinguish between them. In this case, it may be especially tricky to differentiate between Adelie and Chinstrap penguins.

data = scale(penguins.drop(columns=['species', 'island', 'sex']).dropna())
cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
cgram.fit(data)
K=1 fitted in 0.25657081604003906 seconds.
K=2 fitted in 0.10536503791809082 seconds.
K=3 fitted in 0.1698451042175293 seconds.
K=4 fitted in 0.19611215591430664 seconds.
K=5 fitted in 0.2577400207519531 seconds.
K=6 fitted in 0.2907142639160156 seconds.
K=7 fitted in 0.352341890335083 seconds.
K=8 fitted in 0.385761022567749 seconds.
K=9 fitted in 0.4307372570037842 seconds.
ax = cgram.plot(figsize=(10, 8))
ax.yaxis.grid(False)
sns.despine(offset=10)
ax.set_title('K-Means (scikit-learn)')
Text(0.5, 1.0, 'K-Means (scikit-learn)')
../_images/introduction_18_1.png

We’re looking for separations, and this clustergram shows plenty. It is actually quite complicated to determine the optimal number of clusters. However, since we know what happens between different options, we can play with that. If we have a reason to be conservative, we can go with 4 clusters (I know, it is already more than the initial species). But further splits are also reasonable, which indicates that even higher granularity may provide useful insight, that there might be meaningful groups.

Can we say it is three? Since we know it should be three… Well, not really. The difference between the split from 2 - 3 and that from 3 - 4 is slight. However, the culprit here is K-Means, not clustergram. It just simply cannot correctly cluster these data due to the overlaps and the overall structure.

Let’s have a look at how the Gaussian Mixture Model does.

cgram = Clustergram(range(1, 10), n_init=100, method="gmm", verbose=False)
cgram.fit(data)
ax = cgram.plot(figsize=(10, 8))
ax.yaxis.grid(False)
sns.despine(offset=10)
ax.set_title('Gaussian Mixture Model (scikit-learn)')
K=1 fitted in 0.5946810245513916 seconds.
K=2 fitted in 0.25310826301574707 seconds.
K=3 fitted in 0.6201028823852539 seconds.
K=4 fitted in 0.6058881282806396 seconds.
K=5 fitted in 0.6077260971069336 seconds.
K=6 fitted in 0.7751941680908203 seconds.
K=7 fitted in 1.1589140892028809 seconds.
K=8 fitted in 1.319072961807251 seconds.
K=9 fitted in 1.5859079360961914 seconds.
Text(0.5, 1.0, 'Gaussian Mixture Model (scikit-learn)')
../_images/introduction_20_2.png

The result is very similar, though the difference between the third and fourth split is more pronounced. Even here, I would probably go with a four cluster solution.

A situation like this happens very often. The ideal case does not exist. We ultimately need to make a decision on the optimal number of clusters. Clustergam gives us additional insights into what happens between different options, how it splits. We can tell that the four-cluster option in Iris data is not helpful. We can tell that Palmer penguins may be tricky to cluster using K-Means, that there is no decisive right solution. Clustergram does not give an easy answer, but it gives us additional insight, and it is upon us how we interpret it.

If you want to play with the examples used in this documentation, the Jupyter notebook is on GitHub. You can also run it in an interactive binder environment in your browser.

For more information, check Tal Galili’s blog post and original papers by Matthias Schonlau.

Give it a go!