CUDA编程线程索引计算方式

167 阅读2分钟

首先复习一下小学知识

被除数=除数×+余数\text{被除数} = \text{除数} \times \text{商} + \text{余数}

对应到线程索引的计算,可以表示为:

最终的线程 Id=blockId×blockSize+threadId\text{最终的线程 Id} = \text{blockId} \times \text{blockSize} + \text{threadId}
  • blockId:当前 block 在 grid 中的坐标(可能是 1 维到 3 维)
  • blockSize:block 的大小,表示其中含有多少个线程(threads)
  • threadId:当前线程在 block 中的坐标(同样从 1 维到 3 维)

关键概念

  • Grid:包含若干个 blocks,blocks 的数量由 gridDim.x/y/z 描述。某个 block 在 grid 中的坐标由 blockIdx.x/y/z 描述。
  • Block:包含若干个 threads,threads 的数量由 blockDim.x/y/z 描述。某个 thread 在 block 中的坐标由 threadIdx.x/y/z 描述。

多维坐标到一维坐标的转换

多维坐标如何表示成一维坐标?我们可以参考两位数和三位数的表示方法:

数字=百位数字×100+十位数字×10+个位数字\text{数字} = \text{百位数字} \times 100 + \text{十位数字} \times 10 + \text{个位数字}

同样地,当我们知道每个维度上的大小时,就可以利用这种进制方法将三维坐标转换为一维坐标。对于坐标 (x, y, z),如果维度大小为 (Dx, Dy, Dz),一般将 z 视为最高维,y 次之,x 最低,则:

一维坐标 id=Dx×Dy×z+Dx×y+x\text{一维坐标 id} = Dx \times Dy \times z + Dx \times y + x

示例计算

1. 1D Grid, 1D Block

  • blockSize = blockDim.x
  • blockId = blockIdx.x
  • threadId = threadIdx.x
  • Id = blockIdx.x * blockDim.x + threadIdx.x

2. 3D Grid, 1D Block

  • blockSize = blockDim.x(一维 block 的大小)
  • blockId = gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x(三维 grid 中 block 的 id)
  • threadId = threadIdx.x(一维 block 中 thread 的 id)
  • Id = (gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x

3. 1D Grid, 2D Block

  • blockSize = blockDim.x * blockDim.y(二维 block 的大小)
  • blockId = blockIdx.x(一维 grid 中 block id)
  • threadId = blockDim.x * threadIdx.y + threadIdx.x(二维 block 中 thread 的 id)
  • Id = blockIdx.x * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x

4. 3D Grid, 3D Block

  • blockSize = blockDim.x * blockDim.y * blockDim.z(三维 block 的大小)
  • blockId = gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x(三维 grid 中 block 的 id)
  • threadId = blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x(三维 block 中 thread 的 id)
  • Id = (gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x

确实,2D Grid 也是常见的情况,我在上面的表格中漏掉了它们。现在我补充一下包含 2D Grid 情况的表格。

完整的线程索引计算表

情况blockSize 计算blockId 计算threadId 计算Id 计算
1D Grid, 1D BlockblockDim.xblockIdx.xthreadIdx.xblockIdx.x * blockDim.x + threadIdx.x
2D Grid, 1D BlockblockDim.xgridDim.x * blockIdx.y + blockIdx.xthreadIdx.x(gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x
3D Grid, 1D BlockblockDim.xgridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.xthreadIdx.x(gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x
1D Grid, 2D BlockblockDim.x * blockDim.yblockIdx.xblockDim.x * threadIdx.y + threadIdx.xblockIdx.x * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x
2D Grid, 2D BlockblockDim.x * blockDim.ygridDim.x * blockIdx.y + blockIdx.xblockDim.x * threadIdx.y + threadIdx.x(gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x
3D Grid, 2D BlockblockDim.x * blockDim.ygridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.xblockDim.x * threadIdx.y + threadIdx.x(gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x
1D Grid, 3D BlockblockDim.x * blockDim.y * blockDim.zblockIdx.xblockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.xblockIdx.x * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x
2D Grid, 3D BlockblockDim.x * blockDim.y * blockDim.zgridDim.x * blockIdx.y + blockIdx.xblockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x(gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x
3D Grid, 3D BlockblockDim.x * blockDim.y * blockDim.zgridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.xblockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x(gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x