把扩散模型迁移到文本领域 | 读论文

524 阅读3分钟

模型设计

  • First, we must define an embedding function that maps discrete text into a continuous space.
  • Second, we require a rounding method to map vectors in embedding space back to words.

image.png

端到端训练

为了将离散文本应用到连续扩散模型上,这里设置了一个embedding函数EMB(wi)\mathrm E_{\mathrm{MB}}(w_i)将每个词映射到词向量Rd\mathbb{R}^d。对于长度为nn的序列w\mathbf w

EMB(w)=[EMB(w1),,EMB(wn)]Rnd\mathrm E_{\mathrm{MB}}(\mathbf{w})=\left[\mathrm E_{\mathrm{MB}}\left(w_1\right), \ldots, \mathrm E_{\mathrm{MB}}\left(w_n\right)\right] \in \mathbb{R}^{n d}

作者在实验中发现,使用预训练的word embedding效果不如使用随机高斯噪声初始化之后训练出来的embedding,所以加了这样一个网络,实现离散词序列w\mathbf wx0\mathbf x_0的马尔科夫转换,参数化为:

qϕ(x0w)=N(EMB(w),σ0I)q_\phi\left(\mathbf{x}_0 \mid \mathbf{w}\right)=\mathcal{N} (\mathrm E_{\mathrm{MB}} (\mathbf{w}), \sigma_0 I)

与之对应的反向过程添加了一个可训练的舍入步骤,参数化为:

pθ(wx0)=i=1npθ(wixi)p_\theta\left(\mathbf{w} \mid \mathbf{x}_0\right)=\prod_{i=1}^n p_\theta\left(w_i \mid x_i\right)

其中pθ(wixi)p_\theta\left(w_i \mid x_i\right)是一个softmax分布。

因为增加了一个嵌入步骤和一个舍入步骤,添加的这两个网络是和扩散模型进行联合训练的,因此要对训练目标函数改进:

Lvlbe2e(w)=Eqϕ(x0w)[Lvlb(x0)+logqϕ(x0w)logpθ(wx0)]],Lsimple e2e(w)=Eqϕ(x0:Tw)[Lsimple (x0)+EMB(w)μθ(x1,1)2logpθ(wx0)].\begin{aligned} \mathcal{L}_{\mathrm{vlb}}^{\mathrm{e2e}}(\mathbf{w}) & \left.=\underset{q_\phi\left(\mathbf{x}_0 \mid \mathbf{w}\right)}{\mathbb{E}}\left[\mathcal{L}_{\mathrm{vlb}}\left(\mathbf{x}_0\right)+\log q_\phi\left(\mathbf{x}_0 \mid \mathbf{w}\right)-\log p_\theta\left(\mathbf{w} \mid \mathbf{x}_0\right)\right]\right], \\ \\ \mathcal{L}_{\text {simple }}^{\mathrm{e2e}}(\mathbf{w}) & =\underset{q_\phi\left(\mathbf{x}_{0: T} \mid \mathbf{w}\right)}{\mathbb{E}}\left[\mathcal{L}_{\text {simple }}\left(\mathbf{x}_0\right)+\left\|\mathrm E_{\mathrm{MB}} (\mathbf{w})-\mu_\theta\left(\mathbf{x}_1, 1\right)\right\|^2-\log p_\theta\left(\mathbf{w} \mid \mathbf{x}_0\right)\right] .\end{aligned}

Lsimplee2e(w)\mathcal{L}_{\mathrm{simple}}^{\mathrm{e} 2 \mathrm{e}}(\mathbf{w})Lvlbe2e(w)\mathcal{L}_{\mathrm{vlb}}^{\mathrm{e} 2 \mathrm{e}}(\mathbf{w})的简化版,是根据DDPM设计的目标函数。

DDPM的目标函数:

