扩散模型之噪声条件分数网络详解

388 阅读5分钟

本文所涉及所有资源均在传知代码平台可获取。

概述

除了常用的DDPM之外,还有一种新的扩散模型,即Noise Conditional Score Networks,由Yang Song等人在2019年提出。与传统Diffusion model不同,该模型使用了另一种方法来处理生成过程中的噪声,从而提高了模型的生成效果。Noise Conditional Score Networks被视为Diffusion model的一个分支,目前也受到了广泛的关注和研究。总之,Diffusion model及其相关扩展已经成为图像生成领域的研究热点,并在很多应用场景中取得了显著的成功。

演示效果

● 小等级噪声

● 大等级噪声

噪声条件分数网络

噪声条件分数网络(Noise Conditional Score Networks,以下简称NCSN)是一种典型的隐式生成模型,该模型的核心在于如何计算分数(score)以及通过分数进行采样。Score是对数概率密度关于输入变量的梯度,即

NCSN成功的将分数融入进了生成模型中,其采样方法为Langevin dynamics,这是一种受统计物理中布朗运动模型启发的采样方法。

Langevin dynamics可以仅利用分数从p(x)p(x)中采样。给定一个特定的步长ϵ>0ϵ>0和一个初始数据样本x0∼π(x)x0∼π(x),Langevin dynamics方法可以不断迭代,使得xtx t服从的分布接近p(x)p(x)

其中zt∼N(0,I)z t∼N(0,I).

基于Langevin dynamics采样方法,我们可以通过设置一个参数化分数网络sθ(x)s θ(x)对数据样本的分数进行学习,从而实现从x0x 0中学习并采样类似结构数据的目的,且不需要学习数据样本的概率密度。这个过程被成为分数匹配(score matching),我们需要优化的损失函数为

直接计算这个损失函数是非常困难的,因为需要对数据变量xx的每一个分量进行反向传播。对于较复杂的问题,这几乎是不可能的。为了解决这个问题,可以通过对数据点添加高斯噪声,将损失函数转换为

这个公式实际上涉及到了机器学习中非常有趣的一个领域——score match.

然而,这个结果最后匹配的是加噪后数据分布qσ(x~)q σ(x)的分数。因此,只有在噪声比较小的时候采集的样本才符合q(x0)q(x0).

这个分数网络在对数据样本的分数进行学习的时候会遇到许多的问题。其中最主要的问题之一是,我们的原始样本数据无法对全空间进行足够的覆盖,会使得低密度区域不会有足够的样本量来正确的评估分数。实际上,高密度区域总是非常少的。当我们在利用Langevin dynamics方法采样时,初始值大概率会落在低密度区域,这会导致最终的采样结果不准确。

解决上述问题的方法是在原始数据样本x0x0中添加高斯噪声,使得加噪后的数据x~x能尽可能在空间中广泛的分布,以填补空间中的低密度区域。考虑到这个过程需要的噪声等级很大,而前面要求最后结果正确的噪声等级却必须很小。因此,可以设置多个噪声等级{σi}i=1L{σ i}i=1L,且满足

在这种情况下,我们选择通过分数网络sθ(x,σ)=∇xlog⁡qσ(x)s θ(x,σ)=∇xlogq σ(x),就可以使用对应的损失函数来学习在每个噪声等级下的加噪样本的分数。

其中x∼qσ(x∣x0)=N(x~;x0,σ2I)xq σ(xx 0)=N(x;x 0,σ2I),令ϵ∼N(0,I)ϵ∼N(0,I),则

在此基础上,我们可以将损失函数化解为

基于这个损失函数以及Langevin Dynamics方法就是我们熟知的NCSN了。采样过程可以理解为首先从最大的噪声等级σ1σ1中抽样,然后以此结果为初始值从σ2σ2下的噪声等级中采样,使得样本的分布趋于σ2σ2下的分布,不断迭代下最终将趋于最小的噪声等级σLσ L的分布,因此只要σLσ L设置的比较小就可以得到符合x0x0的分布。

这种采样方式就相当于有一个力一样,SCore就是力,

实现

采样过程

首先是这个模型的采样过程,根据Langevin采样过程,我们的核心代码为

@torch.no_grad()
    def sample(self, n):
        x = torch.rand(n, 2) * 16 - 8
        for i in range( len(self.sigma) ):
            alpha_i = self.epsilion * self.sigma[i]**2 / self.sigma_min**2
            for t in range(0, self.T):
                sigma_torch_list = self.expand_sigma( self.sigma[i], n )
                score = self.net( x, sigma_torch_list )
                z = torch.randn_like(x)
                x += score * alpha_i / 2 + np.sqrt(alpha_i) * z
            # print(x)

        return x.numpy()

遍历每个噪声等级,在每个噪声等级的循环步下,在前一个噪声等级的采样结果基础上继续采样郎之万采样。

分数网络

分数网络采样比较简单的受噪声等级控制的全连接神经网络

class ScoreNet(nn.Module):
    def __init__(self):
        super(ScoreNet, self).__init__()
        num_units = 32
        self.x_embeddings = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2),
            ]
        )
        self.sigma_embeddings = nn.ModuleList(
            [
                nn.Linear(1, num_units),
                nn.Linear(1, num_units),
                nn.Linear(1, num_units),
            ]
        )

    def forward(self, x, sigma):
        for idx, embedding_layer in enumerate(self.sigma_embeddings):
            sigma_embedding = embedding_layer(sigma)
            # print(t_embedding.shape)
            x = self.x_embeddings[2 * idx](x)
            x += sigma_embedding
            x = self.x_embeddings[2 * idx + 1](x)

        x = self.x_embeddings[-1](x)
        return x

训练过程

其训练过程与DDPM非常一致,随机选取一个噪声等级,根据损失函数进行训练

sigma = model.random_sigma(data)
    loss = model.loss(data, sigma)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.net.parameters(), 1.)
    optimizer.step()
    return loss

而其中的损失函数的实现核心代码为

def loss(self, x, sigma):
        # x.size == sigma.size

        sigma = sigma.unsqueeze(dim = -1)
        z = torch.randn_like(x)
        perturbed_x = x + sigma * z
        score = self.net( perturbed_x, sigma )
        return torch.square( sigma * score + z ).mean() / 2

通过这种训练过程和采样过程就能够生成出我们想要的样本了。

训练结果

采样结果可看演示效果。

使用方式

● 终端输入python main.py

安装依赖

● Python 3.11.4

● torch 2.0.1

感觉不错,点击我,立即使用