[手写系列] 1.attention

1,381 阅读1分钟

本文已参与[新人创作礼]活动,一起开启掘金创作之路。

1.  self_attention_layer 

注意:

  1. q、k维度相同,v维度可以不同(=输出维度)

  2. 矩阵乘法np.matmul(m,n)。 矩阵点乘 np.multiply(m, n)。

1.1 手写 self_attention_layer (没写反向,之后补充)

X样本是按列的,所以左乘,然后softmax时也要按列

# 手写版本
import numpy as np
from numpy.random import randn
d = 256
n = 32
x = randn(d,n)  #256*32
Wq = randn(d,d) #256*256
Wk = randn(d,d) #256*256
Wv = randn(d,d) #256*256
q = Wq @ x #256*32 一个词一列
k = Wk @ x #⬆️
v = Wv @ x #⬆️
A = k.T @ q #32*32
A /= np.sqrt(d) #/根号下d
def softmax(x):
    e_x = np.exp(x-np.max(x)) #防溢出,最大值归一化
    return e_x/e_x.sum(axis=0)  #axis=0是列,是第一个维度;axis=1是行,是第二个维度 (q是列维度,k.T*q也是列维度)
A_hat = softmax(A)
output = v @ A_hat

1.2 pytorch版本的attention (没写反向,之后补充)

from math import sqrt
import torch
import torch.nn as nn

class Self_Attention(nn.Module):
    def __init__(self,input_dim,dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        self.__norm_fact = 1/sqrt(dim_k)
    def forward(self,x): #b*seq_len*input_dim  #注意,一行是一个词。
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        atten = nn.Softmax(dim=1)(torch.bmm(Q,K.permute(0,2,1))) #K.permute(0,2,1)相当于转置
        output = torch.bmm(atten,V)
        return output

at = Self_Attention(16,256,256) #初始化类 input_dim,dim_k,dim_v=16,256,256
x = torch.rand(4,10,16)
output = at.forward(x) #调用类的函数