Lvlb(x0)=Eq(x1:Tx0)[logq(xTx0)pθ(xT)+t=2Tlogq(xt1x0,xt)pθ(xt1xt)logpθ(x0x1)]\mathcal{L}_{\mathrm{vlb}}\left(\mathbf{x}_0\right)=\underset{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{\mathbb{E}}\left[\log \frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_T\right)}+\sum_{t=2}^T \log \frac{q\left(\mathrm{x}_{t-1} \mid \mathbf{x}_0, \mathbf{x}_t\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right]
Lsimple (x0)=t=1TEq(xtx0)μθ(xt,t)μ^(xt,x0)2\mathcal{L}_{\text {simple }}\left(\mathbf{x}_0\right)=\sum_{t=1}^T \underset{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{\mathbb{E}}\left\|\mu_\theta\left(\mathbf{x}_t, t\right)-\hat{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)\right\|^2

可视化之后可以看到学到的embedding是有意义的。

image.png

减少舍入误差

embedding是将离散文本映射到连续的x0\mathbf x_0上,那与之对应的反向过程就应该是将模型预测出来的x0\mathbf x_0转换回离散的文本。

舍入步骤是根据argmaxpθ(wx0)=i=1npθ(wixi)\operatorname{argmax} p_\theta\left(\mathbf{w} \mid \mathbf{x}_0\right)=\prod_{i=1}^n p_\theta\left(w_i \mid x_i\right)选择每个位置上可能性最大的词。理想状态下通过这个舍入步骤就可以将模型输出的x0\mathbf x_0映射回离散文本,因为去噪步骤应该能让x0\mathbf x_0恰好回到某个单词的embedding上。

然而实际情况是不行的,模型输出是不会精确到某个单词的embedding。

作者认为造成上述问题的原因,是在目标函数中Lsimplee2e(x0)\mathcal{L}_{\mathrm{simple}}^{\mathrm{e} 2 \mathrm{e}}(\mathbf{\mathbf x_0})x0\mathbf x_0结构的建模不够重视。

Lsimple (x0)=t=1TExtμθ(xt,t)μ^(xt,x0)2\mathcal{L}_{\text {simple }}(\mathbf{x_0}) = \sum_{t=1}^T \mathbb{E}_{\mathbf{x}_t}\left\|\mu_\theta\left(\mathrm{x}_t, t\right)-\hat{\mu}\left(\mathrm{x}_t, \mathbf{x}_0\right)\right\|^2

其中的μθ(xt,t)\mu_\theta\left(\mathrm{x}_t, t\right)网络直接去预测时间步tt去噪的pθ(xt1xt)p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)的均值。x0\mathbf x_0对单词的约束只会出现在tt接近于0的项中。因此需要对目标函数进行调整,去强调x0\mathbf x_0

作者对Lsimple \mathcal{L}_{\text {simple }}进行调整,强调模型在目标函数的每一项中都去显式地建模x0\mathbf x_0。 作者推导出一个类似于Lsimple \mathcal{L}_{\text {simple }}的使用x0\mathbf x_0参数化的公式:

Lx0-simple e2e(x0)=t=1TExtfθ(xt,t)x02\mathcal{L}_{\mathbf{x}_0 \text {-simple }}^{\mathrm{e2e}}\left(\mathbf{x}_0\right)=\sum_{t=1}^T \mathbb{E}_{\mathbf{x}_t}\left\|f_\theta\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\right\|^2

其中网络fθ(xt,t)f_\theta\left(\mathbf{x}_t, t\right)直接去预测x0\mathbf x_0,这样会让神经网络去预测每一项的x0\mathbf x_0。 使用修改之后的目标去训练模型,模型很快就会学到让x0\mathbf x_0能回到以word embedding为中心。

clamping trick

作者将其称为clamping trick,在clamping trick中,模型将xt\mathbf x_t降噪为xt1\mathbf x_{t-1}的生成过程:

  1. 通过fθ(xt,t)f_\theta (\mathbf x_t,t)估计出一个x0\mathbf x_0

  2. 在这个估计的条件相爱对xt1\mathbf x_{t-1}进行采样

  3. xt1=αˉfθ(xt,t)+1αˉϵ\mathbf{x}_{t-1}=\sqrt{\bar{\alpha}} f_\theta\left(\mathbf{x}_t, t\right)+\sqrt{1-\bar{\alpha}} \epsilon 其中αˉt=s=0t(1βs)\bar{\alpha}_t=\prod_{s=0}^t\left(1-\beta_s\right)ϵN(0,I)\epsilon \sim \mathcal{N}(0, I)

    这一步就是用的DDPM中的那个,因为都是高斯核,所以xt\mathbf x_t可以由x0\mathbf x_0一步得到。

    xt=αtxt1+1αtϵt1=αt(αt1xt2+1αt1ϵt2)+1αtϵt1=αtαt1xt2+αtαtαt1ϵt2+1αtϵt1=αtαt1xt2+αtαtαt12+1αt2ϵt2=αtαt1xt2+αtαtαt1+1αtϵt2=αtαt1xt2+1αtαt1ϵt2==i=1tαix0+1i=1tαiϵ0=αˉtx0+1αˉtϵ0\begin{aligned} \boldsymbol{x}_t & =\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} \epsilon_{t-1}^* \\ & =\sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_{t-2}^*\right)+\sqrt{1-\alpha_t} \epsilon_{t-1}^* \\ & =\sqrt{\alpha_t \alpha_{t-1}} x_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}} \epsilon_{t-2}^*+\sqrt{1-\alpha_t} \epsilon_{t-1}^* \\ & =\sqrt{\alpha_t \alpha_{t-1}} x_{t-2}+\sqrt{{\sqrt{\alpha_t-\alpha_t \alpha_{t-1}}}^2+\sqrt{1-\alpha_t^2}} \epsilon_{t-2} \\ & =\sqrt{\alpha_t \alpha_{t-1}} x_{t-2}+\sqrt{\alpha_t-\alpha_t \alpha_{t-1}+1-\alpha_t} \epsilon_{t-2} \\ & =\sqrt{\alpha_t \alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \epsilon_{t-2} \\ & =\ldots \\ & =\sqrt{\prod_{i=1}^t \alpha_i} x_0+\sqrt{1-\prod_{i=1}^t \alpha_i} \epsilon_0 \\ & =\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_0\end{aligned}

clamping trick 会将网络fθ(xt,t)f_\theta (\mathbf x_t,t)的预测结果映射到接近的word embedding 序列上。 现在采样步骤就变为了:

xt1=αˉClamp(fθ(xt,t))+1αˉϵ\mathbf{x}_{t-1}=\sqrt{\bar{\alpha}} \cdot \operatorname{Clamp}\left(f_\theta\left(\mathbf{x}_t, t\right)\right)+\sqrt{1-\bar{\alpha}} \epsilon

clamping trick 迫使扩散模型降噪过程中每一步都去计算一个word embedding,使向量预测更为准确,以此减少舍入误差。

作者在这里提示将开始使用clamping trick的起始位置设置为超参数。具体原因看论文P5

论文信息

image.png

论文地址:[2205.14217] Diffusion-LM Improves Controllable Text Generation (arxiv.org)

代码地址:XiangLi1999/Diffusion-LM: Diffusion-LM (github.com)


本文正在参加「金石计划」