扩散模型论文阅读笔记

323 阅读7分钟

扩散模型(Diffusion Model)是继GAN、VAE后的一种生成式模型。《Denoising Diffusion Probabilistic Models》是UC Berkeley于2020年发表的一篇论文,其中引入扩散模型(Diffusion Model)进行图片生成。而目前在文生图领域比较流行的工具,如DALL-E2、Imagen、Stable Diffusion等,均是以上述扩散模型为基础,不断进行算法优化、迭代,取得了令人惊艳的效果。 本文是对于博客《What are Diffusion Models?》部分内容的翻译,主要是对于论文《Denoising Diffusion Probabilistic Models》中扩散模型的解读,后续会不断补充关于扩散模型算法迭代的论文阅读笔记。

扩散模型

正向扩散过程

令原始图片样本为x0\mathbf{x}_0,其满足分布x0q(x0)\mathbf{x}_0 \sim q(\mathbf{x}_0)。定义前向扩散过程,在TT步内,每步给样本增加一个小的满足高斯分布的噪声,从而产生TT个带噪声的样本x1,...,xT\mathbf{x}_1,...,\mathbf{x}_T,整个过程为一个一阶马尔可夫过程,xt\mathbf{x}_t只与xt1\mathbf{x}_{t-1}有关,可用以下公式表示:

q(xtxt1)=N(xt;1βtxt1,βtI)(1)q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I}) \tag{1}

其中,q(xtxt1)q(\mathbf{x}_t|\mathbf{x}_{t-1})表示给定xt1\mathbf{x}_{t-1}时,xt\mathbf{x}_{t}的条件概率,即均值为1βtxt1\sqrt{1-\beta_t}\mathbf{x}_{t-1}、方差为βtI\beta_t\mathbf{I}的高斯分布,集合{βt(0,1)}t=1T\{\beta_t \in (0,1)\}_{t=1}^{T}用于控制每步的噪声大小。进一步给定x0\mathbf{x}_0时,整个马尔科夫过程的条件概率为各步条件概率的连乘,可用以下公式表示:

q(x1:Tx0)=t=1Tq(xtxt1)(2)q(\mathbf{x}_{1:T}|\mathbf{x}_0)=\prod_{t=1}^{T}{q(\mathbf{x}_t|\mathbf{x}_{t-1})} \tag{2}

正向扩散过程可由图1从右到左的过程表示,其中x0\mathbf{x}_0为原始图片,随着每步增加噪声,图片逐渐变得模糊。

图1 从右到左为正向扩散过程

对于上述正向扩散过程,可进一步令αt=1βt\alpha_t=1-\beta_t,且αˉt=i=1tαi\bar{\alpha}_t=\prod_{i=1}^{t}{\alpha_i},则xt\mathbf{x}_t可用以下公式表示:

xt=αtxt1+1αtϵt1;where ϵt1,ϵt2,...N(0,I)=αtαt1xt2+1αtαt1ϵˉt2;where ϵˉt2 merge two Gaussians ().=...=αˉtx0+1αˉtϵ(3)\begin{aligned} \mathbf{x}_t&=\sqrt{\alpha_t}\mathbf{x}_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1} &;\text{where }\epsilon_{t-1},\epsilon_{t-2},...\sim\mathcal{N}(0,\mathbf{I})\\ &=\sqrt{\alpha_t \alpha_{t-1}}\mathbf{x}_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}_{t-2} &;\text{where }\bar{\epsilon}_{t-2}\text{ merge two Gaussians }(*). \\ &=... \\ &=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon \end{aligned} \tag{3}

