DeepSeek之MLA

1,021 阅读9分钟

DeepSeek之MLA

一、产生背景

  • MHA(Multi-Head Attention)

    《Attention is all you need》中经典的多头注意力机制。分多个头,每个头一组Wq Wk Wv矩阵。

    问题:kv cache占用显存太多,以下都是都kv cache进行优化

  • MQA(Multi-Query Attention)

    《Fast Transformer Decoding: One Write-Head is All You Need》中首次提到。分多个头,每个头一个Wq矩阵,全部头共享一组Wk和Wv矩阵

    问题:对kv cache压缩的太多,仅为原来的1/h,h为头数。

  • GQA(Grouped-Query Attention)

    《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》中首次提到。分多个头,每个头一个Wq矩阵,把头划分n组,每组头共享一组Wk和Wv矩阵。

    问题:MHA和MQA之间的优化,可以灵活转换为MHA或者MQA。虽然此时kv cache已经可以灵活的进行压缩,但是推理计算时的计算量和MHA一样的,当然MQA也是一样的。推理时都会转成MHA进行推理计算。

二、解决方案

思考一下GQA和MQA,可以把他们理解为,在推理的时候,把一个小的矩阵放入kv cache中,然后推理计算的时候把小的矩阵扩大成一个大的矩阵,进而和q进行相关性的计算。

这边的扩大实际上是简单的分割和复制,而他们都是线性变换。此时我们是不是可以自定义存储到kv cache中的内容,存储一个很小矩阵,之后需要用的时候再进行升维即可。这样就解决了kv cache显存占用太高的问题,不过这里会增加推理时的计算量,相当于时间换空间了。

三、MLA

MLA(Multi-head Latent Attention)多头潜在注意力机制,和上述解决方案提到的一样,不过MLA可以对推理时的计算量进行显著的优化。

  • 公式推导

    根据第二步的解决方案我们计算q和k相关性的时候可以有如下公式:

    qt(s)ki(s)T=(xtWq(s))(ciWk(s))T=xt(Wq(s)Wk(s)T)ciTq_t^{(s)}k_i^{(s)T} = (x_tW_q^{(s)})(c_iW_k^{(s)})^T = x_t(W_q^{(s)}W_k^{(s)T})c_i^T

    此公式表示在第s个头上,第t个token的q和第i个token的k。的相关性如何计算的。

    t表示第t个token,(s)表示第s个头,i表示第i个token,c表示x经过线性转换后的低维度矩阵。

    根据公式,我们计算q和k相关性的时候,可以把(Wq(s)Wk(s)T)(W_q^{(s)}W_k^{(s)T})当成Wq矩阵,cic_i当成k。推理时,把cic_i放入cache,当下一个token进入网络后,直接算此token经过(Wq(s)Wk(s)T)(W_q^{(s)}W_k^{(s)T})此矩阵后的结果作为q,然后计算q和k的相关性即可。此时就不会增加推理过程中的计算量了。同理,除了k可以这样进行,v也可以这样。ciWv(s)c_iW_v^{(s)}为我们的v,当进行加权合并后,还有一个WoW_o矩阵相乘,这里就可以进行吸收Wv(s)W_v^{(s)}矩阵。

  • RoPE接入

    上述公式推导,如果加入旋转位置编码,则公式变为这样:

    qt(s)ki(s)T=(xtWq(s)Rt)(ciWk(s)Ri)T=xt(Wq(s)RtiWk(s)T)ciTq_t^{(s)}k_i^{(s)T} = (x_tW_q^{(s)}R_t)(c_iW_k^{(s)}R_i)^T = x_t(W_q^{(s)}R_{t-i}W_k^{(s)T})c_i^T

    此时(Wq(s)RtiWk(s)T)(W_q^{(s)}R_{t-i}W_k^{(s)T})无法提取出来成公共的矩阵,因为其中位置信息(t-i),所以只能考虑其他变种。

    在DeepSeek-v2中,是给q和k新增drd_r个维度,用于表示位置RoPE位置编码,而在cache中,额外把k对应的RoPE的信息加入进来了。

