缩放点击注意力推导

71 阅读2分钟

点击缩放注意力推导

点击缩放评分公式为:a(q,k) / d ** -1

这里提出一个问题,为什么要除以根号d,以及这个d表示的什么

现在我们来推导一个这个a(q,k)的方差

  1. 假设满足以下条件
  2. 假设Qi,Ki都是独立随机变量
  3. 每个Qi,Ki的均值都为0,方差为1
  4. 点积的计算是 q*k
  5. 根据方差的性质:独立随机变量之和的方差等于方差之和
  6. var(Qi,Ki) = E((Qi * Ki) ** 2) - E(Qi * Ei) ** 2 (平方的均值减均值的平方)
  7. 由于Q,K独立,则E(QiKi) = E(Qi) * E(Ki) = 0
  8. E(Qi * Ki) = 0,
  9. E(Qi ** 2) = Var(q) + E(q) ** 2 = 1
  10. 同理 E(Ki ** 2) = 1,则 E((Qi * Ki) ** 2) = 1 * 1 = 1
  11. 可得 Var(Qi,Ki) = 1 - 0 = 1
  12. 又因为 V(Q * K)是d个V(Qi * Ki)相加得到的,所以等于d
为什么要除以根号d?
  1. var((q * k) / d ** -1) = 1/d * var(q * k) = 1
  2. 那为什么要这样呢,就是为了保持方差与d的无关性。
  3. 带入softmax()中,按照[10,100]举例,0.009与0.909差距还是很大的,所以消除d的影响减小方差,就是为了避免相近的权重差距过大。

下面是代码

from d2l.torch import masked_softmax
import math

#TODO:缩放点击注意力的实现
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        # queries的形状:(batch_size,查询的个数,d)
        # keys的形状:(batch_size,“键-值”对的个数,d)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)#遮庇的效果,指定有效长度的序列
        return torch.bmm(self.dropout(self.attention_weights), values)