李宏毅-助教Pytorch-快速上手

65 阅读2分钟

一、加载数据

1. 自定义 Dataset 类

class MyDataset(Dataset):
    def __init__(self, file): self.data = ...  # 读取/预处理数据
    def __getitem__(self, index): return self.data[index]  # 按索引取样本
    def __len__(self): return len(self.data)  # 返回数据集长度
  • 核心作用:让 PyTorch 识别自定义数据(如本地 CSV / 图片);

  • 继承原因:遵循 Dataset 抽象接口,兼容 PyTorch 数据加载生态。

2. DataLoader 分组

dataloader = DataLoader(dataset, batch_size=5, shuffle=False)
  • 核心作用:批量加载数据(batch_size 控制每批样本数),shuffle 控制是否打乱;

  • 附加价值:自动多线程加载,提升效率。

3. 张量(Tensors)

代码片段核心作用
torch.tensor/zeros/ones创建多维数组(PyTorch 核心数据结构)
sum/mean/transpose/squeeze/unsqueeze张量数值计算、维度调整
requires_grad=True + backward()开启梯度追踪,反向传播计算参数梯度
x.to('cuda')张量移到 GPU,实现加速计算

二、定义模型

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(10,32), nn.Sigmoid(), nn.Linear(32,1))
    def forward(self, x): return self.net(x)
  • 核心作用:定义神经网络结构(线性层 + 激活函数),forward 指定前向传播逻辑;
  • 继承原因:复用 nn.Module 封装的参数管理、反向传播、设备迁移等核心功能。

三、损失函数 & 优化器

代码片段核心作用
nn.MSELoss()/CrossEntropyLoss()定义损失函数(衡量预测与真实值差距)
torch.optim.SGD(model.parameters(), lr=0.1)优化器:根据梯度调整模型参数,lr 控制学习步长

四、训练 / 验证 / 测试

1. 训练循环

for epoch in range(n_epochs):
    model.train()  # 训练模式(启用Dropout/BatchNorm)
    for x, y in tr_set:
        optimizer.zero_grad()  # 清空梯度(避免累加)
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()  # 反向传播算梯度
        optimizer.step()  # 更新参数
  • 核心逻辑:前向算损失→反向算梯度→优化器调参数,循环迭代。

2. 验证 / 测试

model.eval()  # 评估模式(关闭Dropout/BatchNorm)
with torch.no_grad():  # 禁用梯度计算,加速+省内存
    pred = model(x)  # 仅前向传播,不更新参数
  • 核心目的:验证模型泛化能力,测试最终预测效果。

3. 模型保存 / 加载

torch.save(model.state_dict(), path)  # 存参数
model.load_state_dict(torch.load(path))  # 加载参数
  • 核心作用:保存训练好的模型参数,方便后续复用 / 部署。

总结

  1. 继承核心:Dataset/nn.Module 是 PyTorch 的基础接口,继承可复用生态功能;

  2. 训练核心逻辑:前向传播算损失→反向传播算梯度→优化器更新参数;

  3. 关键细节:训练 / 评估模式切换、梯度清空、设备统一(数据 / 模型同 CPU/GPU)。