在N天之前的文章模型的过拟合 - 知乎专栏中,我曾经提到对神经网络的设想,结果发现多年前已经有文章对这个问题进行了比较深入的探讨,于是好奇心使我对这篇文章进行了研究。文章的名称叫做——《Intriguing properties of neural networks》,这篇文章的作者阵容放到今天还是非常强大的,比方说GAN的作者Ian Goodfellow……
在这篇文章之前,就已经有人提出有趣的想法:训练好的深层神经网络是对训练数据空间的泛化先验。这个说法和之前提到的假想比较类似。首先,深层网络通过学习训练数据使得训练数据的识别能够完美拟合,同时它还能保证训练数据附近的区域也能够识别正确。
当然,这个“附近的区域”是一个不够清晰的词汇,什么算作“附近”?没有度量就没有概念。我们能不能真的找到一种精确的度量,帮助我们衡量模型或者某一个模型层的“泛化”能力?文章中并没有给出一个十分精细的能力衡量方式,但是它给出了一个泛化下界的表达方式。
泛化能力的下界
我们知道Lipchitz条件的公式:
如果模型是一个线性模型,,就有
当时,有
因为x,y处于同于同一个空间,而且我们需要考虑临域的函数波动情况,所以
这个r可以想像成图像数据的一个小波动。前面提过,我们需要保证模型在小波动下不改变输出的标签,因此上面的公式最终变成了:
这个公式看上去就比较亲切了。W可以看作由输入空间到输出空间的一个线性算子,那么上式的左边就是在求W的范数:
于是问题变成了做线性算子的谱分析,当然,这是泛函分析里面的称呼,我们可以换一个更通俗的称呼,对于上面这个问题,我们要求的是矩阵W的绝对值最大的特征值。
为什么是绝对值最大的特征值?
这个问题回到了线性代数上,我们都知道矩阵的特征值公式:
当矩阵W与自己的特征向量相乘时,其效果相当于对特征向量的“长度”做变化,如果与一个非特征向量相乘,那么我们可以想象,这个向量可以被分解成几个特征向量的线性组合,最终还是可以通过特征值的公式表示回来:
所以它的最终公式变为
到了这一步可以看出,如果W不乘以一个特征值,那么它基本不可能达到最大值,所以最大值一定出现在特征向量上:
这么看来,我们选择绝对值最大的特征值即可,所以有:
所以绝对值最大的特征值代表了这个线性算子的波动的最大值,这个数字越大,说明算子的波动程度越大,在深层模型的累积下,局部的泛化性越有可能弱。
站在整体网络的角度
上面我们看完了一层网络的计算方式,下面就来看看一个网络整体的泛化下界。我们假设一个网络由K层组成,由于上面我们推导的矩阵可以直接套用到全连接层上,我们接下来先假设所有的线性运算都是全连接层,不包括卷积层,后面再将卷积层加入。这个网络的计算公式为:
其中的表示了线性部分的全连接层。如果每一层都满足:
那么对于整个网络,就有:
并且
从这个公式可以看出,一般来说,每多一层网络,这个L都会变大。如果小于1,那么模型的震荡就会小一些(但是小于1很难)。
卷积层的泛化下界
全连接层的计算相对容易些,因为矩阵乘法是大家比较熟悉的,但是卷积层呢?卷积操作看上去比较复杂,似乎运算比较困难了,论文给出的解答也比较晦涩,这里可以详细提一下。
我们假设输入的数据维度为C*N*N,卷积参数为C*D*K*K,其中C代表输入的feature map数,D代表输出的feature map数,N为图像的维度,K为图像卷积核的维度。我们不考虑stride和pad的问题,因此stride=1,pad保持图像不变。
卷积操作的计算和全连接层的方法不同,因此我们需要让他们的运算形式相近,然后利用上面的方式求解。这时候我们就需要利用图像处理中的2维傅立叶变换了。我们用表示一个K*K的卷积核,我们对每一个卷积核进行傅立叶变换,这样每个K*K的卷积核变成了N*N的频率矩阵,我们用
表示。
这个时候,对于每个输出的像素,它等于输入的feature map的相同位置的特征值与对应的所在位置的权重相乘求和,如果图像输入设为x,图像输出为y,那么有:
进一步泛化,将输出维度d去掉,就有:
这时候右边第一项是一个矩阵,实际上代表了参数,第二项是一个向量,是输入数据,通过这样的变换,卷积的操作形式同全连接一样了,下面的目标就是求
求完了这个,还没有完,我们发现每一个h,w位置都有一个参数矩阵W,最后还要在这些W的范数中取出一个最大的,用来代表卷积层的泛化下界:
这样最终的求解结果才算完成。
当然,这其中没有推导stride不等于1的情况,大家可以套路情况自己计算下
总结
通过这一串理论推导,我们真的估计出了一个网络层对应的Lipchitz条件数,文章中作者用这个方法计算了AlexNet各个网络的条件数,其中第一层卷积最小,它的参数维度为3*96*11*11,L值为2.75,第5层卷积最大,参数维度为384*256*3*3,L值为11。这里面也可以看出一些模型泛化的效果来。
我们花了大力气求出了这些数字,最后还是要回到论文提出的核心点来:模型的网络层的L值比较高(斜率上界为11的非线性模型),那么在训练数据的附近,很可能出现一些盲点:将训练数据经过微小的调整,就可以让模型判断错误。关于这个实验,我们也在前面的文章寻找CNN的弱点 - 知乎专栏中提到过,下一回我们就沿着这篇文章,看看对抗这种盲点数据的训练方法。