xt\mathbf{x}_t是在xt1\mathbf{x}_{t-1}的基础上,增加一个满足高斯分布的噪声ϵt1\epsilon_{t-1},循环递归,即xt\mathbf{x}_t是在x0\mathbf{x}_0的基础上,增加一个满足高斯分布的噪声ϵ\epsilon。这里使用了高斯分布的一个特性,即两个高斯分布合并后仍是一个高斯分布,例如分布N(0,σ12I)\mathcal{N}(0,\sigma_1^2\mathbf{I})N(0,σ22I)\mathcal{N}(0,\sigma_2^2\mathbf{I}),合并后的分布为N(0,(σ12+σ22)I)\mathcal{N}(0,(\sigma_1^2+\sigma_2^2)\mathbf{I})。因此,给定x0\mathbf{x}_0时,xt\mathbf{x}_t的条件概率为均值为αˉtx0\sqrt{\bar{\alpha}_t}\mathbf{x}_0、方差为(1αˉt)I(1-\bar{\alpha}_t)\mathbf{I}的高斯分布,可用以下公式表示:

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)(4)q(\mathbf{x}_t|\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_t;\sqrt{\bar{\alpha}_t}\mathbf{x}_0,(1-\bar{\alpha}_t)\mathbf{I}) \tag{4}

反向扩散过程

以上介绍了正向扩散过程,即图1从右到左,对原始图片逐步增加噪声,如果将过程逆向,即图1从左到右,那么就能从满足高斯分布的噪音xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(0,\mathbf{I})逐步还原原始图片样本,这就是基于扩散模型生成图片的基本思想,即从xT\mathbf{x}_Tx0\mathbf{x}_0的每一步,在给定xt\mathbf{x}_t时,根据条件概率q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t)采样求解xt1\mathbf{x}_{t-1},直至最终得到x0\mathbf{x}_0。 而当正向扩散过程每步增加的噪声很小时,反向扩散过程的条件概率q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t)也可以认为满足高斯分布,但实际上,我们不能直接求解该条件概率,因为直接求解需要整体数据集合。除直接求解外,另一个方法是训练一个模型pθp_\theta近似预估上述条件概率,可用以下公式表示:

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))(5)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1};\mu_\theta(\mathbf{x}_t,t),\Sigma_\theta(\mathbf{x}_t,t)) \tag{5}
pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)(6)p_{\theta}(\mathbf{x}_{0:T})=p(\mathbf{x}_T)\prod_{t=1}^{T}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)} \tag{6}

进一步,论文将求解q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t)等价于求解q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0),可用以下公式表示:

q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)(7)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}(\mathbf{x}_t,\mathbf{x}_0),\tilde{\beta}_t\mathbf{I}) \tag{7}

基于贝叶斯公式,可将q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)作以下转化:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)(8)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} \tag{8}

而上述公式中,q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)q(xt1x0)q(\mathbf{x}_{t-1}|\mathbf{x}_0)q(xtx0)q(\mathbf{x}_t|\mathbf{x}_0)通过正向扩散过程中的分析可知均满足高斯分布,代入高斯分布的公式进一步转化可得:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)exp(12((xtαtxt1)2βt+(xt1αˉt1x0)21αˉt1(xtαˉtx0)21αˉt))=exp(12(xt22αtxtxt1+αtxt12βt+xt122αˉt1x0xt1+αˉt1x021αˉt1(xtαˉtx0)21αˉt))=exp(12((αtβt+11αˉt1)xt12(2αtβtxt+2αˉt11αˉt1x0)xt1+C(xt,x0)))(9)\begin{aligned} q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)&=q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} \\ &\propto\exp\left(-\frac{1}{2}(\frac{(\mathbf{x}_t-\sqrt{\alpha_t}\mathbf{x}_{t-1})^2}{\beta_t}+\frac{(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0)^2}{1-\bar{\alpha}_{t-1}}-\frac{(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0)^2}{1-\bar{\alpha}_t})\right) \\ &=\exp\left(-\frac{1}{2}(\frac{\mathbf{x}_t^2-2\sqrt{\alpha_t}\mathbf{x}_t\mathbf{x}_{t-1}+\alpha_t\mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0\mathbf{x}_{t-1}+\bar{\alpha}_{t-1}\mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{(\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0)^2}{1-\bar{\alpha}_t})\right) \\ &=\exp\left(-\frac{1}{2}\left((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}})\mathbf{x}_{t-1}^2-(\frac{2\sqrt{\alpha_t}}{\beta_t}\mathbf{x}_t+\frac{2\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}\mathbf{x}_0)\mathbf{x}_{t-1}+C(\mathbf{x}_t,\mathbf{x}_0)\right)\right) \end{aligned} \tag{9}

