flash-attention代码解读:自定义激活函数和反向传播
这段代码实现了几种激活函数及其反向传播的逻辑。让我们逐步解读每个部分的实现和理论基础。
1. bias_gelu 和 bias_gelu_back 函数
@torch.jit.script
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
@torch.jit.script
def bias_gelu_back(g, y, bias):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
-
bias_gelu(y, bias):
- 这是一个实现了带有偏置的GELU激活函数的函数。GELU(Gaussian Error Linear Unit)是一种激活函数,通常用于神经网络的隐藏层。
- 函数首先将输入 和偏置 相加得到 。
- 然后计算 ,其中 函数的参数是一个多项式,用于近似GELU函数的形状。
- 返回值是经过GELU激活的结果,类型与输入 的数据类型相同。
-
bias_gelu_back(g, y, bias):
- 这个函数实现了
bias_gelu的反向传播。 - 输入参数包括梯度 、输入 和偏置 。
- 首先计算 ,然后根据GELU激活函数的导数公式计算梯度。
- 返回 ,这是输入 的梯度,以及 的梯度。
- 这个函数实现了
2. GeLUFunction 类和 FastGeLUFunction 类
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
- GeLUFunction 类和 FastGeLUFunction 类都是基于 PyTorch 的
torch.autograd.Function实现的自定义激活函数类。 - GeLUFunction:
forward方法保存输入 和偏置 ,然后调用bias_gelu函数得到输出。backward方法根据保存的 和 ,以及输入的梯度 ,调用bias_gelu_back函数进行反向传播计算。
- FastGeLUFunction:
forward方法只保存输入 ,然后调用gelu_fwd函数得到输出。backward方法根据保存的 和输入的梯度 ,调用gelu_bwd函数进行反向传播计算。
3. 其他激活函数
除了 GELU 相关的实现外,代码还包含了其他激活函数及其反向传播的实现:
- sqrelu_fwd(x) 和 sqrelu_bwd(g, x): 实现了平方ReLU激活函数及其反向传播。
- relu_bwd(g, x): 实现了ReLU激活函数的反向传播。
这些激活函数在神经网络的训练过程中起着重要作用,通过非线性变换来增强神经网络的表示能力,并在反向传播过程中传递梯度信息。
4. SwiGLUFunction 类
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return swiglu_fwd(x, y)
@staticmethod
def backward(ctx, dout):
x, y = ctx.saved_tensors
return swiglu_bwd(x, y, dout)
swiglu = SwiGLUFunction.apply
- SwiGLUFunction 类实现了一种名为 SwiGLU 的激活函数及其反向传播。
forward方法保存输入 和 ,然后调用swiglu_fwd函数计算输出。backward方法根据保存的 和 ,以及输入的梯度 ,调用swiglu_bwd函数进行反向传播计算。
公式
以下是各种激活函数的数学公式表示:
- GELU:
- Fast GELU:
- SwiGLU:
- ReLU:
- 平方ReLU:
这些公式描述了每个激活函数在输入 下的计算方式,并且反向传播时计算梯度以便于神经网络参数的更新。这些激活函数不仅仅是数学函数,它们通过非线性变换为神经网络的学习能力增加了更多的可能性。