参与拿奖:本文已参与「新人创作礼」活动,一起开启掘金创作之路
1.下载数据集
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 转换方式,将图片数据转换为张量
dataset_transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 如果本地不存在数据集就下载数据集,由参数download决定
tran_set=torchvision.datasets.CIFAR10(root="./data/02 data",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./data/02 data",train=False,transform=dataset_transform,download=True)
# print(test_set[0])
writer=SummaryWriter("logs01")
for i in range(10):
img,target=test_set[i]
writer.add_image("test_set",img,i)
writer.close()
在终端找到logs01所在文件的目录,运行:tensorboard --logdir=logs1
点开生成的链接之后,如下图所示:
1.1 查看数据
# 查看数据
print(test_set.classes)
img,target=test_set[0]
print(img)
print(target,test_set.classes[target])
2.使用DataLoader加载数据集
2.1 导入相关包
import torch
import torchvision
# dataset 原始数据集
# batch_size 每次抓取数据的大小
from torch.utils.data import DataLoader,RandomSampler,SequentialSampler
# 获取数据
test_data=torchvision.datasets.CIFAR10('./data/02 data',train=False,transform=torchvision.transforms.ToTensor())
2.2 使用DataLoader加载数据集
这里主要参数有dataset(需要加载的数据)、batch_size(每一次加载多少大小的数据)、sampler(采样的方法,主要分顺序采样、随机采样)、shuffle(随机抽取)、num_workrs(线程数目)、drop_last(是否去掉余数),其中顺序采样(SequentialSampler)为每次按照顺序加载batch_size大小的数据,随机采样(RandomSampler)每次随机加载batch_size大小的数据,两者都保证数据会被全部加载。
test_loader=DataLoader(dataset=test_data #数据集
,batch_size=4 #每次抓取数据的大小
,sampler=RandomSampler(test_data) # 随机选取
# ,shuffle=True # 随机选取
,num_workers=0
,drop_last=False #是否去掉余数
)
2.3 循环打印数据集
for data in test_loader: #每次取出batch_size大小的数据进行打包
imgs,targets=data
print(imgs.shape,targets)
2.4 使用tensorboard查看
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("dataloader")
step=0
for data in test_loader:
imgs,targets=data
writer.add_images("test_data",imgs,step)
step+=1
writer.close()
writer=SummaryWriter("dataloader")
for epoch in range(2): #抓取次数
step=0
for data in test_loader:
imgs,targets=data
writer.add_images("Epoch: {}".format(epoch),imgs,step)
step+=1
writer.close()
在终端找到dataloader所在文件的目录,运行:tensorboard --logdir=dataloader
打开链接得到如下图所示的结果:
参考资料
[1] b站课程链接
[2] 手敲代码链接