其中C(xt,x0)C(\mathbf{x}_t,\mathbf{x}_0)是不包含中间状态xt1\mathbf{x}_{t-1}的函数。q(xt1xt,x0)q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)也满足高斯分布,对比公式7和公式9,并代入前向扩散过程中已定义的αt=1βt\alpha_t=1-\beta_t、且αˉt=i=1tαi\bar{\alpha}_t=\prod_{i=1}^{t}{\alpha_i},可将公式7的方差和均值表示为以下公式:

β~t=1/(αtβt+11αˉt1)=1/(αtαˉt+βtβt(1αˉt1))=1αˉt11αˉtβt(10)\tilde{\beta}_t=1/(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}})=1/(\frac{\alpha_t-\bar{\alpha}_t+\beta_t}{\beta_t(1-\bar{\alpha}_{t-1})})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\cdot\beta_t \tag{10}
μ~t(xt,x0)=(αtβtxt+αˉt11αˉt1x0)/(αtβt+11αˉt1)=(αtβtxt+αˉt11αˉt1x0)1αˉt11αˉtβt=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0(11)\begin{aligned} \tilde{\mu}_t(\mathbf{x}_t,\mathbf{x}_0)&=(\frac{\sqrt{\alpha_t}}{\beta_t}\mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}\mathbf{x}_0)/(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}) \\ &=(\frac{\sqrt{\alpha_t}}{\beta_t}\mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}\mathbf{x}_0)\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\cdot\beta_t \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\mathbf{x}_0 \end{aligned} \tag{11}

在前向扩散过程中已推导xt\mathbf{x}_tx0\mathbf{x}_0的关系,即x0=1αˉt(xt1αˉtϵt)\mathbf{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_t),将其代入上面的公式,可进一步得到:

μ~t=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt1αˉt(xt1αˉtϵt)=1αt(xt1αt1αˉtϵt)(12)\begin{aligned} \tilde{\mu}_t&=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_t) \\ &=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t) \end{aligned} \tag{12}

模型训练

目标函数

反向扩散过程中已提到通过训练一个模型pθp_\theta近似预估条件概率q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t),对于该模型的训练,其目标函数即最小化以下交叉熵损失函数,也可以认为对于样本x0\mathbf{x}_0,在满足分布q(x0)q(x_0)下,最大化模型给出的概率pθ(x0)p_\theta(\mathbf{x}_0)的期望:

LCE=Eq(x0)logpθ(x0)=Eq(x0)log(pθ(x0:T)dx1:T)=Eq(x0)log(q(x1:Tx0)pθ(x0:T)q(x1:Tx0)dx1:T)=Eq(x0)log(Eq(x1:Tx0)pθ(x0:T)q(x1:Tx0))Eq(x0:T)logpθ(x0:T)q(x1:Tx0)=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=LVLB(13)\begin{aligned} L_{\text{CE}}&=-\mathbb{E}_{q(\mathbf{x}_0)}\log{p_\theta(\mathbf{x}_0)} \\ &=-\mathbb{E}_{q(\mathbf{x}_0)}\log{\left(\int{p_\theta(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}}\right)} \\ &=-\mathbb{E}_{q(\mathbf{x}_0)}\log{\left(\int{q(\mathbf{x}_{1:T}|\mathbf{x}_0)\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}}d\mathbf{x}_{1:T}\right)} \\ &=-\mathbb{E}_{q(\mathbf{x}_0)}\log{\left(\mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right)} \\ &\le -\mathbb{E}_{q(\mathbf{x}_{0:T})}\log{\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}} \\ &=\mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log{\frac{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})}}\right]=L_{\text{VLB}} \end{aligned} \tag{13}

