点击缩放注意力推导
点击缩放评分公式为:a(q,k) / d ** -1
这里提出一个问题,为什么要除以根号d,以及这个d表示的什么
现在我们来推导一个这个a(q,k)的方差
- 假设满足以下条件
- 假设Qi,Ki都是独立随机变量
- 每个Qi,Ki的均值都为0,方差为1
- 点积的计算是 q*k
- 根据方差的性质:独立随机变量之和的方差等于方差之和
- var(Qi,Ki) = E((Qi * Ki) ** 2) - E(Qi * Ei) ** 2 (平方的均值减均值的平方)
- 由于Q,K独立,则E(QiKi) = E(Qi) * E(Ki) = 0
- E(Qi * Ki) = 0,
- E(Qi ** 2) = Var(q) + E(q) ** 2 = 1
- 同理 E(Ki ** 2) = 1,则 E((Qi * Ki) ** 2) = 1 * 1 = 1
- 可得 Var(Qi,Ki) = 1 - 0 = 1
- 又因为 V(Q * K)是d个V(Qi * Ki)相加得到的,所以等于d
为什么要除以根号d?
- var((q * k) / d ** -1) = 1/d * var(q * k) = 1
- 那为什么要这样呢,就是为了保持方差与d的无关性。
- 带入softmax()中,按照[10,100]举例,0.009与0.909差距还是很大的,所以消除d的影响减小方差,就是为了避免相近的权重差距过大。
下面是代码
from d2l.torch import masked_softmax
import math
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)