深挖flash attention框架代码

529 阅读1分钟

flash-attention代码解读:自定义激活函数和反向传播

这段代码实现了几种激活函数及其反向传播的逻辑。让我们逐步解读每个部分的实现和理论基础。

1. bias_gelu 和 bias_gelu_back 函数

@torch.jit.script
def bias_gelu(y, bias):
    x = bias + y
    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)

@torch.jit.script
def bias_gelu_back(g, y, bias):
    x = bias + y
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
    grad_y = ff * g
    return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
  • bias_gelu(y, bias):

    • 这是一个实现了带有偏置的GELU激活函数的函数。GELU(Gaussian Error Linear Unit)是一种激活函数,通常用于神经网络的隐藏层。
    • 函数首先将输入 (y)( y ) 和偏置 (bias)( bias ) 相加得到 (x)( x )
    • 然后计算 (0.5×(1.0+tanh(...)))( 0.5 \times (1.0 + \tanh(...)) ),其中 (tanh)( \tanh ) 函数的参数是一个多项式,用于近似GELU函数的形状。
    • 返回值是经过GELU激活的结果,类型与输入 (y)( y ) 的数据类型相同。
  • bias_gelu_back(g, y, bias):

    • 这个函数实现了 bias_gelu 的反向传播。
    • 输入参数包括梯度 (g)( g )、输入 (y)( y ) 和偏置 (bias)( bias )
    • 首先计算 (x)( x ),然后根据GELU激活函数的导数公式计算梯度。
    • 返回 (grady)( grad_y ),这是输入 (y)( y ) 的梯度,以及 (bias)( bias ) 的梯度。

2. GeLUFunction 类和 FastGeLUFunction 类

class GeLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, bias):
        ctx.save_for_backward(input, bias)
        return bias_gelu(input, bias)

    @staticmethod
    def backward(ctx, grad_output):
        input, bias = ctx.saved_tensors
        tmp = bias_gelu_back(grad_output, input, bias)
        return tmp, tmp

bias_gelu_impl = GeLUFunction.apply

class FastGeLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return gelu_fwd(input)

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        tmp = gelu_bwd(grad_output, input)
        return tmp

fast_gelu_impl = FastGeLUFunction.apply
  • GeLUFunction 类和 FastGeLUFunction 类都是基于 PyTorch 的 torch.autograd.Function 实现的自定义激活函数类。
  • GeLUFunction:
    • forward 方法保存输入 (input)( input ) 和偏置 (bias)( bias ),然后调用 bias_gelu 函数得到输出。
    • backward 方法根据保存的 (input)( input )(bias)( bias ),以及输入的梯度 (gradoutput)( grad_output ),调用 bias_gelu_back 函数进行反向传播计算。
  • FastGeLUFunction:
    • forward 方法只保存输入 (input)( input ),然后调用 gelu_fwd 函数得到输出。
    • backward 方法根据保存的 (input)( input ) 和输入的梯度 (gradoutput)( grad_output ),调用 gelu_bwd 函数进行反向传播计算。

3. 其他激活函数

除了 GELU 相关的实现外,代码还包含了其他激活函数及其反向传播的实现:

  • sqrelu_fwd(x)sqrelu_bwd(g, x): 实现了平方ReLU激活函数及其反向传播。
  • relu_bwd(g, x): 实现了ReLU激活函数的反向传播。

这些激活函数在神经网络的训练过程中起着重要作用,通过非线性变换来增强神经网络的表示能力,并在反向传播过程中传递梯度信息。

4. SwiGLUFunction 类

class SwiGLUFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return swiglu_fwd(x, y)

    @staticmethod
    def backward(ctx, dout):
        x, y = ctx.saved_tensors
        return swiglu_bwd(x, y, dout)

swiglu = SwiGLUFunction.apply
  • SwiGLUFunction 类实现了一种名为 SwiGLU 的激活函数及其反向传播。
  • forward 方法保存输入 (x)( x )(y)( y ),然后调用 swiglu_fwd 函数计算输出。
  • backward 方法根据保存的 (x)( x )(y)( y ),以及输入的梯度 (dout)( dout ),调用 swiglu_bwd 函数进行反向传播计算。

公式

以下是各种激活函数的数学公式表示:

  • GELU:
GELU(x)=x0.5(1+tanh(0.79788456x(1+0.044715x2))) \text{GELU}(x) = x \cdot 0.5 \cdot \left(1 + \tanh\left(0.79788456 \cdot x \cdot \left(1 + 0.044715 \cdot x^2\right)\right)\right)
  • Fast GELU:
FastGELU(x)=x0.5(1+tanh(0.79788456x(1+0.044715x2))) \text{FastGELU}(x) = x \cdot 0.5 \cdot \left(1 + \tanh\left(0.79788456 \cdot x \cdot \left(1 + 0.044715 \cdot x^2\right)\right)\right)
  • SwiGLU:
SwiGLU(x,y)=xy1+ex \text{SwiGLU}(x, y) = \frac{x \cdot y}{1 + e^{-x}}
  • ReLU:
ReLU(x)=max(0,x) \text{ReLU}(x) = \max(0, x)
  • 平方ReLU:
SquareReLU(x)=(ReLU(x))2 \text{SquareReLU}(x) = (\text{ReLU}(x))^2

这些公式描述了每个激活函数在输入 (x)( x ) 下的计算方式,并且反向传播时计算梯度以便于神经网络参数的更新。这些激活函数不仅仅是数学函数,它们通过非线性变换为神经网络的学习能力增加了更多的可能性。