上述公式转化中,先将pθ(x0)p_\theta(\mathbf{x}_0)转化为基于x1:T\mathbf{x}_{1:T}pθ(x0:T)p_\theta(\mathbf{x}_{0:T})的积分,再将pθ(x0:T)p_\theta(\mathbf{x}_{0:T})转化为q(x1:Tx0)pθ(x0:T)q(x1:Tx0)q(\mathbf{x}_{1:T}|\mathbf{x}_0)\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)},最后通过变分下届舍弃Eq(x1:Tx0)\mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)},将目标函数由最小化LCEL_\text{CE}转化为最小化LVLBL_{\text{VLB}}。 进一步对LVLBL_{\text{VLB}}进行拆解转化,如下所示:

LVLB=Eq(x0:T)[logq(x1:Tx0)pθ(x0:T)]=Eq[logt=1Tq(xtxt1)pθ(xT)t=1Tpθ(xt1xt)]=Eq[logpθ(xT)+t=1Tlogq(xtxt1)pθ(xt1xt)]=Eq[logpθ(xT)+t=2Tlogq(xtxt1)pθ(xt1xt)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlog(q(xt1xt,x0)pθ(xt1xt)q(xtx0)q(xt1x0))+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+t=2Tlogq(xtx0)q(xt1x0)+logq(x1x0)pθ(x0x1)]=Eq[logpθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)+logq(xTx0)q(x1x0)+logq(x1x0)pθ(x0x1)]=Eq[logq(xTx0)pθ(xT)+t=2Tlogq(xt1xt,x0)pθ(xt1xt)logpθ(x0x1)]=Eq[DKL(q(xTx0)pθ(xT))LT+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0](14)\begin{aligned} L_{\text{VLB}}&=\mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log{\frac{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}{p_\theta(\mathbf{x}_{0:T})}}\right] \\ &=\mathbb{E}_q\left[\log{\frac{\prod_{t=1}^T{q(\mathbf{x}_t|\mathbf{x}_{t-1})}}{p_\theta(\mathbf{x}_T)\prod_{t=1}^T{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}\right] \\ &=\mathbb{E}_q\left[-\log{p_\theta(\mathbf{x}_T)}+\sum_{t=1}^T{\log{\frac{q(\mathbf{x}_t|\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}\right] \\ &=\mathbb{E}_q\left[-\log{p_\theta(\mathbf{x}_T)}+\sum_{t=2}^T{\log{\frac{q(\mathbf{x}_t|\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}+\log{\frac{q(\mathbf{x}_1|\mathbf{x}_0)}{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}}\right] \\ &=\mathbb{E}_q\left[-\log{p_\theta(\mathbf{x}_T)}+\sum_{t=2}^T{\log{\left(\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}\cdot\frac{q(\mathbf{x}_t|\mathbf{x}_0)}{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}\right)}}+\log{\frac{q(\mathbf{x}_1|\mathbf{x}_0)}{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}}\right] \\ &=\mathbb{E}_q\left[-\log{p_\theta(\mathbf{x}_T)}+\sum_{t=2}^T{\log{\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}+\sum_{t=2}^T{\log{\frac{q(\mathbf{x}_t|\mathbf{x}_0)}{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}}}+\log{\frac{q(\mathbf{x}_1|\mathbf{x}_0)}{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}}\right] \\ &=\mathbb{E}_q\left[-\log{p_\theta(\mathbf{x}_T)}+\sum_{t=2}^T{\log{\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}+\log{\frac{q(\mathbf{x}_T|\mathbf{x}_0)}{q(\mathbf{x}_1|\mathbf{x}_0)}}+\log{\frac{q(\mathbf{x}_1|\mathbf{x}_0)}{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}}\right] \\ &=\mathbb{E}_q\left[\log{\frac{q(\mathbf{x}_T|\mathbf{x}_0)}{p_\theta(\mathbf{x}_T)}}+\sum_{t=2}^T{\log{\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}}}-\log{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}\right] \\ &=\mathbb{E}_q\left[\underbrace{D_{\text{KL}}(q(\mathbf{x}_T|\mathbf{x}_0)\parallel p_\theta(\mathbf{x}_T))}_{L_T}+\sum_{t=2}^T{\underbrace{D_{\text{KL}}(q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)\parallel p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t))}_{L_{t-1}}\underbrace{-\log{p_\theta(\mathbf{x}_0|\mathbf{x}_1)}}_{L_0}}\right] \end{aligned} \tag{14}

其中,先将q(x1:Tx0)q(\mathbf{x}_{1:T}|\mathbf{x_0})pθ(x0:T)p_\theta(\mathbf{x}_{0:T})分别表示为由每步q(xtxt1)q(\mathbf{x}_t|\mathbf{x}_{t-1})pθ(xt1xt)p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)连乘的形式,接着将LVLBL_{\text{VLB}}表示为每步logq(xtxt1)pθ(xt1xt)\log{\frac{q(\mathbf{x}_t|\mathbf{x}_{t-1})}{p_\theta(\mathbf{x}_{t-1}|\mathbf{x_t})}}相加的形式,而形如logqpθ\log{\frac{q}{p_\theta}}的表达式在数学上被定义为KL散度,用来衡量真实概率分布qq和近似概率分布pθp_\theta之间的相似度,因此,公式14最终会可通过公式15来表达:

LVLB=LT+LT1++L0where LT=DKL(q(xTx0)pθ(xT))Lt=DKL(q(xtxt+1,x0)pθ(xtxt+1)) for 1tT1L0=logpθ(x0x1)(15)\begin{aligned} L_{\text{VLB}}&=L_T+L_{T-1}+\cdots+L_0 \\ \text{where }L_T&=D_{\text{KL}}(q(\mathbf{x}_T|\mathbf{x}_0)\parallel p_\theta(\mathbf{x}_T)) \\ L_t&=D_{\text{KL}}(q(\mathbf{x}_t|\mathbf{x}_{t+1},\mathbf{x}_0)\parallel p_\theta(\mathbf{x}_t|\mathbf{x}_{t+1}))\text{ for } 1\le t\le T-1 \\ L_0&=-\log{p_\theta(\mathbf{x}_0|\mathbf{x}_1)} \end{aligned} \tag{15}

其中,LTL_T是常量,LtL_t表示每步q(xtxt+1,x0)q(\mathbf{x}_t|\mathbf{x}_{t+1},\mathbf{x}_0)pθ(xtxt+1)p_\theta(\mathbf{x}_t|\mathbf{x}_{t+1})的KL散度。 反向扩散过程中已提到通过训练一个模型pθp_\theta近似预估条件概率q(xt1xt)q(\mathbf{x}_{t-1}|\mathbf{x}_t),其中如公式12所示,预测值μ~t=1αt(xt1αt1αˉtϵt)\tilde{\mu}_t=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t),而xt\mathbf{x}_t是输入值,因此可以改为在第tt步,针对输入xt\mathbf{x}_t预测噪声ϵt\epsilon_t

μθ(xt,t)=1αt(xt1αt1αˉtϵθ(xt,t))Thus xt1=N(xt1;1αt(xt1αt1αˉtϵθ(xt,t)),Σθ(xt,t))(16)\begin{aligned} \mu_\theta(\mathbf{x}_t,t)&=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_t,t)\right) \\ \text{Thus } \mathbf{x}_{t-1}&=\mathcal{N}(\mathbf{x}_{t-1};\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_t,t)\right),\Sigma_\theta(\mathbf{x}_t,t)) \end{aligned} \tag{16}

