How to give sns.clustermap a precomputed distance matrix? How to give sns.clustermap a precomputed distance matrix? python python

How to give sns.clustermap a precomputed distance matrix?


You can pass the precomputed distance matrix as linkage to clustermap():

import pandas as pd, seaborn as snsimport scipy.spatial as sp, scipy.cluster.hierarchy as hcfrom sklearn.datasets import load_irissns.set(font="monospace")iris = load_iris()X, y = iris.data, iris.targetDF = pd.DataFrame(X, index = ["iris_%d" % (i) for i in range(X.shape[0])], columns = iris.feature_names)DF_corr = DF.T.corr()DF_dism = 1 - DF_corr   # distance matrixlinkage = hc.linkage(sp.distance.squareform(DF_dism), method='average')sns.clustermap(DF_dism, row_linkage=linkage, col_linkage=linkage)

For clustermap(distance_matrix) (i.e., without linkage passed), the linkage is calculated internally based on pairwise distances of the rows and columns in the distance matrix (see note below for full details) instead of using the elements of the distance matrix directly (the correct solution). As a result, the output is somewhat different from the one in the question:clustermap

Note: if no row_linkage is passed to clustermap(), the row linkage is determined internally by considering each row a "point" (observation) and calculating the pairwise distances between the points. So the row dendrogram reflects row similarity. Analogous for col_linkage, where each column is considered a point. This explanation should likely be added to the docs. Here the docs's first example modified to make the internal linkage calculation explicit:

import seaborn as sns; sns.set()import scipy.spatial as sp, scipy.cluster.hierarchy as hcflights = sns.load_dataset("flights")flights = flights.pivot("month", "year", "passengers")row_linkage, col_linkage = (hc.linkage(sp.distance.pdist(x), method='average')  for x in (flights.values, flights.values.T))g = sns.clustermap(flights, row_linkage=row_linkage, col_linkage=col_linkage)   # note: this produces the same plot as "sns.clustermap(flights)", where  #  clustermap() calculates the row and column linkages internally