Quantization and Training of Neural Networks for Efficient Integer Inference

54 阅读2分钟

Introduction

Google发表的论文,主要介绍了一种量化方案,允许使用integer-only算法进行推理。内容包括:

  • quantization scheme,将weight、activation量化为8 bit int,将少部分参数(bias)量化为32 bit int。
  • quantized inference framework,可以在纯整数算术硬件上有效实现。
  • quantized training framework,最大限度减少量化对真实模型的精度损失。

Quantized Inference

r=S(qZ)r=S(q-Z)
q=round(r/S+Z)q=round(r/S+Z)

量化后的整数q到浮点实数r的affine mapping,

S(scale)是任意正实数,表示实数和整数之间的比例关系;

Z(zero-point)是和q同一类型的整数,是实数为0时经过量化后的整数,这使得能够满足r=0时的量化精确表示。

他们的计算方法为:

S=rmaxrminqmaxqminS=\frac{r_{max}-r_{min}}{q_{max}-q_{min}}
Z=round(qmaxrmax/S)Z=round(q_{max}-r_{max}/S)

C++ Struct:

template<typename QType> // e.g. QType=uint8
struct QuantizedBuffer {
    vector<QType> q; // the quantized values
    float S; // the scale
    QType Z; // the zero-point
};

Integer-arithmetic-only matrix multiplication

通过量化使得矩阵浮点计算转化为只涉及整数运算。

考虑两个N×NN\times N的实数矩阵r1,r2r_1,r_2,且r3=r1r2r_3=r_1r_2,展开即:

r3i,k=j=1Nr1i,jr2j,kr_3^{i,k}=\sum_{j=1}^Nr_1^{i,j}r_2^{j,k}

假设S1,S2,S3S_1,S_2,S_3分别是三个矩阵对应的scale,Z1,Z2,Z3Z_1,Z_2,Z_3分别是三个矩阵对应的zero point,对α=1,2,3\alpha=1,2,31i,jN1 \le i,j \le N,有:

rα(i,j)=Sα(qα(i,j)Zα)r^{(i,j)}_\alpha = S_\alpha (q^{(i,j)}_\alpha -Z_\alpha)

代入上式可以推出:

S3(q3(i,k)Z3)=j=1NS1(q1(i,j)Z1)S2(q2(j,k)Z2)S_3(q^{(i,k)}_3-Z_3)=\sum_{j=1}^NS_1(q^{(i,j)}_1-Z_1)S_2(q^{(j,k)}_2-Z_2)

上式也能被写为:

q3(i,k)=Z3+Mj=1N(q1(i,j)Z1)(q2(j,k)Z2)q^{(i,k)}_3=Z_3+M\sum_{j=1}^N(q^{(i,j)}_1-Z_1)(q^{(j,k)}_2-Z_2)
M=S1S2S3M=\frac{S_1S_2}{S_3}

上式中,唯一的非整数是MM,其只依赖于三个矩阵的scale,可以离线计算,并且经验表明MM总是处于(0,1)的区间内,因此MM可以表示为

M=2nM0M=2^{-n}M_0

此时,M0M_0可以表示为定点乘数(int16或者int32),定点数并不一定是整数,所谓定点,指的是小数点的位置是固定的,即小数位数是固定的。因此,如果存在 M=2nM0M=2^{-n}M_0,那我们就可以通过 M0M_0 的 bit 位移操作实现 2nM02^{-n}M_0,这样整个过程就都在定点上计算了。

Efficient handling of zero-points

在上面的公式中因为两个矩阵都需要减去各自的零点值,减法运算后得到的值可能会突破int8范围,到时候就需要int16来存储,但整个运算为了控制在int8的类型下计算,论文做了下面的变换:

q3(i,k)=Z3+M(NZ1Z2Z1a2(k)Z2a1(i)+j=1N(q1(i,j)q2(j,k)))q^{(i,k)}_3=Z_3+M(NZ_1Z_2-Z_1a_2^{(k)}-Z_2\overline a_1^{(i)}+\sum_{j=1}^N(q^{(i,j)}_1q^{(j,k)}_2))
a2(k)=j=1N(q2(j,k))a_2^{(k)}=\sum_{j=1}^N(q^{(j,k)}_2)
a1(i)=j=1N(q1(i,j))\overline a_1^{(i)}=\sum_{j=1}^N(q^{(i,j)}_1)

