自己写 transformer(1)

242 阅读4分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第26天,点击查看活动详情

导入依赖

import os
import random
import math
import json
from functools import partial
import numpy as np
# 导入绘制
import matplotlib.pyplot as plt
#设置 colormap
plt.set_cmap('cividis')
%matplotlib inline

from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
# 导入 tqdm 作为加载进度条
from tqdm.notebook import tqdm
# pyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

Torchvison 这个包主要提供当下流行的数据库、模型架构和常用的图像转换,例如旋转、平移、裁剪和镜像得等对图像几何的变换,也包括对图像色彩、饱和度和亮度的变换。

# Torchvison
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

这里将使用 pytorch Lightning 这个附加的框架,注意力机制是 transformer 的核心,只有很好理解了注意机制,注意力机制是最近引入一个新的层,我们可能都知道全连接层、卷积层,现在有多了个一个注意力层,注意力层特别适合序列任务,关于注意力是如何实现的,可以根据,就是将计算元素与其周围元素进行该元素与其他元素的相关性,然后根据相关性将其他元素信息聚合到该元素形成一个具有上下文信息新的元素。

其实关于注意力实现也是五花八门,这里我们实现注意力还是最基础的,根据查询项和键值项计算出权重,然后再用权重和值做一个加权求和,等价于对

  • 查询项(Query)
  • 键值项(key)
  • 值(Value)
  • 计算相关性分值的函数(Score function)
αi=exp(fattn(keyi,query))jexp(fattn(keyj,query))  ,out=iαivalue\alpha_i = \frac{\exp(f_{attn}(key_i,query))}{\sum_j \exp(f_{attn}(key_j,query))}\;,out=\sum_i \alpha_i \cdot value
def scaled_dot_product(q, k, v, mask=None):
    # 也就是输出序列每一个元素的维度
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q,k.transpose(-2,-1))
    attn_logits = attn_logits / math.sqrt(d_k)
    
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask ==0,-9e15)
    attention = F.softmax(attn_logits,dim=-1)
    values = torch.matmul(attention,v)
    return values,attention

核心概念就是自注意机制,这里自注意力机制采用 scaled dot production 注意力机制,目的是计算其他元素到该元素相关性,从而确定该元素有多少信息量将会参与到该元素中,换句话说就是这个元素将会从其他元素吸收多少信息量来增强自己信息。

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right)V

关于矩阵计算我们需要明确每一个参与元素维度,QRT×dkQ \in \mathbb{R}^{T \times d_k}KRT×dkK \in \mathbb{R}^{T \times d_k}VRT×dvV \in \mathbb{R}^{T \times d_v} 这里 TT 表示序列长度

attn_logits = torch.matmul(q,k.transpose(-2,-1)) 实现的事 QKTQK^T 这一个步骤,因为通常输入为 (batchsize,seq,dmodel)(batch_size,seq,d_{model}) 所以对后两个维度做转置

attn_logits = attn_logits / math.sqrt(d_k) 为什么要除以 dk\sqrt{d_k} 这就是一个缩放因子 1/\sartdk1/\sart{d_k} ,之前初始化参数我们了解到希望每一层输入和输入能够保持一个方差为 1 这样利于信息的传递

qiN(0,σ2),kiN(0,σ2)Var(i=1dkqiki)=σ2dkq_i \sim N(0,\sigma^2),k_i \sim N(0,\sigma^2) \leftarrow Var \left( \sum_{i=1}^{d_k} q_i \cdot k_i \right) = \sigma^2 \cdot d_k

if mask is not None:
    attn_logits = attn_logits.masked_fill(mask ==0,-9e15)

这里 mask 控制在序列中哪些元素会参加到注意力计算,这是因为在解码器端,需要从左向右一个一个输入词,所以当输入某一个值词,该词右侧的元素是不需要参加计算的所以需要将其遮罩,通过给他一个很大负数来进行遮罩。

import pytorch_lightning as pl

seq_len,d_k = 3,2 pl.seed_everything(42) q = torch.randn(seq_len,d_k) k = torch.randn(seq_len,d_k) v = torch.randn(seq_len,d_k)

values, attention = scaled_dot_product(q, k, v) print("Q\n",q) print("K\n",k) print("V\n",v) print("Values\n",values) print("Attention\n",attention)

seq_len, d_k = 3,2

多头注意力实现

scaled dot product 可以计算序列中元素之间的相关性,然后根据相关性进行信息的聚合。而我们像从不同方面(角度)来聚合信息,但是仅是做一次 scaled dot product 来计算元素间相关性远远不够的,其实所谓多头,就是我们计算多次注意力然后将计算结果进行拼接

Multihead(Q,K,V)=Concat(head1,,headh)WO  whereheadi=Attention(QWiQ,KWiK,VWiV)Multihead(Q,K,V) = Concat(head_1,\cdots,head_h)W^O \; where\, head_i = Attention(QW_i^Q,KW_i^K,VW_i^V)

对于多头注意力机制模块来说也就是将输入 batch  size×seq  length×embed  dimbatch\;size \times seq\;length \times embed\;dim 转换为 batch  size×  seqlength×head  dimbatch\;size \times\;seq_length \times head\;dim 这里 headdimhead_dim 就是 embed  dim/num  headsembed\;dim/num\;heads 然后经过 scaled dot product 来计算出 batchsize×seq  length×headdimbatch_size \times seq\;length \times head_dim 再去在 embeding 维度上进行拼接,从而得到

我们将输入分别经过 h 个 query-key-value 组合,

这里这里可以学习参数为 WW_{}

class MultiheadAttention(nn.Module):
    """
    
    """
    def __init__(self, input_dim, embed_dim, num_heads):
        #也就是我们 embedding 维度需要能够被 num_heads 整除
        assert embed_dim % num_heads == 0,"Embedding dimension must be 0 modulo number of heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        #输入维度为 input_dim,然后线性变换为 3*embed_dim 也就是 K,Q,V
        
        self.qkv_proj = nn.Linear(input_dim,3*embed_dim)
        # 
        self.o_proj = nn.Linear(embed_dim,embed_dim)
        # 将参数进行一次均值初始化
        self._reset_parameters()
        
    def _reset_parameters(self):
        # 对参数进行一次初始化
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
        
    def forward(self, x, mask=None,return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        # (batch_size,seq_length,embed_dim*3)
        qkv = self.qkv_proj(x)
        
        #将线性输出分离出 Q、K 和 V
        # 添加一个 head nubmer 维度 也就是 (batch_size,seq_length,self.num_header)
        # 将self.embed_dim 切分
        qkv = qkv.reshape(batch_size,seq_length,self.num_heads,3*self.head_dim)
        # (batch_size, num_head, seq_lenth, 3*head_dim)
        qkv = qkv.permute(0,2,1,3)
        #在 head_dim 维度上进行切分为 Q、K 和 V
        q, k, v= qkv.chunk(3,dim=-1)
        
        #确定 value 输出
        values, attention = scaled_dot_product(q,k,v,mask=mask)
        #调整交换 head num 和 seq_length
        values = values.permute(0,2,1,3)
        #在 embeding 维度上进行拼接
        values = values.reshape(batch_size,seq_length,embed_dim)
        o = self.o_proj(values)
        
        if return_attention:
            return o, attention
        else:
            return o