介绍
cublasSgemm是CUDA的cublas库的矩阵相乘函数,与cblas中矩阵相乘函数不同,cublas中矩阵的存储是列优先的,所以使用时容易造成误解。 官方链接: docs.nvidia.com/cuda/cublas…
行优先与列优先
1 2 3
4 5 6
7 8 9
行优先存储
1 2 3 4 5 6 7 8 9
列优先存储
1 4 7 2 5 8 3 6 9
参数解析
假设输入矩阵为A,B,C,返回矩阵为C,则本函数主要的运算为
其中 A 满足 ,B 满足 , C满足
cublasStatus_t cublasSgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc)
| 参数 | 数据位置 | 输入/输出 | 含义 |
|---|---|---|---|
| handle | 输入 | cublas库上下文句柄 | |
| transa | 输入 | 矩阵A是否需要转置成行优先 | |
| transb | 输入 | 矩阵B是否需要转置成行优先 | |
| m | 输入 | 矩阵A和矩阵C的行数 | |
| n | 输入 | 矩阵B和矩阵C的列数 | |
| k | 输入 | 矩阵A的列数,矩阵B的行数 | |
| alpha | 宿主机器/设备 | 输入 | AB相乘的系数 |
| A | 设备 | 输入 | 矩阵A |
| lda | 输入 | 如果不转置,为A的行数即m, 否则为A的列数k | |
| B | 设备 | 输入 | 矩阵B |
| ldb | 输入 | 如果不转置,为B的行数即k, 否则为B的列数n | |
| beta | 宿主机器/设备 | 输入 | 与C相加的系数 |
| C | 设备 | 输入/输出 | 矩阵C,无论AB是否转置,都为列优先 |
| ldc | 输入 | C的行数m |
示例详解
已知
// m = 2, n = 2, k = 3
std::vector<float> vA{1,2,3,4,5,6}; //2x3
std::vector<float> vB{1,2,3,4,5,6}; //3x2
情况1
AB均为列存储
/*
A B C
1 3 5 1 4 22 49
2 4 6 2 5 28 64
3 6
*/
cublasSgemm(pCuBlas,
CUBLAS_OP_N, // 列存储 不需要转置
CUBLAS_OP_N, // 列存储 不需要转置
2, 2, 3, // m, n, k
&fAlpha,
pCudaA, 2, // lda = m , 没有转置,所以是A的行数
pCudaB, 3, // ldb = k ,没有转置,所以是B的行数
&fBeta,
pCudaC, 2 // ldc = m
);
情况2
A为行存储, B为列存储
/*
A B C
1 2 3 1 4 14 32
4 5 6 2 5 32 77
3 6
*/
cublasSgemm(pCuBlas,
CUBLAS_OP_T, // 行存储 需要转置
CUBLAS_OP_N, // 列存储 不需要转置
2, 2, 3, // m, n, k
&fAlpha,
pCudaA, 3, // lda = k , 转置,所以是A的列数
pCudaB, 3, // ldb = k ,没有转置,所以是B的行数
&fBeta,
pCudaC, 2 // ldc = m
);
情况3
A为列存储, B为行存储
/*
A B C
1 3 5 1 2 35 44
2 4 6 3 4 44 56
5 6
*/
cublasSgemm(pCuBlas,
CUBLAS_OP_N, // 列存储 不需要转置
CUBLAS_OP_T, // 行存储 需要转置
2, 2, 3, // m, n, k
&fAlpha,
pCudaA, 2, // lda = m , 没有转置,所以是A的行数
pCudaB, 2, // ldb = n ,转置,所以是B的列数
&fBeta,
pCudaC, 2 // ldc = m
);
情况4
AB均为行存储
/*
A B C
1 2 3 1 2 22 28
4 5 6 3 4 49 64
5 6
*/
cublasSgemm(pCuBlas,
CUBLAS_OP_T, // 行存储 需要转置
CUBLAS_OP_T, // 行存储 需要转置
2, 2, 3, // m, n, k
&fAlpha,
pCudaA, 3, // lda = k , 转置,所以是A的列数
pCudaB, 2, // ldb = n ,转置,所以是B的列数
&fBeta,
pCudaC, 2 // ldc = m
);
规律总结
- AB列优先不需要转置,行优先需要转置
- C始终为列优先
- m,n,k 无论是否转置始终不变
- lda A不转置时为行数m,转置时为列数k
- ldb B不转置时为行数k,转置时为列数n
- ldc 无论是否转置,都为m
完整代码
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include <vector>
#include <cstdio>
int main(int argc, char* argv[])
{
cublasHandle_t pCuBlas = nullptr;
float *pCudaB = nullptr;
float *pCudaA = nullptr;
float *pCudaC = nullptr;
cudaError_t cudaStat = cudaSetDevice(0);
cudaStat = cudaMalloc((void**)&pCudaA, 2 * 3 * sizeof(float));
cudaStat = cudaMalloc((void**)&pCudaB, 3 * 2 * sizeof(float));
cudaStat = cudaMalloc((void**)&pCudaC, 2 * 2 * sizeof(float));
cublasStatus_t stat = cublasCreate(&pCuBlas);
std::vector<float> vA{1,2,3,4,5,6};
std::vector<float> vB{1,2,3,4,5,6};
std::vector<float> vC1(4,0);
std::vector<float> vC2(4,0);
std::vector<float> vC3(4,0);
std::vector<float> vC4(4,0);
cudaMemcpy(pCudaA, vA.data(), sizeof(float)*6, cudaMemcpyHostToDevice);
cudaMemcpy(pCudaB, vB.data(), sizeof(float)*6, cudaMemcpyHostToDevice);
float fAlpha = 1, fBeta = 0;
stat = cublasSgemm(pCuBlas,
CUBLAS_OP_N, CUBLAS_OP_N,
2, 2, 3, &fAlpha, pCudaA,
2, pCudaB,
3, &fBeta, pCudaC,
2
);
cudaMemcpy(vC1.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
stat = cublasSgemm(pCuBlas,
CUBLAS_OP_T, CUBLAS_OP_N,
2, 2, 3, &fAlpha, pCudaA,
3, pCudaB,
3, &fBeta, pCudaC,
2
);
cudaMemcpy(vC2.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
stat = cublasSgemm(pCuBlas,
CUBLAS_OP_N, CUBLAS_OP_T,
2, 2, 3, &fAlpha, pCudaA,
2, pCudaB,
2, &fBeta, pCudaC,
2
);
cudaMemcpy(vC3.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
stat = cublasSgemm(pCuBlas,
CUBLAS_OP_T, CUBLAS_OP_T,
2, 2, 3, &fAlpha, pCudaA,
3, pCudaB,
2, &fBeta, pCudaC,
2
);
cudaMemcpy(vC4.data(), pCudaC, sizeof(float)*4, cudaMemcpyDeviceToHost);
for(int i = 0 ;i < vA.size(); i++)
printf("A %d: %f\n", i, vA[i]);
for(int i = 0 ;i < vB.size(); i++)
printf("B %d: %f\n", i, vB[i]);
for(int i = 0 ;i < vC1.size(); i++)
printf("C1 %d: %f\n", i, vC1[i]);
for(int i = 0 ;i < vC2.size(); i++)
printf("C2 %d: %f\n", i, vC2[i]);
for(int i = 0 ;i < vC3.size(); i++)
printf("C3 %d: %f\n", i, vC3[i]);
for(int i = 0 ;i < vC4.size(); i++)
printf("C4 %d: %f\n", i, vC4[i]);
if(pCudaA) cudaFree(pCudaA);
if(pCudaB) cudaFree(pCudaB);
if(pCudaC) cudaFree(pCudaC);
if(pCuBlas) cublasDestroy(pCuBlas);
return 0;
}
输出
A 0: 1.000000
A 1: 2.000000
A 2: 3.000000
A 3: 4.000000
A 4: 5.000000
A 5: 6.000000
B 0: 1.000000
B 1: 2.000000
B 2: 3.000000
B 3: 4.000000
B 4: 5.000000
B 5: 6.000000
C1 0: 22.000000
C1 1: 28.000000
C1 2: 49.000000
C1 3: 64.000000
C2 0: 14.000000
C2 1: 32.000000
C2 2: 32.000000
C2 3: 77.000000
C3 0: 35.000000
C3 1: 44.000000
C3 2: 44.000000
C3 3: 56.000000
C4 0: 22.000000
C4 1: 49.000000
C4 2: 28.000000
C4 3: 64.000000