Transformer训练与生成背后的数学基础

16 阅读9分钟

Transformer的训练与生成能力,本质是以概率论为核心的序列建模框架,结合神经网络通用逼近定理可学习的注意力核函数基于梯度下降的经验风险最小化构建的完整数学体系。以下从核心数学假设、训练目标的支撑公式、核心公式的严谨推导三个维度展开。

一、Transformer训练的核心数学假设

Transformer的所有训练逻辑都建立在4个可证明的数学前提之上,是其能够拟合数据、生成信息的底层支撑:

  1. 自回归序列建模的链式法则假设:任意长度为 TT 的token序列 X=(x1,x2,,xT)X=(x_1,x_2,\dots,x_T) ,其联合概率分布可通过概率论链式法则,拆解为每个位置token基于前文上下文的条件概率乘积,彻底规避RNN类模型的串行计算缺陷,是生成式Transformer的底层数学逻辑。

  2. 神经网络通用逼近定理:包含足够多隐藏单元的前馈网络(Transformer的FFN层)与多头注意力模块,能够以任意精度逼近紧集上的任意连续序列到序列映射函数,证明了Transformer的拟合能力上限。

  3. 自注意力的核平滑建模假设:自注意力本质是可学习的内积核加权求和,能够建模序列中任意两个位置的全局依赖关系,突破了RNN类模型的长距离依赖瓶颈,其数学本质是通过相似度权重实现序列信息的最优聚合。

  4. 经验风险最小化与最大似然估计的一致性假设:Transformer的训练目标等价于在训练集上最大化数据的似然概率,即最小化模型分布与真实数据分布的差异,这是监督训练的核心数学准则。

二、支撑Transformer训练目标的核心数学模型与公式

Transformer的训练目标,本质是找到最优模型参数 θ\theta ,让模型分布尽可能逼近真实数据的分布,核心由以下公式体系支撑:

1. 序列建模的概率分解(生成逻辑的核心)

对于训练集中任意序列样本 X(i)=(x1(i),x2(i),,xTi(i))X^{(i)}=(x_1^{(i)},x_2^{(i)},\dots,x_{T_i}^{(i)}) ,依据概率论链式法则,其联合概率可拆解为自回归条件概率乘积形式,这也是Transformer生成文本的核心逻辑:

Pθ(X(i))=t=1TiPθ(xt(i)x1:t1(i))P_\theta(X^{(i)}) = \prod_{t=1}^{T_i} P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

其中:

  • Pθ(xt(i)x1:t1(i))P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) :Transformer建模的、给定前 t1t-1 个token时,第 tt 个位置token的条件概率;

  • x1:t1x_{1:t-1} :前 t1t-1 个token组成的上下文序列, t=1t=1 时为空序列。

2. 训练目标:最大似然估计(MLE)

最大似然估计是Transformer监督训练的核心数学准则,目标是找到最优参数 θ\theta^* ,让模型生成整个训练集样本的概率最大化:

θ=argmaxθL(θ;D)\theta^* = \arg\max_\theta \mathcal{L}(\theta; \mathcal{D})

其中 D={X(1),X(2),,X(N)}\mathcal{D}=\{X^{(1)},X^{(2)},\dots,X^{(N)}\} 代表完整训练集,对数似然函数 L(θ;D)\mathcal{L}(\theta; \mathcal{D}) 具体表达式为:

L(θ;D)=i=1NlogPθ(X(i))=i=1Nt=1TilogPθ(xt(i)x1:t1(i))\mathcal{L}(\theta; \mathcal{D}) = \sum_{i=1}^N \log P_\theta(X^{(i)}) = \sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

取对数是为了避免概率连乘导致的数值下溢,且对数函数为单调递增函数,最大化对数似然与最大化原始似然完全等价。

3. 损失函数:交叉熵损失(与MLE完全等价)

实际训练中,我们通过最小化损失函数实现参数优化,而最小化交叉熵损失等价于最大化对数似然

对于单个token预测任务,设真实标签为one-hot向量 yty_t (仅真实token位置为1,其余位置全为0),模型输出概率分布为 y^t=Pθ(x1:t1)\hat{y}_t=P_\theta(\cdot \mid x_{1:t-1}) ,单位置交叉熵损失公式为:

LCE(yt,y^t)=v=1Vyt(v)logy^t(v)\mathcal{L}_{CE}(y_t, \hat{y}_t) = -\sum_{v=1}^V y_t(v) \cdot \log \hat{y}_t(v)

其中 VV 为词表大小。由于 yty_t 是one-hot向量,公式可简化为:

LCE(yt,y^t)=logPθ(xtx1:t1)\mathcal{L}_{CE}(y_t, \hat{y}_t) = -\log P_\theta(x_t \mid x_{1:t-1})

整个训练集的总损失为所有样本所有位置的交叉熵之和:

Ltotal(θ;D)=i=1Nt=1TiLCE(yt(i),y^t(i))=L(θ;D)\mathcal{L}_{total}(\theta; \mathcal{D}) = \sum_{i=1}^N \sum_{t=1}^{T_i} \mathcal{L}_{CE}(y_t^{(i)}, \hat{y}_t^{(i)}) = -\mathcal{L}(\theta; \mathcal{D})

至此,训练目标从“最大化对数似然”转化为“最小化交叉熵损失”,这是Transformer训练的核心优化目标。

4. 核心组件:缩放点积注意力的数学公式

自注意力是Transformer的核心结构,其前向计算的数学公式为:

Attention(Q,K,V)=Softmax(QKdk)V\text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

其中:

  • Q=XWQ,  K=XWK,  V=XWVQ=XW_Q,\; K=XW_K,\; V=XW_V :输入序列 XX 经过可学习投影矩阵变换得到的查询、键、值矩阵, WQWKWVW_Q、W_K、W_V 是模型核心训练参数;

  • dk\sqrt{d_k} :缩放因子,避免内积值过大导致Softmax梯度消失;

  • Softmax:将注意力得分转化为和为1的权重,实现对值矩阵的加权求和。

5. 优化的数学基础:梯度下降与反向传播

Transformer的参数更新基于梯度下降法,核心是通过反向传播的链式法则,计算损失函数对每个参数的梯度 θLtotal\nabla_\theta \mathcal{L}_{total} ,并沿梯度反方向更新参数:

θt+1=θtηθLtotal(θt)\theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta \mathcal{L}_{total}(\theta_t)

其中 η\eta 为学习率,实际训练中通常使用AdamW优化器实现自适应梯度更新。

三、核心公式的详细推导

推导1:交叉熵损失与最大似然估计的等价性推导

这是Transformer训练目标最核心的数学证明,完整推导如下:

前置定义与前提

  1. 训练集 D={X(1),X(2),,X(N)}\mathcal{D}=\{X^{(1)},X^{(2)},\dots,X^{(N)}\} 为独立同分布(i.i.d.)样本,每个样本服从真实数据分布 Pdata(X)P_{\text{data}}(X)

  2. 模型分布 Pθ(X)P_\theta(X) 由Transformer参数化,训练核心目标是缩小 Pθ(X)P_\theta(X)Pdata(X)P_{\text{data}}(X) 的分布差异;

步骤1:写出似然函数

  1. 对数函数为严格单调递增函数,因此 argmaxθL(θ)=argmaxθlogL(θ)\arg\max_\theta \mathcal{L}(\theta) = \arg\max_\theta \log \mathcal{L}(\theta) ,最大化似然等价于最大化对数似然。

由于样本独立同分布,模型生成整个训练集的联合概率(似然函数)为各样本概率的乘积:

L(θ;D)=Pθ(D;θ)=i=1NPθ(X(i);θ)\mathcal{L}(\theta; \mathcal{D}) = P_\theta(\mathcal{D}; \theta) = \prod_{i=1}^N P_\theta(X^{(i)}; \theta)

步骤2:转化为对数似然函数

对似然函数取自然对数,将乘积转化为求和,避免数值下溢:

logL(θ;D)=log(i=1NPθ(X(i);θ))=i=1NlogPθ(X(i);θ)\log \mathcal{L}(\theta; \mathcal{D}) = \log \left( \prod_{i=1}^N P_\theta(X^{(i)}; \theta) \right) = \sum_{i=1}^N \log P_\theta(X^{(i)}; \theta)

MLE的目标转化为:

θMLE=argmaxθi=1NlogPθ(X(i);θ)\theta^*_{\text{MLE}} = \arg\max_\theta \sum_{i=1}^N \log P_\theta(X^{(i)}; \theta)

步骤3:代入自回归概率分解

根据概率论链式法则,序列的联合概率可分解为条件概率的乘积,取对数后得到:

logPθ(X(i))=log(t=1TiPθ(xt(i)x1:t1(i)))=t=1TilogPθ(xt(i)x1:t1(i))\log P_\theta(X^{(i)}) = \log \left( \prod_{t=1}^{T_i} P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) \right) = \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

代入对数似然函数,得到:

