difference between StratifiedKFold and StratifiedShuffleSplit in sklearn difference between StratifiedKFold and StratifiedShuffleSplit in sklearn python python

difference between StratifiedKFold and StratifiedShuffleSplit in sklearn


In KFolds, each test set should not overlap, even with shuffle. With KFolds and shuffle, the data is shuffled once at the start, and then divided into the number of desired splits. The test data is always one of the splits, the train data is the rest.

In ShuffleSplit, the data is shuffled every time, and then split. This means the test sets may overlap between the splits.

See this block for an example of the difference. Note the overlap of the elements in the test sets for ShuffleSplit.

splits = 5tx = range(10)ty = [0] * 5 + [1] * 5from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFoldfrom sklearn import datasetskfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)print("KFold")for train_index, test_index in kfold.split(tx, ty):    print("TRAIN:", train_index, "TEST:", test_index)print("Shuffle Split")for train_index, test_index in shufflesplit.split(tx, ty):    print("TRAIN:", train_index, "TEST:", test_index)

Output:

KFoldTRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]Shuffle SplitTRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]

As for when to use them, I tend to use KFolds for any cross validation, and I use ShuffleSplit with a split of 2 for my train/test set splits. But I'm sure there are other use cases for both.


@Ken Syme already has a very good answer. I just want to add something.

  • StratifiedKFold is a variation of KFold. First, StratifiedKFold shuffles your data, after that splits the data into n_splits parts and Done. Now, it will use each part as a test set. Note that it only and always shuffles data one time before splitting.

With shuffle = True, the data is shuffled by your random_state. Otherwise, the data is shuffled by np.random (as default).For example, with n_splits = 4, and your data has 3 classes (label) for y (dependent variable). 4 test sets cover all the data without any overlap.

enter image description here

  • On the other hand, StratifiedShuffleSplit is a variation of ShuffleSplit.First, StratifiedShuffleSplit shuffles your data, and then it also splits the data into n_splits parts. However, it's not done yet. After this step, StratifiedShuffleSplit picks one part to use as a test set.Then it repeats the same process n_splits - 1 other times, to get n_splits - 1 other test sets. Look at the picture below, with the same data, but this time, the 4 test sets do not cover all the data, i.e there are overlaps among test sets.

enter image description here

So, the difference here is that StratifiedKFold just shuffles and splits once, therefore the test sets do not overlap, while StratifiedShuffleSplit shuffles each time before splitting, and it splits n_splits times, the test sets can overlap.

  • Note: the two methods uses "stratified fold" (that why "stratified" appears in both names). It means each part preserves the same percentage of samples of each class (label) as the original data. You can read more at cross_validation documents


Output examples of KFold, StratifiedKFold, StratifiedShuffleSplit:Output examples of KFold, StratifiedKFold, StratifiedShuffleSplit

The above pictorial output is an extension of @Ken Syme's code:

from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplitSEED = 43SPLIT = 3X_train = [0,1,2,3,4,5,6,7,8]y_train = [0,0,0,0,0,0,1,1,1]   # note 6,7,8 are labelled class '1'print("KFold, shuffle=False (default)")kf = KFold(n_splits=SPLIT, random_state=SEED)for train_index, test_index in kf.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)print("KFold, shuffle=True")kf = KFold(n_splits=SPLIT, shuffle=True, random_state=SEED)for train_index, test_index in kf.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)print("\nStratifiedKFold, shuffle=False (default)")skf = StratifiedKFold(n_splits=SPLIT, random_state=SEED)for train_index, test_index in skf.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)    print("StratifiedKFold, shuffle=True")skf = StratifiedKFold(n_splits=SPLIT, shuffle=True, random_state=SEED)for train_index, test_index in skf.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)    print("\nStratifiedShuffleSplit")sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=3)for train_index, test_index in sss.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)print("\nStratifiedShuffleSplit (can customise test_size)")sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=2)for train_index, test_index in sss.split(X_train, y_train):    print("TRAIN:", train_index, "TEST:", test_index)