推理流程
从输入文本,到推理输出文本,LLama2模型处理流程如下:
step1 Tokenization
输入数据:一个句子或一段话。通常表示成单词或字符序列。
Tokenization即对文本按单词或字符序列切分,形成Token序列。Token序列再转换为整数索引序列(索引即是单子或字符在语料库中index)。
# input text
[Where is Shanxi]
# 切分tokens序列
['Where' 'is' 'Shanxi']
# 转换为语料库index
['BOS' '8' '21' '33' 'EOS']
step2 Embedding
Embedding将每个Token映射为一个实数向量,为Embeding Vector。
'BOS' -> [p_{00},p_{01},p_{02},...,p_{0d-1}]
'8' -> [p_{10},p_{11},p_{12},...,p_{1d-1}]
...
'EOS'-> [p_{n0},p_{n1},p_{n2},...,p_{nd-1}]
step3 位置编码
位置编码(Positional Encoding)向量提供Token在序列中位置的信息。位置编码是为了区分不同位置的Token,并为模型提供上下文关系的信息。
[p_{00},p_{01},p_{02},...,p_{0d-1}] [pe_{00},pe_{01},pe_{02},...,pe_{0d-1}]
[p_{10},p_{11},p_{12},...,p_{1d-1}] [pe_{10},pe_{11},pe_{12},...,pe_{1d-1}]
[p_{20},p_{21},p_{22},...,p_{2d-1}] + [pe_{20},pe_{21},pe_{22},...,pe_{2d-1}]
... ...
[p_{n0},p_{n1},p_{n2},...,p_{nd-1}] [pe_{n0},pe_{n1},pe_{n2} ,...,pe_{nd-1}]
step4 自回归生成
在生成任务中,使用自回归(Autoregressive)方式,即逐个生成输出序列中的每个Token,Decoder-Only。在解码过程中,每次生成一个Token时,使用前面已生成的内容作为上下文,来帮助预测下一个Token
自回归生成demo如下:
model = LLaMA2()
def generate(inputs, n_tokens_to_generate):
for _ in range(n_tokens_to_generate): # auto-regressive decode loop
output = model(inputs) # model forward pass
next = np.argmax(output[-1]) # greedy sampling
inputs.append(next) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated tokens
input = [p0, p1,p2] #对应['BOS','where','is', 'shanxi']
output_ids = generate(input, 3) # 生成 ['p3','p4','p5']
output_ids = decode(output_ids) # 通过Tokenization解码
output_tokens = [vocab[i] for i in output_ids] # "Shanxi" "is" "a" "province"
step5 输出处理
生成的Token序列通过一个输出层,通常是线性变换加上Softmax函数,将每个位置的概率分布转换为对应Token的概率。根据概率,选择概率最高的Token或者作为模型的预测结果。
模型结构
LLama2模型结构与标 准Transformer Decoder结构基本一致,由32个Transfomer Block组成,不同点如下:
- 1 前置RMSNorm层
- 2 RoPE位置编码
- 3 KVCache
- 4 FeedForward层
1 RMSNorm
Transformer中的Normalization层采用LayerNorm来对Tensor进行归一化,RMSNorm就是LayerNorm的变体。 公式图:
# RMSNorm
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps # ε
self.weight = nn.Parameter(torch.ones(dim)) #可学习参数γ
def _norm(self, x):
# RMSNorm
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
2 RoPE旋转位置编码
什么是绝对位置编码,什么是相对位置编码?
位置编码向量生成方法有很多,常见绝对位置编码是使用三角函数对位置进行编码,公式如下:
绝对位置编码具有实现简单、计算速度快等优点,而相对位置编码则直接地体现了相对位置信号,跟直观理解吻合,实际性能往往也更好。
RoPE解决了一个什么问题?通过绝对位置编码的方式实现相对位置编码。
详细公式推理不展开,可参考LLaMA中ROPE位置编码实现源码解析
# 精简版Attention
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)
self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# attention 操作之前,应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
#...
# 进行后续Attention计算
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)
# ......
3 KV Cache
KV cache的峰值显存占用大小: ,输入序列长度s,输出序列长度n,fp16占用2个字节,transformer模型的层数为l,隐藏层维度为h。
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
if kvcache:
# qkv
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
old_k, old_v = kvcache
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
qkv = [new_q, k, v]
4 FeedForward
SiLu激活函数: