莫凡Pytorch教程(五):Pytorch模型保存与提取

1,818 阅读4分钟

一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第26天,点击查看活动详情

前言

本文为我之前在CSDN平台上的一篇博客记录。原链接为:blog.csdn.net/u011426236/…

Pytorch模型保存与提取

本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何使用Pytorch保存和提取我们的神经网络。

在Pytorch中,网络的存储主要使用torch.save函数来完成。

我们将通过两种方式展示模型的保存和提取。 第一种保存方式是保存整个模型,在重新提取时直接加载整个模型。第二种保存方法是只保存模型的参数,这种方式只保存了参数,而不会保存模型的结构等信息。

两种方式各有优缺点。

  • 保存完整模型不需要知道网络的结构,一次性保存一次性读入。缺点是模型比较大时耗时较长,保存的文件也大。
  • 而只保存参数的方式存储快捷,保存的文件也小一些,但缺点是丢失了网络的结构信息,恢复模型时需要提前建立一个特定结构的网络再读入参数。

以下使用代码展示。

数据生成与展示

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

这里还是生成一组带有噪声的y=x2y=x^{2}数据进行回归拟合。

# torch.manual_seed(1)    # reproducible

# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

基本网络搭建与保存

我们使用nn.Sequential模块来快速搭建一个网络完成回归操作,网络由两层Linear层和中间的激活层ReLU组成。我们设置输入输出的维度为1,中间隐藏层变量的维度为10,以加快训练。

这里使用两种方式进行保存。

def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for step in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_save.png")
        
    torch.save(net1, 'net.pkl')                        # entire network
    torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

在这个save函数中,我们首先使用nn.Sequential模块构建了一个基础的两层神经网络。然后对其进行训练,展示训练结果。之后使用两种方式进行保存。

第一种方式直接保存整个网络,代码为

torch.save(net1, 'net.pkl')                        # entire network

第二种方式只保存网络参数,代码为

torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

对保存的模型进行提取恢复

这里我们为两种不同存储方式保存的模型分别定义恢复提取的函数 首先是对整个网络的提取。直接使用torch.load就可以,无需其他额外操作。

def restore_net():
    # 提取神经网络
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_net.png")

而对于参数的读取,我们首先需要先搭建好一个与之前保存的模型相同架构的网络,然后使用这个网络的load_state_dict方法进行参数读取和恢复。以下展示了使用参数方式读取网络的示例:

def restore_params():
    # 提取神经网络
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_para.png")
    plt.show()

对比不同提取方法的效果

接下来我们对比一下这两种方法的提取效果

# save net1
save()

# restore entire net (may slow)
restore_net()

# restore only the net parameters
restore_params()

最后,得到的展示输出如下: save 这里一共展示了三种情况下的模型:Net1即我们已经训练好的网络,这里我们使用两种方式保存了Net1。使用第一种方式存储和提取整个模型的所有信息的结果为Net2,使用第二种方式存储和提取模型参数的结果为Net3。

通过对比可以看出,这三个网络一模一样,证明不同的存储提取方式的效果是相同的,不会有差异。只不过第二种方式提取时,我们要预先定义好与之前所有保存网络一致的模型结构。

参考

  1. 莫凡Python:Pytorch动态神经网络,mofanpy.com/tutorials/m…