生成对抗网络GAN损失函数loss的简单理解

640 阅读1分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

原始的公式长这样:

minGmaxDV(D,G)=Expdata (x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]

首先可以明确一点,这种公式肯定是从里面算到外面的,也就是可以先看这一部分:

maxDV(D,G)=Expdata (x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]

我们知道,在每个epoch中,GAN的生成器与判别器是分别训练的,即先固定生成器GG,去训练判别器DD,那么上面这个式子实际上就是判别器的"损失函数"。继续拆分上面这个式子,可以发现主要就是加号左右两个部分。

先看左边。左边这一部分的作用是保证判别器的基础判断能力:对于Expdata (x)[logD(x)]\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]x\boldsymbol{x}为从真实数据分布pdata p_{\text {data }}中采样得到的样本。Expdata (x)[logD(x)]\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]越大,相当于意味着D(x)D(\boldsymbol{x})越大,即判别器越能准确地将真实样本识别为真实样本;因此有maxD\max _{D}

再看右边。右边这一部分的作用是保证判别器能够区分出虚假样本:对于Ezpz(z)[log(1D(G(z)))]\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]z\boldsymbol{z}为从某一特定分布pzp_{\boldsymbol{z}}中得到的采样,G(z)G(\boldsymbol{z})为生成器生成的虚假样本。Ezpz(z)[log(1D(G(z)))]\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]越大,相当于意味着D(G(z))D(G(\boldsymbol{z}))越小,即判别器越能够正确区分虚假样本,将其标为False;因此有maxD\max _{D}

再来看生成器G的"损失函数"。到了训练生成器G的阶段,此时判别器D固定。如果G更强,那么判别器会进行误判,此时D(G(z))D(G(\boldsymbol{z}))会变大,Ezpz(z)[log(1D(G(z)))]\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]更接近于零,即整个式子的值会更小;因此有minG\min _{G}