记一次transpose+add算子的优化

132 阅读5分钟

原文位于:zhuanlan.zhihu.com/p/305126524… 要加速的pytorch原代码:

torch_output = input1.transpose(0, 1) + input2

加速的特定shape:

input1 = torch.randn([24300, 11520], device="cuda:0", dtype=torch.bfloat16) 
input2 = torch.randn([11520, 24300], device="cuda:0", dtype=torch.bfloat16)

第一版cuda代码:


#include <algorithm>
#include <iostream>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <torch/python.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>

#define BLOCK_DIM 32

//idata_1:[height, width]
//idata_2:[width, height]
__global__ void add_two_transpose(__nv_bfloat16 *odata, __nv_bfloat16 *idata_1,  __nv_bfloat16 *idata_2, int width, int height)
{
	// __shared__ __nv_bfloat16 block[BLOCK_DIM][BLOCK_DIM+1];
	__shared__ float4 block[BLOCK_DIM][BLOCK_DIM+1];

	// read the matrix tile into shared memory
    // load one element per thread from device memory (idata) and store it
    // in transposed order in block[][]
	unsigned int xIndex = (blockIdx.x * BLOCK_DIM + threadIdx.x)* 8;
	unsigned int yIndex = blockIdx.y * BLOCK_DIM + threadIdx.y;
	if((xIndex < width) && (yIndex < height))
	{
		unsigned int index_in = yIndex * width + xIndex;
		block[threadIdx.y][threadIdx.x] = reinterpret_cast<float4*>(idata_1 + index_in)[0];
	}

        // synchronise to ensure all writes to block[][] have completed
	__syncthreads();

	// write the transposed matrix tile to global memory (odata) in linear order
	xIndex = blockIdx.y * BLOCK_DIM + threadIdx.x;
	yIndex = (blockIdx.x * BLOCK_DIM + threadIdx.y) * 8;
	if((xIndex < height) && (yIndex < width))
	{
		unsigned int index_out = yIndex * height + xIndex;
		// odata[index_out] = block[threadIdx.x][threadIdx.y];// + idata_2[index_out];
        float4 temp = block[threadIdx.x][threadIdx.y];
        __nv_bfloat16 *temp_ptr = reinterpret_cast<__nv_bfloat16*>(&temp);
#pragma unroll
        for(int i=0; i<8; i++){
            odata[index_out] = temp_ptr[i] + idata_2[index_out];
            index_out += height;
        }
	}
}


torch::Tensor torch_modulate_attn(const torch::Tensor& buf0, const torch::Tensor& buf1) {
    int height = buf0.size(0);
    int width = buf0.size(1);
    torch::Tensor output = torch::empty({buf0.size(1), buf0.size(0)}, torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA));
    dim3 dimBlock(BLOCK_DIM, BLOCK_DIM);
    dim3 dimGrid((width + BLOCK_DIM*8 - 1) / (BLOCK_DIM*8), (height + BLOCK_DIM - 1) / BLOCK_DIM);
    add_two_transpose<<<dimGrid, dimBlock>>>(
        (__nv_bfloat16*)output.data_ptr(),
        (__nv_bfloat16*)buf0.data_ptr(),
        (__nv_bfloat16*)buf1.data_ptr(),
        width,
        height
    );

    return output;
}

测试脚本:

import torch
import pytest
from torch import nn
from ops import torch_modulate_attn
import random

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
torch.cuda.manual_seed(42)

@torch.compile
def compile_fn(input1, input2):
    return (input1.transpose(0, 1) + input2).contiguous()

if __name__ == "__main__":
    import time

    # 设置随机种子,确保结果可复现
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    
    # 定义测试输入数据 
    input1 = torch.randn([24300, 11520], device="cuda:0", dtype=torch.bfloat16)
    input2 = torch.randn([11520, 24300], device="cuda:0", dtype=torch.bfloat16)
    torch_time_list, custom_time_list, compile_time_list = [], [], []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    for _ in range(10):
        custom_output = torch_modulate_attn(input1, input2)
        torch_output = input1.transpose(0, 1) + input2
        compile_output = compile_fn(input1, input2)
        assert torch.equal(custom_output, torch_output)

    for _ in range(10):
        start.record()
        torch_output = input1.transpose(0, 1) + input2
        end.record()
        torch.cuda.synchronize()        
        torch_time_list.append(start.elapsed_time(end))

        start.record()
        custom_output = torch_modulate_attn(input1, input2)
        end.record()
        torch.cuda.synchronize()
        custom_time_list.append(start.elapsed_time(end))

        start.record()
        compile_output = compile_fn(input1, input2)
        end.record()
        torch.cuda.synchronize()
        compile_time_list.append(start.elapsed_time(end))

    print(f"torch: {sum(torch_time_list)/len(torch_time_list)} ms")
    print(f"custom: {sum(custom_time_list)/len(custom_time_list)} ms")
    print(f"compile: {sum(compile_time_list)/len(compile_time_list)} ms")

