首先复习一下小学知识
对应到线程索引的计算,可以表示为:
- blockId:当前 block 在 grid 中的坐标(可能是 1 维到 3 维)
- blockSize:block 的大小,表示其中含有多少个线程(threads)
- threadId:当前线程在 block 中的坐标(同样从 1 维到 3 维)
关键概念
- Grid:包含若干个 blocks,blocks 的数量由
gridDim.x/y/z描述。某个 block 在 grid 中的坐标由blockIdx.x/y/z描述。 - Block:包含若干个 threads,threads 的数量由
blockDim.x/y/z描述。某个 thread 在 block 中的坐标由threadIdx.x/y/z描述。
多维坐标到一维坐标的转换
多维坐标如何表示成一维坐标?我们可以参考两位数和三位数的表示方法:
同样地,当我们知道每个维度上的大小时,就可以利用这种进制方法将三维坐标转换为一维坐标。对于坐标 (x, y, z),如果维度大小为 (Dx, Dy, Dz),一般将 z 视为最高维,y 次之,x 最低,则:
示例计算
1. 1D Grid, 1D Block
- blockSize =
blockDim.x - blockId =
blockIdx.x - threadId =
threadIdx.x - Id =
blockIdx.x * blockDim.x + threadIdx.x
2. 3D Grid, 1D Block
- blockSize =
blockDim.x(一维 block 的大小) - blockId =
gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x(三维 grid 中 block 的 id) - threadId =
threadIdx.x(一维 block 中 thread 的 id) - Id =
(gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x
3. 1D Grid, 2D Block
- blockSize =
blockDim.x * blockDim.y(二维 block 的大小) - blockId =
blockIdx.x(一维 grid 中 block id) - threadId =
blockDim.x * threadIdx.y + threadIdx.x(二维 block 中 thread 的 id) - Id =
blockIdx.x * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x
4. 3D Grid, 3D Block
- blockSize =
blockDim.x * blockDim.y * blockDim.z(三维 block 的大小) - blockId =
gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x(三维 grid 中 block 的 id) - threadId =
blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x(三维 block 中 thread 的 id) - Id =
(gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x
确实,2D Grid 也是常见的情况,我在上面的表格中漏掉了它们。现在我补充一下包含 2D Grid 情况的表格。
完整的线程索引计算表
| 情况 | blockSize 计算 | blockId 计算 | threadId 计算 | Id 计算 |
|---|---|---|---|---|
| 1D Grid, 1D Block | blockDim.x | blockIdx.x | threadIdx.x | blockIdx.x * blockDim.x + threadIdx.x |
| 2D Grid, 1D Block | blockDim.x | gridDim.x * blockIdx.y + blockIdx.x | threadIdx.x | (gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x |
| 3D Grid, 1D Block | blockDim.x | gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x | threadIdx.x | (gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * blockDim.x + threadIdx.x |
| 1D Grid, 2D Block | blockDim.x * blockDim.y | blockIdx.x | blockDim.x * threadIdx.y + threadIdx.x | blockIdx.x * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x |
| 2D Grid, 2D Block | blockDim.x * blockDim.y | gridDim.x * blockIdx.y + blockIdx.x | blockDim.x * threadIdx.y + threadIdx.x | (gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x |
| 3D Grid, 2D Block | blockDim.x * blockDim.y | gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x | blockDim.x * threadIdx.y + threadIdx.x | (gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y) + blockDim.x * threadIdx.y + threadIdx.x |
| 1D Grid, 3D Block | blockDim.x * blockDim.y * blockDim.z | blockIdx.x | blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x | blockIdx.x * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x |
| 2D Grid, 3D Block | blockDim.x * blockDim.y * blockDim.z | gridDim.x * blockIdx.y + blockIdx.x | blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x | (gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x |
| 3D Grid, 3D Block | blockDim.x * blockDim.y * blockDim.z | gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x | blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x | (gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x) * (blockDim.x * blockDim.y * blockDim.z) + blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x |