损失项LtL_t可以表示为最小化μ~\tilde{\mu}的差值,即最小二乘法上述的μ~t\tilde{\mu}_tμθ\mu_\theta

Lt=Ex0,ϵ[12Σθ(xt,t)22μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[12Σθ22 1αt(xt1αt1αˉtϵt)1αt(xt1αt1αˉtϵθ(xt,t))2]=Ex0,ϵ[(1αt)22αt(1αˉt)Σθ22ϵtϵθ(xt,t)2]=Ex0,ϵ[(1αt)22αt(1αˉt)Σθ22ϵtϵθ(αˉtx0+1αˉtϵt,t)2](17)\begin{aligned} L_t&=\mathbb{E}_{\mathbf{x}_{0,\epsilon}}\left[\frac{1}{2\parallel\Sigma_\theta(\mathbf{x}_t,t)\parallel_2^2}\parallel\tilde{\mu}_t(\mathbf{x}_t,\mathbf{x}_0)-\mu_\theta(\mathbf{x}_t,t)\parallel^2\right] \\ &=\mathbb{E}_{\mathbf{x}_{0,\epsilon}}\left[\frac{1}{2\parallel\Sigma_\theta\parallel_2^2}\parallel\ \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t\right) - \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_t,t)\right)\parallel^2\right] \\ &=\mathbb{E}_{\mathbf{x}_{0,\epsilon}}\left[\frac{(1-\alpha_t)^2}{2\alpha_t(1-\bar{\alpha}_t)\parallel\Sigma_\theta\parallel_2^2}\parallel\epsilon_t-\epsilon_\theta(\mathbf{x}_t,t)\parallel^2\right] \\ &=\mathbb{E}_{\mathbf{x}_{0,\epsilon}}\left[\frac{(1-\alpha_t)^2}{2\alpha_t(1-\bar{\alpha}_t)\parallel\Sigma_\theta\parallel_2^2}\parallel\epsilon_t-\epsilon_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t,t)\parallel^2\right] \end{aligned} \tag{17}

