Introduction to Clustergram


You can try this notebook in you browser: Binder

When we want to do some cluster analysis to identify groups in our data, we often use algorithms like K-Means, which require the specification of a number of clusters. But the issue is that we usually don’t know how many clusters there are.

There are many methods on how to determine the correct number, like silhouettes or elbow plot, to name a few. But they usually don’t give much insight into what is happening between different options, so the numbers are a bit abstract.

Matthias Schonlau proposed another approach - a clustergram. Clustergram is a two-dimensional plot capturing the flows of observations between classes as you add more clusters. It tells you how your data reshuffles and how good your splits are. Tal Galili later implemented clustergram for K-Means in R. And I have used Tal’s implementation, ported it to Python and created clustergram - a Python package to make clustergrams.

clustergram currently supports K-Means and using scikit-learn (inlcuding Mini-Batch implementation) and RAPIDS.AI cuML (if you have a CUDA-enabled GPU), Gaussian Mixture Model (scikit-learn only) and hierarchical clustering based on scipy.hierarchy. Alternatively, we can create clustergram based on labels and data derived from alternative custom clustering algorithms. It provides a sklearn-like API and plots clustergram using matplotlib, which gives it a wide range of styling options to match your publication style.


You can install clustergram from conda or pip:

conda install clustergram -c conda-forge


pip install clustergram

In any case, you still need to install your selected backend (scikit-learn and scipy or cuML).

from clustergram import Clustergram
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale

Let us look at some examples to understand how clustergram looks and what to do with it.

Iris flower data set

The first example which we try to analyse using clustergram is the famous Iris flower data set. It contains data on three species of Iris flowers measuring sepal width and length and petal width and length. We can start with some exploration:

iris = sns.load_dataset("iris")
g = sns.pairplot(iris, hue="species")
g.fig.suptitle("Iris flowers", y=1.01)
Text(0.5, 1.01, 'Iris flowers')

It seems that setosa is a relatively well-defined group, while the difference between versicolor and virginica is smaller as they partially overlap (or entirely in the case of sepal width).

Okay, so we know how the data looks. Now we can check how does clustergram look. Remember that we know that there are three clusters, and we should ideally be able to recognise this from clustergram. I am saying ideally because even though there are known labels, it does not mean that our data or clustering method are able to distinguish between those classes.

Let’s start with K-Means clustering. To get a stable result, we can run a clustergram with 100 initialisations. Feel free to run even more.

data = scale(iris.drop(columns=['species']))
cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
ax = cgram.plot(figsize=(10, 8))
ax.set_title('K-Means (scikit-learn)')
Text(0.5, 1.0, 'K-Means (scikit-learn)')

On the x axis, we can see the number of clusters. Points represent a centre of each cluster (by default) weighted by the first principal component (that helps with the diagram’s readability). The lines connecting points and their thickness represent observations moving between clusters. Therefore, we can read when new clusters are formed as a split of a single existing class and when they are formed based on observations from two clusters.

We’re looking for the separation, i.e., did an additional cluster bring any meaningful split? The step from one cluster to two is a big one - nice and clear separation. From two to three, we also have quite a nice split in the top branch. But from 3 to 4, there is no visible difference because the new fourth cluster is almost the same as the existing bottom branch. Although it is now separated into two, this split does not give us much information. Therefore, we could conclude that the ideal number of clusters for Iris data is three.

We can also check some additional information, like a silhouette score or Calinski-Harabazs score.

fig, axs = plt.subplots(2, figsize=(10, 10), sharex=True)
cgram.silhouette_score().plot(xlabel="Number of clusters (k)", ylabel="Silhouette score", ax=axs[0])
cgram.calinski_harabasz_score().plot(xlabel="Number of clusters (k)", ylabel="Calinski-Harabasz score", ax=axs[1])

These plots would suggest 3-4 clusters, similarly to clustergram, but they are not very conclusive.

Palmer penguins data set

Now let’s try different data, one where clusters are a bit more complicated to assess. Palmer penguins contain similar data as Iris example, but it measures several attributes of three species of penguins.

