多头注意力机制简介

271 阅读1分钟

多头注意力机制

WQ,WK,WVW_Q, W_K, W_V 为权重矩阵, XX 为权重输入

Q=XWqQ = XW_q

K=XWkK = XW_k

V=XWvV = XW_v

Q,K,V:=reshape(Q),reshape(K),reshape(V)Q, K, V := reshape(Q),reshape(K), reshape(V)

Q,K,VQ, K, V 三个矩阵由 (batch_size, seq_len, embedding_dim) 重塑为 (batch_size, seq_len, num_heads, head_dim) 之后继续操作

其中 embedding_dim = num_heads * head_dim 必须能被整除才可以拆分

Attention(Qi,Ki,Vi)=softmax(QiKiTdk)ViAttention(Q_i,K_i,V_i) = softmax(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i

MultiheadAttention(Q,K,V)=concat(Attetion0,..,Attetionn)WOMultiheadAttention(Q, K, V) = concat(Attetion_0,.., Attetion_n)W^O

其中Attetioni=Attention(Qi,Ki,Vi)Attetion_i = Attention(Q_i,K_i,V_i)

dkd_khead_dim 大小的常量, WoW_o 为权重矩阵

手动实现一下大致是这样

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_model // n_heads
        self.Wq, self.Wk, self.Wv = (
            nn.Linear(d_model, d_model),
            nn.Linear(d_model, d_model), nn.Linear(d_model, d_model)
        )
        self.Wo = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, q, k, v, mask=None):
        q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
        q, k, v = self._split_heads(q), self._split_heads(k), self._split_heads(v)
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.d_k

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        # (batch_size, num_heads, seq_len, seq_len)
        attn_weights = self.dropout(attn_weights)

        weighted_sum = torch.matmul(attn_weights, v)
        # (batch_size, num_heads, seq_len, d_v)
        # Concatenate heads
        weighted_sum = weighted_sum.transpose(1, 2)
        weighted_sum = weighted_sum.reshape(weighted_sum.size(0), -1, self.d_model)
        # (batch_size, seq_len, d_model)
        # Final linear transformation
        output = self.Wo(weighted_sum)  # (batch_size, seq_len, d_model)

        return output, attn_weights

    def _split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        d_head = d_model // self.n_heads
        x = x.view(batch_size, seq_len, self.n_heads, d_head)
        return x.permute(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, d_head)


multi_head_attention = MultiHeadAttention(d_model=512, n_heads=16, dropout=0.1)
x = torch.rand((2, 512, 512))
output, attn_weights = multi_head_attention(x, x, x)
print(output.size())
print(attn_weights.size())