Pytorch lightning 快速入门(上)

592 阅读3分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第31天,点击查看活动详情

pytorch lightning

之前接触过 pytorch lightning,不过没有时间深入了解一下,最近做一些项目看别人用到 pytorch lightning 感觉不错,所以也想学学。Pytorch-Lighting 的一大特点是把模型和系统分开来看。 在正式开始介绍 pytorch lightning 之前,我们先用 pytorch 实现 MNIST 识别,然后接下来将这个项目用 pytorch lightning 重新写一下,来体验一下 pytorch ligntning 给我们带来开发便利性和规范性

import torch
from torch import nn
from torch import optim

from tqdm.notebook import tqdm

定义模型

# 定义模型
model = nn.Sequential(
    nn.Linear(28*28,64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Linear(64,10)
)

这里我们简单定义一个网络,主要是为了介绍 pytorch lightning ,这个网络主要包含 2 个隐藏层,都是使用了全连接,然后添加 ReLU 激活函数来做非线性变换。

优化器

  • 优化器这里选择的事 SGD 优化器,优化器接受模型参数和学习率作为参数
# 定义优化器
params = model.parameters()
optimizer = optim.SGD(params,lr=1e-2)

定义损失函数

# 定义损失函数
loss_fn = nn.CrossEntropyLoss()

定义训练

epochs = 5

for epoch in range(epochs):
    losses = []
    for batch in tqdm(train_loader):
        x,y = batch
#         print(x.shape,y.shape)
        # x:(batch_size x 1 x 28 x 28)
        b = x.size(0)
        # 这里 -1 表示在 batch_size 后面的几个维度进行合并 1 x 28 x 28
        x = x.view(b,-1)
        
        # logits forward
        logits = model(x)
        
        # 2 计算目标函数
        loss = loss_fn(logits,y)
        
        # 3 初始化梯度 这里也可以 optimizer 来将梯度进行清空
        model.zero_grad()
        #params.grad.zero_()
        
        # 4 反向传播参数梯度
        loss.backward()
        
        # 更新参数
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch {epoch+1},train loss:{torch.tensor(losses).mean():.2f}")
    
    losses = []
    for batch in tqdm(val_loader):
        x,y = batch
        
        b = x.size(0)
        x = x.view(b,-1)

        
        with torch.no_grad():
            logits = model(x)
        
        # 2 计算目标函数
        loss = loss_fn(logits,y)
        
        
        losses.append(loss.item())
    print(f"Epoch {epoch+1},validation loss:{torch.tensor(losses).mean():.2f}")
        

加载数据

from torchvision import datasets,transforms
from torch.utils.data import random_split,DataLoader
train_data = datasets.MNIST('data',train=True,download=True,transform=transforms.ToTensor())
train, val = random_split(train_data,[55000,5000])
train_loader = DataLoader(train,batch_size=2)
val_loader = DataLoader(val,batch_size=2)

ResNet

为了提升网络准确度,这里引入残差结构,并且引入 dropout 来解决过拟合的问题。

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28*28,64)
        self.l2 = nn.Linear(64,64)
        self.l3 = nn.Linear(64,10)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self,x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        dropout = self.dropout(h2 + h1)
        logits = self.l3(dropout)
        return logits
    
model = ResNet()
print(f"Epoch {epoch+1}",end=",")
print(f"validation loss:{torch.tensor(losses).mean():.2f}",end=",")
print(f"validation accuracy:{torch.tensor(accs).mean():.2f)}")

引入 pytorch lighting

import pytorch_lightning as pl
  • 定义模型
  • 定义优化器
  • 数据
  • 定义训练
  • 验证

定义模型

class ResNet(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28*28,64)
        self.l2 = nn.Linear(64,64)
        self.l3 = nn.Linear(64,10)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self,x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        dropout = self.dropout(h2 + h1)
        logits = self.l3(dropout)
        return logits
    
model = ResNet()

注意这里是继承于 pl.LightningDataModule 不再是 nn.Module 这个模块类

定义优化器

在刚刚定义 ResNet 类中可以定义 configure_optimizer 方法然后这里实例化一个优化器,并且将其返回,注意因为是在 ResNet 类中定义该方法,所以可以直接通过 self 获取参数

    def configure_optimizer(self):
        optimizer = optim.SGD(self.parameters(),lr=1e-2)
        return optimizer

定义训练

在数值函数中可以初始化损失函数

self.loss_fn = nn.CrossEntropyLoss()
    def training_step(self,batch,batch_idx):
        x,y = batch
        # x:(batch_size x 1 x 28 x 28)
        b = x.size(0)
        # 这里 -1 表示在 batch_size 后面的几个维度进行合并 1 x 28 x 28
        x = x.view(b,-1)
        
        # logits forward
        logits = model(x)
        
        # 2 计算目标函数
        loss = self.loss_fn(logits,y)

加载数据

    def train_dataloader(self):
        train_data = datasets.MNIST('data',train=True,download=True,transform=transforms.ToTensor())
        train_loader = DataLoader(train_data,batch_size=2)
        
        return train_loader
    def backward(self, trainer, loss, optimizer, optimizer_idx):
        loss.backward()