先睹为快—pytorch lightning

252 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情

安装

在安装 pytorch lighting 之前需要先安装一下 pytorch

PyTorch 和 Pytorch Lighting 之间关系

现在模型变得越来越复杂,那么如果用 Pytorch 来实现这样复杂模型,也变得比较麻烦,例如如何使用 pytorch 来实现多 GPU 训练,16 位精度预测和如何在 TPU 上进行训练。这些都需要开发人员去花费心思去设计和调试。

PyTorch Lightning 的出现目的就是为了解决这些问题,而且帮助你有效地组织你的代码,屏蔽一些繁琐工作,让你的开发更有效,今天我们就去体验一下。

Pytorch Lightning 是由 Facebook AI research 的研究人员设计开发的,框架设计具有良好扩展性,而且支持当下一些前沿的技术,例如 TPU 训练。

分享目的

本次分享以构建一个简单用于 MNIST 数据集分类器,通过完整代码来解释说明如何通过 pytorch lightning 来快速构建这个分类器

import os

import pandas as pd
import seaborn as sn

import torch

from IPython.core.display import display

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
PATH_DATASETS = os.environ.get("PATH_DATASETS",".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

定义模型

这里模型是一个1 层全连接神经网络,输入是 28x28 输出一个 10 概率分布。模型定义了一个计算图,输入是 MNIST 数据图像数据,输出是这个图像数字可能是 0 -9 概率分布。

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28,10)
    
    def forward(self,x):
        return torch.relu(self.l1(x.view(x.size(0),-1)))
    
    def training_step(self, batch,batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x),y)
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)
        

这里定义了一个类 MNISTModel 没有继承 nn.Module 而是继承了 LightningModule 这个 Pytorch Lightning 提供模,这个模块提供结构化代码,帮助你规范对模型定义

  • train_step 用于定义训练每次迭代对参数的更新
  • configure_optimizers 用于指定训练时采用优化器

将这些与model相关操作都可以在定义 model 进行指定