【技术专题】PyTorch2 深度学习 - 自动微分(Autograd)与梯度优化 & 数据集与数据加载 & transform预处理转换模块

0 阅读5分钟

大家好,我是锋哥。最近连载更新《PyTorch2 深度学习》技术专题。

image.png 本课程主要讲解基于PyTorch2的深度学习核心知识,主要讲解包括PyTorch2框架入门知识,环境搭建,张量,自动微分,数据加载与预处理,模型训练与优化,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。。 同时也配套视频教程 《PyTorch 2 Python深度学习 视频教程》

自动微分(Autograd)与梯度优化

在PyTorch2中, 自动微分(Autograd)机制, 是 PyTorch 的核心功能之一,用于自动计算张量的导数(梯度)。

它的主要用途是:在神经网络反向传播过程中自动计算参数的梯度

在 PyTorch 中,只要一个张量的属性 requires_grad=True,系统就会跟踪它的所有运算,从而可以在反向传播时自动求出梯度。

基本原理

  • 计算图(Computational Graph) : PyTorch 会动态构建一张有向无环图(DAG),图的节点是张量,边是函数(如加法、乘法等)。 反向传播时,PyTorch 会沿着这张图从输出向输入依次计算梯度。
  • 反向传播(Backpropagation) : 调用 loss.backward() 时,PyTorch 会自动计算所有参与计算的 requires_grad=True 张量的梯度。
  • 梯度存储: 计算出的梯度会存放在每个张量的 .grad 属性中。

简单示例

import torch
​
# 创建一个张量并启用自动求导
x = torch.tensor(3.0, requires_grad=True)
​
# 构建一个函数 y = x^2
y = x ** 2# 自动求导(反向传播)
y.backward()
​
# 查看梯度 dy/dx
print(x.grad)  # 输出:tensor(6.)
print(x.grad.item())

运行输出:

tensor(6.)
6.0

神经网络训练中使用 Autograd

import torch
from torch import nn, optim
​
# 1,构造训练数据:y=2x+1
x = torch.linspace(-5, 5, 100).unsqueeze(1)  # 100的样本,维度[100,1]
print(x, x.shape)
y = 2 * x + 1 + torch.randn(x.size())  # 添加噪声# 2,定义简单的线性模型
model = nn.Linear(1, 1)
​
# 3, 定义损失函数与优化器
criterion = nn.MSELoss()  # 均方误差
optimizer = optim.SGD(model.parameters(), lr=0.01)
​
# 4,训练模型
epochs = 2000
for epoch in range(epochs):
    y_pred = model(x)  # 前向传播
    loss = criterion(y_pred, y)  # 计算损失
    optimizer.zero_grad()  # 清空梯度
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
​
    print(f'epoch: {epoch}, loss: {loss.item()}')
​
# 5,查看结果
[w, b] = model.parameters()
print(f'训练结果:w: {w}, b: {b}')

流程说明:

  1. forward() 前向传播,构建计算图
  2. loss.backward() 反向传播,自动求出参数梯度
  3. optimizer.step() 更新模型参数

数据集与数据加载

在 PyTorch 的训练流程中,数据读取与预处理 通常分为两部分:

  1. Dataset(数据集类) 负责定义样本获取方式,即“如何读一条数据”。
  2. DataLoader(数据加载器) 负责批量加载与并行加速,即“如何读多条数据”。

这两者的配合实现了高效的数据输入管线。

PyTorch 领域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集可以子类化torch.utils.data.Dataset并实现特定于特定数据的函数。它们可用于对模型进行原型设计和基准测试。

1,加载数据

以下是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 商品图片的数据集,包含 60,000 个训练样本和 10,000 个测试样本。每个样本包含一张 28×28 的灰度图像以及 10 个类别中对应类别的标签。

Dataset 是一个抽象类。 自定义数据集时需重写两个关键方法:

方法作用
__len__(self)返回数据集中样本数量
__getitem__(self, index)根据索引返回单个样本 (data, label)

示例代码:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
​
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True
)
print('训练集:')
print(training_data.__len__())
print(training_data.__getitem__(0))
print(training_data.targets)
​
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True
)
​
print('测试集:')
print(test_data.__len__())
print(test_data.__getitem__(0))

运行后,下载数据集到相对目录

image.png

运行输出:

训练集:
60000
(<PIL.Image.Image image mode=L size=28x28 at 0x1601FE917D0>, 9)
tensor([9, 0, 0,  ..., 3, 0, 5])
测试集:
10000
(<PIL.Image.Image image mode=L size=28x28 at 0x1602387ED90>, 9)

2,遍历和可视化数据

Datasets我们可以像列表一样手动索引: training_data[index]。我们用它matplotlib来可视化训练数据中的某些样本。

我们先安装下matplotlib,和 jupyter

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install jupyter -i https://pypi.tuna.tsinghua.edu.cn/simple

示例:

import matplotlib.pyplot as plt
​
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img, cmap="gray")
plt.show()

运行输出:

image.png

3,使用 DataLoaders 准备训练

Dataset会检索数据集的特征,并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个周期重新调整数据以减少模型过拟合,并使用 Pythonmultiprocessing来加速数据检索。

DataLoader是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。

DataLoader是PyTorch中用于批量加载数据的工具类,它将training_data数据集按照指定的batch_size=64进行分批处理,并通过shuffle=True参数在每个训练周期开始时随机打乱数据顺序,以提高模型训练效果。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

transform预处理转换模块

PyTorch 2 的 transform 模块 —— 它是图像预处理与增强中非常核心的部分。

🧠 一、transform 是什么?

PyTorch 中,尤其是使用 torchvision 进行图像任务时,数据的输入通常需要经过预处理才能喂入神经网络。 这些预处理操作(如缩放、裁剪、归一化、数据增强等)就是通过 torchvision.transforms 模块实现的。

PyTorch 2 中该模块更加灵活、可组合,支持 PIL 图像、Tensor、NumPy 数组 等多种格式。


🧩 二、torchvision.transforms 的主要功能分类

功能类别常用 Transform作用说明
图像格式转换ToTensor(), ToPILImage()PIL ↔ Tensor 互转
几何变换Resize(), CenterCrop(), RandomCrop(), RandomRotation(), RandomHorizontalFlip()改变图像尺寸、角度、位置等
颜色变换ColorJitter(), Grayscale(), RandomAdjustSharpness()调整亮度、对比度、饱和度等
数据增强RandomResizedCrop(), RandomAffine()随机扰动图像,提高模型泛化能力
数值标准化Normalize(mean, std)将像素值标准化,提升训练稳定性
组合操作transforms.Compose([...])将多个变换按顺序组合

⚙️ 三、基本使用示例

我们把上一节的实例改下:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms.v2 import ToTensor
​
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()  # 将图像的像素强度值缩放到 [0., 1.] 范围内 归一化
)
print('训练集:')
print(training_data.__len__())
print(training_data.__getitem__(0))
print(training_data.targets)

运行输出:

训练集:
60000
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,          0.0000, 0.0039, 0.0039, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,          0.0157, 0.0000, 0.0000, 0.0118],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,          0.0000, 0.0471, 0.0392, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,          0.3020, 0.5098, 0.2824, 0.0588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,          0.5529, 0.3451, 0.6745, 0.2588],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,          0.4824, 0.7686, 0.8980, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,          0.8745, 0.9608, 0.6784, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,          0.8627, 0.9529, 0.7922, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,          0.8863, 0.7725, 0.8196, 0.2039],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,          0.9608, 0.4667, 0.6549, 0.2196],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,          0.8510, 0.8196, 0.3608, 0.0000],
         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,          0.8549, 1.0000, 0.3020, 0.0000],
         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,          0.8784, 0.9569, 0.6235, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,          0.9137, 0.9333, 0.8431, 0.0000],
         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,          0.8627, 0.9098, 0.9647, 0.0000],
         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,          0.8706, 0.8941, 0.8824, 0.0000],
         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,          0.8745, 0.8784, 0.8980, 0.1137],
         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,          0.8627, 0.8667, 0.9020, 0.2627],
         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,          0.7098, 0.8039, 0.8078, 0.4510],
         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,          0.6549, 0.6941, 0.8235, 0.3608],
         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,          0.7529, 0.8471, 0.6667, 0.0000],
         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,          0.3882, 0.2275, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,          0.0000, 0.0000, 0.0000, 0.0000]]]), 9)
tensor([9, 0, 0,  ..., 3, 0, 5])