logL(θ;D)=i=1Nt=1TilogPθ(xt(i)x1:t1(i))\log \mathcal{L}(\theta; \mathcal{D}) = \sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

步骤4:转化为最小化负对数似然

机器学习中习惯最小化损失函数,因此将最大化对数似然转化为最小化负对数似然(NLL):

θ=argminθi=1Nt=1TilogPθ(xt(i)x1:t1(i))负对数似然NLL\theta^* = \arg\min_\theta \underbrace{ -\sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) }_{\text{负对数似然NLL}}

步骤5:证明NLL与交叉熵损失完全等价

首先给出离散分布的交叉熵定义:对于真实分布 PP 和模型分布 QQ ,交叉熵为

H(P,Q)=xP(x)logQ(x)H(P,Q) = -\sum_{x} P(x) \log Q(x)

其物理意义是用模型分布 QQ 编码真实分布 PP 的样本所需的平均比特数, H(P,Q)H(P,Q) 越小, QQ 越接近 PP

对于序列中第 tt 个位置的预测任务:

  • 真实分布 PdataP_{\text{data}} 是one-hot分布:仅在真实token xt(i)x_t^{(i)} 处取值为1,其余位置为0;

  • 模型分布 QQPθ(x1:t1(i))P_\theta(\cdot \mid x_{1:t-1}^{(i)}) ,即Transformer输出的token概率分布。

代入交叉熵定义,该位置的交叉熵为:

H(Pdata,Pθ)=v=1VPdata(vx1:t1(i))logPθ(vx1:t1(i))H(P_{\text{data}}, P_\theta) = -\sum_{v=1}^V P_{\text{data}}(v \mid x_{1:t-1}^{(i)}) \cdot \log P_\theta(v \mid x_{1:t-1}^{(i)})

由于 PdataP_{data} 是one-hot分布,仅 v=xt(i)v=x_t^{(i)} 时项非零,因此公式简化为:

H(Pdata,Pθ)=logPθ(xt(i)x1:t1(i))H(P_{\text{data}}, P_\theta) = -\log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

这恰好就是该位置的负对数似然。

步骤6:总损失的等价性结论

整个训练集的总交叉熵损失,就是所有样本所有位置的交叉熵之和:

LCE=i=1Nt=1TiH(Pdata(i,t),Pθ(i,t))=i=1Nt=1TilogPθ(xt(i)x1:t1(i))\mathcal{L}_{CE} = \sum_{i=1}^N \sum_{t=1}^{T_i} H(P_{\text{data}}^{(i,t)}, P_\theta^{(i,t)}) = -\sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)})

该式与负对数似然完全一致。

最终结论:Transformer训练中最小化交叉熵损失,等价于在训练集上执行最大似然估计,核心目标是让模型分布无限逼近真实数据的分布,这就是Transformer能够学习数据规律、生成符合语义信息的核心数学支撑。


推导2:缩放点积注意力的反向传播梯度推导

自注意力是Transformer的核心组件,其参数训练依赖反向传播的梯度计算,完整推导如下:

前置定义

简化符号便于推导,设缩放点积注意力输入为 Q,K,VRT×dQ,K,V \in \mathbb{R}^{T \times d} ,统一令 dk=dv=dd_k=d_v=d ,前向计算核心流程如下:

  1. 注意力得分矩阵: S=QKdRT×TS = \frac{Q K^\top}{\sqrt{d}} \in \mathbb{R}^{T \times T}

  2. 注意力权重矩阵: A=Softmax(S)RT×TA = \text{Softmax}(S) \in \mathbb{R}^{T \times T} (按行做Softmax归一化,每行元素和为1)

反向传播核心目标:已知损失对输出矩阵 OO 的梯度 dO=OLRT×ddO = \nabla_O \mathcal{L} \in \mathbb{R}^{T \times d} ,逐层推导损失对查询矩阵 QQ 、键矩阵 KK 、值矩阵 VV 的梯度 dQdKdVdQ、dK、dV ,用于后续参数更新。

  1. 注意力输出矩阵: O=AVRT×dO = A \cdot V \in \mathbb{R}^{T \times d}

步骤1:计算对 VV 的梯度 dVdV

O=AVO = A V ,根据矩阵求导法则,对 VV 求偏导:

LVj,k=i=1TLOi,kOi,kVj,k=i=1TdOi,kAi,j\frac{\partial \mathcal{L}}{\partial V_{j,k}} = \sum_{i=1}^T \frac{\partial \mathcal{L}}{\partial O_{i,k}} \cdot \frac{\partial O_{i,k}}{\partial V_{j,k}} = \sum_{i=1}^T dO_{i,k} \cdot A_{i,j}

