持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第20天,点击查看活动详情
之前讲了一下pytorch怎么读入数据集,怎么使用DataLoader加载数据集。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
上边这个代码是读取FashionMNIST数据集的训练集和测试集。
root设置了一个data文件,代码运行之后数据集就会存到该文件夹下;
download是如果本地data指定的文件夹中没有数据集是否要自动下载;
重点就是这里的transform字段,是对数据集转化为张量(Tensor),本节我们就详细说一下transform
详细内容可以回去看看这个文章:PyTorch数据集处理 - 掘金 (juejin.cn)
训练机器学习算法时候,我们的数据读进来不一定是模型可以直接处理的类型,这里讲一下我们如何使用Transforms来对数据执行一些操作,并使其适合于模型训练过程。
所有的TorchVision的数据集都有两个参数:
-
transform:用于修改数据集的特征(features) -
target_transform:用于修改数据集的标签(labels)
torchvision.transforms为我们提供了一些开箱即用的常用转换方法。
我们还是用FashionMNIST接着说,FashionMNIST的数据类型读进来默认是PIL格式的,标签是整形。在训练过程中我们需要将PIL格式的图片转化为张量(tensor),将其标签转化为one-hot向量。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
transform
transform=ToTensor() 将PIL格式的图或者ndarray转化为FloatTensor,在这里会做一个scale,将图像的像素值转化到[0. ,1.]。
target_transform
对于标签转化我们可以使用 Lambda,使用该方法可以进行用户自定义的格式转化。这里我们要做的就是将整数转化为one-hot向量。
target_transform = Lambda(
lambda y: torch.zeros(10, dtype=torch.float)
.scatter_(0, torch.tensor(y), value=1))
我们的数据集是分为是个类别的,所以第一步声明向量长度为10的零向量torch.zeros。
然后使用scatter_方法,根据label y将零向量的对应位置的0设置为1。