小白跟着 NanoGPT 学习 Transformer

88 阅读4分钟

Function describes the world.

arxiv.org/abs/1706.03…

核心机制

Attention(Q,K,V)=softmax(QKd)Vwhere Q=WQ(i)φi(zi),  K=WK(i)τθ(y),  V=WV(i)τθ(y)and WQ(i)Rd×dϵi,  WK(i),WV(i)Rd×dτ,  φi(zi)RN×dϵi,  τθ(y)RM×dτ\begin{aligned} &\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\Big(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d}}\Big) \cdot \mathbf{V} \\ &\text{where }\mathbf{Q} = \mathbf{W}^{(i)}_Q \cdot \varphi_i(\mathbf{z}_i),\; \mathbf{K} = \mathbf{W}^{(i)}_K \cdot \tau_\theta(y),\; \mathbf{V} = \mathbf{W}^{(i)}_V \cdot \tau_\theta(y) \\ &\text{and } \mathbf{W}^{(i)}_Q \in \mathbb{R}^{d \times d^i_\epsilon},\; \mathbf{W}^{(i)}_K, \mathbf{W}^{(i)}_V \in \mathbb{R}^{d \times d_\tau},\; \varphi_i(\mathbf{z}_i) \in \mathbb{R}^{N \times d^i_\epsilon},\; \tau_\theta(y) \in \mathbb{R}^{M \times d_\tau} \end{aligned}

每个 token 定义三个语义 embedding:query / key / value,对输入序列每个 token query embedding 和(输入序列中)每个 token key embedding 计算点积 ( 计算相似度 ),经 softmax 转换成概率后与 ( 输入序列中 ) 每个 token value embedding 向量相乘得到代表当前 token 的新向量。如此,不包含上下文的 static embedding 就成为包含上下文的 dynamic embedding,实现并行计算和解决长距离依赖问题。

But 为什么是 query / key / value,说下我浅薄的理解:在不同的序列 ( 上下文 ) 中同一 token 或有不同的语义,如此就需要通过某种方式引入上下文信息,一个自然想法:引入其余 token 信息,如将序列中所有 token embedding ( 加权 ) 求和代表当前 token 在当前序列中包含上下文的 embedding?这不正是对应 query 和 key 的点积 ( 代表的相似性 ) ?那干嘛不直接学习点积结果?在单独的 encoder 和 decoder 或许效果不错,但是在 encoder 和 decoder 配合使用的情况下 ( 如翻译 self-attention 和 cross-attention 的区别 ) ,进一步拆解点积结果成 query 和 key 是一个更好的设计:qeury 来自目标语言序列、query 和 value 来自原文语言序列。

此外,如何处理序列位置问题,毕竟同样的 token 序列但位置不同也会有不同的语义。一个直觉的想法便是引入 position embedding 与 token embedding 拼接,如此这个新的 embedding 就引入了位置信息,这种拼接方式可以达到目的,但也带来了明显的纬度膨胀和计算爆炸。另一种方式则是将 position embedding 和 token embedding 相加来引入位置信息,这种方式可以类比波段的叠加,不同的波段叠加在一起可以通过变化进行解析 ( 可以 embedding 是以 one hot 为输入的全链接层 )。

NanoGPT 对应核心代码如下:

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        # query / key / value 线性变换
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        # 输出线性变换
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        # 并行处理 B 个 sample,每个 sample 包含 T 个 token,每个 token 包含 C 个维度
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # query、key、value 以及 mulit head 在一个线性变换中学习
        # 通过 split 分割出 query、key、value
        # 通过 view 改变矩阵形状拆 multi head
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            # 计算 qeury 和 key 相关性
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            # 将相关性转换成 0 - 1 概率值
            att = F.softmax(att, dim=-1)
            # 随机将权重值 0,避免训练过拟合,增加模型鲁棒性
            att = self.attn_dropout(att)
            # value 权重和作为新 token embedding(如此包含上下文信息)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        
        # 拼接 multi head
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        # 拼接后 head 进行一次线性变化(类同多个卷积核不同的权重,学习不同的特征)
        y = self.resid_dropout(self.c_proj(y))
        return y

此处安利一篇解读文:Attention Is All You Need (Transformer) 论文精读

网络架构

# Nanogpt 和论文略微不同的细节:
# 1. 使用可学习位置编码(Positional Embedding);
# 2. 先进行层归一化、再进行权重学习(Norm -> Mulit-Head Attention / Feed Forward -> ADD);
# 3. 仅实现 decoder:encoder 用于理解、decoder 用于续写、encoder + decoder 用于翻译;
class GPT.__init__:
    self.transformer = nn.ModuleDict(dict(
        # Output Embedding
        wte = nn.Embedding(config.vocab_size, config.n_embd),
        # Positional Embedding
        wpe = nn.Embedding(config.block_size, config.n_embd),
        drop = nn.Dropout(config.dropout),
        # Nx
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        class Block__init__:
            # Norm
            self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
            # Multi-Head Attention
            self.attn = CausalSelfAttention(config)
            class CausalSelfAttention__init__:
                self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
                self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
                self.attn_dropout = nn.Dropout(config.dropout)
                self.resid_dropout = nn.Dropout(config.dropout)
            # Norm
            self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
            # Feed Forward
            self.mlp = MLP(config)
            class MLP__init__:
                self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
                self.gelu    = nn.GELU()
                self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
                self.dropout = nn.Dropout(config.dropout)
        ln_f = LayerNorm(config.n_embd, bias=config.bias),
    ))
    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

shape:

  wte: vocab_size -> n_embd

  self-attention: n_embd -> 3 * n_embd -> n_embd

  mlp: n_embd -> 4 * n_embd -> n_embd

  output: n_embd -> vocab_size

vocab_size -> n_embd -> 3 * n_embd -> n_embd -> 4 * n_embd -> n_embd -> vocab_size

前向传播

def forward(class GPT):
    # Output Embedding
    tok_emb = self.transformer.wte(idx) 
    # Positional Embedding
    pos_emb = self.transformer.wpe(pos) 
    x = self.transformer.drop(tok_emb + pos_emb)
    # Nx
    for block in self.transformer.h:
        x = block(x)
        def forward(class Block):
            # Norm -> Multi-Head Attention -> Add
            x = x + self.attn(self.ln_1(x)) 
            def forward(class CausalSelfAttention):
                B, T, C = x.size() 
                q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
                k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
                q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
                v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
                att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
                att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
                att = F.softmax(att, dim=-1)
                att = self.attn_dropout(att)
                y = att @ v 
                y = y.transpose(1, 2).contiguous().view(B, T, C)
                y = self.resid_dropout(self.c_proj(y))
            # Norm -> Feed Forward -> Add
            x = x + self.mlp(self.ln_2(x)) 
            def forward(class MLP):
                x = self.c_fc(x)
                x = self.gelu(x)
                x = self.c_proj(x)
                x = self.dropout(x)
    # Linear
    x = self.transformer.ln_f(x)
    # 计算损失函数(交叉熵)
    logits = self.lm_head(x)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)