static constexpr int NWarpsPerSM = NWarpsPerSM_; // 4
static constexpr int NumThreads = NWarpsPerSM * 32; // 128
using GmemCopyAtom = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<sizeof(uint128_t)>, T>;
static constexpr int GmemValsPerLoad = sizeof(uint128_t) / sizeof(T); // 每次vector加载128位,可以加载几个元素。128/16=8
static constexpr int GmemThreadsPerRow = HeadDim / GmemValsPerLoad; // each thread reads 128 bit,计算得到需要几个线程 128/8=16
using TiledCopyQKVO = decltype(make_tiled_copy( // https://zhuanlan.zhihu.com/p/703560147
GmemCopyAtom{}, // MMA_Atom
make_layout( // ThrLayout
Shape<Int<NumThreads / GmemThreadsPerRow>, Int<GmemThreadsPerRow>>{},
GenRowMajor{}),
make_layout(Shape<_1, Int<GmemValsPerLoad>>{}, GenRowMajor{}))); // ValLa
原理学习:zhuanlan.zhihu.com/p/703560147
zhuanlan.zhihu.com/p/702818267
zhuanlan.zhihu.com/p/697228676
zhuanlan.zhihu.com/p/934430036
首先定义如下的copytiled,定义如下
- Copy_Atom为AutoVectorizingCopyWithAssumedAlignment<sizeof(uint128_t)>,即一个独立拷贝的操作为向量化拷贝128位,ld.global.vec.u32
- ValLayout定义为Shape<_1, Int>{},本例中为<1, 8>。正好对应Atom的拷贝8个元素(每个元素为f16),即一个线程会拷贝八个元素
- ThrLayout定义为Shape<Int<NumThreads / GmemThreadsPerRow>, Int>{},表示将拷贝操作横向扩展GmemThreadsPerRow次,纵向扩展NumThreads / GmemThreadsPerRow次。
因此,打印出来的TiledCopy内容如下:
tiled_copy: TiledCopy
Tiler_MN: (_8,_128)
TiledLayout_TV: ((_16,_8),_8):((_64,_1),_8)
-
首先,Tiler_MN指的是TiledCopy执行一次copy时,操作的Src/Dst Tensor的Shape,是ValLayout与ThrLayout的乘积。即一次拷贝会拷贝(8, 128)的矩阵大小,同时里面每个线程会拷贝(1, 8)个元素,一共有(8, 16)个线程
-
然后TiledLayout_TV是用来计算每个线程具体拷贝的范围的。例如,ID为19的Thread,它拷贝的Tensor分块中,坐标为(0, 2)的元素,对应Tiler_MN中的坐标是多少?(注:以下都是col major)
-
首先,将Thread ID 19转换为Shape(16, 8)的坐标:(3, 1),坐标(0, 2)是Shape(1, 8)的坐标,这个Shape可以简化为(8,),因此坐标也可以简化为(2,)。因此,输入的坐标为((3, 1), 2)。
-
计算offset = 3 x 64 + 1 x 1 + 2 x 8 = 209(坐标乘以步长,然后求和)
-
将209转换为Shape(8, 128)的坐标,(8, 128)即Tiler_MN,结果为:(1, 26)
因此,对于Thread 9这个线程来说,它拷贝的Tensor分块中坐标为(0, 2)的元素位于Tiler_MN的(1, 26)位置处。事实上,Thread 9负责Tiler_MN上(1, 24:32)这个分块。
验证环节
首先通过代码定义拆分
// B=1, H=32, N_QO=2048, HeadDim=128
auto Q = make_tensor(make_gmem_ptr(pQ), make_layout(make_shape(B, H, N_QO, HeadDim), GenRowMajor{}));
auto gQ = local_tile(Q, make_shape(_1{}, _1{}, Int<BlockQO>{}, Int<HeadDim>{}),
make_coord(bx, by, bz, 0))(0, 0, _, _);
__shared__ T psQ[BlockQO * HeadDim]; // 64x128
auto sQ = make_tensor(
make_smem_ptr(psQ),
make_layout(make_shape(Int<BlockQO>{}, Int<HeadDim>{}), GenRowMajor{}));
TiledCopy tiled_copy;
auto thr_copy = tiled_copy.get_slice(tx);
auto tQgQ = thr_copy.partition_S(gQ);
auto tQsQ = thr_copy.partition_D(sQ);
打印出来的内容如下
- gQ: (_64,_128):(128,_1)
- tQgQ: ((_1,_8),_8,_1):((_0,_1),1024,_0)
- tQsQ: ((_1,_8),_8,_1):((_0,_1),_1024,_0)
对于gQ,很正常,就是将[2048, 128]矩阵拆分为[64, 128]
对于tQgQ,首先(_1,_8)代表了当前线程的copy规模为8个元素,然后是第三个8,其stride为1024,表示当前线程会拷贝8次,每次跨1024个元素,正好就是Tiler_MN的大小。而整个gQ为(_64,_128),正好对应了8个Tiler_MN(8, 128),也就是128个线程每个线程都需要拷贝8次。因此tQsQ的大小为1x8x8x1=64。
同理,tQsQ也是一样。
由于tQgQ和tQsQ的layout一样,因此可以直接调用copy函数
// copy Q into smem
copy(tiled_copy, tQgQ, tQsQ);
接下来我们利用打印来看看,线程中tQsQ是如何拷贝的。
打印出tid=19的线程的tQsQ,内容如下:
tQsQ(0): 152.000000
tQsQ(1): 153.000000
tQsQ(2): 154.000000
tQsQ(3): 155.000000
tQsQ(4): 156.000000
tQsQ(5): 157.000000
tQsQ(6): 158.000000
tQsQ(7): 159.000000
----------------------------------------
tQsQ(8): 1176.000000
tQsQ(9): 1177.000000
tQsQ(10): 1178.000000
tQsQ(11): 1179.000000
tQsQ(12): 1180.000000
tQsQ(13): 1181.000000
tQsQ(14): 1182.000000
tQsQ(15): 1183.000000
----------------------------------------
tQsQ(16): 2200.000000
tQsQ(17): 2200.000000
tQsQ(18): 2202.000000
tQsQ(19): 2204.000000
tQsQ(20): 2204.000000
tQsQ(21): 2204.000000
tQsQ(22): 2206.000000
tQsQ(23): 2208.000000
----------------------------------------
tQsQ(24): 3224.000000
tQsQ(25): 3224.000000
tQsQ(26): 3226.000000
tQsQ(27): 3228.000000
tQsQ(28): 3228.000000
tQsQ(29): 3228.000000
tQsQ(30): 3230.000000
tQsQ(31): 3232.000000
----------------------------------------
tQsQ(32): 4248.000000
tQsQ(33): 4248.000000
tQsQ(34): 4248.000000
tQsQ(35): 4252.000000
tQsQ(36): 4252.000000
tQsQ(37): 4252.000000
tQsQ(38): 4256.000000
tQsQ(39): 4256.000000
----------------------------------------
tQsQ(40): 5272.000000
tQsQ(41): 5272.000000
tQsQ(42): 5272.000000
tQsQ(43): 5276.000000
tQsQ(44): 5276.000000
tQsQ(45): 5276.000000
tQsQ(46): 5280.000000
tQsQ(47): 5280.000000
----------------------------------------
tQsQ(48): 6296.000000
tQsQ(49): 6296.000000
tQsQ(50): 6296.000000
tQsQ(51): 6300.000000
tQsQ(52): 6300.000000
tQsQ(53): 6300.000000
tQsQ(54): 6304.000000
tQsQ(55): 6304.000000
----------------------------------------
tQsQ(56): 7320.000000
tQsQ(57): 7320.000000
tQsQ(58): 7320.000000
tQsQ(59): 7324.000000
tQsQ(60): 7324.000000
tQsQ(61): 7324.000000
tQsQ(62): 7328.000000
tQsQ(63): 7328.000000
根据我们前面用TiledLayout_TV计算得到的thread19在Tiler_MN上的坐标为(1, 24:32),实际可以看到thread19对应的tQsQ起始地址为152,即1x16x8+3x8,正好对应了坐标(1, 24:32)。可以看到,是符合我们预期的。
然后是从smem-->rmem的过程
// Smem to Rmem config
// SM75_U32x4_LDSM_N指令对应的是ldmatrix.sync.aligned.x4.m8n8.shared.b16,即拷贝连续的四个8x8的bf16矩阵
// 具体指令拷贝可以参考https://zhuanlan.zhihu.com/p/621855199,https://zhuanlan.zhihu.com/p/697228676
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, T>; // LDSM will fit in the MMA_Atom, note that
// we do not handle bank conflict here
// MMA config
// reed大佬:https://zhuanlan.zhihu.com/p/663092747
// ldmatrix: https://zhuanlan.zhihu.com/p/702818267
static_assert(std::is_same_v<T, half_t> || std::is_same_v<T, bfloat16_t>);
// For simplicity, mnk == (16, 8, 8) is used: two MMAs will have the same
// layout so that we don't need to adjust tSrS to fit in tOrS
using MMA_Atom = std::conditional_t<std::is_same_v<T, half_t>,
MMA_Atom<SM80_16x8x8_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x8_F32BF16BF16F32_TN>>; //mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32
using TiledMMA = decltype(make_tiled_mma(
MMA_Atom{},
make_layout(Shape<Int<NWarpsPerSM>, _1, _1>{}, GenRowMajor{}), // MMAThrLayout,决定扩展MMA_Atom的维度,即增加几个线程处理
Tile<Int<16 * NWarpsPerSM>, _16, _16>{} // Permutations,决定扩展后的重复次数,即线程处理次数
// for SM75_U32x4_LDSM_N, we need at least 4 * 8x8 matrix, which is 16x16
));
// 在CuTe中,make_tiled_copy_A/B函数会自动帮助我们计算列起始地址等信息,我们需要做仅仅是选择合适Copy_Operation
// 简单理解,make_tiled_copy_A会根据我们定义的tiled_mma,将我们配置的SmemCopyAtom适配成需要的数据拷贝形式
auto tiled_s2r_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
auto thr_s2r_copy_Q = tiled_s2r_copy_Q.get_slice(tx);
// partition表示对一个大的逻辑Tensor进行划分得到当前线程的拷贝所需要的源Tensor和目标Tensor,
// 而retile系列的函数表示其输入的数据已经是当前的线程的私有的数据了,但是其可能不满足拷贝所要求的形状,
// 需要将其变换到拷贝所支持的形状
auto tXsQ = thr_s2r_copy_Q.partition_S(sQ);
auto tXrQ = thr_s2r_copy_Q.retile_D(tSrQ); // (CPY, MMA_QO, MMA_HEAD)
首先要了解的是,此时我们已经将数据拷贝到shared memory sQ中了,大小为[64, 128]。此时我们需要做的事是利用ldmatrix指令和wmma指令进行mma计算,因此重点在于如何使用ldmatrix来加载sQ中的内容,以符合所选的mma的计算要求。
SmemCopyAtom选用的是指令SM75_U32x4_LDSM_N,对应的ptx为ldmatrix.sync.aligned.x4.m8n8.shared.b16,具体指令介绍可以看上述提供文档,此处不再赘述。
MMA_Atom选用的指令是SM80_16x8x8_F32BF16BF16F32_TN,对应的ptx为mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32。
首先看TiledMMA
zhuanlan.zhihu.com/p/663092747
using TiledMMA = decltype(make_tiled_mma(
MMA_Atom{},
make_layout(Shape<Int<NWarpsPerSM>, _1, _1>{}, GenRowMajor{}), // MMAThrLayout,决定扩展MMA_Atom的维度,即增加几个线程处理
Tile<Int<16 * NWarpsPerSM>, _16, _16>{} // Permutations,决定扩展后的重复次数,即线程处理次数
// for SM75_U32x4_LDSM_N, we need at least 4 * 8x8 matrix, which is 16x16
));
打印如下
tiled_mma: TiledMMA
ThrLayoutVMNK: (_32,_4,_1,_1):(_1,_32,_0,_0)
PermutationMNK: (_64,_16,_16)
MMA_Atom
ThrID: _32:_1
Shape_MNK: (_16,_8,_8)
LayoutA_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))
LayoutB_TV: ((_4,_8),_2):((_16,_1),_8)
LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))
这里的TiledMMA中有两个重要的layout:
- ThrLayoutVMNK: 也就是AtomLayoutMNK,表示重复MMA_Atom来扩张TiledMMA,即增加MMA_Atom数,这里设置了<4, 1, 1>,说明在m方向上增加了4倍的MMA_Atom,即128个线程(4个warp)
- PermutationMNK: 参考github.com/NVIDIA/cutl…, 8, 8),而PermutationMNK为(_64,_16,_16),说明在n和k方向上都增加了2倍。
最终我们得到的整个TiledMMA示意如下:
接下来我们做thread纬度的切分