cublasSgemm 用法详解

2,376 阅读4分钟

介绍

cublasSgemm是CUDA的cublas库的矩阵相乘函数,与cblas中矩阵相乘函数不同,cublas中矩阵的存储是列优先的,所以使用时容易造成误解。 官方链接: docs.nvidia.com/cuda/cublas…

行优先与列优先

1 2 3                
4 5 6  
7 8 9   

行优先存储

1 2 3 4 5 6 7 8 9

列优先存储

1 4 7 2 5 8 3 6 9

参数解析

假设输入矩阵为A,B,C,返回矩阵为C,则本函数主要的运算为

C=alphaAB+betaCC = alpha * A * B + beta * C

其中 A 满足 (m,k)(m,k) ,B 满足 (n,k)(n, k), C满足 (m,n)(m,n)

cublasStatus_t cublasSgemm(cublasHandle_t handle,
                           cublasOperation_t transa, cublasOperation_t transb,
                           int m, int n, int k,
                           const float *alpha,
                           const float *A, int lda,
                           const float *B, int ldb,
                           const float *beta,
                           float *C, int ldc)
参数数据位置输入/输出含义
handle输入cublas库上下文句柄
transa输入矩阵A是否需要转置成行优先
transb输入矩阵B是否需要转置成行优先
m输入矩阵A和矩阵C的行数
n输入矩阵B和矩阵C的列数
k输入矩阵A的列数,矩阵B的行数
alpha宿主机器/设备输入AB相乘的系数
A设备输入矩阵A
lda输入如果不转置,为A的行数即m, 否则为A的列数k
B设备输入矩阵B
ldb输入如果不转置,为B的行数即k, 否则为B的列数n
beta宿主机器/设备输入与C相加的系数
C设备输入/输出矩阵C,无论AB是否转置,都为列优先
ldc输入C的行数m

示例详解

已知

// m = 2, n = 2, k = 3
std::vector<float> vA{1,2,3,4,5,6};  //2x3
std::vector<float> vB{1,2,3,4,5,6};  //3x2

情况1

AB均为列存储

/*
A           B          C
1 3 5       1 4        22 49
2 4 6       2 5        28 64
            3 6
*/
cublasSgemm(pCuBlas, 
         CUBLAS_OP_N,  // 列存储 不需要转置
         CUBLAS_OP_N,  // 列存储 不需要转置
         2, 2, 3,  // m, n, k
         &fAlpha, 
         pCudaA, 2,  // lda  = m , 没有转置,所以是A的行数
         pCudaB, 3,  // ldb  = k ,没有转置,所以是B的行数
         &fBeta, 
         pCudaC, 2   // ldc  = m
    );

情况2

A为行存储, B为列存储

/*
A           B          C
1 2 3       1 4        14 32
4 5 6       2 5        32 77
            3 6
*/
cublasSgemm(pCuBlas, 
         CUBLAS_OP_T,  // 行存储 需要转置
         CUBLAS_OP_N,  // 列存储 不需要转置
         2, 2, 3,  // m, n, k
         &fAlpha, 
         pCudaA, 3,  // lda  = k , 转置,所以是A的列数
         pCudaB, 3,  // ldb  = k ,没有转置,所以是B的行数
         &fBeta, 
         pCudaC, 2   // ldc  = m
    );

情况3

A为列存储, B为行存储

/*
A           B          C
1 3 5       1 2        35 44
2 4 6       3 4        44 56
            5 6
*/
cublasSgemm(pCuBlas, 
         CUBLAS_OP_N,  // 列存储 不需要转置
         CUBLAS_OP_T,  // 行存储 需要转置
         2, 2, 3,  // m, n, k
         &fAlpha, 
         pCudaA, 2,  // lda  = m , 没有转置,所以是A的行数
         pCudaB, 2,  // ldb  = n ,转置,所以是B的列数
         &fBeta, 
         pCudaC, 2   // ldc  = m
    );

情况4

AB均为行存储

/*
A           B          C
1 2 3       1 2        22 28
4 5 6       3 4        49 64
            5 6
*/
cublasSgemm(pCuBlas, 
         CUBLAS_OP_T,  // 行存储 需要转置
         CUBLAS_OP_T,  // 行存储 需要转置
         2, 2, 3,  // m, n, k
         &fAlpha, 
         pCudaA, 3,  // lda  = k , 转置,所以是A的列数
         pCudaB, 2,  // ldb  = n ,转置,所以是B的列数
         &fBeta, 
         pCudaC, 2   // ldc  = m
    );

