矩阵运算维度设计与Attention机制的底层逻辑

4 阅读1分钟

1. 矩阵运算维度设计

深度学习框架选择 xWTxW^T并非随意之举,而是基于以下核心原则的精心设计:

1.1. 内存访问效率优先

现代硬件(CPU/GPU)的性能瓶颈主要在于内存访问速度,而非计算速度。xWTxW^T能够实现连续内存访问:

对于 x的每一行(一个样本),程序可以顺序读取其所有特征,这与数据在内存中的物理存储方式一致。

这种模式极大提高了缓存命中率,从而显著提升计算速度。

相比之下,若采用 WxWx 模式,则需将 xx 转置或将特征维度前置,导致访问模式变为跳跃式(strided access),引发大量缓存缺失(cache miss),显著降低吞吐效率。

1.2. 批处理的天然适配

深度学习框架普遍采用 (B, L, D) 作为标准张量布局,其中:

  • B:批次大小(Batch size)
  • L:序列长度(Sequence length)
  • D:特征维度(Feature dimension)

xWTxW^T 的运算天然适配该结构:

  • 当 x.shape = [B, D] 时,xWTxW^T 直接输出 [B, d_out],批次维度自动保留,无需额外维度变换。
  • 若使用 WxWx,则必须将 xx 转置为 [D, B]\text{[D, B]},破坏了批处理的直观性与编程习惯,增加代码复杂度与出错风险。

1.3. 框架的设计一致性

主流框架(如PyTorch、TensorFlow)中,nn.Linear(D_in, D_out) 的权重矩阵默认形状为 (D_out, D_in),即:

  • 每一行代表一个输出神经元的权重向量。
  • xWTxW^T 实质上是将输入样本与每个输出神经元做内积。

这种设计还支持:

  • SIMD指令集加速:连续内存访问允许CPU/GPU并行处理多个数据点。
  • 自动微分友好:梯度传播路径清晰,无需复杂转置逻辑。
  • 函数定义的时候,直接传参是输入维度和输出维度,便于代码开发的时候理解
  • 我们实际计算的时候数据是行读取的,x的一行为(D_in),我们希望W的一行也是(D_in)
  • 因此将矩阵的维度定义成(D_out, D_in),这样可以直接取出一行权重(维度D_in)
  • 因此理论上X与W进行相乘的时候,W需要进行转置(D_in, D_out)

2. slef Attention矩阵运算推理

2.1 单头注意力机制

假如有B个输入,每个输入为一系列的token,序列长度为L,每个token转变成D维的embedding,则输入X.shape = [B, L, D]

  1. 原始输入的文本,通过分词器得到tokenid
  2. tokenid通过nn.embedding得到对应的词向量x\mathbf{x}
  3. 词向量与WQW_QWKW_KWVW_V运算得到q\mathbf{q}k\mathbf{k}v\mathbf{v}三个向量
graph BT
    a1[token1]-->|tokenizer|b1[token id1]
    b1-->|nn.embedding|c1[x1]
    c1-->|Wq|d1[q1]
    c1-->|Wk|e1[k1]
    c1-->|Wv|f1[v1]
    
    a2[token2]-->|tokenizer|b2[token id2]
    b2-->|nn.embedding|c2[x2]
    c2-->|Wq|d2[q2]
    c2-->|Wk|e2[k2]
    c2-->|Wv|f2[v2]
    
    a3[token3]-->|tokenizer|b3[token id3]
    b3-->|nn.embedding|c3[x3]
    c3-->|Wq|d3[q3]
    c3-->|Wk|e3[k3]
    c3-->|Wv|f3[v3]

上一个小节讲了,在做计算的时候y=xWTy=xW^T,假如xx的维度为(1, D), 则:

  • WqW_q维度为[d_q, D],WqTW_q^T的维度就是[D,d_q],q\mathbf{q}的维度就是(1,D)x(D,d_q)->(1,d_q)
  • WkW_k维度为[d_k, D],WkTW_k^T的维度就是[D,d_k],k\mathbf{k}的维度就是(1,D)x(D,d_k)->(1,d_k)
  • WvW_v维度为[d_v, D],WvTW_v^T的维度就是[D,d_v],v\mathbf{v}的维度就是(1,D)x(D,d_v)->(1,d_v)

以token1为例,我们可以计算出

  • q1=x1WqTq_1 = x_1W_q^Tq1.shape=[1,d_q]
  • k1=x1WkTk_1 = x_1W_k^Tk1.shape=[1,d_k]
  • v1=x1WvTv_1 = x_1W_v^Tv1.shape=[1,d_v]

每个序列有L个token,因此就有L个向量,可以发现Q,K,V都是一个二维矩阵,例如Q有L行,每行有d_q列

  • Q=[q1,q2,...qL]Q = [q_1,q_2,...q_L],Q.shape = [L, d_q]
  • K=[k1,k2,...kL]K = [k_1,k_2,...k_L],K.shape = [L, d_k]
  • V=[v1,v2,...vL]V = [v_1,v_2,...v_L],V.shape = [L, d_v]

