大模型原理理解-LLama 2 原理及源码解读

2,901 阅读12分钟

image.png llama 1和2 对比(70B的模型才有使用到GQA) image.png

1. Rope位置编码

  • end是self.params.max_seq_len * 2,也就是4096,这也是Llama2最大的token处理数量。
  • 请注意,self.params.max_seq_len 乘以 2,因为 Llama 2 代模型的代币限制为 4096。
  • 添加此乘数而不是直接使用 4096 是为了可以在训练或微调时实现令牌长度的动态变化。
  • precompute_freqs_cis的维度是(4096,64),freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # dim = 128  # hidden_size/num_neads
    # end = 4096 # self.params.max_seq_len * 2=4096
    # 计算词向量元素两两分组以后,每组元素对应的旋转角度 
    # torch.arange(0, dim, 2) [0, 2, 4, 6, 8, 10,..., 124, 126] 共64个
    # torch.arange(0, dim, 2)[: (dim // 2)] 保证是64个
    # torch.Size([64])
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # t = [0,....end] torch.Size([4096])
    t = torch.arange(end, device=freqs.device)  # type: ignore
    # t为列向量 freqs为行向量做外积
    # freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)
    freqs = torch.outer(t, freqs).float()  # torch.Size([4096, 64])
    # 生成复数
    # torch.polar(abs,angle) -> abs*cos(angle) + abs*sin(angle)*j
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # freqs_cis.shape  = (end,dim//2)
    return freqs_cis
​
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # ndim为x的维度数 ,此时应该为4
    # freqs_cis.shape = [1024, 64]
    # x.shape = [2, 1024, 32, 64]
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    # (1,x.shape[1],1,x.shape[-1]) = (1,1024,1,64)
    return freqs_cis.view(*shape)
​
def apply_rotary_emb(
    xq: torch.Tensor,         # [2, seqlen, 32, 128]
    xk: torch.Tensor,          
    freqs_cis: torch.Tensor,  # [seqlen,64]
) -> Tuple[torch.Tensor, torch.Tensor]:  # [2, seqlen, 32, 128]
    # xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]->[2, 1024, 32, 128]
    # xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]->torch.Size([2, 1024, 32, 64])
    # torch.view_as_complex用于将二维向量转换为复数域 torch.view_as_complex即([x,y]) -> (x+yj)
    # 所以经过view_as_complex变换后xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
    # (bsz,1024,32,128)->(bsz,1024,32,64)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    # freqs_cis ->(4096,64)
    # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]->(1024,64)
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 
    # freqs_cis.shape = (1,x.shape[1],1,x.shape[-1]) = (1,1024,1,64)
    # xq_ 与freqs_cis广播哈达玛积
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
    # torch.view_as_real用于将复数再转换回实数向量, 再经过flatten展平第4个维度 
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

2. Attention

大模型一般是分布式训练,这里涉及到几个概念。n_heads是注意力头的总个数,由于并行机制,每个进程会有n_local_heads个注意力头(为简化理解,可以假设只有一个进程)。由于计算当前位置的Attention Score依赖于之前所有的kv,因此需要将kv缓存下来。为了减少空间复杂度,可以对kv的头个数n_kv_heads进行调整,这个值一般小于等于n_heads,n_heads是n_kv_heads的整数倍,这个倍数也就是n_rep。相应的,每个进程会有n_local_kv_heads个注意力头。每个头的维度为head_dim=dim//n_heads。

例如:n_heads=32,model_parallel_size(并行数量)= 4,n_kv_heads = 8,n_local_heads = 32/4, n_local_kv_heads = 8/4,n_rep = 32/8。

关于kv_cache 的理解

  • kv_cache是缓存键值对,在训练过程中,我们只保存最近n个键值对.

举个例子,假设有这样一个生成任务 ref:zhuanlan.zhihu.com/p/649756898

In  [1]: {prompt:"将进酒:"}
Out [1]: 将进酒:人
​
In  [2]: 将进酒:人
Out [2]: 将进酒:人生
​
In  [3]: 将进酒:人生
Out [3]: 将进酒:人生得
​
In  [4]: 将进酒:人生得
Out [4]: 将进酒:人生得意
​
In  [5]: 将进酒:人生得意
Out [5]: 将进酒:人生得意需
​
​
In  [6]: 将进酒:人生得意需
Out [6]: 将进酒:人生得意需尽
​
In  [7]: 将进酒:人生得意需尽
Out [7]: 将进酒:人生得意需尽欢

