【13】变分自编码器(VAE)的原理介绍与pytorch实现

2,465 阅读13分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

1.VAE的设计思路

VAE作为一个生成模型,其基本思路是很容易理解的:把一堆真实样本通过编码器网络变换成一个理想的数据分布,然后这个数据分布再传递给一个解码器网络,得到一堆生成样本,生成样本与真实样本足够接近的话,就训练出了一个自编码器模型。那VAE(变分自编码器)就是在自编码器模型上做进一步变分处理,使得编码器的输出结果能对应到目标分布的均值和方差,如下图所示,具体的方法和思想在后文会介绍:

在这里插入图片描述

VAE最想解决的问题是什么?当然是如何构造编码器和解码器,使得图片能够编码成易于表示的形态,并且这一形态能够尽可能无损地解码回原真实图像。

这似乎听起来与PCA(主成分分析)有些相似,而PCA本身是用来做矩阵降维的:

在这里插入图片描述

如图,X本身是一个矩阵,通过一个变换W变成了一个低维矩阵c,因为这一过程是线性的,所以再通过一个变换就能还原出一个,现在我们要找到一种变换W,使得矩阵X与能够尽可能地一致,这就是PCA做的事情。在PCA中找这个变换W用到的方法是SVD(奇异值分解)算法,这是一个纯数学方法,不再细述,因为在VAE中不再需要使用SVD,直接用神经网络代替。

回顾上述介绍,我们会发现PCA与我们想要构造的自编码器的相似之处是在于,如果把矩阵X视作输入图像,W视作一个编码器,低维矩阵c视作图像的编码,然后和分别视作解码器和生成图像,PCA就变成了一个自编码器网络模型的雏形。

在这里插入图片描述

现在我们需要对这一雏形进行改进。首先一个最明显能改进的地方是用神经网络代替W变换和变换,就得到了如下Deep Auto-Encoder模型:

在这里插入图片描述

这一替换的明显好处是,引入了神经网络强大的拟合能力,使得编码(Code)的维度能够比原始图像(X)的维度低非常多。在一个手写数字图像的生成模型中,Deep Auto-Encoder能够把一个784维的向量(28*28图像)压缩到只有30维,并且解码回的图像具备清楚的辨认度(如下图)。

在这里插入图片描述

至此我们构造出了一个重构图像比较清晰的自编码模型,但是这并没有达到我们真正想要构造的生成模型的标准,因为,对于一个生成模型而言,解码器部分应该是单独能够提取出来的,并且对于在规定维度下任意采样的一个编码,都应该能通过解码器产生一张清晰且真实的图片。

我们先来分析一下现有模型无法达到这一标准的原因。

在这里插入图片描述

如上图所示,假设有两张训练图片,一张是全月图,一张是半月图,经过训练我们的自编码器模型已经能无损地还原这两张图片。接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。一个比较合理的解释是,因为编码和解码的过程使用了深度神经网络,这是一个非线性的变换过程,所以在code空间上点与点之间的迁移是非常没有规律的。

如何解决这个问题呢?我们可以引入噪声,使得图片的编码区域得到扩大,从而掩盖掉失真的空白编码点。

在这里插入图片描述

如上图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内,于是在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。

由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点。为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

在这里插入图片描述

那么上述的这种将图像编码由离散变为连续的方法,就是变分自编码的核心思想。下面就会介绍VAE的模型架构,以及解释VAE是如何实现上述构思的。


2.VAE的模型架构

在这里插入图片描述

上面这张图就是VAE的模型架构,我们先粗略地领会一下这个模型的设计思想。

在auto-encoder中,编码器是直接产生一个编码的,但是在VAE中,为了给编码添加合适的噪音,编码器会输出两个编码,一个是原有编码(m1,m2,m3),另外一个是控制噪音干扰程度的编码(σ1,σ2,σ3),第二个编码其实很好理解,就是为随机噪音码(e1,e2,e3)分配权重,然后加上exp(σi)的目的是为了保证这个分配的权重是个正值,最后将原编码与噪音编码相加,就得到了VAE在code层的输出结果(c1,c2,c3)。其它网络架构都与Deep Auto-encoder无异。

损失函数方面,除了必要的重构损失外,VAE还增添了一个损失函数(见上图Minimize2内容),这同样是必要的部分,因为如果不加的话,整个模型就会出现问题:为了保证生成图片的质量越高,编码器肯定希望噪音对自身生成图片的干扰越小,于是分配给噪音的权重越小,这样只需要将(σ1,σ2,σ3)赋为接近负无穷大的值就好了。所以,第二个损失函数就有限制编码器走这样极端路径的作用,这也从直观上就能看出来,exp(σi)-(1+σi)在σi=0处取得最小值,于是(σ1,σ2,σ3)就会避免被赋值为负无穷大。

