nn.Module,nn.ModuleList,nn.Sequential,这些类我们称为为容器(containers)
1. nn.ModuleList
ModuleList 功能:将子模块保存在列表中。这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现 forward 功能需要自己实现。
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
class net2(nn.Module):
def __init__(self):
super(net3, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
def forward(self, x):
x = self.linears[2](x)
x = self.linears[0](x)
x = self.linears[1](x)
return x
net2网络结构:
net2(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=20, bias=True)
(1): Linear(in_features=20, out_features=30, bias=True)
(2): Linear(in_features=5, out_features=10, bias=True)
)
2. nn.Sequential
nn.Sequential 顺序容器。模块将按照在构造函数中传递的顺序添加到其中。内部已实现forward功能,而且里面的模块必须是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。
model_1 = nn.Sequential(
nn.Conv1d(1,20,5),
nn.ReLU(),
nn.Conv1d(20,64,5),
nn.ReLU())
model_2 = nn.Sequential(*[
nn.Conv1d(1,20,5),
nn.ReLU(),
nn.Conv1d(20,64,5),
nn.ReLU()])
x = torch.rand(32, 1, 48)
output = model_1(x)
print(output.shape)
print(model_1)
输出:
torch.Size([32, 64, 40])
Sequential(
(0): Conv1d(1, 20, kernel_size=(5,), stride=(1,))
(1): ReLU()
(2): Conv1d(20, 64, kernel_size=(5,), stride=(1,))
(3): ReLU())
采用OrderedDict来指定每个module的名字
import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
print(model)
输出:
Sequential(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU()
)
3. nn.ModuleDict
ModuleDict可以像常规Python字典一样进行索引,但它包含的模块已正确注册,并且所有Module方法都可见
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
class My_Model(nn.Module):
def __init__(self):
super(My_Model, self).__init__()
self.layers = nn.ModuleDict({'linear_1': nn.Linear(32, 64), 'act_1': nn.ReLU(),
'linear_2': nn.Linear(64, 128), 'act_2': nn.ReLU()})
def forward(self, x):
for layer in self.layers.values():
x = layer(x)
return x
net = My_Model()
x = torch.randn(8, 3, 32)
out = net(x)
print(out.shape)
输出:
torch.Size([8, 3, 128])
4. ModuleDict、 ModuleList 的区别
- ModuleDict 可以给每个层定义名字,ModuleList 不会
- ModuleList 可以通过索引读取,并且使用 append 添加元素
- ModuleDict 可以通过 key 读取,并且可以像 字典一样添加元素
import torch.nn as nn
net = nn.ModuleList([nn.Linear(32, 64),
nn.ReLU()])
net.append(nn.Linear(64, 10))
print(net)
net = nn.ModuleDict({'linear1': nn.Linear(32, 64),
'act': nn.ReLU()})
net['linear2'] = nn.Linear(64, 128)
print(net)
5. Sequential 、ModuleDict、 ModuleList 的区别
-
ModuleList 仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现 forward 功能需要自己实现
-
和 ModuleList 一样, ModuleDict 实例仅仅是存放了一些模块的字典,并没有定义 forward 函数需要自己定义
-
而 Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,所以,直接如下写模型,是可以直接调用的,不再需要写forward,sequential 内部已经有 forward
6. n.ModuleList 转换成 nn.Sequential
import torch
import torch.nn as nn
module_list = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net = nn.Sequential(*module_list)
x = torch.randn(8, 3, 32)
print(net(x).shape)
7. n.ModuleDict 转换成 nn.Sequential
import torch
import torch.nn as nn
module_dict = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})
net = nn.Sequential(*module_dict.values())
x = torch.randn(8, 3, 32)
print(net(x).shape)
8. nn.ModuleList 、 nn.ModuleDict 与 Python list、Dict 的区别
加入到 ModuleList 、ModuleDict 里面的所有模块的参数会被自动添加到整个网络中。
import torch.nn as nn
net = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
for name, param in net.named_parameters():
print(name, param.size())
import torch.nn as nn
net = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})
for name, param in net.named_parameters():
print(name, param.size())
参考:
[2.](nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict 的使用 与 区别 - 哔哩哔哩 (bilibili.com))