创新互联www.cdcxhl.cn八线动态BGP香港云服务器提供商,新人活动买多久送多久,划算不套路!
这篇文章主要介绍pytorch随机采样SubsetRandomSampler()的方法,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!
这篇文章记录一个采样器都随机地从原始的数据集中抽样数据。抽样数据采用permutation。 生成任意一个下标重排,从而利用下标来提取dataset中的数据的方法
需要的库
import torch
使用方法
这里以MNIST举例
train_dataset = dsets.MNIST(root='./data', #文件存放路径 train=True, #提取训练集 transform=transforms.ToTensor(), #将图像转化为Tensor download=True) sample_size = len(train_dataset) sampler1 = torch.utils.data.sampler.SubsetRandomSampler( np.random.choice(range(len(train_dataset)), sample_size))