在第三次预测中,是用"将进酒:人生" 来预测下一个"得"字,所以需要把"将进酒:人生"进行token化后再进行Attention计算:

不难发现在第二次处理的时候,就已经把"将进酒:人"所对应的Q,K,V进行过相关的运算,所以没必要在对他们进行Attention计算,这样就能节省大部分算力,由此K V Cache便是来解决这个问题的:通过将每次计算的K和V缓存下来,之后新的序列进来时只需要从KV Cache中读取之前的KV值即可,就不需要再去重复计算之前的KV了。

image.png

此外,对于 QQ 也不用将序列对应的所有 QiQ_i都计算出来,只需要计算最新的 QQ_生, 即此时句子长度为1(seqlen=1) , kv同理, 参考108-119行代码。

  • 注意:代码中只对kv进行cache,而没有对q 进行cache原因是当前token的q只需要和之前token的kv交互,之后不会再用到,所以没必要存;即使你把原来的Q保存下来,然后与新增的K算attention之后也会被mask掉,这是个单向注意力
  • 初始在prompt阶段seqlen>=1, 后续生成过程中seqlen==1。

关于 group query attention

而MQA(Multi Query Attention)就是Q依然保持多头,但是K,V只有一个,所有多头的Q共享一个K,V,因为要分组共享,那就要确认是哪几个组共用一个KV,那么通过repeat这种方式就能很直观的实现,相邻的(self.n_rep)头共享相同的KV的数值。

  • self.n_rep=1的时候就是multi head attention, 可以在配置文件中配置 args.n_kv_heads 决定。
class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, args: ModelArgs):
        """
        Initialize the Attention module.

        Args:
            args (ModelArgs): Model configuration parameters.

        Attributes:
            n_kv_heads (int): Number of key and value heads.
            n_local_heads (int): Number of local query heads.
            n_local_kv_heads (int): Number of local key and value heads.
            n_rep (int): Number of repetitions for local heads.
            head_dim (int): Dimension size of each attention head.
            wq (ColumnParallelLinear): Linear transformation for queries.
            wk (ColumnParallelLinear): Linear transformation for keys.
            wv (ColumnParallelLinear): Linear transformation for values.
            wo (RowParallelLinear): Linear transformation for output.
            cache_k (torch.Tensor): Cached keys for attention.
            cache_v (torch.Tensor): Cached values for attention.

        """
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """
        Forward pass of the attention module.
        Args:
            x (torch.Tensor): Input tensor.
            start_pos (int): Starting position for caching.
            freqs_cis (torch.Tensor): Precomputed frequency tensor.
            mask (torch.Tensor, optional): Attention mask tensor.

        Returns:
            torch.Tensor: Output tensor after attention.

        """
        # 假设当前x为(1, 1, dim),也就是上一个预测的token
        bsz, seqlen, _ = x.shape
        
        # 计算当前token的qkv
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
     
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        
        # 对当前token的qkv增加位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        
        # 缓存当前token的kv
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
         
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
        
        # 从缓存中读取之和当前的kv 
        # 之前索引是:start_pos
        # 当前索引是:start_pos : start_pos + seqlen
        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]
        
        # grouped query attention
        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        # 当seqlen 是 1 的时候,mask 是None
        if mask is not None:
        # 加入mask,使得前面的token在于后面的token计算attention时得分为0,mask掉
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

关于mask

132 行 mask 的格式:

  • 推理阶段seqlen=1 mask == None:
  • 初始在prompt阶段 seqlen>=1 ,mask 中 0元素的位置和score对应位置元素相加不变,-inf 和score 对应位置元素相加为-inf ,softmax 之后为0,mask 如下:
mask = None
# 假设 seq_len = 10
if seqlen > 1:
    # [10,10]的全-inf 矩阵
    mask = torch.full(
        (seqlen, seqlen), float("-inf"), device=tokens.device
    )
    #下三角为0 的矩阵,0
    mask = torch.triu(mask, diagonal=1) # (10,10)

    # When performing key-value caching, we compute the attention scores
    # only for the new sequence. Thus, the matrix of scores is of size
    # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
    # j > cache_len + i, since row i corresponds to token cache_len + i.
    # 和之前cache 缓存的进行拼接
    mask = torch.hstack([
        torch.zeros((seqlen, start_pos), device=tokens.device),
        mask
    ]).type_as(h) #(10,8)+(10,10)

