GAIN文章阅读以及pytorch实现

601 阅读4分钟

GAIN

别人缺陷

基于深度学习算法包括去噪自动编码器(DAE)和生成对抗网络(GAN)。DAE在训练时需要使用完整的数据集。也有一些DAE方法不需要完整数据集,但是只用了能够观察到的数据来进行训练。有方法使用深度卷积GAN来做图像修复,也是需要完整的数据来训练鉴别器D。

问题主要出在需要完整数据集,没有的话就只使用数据中完整的那一部分。

创新点

  1. 没有完整数据集,也可以使用
  2. 设计了一个提示矩阵,为鉴别器提供额外的信息(后续说明),这种提示确保了生成器根据真实的底层数据分布生成样本。

修复问题定义

d维输入(数据向量 data vector)X,X=X_1×X_2...×X_d=(X_1,...,X_d)X=X\_1×X\_2...×X\_d=(X\_1,...,X\_d),其分布为P(X)。掩码矩阵(mask vector) M=(M_1,...,M_d)M=(M\_1,...,M\_d),元素的值随机为0或1。再定义一个X~\tilde{X}。M可以表示为X的哪些分量是已被观察的(即非缺失的),且从X~\tilde{X}中可以恢复M。

image-20230524104904673.png

通过对数据的分布进行建模,而不是仅仅对期望建模,这样的话可以进行多次绘制,也就是多次估算,使得我们能够捕获估算值的不确定性。大意就是对数据分布建模可以进行多次估算确实的值。

模型设计

image-20230524140139917.png

生成器G

X~\tilde{X}, M 和噪音Z(d维噪声)作为输入,输出为Xˉ\bar{X}。定义G:

Xˉ=G(X~,M,(1M)Z)X^=MX~+(1M)Xˉ\bar{X}=G(\tilde{X},M,(1-M)⊙Z)\\ \hat{X}=M⊙\tilde{X}+(1-M)⊙\bar{X}

X^\hat{X}是修复后的完整数据。

生成器就是简单的全连接层

鉴别器D

标准的GAN,鉴别器是判断生成器数据是真还是假,然而这里的生成器的数据是真与假的结合。

现在这里的鉴别器D识别的不是整个向量是真是假,而是区分一个向量里哪些元素是真,哪些是假,相当于预测出掩码矩阵。即如果鉴别器很强大,那么预测出的掩码矩阵就是我们原先定义的M。

为什么这样设计鉴别器:如果生成器很强大,生成了满足原先分布的数据,鉴别器完全鉴别不出来,则鉴别器鉴别出的掩码矩阵就是全1。我们希望鉴别器能够鉴别出哪些是生成的,如果鉴别出来,就会越来越趋近于M。

鉴别器也是简单的全连接层。

提示矩阵Hint

揭示了原始数据中缺失部分的某些信息,让D更加关注它所提示的部分,同时也逼迫G生成更加接近真实的数据用来填补。提示矩阵是基于掩码矩阵来自定义的,提示矩阵和G生成的数据共同作为输入。为什么可以产生这种效果,详见论文推导。

训练

训练D以最大化正确预测M的概率(训练D来让D的输出接近M),训练G来最小化D预测M的概率(训练G来让D的输出远离M) 。通俗来讲,D为了判别生成的数据是假,所以对于缺失位置,尽可能判断为0,即逐渐趋向于M。定义V(D, G):

image-20230524145545507.png

先训练D,设D的输出为D_prob(注意,输入D中的数据new_x,未缺失位置的数据使用原值,缺失位置的数据使用生成值)。D的损失如下:

D_loss = -torch.mean(m * torch.log(D_prob + 1e-8) + (1 - m) * torch.log(1. - D_prob + 1e-8))

如果D判别未缺失位置的数接近1,则torch.mean(m * torch.log(D_prob + 1e-8)越大,那么D_loss越小。

如果D判别缺失位置的数据接近1,则(1 - m) * torch.log(1. - D_prob + 1e-8))越大,那么D_loss越大。

再训练G,设G的输出为G_sample。G的损失如下

G_loss = -torch.mean((1 - m) * torch.log(D_prob + 1e-8)) + 
alpha * torch.mean((m * new_x - m * G_sample) ** 2) / torch.mean(m)

torch.mean((1 - m) * torch.log(D_prob + 1e-8))上面解释过。

torch.mean((m * new_x - m * G_sample) ** 2) / torch.mean(m)表示未缺失位置的MSE

总体而言

  1. G让生成的未缺失数据逼近于真实数据分布
  2. D是被出哪些是缺失数据,哪些是未缺失数据
  3. 没有使用到缺失位置的loss,因为实际场景下数据是缺失的,确实位置是无法计算loss的

pytorch实现代码:github.com/RadishVeget…

一些参考:

juejin.cn/post/715058…

github.com/dhanajitb/G…