pytorch 之 DataSets、DataLoaders & Build Model

158 阅读2分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第1天,点击查看活动详情

本文是 pytorch 官方教程学习笔记,包括 DataSets&DataLoaders 和 Build Model 部分

DataSets&DataLoaders官方教程
Build Model官方教程

DataSets & DataLoaders

下载数据集

from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="E:\Downloads\train",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="E:\Downloads\test",
    train=False,
    download=True,
    transform=ToTensor()
)

root 表示数据集存放路径,train 标识训练集/测试集,download 表示是否从网上下载,transform 指数据集格式转换方式

DataSet

针对现有数据集,往往需要先定义一个 DataSet 类,该类必须包含 __init__() __len__() __getitem__ 三个函数,分别负责加载数据集、返回数据条数、根据索引返回数据

DataLoader

DataSet 每次返回一组数据和标签,但是在训练模型时往往采用 minibatch 的方式,每个 epoch 重洗(shuffle)数据防止过拟合,因此便引入 DataLoader
样例如下,DataLoader 的第一个参数为 DataSet 类的对象

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Build Model

定义神经网络类

from torch import nn
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        return x

自己定义的网络模型必须是 nn.Module 的子类,在 __init__ 中定义网络结构,在 forward 中定义一次前向传播应执行的操作
上述代码中,nn.Flatten() 将图像每个通道的二维数据拉平成一维,第 0 维数据即通道数保持不变;nn.Sequential() 定义对数据的按序一连串操作,其中 nn.Linear() 执行线性变换,nn.ReLU() 是激活函数,负责添加非线性因子

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = NeuralNetwork().to(device)
print(model)

创建一个模型对象,输出其结构如下图所示

image.png

测试网络

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
print(logits)
pred_probab = nn.Softmax(dim=1)(logits)
print(pred_probab)
y_pred = pred_probab.argmax(1)
print(y_pred)

上述测试程序输出结果如下图所示

image.png 模型输入 X 结构为 1*28*28,输出 logits 结构为 1*10
nn.Softmax(dim=1) 表示在第一维做 Softmax,tensor.argmax(1) 同理,表示在第一维求 argmax