量化压缩论文:Quantization and Training of Neural Networks for Efficient Integer-Arithm

165 阅读4分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

1. 背景与介绍

随着今年深度学习的发展,基于神经网络的模型在许多任务上取得了比传统方法高很多的准确率指标。但是随之而来的问题是模型越来越庞大,这给它们部署在移动端平台(例如手机、AR/VR)上带来了诸多不便。因此,如何减小模型大小和提高推理速度,成为了一个新的热门研究方向。

目前的模型压缩方法主要分为两大类:1. 设计更高效的网络结构;2.将网络的权重和激活函数从float32量化成更低的bit。

之前的研究有两个问题:

  1. 缺乏一个准确的基线。以往的研究都是采用AlexNet、VGG、GoogleNet之类的模型作为对比基线,但是这些模型在结构上本身就是存在设计冗余的,因此难以评价一个模型压缩方法的有效程度。本文采用的基线是MobileNet。
  2. 很多量化方法部署到实际的硬件平台上可能没法真正加快推理速度。例如,有的方法只对权重进行量化,激活函数还是浮点,所以只能减少存储空间,没法显著减少计算量。有的方法,比如三元/二元量化,采用位移来代替乘法运算。但是在有些硬件上位移并不一定比乘法加法快,而且只有当位数较大时乘加运算的开销才较大。当权重和激活函数量化后,位数较少了,减少乘加运算的需求也没那么旺盛了。有些1-bit的方法,还可能导致准确率显著下降。

2. 量化方案介绍

2.1. 量化参数介绍

所谓量化,就是在原始的浮点数r和量化后的整数q之间寻找一个仿射变换,使得:r=S(qZ)r=S(q-Z), 其中S和Z是参数。对于每一层权重(weights)和对应激活函数(activation),S和Z是相同的。q的bit数常见的有8,4,2,1等,我们常说的8bit量化就是指q的bit数为8,其他bit数类似。bias一般量化为32bit整数。 S称为放缩因子(scale),Z称为零点(zero-point)。Z的存在很有必要,这样能保证浮点数中的零能跟一个量化后的整数对应上。在神经网络中,经常有用零进行padding的情况,因此如果找不到一个整数对应的话会大大损失量化后的精度。

2.2. 量化后的全整数矩阵乘法

考虑两个均为N×NN\times N的方阵r1r_1r2r_2,我们需要通过它们的矩阵乘法获得方阵r3r_3。用rα(i,j),α=1,2,or,3,1i,jNr_{\alpha}^{(i,j)}, \alpha=1,2,or,3,1\leq i,j\leq N表示第α\alpha个矩阵的第ii行第jj列的元素。S,Z,qS,Z,q的表示方法类似。则有:

S3(q3(i,k))=j=1NS1(q1(i,j)Z1)S2(q2(j,k)Z2)(1-1)S_3(q_3^{(i,k)})=\sum_{j=1}^{N}S_1(q_1^{(i,j)}-Z_1)S_2(q_2^{(j,k)}-Z_2)\tag{1-1}

进而得到$$ q_3^{(i,k)}=Z_3+M\sum_{j=1}^{N}(q_1^{(i,j)}-Z_1)(q_2^{(j,k)}-Z_2)\tag{1-2}

其中$M=\frac{S_1S_2}{S_3}$ 在(1-2)中,除了M外,都是整数。 对于M,经过大量实验统计表明,它总是在区间(0,1]中。因此可以表示为:

M=2^{-n}M_0\tag{1-3}

其中$M_0\in [0.5,1)$,n是一个非负整数。可以将$M_0$表示成一个定点数,也即如果硬件平台采用的是int32,则可以找出离$2^{31}M_0$最近的整数来代替$M_0$(最后记得再还原回去就行),这样$M_0$的精度至少有30bit。这样的话,(1-3)就可以转换为整数的右移计算,大大提高效率。 ## 2.3. 关于减少减法操作的说明 相比原始的浮点矩阵乘,(1-2)中似乎额外增加了$2N^3$次减法。其实可以简化为:

q_3^{(i,k)}=Z_3+M\left( NZ_1Z_2-Z_1a_2^{(k)}-Z_2a_1^{(i)}+\sum_{j=1}^Nq_1^{(i,j)}q_2^{(j,k)}\right)\tag{1-4}

其中$a_2^{(k)}=\sum_{j=1}^Nq_2^{(j,k)}, a_1^{(i)}=\sum_{j=1}^Nq_2^{(i,j)}$ 可见$a_2^{(k)},a_1^{(i)}$分别只需要N次加法,因此总的加法只需要$2N^2$(注意当i变化时,$a_2^{(k)}$不变;k变化时,$a_1^{(i)}$不变) 因此,(1-4)的主要复杂度在于$\sum_{j=1}^Nq_1^{(i,j)}q_2^{(j,k)}$,它总共具有$2N$次的算数运算(乘和加),为了得到结果矩阵中的所有元素则需要的复杂度为$2N^3$,这在原始的浮点乘以及其他形式的量化计算下都是无法避免的。其他的复杂度是$O(N^2)$,带有一个小的常数,可以忽略。