pytorch中Dataset和DataLoader的使用

154 阅读2分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第1天,点击查看活动详情

1.datasets下载数据集

image.png

root :代表着路径,表示现存或者准备存储的地方。

train :代表是否下载训练数据集,如果否的话就下载测试数据集

transform: 如果想对数据集进行什么变化,在这里进行操作

target_transform:跟上面的一样

download:如果是True就会从网上下载数据集

这里使用datasets进行下载操作

import torchvision
train_set=torchvision.datasets.CIFAR10(root='../dataset',train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root='../dataset',train=False,download=True)

image.png 也可以将网址复制下来后在迅雷等软件上下载,但是在导入的时候需要注意一点:导入的是解压前的文件,解压后再导入不仅读取不了,而且还会报错然后重新下载。

image.png 除了这个数据集之外,datasets还提供很多其他数据集。

2.dataloader

首先了解官方文档的描述。

image.png

dataset: 这里就是指将要进行操作的数据集,也可以直接使用上面dataset下载好的数据集进行使用。

batch_size :这里指每次操作的数据个数,每次下载两个或者多个等都可以进行设置,默认情况为1

shuffle;指打乱的意思将数据进行打乱

num_workers:指使用多进程还是单进程,多进程自然会比较快

import torchvision
from torch.utils.data import DataLoader

test_set=torchvision.datasets.CIFAR10(root='./datasets',train=False,transform=torchvision.transforms.ToTensor())
test_loader=DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target=test_set[0]
print(img.shape)
print(target)

image.png 我们使用for循环来对data_loader中的数据进行抓取。

import torchvision
from torch.utils.data import DataLoader

test_set=torchvision.datasets.CIFAR10(root='./datasets',train=False,transform=torchvision.transforms.ToTensor())
test_loader=DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target=test_set[0]
print(img.shape)
print(target)
for data in test_loader:
    imgs,targets=data
    print(imgs.shape)
    print(targets)

这里需注意,没轮抓取的size是一样的,但是由于shuffle设定为True,所以每次抓取的图片是随机不一样的。

image.png

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_set=torchvision.datasets.CIFAR10(root='./datasets',train=False,transform=torchvision.transforms.ToTensor())
test_loader=DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target=test_set[0]
print(img.shape)
print(target)
writer = SummaryWriter("666")
step=0
for data in test_loader:
    imgs,targets=data
    #print(imgs.shape)
    #print(targets)
    writer.add_images("test_set",imgs,step)
    step+=1
writer.close()

3.总结

了解了各个参数的含义后,就可以更加深入的理解官方文档给的案例操作.

image.png

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

熟练下载并打开几个案例后就可以更好的去掌握Dataset和Dataloader的使用了。