规律总结

  • AB列优先不需要转置,行优先需要转置
  • C始终为列优先
  • m,n,k 无论是否转置始终不变
  • lda A不转置时为行数m,转置时为列数k
  • ldb B不转置时为行数k,转置时为列数n
  • ldc 无论是否转置,都为m

完整代码

#include "cuda_runtime.h"
#include "cublas_v2.h"

#include <vector>
#include <cstdio>

int main(int argc, char* argv[])
{
    cublasHandle_t pCuBlas = nullptr;
    float *pCudaB = nullptr;
    float *pCudaA = nullptr;
    float *pCudaC = nullptr;

    cudaError_t cudaStat = cudaSetDevice(0);
    cudaStat = cudaMalloc((void**)&pCudaA, 2 * 3 * sizeof(float));
    cudaStat = cudaMalloc((void**)&pCudaB, 3 * 2 * sizeof(float));
    cudaStat = cudaMalloc((void**)&pCudaC, 2 * 2 * sizeof(float));
    cublasStatus_t stat = cublasCreate(&pCuBlas);

    std::vector<float> vA{1,2,3,4,5,6};
    std::vector<float> vB{1,2,3,4,5,6};
    std::vector<float> vC1(4,0);
    std::vector<float> vC2(4,0);
    std::vector<float> vC3(4,0);
    std::vector<float> vC4(4,0);

    cudaMemcpy(pCudaA, vA.data(), sizeof(float)*6, cudaMemcpyHostToDevice);
    cudaMemcpy(pCudaB, vB.data(), sizeof(float)*6, cudaMemcpyHostToDevice);
    float fAlpha = 1, fBeta = 0;

    stat = cublasSgemm(pCuBlas, 
         CUBLAS_OP_N, CUBLAS_OP_N, 
         2, 2, 3, &fAlpha, pCudaA, 
         2, pCudaB,
         3, &fBeta, pCudaC,
         2
    );
    cudaMemcpy(vC1.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);

    stat = cublasSgemm(pCuBlas, 
         CUBLAS_OP_T, CUBLAS_OP_N, 
         2, 2, 3, &fAlpha, pCudaA, 
         3, pCudaB,
         3, &fBeta, pCudaC,
         2
    );

    cudaMemcpy(vC2.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
    stat = cublasSgemm(pCuBlas, 
         CUBLAS_OP_N, CUBLAS_OP_T, 
         2, 2, 3, &fAlpha, pCudaA, 
         2, pCudaB,
         2, &fBeta, pCudaC,
         2
    );
    cudaMemcpy(vC3.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
    stat = cublasSgemm(pCuBlas, 
         CUBLAS_OP_T, CUBLAS_OP_T, 
         2, 2, 3, &fAlpha, pCudaA, 
         3, pCudaB,
         2, &fBeta, pCudaC,
         2
    );
    cudaMemcpy(vC4.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);


    for(int i = 0 ;i < vA.size(); i++)
        printf("A %d: %f\n", i, vA[i]);
    for(int i = 0 ;i < vB.size(); i++)
        printf("B %d: %f\n", i, vB[i]);
    for(int i = 0 ;i < vC1.size(); i++)
        printf("C1 %d: %f\n", i, vC1[i]);
    for(int i = 0 ;i < vC2.size(); i++)
        printf("C2 %d: %f\n", i, vC2[i]);
    for(int i = 0 ;i < vC3.size(); i++)
        printf("C3 %d: %f\n", i, vC3[i]);
    for(int i = 0 ;i < vC4.size(); i++)
        printf("C4 %d: %f\n", i, vC4[i]);
        
    if(pCudaA) cudaFree(pCudaA);
    if(pCudaB) cudaFree(pCudaB);
    if(pCudaC) cudaFree(pCudaC);
    if(pCuBlas) cublasDestroy(pCuBlas);

    return 0;

}
    

输出

A 0: 1.000000
A 1: 2.000000
A 2: 3.000000
A 3: 4.000000
A 4: 5.000000
A 5: 6.000000
B 0: 1.000000
B 1: 2.000000
B 2: 3.000000
B 3: 4.000000
B 4: 5.000000
B 5: 6.000000
C1 0: 22.000000
C1 1: 28.000000
C1 2: 49.000000
C1 3: 64.000000
C2 0: 14.000000
C2 1: 32.000000
C2 2: 32.000000
C2 3: 77.000000
C3 0: 35.000000
C3 1: 44.000000
C3 2: 44.000000
C3 3: 56.000000
C4 0: 22.000000
C4 1: 49.000000
C4 2: 28.000000
C4 3: 64.000000