层与块|深度学习计算|动手学深度学习

13 阅读3分钟

1. 如果将 MySequential 中存储块的方式更改为 Python 列表,会出现什么样的问题?

在PyTorch中,nn.Module的子类通常使用OrderedDict来存储模块的子模块,这样做是为了确保在序列模型中,子模块能够按照它们被添加的顺序进行迭代。如果将MySequential中存储块的方式从OrderedDict更改为Python原生的列表(list),可能会出现以下几个问题:

  1. 迭代顺序:使用列表确实可以保持块的添加顺序,因此在forward方法中迭代时,块还是会按照它们被添加到列表中的顺序进行处理。这一点与OrderedDict是一致的。

  2. 名称唯一性:在nn.Module中,子模块的名称必须是唯一的。如果使用列表,你需要确保列表中的模块没有名称冲突,因为列表不支持使用非数字作为索引。而OrderedDict可以使用字符串作为键,这样可以给每个模块一个有意义的名称,有助于调试和代码的可读性。

  3. 模块访问:使用OrderedDict时,可以通过模块的名称来访问特定的子模块,这在调试或者需要单独处理某个特定模块时非常有用。如果使用列表,你只能通过索引来访问模块,这限制了灵活性。

  4. 模块移除和更新OrderedDict提供了方便的方式来添加、移除和更新模块,而列表在这方面的操作不是那么直观和方便。例如,如果你想替换或移除列表中的一个模块,你需要使用popremove方法,这可能会导致其他模块的索引发生变化,而在OrderedDict中,你可以直接使用del语句来移除一个模块。

  5. 序列化OrderedDict可以更好地支持序列化和反序列化操作,因为它们保证了对象的状态(包括子模块的顺序)在保存和恢复时的一致性。

  6. 性能:虽然对于小型模型,使用列表可能在性能上没有太大的差异,但是对于大型或复杂的模型,OrderedDict可能提供更优的性能,特别是当涉及到频繁的模块添加、删除或更新操作时。

  7. API一致性:PyTorch的API设计中使用OrderedDict是为了保持一致性。如果在自定义模块中使用列表,可能会导致与PyTorch的预期API行为不一致,从而使得用户在使用自定义模块时感到困惑。

总的来说,虽然理论上可以使用列表来存储模块,但使用OrderedDict是更推荐的做法,因为它提供了更好的灵活性、可读性以及与PyTorch框架的兼容性。

2. 实现一个块,它以两个块为参数,例如 net1net2,并返回前向传播中两个网络的串联输出。这也被称为平行块。

class ParallelBlock(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self._modules[str(idx)] = module

    def forward(self, X):
        outs = []
        for block in self._modules.values():
            outs.append(block(X))
        out = torch.cat(outs, dim=1)
        return out
net1 = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU())
net2 = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU())
net = ParallelBlock(net1, net2)
net(X)
tensor([[0.0000, 0.2955, 0.0652, 0.0000, 0.0000, 0.0000, 0.3279, 0.0000, 0.1225,
         0.2088, 0.1095, 0.0000, 0.0000, 0.0000, 0.0000, 0.2383, 0.0000, 0.2530,
         0.0826, 0.1164, 0.0387, 0.0000, 0.1693, 0.0000, 0.0000, 0.1104, 0.0030,
         0.0000, 0.2451, 0.2041, 0.1935, 0.0350, 0.0981, 0.0565, 0.0000, 0.1866,
         0.0000, 0.0000, 0.0000, 0.0933, 0.0899, 0.0000, 0.3394, 0.1120, 0.0000,
         0.0000, 0.0718, 0.0275, 0.0000, 0.0943, 0.0000, 0.1015, 0.0000, 0.2515,
         0.0017, 0.0537, 0.2551, 0.0000, 0.1681, 0.2545, 0.0000, 0.0000, 0.0000,
         0.1079],
        [0.0000, 0.2858, 0.0641, 0.0000, 0.0000, 0.0000, 0.3474, 0.0918, 0.1000,
         0.2191, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2357, 0.0000, 0.2782,
         0.1247, 0.2183, 0.1063, 0.0384, 0.1456, 0.0000, 0.0000, 0.1282, 0.0404,
         0.0000, 0.2798, 0.1584, 0.1291, 0.0342, 0.1341, 0.0547, 0.0000, 0.1452,
         0.0000, 0.0000, 0.0000, 0.2182, 0.0986, 0.0000, 0.3971, 0.0607, 0.0000,
         0.0000, 0.0198, 0.0421, 0.0000, 0.0363, 0.0000, 0.0834, 0.0546, 0.2563,
         0.0095, 0.0000, 0.3408, 0.0000, 0.0884, 0.1978, 0.0000, 0.0000, 0.0000,
         0.1258]], grad_fn=<CatBackward0>)

3. 假设我们想要连接同一网络的多个实例。实现一个函数,该函数生成同一个块的多个实例,并在此基础上构建更大的网络。