测试GPU:A800

测试结果:

torch: 0.4977471947669983 ms 
custom: 0.1826815977692604 ms 
compile: 0.16609280109405516 ms

速度比torch.compile慢,但是如果仅保留kernel中transpose的部分,则速度比torch.compile快,猜测原因系该kernel仅对idata_1进行float4向量化读取加速,而没有对idata_2进行向量化加速。

由于idate_2的width为24300,只能被4和2整除,因此考虑采用float2或者float进行向量化(原始数据类型为bf16),先尝试float2,结果报错使用的共享内存太多,因此改为使用float,代码如下:

#include <algorithm>
#include <iostream>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <torch/python.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>

#define BLOCK_DIM 32

//idata_1:[height, width]
//idata_2:[width, height]
__global__ void add_two_transpose(__nv_bfloat16 *odata, __nv_bfloat16 *idata_1,  __nv_bfloat16 *idata_2, int width, int height)
{
	__shared__ float4 block[BLOCK_DIM * 2][BLOCK_DIM+1];

#pragma unroll    
    for(int i=0; i<2; i++){
        unsigned int xIndex = blockIdx.x * BLOCK_DIM * 8 + threadIdx.x * 8;
	    unsigned int yIndex = blockIdx.y * BLOCK_DIM * 2 + threadIdx.y + 32 * i;
        unsigned int index_in = yIndex * width + xIndex;
        if((xIndex < width) && (yIndex < height))
        {
            block[threadIdx.y + 32 * i][threadIdx.x] = reinterpret_cast<float4*>(idata_1 + index_in)[0];
        }
    }

    // synchronise to ensure all writes to block[][] have completed
	__syncthreads();
    __nv_bfloat16* block_bf16 = reinterpret_cast<__nv_bfloat16*>(block);
	
#pragma unroll    
    for(int i = 0; i < 8; i++){
        unsigned int xIndex = blockIdx.y * BLOCK_DIM * 2 + threadIdx.x * 2;
        unsigned int yIndex = blockIdx.x * BLOCK_DIM * 8 + threadIdx.y + 32 * i;
        if((xIndex < height) && (yIndex < width))
        {
            unsigned int index_out = yIndex * height + xIndex;
            float temp_2 = reinterpret_cast<float*>(idata_2 + index_out)[0];
            __nv_bfloat16* temp_2_ptr = reinterpret_cast<__nv_bfloat16*>(&temp_2);
#pragma unroll
            for(int j = 0; j < 2; j++){
                temp_2_ptr[j] = temp_2_ptr[j] + block_bf16[(threadIdx.x * 2 + j) * 264 + (threadIdx.y + 32 * i)];
            }
            reinterpret_cast<float*>(odata + index_out)[0] = temp_2;
        }
    }
}


torch::Tensor torch_modulate_attn(const torch::Tensor& buf0, const torch::Tensor& buf1) {
    int height = buf0.size(0);
    int width = buf0.size(1);
    torch::Tensor output = torch::empty({buf0.size(1), buf0.size(0)}, torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA));
    dim3 dimBlock(BLOCK_DIM, BLOCK_DIM);
    dim3 dimGrid((width + BLOCK_DIM*8 - 1) / (BLOCK_DIM*8), (height + BLOCK_DIM*2 - 1) / (BLOCK_DIM*2));
    add_two_transpose<<<dimGrid, dimBlock>>>(
        (__nv_bfloat16*)output.data_ptr(),
        (__nv_bfloat16*)buf0.data_ptr(),
        (__nv_bfloat16*)buf1.data_ptr(),
        width,
        height
    );

    return output;
}

测试结果:

torch: 0.4959743946790695 ms 
custom: 0.13992959931492804 ms 
compile: 0.16209919899702072 ms

感触:

在CUDA编程中,优化性能主要围绕内存访问和计算效率两个核心方向展开。在内存优化方面,关键策略包括实现内存访问的合并(coalesced memory access)以及引入向量化(vectorization)技术,以最大化内存带宽的利用率。相比之下,计算优化方面的可操作空间相对有限,除非开发者具备设计全新算法的高级能力,否则计算优化通常受限于硬件架构和现有算法的固有特性。

具体而言,在处理矩阵转置(transpose)等操作时,优先考虑利用共享内存(shared memory)来实现内存访问的合并,从而减少全局内存(global memory)访问的延迟。此外,通过向量化技术可以进一步提升内存访问的效率。然而,CUDA编程中的许多性能瓶颈难以通过理论模型精确预测,通常需要依赖实验驱动的迭代优化方法,即通过实际运行代码并基于性能分析结果进行逐步调整。

这种基于实验的优化方法(empirical optimization)在CUDA编程中尤为常见,因为硬件行为、内存层次结构以及线程调度等因素的复杂性使得理论建模往往无法完全捕捉实际运行时的性能特征。因此,开发者需要在理论分析与实验验证之间找到平衡,以实现最优的性能提升。