每个a2(k)a_2^{(k)}a1(i)\overline a_1^{(i)}只需要N个算术运算,所以总共需要2N22N^2个加法运算。其余的主要计算都在矩阵乘法j=1N(q1(i,j)q2(j,k))\sum_{j=1}^N(q^{(i,j)}_1q^{(j,k)}_2)上,需要2N32N^3算术运算。

通过这样的变换,可以有效的避免计算过程中的值溢出int8范围,使得可以低开销的处理任意zero-points。

Implementation of a typical fused layer

前面描述了权重矩阵的量化,但在神经网络中还有偏置bias和激活函数的映射,因为int8类型的运算完之后的值应该是在int32之内的,所以bias选择int32的类型。

这样的选择首先因为bias在整个神经网络中只占据极少的一部分,此外bias的作用其实非常重要,高精度的bias可以降低模型的偏差。因此加上bias之后就变成了int32,我们需要再次转换成int8类型(反量化),之后再进入到激活中。

image.png

假设图中的权重weights为ww,biases为bb,输入input为xx,输出的待激活值为aa,图中的运算可以表示为:

a=iNwixi+ba=\sum_i^Nw_ix_i+b

代入量化公式可得

Sa(qaZa)=iNSw(qwZw)Sx(qxZx)+Sb(qbZb)S_a(q_a-Z_a)=\sum_{i}^NS_w(q_w-Z_w)S_x(q_x-Z_x)+S_b(q_b-Z_b)
qa=SwSxSaiN(qwZw)(qxZx)+SbSa(qbZb)+Zaq_a=\frac{S_wS_x}{S_a}\sum_{i}^N(q_w-Z_w)(q_x-Z_x)+\frac{S_b}{S_a}(q_b-Z_b)+Z_a

由于iN(qwZw)(qxZx)\sum_{i}^N(q_w-Z_w)(q_x-Z_x)的结果通常以int32存储,所以偏置向量bb使用int32作为量化类型,zero-point选为0,scale与accumulators的scale相同,为weight scale与input activation scale的乘积。

Sb=SwSx,Zb=0S_{b}=S_wS_x, Z_{b}=0

因此,运算公式调整为:

qa=SwSxSa(iN(qwZw)(qxZx)+qb)+Za=M(iNqwqxiNqwZxiNqxZw+iNZwZx+qb)+ZaM=SwSxSaq_a=\frac{S_wS_x}{S_a}(\sum_{i}^N(q_w-Z_w)(q_x-Z_x)+q_b)+Z_a\\=M(\sum_{i}^Nq_wq_x-\sum_{i}^Nq_wZ_x-\sum_{i}^Nq_xZ_w+\sum_{i}^NZ_wZ_x+q_b)+Z_a\\M=\frac{S_wS_x}{S_a}

上式可以完全通过定点运算计算。由于ZwqwZxqbZ_w、q_w、Z_x、q_b都是可以事先计算的,因此 iNqwZxiNZwZx+qb\sum_i^N q_wZ_x、\sum_i^NZ_wZ_x+q_b也可以事先计算好,实际 inference 的时候,只需要计算 iNqwqx\sum_{i}^N q_wq_xiNqxZw\sum_i^N q_xZ_w 即可。

获得int32累加值后还要做三件事:
1.scale down(将累加输出值缩放到8bit)

2.cast down(将1步处理生成的8bit转换到uint8)

3.activation function(利用激活函数产生最终的8bit输出)

Training with simulated quantization

image.png