3.FeedForward

激活函数: image.png

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        """
        Initialize the FeedForward module.

        Args:
            dim (int): Input dimension.
            hidden_dim (int): Hidden dimension of the feedforward layer.
            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.

        Attributes:
            w1 (ColumnParallelLinear): Linear transformation for the first layer.
            w2 (RowParallelLinear): Linear transformation for the second layer.
            w3 (ColumnParallelLinear): Linear transformation for the third layer.

        """
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

transformer源码逐行注释

transformer中的代码结构嵌套相对复杂,在需要对源码进行改造的时候,往往需要在下述代码中进行修改,此处主要对attention部分进行二次理解。github.com/huggingface…

待学习项:

  • rope_scaling
  • max_position_embeddings=4096 对应end,这也是Llama2最大的token处理数量。
  • pretraining_tp=1 配置文件中等于1,注释解释为:该参数是Experimental feature,表示预训练期间的并行化等级,设置为非1值的话,将激活线性层的更准确但更慢的计算,这应该更好地匹配原始logits。
  • position_ids 当前输入的序列的位置索引,如果是prompt ,past_key_values_length就是0,position_ids就是当前一次性输入的prompt 的长度,以上方“将进酒:”为例seq_length=Len(prompt)=4,注意严谨来说tokenizer.encoder之后应该是要比4长,这里为便于理解,就假设等长,position_ids =[0,1,2,3]; 如果是推理阶段,假设预测第10个字符,past_key_values_length = 9, seq_length=1, position_ids =[9]
position_ids= torch.arange(
                past_key_values_length, 
                seq_length + past_key_values_length,
                dtype=torch.long, device=device)

以下参数维度以7B为例

class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size  # 4096
        self.num_heads = config.num_attention_heads #32
        self.head_dim = self.hidden_size // self.num_heads  # 128
        # 对于70B,num_heads=64,num_key_value_heads=8,num_key_value_groups=4
        # 对于7B,num_heads=32,num_key_value_heads=32,num_key_value_groups=1
        self.num_key_value_heads = config.num_key_value_heads # 32 7B的时候没有GQA
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 1
        self.max_position_embeddings = config.max_position_embeddings # 4096
        self.rope_theta = config.rope_theta # 无 原定义为10000.0
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        # [4096,4096]
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        # [4096,4096]
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        # [4096,4096]
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        # [4096,4096]
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        # 初始化不同的rope,包含基本rope, 线性rope, 还有动态差值的rope
        self._init_rope()
    def _init_rope(self):
        if self.config.rope_scaling is None: #null
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        # [1,q_len,hidden_size]
        bsz, q_len, _ = hidden_states.size()
        
        if self.config.pretraining_tp > 1: 
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)
        else:
            # 输入通过线性变化得到 query_states,key_states,value_states
            # [1,q_len,hidden_size]->[1,q_len,hidden_size]
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)
            
        # 维度转置    
        # [1,q_len,128*32]->[1,32,q_len,128]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # [1,q_len,128*32]->[1,32,q_len,128]
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  
        kv_seq_len = key_states.shape[-2] # q_len
        # 如果是prompt past_key_value is None
        # 如果是推理阶段 past_key_value is not None 
        # past_key_value: 一个tuple,装了两个元素,分别是缓存的 key 和 value
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        
        # 给q和k添加rope 位置编码
        # cos (`torch.Tensor`): The cosine part of the rotary embedding.
        # sin (`torch.Tensor`): The sine part of the rotary embedding.
        # value_states [1, 32, seq_len=q_len, 128]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        
        #  如果有kv缓存, 就把之前缓存的和现在的进行拼接
        # past_key_value[0]:(1,32,past_key_value_length,128)
        # key_states:(1,32,q_len,128)
        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
            
        # 拼接之后 key_states:(1,32,kv_seq_len,128)
        # kv_seq_len = past_key_value_length+q_len
        past_key_value = (key_states, value_states) if use_cache else None
        
        # grouped query attention的时候需要共享参数,num_key_value_groups = 1,就是multi head attention
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        
        # 计算attention
        # (1,32,q_len,128)*(1,32,128,kv_seq_len) ->(1,32,q_len,kv_seq_len) 
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )
        # attention_mask 保证单项,也为了告诉模型那些部分的attention_score有效
        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

ref:

Llama 2详解 zhuanlan.zhihu.com/p/649756898