本文已参与「新人创作礼」活动,一起开启掘金创作之路。
站在巨人的肩膀上看世界
参考文献
- 知乎-深入理解 ViT(2):www.zhihu.com/zvideo/1455…
- ViT论文逐段精读【论文精读】:作者-跟李沐学AI,主讲bryanyzhu
- VIT (Vision Transformer) 模型论文+代码(源码)从零详细解读,看不懂来打我:作者-NLP从入门到放弃
- Transformer中
Self-Attention
以及Multi-Head Attention
详解:作者-霹雳吧啦Wz - 11.1 Vision Transformer(vit)网络详解:作者-霹雳吧啦Wz
代码
11.2 使用pytorch搭建Vision Transformer(vit)模型:作者-霹雳吧啦Wz- 博客-详解Transformer中
Self-Attention
以及Multi-Head Attention
:太阳花的小绿豆 - 原文:arxiv.org/abs/2010.11…
2022年02月14日14:31:34学习完上述Vision Transformer教程,站在巨人的肩膀上可以让我们更快的前进。完成学习后对内容进行一个总结,以及对ViT模型进行一个简单介绍,让才接触的同学快速了解ViT模型。
Vision Transformer
我们以原文中的一张图,来快速了解
ViT这个模型,我们把整体结构分成五个部分:
- 将输入的图像进行patch的划分
- Linear Projection of Flatted Patches,将patch拉平并进行线性映射
- 生成CLS token特殊字符*,生成Position Embedding,Patch+Position Embedding相加作为inputs token
- Transformer Encoder编码,特征提取
- MLP Head进行分类输出结果
第一部分:图像划分Patch
名称 | 说明 |
---|---|
输入图像维度 | 224×224×3 |
Patch块大小 | 16×16 |
token个数 | 196 |
token维度 | 768 |
将图像分成16×16
的patch(小方块),每个patch块可以看做是一个token(词向量),共有(224/16=14)14×14=196
个token,每个token的维度为16×16×3=768
。
第二部分:Linear Projection of Flatted Patches
一个patch块它的维度是16×16×3=768
,我们把它flatten拉平成行向量它的长度就为768,一共有14×14=196
个patch,所以输入的维度是[196, 768]
,我们经过一个Linear Projection(映射)到指定的维度,比如1024或2048,我们用全连接层来实现,但映射的维度我们任然选择为768,那么此时的输出是[196, 768]
。
整个过程我们把它称作patch embedding;输出的结果维度是[196, 768]
但是
有大佬就不服啦,反正都是线性映射,为啥我要这么复杂,把它flatten后在用全连接层呢?聪明的大佬就想到,我直接使用卷积
来实现,用16×16
大小的卷积核,步长stride=16,维度设为768,输入[3, 224, 224]->[768, 14, 14]
,然后交换并合并一下维度不就得到结果了吗?[196, 768]
第三部分:Patch+Position Embedding
patch embedding的维度为[196, 768]
1.首先==生成==一个cls token它的维度为[1, 768],然后拼接到输入的path embedding,得到的维度为[197, 768]
2.对197个patch都==生成==一个位置信息,它的维度同patch维度为[197, 768]
3.Patch+Position Embedding,直接相加作为新的
输入token
疑惑解答
- cls token和位置信息编码是如何来的呢?随机生成的可学习参数,可以全零初始化,也可以0-1随机初始
cls_token = nn.Parameter(torch.zeros(1, 768))
pos_embedding = nn.Parameter(torch.zeros(197, 768))
- cls token的作用是为了同NLP领域的Transformer保持一致,后面直接用cls token作为最后网络提取特征,作为输入用于分类
- 不使用cls token也是可以的,对196个维度为768的patch使用全局均值池化,得到结果同cls token[1, 768]后用于分类
- 位置信息编码是为了给patch加上相对位置,不然在后面的特征提取中丢掉了位置信息可就不好了
- 位置信息编码有可学习参数类型,有通过公式计算的方法,可以是一维、二维;但不使用性能会差点
第四部分:Transformer Encoder
整个Transformer Encoder结构如下,详细代码实现和Multi-Head Attention的理解请看开头的视频讲解。
名称 | 说明 |
---|---|
输入token(Embedded Patches) | 维度为[197, 768] |
编码后的输出 | [197, 768] |
L× | 重复EncoderL次 |
Norm | Layer Norm归一化 |
Multi-Head Attention | 多头注意力机制 |
+ | 跳跃连接 |
MLP | Linear全连接层 |
Transformer Encoder网络结构如下:
Encoder输入的维度为[197, 768],输出的维度为[197, 768],可以把中间过程简单的理解成为特征提取的过程
其中
的Multi-Head Attention多头注意力机制,看完开头 太阳花的小绿豆的博客后,做一个简单的认识:
Self-Attention:
Multi-Head Attention:
第五步:MLP Head
分类头的意思就是:特征提取工作已经全部完成,现在你要做分类就加上对应的操作。(同CNN特征提取层后的Linear层+Softmax进行分类预测)。
整个Encoder的输出为[197, 768]我们仅仅保留最前面的CLS token作为全连接的输入[1, 768],然后接上全连接层及分类数n_class,使用交叉熵损失函数计算损失,反向传播更新网络的权重和参数。
2022年02月15日15:26:37,over。