scikit-learn .predict() default threshold scikit-learn .predict() default threshold python python

scikit-learn .predict() default threshold


The threshold can be set using clf.predict_proba()

for example:

from sklearn.tree import DecisionTreeClassifierclf = DecisionTreeClassifier(random_state = 2)clf.fit(X_train,y_train)# y_pred = clf.predict(X_test)  # default threshold is 0.5y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool) # set threshold as 0.3


is scikit's classifier.predict() using 0.5 by default?

In probabilistic classifiers, yes. It's the only sensible threshold from a mathematical viewpoint, as others have explained.

What would be the way to do this in a classifier like MultinomialNB that doesn't support class_weight?

You can set the class_prior, which is the prior probability P(y) per class y. That effectively shifts the decision boundary. E.g.

# minimal dataset>>> X = [[1, 0], [1, 0], [0, 1]]>>> y = [0, 0, 1]# use empirical prior, learned from y>>> MultinomialNB().fit(X,y).predict([1,1])array([0])# use custom prior to make 1 more likely>>> MultinomialNB(class_prior=[.1, .9]).fit(X,y).predict([1,1])array([1])


The threshold in scikit learn is 0.5 for binary classification and whichever class has the greatest probability for multiclass classification. In many problems a much better result may be obtained by adjusting the threshold. However, this must be done with care and NOT on the holdout test data but by cross validation on the training data. If you do any adjustment of the threshold on your test data you are just overfitting the test data.

Most methods of adjusting the threshold is based on the receiver operating characteristics (ROC) and Youden's J statistic but it can also be done by other methods such as a search with a genetic algorithm.

Here is a peer review journal article describing doing this in medicine:

http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2515362/

So far as I know there is no package for doing it in Python but it is relatively simple (but inefficient) to find it with a brute force search in Python.

This is some R code that does it.

## load dataDD73OP <- read.table("/my_probabilites.txt", header=T, quote="\"")library("pROC")# No smoothingroc_OP <- roc(DD73OP$tc, DD73OP$prob)auc_OP <- auc(roc_OP)auc_OPArea under the curve: 0.8909plot(roc_OP)# Best threshold# Method: Youden#Youden's J statistic (Youden, 1950) is employed. The optimal cut-off is the threshold that maximizes the distance to the identity (diagonal) line. Can be shortened to "y".#The optimality criterion is:#max(sensitivities + specificities)coords(roc_OP, "best", ret=c("threshold", "specificity", "sensitivity"), best.method="youden")#threshold specificity sensitivity #0.7276835   0.9092466   0.7559022