ROC for multiclass classification ROC for multiclass classification python python

ROC for multiclass classification


As people mentioned in comments you have to convert your problem into binary by using OneVsAll approach, so you'll have n_class number of ROC curves.

A simple example:

from sklearn.metrics import roc_curve, aucfrom sklearn import datasetsfrom sklearn.multiclass import OneVsRestClassifierfrom sklearn.svm import LinearSVCfrom sklearn.preprocessing import label_binarizefrom sklearn.model_selection import train_test_splitimport matplotlib.pyplot as pltiris = datasets.load_iris()X, y = iris.data, iris.targety = label_binarize(y, classes=[0,1,2])n_classes = 3# shuffle and split training and test setsX_train, X_test, y_train, y_test =\    train_test_split(X, y, test_size=0.33, random_state=0)# classifierclf = OneVsRestClassifier(LinearSVC(random_state=0))y_score = clf.fit(X_train, y_train).decision_function(X_test)# Compute ROC curve and ROC area for each classfpr = dict()tpr = dict()roc_auc = dict()for i in range(n_classes):    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])    roc_auc[i] = auc(fpr[i], tpr[i])# Plot of a ROC curve for a specific classfor i in range(n_classes):    plt.figure()    plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])    plt.plot([0, 1], [0, 1], 'k--')    plt.xlim([0.0, 1.0])    plt.ylim([0.0, 1.05])    plt.xlabel('False Positive Rate')    plt.ylabel('True Positive Rate')    plt.title('Receiver operating characteristic example')    plt.legend(loc="lower right")    plt.show()

enter image description hereenter image description hereenter image description here


This works for me and is nice if you want them on the same plot. It is similar to @omdv's answer but maybe a little more succinct.

def plot_multiclass_roc(clf, X_test, y_test, n_classes, figsize=(17, 6)):    y_score = clf.decision_function(X_test)    # structures    fpr = dict()    tpr = dict()    roc_auc = dict()    # calculate dummies once    y_test_dummies = pd.get_dummies(y_test, drop_first=False).values    for i in range(n_classes):        fpr[i], tpr[i], _ = roc_curve(y_test_dummies[:, i], y_score[:, i])        roc_auc[i] = auc(fpr[i], tpr[i])    # roc for each class    fig, ax = plt.subplots(figsize=figsize)    ax.plot([0, 1], [0, 1], 'k--')    ax.set_xlim([0.0, 1.0])    ax.set_ylim([0.0, 1.05])    ax.set_xlabel('False Positive Rate')    ax.set_ylabel('True Positive Rate')    ax.set_title('Receiver operating characteristic example')    for i in range(n_classes):        ax.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f) for label %i' % (roc_auc[i], i))    ax.legend(loc="best")    ax.grid(alpha=.4)    sns.despine()    plt.show()plot_multiclass_roc(full_pipeline, X_test, y_test, n_classes=16, figsize=(16, 10))