[NLP] transformer_基本知识

167 阅读3分钟

self attention 基本概念:

image1.png 输入是向量序列x1,x2,...,xnx_{1},x_{2},...,x_{n} ,输出是向量序列 y1,y2,...,yny_{1},y_{2},...,y_{n}

如何得到:

S1:对xix_{i} 转置得到xix_{i}^{'}

S2: 系数wijw_{ij}^{'}=xixjx_{i}^{'}x_{j}

S3: softmax 系数到【0,1】得到wij=softmax(wij)w_{ij} = softmax(w_{ij}^{'})

S4: yi=jwijxjy_{i} = \sum_{j}w_{ij}x_{j}

注意:self attention 不分词的顺序,单词之间的联系由任务决定。wijw_{ij}不是param, 是输入向量序列之间的联系

再回到图1,y2=w21x1+w22x2+w23x3+w24x4y_{2} = w_{21}x_{1}+ w_{22}x_{2}+ w_{23}x_{3}+ w_{24}x_{4}

如何用pytorch 实现xixjx_{i}^{'}x_{j}的计算?

import torch 
import torch.nn.functional as F 
# assume we have some tensor x with size (b, t, k) 
x = ... 
raw_weights = torch.bmm(x, x.transpose(1, 2))
weights = F.softmax(raw_weights, dim=2)
y = torch.bmm(weights, x)

注意此处用torch.bmm 进行运算,已知torch.mm 进行两个tensor 的数学乘法运算。

image.png

torch.bmm 要注意两点,一个是输入必须是两个3-D tensor, 二是第0维不参加运算(也就是batch size 这一维不参加乘法运算)。

qkv 的引入:

image.png 因为xix_{i}有三种用途,引入query, key 和 value, 表示为:

qi=Wqxiq_{i}=W_{q}x_{i}

ki=Wkxik_{i}=W_{k}x_{i}

vi=Wvxiv_{i}=W_{v}x_{i}

wij=qiTkjw_{ij}^{'}=q_{i}^{T}k_{j}

wij=softmax(wij)w_{ij}=softmax(w_{ij})

yi=jwijvjy_{i}=\sum_{j}w_{ij}v_{j}

回到图2,y2y_{2} 是由q,k,v 三方决定的。

又因为输入值太大会导致一些问题,所以需要scale 一下:

wij=qiTkjkw_{ij}^{'}=\frac{q_{i}^{T}k_{j}}{\sqrt{k} }

其中,embedding size 是k

multi head attention 的引入:

Susan gave Mary the roses != Mary gave Susan the roses 因此要多次学习。

最简单的是如果输入时256维,有8个注意力头,每个头是32维。对于第r 个头,Wqr,Wkr,WvrW_{q}^{r},W_{k}^{r},W_{v}^{r} 都是32*32

下面是代码和维度解释:

输入x大小为(b,t,k)

image.png

q,k,v 大小为

S1: nn.Linear(k,k*heads)——>(b,t,k*heads)
S2: .view(b,t,h,k)——>(b,t,h,k)

image.png

让head 和dimension 维度靠近,

S1: transpose(1,2)——>(b,h,t,k)

S2: .view(b*h,t,k)——>(b*h,t,k)

先scale 再进行qkv之间的运算:

S1: torch.bmm(queries,keys.tranpose(1,2))——>(b*h,t,k)(b*h,k,t)——>(b*h,t,t)
S2:  torch.bmm(dot, values).view(b, h, t, k)——>(b*h,t,t)(b*h,t,k)——>(b*h,t,k)——>(b,h,t,k)

最后再把k*h 合并:

S1:out.transpose(1, 2).contiguous().view(b, t, h * k)——>(b,t,h,k)——>(b,t,h*k)
S2: nn.Linear(heads * k, k)——>(b,t,k)

完整selfAttention 代码:

class SelfAttention(nn.Module):
    def __init__(self, emb, heads=8, mask=False):
        """
        :param emb:
        :param heads:
        :param mask:
        """

        super().__init__()

        self.emb = emb
        self.heads = heads
        self.mask = mask

        self.tokeys = nn.Linear(emb, emb * heads, bias=False)
        self.toqueries = nn.Linear(emb, emb * heads, bias=False)
        self.tovalues = nn.Linear(emb, emb * heads, bias=False)

        self.unifyheads = nn.Linear(heads * emb, emb)

    def forward(self, x):

        b, t, e = x.size()
        h = self.heads
        assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})'

        keys    = self.tokeys(x)   .view(b, t, h, e)
        queries = self.toqueries(x).view(b, t, h, e)
        values  = self.tovalues(x) .view(b, t, h, e)

        # compute scaled dot-product self-attention

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, e)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, e)
        values = values.transpose(1, 2).contiguous().view(b * h, t, e)

        # - get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        dot = dot / math.sqrt(e) # dot contains b*h  t-by-t matrices with raw self-attention logits

        assert dot.size() == (b*h, t, t), f'Matrix has size {dot.size()}, expected {(b*h, t, t)}.'

        if self.mask: # mask out the lower half of the dot matrix,including the diagonal
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)

        dot = F.softmax(dot, dim=2) # dot now has row-wise self-attention probabilities

        assert not util.contains_nan(dot[:, 1:, :]) # only the forst row may contain nan

        if self.mask == 'first':
            dot = dot.clone()
            dot[:, :1, :] = 0.0
            # - The first row of the first attention matrix is entirely masked out, so the softmax operation results
            #   in a division by zero. We set this row to zero by hand to get rid of the NaNs

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, e)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, h * e)

        return self.unifyheads(out)

第一个疑问:代码里的mask

把t时刻之后的信息隐藏起来,所以用一个三角形矩阵。

构建transformer:

image.png

根据图三可知,transformer 包含1 self attention 2 layer norm 3 MLP 4 layer norm 组成,注意还有跳层连接。transformer 代码可以这么写:

class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()

        self.attention = SelfAttention(k, heads=heads)

        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)

        self.ff = nn.Sequential(
            nn.Linear(k, 4 * k),
            nn.ReLU(),
            nn.Linear(4 * k, k))

    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)

        fedforward = self.ff(x)
        return self.norm2(fedforward + x)

第二个疑问: feedforward 有什么用?

找了下资料,feedforward包含三步:线性变换+relu+线性变换。也就是把数据映射到高维空间再映射到低维空间,确保学习更抽象的特征。用激活函数强化表达能力。