- facebook 源码 github.com/facebookres…
- transformers 源码 当前为了耦合transformers 的接口,方便使用from_pretrain 之类的接口,都按照下述类框架进行修改 。github.com/huggingface…
- llama 2 基础参数
llama 1和2 对比(70B的模型才有使用到GQA)
1. Rope位置编码
- end是self.params.max_seq_len * 2,也就是4096,这也是Llama2最大的token处理数量。
- 请注意,self.params.max_seq_len 乘以 2,因为 Llama 2 代模型的代币限制为 4096。
- 添加此乘数而不是直接使用 4096 是为了可以在训练或微调时实现令牌长度的动态变化。
- precompute_freqs_cis的维度是(4096,64),
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# dim = 128 # hidden_size/num_neads
# end = 4096 # self.params.max_seq_len * 2=4096
# 计算词向量元素两两分组以后,每组元素对应的旋转角度
# torch.arange(0, dim, 2) [0, 2, 4, 6, 8, 10,..., 124, 126] 共64个
# torch.arange(0, dim, 2)[: (dim // 2)] 保证是64个
# torch.Size([64])
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# t = [0,....end] torch.Size([4096])
t = torch.arange(end, device=freqs.device) # type: ignore
# t为列向量 freqs为行向量做外积
# freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)
freqs = torch.outer(t, freqs).float() # torch.Size([4096, 64])
# 生成复数
# torch.polar(abs,angle) -> abs*cos(angle) + abs*sin(angle)*j
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# freqs_cis.shape = (end,dim//2)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# ndim为x的维度数 ,此时应该为4
# freqs_cis.shape = [1024, 64]
# x.shape = [2, 1024, 32, 64]
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# (1,x.shape[1],1,x.shape[-1]) = (1,1024,1,64)
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor, # [2, seqlen, 32, 128]
xk: torch.Tensor,
freqs_cis: torch.Tensor, # [seqlen,64]
) -> Tuple[torch.Tensor, torch.Tensor]: # [2, seqlen, 32, 128]
# xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]->[2, 1024, 32, 128]
# xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]->torch.Size([2, 1024, 32, 64])
# torch.view_as_complex用于将二维向量转换为复数域 torch.view_as_complex即([x,y]) -> (x+yj)
# 所以经过view_as_complex变换后xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
# (bsz,1024,32,128)->(bsz,1024,32,64)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# freqs_cis ->(4096,64)
# freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]->(1024,64)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# freqs_cis.shape = (1,x.shape[1],1,x.shape[-1]) = (1,1024,1,64)
# xq_ 与freqs_cis广播哈达玛积
# [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
# torch.view_as_real用于将复数再转换回实数向量, 再经过flatten展平第4个维度
# [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
2. Attention
大模型一般是分布式训练,这里涉及到几个概念。n_heads是注意力头的总个数,由于并行机制,每个进程会有n_local_heads个注意力头(为简化理解,可以假设只有一个进程)。由于计算当前位置的Attention Score依赖于之前所有的kv,因此需要将kv缓存下来。为了减少空间复杂度,可以对kv的头个数n_kv_heads进行调整,这个值一般小于等于n_heads,n_heads是n_kv_heads的整数倍,这个倍数也就是n_rep。相应的,每个进程会有n_local_kv_heads个注意力头。每个头的维度为head_dim=dim//n_heads。
例如:n_heads=32,model_parallel_size(并行数量)= 4,n_kv_heads = 8,n_local_heads = 32/4, n_local_kv_heads = 8/4,n_rep = 32/8。
关于kv_cache 的理解
- kv_cache是缓存键值对,在训练过程中,我们只保存最近n个键值对.
举个例子,假设有这样一个生成任务 ref:zhuanlan.zhihu.com/p/649756898
In [1]: {prompt:"将进酒:"}
Out [1]: 将进酒:人
In [2]: 将进酒:人
Out [2]: 将进酒:人生
In [3]: 将进酒:人生
Out [3]: 将进酒:人生得
In [4]: 将进酒:人生得
Out [4]: 将进酒:人生得意
In [5]: 将进酒:人生得意
Out [5]: 将进酒:人生得意需
In [6]: 将进酒:人生得意需
Out [6]: 将进酒:人生得意需尽
In [7]: 将进酒:人生得意需尽
Out [7]: 将进酒:人生得意需尽欢
在第三次预测中,是用"将进酒:人生" 来预测下一个"得"字,所以需要把"将进酒:人生"进行token化后再进行Attention计算:
不难发现在第二次处理的时候,就已经把"将进酒:人"所对应的Q,K,V进行过相关的运算,所以没必要在对他们进行Attention计算,这样就能节省大部分算力,由此K V Cache便是来解决这个问题的:通过将每次计算的K和V缓存下来,之后新的序列进来时只需要从KV Cache中读取之前的KV值即可,就不需要再去重复计算之前的KV了。
此外,对于 也不用将序列对应的所有 都计算出来,只需要计算最新的 , 即此时句子长度为1(seqlen=1) , kv同理, 参考108-119行代码。
- 注意:代码中只对kv进行cache,而没有对q 进行cache原因是当前token的q只需要和之前token的kv交互,之后不会再用到,所以没必要存;即使你把原来的Q保存下来,然后与新增的K算attention之后也会被mask掉,这是个单向注意力。
- 初始在prompt阶段seqlen>=1, 后续生成过程中seqlen==1。
关于 group query attention
而MQA(Multi Query Attention)就是Q依然保持多头,但是K,V只有一个,所有多头的Q共享一个K,V,因为要分组共享,那就要确认是哪几个组共用一个KV,那么通过repeat这种方式就能很直观的实现,相邻的(self.n_rep)头共享相同的KV的数值。
- self.n_rep=1的时候就是multi head attention, 可以在配置文件中配置 args.n_kv_heads 决定。
class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
"""
Initialize the Attention module.
Args:
args (ModelArgs): Model configuration parameters.
Attributes:
n_kv_heads (int): Number of key and value heads.
n_local_heads (int): Number of local query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (ColumnParallelLinear): Linear transformation for queries.
wk (ColumnParallelLinear): Linear transformation for keys.
wv (ColumnParallelLinear): Linear transformation for values.
wo (RowParallelLinear): Linear transformation for output.
cache_k (torch.Tensor): Cached keys for attention.
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass of the attention module.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.
Returns:
torch.Tensor: Output tensor after attention.
"""
# 假设当前x为(1, 1, dim),也就是上一个预测的token
bsz, seqlen, _ = x.shape
# 计算当前token的qkv
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 对当前token的qkv增加位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 缓存当前token的kv
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# 从缓存中读取之和当前的kv
# 之前索引是:start_pos
# 当前索引是:start_pos : start_pos + seqlen
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# grouped query attention
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
# 当seqlen 是 1 的时候,mask 是None
if mask is not None:
# 加入mask,使得前面的token在于后面的token计算attention时得分为0,mask掉
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
关于mask
132 行 mask 的格式:
- 推理阶段seqlen=1 mask == None:
- 初始在prompt阶段 seqlen>=1 ,mask 中 0元素的位置和score对应位置元素相加不变,-inf 和score 对应位置元素相加为-inf ,softmax 之后为0,mask 如下:
mask = None
# 假设 seq_len = 10
if seqlen > 1:
# [10,10]的全-inf 矩阵
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)
#下三角为0 的矩阵,0
mask = torch.triu(mask, diagonal=1) # (10,10)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
# 和之前cache 缓存的进行拼接
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask
]).type_as(h) #(10,8)+(10,10)
3.FeedForward
激活函数:
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
transformer源码逐行注释
transformer中的代码结构嵌套相对复杂,在需要对源码进行改造的时候,往往需要在下述代码中进行修改,此处主要对attention部分进行二次理解。github.com/huggingface…
待学习项:
- rope_scaling
- max_position_embeddings=4096 对应end,这也是Llama2最大的token处理数量。
- pretraining_tp=1 配置文件中等于1,注释解释为:该参数是Experimental feature,表示预训练期间的并行化等级,设置为非1值的话,将激活线性层的更准确但更慢的计算,这应该更好地匹配原始logits。
- position_ids 当前输入的序列的位置索引,如果是prompt ,past_key_values_length就是0,position_ids就是当前一次性输入的prompt 的长度,以上方“将进酒:”为例seq_length=Len(prompt)=4,注意严谨来说tokenizer.encoder之后应该是要比4长,这里为便于理解,就假设等长,position_ids =[0,1,2,3]; 如果是推理阶段,假设预测第10个字符,past_key_values_length = 9, seq_length=1, position_ids =[9]
position_ids= torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long, device=device)
以下参数维度以7B为例
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size # 4096
self.num_heads = config.num_attention_heads #32
self.head_dim = self.hidden_size // self.num_heads # 128
# 对于70B,num_heads=64,num_key_value_heads=8,num_key_value_groups=4
# 对于7B,num_heads=32,num_key_value_heads=32,num_key_value_groups=1
self.num_key_value_heads = config.num_key_value_heads # 32 7B的时候没有GQA
self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 1
self.max_position_embeddings = config.max_position_embeddings # 4096
self.rope_theta = config.rope_theta # 无 原定义为10000.0
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# [4096,4096]
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
# [4096,4096]
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
# [4096,4096]
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
# [4096,4096]
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
# 初始化不同的rope,包含基本rope, 线性rope, 还有动态差值的rope
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None: #null
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# [1,q_len,hidden_size]
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
# 输入通过线性变化得到 query_states,key_states,value_states
# [1,q_len,hidden_size]->[1,q_len,hidden_size]
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 维度转置
# [1,q_len,128*32]->[1,32,q_len,128]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# [1,q_len,128*32]->[1,32,q_len,128]
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] # q_len
# 如果是prompt past_key_value is None
# 如果是推理阶段 past_key_value is not None
# past_key_value: 一个tuple,装了两个元素,分别是缓存的 key 和 value
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# 给q和k添加rope 位置编码
# cos (`torch.Tensor`): The cosine part of the rotary embedding.
# sin (`torch.Tensor`): The sine part of the rotary embedding.
# value_states [1, 32, seq_len=q_len, 128]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# 如果有kv缓存, 就把之前缓存的和现在的进行拼接
# past_key_value[0]:(1,32,past_key_value_length,128)
# key_states:(1,32,q_len,128)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# 拼接之后 key_states:(1,32,kv_seq_len,128)
# kv_seq_len = past_key_value_length+q_len
past_key_value = (key_states, value_states) if use_cache else None
# grouped query attention的时候需要共享参数,num_key_value_groups = 1,就是multi head attention
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# 计算attention
# (1,32,q_len,128)*(1,32,128,kv_seq_len) ->(1,32,q_len,kv_seq_len)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
# attention_mask 保证单项,也为了告诉模型那些部分的attention_score有效
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
ref:
Llama 2详解 zhuanlan.zhihu.com/p/649756898