量化感知训练(QAT)是指在量化过程中,对网络进行训练,让网络能够更好的适应量化带来的信息损失,这一方法准确性一般比后训练量化要高。在PTQ中模型训练和量化是分开的,而QAT则是在模型训练时加入了伪量化节点,用于模拟模型量化时引起的误差。其处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的baseline模型;
  2. 在baseline模型中插入伪量化节点,得到QAT模型,并且在数据集上对QAT模型进行finetune;
  3. 伪量化节点会模拟推理时的量化过程并且保存finetune过程中计算得到的量化参数;
  4. finetune完成后,使用3. 中得到的量化参数对QAT模型进行量化得到INT8模型,并部署至推理框架中进行推理

Fake quantize

伪量化实际上是quantization+dequantization的结合,实际上就是模拟量化round引起的误差,其公式为

clamp(r;a,b)=min(max(x,a),b)s(a,b,n)=ban1q(r;a,b,n)=clamp(r;a,b)as(a,b,n)s(a,b,n)+aclamp(r;a,b)=min(max(x,a),b)\\s(a,b,n)=\frac{b-a}{n-1}\\q(r;a,b,n)=\lfloor\frac{clamp(r;a,b)-a}{s(a,b,n)}\rceil s(a,b,n)+a

上式中,r为待量化的实数;[a,b][a,b]为量化范围;n为量化bit数,常为8;\lfloor \rceil代表四舍五入至最近的整数;s(a,b,n)s(a,b,n)为量化因子scale;q(r;a,b,n)q(r;a,b,n)为量化-反量化之后的值,其类型为float。

注意论文中其定义的量化/反量化公式为q=(rz)/sq=(r-z)/sr=qs+zr=qs+z,因此上述公式与之前定义略有不同。

伪量化的操作看起来输入输出没变,但是实际上在其中模拟了量化round操作,将这种误差当做一种训练的噪声,在QAT finetune的同时,模型会去适应这种噪声,从而在最后量化为INT8时,减少精度的损失。

Learning quantization ranges

对于Weights和Activations,量化的方法与量化范围不同,对于weights,在进行conv运算之前进行fake quantize,而activations在激活函数执行完之后进行fake quantize,对应于途中的wt quant与act quant。

伪量化节点会保存finetune过程中的量化参数,伪量化节点的计算公式中 [a,b][a,b] 即为FP32浮点数值的范围,这个值将会在finetune过程中进行估计与更新,上面介绍了伪量化节点分别weight quantization以及activation quantizaion:

  • 对于weight quantization的量化节点,直接将[a,b][a,b]设置为weights的最大值与最小值即可,即 a=min(w)a=min(w) , b=max(w)b=max(w) ;
  • 对于activation quantizaion,处理方式类似于batch norm,使用了指数移动平均(EMA),在finetune的每个batch动态地更新[a,b][a,b] ;

最后量化模型的时候,只需设置 S=s(a,b,n)Z=z(a,b,n) S=s(a,b,n)、Z=z(a,b,n) 即可。

另外,偏差bias没有被量化,且其在推理过程中被表示为32位整数。

Batch normalization folding

BatchNorm是Google提出的一种加速神经网络训练的技术,在每一层输出时做了一遍归一化的操作。在inference过程中,BN的公式为:

BatchNorm(x)=γ(xμσk2+ϵ)+βBatchNorm(x)=\gamma (\frac{x-\mu}{\sqrt{\sigma_k^2+\epsilon}})+\beta

其中μ\muσ\sigma是批量统计的指数移动平均值计算的均值与方差,γ\gammaβ\beta是学习到的超参数。若将其融合至前一个线性层,则可以重写为: image.png 注意折叠之后的权重为

wfold=γwσk2+ϵw_{fold}=\frac{\gamma w}{\sqrt{\sigma_k^2+\epsilon}}

带有BatchNorm的Conv层训练和推理流程如下图所示。

屏幕截图 2024-01-02 104954.png

屏幕截图 2024-01-02 105003.png

对于Batch normalization层,在训练时作为一个单独的层,而在推理时为了提升效率通常融合到前一个或者下一个全连接层或者卷积层中。因此,在QAT中需要模拟这种folding,如下图所示

image.png

由于实际 inference 的时候,BN 是 folding 到 Conv 中的,因此在QAT中也需要模拟这个操作,得到新的 weight 和 bias,并用新的 Conv 估计量化误差来回传梯度。