对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

3,824 阅读9分钟

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

写在前面

Hello,大家好,我是小苏🧒🏽🧒🏽🧒🏽

在前面的文章中,我已经介绍过挺多种GAN网络了,感兴趣的可以关注一下我的专栏:深度学习网络原理与实战 。目前专栏主要更新了GAN系列文章、Transformer系列和语义分割系列文章,都有理论详解和代码实战,文中的讲解都比较通俗易懂,如果你希望丰富这方面的知识,建议你阅读试试,相信你会有蛮不错的收获。🍸🍸🍸

在阅读本篇教程之前,你非常有必要阅读下面两篇文章:

其实啊,我相信大家来看这篇文章的时候,一定是对上文提到的文章有所了解了,因此大家要是觉得自己对GAN和WGAN了解的已经足够透彻了,那么完全没有必要再浪费时间阅读了。如果你还对它们有一些疑惑或者过了很久已经忘了希望回顾一下的话,那么文章[1]和文章[2]获取对你有所帮助。

大家准备好了嘛,我们这就开始准备学习Spectral Normalization啦!🚖🚖🚖

 

Spectral Normalization原理详解

​  首先,让我们简单的回顾一下WGAN。🌞🌞🌞由于原始GAN网络存在训练不稳定的现象,究其本质,是因为它的损失函数实际上是JS散度,而JS散度不会随着两个分布的距离改变而改变(这句不严谨,细节参考WGAN中的描述),这就会导致生成器的梯度会一直不变,从而导致模型训练效果很差。WGAN为了解决原始GAN网络训练不稳定的现象,引入了EM distance代替原有的JS散度,这样的改变会使生成器梯度一直变化,从而使模型得到充分训练。但是WGAN的提出伴随着一个难点,即如何让判别器的参数矩阵满足Lipschitz连续条件。

​  如何解决上述所说的难点呢?在WGAN中,我们采用了一种简单粗暴的方式来满足这一条件,即直接对判别器的权重参数进行剪裁,强制将权重限制在[-c,c]范围内。大家可以动动我们的小脑瓜想想这种权重剪裁的方式有什么样的问题——(滴,揭晓答案🍍🍍🍍)如果权重剪裁的参数c很大,那么任何权重可能都需要很长时间才能达到极限,从而使训练判别器达到最优变得更加困难;如果权重剪裁的参数c很小,这又容易导致梯度消失。因此,如何确定权重剪裁参数c是重要的,同时这也是困难的。WGAN提出之后,又提出了WGAN-GP来实现Lipschitz 连续条件,其主要通过添加一个惩罚项来实现。【关于WGAN-GP我没有做相关教程,如果不明白的可以评论区留言】那么本文提出了一种归一化的手段Spectral Normalization来实现Lipschitz连续条件,这种归一化具体是怎么实现的呢,下面听我慢慢道来。🍻🍻🍻


我们还是来先回顾一下Lipschitz连续条件,如下:

​            f(x1)f(x2)Kx1x2|f(x_1)-f(x_2)| \le K|x_1-x_2|

这个式子限制了函数f(){\rm{f}}( \cdot )的导数,即其导数的绝对值小于K,f(x1)f(x2)x1x2K\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|} \le K。 🍋🍋🍋

本文介绍的Spectral Normalization的K=1,让我们一起来看看怎么实现的吧!!!


  上文提到,WGAN的难点是如何让判别器的参数矩阵满足Lipschitz连续条件。那么我们就从判别器入手和大家唠一唠。实际上,判别器也是由多层卷积神经网络构成的,我们用下式表示第n层网络输出和第n-1层输入的关系:

​            Xn=an(WnXn1+bn)X_n=a_n(W_n \cdot X_{n-1}+b_n)

  其中an()a_n(\cdot)表示激活函数,WnW_n表示权重参数矩阵。为了方便起见,我们不设置偏置项bnb_n,即bn=0b_n=0。那么上式变为:

​            Xn=an(WnXn1)X_n=a_n(W_n \cdot X_{n-1})

  再为了方便起见🤸🏽‍♂️🤸🏽‍♂️🤸🏽‍♂️,我们设an()a_n(\cdot),即激活函数为Relu。Relu函数在大于0时为y=x,小于0时为y=0,函数图像如下图所示:

image-20221114112127282

