导入相应的库和模块:
%matplotlib inline #让图表显示在单元格下方
import torch
import torchvision #导入pytorch的视觉工具箱
from torch.utils import data #导入数据处理模块
from torchvision import transforms #导入图像预处理模块
from d2l import torch as d2l
d2l.use_svg_display() #将图表设置为svg格式
获取数据源:
- 创建一个把图片转成张量的转换器
- 下载并加载Fashion-MNIST训练集和测试集
- 显示训练集和测试集的长度
- 检查训练集第一张图片的形状
trans = transforms.ToTensor() #创建一个能把图片转成张量的转换器
mnist_train = torchvision.datasets.FashionMNIST( #下载并加载Fashion-MNIST训练集
root = "./data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root = "./data",train=False,transform=trans,download=True)
len(mnist_train),len(mnist_test)
mnist_train[0][0].shape #查看训练集第一张图像的形状
定义图像处理工具:
- get_fashion_mnist_labels函数:
- 作用:把数字标签转换成文本标签
- 实现过程:定义类别列表,遍历数字标签并转换成整型,在类别列表中找到相应的文本标签,把它返回。
- show_images函数:
- 作用:在网格中显示多张图片
- 传入参数:图片、行数、列数、图像比例
- 实现过程:
- 计算并设定整个网格图像的尺寸
- 创建子图并返回子图列表
- 把多维数组平铺成一维列表
- 每对打包成一张子图和一张图片,并编上序号
- 若图片是张量,就转成numpy数组并显示;若不是张量,直接显示
- 隐藏子图的x,y轴
- 把子图标题设置成文本标签
- 返回绘制好的子图和图片的网格画布
def get_fashion_mnist_labels(labels):
text_labels=[
't-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']
return [text_labels[int(i)] for i in labels] #把数字标签转换成文本标签
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5): #定义一个用来在网格里显示多张图片的函数
figsize = (num_cols * scale,num_rows * scale) #计算并设定整个网格图像的尺寸
_ , axes = d2l.plt.subplots(num_rows,num_cols,figsize = figsize) #创建子图并返回子图列表
axes = axes.flatten() #把多维数组平铺成一维列表
for i ,(ax,img) in enumerate(zip(axes,imgs)): #每对打包成一张子图和一张图片,并编上序号
if torch.is_tensor(img): #若图片是张量,就转成numpy数组并显示
ax.imshow(img.numpy())
else: #若不是张量,直接显示
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False) #隐藏子图的x,y轴
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i]) #把子图标题设置成文本标签
return axes #返回绘制好的子图和图片的网格画布
图像批处理:
- 取出第一批18张图像存为X,标签存为y
- 按2行9列的方式画出图片,尺寸是28×28,标题是文本标签
X,y = next(iter(data.DataLoader(mnist_train,batch_size = 18))) #取出第一批18张图像存为X,标签存为y
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y)) #按2行9列的方式画出图片,标题是文本标签
并行处理并计时:
- get_dataloader_workers函数,定义并行处理的进程数是4
- 把训练数据分批打乱,用4个进程并行加速
- 启动计时器,用空循环把数据集完整的遍历,预热数据并计算完整的耗时
- 打印完整耗时
batch_size = 256
def get_dataloader_workers():
return 4 #4个并行处理
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()) #把训练数据分批打乱,用4个进程并行加速
timer = d2l.Timer() #启动计时器
for X,y in train_iter: #用空循环把数据集完整的遍历,预热数据并计算完整的耗时
continue
f'{timer.stop():.2f} sec' #打印完整耗时
数据的正式处理:
- load_data_fashion_mnist函数:
- 作用:加载和打包训练数据和测试数据
- 实现过程:
- 创建只包含张量的列表,若传入resize参数,则缩放图片
- 把张量列表变成流水线,以流水线的方式加载训练集和测试集
- 一次性创建好训练和测试两个迭代器,返回给调用者
def load_data_fashion_mnist(batch_size,resize=None): #定义能加载和打包训练和测试数据的函数
trans = [transforms.ToTensor()] #创建只包含张量的列表
if resize: #若传入resize,则缩放图片
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans) #把张量列表变成流水线
mnist_train = torchvision.datasets.FashionMNIST( #以流水线的方式加载训练集和测试集
root="./data",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="./data",train=False,transform=trans,download=True)
return (data.DataLoader(mnist_train,batch_size,shuffle=True, #一次性创建好训练和测试两个迭代器,返回给调用者
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test,batch_size,shuffle=True,
num_workers=get_dataloader_workers()))