Softmax 被称为访存密集型算子(memory-bound operator),原因如下:
1.softmax计算流程
以最常见的按行 softmax(如在 NLP 中的 attention)为例,它的计算主要包含以下几步:
- 最大值计算(为了数值稳定性):
- 指数运算并求和:
- 归一化:
2.为什么访存密集(Memory-bound)
这些操作共同特点是:
特性 | 原因 |
---|---|
频繁访问内存 | 多次读写输入 x ,保存中间值(如 m_i , s_i , exp(x_{ij}) ) |
计算量低于数据传输量 | 每个元素最多几次浮点操作(加法、指数、除法),但必须完整读写内存 |
缺乏重用性 | 中间结果不能长期缓存,数据局部性差 |
缓存不友好 | 行维度可能较短,容易导致 cache miss,特别在 batch 较小时 |
结果是:计算不复杂,但大量的数据移动消耗了带宽,这使得 softmax 性能受限于内存带宽,而不是算力。
3.对比计算密集型算子
类型 | 示例 | 特点 |
---|---|---|
计算密集型 | 矩阵乘法、卷积 | 每读取一次数据,可重复使用多次进行运算(高 FLOPs / Byte) |
访存密集型 | Softmax、LayerNorm | 每读取的数据只用一次,FLOPs / Byte 比低 |
4.举个例子
对于一个二维矩阵输入 X ∈ [B, N]
(batch size 为 B,每行 N 个元素),典型 softmax 实现分为三步:
4.1 Softmax 计算步骤
4.2 定量分析:访存量 vs 运算量
假设输入:
- Batch size B=1024B = 1024B=1024
- 每行元素数 N=1024N = 1024N=1024
- 数据类型为
float32
(每个元素 4 字节) 1. 访存量(Memory Access)
2.计算量(FLOPs)
3. FLOP / Byte Ratio
- 总计算量 ≈ 5.2 MFLOPs
- 总访存 ≈ 20 MB = 20 × 2²⁰ ≈ 21 × 10⁶ Bytes
结果:
这个比值很低(相比典型的计算密集型操作如 matmul,FLOP/Byte 往往能到 5~20),说明这是内存受限的操作。
4.3 对比:矩阵乘法(MatMul)FLOP/Byte
比如做:
这是计算密集型操作的典型特征。
4.4 总结
操作类型 | FLOP / Byte(越大越计算密集) | 访存瓶颈? |
---|---|---|
Softmax | ≈ 0.25 | ✅ 是 |
矩阵乘法 | ≈ 167 | ❌ 否 |
因此,Softmax 是访存密集型的(memory-bound),主要瓶颈在内存带宽,不是计算性能。
3.优化启示
- 优化 softmax 通常需要减少访存次数(比如 fused kernel 把多个操作合并一次访存完成)
- 在 GPU 上优化 softmax 也强调 warp-level 的高带宽读写而非算力利用
3.1 fused kernel
3.1.1 在 GPU 或 HPC 系统上,有两个关键挑战:
❶ 内存访问次数多、读写分散:
-
典型实现需要:
- 多次读取输入
- 多次写入/读取中间变量(如
exp(x)
,sum
,buffer
,output
) - 数据写回 global memory
这种高频访问 global memory 会浪费大量带宽。
❷ 每个阶段独立 Kernel 调用 → 调度 & latency 浪费
-
默认实现中:
max
,exp
,sum
,div
各自是独立的 CUDA kernel- 4 个 kernel 启动 → 启动开销 + 中间结果写回 + 同步开销大
3.1.2 解决方案:Kernel Fusion
核心思想:将多个 softmax 步骤合并为一个 kernel,减少 memory 访问和 kernel 启动开销
3.1.3 例子:Fused Softmax kernel 包含以下逻辑
global void fused_softmax(float* x, float* out, int N) { // 使用 shared memory 缓存一行 shared float row[1024];
int tid = threadIdx.x;
int bid = blockIdx.x;
// Step 1: 加载、找最大值
float max_val = -INFINITY;
for (int i = tid; i < N; i += blockDim.x) {
float val = x[bid * N + i];
row[i] = val;
max_val = max(max_val, val);
}
__syncthreads();
// Step 2: 计算 exp 和 sum
float sum = 0.0f;
for (int i = tid; i < N; i += blockDim.x) {
row[i] = expf(row[i] - max_val);
sum += row[i];
}
__syncthreads();
// Step 3: 写回归一化值
for (int i = tid; i < N; i += blockDim.x) {
out[bid * N + i] = row[i] / sum;
}
}
🎯 优化效果
优化方式 | 说明 | 效果 |
---|---|---|
Kernel Fusion | 减少 kernel 调用、内存读写次数 | 运行时间减少 30%~50% |
Shared Memory 缓存 | 减少 global memory 带宽压力 | 进一步提升效率 |
Vectorized Load | 每次加载多个元素(如 float4 ) | 提高吞吐 |
Warp-wise Reduce | 使用 warp shuffle 优化 max/sum | 减少同步开销 |
3.1.4 案例:深度学习框架中对 softmax 的优化
框架 | 优化版本 | 使用的技术 |
---|---|---|
TensorFlow | tf.nn.softmax | XLA JIT 编译 + kernel fusion |
PyTorch | torch.nn.functional.softmax | Fused kernel in CUDA |
ONNX | onnxruntime + kernel fusion | TVM / triton fused kernels |
3.1.5 总结:为什么要做 kernel fusion for softmax?
问题 | Kernel Fusion 如何缓解 |
---|---|
多次访问内存开销大 | 合并计算步骤,减少 load/store |
多个 kernel 启动开销大 | 一个 kernel 完成所有步骤 |
多个中间变量占用 cache | 用 register/shared memory 保持局部性 |
总体算力利用率低 | 提升 GPU occupancy,降低 latency |
3.2 warp-level 的高带宽读写
在 GPU 上优化 softmax 时,核心关注点是如何高效地进行数据的读写(特别是 warp-level 的内存访问),而不是提升计算单元的利用率(算力/FLOPs) 。
3.2.1 Warp-level 优化是关键
Warp = 一组 32 个线程组成一个 GPU 最小调度单元
GPU 的访存效率依赖于 warp 中线程的 内存访问是否对齐(coalesced access):
- ✅ 优化目标:让 warp 内线程读写连续地址,避免分散、非对齐、bank conflict
- 例子:warp 内 thread0 读 x[0],thread1 读 x[1],…thread31 读 x[31],则是最佳对齐
在 softmax 中:
- 如果每一行的数据量不能正好被 warp 整除,或者数据分布零散,就会造成访问浪费
- 因此,softmax 的优化 更像是访存调度优化,而不是算力优化
3.2.2 举例说明:
比如你要对一个矩阵 [batch_size=128, dim=96]
做 softmax:
- 你可以把每一行分配给一个 warp
- 让 warp 的 32 个线程分别加载一段数据
- 在 shared memory 中进行 reduce(max/sum)操作
- 最终用 warp shuffle 实现并行归一化
- → 全程避免 global memory 多次读写、只读一次写一次
这种方式比用多个 kernel、反复访存更高效。
3.2.3 总结
传统优化重点 | Softmax 优化重点 |
---|---|
提升算力利用率 | 提升访存效率(coalesced memory access) |
增加并行线程 | 提高 warp 内内存对齐、高带宽读写 |
更大 tile/block | 更高效地加载/缓存数据(如 shared memory) |