How to get the samples in each cluster? How to get the samples in each cluster? python python

How to get the samples in each cluster?


I had a similar requirement and i am using pandas to create a new dataframe with the index of the dataset and the labels as columns.

data = pd.read_csv('filename')km = KMeans(n_clusters=5).fit(data)cluster_map = pd.DataFrame()cluster_map['data_index'] = data.index.valuescluster_map['cluster'] = km.labels_

Once the DataFrame is available is quite easy to filter,For example, to filter all data points in cluster 3

cluster_map[cluster_map.cluster == 3]


If you have a large dataset and you need to extract clusters on-demand you'll see some speed-up using numpy.where. Here is an example on the iris dataset:

from sklearn.cluster import KMeansfrom sklearn import datasetsimport numpy as npcenters = [[1, 1], [-1, -1], [1, -1]]iris = datasets.load_iris()X = iris.datay = iris.targetkm = KMeans(n_clusters=3)km.fit(X)

Define a function to extract the indices of the cluster_id you provide. (Here are two functions, for benchmarking, they both return the same values):

def ClusterIndicesNumpy(clustNum, labels_array): #numpy     return np.where(labels_array == clustNum)[0]def ClusterIndicesComp(clustNum, labels_array): #list comprehension    return np.array([i for i, x in enumerate(labels_array) if x == clustNum])

Let's say you want all samples that are in cluster 2:

ClusterIndicesNumpy(2, km.labels_)array([ 52,  77, 100, 102, 103, 104, 105, 107, 108, 109, 110, 111, 112,       115, 116, 117, 118, 120, 122, 124, 125, 128, 129, 130, 131, 132,       134, 135, 136, 137, 139, 140, 141, 143, 144, 145, 147, 148])

Numpy wins the benchmark:

%timeit ClusterIndicesNumpy(2,km.labels_)100000 loops, best of 3: 4 µs per loop%timeit ClusterIndicesComp(2,km.labels_)1000 loops, best of 3: 479 µs per loop

Now you can extract all of your cluster 2 data points like so:

X[ClusterIndicesNumpy(2,km.labels_)]array([[ 6.9,  3.1,  4.9,  1.5],        [ 6.7,  3. ,  5. ,  1.7],       [ 6.3,  3.3,  6. ,  2.5],        ... #truncated

Double-check the first three indices from the truncated array above:

print X[52], km.labels_[52]print X[77], km.labels_[77]print X[100], km.labels_[100][ 6.9  3.1  4.9  1.5] 2[ 6.7  3.   5.   1.7] 2[ 6.3  3.3  6.   2.5] 2


Actually a very simple way to do this is:

clusters=KMeans(n_clusters=5)df[clusters.labels_==0]

The second row returns all the elements of the df that belong to the 0th cluster. Similarly you can find the other cluster-elements.