多模态论文笔记——dVAE(DALL·E的核心部件)

535 阅读1分钟

大家好,这里是好评笔记,本文为试读,查看全文请移步公主号:Goodnote。详细介绍DALL·E的核心部件之一——dVAE,在VQ-VAE的基础上使用Gumbel-Softmax实现采样,用于图像生成。

6.png

@[toc]


前情提要

VAE

深度学习——AE、VAE

AE 和 VAE 在结构、目的和优化方式上存在多个重要区别:

特性AEVAE
编码器输出固定的低维向量(确定性的表示)隐藏变量的均值 μ\mu 和方差 σ2\sigma^2(表示潜在空间的分布
潜在空间没有明确的分布假设假设潜在空间遵循某种概率分布(通常为正态分布)
解码器从固定低维向量生成输入数据的近似使用 重参数化技巧 从潜在变量的分布中采样,再通过解码器生成输入数据的近似
损失函数仅有重构损失,最小化输入数据与重构数据的差异重构损失 + KL 散度,既保证数据重构效果,又保证潜在空间的分布合理
目的数据降维、特征提取或数据去噪生成新数据(如图像生成、文本生成等),同时保留对输入数据的重构能力
生成新数据的能力无法直接生成新数据可以通过在潜在空间中采样生成与训练数据相似的全新数据

VQ-VAE

万字长文解读深度学习——VQ-VAE和VQ-VAE-2

VAE vs. VQ-VAE

区别

需要明白的是,VAE的主要作用是生成数据;而VQ-VAE的主要作用是压缩、重建数据(与AE一样),如果需要生成新数据,则需要结合 PixelCNN 等生成模型。

  • VAE 的核心思想是通过编码器学习潜在变量的连续分布(通常是高斯分布,非离散),并从该分布中采样潜在变量 z,然后由解码器生成数据
  • VQ-VAE模型的目标是学习如何将输入数据编码为离散潜在表示,并通过解码器重建输入数据,量化过程通过最近邻搜索确定嵌入向量,是一个确定性操作,这一过程并不涉及离散采样
  • 如果需要生成新数据,则需要在离散潜在空间中随机采样嵌入向量。VQ-VAE 本身没有内置采样机制,通常需要结合 PixelCNN 或PixelSNAIL 等模型来完成离散采样。

不可导问题及解决方法

  • VAE 通过连续潜在空间重参数化技巧避免了采样操作的不可导问题。
  • VQ-VAE潜在空间是离散的,量化过程是不可导的,通过在最近邻搜索中使用停止梯度传播来解决不可导问题(dVAE中引入Gumbel-Softmax 替代停止梯度)。原本的VQ-VAE不涉及生成数据,所以不需要采样,如果需要生成数据,则需要结合 PixelCNN 等生成模型。

VAE 和 VQ-VAE 的不可导问题及解决方法:

特性VAEVQ-VAE
潜在空间连续空间离散空间
不可导问题来源采样操作不可导最近邻搜索不可导
解决方法重参数化技巧停止梯度传播
实现方式分离随机性,直接优化 μ,σ\mu, \sigma解码器损失绕过量化过程优化编码器
适用场景平滑采样和连续潜在变量建模离散特征学习和高分辨率生成

重新参数化梯度是一种常用于训练变分自编码器(VAE)等生成模型的技术。它依赖于连续分布的可分解性,而 VQ-VAE 的离散分布(通过 one-hot 编码或 Codebook 表示)无法通过这种方式重新参数化。


VAE 的不可导问题及解决方法

不可导问题

  • 在训练VAE时,我们希望从一个分布中采样出一些隐变量,以生成模型的输出。然而,由于采样操作是不可导的,因此通常不能直接对采样操作求梯度。为了解决这个问题,我们可以使用重新参数化技术。
  • 在 VAE 中,潜在变量 zz 是通过从编码器输出的分布 q(zx)q(z|x) 中采样得到的: zN(μ,σ2) z \sim \mathcal{N}(\mu, \sigma^2)
    • μ\muσ\sigma 是编码器生成的分布参数。
    • 采样操作引入随机性,而随机采样本身不可导,因此无法通过梯度反向传播来优化编码器参数。

解决方法:重参数化技巧

  • 重新参数化技术的基本思想是,将采样过程拆分为两步:首先从一个固定的分布中采样一些固定的随机变量,然后通过一个确定的函数将这些随机变量转换为我们所需的随机变量。这样,我们就可以对这个确定的函数求导,从而能够计算出采样操作对于损失函数的梯度。

VAE 通过 重参数化技巧(Reparameterization Trick) 将采样过程分解为可导部分和不可导部分:

  1. 分离随机性:
    • 采样公式改写为:
      z=μ+σϵ,ϵN(0,1)z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)
    • ϵ\epsilon 是标准正态分布的随机噪声,采样只发生在 ϵ\epsilon 中。
    • μ\muσ\sigma 是由编码器网络直接输出的,可导。
  2. 作用:
    • 随机性仅由不可导的 ϵ\epsilon 控制,而 μ\muσ\sigma 的梯度可以正常计算,从而实现端到端训练。

VQ-VAE 的不可导问题及解决方法

不可导问题

  • 在 VQ-VAE 中,潜在变量是通过将编码器输出 ze(x)z_e(x) 映射到最近的嵌入向量(codebook 中的向量)得到的: zq(x)=argminekze(x)ek2 z_q(x) = \arg\min_{e_k} \|z_e(x) - e_k\|_2
    • 最近邻搜索是一个不可导操作,因为argmin或argmax 是一个离散操作,涉及离散索引 kk,因此不可导。

解决方法:停止梯度传播(Stop Gradient)

VQ-VAE 使用 停止梯度传播(Stop Gradient) 技巧来解决不可导问题:

  1. 停止梯度:
    • 在计算量化操作时,不允许梯度传播到最近邻搜索的部分。
    • 假设 zq(x)z_q(x) 是量化后的嵌入向量,VQ-VAE 中的梯度计算会直接将解码器损失作用到编码器输出 ze(x)z_e(x),而不会涉及量化过程。
  2. 公式:
    • zq(x)z_q(x) 的生成:
      zq(x)=ek,k=argminize(x)ei2z_q(x) = e_k, \quad k = \arg\min_{i} \|z_e(x) - e_i\|_2
    • 在优化过程中,损失的梯度会通过以下方式传播:
      zq(x)=ze(x)+(ekze(x)).detach()z_q(x) = z_e(x) + (e_k - z_e(x)).detach()
      • (ekze(x)).detach()(e_k - z_e(x)).detach() 表示停止梯度传播,仅用 ze(x)z_e(x) 来优化编码器。

详细全文请移步公主号:Goodnote。

参考:欢迎来到好评笔记(Goodnote)!