Go容器化微服务系统实战|完结无密

270 阅读3分钟

Go容器化微服务系统实战|完结无密

超清原画 完整无密 包括所有视频课件以及源码 MP4格式 获取资料:https://www.zxit666.com/1854/

莫凡Pytorch教程(五):Pytorch模型保管与提取

Pytorch模型保管与提取

本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何运用Pytorch保管和提取我们的神经网络。

在Pytorch中,网络的存储主要运用torch.save函数来完成。

我们将经过两种方式展现模型的保管和提取。 第一种保管方式是保管整个模型,在重新提取时直接加载整个模型。第二种保管办法是只保管模型的参数,这种方式只保管了参数,而不会保管模型的构造等信息。

两种方式各有优缺陷。

  • 保管完好模型不需求晓得网络的构造,一次性保管一次性读入。缺陷是模型比拟大时耗时较长,保管的文件也大。
  • 而只保管参数的方式存储快捷,保管的文件也小一些,但缺陷是丧失了网络的构造信息,恢复模型时需求提早树立一个特定构造的网络再读入参数。

以下运用代码展现。

数据生成与展现

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
复制代码

这里还是生成一组带有噪声的

y=x^{2}

y=x2数据停止回归拟合。

# torch.manual_seed(1)    # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
复制代码

根本网络搭建与保管

我们运用nn.Sequential模块来快速搭建一个网络完成回归操作,网络由两层Linear层和中间的激活层ReLU组成。我们设置输入输出的维度为1,中间躲藏层变量的维度为10,以加快锻炼。

这里运用两种方式停止保管。

def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for step in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_save.png")
        
    torch.save(net1, 'net.pkl')                        # entire network
    torch.save(net1.state_dict(), 'net_params.pkl')    # parameters
复制代码

在这个save函数中,我们首先运用nn.Sequential模块构建了一个根底的两层神经网络。然后对其停止锻炼,展现锻炼结果。之后运用两种方式停止保管。

第一种方式直接保管整个网络,代码为

torch.save(net1, 'net.pkl')                        # entire network
复制代码

第二种方式只保管网络参数,代码为

torch.save(net1.state_dict(), 'net_params.pkl')    # parameters
复制代码

对保管的模型停止提取恢复

这里我们为两种不同存储方式保管的模型分别定义恢复提取的函数 首先是对整个网络的提取。直接运用torch.load就能够,无需其他额外操作。

def restore_net():
    # 提取神经网络
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_net.png")
复制代码

而关于参数的读取,我们首先需求先搭建好一个与之前保管的模型相同架构的网络,然后运用这个网络的load_state_dict办法停止参数读取和恢复。以下展现了运用参数方式读取网络的示例:

def restore_params():
    # 提取神经网络
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_para.png")
    plt.show()
复制代码

比照不同提取办法的效果

接下来我们比照一下这两种办法的提取效果

# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()