clustergram API reference


class clustergram.Clustergram(k_range, backend='sklearn', method='kmeans', pca_weighted=True, pca_kwargs={}, verbose=True, **kwargs)

Clustergram class mimicking the interface of clustering class (e.g. KMeans).

Clustergram is a graph used to examine how cluster members are assigned to clusters as the number of clusters increases. This graph is useful in exploratory analysis for nonhierarchical clustering algorithms such as k-means and for hierarchical cluster algorithms when the number of observations is large enough to make dendrograms impractical.

Clustergram offers two backends for the computation - scikit-learn which uses CPU and RAPIDS.AI cuML, which uses GPU. Note that both are optional dependencies, but you will need at least one of them to generate clustergram.


iterable of integer values to be tested as k.

backendstring (‘sklearn’ or ‘cuML’, default ‘sklearn’)

Whether to use sklearn’s implementation of KMeans and PCA or cuML version. Sklearn does computation on CPU, cuML on GPU.

methodstring (‘kmeans’ or ‘gmm’)

Clustering method. kmeans uses KMeans clustering, ‘gmm’ Gaussian Mixture Model. ‘gmm’ is currently supported only with ‘sklearn’ backend.

pca_weightedbool (default True)

Whether use PCA weighted mean of clusters or standard mean of clusters.

pca_kwargsdict (default {})

Additional arguments passed to the PCA object, e.g. svd_solver. Applies only if pca_weighted=True.

verbosebool (default True)

Print progress and time of individual steps.


Additional arguments passed to the KMeans object, e.g. random_state.


The clustergram: A graph for visualizing hierarchical and nonhierarchical cluster analyses:

Tal Galili’s R implementation:


>>> c_gram = clustergram.Clustergram(range(1, 9))
>>> c_gram.plot()

Specifying parameters:

>>> c_gram2 = clustergram.Clustergram(
...     range(1, 9), backend="cuML", pca_weighted=False, random_state=0
... )
>>> c_gram2.plot(figsize=(12, 12))

DataFrame with (weighted) means of clusters.


fit(data, **kwargs)

Compute (weighted) means of clusters.

plot([ax, size, linewidth, cluster_style, …])

Generate clustergram plot based on cluster centre mean values.

fit(data, **kwargs)

Compute (weighted) means of clusters.


Input data to be clustered. It is expected that data are scaled. Can be numpy.array, pandas.DataFrame or their RAPIDS counterparts.


Additional arguments passed to the, e.g. sample_weight.


Fitted clustergram.

plot(ax=None, size=1, linewidth=1, cluster_style=None, line_style=None, figsize=None, k_range=None)

Generate clustergram plot based on cluster centre mean values.

axmatplotlib.pyplot.Artist (default None)

matplotlib axis on which to draw the plot

sizefloat (default 1)

multiplier of the size of a cluster centre indication. Size is determined as 500 / count of observations in a cluster multiplied by size.

linewidthfloat (default 1)

multiplier of the linewidth of a branch. Line width is determined as 50 / count of observations in a branch multiplied by linewidth.

cluster_styledict (default None)

Style options to be passed on to the cluster centre plot, such as color, linewidth, edgecolor or alpha.

line_styledict (default None)

Style options to be passed on to branches, such as color, linewidth, edgecolor or alpha.

figsizetuple of integers (default None)

Size of the resulting matplotlib.figure.Figure. If the argument axes is given explicitly, figsize is ignored.

k_rangeiterable (default None)

iterable of integer values to be plotted. In none, Clustergram.k_range will be used. Has to be a substet of Clustergram.k_range.

axmatplotlib axis instance