Cudnn 算子融合

116 阅读3分钟

这里我们以视觉网络中常见的 conv+ bn 说起

一、推理融合

截屏2025-08-16 18.13.46.png


1. BatchNorm 的公式

对于通道 cc 的输入特征图 xx,BatchNorm 在训练时公式为:

y=γxμσ2+ϵ+βy = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

其中:

  • μ,σ2\mu, \sigma^2:均值和方差(训练后会存成推理时固定的 running_mean 和 running_var
  • γ,β\gamma, \beta:可学习的缩放和偏移参数
  • ϵ\epsilon:防止除零的小常数

2. Conv 的公式

假设卷积层的权重为 W,偏置为 b,卷积计算为:

z=W*x+b

其中 * 表示卷积操作。


3. Conv + BN 融合原理

把 BN 的线性变换合并进卷积:

  1. BN 先把卷积结果标准化:
y=γ(Wx+b)μσ2+ϵ+βy = \gamma \cdot \frac{(W*x + b) - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
  1. 重新整理成:
y=(γσ2+ϵ)(Wx)+(γσ2+ϵ(bμ)+β)y = \left(\frac{\gamma}{\sqrt{\sigma^2 + \epsilon}}\right) (W*x) + \left(\frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} (b - \mu) + \beta \right)
  1. 定义新的卷积参数:

    • 新权重:
    W=Wγσ2+ϵW' = W \cdot \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}}
    • 新偏置:
    b=γσ2+ϵ(bμ)+βb' = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} (b - \mu) + \beta

最终得到等效的 Conv-only 层

y=W′∗x+b′y = W' * x + b'


4. 融合的好处

  • 减少推理开销:少了一层 BN 计算和访存。
  • 简化计算图:利于部署到 TensorRT、TVM、ONNXRuntime、NCNN 等推理框架。

二、训练融合

前向融合

我们继续以该图分析

截屏2025-08-16 18.13.46.png


bn 训练前向

给定的某一通道内:

1、批统计:求均值 标准差

μB=1mi=1mxi,σB2=1mi=1m(xiμB)2\mu_B=\frac{1}{m}\sum_{i=1}^m x_i,\qquad \sigma_B^2=\frac{1}{m}\sum_{i=1}^m (x_i-\mu_B)^2

2、求归一化

x^i=xiμBσB2+ε,yi=γx^i+β\hat{x}_i=\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\varepsilon}},\qquad y_i=\gamma\,\hat{x}_i+\beta

以上两步都需要遍历当前通道内NHW个点

卷积前向:

假设卷积层的权重为 W,偏置为 b,卷积计算为:

z=W*x+b

其中 * 表示卷积操作。

这里我们以cuda 编程模型来考虑,为了追求性能,卷积和 bn 训练的滑块无法做到统一。所以对于训练融合我们换个思路,将 bn 训练第一步和第二步在计算图上拆为两个计算节点,归一化节点结合 relu,conv 进行融合。第一部分不进行融合,这样计算图可以粗略概括为,conv+bn+relu+bn==>conv+stats+(scale+bias+relu+conv)

ConvBNfprop.png

下载.png

反向融合

反向融合首先要明白反向传播的原理,其本质就是链式求导

bn 的反向传播也主要分为两部分

第一步求 db ds

第二步求 A B C 三个系数并进行运算得出 dinput

推导如下

首先明白最终目标是求得 Lxi\displaystyle \frac{\partial L}{\partial x_i}

回顾前向公式

μ=1mj=1mxj,σ2=1mj=1m(xjμ)2,x^i=xiμσ2+ε,yi=γx^i+β.\mu=\tfrac{1}{m}\sum_{j=1}^m x_j,\quad \sigma^2=\tfrac{1}{m}\sum_{j=1}^m (x_j-\mu)^2,\quad \hat{x}_i=\frac{x_i-\mu}{\sqrt{\sigma^2+\varepsilon}},\quad y_i=\gamma\hat{x}_i+\beta.

所以反向公式分为三部分

Lxi=Lx^ix^ixi+Lμμxi+Lσ2σ2xi\frac{\partial L}{\partial x_i} = {\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial x_i}} + {\frac{\partial L}{\partial \mu}\frac{\partial \mu}{\partial x_i}} + {\frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial x_i}}

1Lx^ix^ixi=Lyiyix^ix^ixi=Lyiγ 1、 \frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial x_i} = \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial x_i} = \frac{\partial L}{\partial y_i}\gamma

2Lμμxi2、\frac{\partial L}{\partial \mu}\frac{\partial \mu}{\partial x_i}
对于Lμ,由两部分组成,一个是x^iμ,另一部分是σ2中的μ,所以对于\frac{\partial L}{\partial \mu},由两部分组成,一个是\hat{x}_i 中\mu,另一部分是\sigma^2中的\mu,所以
Lμ=i=1mLx^ix^iμ+i=1mLx^ix^iσ2σ2μ,\frac{\partial L}{\partial \mu}=\sum_{i=1}^m\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \mu}+\sum_{i=1}^m\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial \mu},