上述我们只是粗略地理解了VAE的构造机理,但是还有一些更深的原理需要挖掘,例如第二个损失函数为何选用这样的表达式,以及VAE是否真的能实现我们的预期设想,即“图片能够编码成易于表示的形态,并且这一形态能够尽可能无损地解码回原真实图像,是否有相应的理论依据。

下面我们会从理论上深入地分析一下VAE的构造依据以及作用原理。


3.VAE的作用原理

我们知道,对于生成模型而言,主流的理论模型可以分为隐马尔可夫模型HMM、朴素贝叶斯模型NB和高斯混合模型GMM,而VAE的理论基础就是高斯混合模型。

什么是高斯混合模型呢?就是说,任何一个数据的分布,都可以看作是若干高斯分布的叠加。

在这里插入图片描述

如图所示,如果P(X)代表一种分布的话,存在一种拆分方法能让它表示成图中若干浅蓝色曲线对应的高斯分布的叠加。有意思的是,这种拆分方法已经证明出,当拆分的数量达到512时,其叠加的分布相对于原始分布而言,误差是非常非常小的了。

于是我们可以利用这一理论模型去考虑如何给数据进行编码。一种最直接的思路是,直接用每一组高斯分布的参数作为一个编码值实现编码。

在这里插入图片描述

如上图所示,m代表着编码维度上的编号,譬如实现一个512维的编码,m的取值范围就是1,2,3……512。m会服从于一个概率分布P(m)(多项式分布)。现在编码的对应关系是,每采样一个m,其对应到一个小的高斯分布N(μm,∑m),P(X)就可以等价为所有的这些高斯分布的叠加,即: 在这里插入图片描述 其中 在这里插入图片描述

上述的这种编码方式是非常简单粗暴的,它对应的是我们之前提到的离散的、有大量失真区域的编码方式。于是我们需要对目前的编码方式进行改进,使得它成为连续有效的编码。

在这里插入图片描述

现在我们的编码换成一个连续变量z,我们规定z服从正态分布N(0,1)(实际上并不一定要选N(0,1)用,其他的连续分布都是可行的)。每对于一个采样z,会有两个函数μ和σ,分别决定z对应到的高斯分布的均值和方差,然后在积分域上所有的高斯分布的累加就成为了原始分布P(X),即:

在这里插入图片描述

其中 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述 在这里插入图片描述


4.VAE的Pytorch实现

1)参考代码

model.py

# 定义变分自编码器VAE
class Variable_AutoEncoder(nn.Module):

    def __init__(self):

        super(Variable_AutoEncoder, self).__init__()

        # 定义编码器
        self.Encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU()
        )

        # 定义解码器
        self.Decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.fc_m = nn.Linear(64, 20)
        self.fc_sigma = nn.Linear(64, 20)

    def forward(self, input):

        code = input.view(input.size(0), -1)
        code = self.Encoder(code)

        # m, sigma = code.chunk(2, dim=1)
        m = self.fc_m(code)
        sigma = self.fc_sigma(code)

        e = torch.randn_like(sigma)

        c = torch.exp(sigma) * e + m
        # c = sigma * e + m

        output = self.Decoder(c)
        output = output.view(input.size(0), 1, 28, 28)

        return output, m, sigma

train.py

import torch
import torchvision
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from model import Auto_Encoder, Variable_AutoEncoder
import os

# 定义超参数
learning_rate = 1e-3
batch_size = 64
epochsize = 30
root = 'E:/学习/机器学习/数据集/MNIST'
sample_dir = "image5"

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# 图像相关处理操作
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])   # 一定要去掉这句,不需要Normalize操作
])

# 训练集下载
mnist_train = datasets.MNIST(root=root, train=True, transform=transform, download=False)
mnist_train = DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)

# 测试集下载
mnist_test = datasets.MNIST(root=root, train=False, transform=transform, download=False)
mnist_test = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)

# image,_ = iter(mnist_test).next()
# print("image.shape:",image.shape)   # torch.Size([64, 1, 28, 28])

device = torch.device('cuda')

# 定义并导入网络结构
VAE = Variable_AutoEncoder()
VAE = VAE.to(device)
# VAE.load_state_dict(torch.load('VAE.ckpt'))

criteon = nn.MSELoss()
optimizer = optim.Adam(VAE.parameters(), lr=learning_rate)