论文中进一步发现,训练扩散模型时,如果舍弃权重项,对损失函数项LtL_t进行简化,效果会更好,简化后的LtL_t如下:

Ltsimple=Et[1,T],x0,ϵt[ϵtϵθ(xt,t)2]=Et[1,T],x0,ϵt[ϵtϵθ(αˉtx0+1αˉtϵt,t)2](18)\begin{aligned} L_t^{\text{simple}}&=\mathbb{E}_{t\sim[1,T],\mathbf{x}_0,\epsilon_t}\left[\parallel\epsilon_t-\epsilon_\theta(\mathbf{x}_t,t)\parallel^2\right] \\ &=\mathbb{E}_{t\sim[1,T],\mathbf{x}_0,\epsilon_t}\left[\parallel\epsilon_t-\epsilon_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t,t)\parallel^2\right] \\ \end{aligned} \tag{18}

最终的损失函数为:

Lsimple=Ltsimple+C(19)L_{\text{simple}}=L_t^{\text{simple}}+C \tag{19}

其中,CC为依赖θ\theta的常量。

训练过程

图2 训练和采样算法

训练和采样算法如图2所示。在训练算法中,对于随机采样的样本x0\mathbf{x}_0和步数tt按高斯分布生成噪声ϵ\epsilon,然后基于噪声计算公式18的梯度,然后再基于梯度更新模型参数,循环上述步骤,直至收敛。在采样算法中,按高斯分布生成xT\mathbf{x}_T,然后分TT步,每步预测噪音ϵθ\epsilon_\theta,按公式12由xt\mathbf{x}_tϵθ\epsilon_\theta计算xt1\mathbf{x}_{t-1},直至最后计算得到x0\mathbf{x}_0,另外,除最后一步外,前序每步计算结果都会再增加一个满足高斯分布的随机噪声z\mathbf{z}