我们要计算token之间的相关性,例如token0与其他所有token的相关性,

  • a11=q1k1Ta_{11} = q_1 k_1^T
  • a12=q1k2Ta_{12} = q_1 k_2^T
  • a13=q1k3Ta_{13} = q_1 k_3^T

q和k的维度要相同,否则无法进行点积d_q=d_k

token0和其他所有token的权重组织一个向量

a1=[q1k1T,q1k2T,...,q1kLT]=q1[k1T,k2T...kLT]=q1KT\begin{aligned} \mathbf{a_1} &= [\mathbf{q_1} * \mathbf{k_1^T},\mathbf{q_1} * \mathbf{k_2^T}, ...,\mathbf{q_1} * \mathbf{k_L^T}] \\ &= \mathbf{q_1} * [\mathbf{k_1^T}, \mathbf{k_2^T} ... \mathbf{k_L^T}] \\ &= \mathbf{q_1} * \mathbf{K^T} \end{aligned}

所有的kiTk_i^T组织起来的矩阵就是KTK^T,同理可以得到a2和其他所有

  • a2=q2KT\mathbf{a_2} = \mathbf{q_2} * \mathbf{K^T}
  • aL=qLKT\mathbf{a_L} = \mathbf{q_L} * \mathbf{K^T}

我们讲这个注意力权重拼接成一个矩阵

A=[a1,a2,...aL]=[q1KT,q2KT,...qLKT]=[q1,q2,...qL]KT=QKT\begin{aligned} A &= [\mathbf{a1},\mathbf{a2},... \mathbf{aL}] \\ &= [\mathbf{q_1} * \mathbf{K^T}, \mathbf{q_2} * \mathbf{K^T}, ... \mathbf{q_L} * \mathbf{K^T}] \\ &= [\mathbf{q_1}, \mathbf{q_2}, ... \mathbf{q_L}] * \mathbf{K^T} \\ &= \mathbf{Q} * \mathbf{K^T} \end{aligned}

就得到了最终的注意力矩阵,也就是每个token对其他token的权重

  • A.shape = [L, L]
  • o = AV,o.shape = [L, o_v]

2.2 多头注意力机制

引入多头机制(Multi-Head Attention)的核心目的:让不同“子空间”学习不同的注意力模式。

1. 单个头内的运算 

假设我们有 H=8H=8 个头,总维度 D=512D=512,那么每个头的维度 dh=64d_{h}=64。  对于其中某一个头(比如 Head 1): 

  • Q1,K1Q_{1},K_{1} 的形状是 [L,64][L,64]

  • 注意力矩阵 A1A_{1}

    A1=softmax(Q1K1T64)A_{1}=\text{softmax}\left(\frac{Q_{1}K_{1}^{T}}{\sqrt{64}}\right)

注意:A1A_{1} 的形状依然是 [L,L][L,L]。它表示的是这个头视角下,Token 之间的相关性。

  • V1V_{1} 的形状是 [L,64][L,64]
  • 相乘 O1=A1V1O_{1}=A_{1}V_{1}

[L,L]×[L,64][L,64][L,L]\times [L,64]\rightarrow [L,64] 所以,Head 1 的输出是一个 [L,64][L,64] 的矩阵。 

2. 为什么说“维度缩小了”? 

在单头注意力中,VV[L,512][L,512],最后得到 [L,512][L,512]。 在多头中,每个头确实只产出了 [L,64][L,64] 的结果。 但是有 8 个头: 

  • Head 1 输出 O1:[L,64]O_{1}:[L,64]
  • Head 2 输出 O2:[L,64]O_{2}:[L,64]
  • ...
  • Head 8 输出 O8:[L,64]O_{8}:[L,64] 

3. 最终的“拼接” (Concatenation)

我们会把这 8 个头的输出在最后一个维度(特征维度)拼起来: Oconcat=[O1,O2,...,O8]O_{concat}=[O_{1},O_{2},...,O_{8}] [L,64] 拼接 8 次 [L,64×8][L,512][L,64]\text{\ 拼接\ }8\text{\ 次\ }\rightarrow [L,64\times 8]\rightarrow [L,512]

  1. 深度理解:为什么这样做?  可以把 VV512512 个维度想象成 512 个不同的特征属性。 
  • 单头: 用同一个注意力权重 AA 去加权所有的 512 个属性。
  • 多头:
    • 前 64 个属性用 A1A_{1} 来加权(可能关注语法);
    • 中间 64 个属性用 A2A_{2} 来加权(可能关注指代消解);
    • ...
    • 最后 64 个属性用 A8A_{8} 来加权。 

结论:

VV 的维度在每个头内部计算时确实是缩小的,但这种缩小是为了分工。每个头只负责处理一部分特征维度。最后拼在一起时,总的特征维度 DD 又回来了。