print("start train...")
for epoch in range(epochsize):

    # 训练网络
    for batchidx, (realimage, _) in enumerate(mnist_train):

        realimage = realimage.to(device)

        # 生成假图像
        fakeimage, m, sigma = VAE(realimage)

        # 计算KL损失与MSE损失
        # KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2)) / (input.size(0)*28*28)
        # KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2))
        # 此公式是直接根据KL Div公式化简,两个分布分别是(0-1)分布与(m,sigma)分布
        # 最后根据像素点与样本批次求平均,既realimage.size(0)*28*28
        KLD = 0.5 * torch.sum(
            torch.pow(m, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (realimage.size(0)*28*28)

        # 计算均方差损失
        # MSE = criteon(fakeimage, realimage)
        MSE = torch.sum(torch.pow(fakeimage - realimage, 2)) / (realimage.size(0)*28*28)

        # 总的损失函数
        loss = MSE + KLD

        # 更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batchidx%300 == 0:
            print("epoch:{}/{}, batchidx:{}/{}, loss:{}, MSE:{}, KLD:{}"
                  .format(epoch, epochsize, batchidx, len(mnist_train), loss, MSE, KLD))

    # 生成图像
    realimage, _ = iter(mnist_test).next()
    realimage = realimage.to(device)
    fakeimage, _, _ = VAE(realimage)

    # 真假图像何必成一张
    image = torch.cat([realimage[0:32], fakeimage[0:32]], dim=0)

    # 保存图像
    save_image(image, os.path.join(sample_dir, 'image-{}.png'.format(epoch + 1)), nrow=8, normalize=True)

    torch.save(VAE.state_dict(), 'VAE.ckpt')
2)训练结果展示

在这里插入图片描述

Epoch1生成的图像

在这里插入图片描述

Epoch10生成的图像

在这里插入图片描述

Epoch30生成的图像

一开始的时候VAE的效果比AE的差,但是训练次数多了之后效果会变好。

3)生成结果展示

test.py

import torch
from torchvision.utils import save_image
from model import Variable_AutoEncoder
import os

epochsize = 20
batch_size = 64
sample_dir = "vae_val_result"
#seed = 0

#torch.manual_seed(seed)

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

VAE = Variable_AutoEncoder()
VAE.load_state_dict(torch.load('VAE.ckpt'))


for epoch in range(epochsize):

    z = torch.randn(batch_size, 20)
    # code = VAE.Encoder(z)
    fakeimage = VAE.Decoder(z)
    fakeimage = fakeimage.view(z.size(0), 1, 28, 28)

    # print("fakeimage.shape:",fakeimage.shape)

    save_image(fakeimage, os.path.join(sample_dir, 'image-{}.png'.format(epoch + 1)), nrow=8, normalize=True)
    print("generate success:",epoch)

我们根据之前训练好的VAE网络来随机的从0-1分布中sample一些噪声出来到Decoder中,结果如下,可以看见能够正常是随机生成图像。

在这里插入图片描述

但是,就结果而言,生成的图像比我们之前利用GAN生成出来的图像要模糊。


5.实现VAE中出现的问题

  • 问题1:训练结果中长时间生成的图像只有少量白点或者全部都是同一个模糊图像

在这里插入图片描述

损失函数出现的问题,在这次实验总,根据随机从0-1高斯分布sample出来的噪声与网络生成出来的N(μ,σ),这两者的KL分布直接计算出来,而不是单纯的使用paper给出的公式。

在这里插入图片描述

计算结果:

在这里插入图片描述

代码表示:

# KLD = torch.sum(torch.exp(sigma) - (1 + sigma) + torch.pow(m, 2))

KLD = 0.5 * torch.sum(
    torch.pow(m, 2) +
    torch.pow(sigma, 2) -
    torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (realimage.size(0)*28*28)
  • 问题2:训练结过中开始时生成一个近似成功的图像,但是长期训练后只能得到一些类似的结果 在这里插入图片描述

此问题的原因是使用了Normalize操作数据集,不使用即可

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])   
])
  • 问题3:生成结果时出现了全部为0的状况

对于这种情况,主要是VAE的结构设计得不对,以下为原始结构。

# 定义变分自编码器VAE
class Variable_AutoEncoder(nn.Module):

    def __init__(self):

        super(Variable_AutoEncoder, self).__init__()

        # 定义编码器
        self.Encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        # 定义解码器
        self.Decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, input):

        code = input.view(input.size(0), -1)
        code = self.Encoder(code)

        m, sigma = code.chunk(2, dim=1)
        e = torch.randn_like(sigma)
        c = torch.exp(sigma) * e + m

        output = self.Decoder(c)
        output = output.view(input.size(0), 1, 28, 28)

        return output, m, sigma

对于这种结构,可以正常训练出一个结果,见下图。

在这里插入图片描述

但是当使用此训练好的网络来随机输出一些服从0-1分布噪声进Decoder的时候,结果全为0。

在这里插入图片描述

个人猜测是code的维度不够,不足以保留过多的信息,随后见上述的参考代码model.py,改进过之后便可以正常的生成图像。


参考资料: 本文的理论部分摘抄至:www.gwylab.com/note-vae.ht… 讲解的十分的详细。