Transform一行代码抢救数据集数据类型

162 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第20天,点击查看活动详情


之前讲了一下pytorch怎么读入数据集,怎么使用DataLoader加载数据集。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

上边这个代码是读取FashionMNIST数据集的训练集和测试集。

root设置了一个data文件,代码运行之后数据集就会存到该文件夹下;

download是如果本地data指定的文件夹中没有数据集是否要自动下载;

重点就是这里的transform字段,是对数据集转化为张量(Tensor),本节我们就详细说一下transform

详细内容可以回去看看这个文章:PyTorch数据集处理 - 掘金 (juejin.cn)


训练机器学习算法时候,我们的数据读进来不一定是模型可以直接处理的类型,这里讲一下我们如何使用Transforms来对数据执行一些操作,并使其适合于模型训练过程。

所有的TorchVision的数据集都有两个参数:

  • transform:用于修改数据集的特征(features)

  • target_transform:用于修改数据集的标签(labels)

torchvision.transforms为我们提供了一些开箱即用的常用转换方法。

我们还是用FashionMNIST接着说,FashionMNIST的数据类型读进来默认是PIL格式的,标签是整形。在训练过程中我们需要将PIL格式的图片转化为张量(tensor),将其标签转化为one-hot向量。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

transform

transform=ToTensor() 将PIL格式的图或者ndarray转化为FloatTensor,在这里会做一个scale,将图像的像素值转化到[0. ,1.]。

target_transform

对于标签转化我们可以使用 Lambda,使用该方法可以进行用户自定义的格式转化。这里我们要做的就是将整数转化为one-hot向量。

target_transform = Lambda(
                            lambda y: torch.zeros(10, dtype=torch.float)
                            .scatter_(0, torch.tensor(y), value=1))

我们的数据集是分为是个类别的,所以第一步声明向量长度为10的零向量torch.zeros。 然后使用scatter_方法,根据label y将零向量的对应位置的0设置为1。