Taking subsets of a pytorch dataset Taking subsets of a pytorch dataset python python

Taking subsets of a pytorch dataset


torch.utils.data.Subset is easier, supports shuffle, and doesn't require writing your own sampler:

import torchvisionimport torchtrainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=None)evens = list(range(0, len(trainset), 2))odds = list(range(1, len(trainset), 2))trainset_1 = torch.utils.data.Subset(trainset, evens)trainset_2 = torch.utils.data.Subset(trainset, odds)trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,                                            shuffle=True, num_workers=2)trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,                                            shuffle=True, num_workers=2)


You can define a custom sampler for the dataset loader avoiding recreating the dataset (just creating a new loader for each different sampling).

class YourSampler(Sampler):    def __init__(self, mask):        self.mask = mask    def __iter__(self):        return (self.indices[i] for i in torch.nonzero(self.mask))    def __len__(self):        return len(self.mask)trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform)sampler1 = YourSampler(your_mask)sampler2 = YourSampler(your_other_mask)trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,                                          sampler = sampler1, shuffle=False, num_workers=2)trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,                                          sampler = sampler2, shuffle=False, num_workers=2)

PS: You can find more info here: http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler