大多数从业者在第一次学习卷积神经网络(CNN)架构时--了解到它是由三个基本部分组成的:
- 卷积层
- 池化层
- 完全连接的层
大多数资源都对这种细分有一些变化,包括我自己的书。特别是在网上--完全连接的层指的是一个扁平化层和(通常)多个密集层。
这曾经是常态,著名的架构,如VGGNets使用这种方法,并会在:
model = keras.Sequential([
#...
keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
keras.layers.Flatten(),
keras.layers.Dropout(0.5),
keras.layers.Dense(4096, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(4096, activation='relu'),
keras.layers.Dense(n_classes, activation='softmax')
])
不过,由于某些原因--人们常常忘记VGGNet实际上是最后一个使用这种方法的架构,因为它造成了明显的计算瓶颈。当ResNets在VGGNets之后的第二年(也就是7年前)发表时,所有的主流架构都以模型定义结束。
model = keras.Sequential([
#...
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(n_classes, activation='softmax')
])
CNN中的扁平化已经坚持了7年。7年了!而且似乎没有足够的人在谈论它对你的学习经验和你所使用的计算资源的破坏性影响。
全局平均池在很多方面都比扁平化要好。如果你在做一个小型CNN的原型--使用全局池化。如果你在教别人关于CNN的知识--使用全局池化。如果你要做一个MVP--使用全局池化。在其他实际需要的情况下使用扁平化层。
案例研究--扁平化与全局池化
全局池化(Global Pooling)将所有的特征图浓缩为一张,将所有的相关信息汇集到一张图中,通过一个密集的分类层而不是多个分类层就可以轻松理解。它通常被应用为平均池化(GlobalAveragePooling2D)或最大池化(GlobalMaxPooling2D),也可用于一维和三维输入。
与其将一个特征图(如(7, 7, 32) )平铺成一个长度为1536的向量,并训练一个或多个层来从这个长向量中辨别模式:我们可以将它浓缩成一个(7, 7) 向量,并直接从那里进行分类。就是这么简单!
请注意,像ResNets这样的网络的瓶颈层是以数万个特征计算的,而不是仅仅1536个。扁平化时,你在折磨你的网络,以一种非常低效的方式从奇特的向量中学习。想象一下,一张二维图像在每个像素行上都被切开,然后串联成一个平面向量。过去垂直方向上相距0像素的两个像素在水平方向上没有feature_map_width 。虽然这对分类算法来说可能没有太大关系,因为分类算法倾向于空间不变性--但对于计算机视觉的其他应用来说,这甚至在概念上都不是好事。
让我们定义一个小型的示范性网络,它使用一个带有几个密集层的扁平化层:
model = keras.Sequential([
keras.layers.Input(shape=(224, 224, 3)),
keras.layers.Conv2D(32, (3, 3), activation='relu'),
keras.layers.Conv2D(32, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2), (2, 2)),
keras.layers.BatchNormalization(),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2), (2, 2)),
keras.layers.BatchNormalization(),
keras.layers.Flatten(),
keras.layers.Dropout(0.3),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.summary()
摘要是什么样子的?
...
dense_6 (Dense) (None, 10) 330
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________
一个玩具网络的11.5M的参数--看着参数在更大的输入下爆炸开来。11.5M的参数。EfficientNets是有史以来性能最好的网络之一,其工作参数约为600万,在实际性能和从数据中学习的能力方面无法与这个简单模型相比。
我们可以通过使网络更深入来大大减少这个数字,这将引入更多的最大池(和潜在的分层卷积)来减少特征图,然后再把它们压扁。然而,考虑到我们将使网络变得更加复杂,以使其计算成本更低,所有这些都是为了一个单一的层,它在计划中扔了一个扳手。
深化层应该是为了提取数据点之间更有意义的非线性关系,而不是为了迎合扁平化层而减少输入大小。
这是一个带有全局池的网络:
model = keras.Sequential([
keras.layers.Input(shape=(224, 224, 3)),
keras.layers.Conv2D(32, (3, 3), activation='relu'),
keras.layers.Conv2D(32, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2), (2, 2)),
keras.layers.BatchNormalization(),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2), (2, 2)),
keras.layers.BatchNormalization(),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.3),
keras.layers.Dense(10, activation='softmax')
])
model.summary()
总结?
dense_8 (Dense) (None, 10) 650
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________
好多了!如果我们对这个模型进行深入研究,参数数会增加,我们可能会用新的层来捕捉更复杂的数据模式。但如果天真地去做,就会出现约束VGGNets的同样问题。
更进一步--手持式端到端项目
你好奇的天性让你想更进一步?我们建议查看我们的 指导性项目: "卷积神经网络--超越基本架构".
我将带你进行一次时间旅行--从1998年到2022年,强调这些年来开发的决定性架构,它们的独特之处,它们的缺点是什么,并从头实现那些值得注意的架构。谈到这些,没有什么比手上有一些污垢更好。
你可以在不知道发动机是4个还是8个气缸,以及发动机内阀门的位置是什么的情况下驾驶一辆汽车。然而--如果你想设计和欣赏一个发动机(计算机视觉模型),你会想要更深入一点。即使你不想花时间设计架构,而想建造产品,这也是大多数人想做的事--你也会在本课中找到重要的信息。你会了解到为什么使用VGGNet这样过时的架构会损害你的产品和性能,以及为什么如果你要构建任何现代产品,你应该跳过它们,你还会了解到你可以去使用哪些架构来解决实际问题,以及每种架构的优点和缺点是什么。
如果你想将计算机视觉应用于你的领域,使用本课的资源--你将能够找到最新的模型,了解它们是如何工作的,以及通过哪些标准你可以比较它们,并决定使用哪一个。
你不必在谷歌上寻找架构及其实现方式--它们通常在论文中得到了非常清晰的解释,而像Keras这样的框架使这些实现方式比以往更容易。这个指导项目的关键收获是教你如何寻找、阅读、实现和理解架构和论文。世界上没有任何资源能够跟上所有最新的发展。我在这里包括了最新的论文--但在几个月后,新的论文会冒出来,这是不可避免的。知道在哪里可以找到可靠的实现,将它们与论文进行比较,并对它们进行调整,可以为你可能想要建立的许多计算机视觉产品提供必要的竞争优势。
总结
在这个简短的指南中,我们看了一下CNN架构设计中扁平化的替代方案。尽管很短--该指南解决了设计原型或MVP时的一个常见问题,并建议你使用一个更好的替代扁平化的方法。
任何经验丰富的计算机视觉工程师都会知道并应用这一原则,而且这一做法被认为是理所当然的。不幸的是,它似乎没有被正确地转达给刚刚进入这个领域的新从业者,而且可能会产生需要一段时间才能摆脱的粘性习惯。
如果你要进入计算机视觉领域--帮你自己一个忙,在你的学习旅程中,不要在分类前使用扁平化层。
