Go容器化微服务系统实战|完结无密
超清原画 完整无密 包括所有视频课件以及源码 MP4格式 获取资料:https://www.zxit666.com/1854/
莫凡Pytorch教程(五):Pytorch模型保管与提取
Pytorch模型保管与提取
本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何运用Pytorch保管和提取我们的神经网络。
在Pytorch中,网络的存储主要运用torch.save函数来完成。
我们将经过两种方式展现模型的保管和提取。 第一种保管方式是保管整个模型,在重新提取时直接加载整个模型。第二种保管办法是只保管模型的参数,这种方式只保管了参数,而不会保管模型的构造等信息。
两种方式各有优缺陷。
- 保管完好模型不需求晓得网络的构造,一次性保管一次性读入。缺陷是模型比拟大时耗时较长,保管的文件也大。
- 而只保管参数的方式存储快捷,保管的文件也小一些,但缺陷是丧失了网络的构造信息,恢复模型时需求提早树立一个特定构造的网络再读入参数。
以下运用代码展现。
数据生成与展现
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
复制代码
这里还是生成一组带有噪声的
y=x^{2}
y=x2数据停止回归拟合。
# torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
复制代码
根本网络搭建与保管
我们运用nn.Sequential模块来快速搭建一个网络完成回归操作,网络由两层Linear层和中间的激活层ReLU组成。我们设置输入输出的维度为1,中间躲藏层变量的维度为10,以加快锻炼。
这里运用两种方式停止保管。
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for step in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_save.png")
torch.save(net1, 'net.pkl') # entire network
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
在这个save函数中,我们首先运用nn.Sequential模块构建了一个根底的两层神经网络。然后对其停止锻炼,展现锻炼结果。之后运用两种方式停止保管。
第一种方式直接保管整个网络,代码为
torch.save(net1, 'net.pkl') # entire network
复制代码
第二种方式只保管网络参数,代码为
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
对保管的模型停止提取恢复
这里我们为两种不同存储方式保管的模型分别定义恢复提取的函数 首先是对整个网络的提取。直接运用torch.load就能够,无需其他额外操作。
def restore_net():
# 提取神经网络
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_res_net.png")
复制代码
而关于参数的读取,我们首先需求先搭建好一个与之前保管的模型相同架构的网络,然后运用这个网络的load_state_dict办法停止参数读取和恢复。以下展现了运用参数方式读取网络的示例:
def restore_params():
# 提取神经网络
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_res_para.png")
plt.show()
复制代码
比照不同提取办法的效果
接下来我们比照一下这两种办法的提取效果
# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()