llm.c 源码解析——激活函数 GELU

696 阅读1分钟

llm.c 源码解析——激活函数 GELU

GELU(Gaussian Error Linear Unit)是一种激活函数,最初由Dan Hendrycks 和 Kevin Gimpel 在2016 年提出。

img01-gelu-abstract

gelu_forward

Andrej Karpathy 的 C 语言版的 GELU 采用了一种近似的算法(下图中红框内),

img02-gelu-def

void gelu_forward(float* out, float* inp, int N) {
    float s = sqrtf(2.0f / M_PI);
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        out[i] = 0.5f * x * (1.0f + tanhf(s * (x + cube)));
    }
}

而 PyTorch 版的 GELU 可通过 approximate 参数控制是采用精确算法(none)还是近似算法(tanh)。

其中的 M_2_SQRTPI2/sqrt(π)

def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType:
    if not isinstance(a, TensorLike):
        # raise RuntimeError...
    M_SQRT2 = 1.41421356237309504880
    M_SQRT1_2 = 0.70710678118654752440
    M_2_SQRTPI = 1.12837916709551257390
    if approximate == "tanh":
        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
        kKappa = 0.044715
        a_cube = a * a * a
        inner = kBeta * (a + kKappa * a_cube)
        return 0.5 * a * (1 + torch.tanh(inner))
    elif approximate == "none":
        kAlpha = M_SQRT1_2
        return a * 0.5 * (1 + torch.erf(a * kAlpha))
    else:
        raise RuntimeError("approximate argument must be either none or tanh.")

gelu_backward

void gelu_backward(float* dinp, float* inp, float* dout, int N) {
    float s = sqrtf(2.0f / M_PI);
    for (int i = 0; i < N; i++) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = s * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * s * (1.0f + 3.0f * 0.044715f * x * x);
        dinp[i] += local_grad * dout[i];
    }
}

局部梯度 local_grad

image-20240411144849471

image-20240411163732324

image-20240411165342618

参考

mp.weixin.qq.com/s/qb0dhdFnX…

github.com/karpathy/ll…

paperswithcode.com/method/gelu