首先我们看下

σ2μ=2mj=1m(xjμ)这个导数恒等于0,所以Lμ的值由前半部分i=1mLx^ix^iμ决定\frac{\partial \sigma^2}{\partial \mu}=\tfrac{-2}{m}\sum_{j=1}^m (x_j-\mu) 这个导数恒等于 0,所以\frac{\partial L}{\partial \mu}的值由前半部分\sum_{i=1}^m\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \mu}决定

i=1mLx^ix^iμ=i=1mLx^i1σ2+ε\sum_{i=1}^m\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \mu}=\sum_{i=1}^m\frac{\partial L}{\partial \hat{x}_i}\frac{-1}{\sqrt{\sigma^2+\varepsilon}}
又由于μxi=1m,所以Lμμxi=1mLx^i1σ 又由于 \frac{\partial \mu}{\partial x_i}=\frac{1}{m},所以\frac{\partial L}{\partial \mu}\frac{\partial \mu}{\partial x_i}=\frac{1}{m}\sum\frac{\partial L}{\partial \hat{x}_i}\frac{-1}{\sigma}

3、继续分析Lσ2σ2xi3、继续分析{\frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial x_i}}

x^i=(xiμ)inv_std,把inv_std看作1σ2+ε\hat{x}_i=(x_i-\mu)\text{inv\_std} ,把 {\text {inv\_std}} 看作 \frac{1}{\sqrt{\sigma^2+\varepsilon}} 看作 x^i\hat{x}_iσ2\sigma^2 的函数:

x^iσ2=(xiμ)inv_stdσ2=(xiμ)(12)(σ2+ε)3/2.\frac{\partial \hat{x}_i}{\partial \sigma^2} =(x_i-\mu)\cdot\frac{\partial \text{inv\_std}}{\partial \sigma^2} =(x_i-\mu)\cdot\Big(-\tfrac{1}{2}\Big)(\sigma^2+\varepsilon)^{-3/2}.

于是

Lσ2=idx^ix^iσ2=(12)(σ2+ε)3/2idx^i(xiμ).\frac{\partial L}{\partial \sigma^2} =\sum_i d\hat{x}_i\frac{\partial \hat{x}_i}{\partial \sigma^2} =\Big(-\tfrac{1}{2}\Big)(\sigma^2+\varepsilon)^{-3/2}\sum_i d\hat{x}_i(x_i-\mu).

又由于

σ2xi=2m(xiμ)\frac{\partial \sigma^2}{\partial x_i}=\frac{2}{m}(x_i-\mu)

所以

Lσ2σ2xi=2m(xiμ)(12)(σ2+ε)3/2jdx^j(xjμ){\frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial x_i}}=\frac{2}{m}(x_i-\mu)*\Big(-\tfrac{1}{2}\Big)(\sigma^2+\varepsilon)^{-3/2}\sum_j d\hat{x}_j(x_j-\mu)

将以上求得的三部分相加

  Lxi=inv_stdm(mdx^iS1x^iS2)  \boxed{\;\frac{\partial L}{\partial x_i} =\frac{\text{inv\_std}}{m}\Big(m\,d\hat{x}_i - S_1 - \hat{x}_i\,S_2\Big)\;}
S1(db)=jdx^j,S2(ds)=jdx^jx^j.S_1(db)=\sum_j d\hat{x}_j,\qquad S_2(ds)=\sum_j d\hat{x}_j\,\hat{x}_j.

relu 的反向传播 relu 前向公式

y=ReLU(x)=max(0,x)y = \text{ReLU}(x) = \max(0, x)

对输入 xx 求导:

yx={1如果 x>00如果 x<0\frac{\partial y}{\partial x} = \begin{cases} 1 & \text{如果 } x > 0 \\ 0 & \text{如果 } x < 0 \end{cases}

假设我们有损失函数 LL,它依赖于 ReLU 的输出 yy,我们要求

Lx\frac{\partial L}{\partial x}

由链式法则:

Lx=Lyyx\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x}

其中:

  • Ly\frac{\partial L}{\partial y} 来自上游(前一层传下来的梯度)
  • yx\frac{\partial y}{\partial x} 是 ReLU 的导数(0 或 1)

所以:

Lx={Ly如果 x>00如果 x0\frac{\partial L}{\partial x} = \begin{cases} \frac{\partial L}{\partial y} & \text{如果 } x > 0 \\ 0 & \text{如果 } x \leq 0 \end{cases}

在实际框架里,比如 PyTorch,ReLU backward 常写成:

grad_input = grad_output * (x > 0)

所以整个bn+relu+卷积的反向流程如图所示,这里类似前向分为两个算子,卷积 relu 为一个融合算子,求db ds为另一个算子

image.png

三、融合方式

预编译模式 不灵活但是算子可定制优化

动态生成编译模式 灵活但是算子不好定制优化性能略差