​  这样的话式Xn=an(WnXn1)X_n=a_n(W_n \cdot X_{n-1})可以写成Xn=DnWnXn1X_n=D_n \cdot W_n \cdot X_{n-1},其中DnD_n为对角矩阵。【大家这里能否理解呢?如果我们的输入为正数时,通过Relu函数值是不变的,那么此时DnD_n对应的对角元素应该为1;如果我们的输入为负数时,通过Relu函数值将变成0,那么此时DnD_n对应的对角元素应该为0。也就是说我们将XnX_n改写成DnWnXn1D_n \cdot W_n \cdot X_{n-1}形式是可行的。】

​  接着我们做一些简单的推理,得到判别器第n层输出和原始输入的关系,如下图所示:

image-20221114144627781

  最后一层的输出XnX_n即为判别器的输出,接下来我们用f(x)f(x)表示;原始输入数据x0x_0我们接下来用xx表示。则判别器最终输入输出的关系式如下:

​   f(x)=DnWnDn1Wn1D3W3D2W2D1W1xf(x) = {D_n} \cdot {W_n} \cdot {D_{n - 1}} \cdot {W_{n - 1}} \cdots {D_3} \cdot {W_3} \cdot {D_2} \cdot {W_2} \cdot {D_1} \cdot {W_1} \cdot x

  上文说到Lipschitz连续条件本质上就是限制函数f(){\rm{f}}( \cdot )的导数变化范围,其实就是对f(x)f(x)梯度提出限制,如下:

xf(x)2=DnWnDn1Wn1D3W3D2W2D1W12Dn2Wn2Dn12Wn12D12W12||{\nabla _x}f(x)|{|_2} = ||{D_n} \cdot {W_n} \cdot {D_{n - 1}} \cdot {W_{n - 1}} \cdots {D_3} \cdot {W_3} \cdot {D_2} \cdot {W_2} \cdot {D_1} \cdot {W_1}|{|_2} \le ||{D_n}|{|_2} \cdot ||{W_n}|{|_2} \cdot ||{D_{n - 1}}|{|_2} \cdot ||{W_{n - 1}}|{|_2} \cdots ||{D_1}|{|_2} \cdot ||{W_1}|{|_2}

  其中A2||A||_2表示矩阵A的2范数,也叫谱范数,它的值为λ1\sqrt {{\lambda _1}}λ1{\lambda _1}AHA{{\rm{A}}^H}{\rm{A}}的最大特征值。λ1\sqrt {{\lambda _1}}又称作矩阵A的奇异值【注:奇异值是AHA{{\rm{A}}^H}{\rm{A}}的特征值的开根号,也就是说λ1\sqrt {{\lambda _1}} 为A的其中一个奇异值或谱范数是最大的奇异值】,这里我们将谱范数,即最大的奇异值记作σ(A)=λ1\sigma {(A)} = \sqrt {{\lambda _1}}。由于D是对角矩阵且由0、1构成,其奇异值总是小于等于1,故有下式:

image-20221114155057050

  即xf(x)2=Dn2Wn2D12W12Π1nσ(Wi){\nabla _x}f(x)|{|_2}= ||{D_n}|{|_2}\cdot ||{W_n}|{|_2} \cdots ||{D_1}|{|_2} \cdot ||{W_1}|{|_2} \le \mathop \Pi \limits_1^{\rm{n}} \sigma ({W_i})。为满足Lipschitz连续条件,我们应该让xf(x)2K||{\nabla _x}f(x)|{|_2} \le K ,这里的K设置为1。那具体要怎么做呢,其实就是对上式做一个归一化处理,让每一层参数矩阵除以该层参数矩阵的谱范数,如下:

​  xf(x)2=Dn2Wn2σ(Wn)D12W12σ(W1)Π1nσ(Wi)σ(Wi)=1||{\nabla _x}f(x)|{|_2} = |{D_n}|{|_2} \cdot \frac{{||{W_n}|{|_2}}}{{\sigma ({W_n})}} \cdots ||{D_1}|{|_2} \cdot \frac{{||{W_1}|{|_2}}}{{\sigma ({W_1})}} \le \mathop \Pi \limits_1^{\rm{n}} \frac{{\sigma ({W_i})}}{{\sigma ({W_i})}} = 1

  这样,其实我们的Spectral Normalization原理就讲的差不多了,最后我们要做的就是求得每层参数矩阵的谱范数,然后再进行归一化操作。要想求矩阵的谱范数,首先得求矩阵的奇异值,具体求法我放在附录部分。

  但是按照正常求奇异值的方法会消耗大量的计算资源,因此论文中使用了一种近似求解谱范数的方法,伪代码如下图所示:

