为什么说Softmax是访存密集型算子?

0 阅读5分钟

Softmax 被称为访存密集型算子(memory-bound operator),原因如下:

1.softmax计算流程

以最常见的按行 softmax(如在 NLP 中的 attention)为例,它的计算主要包含以下几步:

  1. 最大值计算(为了数值稳定性):

image.png

  1. 指数运算并求和

image.png

  1. 归一化

image.png

2.为什么访存密集(Memory-bound)

这些操作共同特点是:

特性原因
频繁访问内存多次读写输入 x,保存中间值(如 m_i, s_i, exp(x_{ij})
计算量低于数据传输量每个元素最多几次浮点操作(加法、指数、除法),但必须完整读写内存
缺乏重用性中间结果不能长期缓存,数据局部性差
缓存不友好行维度可能较短,容易导致 cache miss,特别在 batch 较小时

结果是:计算不复杂,但大量的数据移动消耗了带宽,这使得 softmax 性能受限于内存带宽,而不是算力。

3.对比计算密集型算子

类型示例特点
计算密集型矩阵乘法、卷积每读取一次数据,可重复使用多次进行运算(高 FLOPs / Byte)
访存密集型Softmax、LayerNorm每读取的数据只用一次,FLOPs / Byte 比低

4.举个例子

对于一个二维矩阵输入 X ∈ [B, N](batch size 为 B,每行 N 个元素),典型 softmax 实现分为三步:

4.1 Softmax 计算步骤

image.png

4.2 定量分析:访存量 vs 运算量

假设输入:

  • Batch size B=1024B = 1024B=1024
  • 每行元素数 N=1024N = 1024N=1024
  • 数据类型为 float32(每个元素 4 字节) 1. 访存量(Memory Access)

image.png 2.计算量(FLOPs)

image.png

3. FLOP / Byte Ratio

  • 总计算量 ≈ 5.2 MFLOPs
  • 总访存 ≈ 20 MB = 20 × 2²⁰ ≈ 21 × 10⁶ Bytes

结果:

image.png

这个比值很低(相比典型的计算密集型操作如 matmul,FLOP/Byte 往往能到 5~20),说明这是内存受限的操作

4.3 对比:矩阵乘法(MatMul)FLOP/Byte

比如做:

image.png 这是计算密集型操作的典型特征。

4.4 总结

操作类型FLOP / Byte(越大越计算密集)访存瓶颈?
Softmax≈ 0.25✅ 是
矩阵乘法≈ 167❌ 否

因此,Softmax 是访存密集型的(memory-bound),主要瓶颈在内存带宽,不是计算性能

3.优化启示

  • 优化 softmax 通常需要减少访存次数(比如 fused kernel 把多个操作合并一次访存完成)
  • 在 GPU 上优化 softmax 也强调 warp-level 的高带宽读写而非算力利用

3.1 fused kernel

3.1.1 在 GPU 或 HPC 系统上,有两个关键挑战:

❶ 内存访问次数多、读写分散:
  • 典型实现需要:

    • 多次读取输入
    • 多次写入/读取中间变量(如 exp(x), sum, buffer, output
    • 数据写回 global memory

这种高频访问 global memory 会浪费大量带宽。

❷ 每个阶段独立 Kernel 调用 → 调度 & latency 浪费
  • 默认实现中:

    • max, exp, sum, div 各自是独立的 CUDA kernel
    • 4 个 kernel 启动 → 启动开销 + 中间结果写回 + 同步开销大

3.1.2 解决方案:Kernel Fusion

核心思想:将多个 softmax 步骤合并为一个 kernel,减少 memory 访问和 kernel 启动开销

3.1.3 例子:Fused Softmax kernel 包含以下逻辑

global void fused_softmax(float* x, float* out, int N) { // 使用 shared memory 缓存一行 shared float row[1024];

int tid = threadIdx.x;
int bid = blockIdx.x;

// Step 1: 加载、找最大值
float max_val = -INFINITY;
for (int i = tid; i < N; i += blockDim.x) {
    float val = x[bid * N + i];
    row[i] = val;
    max_val = max(max_val, val);
}
__syncthreads();

// Step 2: 计算 exp 和 sum
float sum = 0.0f;
for (int i = tid; i < N; i += blockDim.x) {
    row[i] = expf(row[i] - max_val);
    sum += row[i];
}
__syncthreads();

// Step 3: 写回归一化值
for (int i = tid; i < N; i += blockDim.x) {
    out[bid * N + i] = row[i] / sum;
}

}

🎯 优化效果

优化方式说明效果
Kernel Fusion减少 kernel 调用、内存读写次数运行时间减少 30%~50%
Shared Memory 缓存减少 global memory 带宽压力进一步提升效率
Vectorized Load每次加载多个元素(如 float4提高吞吐
Warp-wise Reduce使用 warp shuffle 优化 max/sum减少同步开销

3.1.4 案例:深度学习框架中对 softmax 的优化

框架优化版本使用的技术
TensorFlowtf.nn.softmaxXLA JIT 编译 + kernel fusion
PyTorchtorch.nn.functional.softmaxFused kernel in CUDA
ONNXonnxruntime + kernel fusionTVM / triton fused kernels

3.1.5 总结:为什么要做 kernel fusion for softmax?

问题Kernel Fusion 如何缓解
多次访问内存开销大合并计算步骤,减少 load/store
多个 kernel 启动开销大一个 kernel 完成所有步骤
多个中间变量占用 cache用 register/shared memory 保持局部性
总体算力利用率低提升 GPU occupancy,降低 latency

3.2 warp-level 的高带宽读写

在 GPU 上优化 softmax 时,核心关注点是如何高效地进行数据的读写(特别是 warp-level 的内存访问),而不是提升计算单元的利用率(算力/FLOPs)

3.2.1 Warp-level 优化是关键

Warp = 一组 32 个线程组成一个 GPU 最小调度单元

GPU 的访存效率依赖于 warp 中线程的 内存访问是否对齐(coalesced access):

  • 优化目标:让 warp 内线程读写连续地址,避免分散、非对齐、bank conflict
  • 例子:warp 内 thread0 读 x[0],thread1 读 x[1],…thread31 读 x[31],则是最佳对齐

在 softmax 中:

  • 如果每一行的数据量不能正好被 warp 整除,或者数据分布零散,就会造成访问浪费
  • 因此,softmax 的优化 更像是访存调度优化,而不是算力优化

3.2.2 举例说明:

比如你要对一个矩阵 [batch_size=128, dim=96] 做 softmax:

  • 你可以把每一行分配给一个 warp
  • 让 warp 的 32 个线程分别加载一段数据
  • 在 shared memory 中进行 reduce(max/sum)操作
  • 最终用 warp shuffle 实现并行归一化
  • → 全程避免 global memory 多次读写、只读一次写一次

这种方式比用多个 kernel、反复访存更高效。

3.2.3 总结

传统优化重点Softmax 优化重点
提升算力利用率提升访存效率(coalesced memory access)
增加并行线程提高 warp 内内存对齐、高带宽读写
更大 tile/block更高效地加载/缓存数据(如 shared memory)