写成矩阵形式为:

dV=AdOdV = A^\top \cdot dO

步骤2:计算对注意力权重 AA 的梯度 dAdA

同样由 O=AVO = A V ,对 AA 的元素求偏导:

LAi,j=k=1dLOi,kOi,kAi,j=k=1ddOi,kVj,k\frac{\partial \mathcal{L}}{\partial A_{i,j}} = \sum_{k=1}^d \frac{\partial \mathcal{L}}{\partial O_{i,k}} \cdot \frac{\partial O_{i,k}}{\partial A_{i,j}} = \sum_{k=1}^d dO_{i,k} \cdot V_{j,k}

写成矩阵形式为:

dA=dOVdA = dO \cdot V^\top

步骤3:计算对得分矩阵 SS 的梯度 dSdS

A=Softmax(S)A = \text{Softmax}(S) ,先推导Softmax函数的单元素偏导规则。对于任意行向量 sRTs \in \mathbb{R}^T ,Softmax归一化后输出 ai=esik=1Tesk=esiZa_i = \frac{e^{s_i}}{\sum_{k=1}^T e^{s_k}} = \frac{e^{s_i}}{Z} ,其中 Z=k=1TeskZ=\sum_{k=1}^T e^{s_k} 为该行的归一化常数,保证每行权重和为1。

对Softmax输出求导,分两种核心情况推导单元素偏微分:

  1. 同位置求导(i=j)aisi=esiZesiesiZ2=ai(1ai)\frac{\partial a_i}{\partial s_i} = \frac{e^{s_i} \cdot Z - e^{s_i} \cdot e^{s_i}}{Z^2} = a_i (1 - a_i)

  2. 异位置求导(i≠j)aisj=esiesjZ2=aiaj\frac{\partial a_i}{\partial s_j} = \frac{ - e^{s_i} \cdot e^{s_j} }{Z^2} = - a_i a_j

结合反向传播链式法则,损失对得分矩阵 SS 的梯度,需要通过损失对注意力权重 AA 的梯度 dAdA 递推得到,完整矩阵形式的梯度公式为:

dS=A(dArow_sum(dAA))dS = A \odot \left( dA - \text{row\_sum}(dA \odot A) \right)

公式符号说明: \odot 代表哈达玛积(矩阵对应元素逐点相乘), row_sum\text{row\_sum} 代表对矩阵每一行单独求和,再将结果广播至该行所有列,保持矩阵维度不变,这一步是为了适配Softmax行归一化的梯度特性,避免梯度计算偏差。

步骤4:计算查询矩阵Q与键矩阵K的梯度

得分矩阵 SS 由查询矩阵 QQ 和键矩阵 KK 通过缩放内积得到,核心公式为 S=QKdS = \frac{Q K^\top}{\sqrt{d}} ,基于矩阵微分链式法则,分别推导对 QQKK 的梯度:

  1. 对查询矩阵 QQ 求梯度:缩放因子保持不变,直接关联得分矩阵梯度与键矩阵转置,即 dQ=1ddSKdQ = \frac{1}{\sqrt{d}} \cdot dS \cdot K

  2. 对键矩阵 KK 求梯度:需要先对得分矩阵 SS 转置,再关联梯度与查询矩阵,即 dK=1ddSQdK = \frac{1}{\sqrt{d}} \cdot dS^\top \cdot Q

缩放点积注意力反向传播最终梯度汇总

整合所有梯度推导结果,得到自注意力层完整的反向传播梯度公式,所有公式均适配标准LaTeX渲染规则,无复杂嵌套语法,确保正常显示:

dV=AdOdV = A^\top \cdot dO

dA=dOVdA = dO \cdot V^\top

dS=A(dArow_sum(dAA))dS = A \odot \left( dA - \text{row\_sum}(dA \odot A) \right)

dQ=1ddSKdQ = \frac{1}{\sqrt{d}} \cdot dS \cdot K

dK=1ddSQdK = \frac{1}{\sqrt{d}} \cdot dS^\top \cdot Q

这套完整的梯度推导流程,是Transformer自注意力模块参数更新的核心数学依据,结合反向传播链式法则,可将梯度逐层回传至模型所有可学习参数(投影矩阵 WQ,WK,WVW_Q, W_K, W_V 、前馈网络权重等),最终通过梯度下降完成模型训练,让模型逐步学习序列数据的内在分布规律。