penguins = sns.load_dataset("penguins")
g = sns.pairplot(penguins, hue="species")
g.fig.suptitle("Palmer penguins", y=1.01)
Text(0.5, 1.01, 'Palmer penguins')

Looking at the situation, we see that the overlap between species is much higher than before. It will likely be much more complicated to identify them. Again, we know that there are three clusters, but that does not mean that data has the power to distinguish between them. In this case, it may be especially tricky to differentiate between Adelie and Chinstrap penguins.

data = scale(penguins.drop(columns=['species', 'island', 'sex']).dropna())
cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2088/ in <module>
      1 cgram = Clustergram(range(1, 10), n_init=100, verbose=False)
----> 2

~/checkouts/ in fit(self, data, **kwargs)
    207         if self.backend == "sklearn":
    208             if self.method == "kmeans":
--> 209                 self._kmeans_sklearn(data, minibatch=False, **kwargs)
    210             elif self.method == "minibatchkmeans":
    211                 self._kmeans_sklearn(data, minibatch=True, **kwargs)

~/checkouts/ in _kmeans_sklearn(self, data, minibatch, **kwargs)
    245                 )
    246             else:
--> 247                 results = KMeans(n_clusters=n, **self.engine_kwargs).fit(data, **kwargs)
    249             self.labels[n] = results.labels_

~/checkouts/ in fit(self, X, y, sample_weight)
   1188             # run a k-means once
-> 1189             labels, inertia, centers, n_iter_ = kmeans_single(
   1190                 X,
   1191                 sample_weight,

~/checkouts/ in _kmeans_single_elkan(X, sample_weight, centers_init, max_iter, verbose, x_squared_norms, tol, n_threads)
    540         )
--> 542     inertia = _inertia(X, sample_weight, centers, labels, n_threads)
    544     return labels, inertia, centers, i + 1

ax = cgram.plot(figsize=(10, 8))
ax.set_title('K-Means (scikit-learn)')
Text(0.5, 1.0, 'K-Means (scikit-learn)')

We’re looking for separations, and this clustergram shows plenty. It is actually quite complicated to determine the optimal number of clusters. However, since we know what happens between different options, we can play with that. If we have a reason to be conservative, we can go with 4 clusters (I know, it is already more than the initial species). But further splits are also reasonable, which indicates that even higher granularity may provide useful insight, that there might be meaningful groups.

Can we say it is three? Since we know it should be three… Well, not really. The difference between the split from 2 - 3 and that from 3 - 4 is slight. However, the culprit here is K-Means, not clustergram. It just simply cannot correctly cluster these data due to the overlaps and the overall structure.

Let’s have a look at how the Gaussian Mixture Model does.

cgram = Clustergram(range(1, 10), n_init=100, method="gmm", verbose=False)
ax = cgram.plot(figsize=(10, 8))
ax.set_title('Gaussian Mixture Model (scikit-learn)')
K=1 fitted in 0.5946810245513916 seconds.
K=2 fitted in 0.25310826301574707 seconds.
K=3 fitted in 0.6201028823852539 seconds.
K=4 fitted in 0.6058881282806396 seconds.
K=5 fitted in 0.6077260971069336 seconds.
K=6 fitted in 0.7751941680908203 seconds.
K=7 fitted in 1.1589140892028809 seconds.
K=8 fitted in 1.319072961807251 seconds.
K=9 fitted in 1.5859079360961914 seconds.
Text(0.5, 1.0, 'Gaussian Mixture Model (scikit-learn)')

The result is very similar, though the difference between the third and fourth split is more pronounced. Even here, I would probably go with a four cluster solution.

A situation like this happens very often. The ideal case does not exist. We ultimately need to make a decision on the optimal number of clusters. Clustergam gives us additional insights into what happens between different options, how it splits. We can tell that the four-cluster option in Iris data is not helpful. We can tell that Palmer penguins may be tricky to cluster using K-Means, that there is no decisive right solution. Clustergram does not give an easy answer, but it gives us additional insight, and it is upon us how we interpret it.

If you want to play with the examples used in this documentation, the Jupyter notebook is on GitHub. You can also run it in an interactive binder environment in your browser.

For more information, check Tal Galili’s blog post and original papers by Matthias Schonlau.

Give it a go!