Pytorch模型定义

1,268 阅读1分钟

我报名参加金石计划1期挑战——瓜分10万奖池,这是我的第6篇文章,点击查看活动详情

一、PyTorch模型定义

image.png

  1. Sequential适用于快速验证结果,因为已经明确了要用哪些层,直接写一下就好了,不需要同时写__init__forward

  2. ModuleList和ModuleDict在某个完全相同的层需要重复出现多次时,非常方便实现,可以”一行顶多行“;

  3. 当我们需要之前层的信息的时候,比如 ResNets 中的残差计算,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList/ModuleDict 比较方便。

import torch.nn as nn
net = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net)
import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)),
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
          ]))
print(net2)
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)
net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)

二、利用模型块快速搭建复杂网络

5.2.1unet.png

分块定义,组装。

三、PyTorch修改模型

  • 可以根据名称指定层进行修改
  • 可以添加外部输入
  • 添加额外输出

四、PyTorch模型保存与读取

PyTorch存储模型主要采用pkl,pt,pth三种格式。

1.普通模式


from torchvision import models
model = models.resnet152(pretrained=True)

# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict, save_dir)

2.多卡模式

os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
model = model.cuda()  # 单卡
model = torch.nn.DataParallel(model).cuda()  # 多卡

五、思考

这一节概念较多,动手较少,需要注意的是要读懂常见的几种网络(非常重要),做到一通百通。