image-20221114164733970

  在代码的实战中我们就是按照上图的伪代码求解谱范数的,届时我们会为大家介绍。🍄🍄🍄


注:大家阅读这部分有没有什么难度呢,我觉得可能还是挺难的,你需要一些矩阵分析的知识,我已经尽可能把这个问题描述的简单了,有的文章写的很好,公式推导的也很详尽,我会在参考链接中给出。但是会涉及到最优化的一些理论,估计这就让大家更头疼了,所以大家慢慢消化吧!!!🍚🍚🍚在最后的附录中,我会给出本节内容相关的矩阵分析知识,是我上课时的一些笔记,笔记包含本节的知识点,但针对性可能不是很强,也就是说可能包含一些其它内容,大家可以选择忽略,当然了,你也可以细细的研究研究每个知识点,说不定后面就用到了呢!!!🥝🥝🥝

 

Spectral Normalization源码解析

源码下载地址:Spectral Normalization📥📥📥

  这个代码使用的是CIFAR10数据集,实现的是一般生成对抗网络的图像生成任务。我不打算再对每一句代码进行详细的解释,有不明白的可以先去看看我专栏中的其它GAN网络的文章,都有源码解析,弄明白后再看这篇你会发现非常简单。那么这篇文章我主要来介绍一下Spectral Normalization部分的内容,其相关内容在spectral_normalization.py文件中,我们理论部分提到Spectral Normalization关键的一步是求解每个参数矩阵的谱范数,相关代码如下:

def _update_u_v(self):
    u = getattr(self.module, self.name + "_u")
    v = getattr(self.module, self.name + "_v")
    w = getattr(self.module, self.name + "_bar")
    height = w.data.shape[0]
    for _ in range(self.power_iterations):
        u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))  
        v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))

    sigma = u.dot(w.view(height, -1).mv(v))
    setattr(self.module, self.name, w / sigma.expand_as(w))
    
    
    
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

  对上述代码做一定的解释,6,7,8,9,10行做的就是理论部分伪代码的工作,最后会得到谱范数sigma。11行为使用参数矩阵除以谱范数sigma,以此实现归一化的作用。【torch.mv实现的是矩阵乘法的操作,里面可能还有些函数你没见过,大家百度一下用法就知道了,非常简单】

其实关键的代码就这些,是不是发现特别简单呢🍸🍸🍸每次介绍代码时我都会强调自己动手调试的重要性,很多时候写文章介绍源码都觉得有些力不从心,一些想表达的点总是很难表述,总之,大家要是有什么不明白的就尽情调试叭,或者评论区留言,我天天在线摸鱼滴喔。⭐⭐⭐后期我也打算出一些视频教学了,这样的话就可以带着大家一起调试,我想这样介绍源码彼此都会轻松很多。🛩🛩🛩

 

小结

  Spectral Normalization确实是有一定难度的,我也有许多地方理解的也不是很清楚,对于这种难啃的问题我是这样认为的。我们可以先对其有一个大致的了解,知道整个过程,知道代码怎么实现,能使用代码跑通一些模型,然后考虑能否将其用在自己可能需要使用的地方,如果加入的效果不好,我们就没必要深究原理了,如果发现效果好,这时候我们再回来慢慢细嚼原理也不迟。最后,希望各位都能获取新知识,能够学有所成叭!!!🌹🌹🌹

 

参考链接

GAN — Spectral Normalization 🍁🍁🍁

Spectral Normalization for GAN🍁🍁🍁

详解GAN的谱归一化(Spectral Normalization)🍁🍁🍁

谱归一化(Spectral Normalization)的理解🍁🍁🍁

 

附录

  这部分是我学习矩阵分析这门课程时的笔记,截取一些包含此部分的内容,有需求的感兴趣的可以看一看。🌱🌱🌱

image-20221114213810323

image-20221114213405314

image-20221114213500828

   

如若文章对你有所帮助,那就🛴🛴🛴

         一键三连 (1).gif