2017年Google《Attention Is All You Need》一文横空出世,彻底抛弃RNN框架,以‘注意力机制’为核心打造Transformer,凭借并行计算能力和全局依赖捕捉能力,奠定了ChatGPT、BERT等大模型的技术基石。
摘要:【深度学习Day14】我们用LSTM攻克了长序列情感分类,但它“逐词串行计算”的本质的仍是效率与能力的瓶颈。2017年Google《Attention Is All You Need》一文横空出世,彻底抛弃RNN框架,以“注意力机制”为核心打造Transformer,凭借并行计算能力和全局依赖捕捉能力,奠定了ChatGPT、BERT等大模型的技术基石。本文将以MATLAB老鸟的矩阵思维,拆解Transformer的两大核心——自注意力机制(Self-Attention)与位置编码(Positional Encoding),用“图书馆检索”类比理清QKV逻辑,通过矩阵公式推导吃透计算本质,最终用PyTorch手写自注意力层与Transformer Block,同时对比MATLAB与PyTorch的实现差异,让你从原理到代码彻底搞懂Transformer!
关键词:Transformer、Self-Attention、QKV、位置编码、PyTorch、多头注意力、残差连接、LayerNorm、矩阵运算
1. 开篇对决:为什么Transformer要“淘汰”RNN/LSTM?
LSTM虽缓解了梯度消失,但本质上仍未跳出RNN的“串行思维”——就像工厂流水线,必须等前一个工序完成,下一个才能启动。这种模式在长序列任务中暴露了两大致命短板,而Transformer的出现,正是为了彻底解决这些问题。
1.1 RNN/LSTM的两大“绝症”
- 效率绝症:无法并行训练,硬件算力浪费 普通RNN和LSTM都需按时间步顺序计算,当前时刻的隐藏状态必须依赖上一时刻结果。哪怕你有8张4090显卡,也只能让一张显卡工作,其余全部闲置——就像排队结账,再多收银台也只能排一队,效率极低。在长序列(如千词文本、小时级时间序列)任务中,训练耗时会呈指数级增长。
- 记忆绝症:长序列仍会“失忆”,全局依赖捕捉薄弱 LSTM的门控机制虽能延长记忆,但仍属于“递进式记忆”——信息沿时间步逐步传递,每经过一个时刻就会有损耗。若序列长度达到上千(比如一篇论文、一本小说),开头的关键信息经过层层传递后,依然会被稀释甚至丢失,无法与结尾信息建立有效关联。
1.2 Transformer的“降维打击”:全局并行的“上帝视角”
Transformer的核心逻辑是“打破串行枷锁,实现全局并行”:它不按时间步逐词计算,而是将一整句话的所有词向量同时输入,让每个词瞬间“看到”其他所有词,直接计算词与词之间的关联强度——这就像你站在山顶看全景,而非沿着山路逐步摸索,既能捕捉全局依赖,又能让所有计算并行执行,显卡算力被充分利用。
核心差异对比(矩阵思维视角):RNN是“逐元素递推计算”(时间步维度循环),Transformer是“全矩阵一次性运算”(时间步维度并行),这也是MATLAB用户对Transformer更易共情的原因——本质都是矩阵操作,只是维度处理逻辑不同。
| 对比维度 | RNN/LSTM | Transformer |
|---|---|---|
| 计算模式 | 串行计算,依赖时间步顺序 | 并行计算,全局信息同步捕捉 |
| 依赖捕捉能力 | 递进式依赖,长序列衰减 | 全局依赖,任意词对直接关联 |
| 硬件利用率 | 低,多显卡无法并行 | 高,全矩阵运算适配并行加速 |
| 核心依赖 | 门控机制(LSTM)缓解梯度消失 | 自注意力机制+位置编码 |
2. 核心心脏:自注意力机制(Self-Attention)——让词与词“互认”
自注意力机制是Transformer的灵魂,核心作用是让句子中的每个词“自主判断与其他词的关联程度”,进而生成融合上下文信息的词向量。比如区分“Apple is red”(苹果是红色的)和“Apple Inc. released iPhone”(苹果公司发布iPhone)中的“Apple”,靠的就是自注意力捕捉到的不同上下文关联。
对于MATLAB老鸟来说,自注意力的本质就是“矩阵变换+相似度计算+加权求和”的组合操作,全程无复杂循环,纯矩阵运算即可实现。
2.1 灵魂三问:、、到底是什么?(图书馆检索类比)
自注意力机制将每个词向量通过线性变换,生成三个“分身”(、、),三者分工明确,类比图书馆检索场景就能秒懂:
- (Query,查询向量) :每个词的“检索需求”。就像你去图书馆找书时提出的需求——“我要找深度学习相关的书”,每个词都会基于自身信息生成专属查询。
- (Key,键向量) :每个词的“身份标签”。就像书脊上的分类标签——“计算机科学/AI/深度学习”,用于和匹配,判断两者是否相关。
- (Value,值向量) :每个词的“核心内容”。就像书里的具体内容,当和匹配成功后,就会提取对应的作为上下文信息。
核心逻辑:用每个词的,去匹配所有词的,计算相似度(关联强度),再根据相似度对所有词的进行加权求和,最终得到融合上下文的词向量——每个词的最终表示,都包含了自身信息和与之相关的其他词的信息。
2.2 一步到位:自注意力的矩阵公式与拆解(MATLAB友好版)
自注意力的完整计算可浓缩为一个公式,全程矩阵运算,对MATLAB用户而言极其解压:
我们按步骤拆解,结合矩阵维度变化(假设词向量维度为,序列长度为),彻底搞懂每一步的意义:
- 、、的生成(线性变换) 首先对输入词向量矩阵(维度)做三次线性变换,分别得到、、矩阵(维度均为): ,,,其中、、是可学习的权重矩阵(维度)。
关键疑问:为什么要做线性变换?不能直接用输入词向量当、、吗?
答:线性变换能将词向量投影到三个不同的特征空间,让模型从“查询、键、值”三个视角理解词的含义——就像观察一个物体,从正面()、侧面()、底面()分别观测,能捕捉更全面的信息。若直接用原始词向量,三个角色重叠,无法实现精准的“检索-匹配-提取”逻辑。
-
计算相似度矩阵() 将矩阵与矩阵的转置(,维度)做矩阵乘法,得到相似度矩阵(维度): 。相似度矩阵的每个元素,表示第个词的与第个词的的相似度,即两个词的关联强度——这就是自注意力的“关联图谱”,一次性算出所有词对的关系,对应MATLAB中的
Q * K'操作。 -
缩放操作(除以) 当较大时,的结果数值会很大,经过softmax后,大概率出现“部分值接近1,其余接近0”的极端分布,导致梯度落在softmax的饱和区(梯度几乎为0),引发梯度消失。 除以能将相似度矩阵的方差控制在1附近,避免数值过大,保证softmax后梯度正常传播——这一步相当于MATLAB中的
S / sqrt(d_k),是稳定训练的关键。 -
Softmax归一化对相似度矩阵的每一行做softmax运算,将相似度转化为概率分布(维度): 。每行元素之和为1,表示第个词对第个词的“贡献权重”——权重越高,说明第个词对第个词的上下文影响越大。
-
加权求和() 将概率矩阵与矩阵(维度)做矩阵乘法,得到最终的自注意力输出矩阵(维度): 每个词的输出向量,都是所有词的向量按权重加权求和的结果,完美融合了上下文信息——这一步对应MATLAB中的
A * V,完成上下文信息的聚合。
2.3 进阶:多头注意力(Multi-Head Attention)——多视角融合
上述基础自注意力是“单头”,而Transformer实际用的是“多头注意力”——将自注意力机制重复次(为头数),每个头独立计算,最后将结果拼接并做线性变换,得到最终输出。
核心优势:每个头可以捕捉不同类型的关联(比如一个头关注语法结构,一个头关注语义关联),多视角融合能让模型理解更全面。比如句子“他喜欢吃苹果”,有的头会关注“他”与“喜欢”的主谓关联,有的头会关注“吃”与“苹果”的动宾关联。
多头注意力的矩阵逻辑(简化版):
①将、、按头数拆分(每个头的维度为);
②每个头独立计算基础自注意力;
③所有头的输出拼接(维度);
④线性变换整合,得到多头注意力输出。
3. 关键补丁:位置编码(Positional Encoding)——给词加“座位号”
Transformer的致命缺陷:它是“并行全量输入”,本身没有时间顺序概念——在它眼里,“我爱你”和“你爱我”的词向量集合完全一致,无法区分语序,这会导致语义理解错误(类似Bag of Words模型的问题)。
为了解决这个问题,Google工程师提出了“位置编码”:给每个位置的词向量加一个“位置特征”(座位号),让模型能识别词的顺序。这个“座位号”不是简单的1、2、3,而是用正弦和余弦函数生成的周期性特征。
3.1 位置编码公式与原理
对于序列中位置为(从0开始)的词,其位置编码的第维(为词向量维度索引)计算公式如下:
其中是词向量总维度,和表示偶数维和奇数维——偶数维用正弦函数,奇数维用余弦函数,形成周期性的位置特征。
3.2 为什么不用简单整数编码?(MATLAB数值分析视角)
若直接用1、2、3...作为位置编码,会存在两个严重问题:
- 数值无限增长:序列越长,位置编码数值越大,会严重干扰词向量的语义信息,导致模型训练不稳定——就像在MATLAB中处理数据时,量级差异过大引发的数值溢出问题。
- 无相对位置信息:整数编码只能表示绝对位置,无法体现“相对距离”——比如位置1和2的距离,与位置100和101的距离,在整数编码中都是1,但语义上的关联强度可能完全不同。
正弦余弦编码的优势(准确的说是多维周期函数):
① 数值范围固定在,不会干扰词向量;
② 具有相对位置性质:可由通过正弦和角公式推导得到,模型能自动学习到词之间的相对距离。
3.3 位置编码的使用方式
位置编码不单独作为输入,而是与词向量“相加”后输入Transformer(而非拼接):。 原因:相加能让词向量同时包含“语义信息”和“位置信息”,且维度保持不变,后续矩阵运算无需调整维度——这是一种简洁高效的融合方式,对应MATLAB中的X_embedding + PE。
MATLAB快速实现(简化版):假设词向量维度,序列长度,可通过以下代码生成位置编码矩阵,直接复用至模型输入:
d_model = 512;
seq_len = 100;
PE = zeros(seq_len, d_model);
pos = 0:seq_len-1; % 位置索引(从0开始)
for i = 1:d_model/2
denom = 10000^(2*(i-1)/d_model);
PE(:, 2*i-1) = sin(pos' / denom); % 偶数维(MATLAB索引为奇数)
PE(:, 2*i) = cos(pos' / denom); % 奇数维(MATLAB索引为偶数)
end
4. PyTorch实战:手写自注意力层与Transformer Block
原理懂了,动手实现才是关键。我们不直接调用PyTorch的nn.MultiheadAttention,而是手写简版自注意力层和Transformer Block,从代码层面吃透内部逻辑——同时结合MATLAB视角,解读高维矩阵运算的技巧。
4.1 手写自注意力层(支持多头)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size # 词向量总维度
self.heads = heads # 头数
self.head_dim = embed_size // heads # 每个头的维度
# 断言:确保词向量维度能被头数整除,避免维度不匹配
assert (self.head_dim * heads == embed_size), "Embedding size must be divisible by heads"
# 定义Q、K、V的线性变换层(每个头独立,此处拼接计算提升效率)
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# 多头结果拼接后的输出线性层,将维度还原为embed_size
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask=None):
# N: 批次大小 (Batch Size)
# value_len/key_len/query_len: 序列长度(Encoder中三者一致,Decoder中可能不同)
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 1. 线性变换 + 拆分多头(维度调整:[N, Seq_Len, Embed_Size] → [N, Seq_Len, Heads, Head_Dim])
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 2. 计算Q与K的相似度(爱因斯坦求和约定,替代复杂的permute+matmul)
# 维度说明:queries[N, Q_len, H, D] → 与 keys[N, K_len, H, D] 做点积,结果为[N, H, Q_len, K_len]
# 对应MATLAB操作:permute(queries, [1,3,2,4]) * permute(keys, [1,3,4,2])
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# 3. 缩放操作:除以根号下每个头的维度(避免数值过大导致softmax梯度消失)
energy = energy / math.sqrt(self.head_dim) # 注意:此处用head_dim,而非embed_size
# 4. Masking(可选,Decoder专用,遮挡未来时刻的词,避免信息泄露)
if mask is not None:
# 将mask为0的位置设为极小值,softmax后概率接近0
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 5. Softmax归一化:在Key维度(第3维)做归一化,得到注意力权重
attention = torch.softmax(energy, dim=3)
# 6. 注意力权重与V加权求和(再次用einsum简化高维矩阵乘法)
# 维度说明:attention[N, H, Q_len, K_len] × values[N, V_len, H, D] → [N, Q_len, H, D]
# 对应MATLAB操作:permute(attention, [1,3,2,4]) * permute(values, [1,3,4,2]) → 再reshape
out = torch.einsum("nhqk,nvhd->nqhd", [attention, values])
# 7. 多头结果拼接:[N, Q_len, H, D] → [N, Q_len, H*D] = [N, Q_len, Embed_Size]
out = out.reshape(N, query_len, self.heads * self.head_dim)
# 8. 输出线性层:整合多头信息
out = self.fc_out(out)
return out
4.2 代码解析
这部分重点解读高维矩阵运算的技巧,对比MATLAB与PyTorch的实现差异,帮你快速上手:
- 爱因斯坦求和约定(einsum):高维矩阵运算神器 代码中
torch.einsum是核心,它能通过字符串指定维度索引,自动完成转置、矩阵乘法和维度调整——这在MATLAB中需要用permute调整维度顺序,再用matmul做矩阵乘法,步骤繁琐且易出错。 例:torch.einsum("nqhd,nkhd->nhqk", [queries, keys])含义:(批次)、(查询序列长度)、(头数)、(头维度)→ 固定和,让与做点积,得到的相似度矩阵,最终维度为,完美契合注意力权重矩阵的需求。 - 多头拆分与拼接:维度对齐是关键 多头注意力的核心是“拆分维度→独立计算→拼接整合”,代码中通过
reshape实现维度拆分(将Embed_Size拆分为Heads×Head_Dim),计算完成后再拼接回原维度——这和MATLAB中的reshape逻辑完全一致,只是PyTorch的维度顺序(Batch在前)与MATLAB(Batch在后)略有差异,需注意调整。 - Mask的作用与实现 Mask主要用于Decoder(比如翻译任务中,遮挡未来时刻的词,避免模型看到未预测的内容),代码中通过
masked_fill将无效位置设为极小值,softmax后概率接近0,相当于“屏蔽”这些位置的信息——这对应MATLAB中的mask = zeros(size(energy)); energy(mask==0) = -1e20;。
4.3 拼装Transformer Block(Encoder核心模块)
单一自注意力层无法完成复杂任务,Transformer的基本单元是“Transformer Block”——将自注意力层、前馈神经网络、残差连接和LayerNorm组合而成,形成完整的特征提取模块。
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
# 自注意力层(多头)
self.attention = SelfAttention(embed_size, heads)
# LayerNorm:Transformer的标配归一化方式(区别于CNN的BatchNorm)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
# 前馈神经网络(Feed Forward Network):对每个词向量独立做非线性变换
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size), # 升维
nn.ReLU(), # 非线性激活
nn.Linear(forward_expansion * embed_size, embed_size) # 降维回原维度
)
# Dropout:防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
# 1. 自注意力 + 残差连接 + LayerNorm
attention = self.attention(value, key, query, mask)
# 残差连接:将输入query与注意力输出相加,缓解梯度消失
x = self.dropout(self.norm1(attention + query))
# 2. 前馈网络 + 残差连接 + LayerNorm
forward = self.feed_forward(x)
# 残差连接:将前馈网络输入x与输出forward相加
out = self.dropout(self.norm2(forward + x))
return out
4.4 核心组件解析:残差连接与LayerNorm
- 残差连接(Residual Connection) 核心作用是缓解梯度消失——将模块的输入直接与输出相加,让梯度能“绕开”复杂的网络层,直接反向传播到浅层。代码中
attention + query和forward + x就是残差连接,这和ResNet的残差逻辑完全一致,是深层Transformer能稳定训练的关键。 - LayerNorm vs BatchNorm(面试高频考点) Transformer用LayerNorm而非CNN常用的BatchNorm,核心原因是NLP任务的序列长度不固定,Batch维度的统计量不稳定:
-
- BatchNorm(BN):在Batch维度归一化(对一批样本的同一特征维度求均值方差),适合图像等固定维度数据,但NLP中句子长度不一,Batch统计量波动大,效果差。
-
- LayerNorm(LN):在Feature维度归一化(对每个样本的所有特征维度求均值方差),与序列长度无关,能保证每个词的向量分布稳定,更适配NLP任务。 对应MATLAB实现:LN可通过
(x - mean(x, 2)) ./ std(x, 0, 2)实现(按行归一化,即每个样本的特征维度归一化)。
- LayerNorm(LN):在Feature维度归一化(对每个样本的所有特征维度求均值方差),与序列长度无关,能保证每个词的向量分布稳定,更适配NLP任务。 对应MATLAB实现:LN可通过
5. 面试避坑指南:Transformer专场
Q1:为什么、、要通过线性变换生成?不能直接用输入词向量吗?
线性变换能将词向量投影到三个相互独立的特征空间,实现“查询-键-值”的功能分工,让模型能自主学习不同视角的语义关联;若直接用原始词向量,三者特征空间重叠,无法形成有效的“检索-匹配-提取”机制,模型表达能力受限。
就像你找东西时,需要明确“要找什么(Q)、东西标了什么(K)、东西里有什么(V)”,三个角色分开才能高效检索;若三者都是同一个东西,相当于既当裁判又当选手,逻辑混乱,效果自然差。
Q2:为什么要做Scaled操作(除以)?
当词向量维度较大时,的点积结果方差会随增大而增大,导致数值过大,经过softmax后进入饱和区(梯度趋近于0),引发梯度消失;除以可将点积结果的方差控制在1附近,避免数值极端化,保证梯度正常传播。
就像音量太大导致喇叭失真,Scaled操作相当于“调小音量”,让信号保持在合理范围,避免模型“听不清”(梯度消失)。
Q3:LayerNorm和BatchNorm的核心区别及适用场景?
① 归一化维度不同:BN在Batch维度归一化,LN在Feature维度归一化;② 统计量计算方式不同:BN依赖批次内所有样本的统计量,LN仅依赖单个样本的统计量;③ 适用场景不同:BN适合图像等固定维度、Batch统计量稳定的数据;LN适合NLP等序列长度不固定、Batch统计量波动大的数据。
BN是“全班同学按同一标准打分”,适合人数固定的班级;LN是“每个人按自己的标准打分”,适合人数不固定的群体,更适配NLP的句子长度差异。
Q4:多头注意力的优势是什么?头数越多越好吗?
优势是多视角捕捉关联信息,不同头可聚焦语法、语义、位置等不同类型的依赖,提升模型表达能力;头数并非越多越好,过多头会增加参数数量和计算量,导致过拟合,且边际收益递减,实际中常用8头或16头(如BERT用12头)。
多头就像多个人一起分析问题,每个人关注一个角度,最后汇总结论更全面;但人太多(头数过多)会增加沟通成本(计算量),且观点重复,反而效率低。
Q5:位置编码为什么用正弦余弦函数?
① 数值范围固定在,不会干扰词向量语义;② 具有相对位置不变性,可由线性表示,模型能学习到词之间的相对距离;③ 无需额外参数,计算高效,适配任意长度序列。
正弦余弦是周期性函数,能给每个位置一个“循环标签”,既区分了顺序,又不会因序列太长导致标签数值溢出,还能让模型知道“两个词隔了多少位置”。
📌 下期预告
Transformer虽强,却是为1D序列数据(文本、时间序列)量身打造的王者。但现实世界中,很多数据并非“排成一排”——社交网络的人际关系、分子的化学键结构、知识图谱的实体关联,都是复杂的网状结构(图结构)。
CNN无法处理不规则的图结构(卷积核无法对齐),Transformer也难以捕捉图中的拓扑关系(无明确位置顺序)。那么,如何让神经网络“看懂”图结构?
下一篇,我们将进入全新领域——图神经网络(GNN)基础。我们会拆解图卷积网络(GCN)的核心原理,把卷积的概念从“网格”推广到“图”,教你如何用神经网络建模人际关系、分子结构,打通“序列建模→图建模”的技术脉络!