image.png

  • 代码实现

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    
    
    # rms归一化
    class RMSNorm(nn.Module):
        def __init__(self, hidden_size, eps=1e-6):
            super().__init__()
            self.weight = nn.Parameter(torch.ones(hidden_size))
            self.variance_epsilon = eps
    
        """
            layernorm是减去均值后再除方差
            rmsnorm是之间除以方差(注意这里的方差算的时候不减均值)
            本质就是进行缩放。
        """
    
        def forward(self, hidden_states):
            hidden_states = hidden_states.float()
            variance = hidden_states.pow(2).mean(-1, keepdim=True)  # 求平方,再求均值
            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  # rsqrt为平方根的倒数。把hs除以平方根
            return self.weight * hidden_states.float()  # 线性转换
    
    
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    
    def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
    
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
    
        return q_embed, k_embed
    
    
    # 旋转位置编码
    class RotaryEmbedding(nn.Module):
        def __init__(self, dim, max_seq_len=1024):
            super(RotaryEmbedding, self).__init__()
            self.dim = dim
            self.max_seq_len = max_seq_len
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_seq_len).float().unsqueeze(1)
            freqs = t @ inv_freq.unsqueeze(0)
            freqs = torch.cat((freqs, freqs), dim=-1)
    
            self.register_buffer("cos_cached", freqs.cos())
            self.register_buffer("sin_cached", freqs.sin())
    
        def forward(self, q, k):
            cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
            sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
            return apply_rotate_pos_emb(q, k, cos, sin)
    
    
    class MLA(nn.Module):
        def __init__(self,
                     dim,
                     n_heads,
                     q_lora_rank,
                     kv_lora_rank,
                     qk_nope_head_dim,
                     qk_rope_head_dim,
                     v_head_dim,
                     max_seq_len,
                     max_batch_size,
                     mode):
            super().__init__()
            self.dim = dim  # 隐藏层维度
            self.n_heads = n_heads  # 总头数
            self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度
            self.kv_lora_rank = kv_lora_rank  # kv低秩压缩到的维度
            self.qk_nope_head_dim = qk_nope_head_dim  # 表示qk的维度一样
            self.qk_rope_head_dim = qk_rope_head_dim  # 表示qk的维度一样
            self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # qk的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度
            self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度
            self.mode = mode
            self.max_seq_len = max_seq_len
            self.max_batch_size = max_batch_size
    
            self.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵
            # 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944
    
            self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # kv的降维矩阵
            # nn.Linear(self.dim, self.kv_lora_rank)
            # nn.Linear(self.dim, self.qk_rope_head_dim)
            self.kv_norm = RMSNorm(self.kv_lora_rank)
            self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (
                        self.qk_nope_head_dim + self.v_head_dim))  # kv的升维矩阵 升的维度是头数 乘以 k的维度和v的维度和
    
            self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)
    
            self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转旋转位置编码
    
            if self.mode == 'naive':
                self.register_buffer('k_cache',
                                     torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),
                                     persistent=False)
                self.register_buffer('v_cache',
                                     torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),
                                     persistent=False)
    
            else:
                self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),
                                     persistent=False)
                self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),
                                     persistent=False)
    
        def forward(self, x, mask=None):
    
            bs, seq_len, _ = x.shape  # x为[N,T,E]
    
            # -------对q进行处理,初始化-------
            q = self.wq_a(x)  # 计算q,这里的q的降维后的 [N,T,q_lora_rank]
            q = self.q_norm(q)  # 对q进行标准化
            q = self.wq_b(q)  # 对q进行升维[N,T,H*(qk_nope_E+qk_rope_E)] 这里的维度大小是两部分拼接来的,一部分是用于计算的,一部分后续是需要经过旋转位置编码的,
            q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [N,T,H,qk_nope_E+qk_rope_E]
            q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
                                       dim=-1)  # 切分no_pe和rope [N,T,H,qk_nope_E] [N,T,H,qk_rope_E]
    
            # -------对kv进行处理,初始化-------
            kv = self.wkv_a(x)  # 计算kv,这里的kv是进行降维后的。[N, T, kv_lora_rank + qk_rope_E]
            kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],
                                   dim=-1)  # [N, T, kv_lora_rank] 和 [N, T, qk_rope_E]
    
            k_pe = k_pe.unsqueeze(2)  # [N, T, 1, qk_rope_E] 后续进行扩容到H
            q_pe, k_pe = self.rotary_emb(q_pe, k_pe)  # 进行旋转位置编码 [N, T, H, qk_rope_E],[N, T, 1, qk_rope_E]
    
            # -------计算相关性-------
            if self.mode == 'naive':  # 此方案没有降低kv cache
    
                q = torch.cat([q_nope, q_pe], dim=-1)  # [N, T, H, (qk_nope_E+qk_rope_E)]
    
                kv = self.kv_norm(kv)  # [N, T, kv_lora_rank]
                kv = self.wkv_b(kv)  # 进行升维 [N, T, H * (qk_nope_E + v_E)]
                kv = kv.view(bs, seq_len, self.n_heads,
                             self.qk_nope_head_dim + self.v_head_dim)  # 把H拿出来,[N,T,H,qk_nope_E + v_E]
                k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim],
                                        dim=-1)  # 切分k和v [N,T,H,qk_nope_E]和[N,T,H,v_E]
    
                k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)  # [N, T, H, (qk_nope_E+qk_rope_E)]
    
                self.k_cache[:bs, :seq_len, :, :] = k  # k放入到k_cache中
                self.v_cache[:bs, :seq_len, :, :] = v  # v放入到v_cache中
                # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
                # 计算相关性,q乘k的转置,除以根号下dk。 计算前先把q和k的维度变成[N,H,T,E]
                scores = torch.matmul(q.transpose(1, 2),
                                      self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(
                                          self.qk_nope_head_dim + self.qk_rope_head_dim))
                # 把维度从[N,H,T,E] 再转换成 [N,T,H,E]
                scores = scores.transpose(1, 2)
    
            else:  # 此方案降低了kv cache
                k_pe = k_pe.squeeze(2)  # [N, T, 1, qk_rope_E] 再变回 [N, T, qk_rope_E]
    
                wkv_b = self.wkv_b.weight  # 升维矩阵的权重,和正常是相反的:[输出,输入],[H * (qk_nope_E + v_E), kv_lora_rank]
                wkv_b = wkv_b.view(self.n_heads, -1,
                                   self.kv_lora_rank)  # [H, qk_nope_E+v_E, kv_lora_rank]
                # q和kv的升维矩阵提前相乘了。
                q_nope = torch.einsum("bshd,hdc->bshc", q_nope,
                                      wkv_b[:,
                                      :self.qk_nope_head_dim])  # [N, T, H, qk_nope_E] 和  [H, qk_nope_E, kv_lora_rank] ---> [N, T, H, kv_lora_rank]
    
                # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的kv
                # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的kv计算注意力了,kv_caceh时也只需存储压缩后的kv
    
                """
                q*k(T)  = x * wq* (c*wkv_b[:, :self.qk_nope_head_dim])(T) 
                        = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的kv
                      
                        c*wkv_b[:, :self.qk_nope_head_dim] c为压缩有的kv, wkv_b为升维矩阵,  
                      
                      
                einsum  
                    基本概念:
                        出现在箭头右边的是自由索引,只出现在箭头左边的是求和索引(表示中间计算结果需要在这个维度上求和才能得到输出)  
                    三个规则:
                        - 箭头左边,不同输入之间重复出现的索引表示两个输入需要在该维度上进行乘法操作。
                        - 和求和索引的解释一样
                        - 箭头右边的顺序可以是任意的,自定义好后,会自动进行相应的转置操作
              
                eg:"bshc,btc->bsht" bsht是自由索引,c是求和索引
              
                """
    
                kv = self.kv_norm(kv)  # 进行标准化
                self.kv_cache[:bs, :seq_len, :] = kv  # kv缓存,为[N,T,kv_lora_rank]
                self.pe_cache[:bs, :seq_len, :] = k_pe  # 旋转位置编码缓存 [N, T, qk_rope_E]
    
                # q_nope和k_nope计算相关性。这里的k是一个头。 "bshc,btc->bsht"表示在c这个维度上矩阵相乘。
                scores_nope = torch.einsum("bshc,btc->bsht", q_nope,
                                           self.kv_cache[:bs, :seq_len,
                                           :])  # [N, T, H, kv_lora_rank] 和 [N, T, kv_lora_rank] --->[N, T, H, T]   # bshc btc -> bshc bct -> bsht
                # q_pe和k_pe计算相关性。这里的k是一个头。 "bshr,btr->bsht"表示在r这个维度上矩阵相乘
                scores_pe = torch.einsum("bshr,btr->bsht", q_pe,
                                         self.pe_cache[:bs, :seq_len,
                                         :])  # [N, T, H, qk_rope_E] 和 [N, T, qk_rope_E] --->[N, T, H, T]  bshr btr -> bshr bt1r -> bshr bthr -> bsht
                # 相关性加起来除以根号下dk [N, T, H, T]  + [N, T, H, T]  = [N, T, H, T]
                scores = (scores_nope + scores_pe) / math.sqrt(
                    self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]
    
            if mask is not None:
                # mask shape:[bs, seq_len, seq_len]
                scores += mask.unsqueeze(2)
    
            scores = scores.softmax(dim=-1)
    
            if self.mode == 'naive':
                x = torch.einsum("bsht,bthd->bshd", scores,
                                 self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshd
            else:
    
                # scores * v = scores * c * wkv_b[:, -self.v_head_dim:]  加权合并value。  "bsht,btc->bshc"  在t维度上进行相乘
                # scores为[N, T, H, T]
                # self.kv_cache[:bs, :seq_len]为[N, T, kv_lora_rank]
                # 结果是[N,T,H,kv_lora_rank] # 加权合并value。 value为 kv低秩矩阵和低秩矩阵的升维矩阵相乘后的结果。不过这里是分开乘了。
                x = torch.einsum("bsht,btc->bshc", scores,
                                 self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]
    
                # "bshc,hdc->bshd" 在c维度上进行相乘。 乘以低秩矩阵的升维矩阵
                # [N,T,H,kv_lora_rank],[H, v_E, kv_lora_rank] ---> [N,T,H,v_E]
                x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshd
    
            # [N,T,H,v_E] ---> [N,T,H*v_E]
            x = x.contiguous().view(bs, seq_len, -1)
            # [N,T,H*v_E] ---> [N,T,E] Wo矩阵,
            x = self.wo(x)
    
            return x
    
    
    if __name__ == '__main__':
        x = torch.randn(4, 100, 4096)
    
        dim = 4096  # 输入的维度
        n_heads = 16  # 头数
        q_lora_rank = 128  # q映射到低纬度为128
        kv_lora_rank = 64  # kv映射到低纬度为64
        qk_nope_head_dim = 256  # 不带旋转位置编码的维度  256 * 16 = 4096
        qk_rope_head_dim = 48  # 旋转位置编码的维度
        v_head_dim = 256  # 值维度
        max_seq_len = 512  # 最大序列长度
        max_batch_size = 16  # 最大批次大小
        mode = 'none'
    
        mla = MLA(dim=dim,
                  n_heads=n_heads,
                  q_lora_rank=q_lora_rank,
                  kv_lora_rank=kv_lora_rank,
                  qk_nope_head_dim=qk_nope_head_dim,
                  qk_rope_head_dim=qk_rope_head_dim,
                  v_head_dim=v_head_dim,
                  max_seq_len=max_seq_len,
                  max_batch_size=max_batch_size,
                  mode=mode)
    
        print(mla(x))
        print(mla.kv_cache)